Files
Neo-ZQYY/apps/backend/app/routers/db_viewer.py

229 lines
6.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""数据库查看器 API
提供 4 个端点:
- GET /api/db/schemas — 返回 Schema 列表
- GET /api/db/schemas/{name}/tables — 返回表列表和行数
- GET /api/db/tables/{schema}/{table}/columns — 返回列定义
- POST /api/db/query — 只读 SQL 执行
所有端点需要 JWT 认证。
使用 get_etl_readonly_connection(site_id) 确保 RLS 隔离。
"""
from __future__ import annotations
import logging
import re
from fastapi import APIRouter, Depends, HTTPException, status
from psycopg2 import errors as pg_errors, OperationalError
from app.auth.dependencies import CurrentUser, get_current_user
from app.database import get_etl_readonly_connection
from app.schemas.db_viewer import (
ColumnInfo,
QueryRequest,
QueryResponse,
SchemaInfo,
TableInfo,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/db", tags=["数据库查看器"])
# 写操作关键词(不区分大小写)
_WRITE_KEYWORDS = re.compile(
r"\b(INSERT|UPDATE|DELETE|DROP|TRUNCATE)\b",
re.IGNORECASE,
)
# 查询结果行数上限
_MAX_ROWS = 1000
# 查询超时(秒)
_QUERY_TIMEOUT_SEC = 30
# ── GET /api/db/schemas ──────────────────────────────────────
@router.get("/schemas", response_model=list[SchemaInfo])
async def list_schemas(
user: CurrentUser = Depends(get_current_user),
) -> list[SchemaInfo]:
"""返回 ETL 数据库中的 Schema 列表。"""
conn = get_etl_readonly_connection(user.site_id)
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
ORDER BY schema_name
"""
)
rows = cur.fetchall()
return [SchemaInfo(name=row[0]) for row in rows]
finally:
conn.close()
# ── GET /api/db/schemas/{name}/tables ────────────────────────
@router.get("/schemas/{name}/tables", response_model=list[TableInfo])
async def list_tables(
name: str,
user: CurrentUser = Depends(get_current_user),
) -> list[TableInfo]:
"""返回指定 Schema 下所有表的名称和行数统计。"""
conn = get_etl_readonly_connection(user.site_id)
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT
t.table_name,
s.n_live_tup
FROM information_schema.tables t
LEFT JOIN pg_stat_user_tables s
ON s.schemaname = t.table_schema
AND s.relname = t.table_name
WHERE t.table_schema = %s
AND t.table_type = 'BASE TABLE'
ORDER BY t.table_name
""",
(name,),
)
rows = cur.fetchall()
return [
TableInfo(name=row[0], row_count=row[1])
for row in rows
]
finally:
conn.close()
# ── GET /api/db/tables/{schema}/{table}/columns ──────────────
@router.get(
"/tables/{schema}/{table}/columns",
response_model=list[ColumnInfo],
)
async def list_columns(
schema: str,
table: str,
user: CurrentUser = Depends(get_current_user),
) -> list[ColumnInfo]:
"""返回指定表的列定义。"""
conn = get_etl_readonly_connection(user.site_id)
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT
column_name,
data_type,
is_nullable,
column_default
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
""",
(schema, table),
)
rows = cur.fetchall()
return [
ColumnInfo(
name=row[0],
data_type=row[1],
is_nullable=row[2] == "YES",
column_default=row[3],
)
for row in rows
]
finally:
conn.close()
# ── POST /api/db/query ───────────────────────────────────────
@router.post("/query", response_model=QueryResponse)
async def execute_query(
body: QueryRequest,
user: CurrentUser = Depends(get_current_user),
) -> QueryResponse:
"""只读 SQL 执行。
安全措施:
1. 拦截写操作关键词INSERT / UPDATE / DELETE / DROP / TRUNCATE
2. 限制返回行数上限 1000 行
3. 设置查询超时 30 秒
"""
sql = body.sql.strip()
if not sql:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="SQL 语句不能为空",
)
# 拦截写操作
if _WRITE_KEYWORDS.search(sql):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="只允许只读查询,禁止 INSERT / UPDATE / DELETE / DROP / TRUNCATE 操作",
)
conn = get_etl_readonly_connection(user.site_id)
try:
with conn.cursor() as cur:
# 设置查询超时
cur.execute(
"SET LOCAL statement_timeout = %s",
(f"{_QUERY_TIMEOUT_SEC}s",),
)
try:
cur.execute(sql)
except pg_errors.QueryCanceled:
raise HTTPException(
status_code=status.HTTP_408_REQUEST_TIMEOUT,
detail=f"查询超时(超过 {_QUERY_TIMEOUT_SEC} 秒)",
)
except Exception as exc:
# SQL 语法错误或其他执行错误
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"SQL 执行错误: {exc}",
)
# 提取列名
columns = (
[desc[0] for desc in cur.description]
if cur.description
else []
)
# 限制返回行数
rows = cur.fetchmany(_MAX_ROWS)
# 将元组转为列表,便于 JSON 序列化
rows_list = [list(row) for row in rows]
return QueryResponse(
columns=columns,
rows=rows_list,
row_count=len(rows_list),
)
except HTTPException:
raise
except OperationalError as exc:
# 连接级错误
logger.error("数据库查看器连接错误: %s", exc)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="数据库连接错误",
)
finally:
conn.close()