# -*- coding: utf-8 -*- """数据库查看器 API 提供 4 个端点: - GET /api/db/schemas — 返回 Schema 列表 - GET /api/db/schemas/{name}/tables — 返回表列表和行数 - GET /api/db/tables/{schema}/{table}/columns — 返回列定义 - POST /api/db/query — 只读 SQL 执行 所有端点需要 JWT 认证。 使用 get_etl_readonly_connection(site_id) 确保 RLS 隔离。 """ from __future__ import annotations import logging import re from fastapi import APIRouter, Depends, HTTPException, status from psycopg2 import errors as pg_errors, OperationalError from app.auth.dependencies import CurrentUser, get_current_user from app.database import get_etl_readonly_connection from app.schemas.db_viewer import ( ColumnInfo, QueryRequest, QueryResponse, SchemaInfo, TableInfo, ) logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/db", tags=["数据库查看器"]) # 写操作关键词(不区分大小写) _WRITE_KEYWORDS = re.compile( r"\b(INSERT|UPDATE|DELETE|DROP|TRUNCATE)\b", re.IGNORECASE, ) # 查询结果行数上限 _MAX_ROWS = 1000 # 查询超时(秒) _QUERY_TIMEOUT_SEC = 30 # ── GET /api/db/schemas ────────────────────────────────────── @router.get("/schemas", response_model=list[SchemaInfo]) async def list_schemas( user: CurrentUser = Depends(get_current_user), ) -> list[SchemaInfo]: """返回 ETL 数据库中的 Schema 列表。""" conn = get_etl_readonly_connection(user.site_id) try: with conn.cursor() as cur: cur.execute( """ SELECT schema_name FROM information_schema.schemata WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast') ORDER BY schema_name """ ) rows = cur.fetchall() return [SchemaInfo(name=row[0]) for row in rows] finally: conn.close() # ── GET /api/db/schemas/{name}/tables ──────────────────────── @router.get("/schemas/{name}/tables", response_model=list[TableInfo]) async def list_tables( name: str, user: CurrentUser = Depends(get_current_user), ) -> list[TableInfo]: """返回指定 Schema 下所有表的名称和行数统计。""" conn = get_etl_readonly_connection(user.site_id) try: with conn.cursor() as cur: cur.execute( """ SELECT t.table_name, s.n_live_tup FROM information_schema.tables t LEFT JOIN pg_stat_user_tables s ON s.schemaname = t.table_schema AND s.relname = t.table_name WHERE t.table_schema = %s AND t.table_type = 'BASE TABLE' ORDER BY t.table_name """, (name,), ) rows = cur.fetchall() return [ TableInfo(name=row[0], row_count=row[1]) for row in rows ] finally: conn.close() # ── GET /api/db/tables/{schema}/{table}/columns ────────────── @router.get( "/tables/{schema}/{table}/columns", response_model=list[ColumnInfo], ) async def list_columns( schema: str, table: str, user: CurrentUser = Depends(get_current_user), ) -> list[ColumnInfo]: """返回指定表的列定义。""" conn = get_etl_readonly_connection(user.site_id) try: with conn.cursor() as cur: cur.execute( """ SELECT column_name, data_type, is_nullable, column_default FROM information_schema.columns WHERE table_schema = %s AND table_name = %s ORDER BY ordinal_position """, (schema, table), ) rows = cur.fetchall() return [ ColumnInfo( name=row[0], data_type=row[1], is_nullable=row[2] == "YES", column_default=row[3], ) for row in rows ] finally: conn.close() # ── POST /api/db/query ─────────────────────────────────────── @router.post("/query", response_model=QueryResponse) async def execute_query( body: QueryRequest, user: CurrentUser = Depends(get_current_user), ) -> QueryResponse: """只读 SQL 执行。 安全措施: 1. 拦截写操作关键词(INSERT / UPDATE / DELETE / DROP / TRUNCATE) 2. 限制返回行数上限 1000 行 3. 设置查询超时 30 秒 """ sql = body.sql.strip() if not sql: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="SQL 语句不能为空", ) # 拦截写操作 if _WRITE_KEYWORDS.search(sql): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="只允许只读查询,禁止 INSERT / UPDATE / DELETE / DROP / TRUNCATE 操作", ) conn = get_etl_readonly_connection(user.site_id) try: with conn.cursor() as cur: # 设置查询超时 cur.execute( "SET LOCAL statement_timeout = %s", (f"{_QUERY_TIMEOUT_SEC}s",), ) try: cur.execute(sql) except pg_errors.QueryCanceled: raise HTTPException( status_code=status.HTTP_408_REQUEST_TIMEOUT, detail=f"查询超时(超过 {_QUERY_TIMEOUT_SEC} 秒)", ) except Exception as exc: # SQL 语法错误或其他执行错误 raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"SQL 执行错误: {exc}", ) # 提取列名 columns = ( [desc[0] for desc in cur.description] if cur.description else [] ) # 限制返回行数 rows = cur.fetchmany(_MAX_ROWS) # 将元组转为列表,便于 JSON 序列化 rows_list = [list(row) for row in rows] return QueryResponse( columns=columns, rows=rows_list, row_count=len(rows_list), ) except HTTPException: raise except OperationalError as exc: # 连接级错误 logger.error("数据库查看器连接错误: %s", exc) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="数据库连接错误", ) finally: conn.close()