- admin-web: TaskSelector 重构为按域+层全链路展示,新增同步检查功能 - admin-web: TaskConfig 动态加载 Flow/处理模式定义,DWD 表过滤内嵌域面板 - admin-web: App hydrate 完成前显示 loading,避免误跳 /login - backend: 新增 /tasks/sync-check 对比后端与 ETL 真实注册表 - backend: 新增 /tasks/flows 返回 Flow 和处理模式定义 - apps/mcp-server: 新增 MCP Server 模块(百炼 AI PostgreSQL 只读查询) - scripts/server: 新增 setup-server-git.py + server-exclude.txt - docs: 更新 LAUNCH-CHECKLIST 添加 Git 排除配置步骤 - pyproject.toml: workspace members 新增 mcp-server
413 lines
14 KiB
Python
413 lines
14 KiB
Python
import os
|
||
import re
|
||
import contextlib
|
||
from collections import defaultdict
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
|
||
import sqlparse
|
||
from dotenv import load_dotenv
|
||
from psycopg_pool import ConnectionPool
|
||
|
||
from mcp.server.fastmcp import FastMCP
|
||
from mcp.server.transport_security import TransportSecuritySettings
|
||
|
||
from starlette.applications import Starlette
|
||
from starlette.middleware.base import BaseHTTPMiddleware
|
||
from starlette.requests import Request
|
||
from starlette.responses import JSONResponse
|
||
from starlette.routing import Mount
|
||
|
||
# 加载配置:.env.local > 同级 .env > 项目根 .env
|
||
_here = Path(__file__).resolve().parent
|
||
_root = _here.parent.parent # apps/mcp-server -> apps -> NeoZQYY
|
||
load_dotenv(_here / ".env.local", override=True)
|
||
load_dotenv(_here / ".env", override=False)
|
||
load_dotenv(_root / ".env", override=False)
|
||
|
||
|
||
# ----------------------------
|
||
# 工具:环境变量解析(避免 int("") 报错)
|
||
# ----------------------------
|
||
def env_str(name: str, default: str = "", required: bool = False) -> str:
|
||
v = os.getenv(name, default)
|
||
v = v if v is not None else default
|
||
v = v.strip() if isinstance(v, str) else v
|
||
if required and (v is None or v == ""):
|
||
raise RuntimeError(f"Missing required env var: {name}")
|
||
return v
|
||
|
||
|
||
def env_int(name: str, default: Optional[int] = None, required: bool = False) -> int:
|
||
raw = os.getenv(name, "")
|
||
if raw is None or raw.strip() == "":
|
||
if required and default is None:
|
||
raise RuntimeError(f"Missing required env var: {name}")
|
||
if default is None:
|
||
raise RuntimeError(f"Missing env var: {name}")
|
||
return default
|
||
try:
|
||
return int(raw.strip())
|
||
except ValueError as e:
|
||
raise RuntimeError(f"Invalid int env var {name}={raw!r}") from e
|
||
|
||
|
||
# ----------------------------
|
||
# 配置(用环境变量注入)
|
||
# MCP_PG_* 优先(独立部署),回退到项目公共 DB_* / PG_NAME
|
||
# ----------------------------
|
||
PGHOST = env_str("MCP_PG_HOST", default="") or env_str("DB_HOST", required=True)
|
||
PGPORT = env_int("MCP_PG_PORT", default=0) or env_int("DB_PORT", default=5432)
|
||
PGDATABASE = env_str("MCP_PG_DATABASE", default="") or env_str("ETL_DB_NAME", default="") or env_str("PG_NAME", required=True)
|
||
PGUSER = env_str("MCP_PG_USER", default="") or env_str("DB_USER", required=True)
|
||
PGPASSWORD = env_str("MCP_PG_PASSWORD", default="") or env_str("DB_PASSWORD", required=True)
|
||
|
||
MCP_TOKEN = env_str("MCP_TOKEN", default="") # 鉴权 token(可空:不启用鉴权)
|
||
MAX_ROWS = env_int("MCP_MAX_ROWS", default=500) # query_sql 默认最大行数
|
||
PORT = env_int("PORT", default=9000) # uvicorn 端口
|
||
|
||
# etl_feiqiu 库的六层 schema 架构
|
||
ALLOWED_SCHEMAS = ("ods", "dwd", "dws", "core", "meta", "app")
|
||
ALLOWED_SCHEMA_SET = set(ALLOWED_SCHEMAS)
|
||
|
||
# psycopg DSN(如果密码包含空格等特殊字符,建议改用 URL 形式并做编码)
|
||
DSN = (
|
||
f"host={PGHOST} port={PGPORT} dbname={PGDATABASE} "
|
||
f"user={PGUSER} password={PGPASSWORD}"
|
||
)
|
||
|
||
# 连接池:不要 open=True(避免解释器退出时 __del__ 清理触发异常)
|
||
pool = ConnectionPool(conninfo=DSN, min_size=1, max_size=10, timeout=60, open=False)
|
||
|
||
|
||
# ----------------------------
|
||
# SQL 只读门禁(最终底线仍是 DB 只读账号)
|
||
# ----------------------------
|
||
FORBIDDEN = re.compile(
|
||
r"\b(insert|update|delete|drop|alter|truncate|create|grant|revoke|copy|call|execute|do)\b",
|
||
re.IGNORECASE,
|
||
)
|
||
|
||
# 额外禁止显式跨 schema 访问(避免越权)
|
||
# 匹配 schema.table 模式,但排除单字母别名(如 t.id、o.amount)
|
||
SCHEMA_QUAL = re.compile(r"\b([a-zA-Z_][a-zA-Z0-9_]{1,})\s*\.\s*[a-zA-Z_]", re.IGNORECASE)
|
||
|
||
|
||
def _is_probably_readonly(sql: str) -> bool:
|
||
if FORBIDDEN.search(sql):
|
||
return False
|
||
parsed = sqlparse.parse(sql)
|
||
if not parsed:
|
||
return False
|
||
stmt = parsed[0]
|
||
for tok in stmt.tokens:
|
||
if tok.is_whitespace:
|
||
continue
|
||
first = str(tok).strip().lower()
|
||
return first in ("select", "with", "show", "explain")
|
||
return False
|
||
|
||
|
||
def _validate_schema(schema: str) -> Optional[Dict[str, Any]]:
|
||
if schema not in ALLOWED_SCHEMA_SET:
|
||
return {"error": f"schema 不允许:{schema}。仅允许:{sorted(ALLOWED_SCHEMA_SET)}"}
|
||
return None
|
||
|
||
|
||
def _reject_cross_schema(sql: str, allowed_schema: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
简单防护:如果出现显式 schema 前缀(xxx.),要求必须是白名单内的 schema 或系统 schema。
|
||
注:这不是 SQL parser 级别的严格策略,但能挡住绝大多数越权写法。
|
||
"""
|
||
matches = set(m.group(1) for m in SCHEMA_QUAL.finditer(sql or ""))
|
||
# 允许所有业务 schema + 系统 schema
|
||
safe = ALLOWED_SCHEMA_SET | {"pg_catalog", "information_schema"}
|
||
bad = sorted([s for s in matches if s.lower() not in {a.lower() for a in safe}])
|
||
if bad:
|
||
return {"error": f"SQL 被拒绝:检测到不允许的 schema 引用 {bad},仅允许 {sorted(ALLOWED_SCHEMA_SET)} / 系统 schema。"}
|
||
return None
|
||
|
||
|
||
# ----------------------------
|
||
# FastMCP:Streamable HTTP + JSON 响应
|
||
# ----------------------------
|
||
mcp = FastMCP(
|
||
"postgres-mcp",
|
||
stateless_http=True,
|
||
json_response=True,
|
||
transport_security=TransportSecuritySettings(
|
||
enable_dns_rebinding_protection=True,
|
||
allowed_hosts=[
|
||
# 关键:既允许不带端口,也允许带端口
|
||
"mcp.langlangzhuoqiu.cn",
|
||
"mcp.langlangzhuoqiu.cn:*",
|
||
"localhost",
|
||
"localhost:*",
|
||
"127.0.0.1",
|
||
"127.0.0.1:*",
|
||
"100.64.0.4",
|
||
"100.64.0.4:*",
|
||
"100.64.0.1",
|
||
"100.64.0.1:*",
|
||
"106.52.16.235",
|
||
"106.52.16.235:*",
|
||
],
|
||
allowed_origins=[
|
||
"https://mcp.langlangzhuoqiu.cn",
|
||
"https://mcp.langlangzhuoqiu.cn:*",
|
||
"http://localhost",
|
||
"http://localhost:*",
|
||
"http://127.0.0.1",
|
||
"http://127.0.0.1:*",
|
||
],
|
||
),
|
||
)
|
||
|
||
|
||
# ----------------------------
|
||
# Tools:面向 etl_feiqiu 六层 schema
|
||
# ----------------------------
|
||
@mcp.tool()
|
||
def list_tables(schema: str = "dwd", include_views: bool = False) -> Dict[str, Any]:
|
||
"""列出指定 schema(ods/dwd/dws/core/meta/app)下的表(可选包含视图)"""
|
||
err = _validate_schema(schema)
|
||
if err:
|
||
return err
|
||
|
||
table_types = ("BASE TABLE", "VIEW") if include_views else ("BASE TABLE",)
|
||
sql = """
|
||
SELECT table_name, table_type
|
||
FROM information_schema.tables
|
||
WHERE table_schema = %s AND table_type = ANY(%s)
|
||
ORDER BY table_name;
|
||
"""
|
||
with pool.connection() as conn:
|
||
with conn.cursor() as cur:
|
||
cur.execute(sql, (schema, list(table_types)))
|
||
rows = cur.fetchall()
|
||
|
||
return {
|
||
"schema": schema,
|
||
"include_views": include_views,
|
||
"tables": [{"name": r[0], "type": r[1]} for r in rows],
|
||
"table_count": len(rows),
|
||
}
|
||
|
||
|
||
@mcp.tool()
|
||
def describe_table(table: str, schema: str = "dwd") -> Dict[str, Any]:
|
||
"""查看表结构(字段、类型、是否可空、默认值)"""
|
||
err = _validate_schema(schema)
|
||
if err:
|
||
return err
|
||
|
||
sql = """
|
||
SELECT column_name, data_type, is_nullable, column_default, ordinal_position
|
||
FROM information_schema.columns
|
||
WHERE table_schema=%s AND table_name=%s
|
||
ORDER BY ordinal_position;
|
||
"""
|
||
with pool.connection() as conn:
|
||
with conn.cursor() as cur:
|
||
cur.execute(sql, (schema, table))
|
||
rows = cur.fetchall()
|
||
|
||
return {
|
||
"schema": schema,
|
||
"table": table,
|
||
"columns": [
|
||
{"name": r[0], "type": r[1], "nullable": r[2], "default": r[3], "position": r[4]}
|
||
for r in rows
|
||
],
|
||
"column_count": len(rows),
|
||
}
|
||
|
||
|
||
@mcp.tool()
|
||
def describe_schemas(
|
||
schemas: Optional[List[str]] = None,
|
||
include_views: bool = False,
|
||
max_tables_per_schema: int = 500,
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
返回 ods/dwd/dws/core/meta/app schema 下的表结构(含主键)。
|
||
不传 schemas 则返回全部六个 schema。
|
||
"""
|
||
schemas = schemas or list(ALLOWED_SCHEMAS)
|
||
|
||
invalid = [s for s in schemas if s not in ALLOWED_SCHEMA_SET]
|
||
if invalid:
|
||
return {"error": f"存在不允许的 schema:{invalid}。仅允许:{sorted(ALLOWED_SCHEMA_SET)}"}
|
||
|
||
table_types = ("BASE TABLE", "VIEW") if include_views else ("BASE TABLE",)
|
||
|
||
with pool.connection() as conn:
|
||
with conn.cursor() as cur:
|
||
# 1) 表清单
|
||
cur.execute(
|
||
"""
|
||
SELECT table_schema, table_name, table_type
|
||
FROM information_schema.tables
|
||
WHERE table_schema = ANY(%s)
|
||
AND table_type = ANY(%s)
|
||
ORDER BY table_schema, table_name;
|
||
""",
|
||
(schemas, list(table_types)),
|
||
)
|
||
table_rows = cur.fetchall()
|
||
|
||
tables_by_schema: Dict[str, List[Tuple[str, str]]] = defaultdict(list)
|
||
for s, t, tt in table_rows:
|
||
if len(tables_by_schema[s]) < max_tables_per_schema:
|
||
tables_by_schema[s].append((t, tt))
|
||
|
||
# 2) 所有列(一次性取;如表非常多,可考虑拆分/分页)
|
||
cur.execute(
|
||
"""
|
||
SELECT table_schema, table_name, column_name, data_type, is_nullable, column_default, ordinal_position
|
||
FROM information_schema.columns
|
||
WHERE table_schema = ANY(%s)
|
||
ORDER BY table_schema, table_name, ordinal_position;
|
||
""",
|
||
(schemas,),
|
||
)
|
||
col_rows = cur.fetchall()
|
||
|
||
cols_map: Dict[Tuple[str, str], List[Dict[str, Any]]] = defaultdict(list)
|
||
for s, t, c, dt, nul, dft, pos in col_rows:
|
||
cols_map[(s, t)].append(
|
||
{"name": c, "type": dt, "nullable": nul, "default": dft, "position": pos}
|
||
)
|
||
|
||
# 3) 主键
|
||
cur.execute(
|
||
"""
|
||
SELECT kcu.table_schema, kcu.table_name, kcu.column_name, kcu.ordinal_position
|
||
FROM information_schema.table_constraints tc
|
||
JOIN information_schema.key_column_usage kcu
|
||
ON tc.constraint_name = kcu.constraint_name
|
||
AND tc.table_schema = kcu.table_schema
|
||
AND tc.table_name = kcu.table_name
|
||
WHERE tc.constraint_type = 'PRIMARY KEY'
|
||
AND tc.table_schema = ANY(%s)
|
||
ORDER BY kcu.table_schema, kcu.table_name, kcu.ordinal_position;
|
||
""",
|
||
(schemas,),
|
||
)
|
||
pk_rows = cur.fetchall()
|
||
|
||
pk_map: Dict[Tuple[str, str], List[str]] = defaultdict(list)
|
||
for s, t, col, _pos in pk_rows:
|
||
pk_map[(s, t)].append(col)
|
||
|
||
# 4) 组装
|
||
result: Dict[str, Any] = {
|
||
"schemas": {},
|
||
"include_views": include_views,
|
||
"limits": {"max_tables_per_schema": max_tables_per_schema},
|
||
}
|
||
|
||
for s in schemas:
|
||
schema_tables = tables_by_schema.get(s, [])
|
||
result["schemas"][s] = {"table_count": len(schema_tables), "tables": {}}
|
||
for t, tt in schema_tables:
|
||
key = (s, t)
|
||
result["schemas"][s]["tables"][t] = {
|
||
"type": tt,
|
||
"primary_key": pk_map.get(key, []),
|
||
"columns": cols_map.get(key, []),
|
||
"column_count": len(cols_map.get(key, [])),
|
||
}
|
||
|
||
return result
|
||
|
||
|
||
@mcp.tool()
|
||
def query_sql(schema: str, sql: str, max_rows: int = MAX_ROWS) -> Dict[str, Any]:
|
||
"""
|
||
在指定 schema 内执行只读 SQL(会 SET LOCAL search_path),并限制显式跨 schema 引用。
|
||
"""
|
||
err = _validate_schema(schema)
|
||
if err:
|
||
return err
|
||
|
||
sql = (sql or "").strip().rstrip(";")
|
||
if not _is_probably_readonly(sql):
|
||
return {"error": "SQL 被拒绝:仅允许只读(select/with/show/explain)并禁止危险关键字。"}
|
||
|
||
cross = _reject_cross_schema(sql, allowed_schema=schema)
|
||
if cross:
|
||
return cross
|
||
|
||
with pool.connection() as conn:
|
||
with conn.cursor() as cur:
|
||
# schema 已白名单校验,可安全拼接
|
||
cur.execute(f"SET LOCAL search_path TO {schema}")
|
||
cur.execute(sql)
|
||
|
||
cols = [d.name for d in (cur.description or [])]
|
||
rows = cur.fetchmany(max_rows + 1)
|
||
|
||
truncated = len(rows) > max_rows
|
||
rows = rows[:max_rows]
|
||
|
||
safe_rows: List[List[Any]] = []
|
||
for r in rows:
|
||
safe_rows.append([v if isinstance(v, (int, float, str, bool)) or v is None else str(v) for v in r])
|
||
|
||
return {
|
||
"schema": schema,
|
||
"columns": cols,
|
||
"rows": safe_rows,
|
||
"row_count": len(safe_rows),
|
||
"truncated": truncated,
|
||
"max_rows": max_rows,
|
||
}
|
||
|
||
|
||
# ----------------------------
|
||
# 鉴权 Middleware(支持 Bearer 或 query token)
|
||
# ----------------------------
|
||
class AuthMiddleware(BaseHTTPMiddleware):
|
||
async def dispatch(self, request: Request, call_next):
|
||
if MCP_TOKEN and request.url.path.startswith("/mcp"):
|
||
auth = request.headers.get("authorization", "")
|
||
token_q = request.query_params.get("token", "")
|
||
if auth != f"Bearer {MCP_TOKEN}" and token_q != MCP_TOKEN:
|
||
return JSONResponse({"error": "unauthorized"}, status_code=401)
|
||
return await call_next(request)
|
||
|
||
|
||
# ----------------------------
|
||
# lifespan:显式 open/close pool,并运行 session_manager
|
||
# ----------------------------
|
||
@contextlib.asynccontextmanager
|
||
async def lifespan(app: Starlette):
|
||
pool.open(wait=True, timeout=30)
|
||
try:
|
||
async with mcp.session_manager.run():
|
||
yield
|
||
finally:
|
||
# 避免解释器退出阶段 __del__ 清理导致异常
|
||
pool.close(timeout=5)
|
||
|
||
|
||
# MCP endpoint:/mcp(默认 streamable_http_path="/mcp")
|
||
app = Starlette(
|
||
routes=[Mount("/", app=mcp.streamable_http_app())],
|
||
lifespan=lifespan,
|
||
)
|
||
app.add_middleware(AuthMiddleware)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
|
||
uvicorn.run(
|
||
app,
|
||
host="0.0.0.0",
|
||
port=PORT,
|
||
proxy_headers=True,
|
||
forwarded_allow_ips="*",
|
||
)
|