feat: TaskSelector v2 全链路展示 + 同步检查 + MCP Server + 服务器 Git 排除
- admin-web: TaskSelector 重构为按域+层全链路展示,新增同步检查功能 - admin-web: TaskConfig 动态加载 Flow/处理模式定义,DWD 表过滤内嵌域面板 - admin-web: App hydrate 完成前显示 loading,避免误跳 /login - backend: 新增 /tasks/sync-check 对比后端与 ETL 真实注册表 - backend: 新增 /tasks/flows 返回 Flow 和处理模式定义 - apps/mcp-server: 新增 MCP Server 模块(百炼 AI PostgreSQL 只读查询) - scripts/server: 新增 setup-server-git.py + server-exclude.txt - docs: 更新 LAUNCH-CHECKLIST 添加 Git 排除配置步骤 - pyproject.toml: workspace members 新增 mcp-server
This commit is contained in:
412
apps/mcp-server/server.py
Normal file
412
apps/mcp-server/server.py
Normal file
@@ -0,0 +1,412 @@
|
||||
import os
|
||||
import re
|
||||
import contextlib
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import sqlparse
|
||||
from dotenv import load_dotenv
|
||||
from psycopg_pool import ConnectionPool
|
||||
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.server.transport_security import TransportSecuritySettings
|
||||
|
||||
from starlette.applications import Starlette
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Mount
|
||||
|
||||
# 加载配置:.env.local > 同级 .env > 项目根 .env
|
||||
_here = Path(__file__).resolve().parent
|
||||
_root = _here.parent.parent # apps/mcp-server -> apps -> NeoZQYY
|
||||
load_dotenv(_here / ".env.local", override=True)
|
||||
load_dotenv(_here / ".env", override=False)
|
||||
load_dotenv(_root / ".env", override=False)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# 工具:环境变量解析(避免 int("") 报错)
|
||||
# ----------------------------
|
||||
def env_str(name: str, default: str = "", required: bool = False) -> str:
|
||||
v = os.getenv(name, default)
|
||||
v = v if v is not None else default
|
||||
v = v.strip() if isinstance(v, str) else v
|
||||
if required and (v is None or v == ""):
|
||||
raise RuntimeError(f"Missing required env var: {name}")
|
||||
return v
|
||||
|
||||
|
||||
def env_int(name: str, default: Optional[int] = None, required: bool = False) -> int:
|
||||
raw = os.getenv(name, "")
|
||||
if raw is None or raw.strip() == "":
|
||||
if required and default is None:
|
||||
raise RuntimeError(f"Missing required env var: {name}")
|
||||
if default is None:
|
||||
raise RuntimeError(f"Missing env var: {name}")
|
||||
return default
|
||||
try:
|
||||
return int(raw.strip())
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Invalid int env var {name}={raw!r}") from e
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# 配置(用环境变量注入)
|
||||
# MCP_PG_* 优先(独立部署),回退到项目公共 DB_* / PG_NAME
|
||||
# ----------------------------
|
||||
PGHOST = env_str("MCP_PG_HOST", default="") or env_str("DB_HOST", required=True)
|
||||
PGPORT = env_int("MCP_PG_PORT", default=0) or env_int("DB_PORT", default=5432)
|
||||
PGDATABASE = env_str("MCP_PG_DATABASE", default="") or env_str("ETL_DB_NAME", default="") or env_str("PG_NAME", required=True)
|
||||
PGUSER = env_str("MCP_PG_USER", default="") or env_str("DB_USER", required=True)
|
||||
PGPASSWORD = env_str("MCP_PG_PASSWORD", default="") or env_str("DB_PASSWORD", required=True)
|
||||
|
||||
MCP_TOKEN = env_str("MCP_TOKEN", default="") # 鉴权 token(可空:不启用鉴权)
|
||||
MAX_ROWS = env_int("MCP_MAX_ROWS", default=500) # query_sql 默认最大行数
|
||||
PORT = env_int("PORT", default=9000) # uvicorn 端口
|
||||
|
||||
# etl_feiqiu 库的六层 schema 架构
|
||||
ALLOWED_SCHEMAS = ("ods", "dwd", "dws", "core", "meta", "app")
|
||||
ALLOWED_SCHEMA_SET = set(ALLOWED_SCHEMAS)
|
||||
|
||||
# psycopg DSN(如果密码包含空格等特殊字符,建议改用 URL 形式并做编码)
|
||||
DSN = (
|
||||
f"host={PGHOST} port={PGPORT} dbname={PGDATABASE} "
|
||||
f"user={PGUSER} password={PGPASSWORD}"
|
||||
)
|
||||
|
||||
# 连接池:不要 open=True(避免解释器退出时 __del__ 清理触发异常)
|
||||
pool = ConnectionPool(conninfo=DSN, min_size=1, max_size=10, timeout=60, open=False)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# SQL 只读门禁(最终底线仍是 DB 只读账号)
|
||||
# ----------------------------
|
||||
FORBIDDEN = re.compile(
|
||||
r"\b(insert|update|delete|drop|alter|truncate|create|grant|revoke|copy|call|execute|do)\b",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# 额外禁止显式跨 schema 访问(避免越权)
|
||||
# 匹配 schema.table 模式,但排除单字母别名(如 t.id、o.amount)
|
||||
SCHEMA_QUAL = re.compile(r"\b([a-zA-Z_][a-zA-Z0-9_]{1,})\s*\.\s*[a-zA-Z_]", re.IGNORECASE)
|
||||
|
||||
|
||||
def _is_probably_readonly(sql: str) -> bool:
|
||||
if FORBIDDEN.search(sql):
|
||||
return False
|
||||
parsed = sqlparse.parse(sql)
|
||||
if not parsed:
|
||||
return False
|
||||
stmt = parsed[0]
|
||||
for tok in stmt.tokens:
|
||||
if tok.is_whitespace:
|
||||
continue
|
||||
first = str(tok).strip().lower()
|
||||
return first in ("select", "with", "show", "explain")
|
||||
return False
|
||||
|
||||
|
||||
def _validate_schema(schema: str) -> Optional[Dict[str, Any]]:
|
||||
if schema not in ALLOWED_SCHEMA_SET:
|
||||
return {"error": f"schema 不允许:{schema}。仅允许:{sorted(ALLOWED_SCHEMA_SET)}"}
|
||||
return None
|
||||
|
||||
|
||||
def _reject_cross_schema(sql: str, allowed_schema: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
简单防护:如果出现显式 schema 前缀(xxx.),要求必须是白名单内的 schema 或系统 schema。
|
||||
注:这不是 SQL parser 级别的严格策略,但能挡住绝大多数越权写法。
|
||||
"""
|
||||
matches = set(m.group(1) for m in SCHEMA_QUAL.finditer(sql or ""))
|
||||
# 允许所有业务 schema + 系统 schema
|
||||
safe = ALLOWED_SCHEMA_SET | {"pg_catalog", "information_schema"}
|
||||
bad = sorted([s for s in matches if s.lower() not in {a.lower() for a in safe}])
|
||||
if bad:
|
||||
return {"error": f"SQL 被拒绝:检测到不允许的 schema 引用 {bad},仅允许 {sorted(ALLOWED_SCHEMA_SET)} / 系统 schema。"}
|
||||
return None
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# FastMCP:Streamable HTTP + JSON 响应
|
||||
# ----------------------------
|
||||
mcp = FastMCP(
|
||||
"postgres-mcp",
|
||||
stateless_http=True,
|
||||
json_response=True,
|
||||
transport_security=TransportSecuritySettings(
|
||||
enable_dns_rebinding_protection=True,
|
||||
allowed_hosts=[
|
||||
# 关键:既允许不带端口,也允许带端口
|
||||
"mcp.langlangzhuoqiu.cn",
|
||||
"mcp.langlangzhuoqiu.cn:*",
|
||||
"localhost",
|
||||
"localhost:*",
|
||||
"127.0.0.1",
|
||||
"127.0.0.1:*",
|
||||
"100.64.0.4",
|
||||
"100.64.0.4:*",
|
||||
"100.64.0.1",
|
||||
"100.64.0.1:*",
|
||||
"106.52.16.235",
|
||||
"106.52.16.235:*",
|
||||
],
|
||||
allowed_origins=[
|
||||
"https://mcp.langlangzhuoqiu.cn",
|
||||
"https://mcp.langlangzhuoqiu.cn:*",
|
||||
"http://localhost",
|
||||
"http://localhost:*",
|
||||
"http://127.0.0.1",
|
||||
"http://127.0.0.1:*",
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Tools:面向 etl_feiqiu 六层 schema
|
||||
# ----------------------------
|
||||
@mcp.tool()
|
||||
def list_tables(schema: str = "dwd", include_views: bool = False) -> Dict[str, Any]:
|
||||
"""列出指定 schema(ods/dwd/dws/core/meta/app)下的表(可选包含视图)"""
|
||||
err = _validate_schema(schema)
|
||||
if err:
|
||||
return err
|
||||
|
||||
table_types = ("BASE TABLE", "VIEW") if include_views else ("BASE TABLE",)
|
||||
sql = """
|
||||
SELECT table_name, table_type
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = %s AND table_type = ANY(%s)
|
||||
ORDER BY table_name;
|
||||
"""
|
||||
with pool.connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (schema, list(table_types)))
|
||||
rows = cur.fetchall()
|
||||
|
||||
return {
|
||||
"schema": schema,
|
||||
"include_views": include_views,
|
||||
"tables": [{"name": r[0], "type": r[1]} for r in rows],
|
||||
"table_count": len(rows),
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def describe_table(table: str, schema: str = "dwd") -> Dict[str, Any]:
|
||||
"""查看表结构(字段、类型、是否可空、默认值)"""
|
||||
err = _validate_schema(schema)
|
||||
if err:
|
||||
return err
|
||||
|
||||
sql = """
|
||||
SELECT column_name, data_type, is_nullable, column_default, ordinal_position
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema=%s AND table_name=%s
|
||||
ORDER BY ordinal_position;
|
||||
"""
|
||||
with pool.connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (schema, table))
|
||||
rows = cur.fetchall()
|
||||
|
||||
return {
|
||||
"schema": schema,
|
||||
"table": table,
|
||||
"columns": [
|
||||
{"name": r[0], "type": r[1], "nullable": r[2], "default": r[3], "position": r[4]}
|
||||
for r in rows
|
||||
],
|
||||
"column_count": len(rows),
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def describe_schemas(
|
||||
schemas: Optional[List[str]] = None,
|
||||
include_views: bool = False,
|
||||
max_tables_per_schema: int = 500,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
返回 ods/dwd/dws/core/meta/app schema 下的表结构(含主键)。
|
||||
不传 schemas 则返回全部六个 schema。
|
||||
"""
|
||||
schemas = schemas or list(ALLOWED_SCHEMAS)
|
||||
|
||||
invalid = [s for s in schemas if s not in ALLOWED_SCHEMA_SET]
|
||||
if invalid:
|
||||
return {"error": f"存在不允许的 schema:{invalid}。仅允许:{sorted(ALLOWED_SCHEMA_SET)}"}
|
||||
|
||||
table_types = ("BASE TABLE", "VIEW") if include_views else ("BASE TABLE",)
|
||||
|
||||
with pool.connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
# 1) 表清单
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT table_schema, table_name, table_type
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = ANY(%s)
|
||||
AND table_type = ANY(%s)
|
||||
ORDER BY table_schema, table_name;
|
||||
""",
|
||||
(schemas, list(table_types)),
|
||||
)
|
||||
table_rows = cur.fetchall()
|
||||
|
||||
tables_by_schema: Dict[str, List[Tuple[str, str]]] = defaultdict(list)
|
||||
for s, t, tt in table_rows:
|
||||
if len(tables_by_schema[s]) < max_tables_per_schema:
|
||||
tables_by_schema[s].append((t, tt))
|
||||
|
||||
# 2) 所有列(一次性取;如表非常多,可考虑拆分/分页)
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT table_schema, table_name, column_name, data_type, is_nullable, column_default, ordinal_position
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = ANY(%s)
|
||||
ORDER BY table_schema, table_name, ordinal_position;
|
||||
""",
|
||||
(schemas,),
|
||||
)
|
||||
col_rows = cur.fetchall()
|
||||
|
||||
cols_map: Dict[Tuple[str, str], List[Dict[str, Any]]] = defaultdict(list)
|
||||
for s, t, c, dt, nul, dft, pos in col_rows:
|
||||
cols_map[(s, t)].append(
|
||||
{"name": c, "type": dt, "nullable": nul, "default": dft, "position": pos}
|
||||
)
|
||||
|
||||
# 3) 主键
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT kcu.table_schema, kcu.table_name, kcu.column_name, kcu.ordinal_position
|
||||
FROM information_schema.table_constraints tc
|
||||
JOIN information_schema.key_column_usage kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
AND tc.table_name = kcu.table_name
|
||||
WHERE tc.constraint_type = 'PRIMARY KEY'
|
||||
AND tc.table_schema = ANY(%s)
|
||||
ORDER BY kcu.table_schema, kcu.table_name, kcu.ordinal_position;
|
||||
""",
|
||||
(schemas,),
|
||||
)
|
||||
pk_rows = cur.fetchall()
|
||||
|
||||
pk_map: Dict[Tuple[str, str], List[str]] = defaultdict(list)
|
||||
for s, t, col, _pos in pk_rows:
|
||||
pk_map[(s, t)].append(col)
|
||||
|
||||
# 4) 组装
|
||||
result: Dict[str, Any] = {
|
||||
"schemas": {},
|
||||
"include_views": include_views,
|
||||
"limits": {"max_tables_per_schema": max_tables_per_schema},
|
||||
}
|
||||
|
||||
for s in schemas:
|
||||
schema_tables = tables_by_schema.get(s, [])
|
||||
result["schemas"][s] = {"table_count": len(schema_tables), "tables": {}}
|
||||
for t, tt in schema_tables:
|
||||
key = (s, t)
|
||||
result["schemas"][s]["tables"][t] = {
|
||||
"type": tt,
|
||||
"primary_key": pk_map.get(key, []),
|
||||
"columns": cols_map.get(key, []),
|
||||
"column_count": len(cols_map.get(key, [])),
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def query_sql(schema: str, sql: str, max_rows: int = MAX_ROWS) -> Dict[str, Any]:
|
||||
"""
|
||||
在指定 schema 内执行只读 SQL(会 SET LOCAL search_path),并限制显式跨 schema 引用。
|
||||
"""
|
||||
err = _validate_schema(schema)
|
||||
if err:
|
||||
return err
|
||||
|
||||
sql = (sql or "").strip().rstrip(";")
|
||||
if not _is_probably_readonly(sql):
|
||||
return {"error": "SQL 被拒绝:仅允许只读(select/with/show/explain)并禁止危险关键字。"}
|
||||
|
||||
cross = _reject_cross_schema(sql, allowed_schema=schema)
|
||||
if cross:
|
||||
return cross
|
||||
|
||||
with pool.connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
# schema 已白名单校验,可安全拼接
|
||||
cur.execute(f"SET LOCAL search_path TO {schema}")
|
||||
cur.execute(sql)
|
||||
|
||||
cols = [d.name for d in (cur.description or [])]
|
||||
rows = cur.fetchmany(max_rows + 1)
|
||||
|
||||
truncated = len(rows) > max_rows
|
||||
rows = rows[:max_rows]
|
||||
|
||||
safe_rows: List[List[Any]] = []
|
||||
for r in rows:
|
||||
safe_rows.append([v if isinstance(v, (int, float, str, bool)) or v is None else str(v) for v in r])
|
||||
|
||||
return {
|
||||
"schema": schema,
|
||||
"columns": cols,
|
||||
"rows": safe_rows,
|
||||
"row_count": len(safe_rows),
|
||||
"truncated": truncated,
|
||||
"max_rows": max_rows,
|
||||
}
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# 鉴权 Middleware(支持 Bearer 或 query token)
|
||||
# ----------------------------
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if MCP_TOKEN and request.url.path.startswith("/mcp"):
|
||||
auth = request.headers.get("authorization", "")
|
||||
token_q = request.query_params.get("token", "")
|
||||
if auth != f"Bearer {MCP_TOKEN}" and token_q != MCP_TOKEN:
|
||||
return JSONResponse({"error": "unauthorized"}, status_code=401)
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# lifespan:显式 open/close pool,并运行 session_manager
|
||||
# ----------------------------
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app: Starlette):
|
||||
pool.open(wait=True, timeout=30)
|
||||
try:
|
||||
async with mcp.session_manager.run():
|
||||
yield
|
||||
finally:
|
||||
# 避免解释器退出阶段 __del__ 清理导致异常
|
||||
pool.close(timeout=5)
|
||||
|
||||
|
||||
# MCP endpoint:/mcp(默认 streamable_http_path="/mcp")
|
||||
app = Starlette(
|
||||
routes=[Mount("/", app=mcp.streamable_http_app())],
|
||||
lifespan=lifespan,
|
||||
)
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host="0.0.0.0",
|
||||
port=PORT,
|
||||
proxy_headers=True,
|
||||
forwarded_allow_ips="*",
|
||||
)
|
||||
Reference in New Issue
Block a user