229 lines
6.9 KiB
Python
229 lines
6.9 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""数据库查看器 API
|
||
|
||
提供 4 个端点:
|
||
- GET /api/db/schemas — 返回 Schema 列表
|
||
- GET /api/db/schemas/{name}/tables — 返回表列表和行数
|
||
- GET /api/db/tables/{schema}/{table}/columns — 返回列定义
|
||
- POST /api/db/query — 只读 SQL 执行
|
||
|
||
所有端点需要 JWT 认证。
|
||
使用 get_etl_readonly_connection(site_id) 确保 RLS 隔离。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
import re
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, status
|
||
from psycopg2 import errors as pg_errors, OperationalError
|
||
|
||
from app.auth.dependencies import CurrentUser, get_current_user
|
||
from app.database import get_etl_readonly_connection
|
||
from app.schemas.db_viewer import (
|
||
ColumnInfo,
|
||
QueryRequest,
|
||
QueryResponse,
|
||
SchemaInfo,
|
||
TableInfo,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(prefix="/api/db", tags=["数据库查看器"])
|
||
|
||
# 写操作关键词(不区分大小写)
|
||
_WRITE_KEYWORDS = re.compile(
|
||
r"\b(INSERT|UPDATE|DELETE|DROP|TRUNCATE)\b",
|
||
re.IGNORECASE,
|
||
)
|
||
|
||
# 查询结果行数上限
|
||
_MAX_ROWS = 1000
|
||
|
||
# 查询超时(秒)
|
||
_QUERY_TIMEOUT_SEC = 30
|
||
|
||
|
||
# ── GET /api/db/schemas ──────────────────────────────────────
|
||
|
||
@router.get("/schemas", response_model=list[SchemaInfo])
|
||
async def list_schemas(
|
||
user: CurrentUser = Depends(get_current_user),
|
||
) -> list[SchemaInfo]:
|
||
"""返回 ETL 数据库中的 Schema 列表。"""
|
||
conn = get_etl_readonly_connection(user.site_id)
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"""
|
||
SELECT schema_name
|
||
FROM information_schema.schemata
|
||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||
ORDER BY schema_name
|
||
"""
|
||
)
|
||
rows = cur.fetchall()
|
||
return [SchemaInfo(name=row[0]) for row in rows]
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
# ── GET /api/db/schemas/{name}/tables ────────────────────────
|
||
|
||
@router.get("/schemas/{name}/tables", response_model=list[TableInfo])
|
||
async def list_tables(
|
||
name: str,
|
||
user: CurrentUser = Depends(get_current_user),
|
||
) -> list[TableInfo]:
|
||
"""返回指定 Schema 下所有表的名称和行数统计。"""
|
||
conn = get_etl_readonly_connection(user.site_id)
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"""
|
||
SELECT
|
||
t.table_name,
|
||
s.n_live_tup
|
||
FROM information_schema.tables t
|
||
LEFT JOIN pg_stat_user_tables s
|
||
ON s.schemaname = t.table_schema
|
||
AND s.relname = t.table_name
|
||
WHERE t.table_schema = %s
|
||
AND t.table_type = 'BASE TABLE'
|
||
ORDER BY t.table_name
|
||
""",
|
||
(name,),
|
||
)
|
||
rows = cur.fetchall()
|
||
return [
|
||
TableInfo(name=row[0], row_count=row[1])
|
||
for row in rows
|
||
]
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
# ── GET /api/db/tables/{schema}/{table}/columns ──────────────
|
||
|
||
@router.get(
|
||
"/tables/{schema}/{table}/columns",
|
||
response_model=list[ColumnInfo],
|
||
)
|
||
async def list_columns(
|
||
schema: str,
|
||
table: str,
|
||
user: CurrentUser = Depends(get_current_user),
|
||
) -> list[ColumnInfo]:
|
||
"""返回指定表的列定义。"""
|
||
conn = get_etl_readonly_connection(user.site_id)
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"""
|
||
SELECT
|
||
column_name,
|
||
data_type,
|
||
is_nullable,
|
||
column_default
|
||
FROM information_schema.columns
|
||
WHERE table_schema = %s AND table_name = %s
|
||
ORDER BY ordinal_position
|
||
""",
|
||
(schema, table),
|
||
)
|
||
rows = cur.fetchall()
|
||
return [
|
||
ColumnInfo(
|
||
name=row[0],
|
||
data_type=row[1],
|
||
is_nullable=row[2] == "YES",
|
||
column_default=row[3],
|
||
)
|
||
for row in rows
|
||
]
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
# ── POST /api/db/query ───────────────────────────────────────
|
||
|
||
@router.post("/query", response_model=QueryResponse)
|
||
async def execute_query(
|
||
body: QueryRequest,
|
||
user: CurrentUser = Depends(get_current_user),
|
||
) -> QueryResponse:
|
||
"""只读 SQL 执行。
|
||
|
||
安全措施:
|
||
1. 拦截写操作关键词(INSERT / UPDATE / DELETE / DROP / TRUNCATE)
|
||
2. 限制返回行数上限 1000 行
|
||
3. 设置查询超时 30 秒
|
||
"""
|
||
sql = body.sql.strip()
|
||
if not sql:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="SQL 语句不能为空",
|
||
)
|
||
|
||
# 拦截写操作
|
||
if _WRITE_KEYWORDS.search(sql):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="只允许只读查询,禁止 INSERT / UPDATE / DELETE / DROP / TRUNCATE 操作",
|
||
)
|
||
|
||
conn = get_etl_readonly_connection(user.site_id)
|
||
try:
|
||
with conn.cursor() as cur:
|
||
# 设置查询超时
|
||
cur.execute(
|
||
"SET LOCAL statement_timeout = %s",
|
||
(f"{_QUERY_TIMEOUT_SEC}s",),
|
||
)
|
||
|
||
try:
|
||
cur.execute(sql)
|
||
except pg_errors.QueryCanceled:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_408_REQUEST_TIMEOUT,
|
||
detail=f"查询超时(超过 {_QUERY_TIMEOUT_SEC} 秒)",
|
||
)
|
||
except Exception as exc:
|
||
# SQL 语法错误或其他执行错误
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=f"SQL 执行错误: {exc}",
|
||
)
|
||
|
||
# 提取列名
|
||
columns = (
|
||
[desc[0] for desc in cur.description]
|
||
if cur.description
|
||
else []
|
||
)
|
||
|
||
# 限制返回行数
|
||
rows = cur.fetchmany(_MAX_ROWS)
|
||
# 将元组转为列表,便于 JSON 序列化
|
||
rows_list = [list(row) for row in rows]
|
||
|
||
return QueryResponse(
|
||
columns=columns,
|
||
rows=rows_list,
|
||
row_count=len(rows_list),
|
||
)
|
||
except HTTPException:
|
||
raise
|
||
except OperationalError as exc:
|
||
# 连接级错误
|
||
logger.error("数据库查看器连接错误: %s", exc)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="数据库连接错误",
|
||
)
|
||
finally:
|
||
conn.close()
|