import os import re import contextlib from collections import defaultdict from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import sqlparse from dotenv import load_dotenv from psycopg_pool import ConnectionPool from mcp.server.fastmcp import FastMCP from mcp.server.transport_security import TransportSecuritySettings from starlette.applications import Starlette from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Mount # 加载配置:.env.local > 同级 .env > 项目根 .env _here = Path(__file__).resolve().parent _root = _here.parent.parent # apps/mcp-server -> apps -> NeoZQYY load_dotenv(_here / ".env.local", override=True) load_dotenv(_here / ".env", override=False) load_dotenv(_root / ".env", override=False) # ---------------------------- # 工具:环境变量解析(避免 int("") 报错) # ---------------------------- def env_str(name: str, default: str = "", required: bool = False) -> str: v = os.getenv(name, default) v = v if v is not None else default v = v.strip() if isinstance(v, str) else v if required and (v is None or v == ""): raise RuntimeError(f"Missing required env var: {name}") return v def env_int(name: str, default: Optional[int] = None, required: bool = False) -> int: raw = os.getenv(name, "") if raw is None or raw.strip() == "": if required and default is None: raise RuntimeError(f"Missing required env var: {name}") if default is None: raise RuntimeError(f"Missing env var: {name}") return default try: return int(raw.strip()) except ValueError as e: raise RuntimeError(f"Invalid int env var {name}={raw!r}") from e # ---------------------------- # 配置(用环境变量注入) # MCP_PG_* 优先(独立部署),回退到项目公共 DB_* / ETL_DB_NAME / PG_NAME(分离式配置) # ---------------------------- PGHOST = env_str("MCP_PG_HOST", default="") or env_str("DB_HOST", required=True) PGPORT = env_int("MCP_PG_PORT", default=0) or env_int("DB_PORT", default=5432) PGDATABASE = env_str("MCP_PG_DATABASE", default="") or env_str("ETL_DB_NAME", default="") or env_str("PG_NAME", required=True) PGUSER = env_str("MCP_PG_USER", default="") or env_str("DB_USER", required=True) PGPASSWORD = env_str("MCP_PG_PASSWORD", default="") or env_str("DB_PASSWORD", required=True) MCP_TOKEN = env_str("MCP_TOKEN", default="") # 鉴权 token(可空:不启用鉴权) MAX_ROWS = env_int("MCP_MAX_ROWS", default=500) # query_sql 默认最大行数 PORT = env_int("PORT", default=9000) # uvicorn 端口 # etl_feiqiu 库的六层 schema 架构 ALLOWED_SCHEMAS = ("ods", "dwd", "dws", "core", "meta", "app") ALLOWED_SCHEMA_SET = set(ALLOWED_SCHEMAS) # psycopg DSN(如果密码包含空格等特殊字符,建议改用 URL 形式并做编码) DSN = ( f"host={PGHOST} port={PGPORT} dbname={PGDATABASE} " f"user={PGUSER} password={PGPASSWORD}" ) # 连接池:不要 open=True(避免解释器退出时 __del__ 清理触发异常) pool = ConnectionPool(conninfo=DSN, min_size=1, max_size=10, timeout=60, open=False) # ---------------------------- # SQL 只读门禁(最终底线仍是 DB 只读账号) # ---------------------------- FORBIDDEN = re.compile( r"\b(insert|update|delete|drop|alter|truncate|create|grant|revoke|copy|call|execute|do)\b", re.IGNORECASE, ) # 额外禁止显式跨 schema 访问(避免越权) # 匹配 schema.table 模式,但排除单字母别名(如 t.id、o.amount) SCHEMA_QUAL = re.compile(r"\b([a-zA-Z_][a-zA-Z0-9_]{1,})\s*\.\s*[a-zA-Z_]", re.IGNORECASE) def _is_probably_readonly(sql: str) -> bool: if FORBIDDEN.search(sql): return False parsed = sqlparse.parse(sql) if not parsed: return False stmt = parsed[0] for tok in stmt.tokens: if tok.is_whitespace: continue first = str(tok).strip().lower() return first in ("select", "with", "show", "explain") return False def _validate_schema(schema: str) -> Optional[Dict[str, Any]]: if schema not in ALLOWED_SCHEMA_SET: return {"error": f"schema 不允许:{schema}。仅允许:{sorted(ALLOWED_SCHEMA_SET)}"} return None def _reject_cross_schema(sql: str, allowed_schema: str) -> Optional[Dict[str, Any]]: """ 简单防护:如果出现显式 schema 前缀(xxx.),要求必须是白名单内的 schema 或系统 schema。 注:这不是 SQL parser 级别的严格策略,但能挡住绝大多数越权写法。 """ matches = set(m.group(1) for m in SCHEMA_QUAL.finditer(sql or "")) # 允许所有业务 schema + 系统 schema safe = ALLOWED_SCHEMA_SET | {"pg_catalog", "information_schema"} bad = sorted([s for s in matches if s.lower() not in {a.lower() for a in safe}]) if bad: return {"error": f"SQL 被拒绝:检测到不允许的 schema 引用 {bad},仅允许 {sorted(ALLOWED_SCHEMA_SET)} / 系统 schema。"} return None # ---------------------------- # FastMCP:Streamable HTTP + JSON 响应 # ---------------------------- mcp = FastMCP( "postgres-mcp", stateless_http=True, json_response=True, transport_security=TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=[ # 关键:既允许不带端口,也允许带端口 "mcp.langlangzhuoqiu.cn", "mcp.langlangzhuoqiu.cn:*", "localhost", "localhost:*", "127.0.0.1", "127.0.0.1:*", "100.64.0.4", "100.64.0.4:*", "100.64.0.1", "100.64.0.1:*", "106.52.16.235", "106.52.16.235:*", ], allowed_origins=[ "https://mcp.langlangzhuoqiu.cn", "https://mcp.langlangzhuoqiu.cn:*", "http://localhost", "http://localhost:*", "http://127.0.0.1", "http://127.0.0.1:*", ], ), ) # ---------------------------- # Tools:面向 etl_feiqiu 六层 schema # ---------------------------- @mcp.tool() def list_tables(schema: str = "dwd", include_views: bool = False) -> Dict[str, Any]: """列出指定 schema(ods/dwd/dws/core/meta/app)下的表(可选包含视图)""" err = _validate_schema(schema) if err: return err table_types = ("BASE TABLE", "VIEW") if include_views else ("BASE TABLE",) sql = """ SELECT table_name, table_type FROM information_schema.tables WHERE table_schema = %s AND table_type = ANY(%s) ORDER BY table_name; """ with pool.connection() as conn: with conn.cursor() as cur: cur.execute(sql, (schema, list(table_types))) rows = cur.fetchall() return { "schema": schema, "include_views": include_views, "tables": [{"name": r[0], "type": r[1]} for r in rows], "table_count": len(rows), } @mcp.tool() def describe_table(table: str, schema: str = "dwd") -> Dict[str, Any]: """查看表结构(字段、类型、是否可空、默认值)""" err = _validate_schema(schema) if err: return err sql = """ SELECT column_name, data_type, is_nullable, column_default, ordinal_position FROM information_schema.columns WHERE table_schema=%s AND table_name=%s ORDER BY ordinal_position; """ with pool.connection() as conn: with conn.cursor() as cur: cur.execute(sql, (schema, table)) rows = cur.fetchall() return { "schema": schema, "table": table, "columns": [ {"name": r[0], "type": r[1], "nullable": r[2], "default": r[3], "position": r[4]} for r in rows ], "column_count": len(rows), } @mcp.tool() def describe_schemas( schemas: Optional[List[str]] = None, include_views: bool = False, max_tables_per_schema: int = 500, ) -> Dict[str, Any]: """ 返回 ods/dwd/dws/core/meta/app schema 下的表结构(含主键)。 不传 schemas 则返回全部六个 schema。 """ schemas = schemas or list(ALLOWED_SCHEMAS) invalid = [s for s in schemas if s not in ALLOWED_SCHEMA_SET] if invalid: return {"error": f"存在不允许的 schema:{invalid}。仅允许:{sorted(ALLOWED_SCHEMA_SET)}"} table_types = ("BASE TABLE", "VIEW") if include_views else ("BASE TABLE",) with pool.connection() as conn: with conn.cursor() as cur: # 1) 表清单 cur.execute( """ SELECT table_schema, table_name, table_type FROM information_schema.tables WHERE table_schema = ANY(%s) AND table_type = ANY(%s) ORDER BY table_schema, table_name; """, (schemas, list(table_types)), ) table_rows = cur.fetchall() tables_by_schema: Dict[str, List[Tuple[str, str]]] = defaultdict(list) for s, t, tt in table_rows: if len(tables_by_schema[s]) < max_tables_per_schema: tables_by_schema[s].append((t, tt)) # 2) 所有列(一次性取;如表非常多,可考虑拆分/分页) cur.execute( """ SELECT table_schema, table_name, column_name, data_type, is_nullable, column_default, ordinal_position FROM information_schema.columns WHERE table_schema = ANY(%s) ORDER BY table_schema, table_name, ordinal_position; """, (schemas,), ) col_rows = cur.fetchall() cols_map: Dict[Tuple[str, str], List[Dict[str, Any]]] = defaultdict(list) for s, t, c, dt, nul, dft, pos in col_rows: cols_map[(s, t)].append( {"name": c, "type": dt, "nullable": nul, "default": dft, "position": pos} ) # 3) 主键 cur.execute( """ SELECT kcu.table_schema, kcu.table_name, kcu.column_name, kcu.ordinal_position FROM information_schema.table_constraints tc JOIN information_schema.key_column_usage kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema AND tc.table_name = kcu.table_name WHERE tc.constraint_type = 'PRIMARY KEY' AND tc.table_schema = ANY(%s) ORDER BY kcu.table_schema, kcu.table_name, kcu.ordinal_position; """, (schemas,), ) pk_rows = cur.fetchall() pk_map: Dict[Tuple[str, str], List[str]] = defaultdict(list) for s, t, col, _pos in pk_rows: pk_map[(s, t)].append(col) # 4) 组装 result: Dict[str, Any] = { "schemas": {}, "include_views": include_views, "limits": {"max_tables_per_schema": max_tables_per_schema}, } for s in schemas: schema_tables = tables_by_schema.get(s, []) result["schemas"][s] = {"table_count": len(schema_tables), "tables": {}} for t, tt in schema_tables: key = (s, t) result["schemas"][s]["tables"][t] = { "type": tt, "primary_key": pk_map.get(key, []), "columns": cols_map.get(key, []), "column_count": len(cols_map.get(key, [])), } return result @mcp.tool() def query_sql(schema: str, sql: str, max_rows: int = MAX_ROWS) -> Dict[str, Any]: """ 在指定 schema 内执行只读 SQL(会 SET LOCAL search_path),并限制显式跨 schema 引用。 """ err = _validate_schema(schema) if err: return err sql = (sql or "").strip().rstrip(";") if not _is_probably_readonly(sql): return {"error": "SQL 被拒绝:仅允许只读(select/with/show/explain)并禁止危险关键字。"} cross = _reject_cross_schema(sql, allowed_schema=schema) if cross: return cross with pool.connection() as conn: with conn.cursor() as cur: # schema 已白名单校验,可安全拼接 cur.execute(f"SET LOCAL search_path TO {schema}") cur.execute(sql) cols = [d.name for d in (cur.description or [])] rows = cur.fetchmany(max_rows + 1) truncated = len(rows) > max_rows rows = rows[:max_rows] safe_rows: List[List[Any]] = [] for r in rows: safe_rows.append([v if isinstance(v, (int, float, str, bool)) or v is None else str(v) for v in r]) return { "schema": schema, "columns": cols, "rows": safe_rows, "row_count": len(safe_rows), "truncated": truncated, "max_rows": max_rows, } # ---------------------------- # 鉴权 Middleware(支持 Bearer 或 query token) # ---------------------------- class AuthMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): if MCP_TOKEN and request.url.path.startswith("/mcp"): auth = request.headers.get("authorization", "") token_q = request.query_params.get("token", "") if auth != f"Bearer {MCP_TOKEN}" and token_q != MCP_TOKEN: return JSONResponse({"error": "unauthorized"}, status_code=401) return await call_next(request) # ---------------------------- # lifespan:显式 open/close pool,并运行 session_manager # ---------------------------- @contextlib.asynccontextmanager async def lifespan(app: Starlette): pool.open(wait=True, timeout=30) try: async with mcp.session_manager.run(): yield finally: # 避免解释器退出阶段 __del__ 清理导致异常 pool.close(timeout=5) # MCP endpoint:/mcp(默认 streamable_http_path="/mcp") app = Starlette( routes=[Mount("/", app=mcp.streamable_http_app())], lifespan=lifespan, ) app.add_middleware(AuthMiddleware) if __name__ == "__main__": import uvicorn uvicorn.run( app, host="0.0.0.0", port=PORT, proxy_headers=True, forwarded_allow_ips="*", )