Files
Neo-ZQYY/apps/mcp-server/server.py

413 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import re
import contextlib
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import sqlparse
from dotenv import load_dotenv
from psycopg_pool import ConnectionPool
from mcp.server.fastmcp import FastMCP
from mcp.server.transport_security import TransportSecuritySettings
from starlette.applications import Starlette
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Mount
# 加载配置:.env.local > 同级 .env > 项目根 .env
_here = Path(__file__).resolve().parent
_root = _here.parent.parent # apps/mcp-server -> apps -> NeoZQYY
load_dotenv(_here / ".env.local", override=True)
load_dotenv(_here / ".env", override=False)
load_dotenv(_root / ".env", override=False)
# ----------------------------
# 工具:环境变量解析(避免 int("") 报错)
# ----------------------------
def env_str(name: str, default: str = "", required: bool = False) -> str:
v = os.getenv(name, default)
v = v if v is not None else default
v = v.strip() if isinstance(v, str) else v
if required and (v is None or v == ""):
raise RuntimeError(f"Missing required env var: {name}")
return v
def env_int(name: str, default: Optional[int] = None, required: bool = False) -> int:
raw = os.getenv(name, "")
if raw is None or raw.strip() == "":
if required and default is None:
raise RuntimeError(f"Missing required env var: {name}")
if default is None:
raise RuntimeError(f"Missing env var: {name}")
return default
try:
return int(raw.strip())
except ValueError as e:
raise RuntimeError(f"Invalid int env var {name}={raw!r}") from e
# ----------------------------
# 配置(用环境变量注入)
# MCP_PG_* 优先(独立部署),回退到项目公共 DB_* / ETL_DB_NAME / PG_NAME分离式配置
# ----------------------------
PGHOST = env_str("MCP_PG_HOST", default="") or env_str("DB_HOST", required=True)
PGPORT = env_int("MCP_PG_PORT", default=0) or env_int("DB_PORT", default=5432)
PGDATABASE = env_str("MCP_PG_DATABASE", default="") or env_str("ETL_DB_NAME", default="") or env_str("PG_NAME", required=True)
PGUSER = env_str("MCP_PG_USER", default="") or env_str("DB_USER", required=True)
PGPASSWORD = env_str("MCP_PG_PASSWORD", default="") or env_str("DB_PASSWORD", required=True)
MCP_TOKEN = env_str("MCP_TOKEN", default="") # 鉴权 token可空不启用鉴权
MAX_ROWS = env_int("MCP_MAX_ROWS", default=500) # query_sql 默认最大行数
PORT = env_int("PORT", default=9000) # uvicorn 端口
# etl_feiqiu 库的六层 schema 架构
ALLOWED_SCHEMAS = ("ods", "dwd", "dws", "core", "meta", "app")
ALLOWED_SCHEMA_SET = set(ALLOWED_SCHEMAS)
# psycopg DSN如果密码包含空格等特殊字符建议改用 URL 形式并做编码)
DSN = (
f"host={PGHOST} port={PGPORT} dbname={PGDATABASE} "
f"user={PGUSER} password={PGPASSWORD}"
)
# 连接池:不要 open=True避免解释器退出时 __del__ 清理触发异常)
pool = ConnectionPool(conninfo=DSN, min_size=1, max_size=10, timeout=60, open=False)
# ----------------------------
# SQL 只读门禁(最终底线仍是 DB 只读账号)
# ----------------------------
FORBIDDEN = re.compile(
r"\b(insert|update|delete|drop|alter|truncate|create|grant|revoke|copy|call|execute|do)\b",
re.IGNORECASE,
)
# 额外禁止显式跨 schema 访问(避免越权)
# 匹配 schema.table 模式,但排除单字母别名(如 t.id、o.amount
SCHEMA_QUAL = re.compile(r"\b([a-zA-Z_][a-zA-Z0-9_]{1,})\s*\.\s*[a-zA-Z_]", re.IGNORECASE)
def _is_probably_readonly(sql: str) -> bool:
if FORBIDDEN.search(sql):
return False
parsed = sqlparse.parse(sql)
if not parsed:
return False
stmt = parsed[0]
for tok in stmt.tokens:
if tok.is_whitespace:
continue
first = str(tok).strip().lower()
return first in ("select", "with", "show", "explain")
return False
def _validate_schema(schema: str) -> Optional[Dict[str, Any]]:
if schema not in ALLOWED_SCHEMA_SET:
return {"error": f"schema 不允许:{schema}。仅允许:{sorted(ALLOWED_SCHEMA_SET)}"}
return None
def _reject_cross_schema(sql: str, allowed_schema: str) -> Optional[Dict[str, Any]]:
"""
简单防护:如果出现显式 schema 前缀xxx.),要求必须是白名单内的 schema 或系统 schema。
注:这不是 SQL parser 级别的严格策略,但能挡住绝大多数越权写法。
"""
matches = set(m.group(1) for m in SCHEMA_QUAL.finditer(sql or ""))
# 允许所有业务 schema + 系统 schema
safe = ALLOWED_SCHEMA_SET | {"pg_catalog", "information_schema"}
bad = sorted([s for s in matches if s.lower() not in {a.lower() for a in safe}])
if bad:
return {"error": f"SQL 被拒绝:检测到不允许的 schema 引用 {bad},仅允许 {sorted(ALLOWED_SCHEMA_SET)} / 系统 schema。"}
return None
# ----------------------------
# FastMCPStreamable HTTP + JSON 响应
# ----------------------------
mcp = FastMCP(
"postgres-mcp",
stateless_http=True,
json_response=True,
transport_security=TransportSecuritySettings(
enable_dns_rebinding_protection=True,
allowed_hosts=[
# 关键:既允许不带端口,也允许带端口
"mcp.langlangzhuoqiu.cn",
"mcp.langlangzhuoqiu.cn:*",
"localhost",
"localhost:*",
"127.0.0.1",
"127.0.0.1:*",
"100.64.0.4",
"100.64.0.4:*",
"100.64.0.1",
"100.64.0.1:*",
"106.52.16.235",
"106.52.16.235:*",
],
allowed_origins=[
"https://mcp.langlangzhuoqiu.cn",
"https://mcp.langlangzhuoqiu.cn:*",
"http://localhost",
"http://localhost:*",
"http://127.0.0.1",
"http://127.0.0.1:*",
],
),
)
# ----------------------------
# Tools面向 etl_feiqiu 六层 schema
# ----------------------------
@mcp.tool()
def list_tables(schema: str = "dwd", include_views: bool = False) -> Dict[str, Any]:
"""列出指定 schemaods/dwd/dws/core/meta/app下的表可选包含视图"""
err = _validate_schema(schema)
if err:
return err
table_types = ("BASE TABLE", "VIEW") if include_views else ("BASE TABLE",)
sql = """
SELECT table_name, table_type
FROM information_schema.tables
WHERE table_schema = %s AND table_type = ANY(%s)
ORDER BY table_name;
"""
with pool.connection() as conn:
with conn.cursor() as cur:
cur.execute(sql, (schema, list(table_types)))
rows = cur.fetchall()
return {
"schema": schema,
"include_views": include_views,
"tables": [{"name": r[0], "type": r[1]} for r in rows],
"table_count": len(rows),
}
@mcp.tool()
def describe_table(table: str, schema: str = "dwd") -> Dict[str, Any]:
"""查看表结构(字段、类型、是否可空、默认值)"""
err = _validate_schema(schema)
if err:
return err
sql = """
SELECT column_name, data_type, is_nullable, column_default, ordinal_position
FROM information_schema.columns
WHERE table_schema=%s AND table_name=%s
ORDER BY ordinal_position;
"""
with pool.connection() as conn:
with conn.cursor() as cur:
cur.execute(sql, (schema, table))
rows = cur.fetchall()
return {
"schema": schema,
"table": table,
"columns": [
{"name": r[0], "type": r[1], "nullable": r[2], "default": r[3], "position": r[4]}
for r in rows
],
"column_count": len(rows),
}
@mcp.tool()
def describe_schemas(
schemas: Optional[List[str]] = None,
include_views: bool = False,
max_tables_per_schema: int = 500,
) -> Dict[str, Any]:
"""
返回 ods/dwd/dws/core/meta/app schema 下的表结构(含主键)。
不传 schemas 则返回全部六个 schema。
"""
schemas = schemas or list(ALLOWED_SCHEMAS)
invalid = [s for s in schemas if s not in ALLOWED_SCHEMA_SET]
if invalid:
return {"error": f"存在不允许的 schema{invalid}。仅允许:{sorted(ALLOWED_SCHEMA_SET)}"}
table_types = ("BASE TABLE", "VIEW") if include_views else ("BASE TABLE",)
with pool.connection() as conn:
with conn.cursor() as cur:
# 1) 表清单
cur.execute(
"""
SELECT table_schema, table_name, table_type
FROM information_schema.tables
WHERE table_schema = ANY(%s)
AND table_type = ANY(%s)
ORDER BY table_schema, table_name;
""",
(schemas, list(table_types)),
)
table_rows = cur.fetchall()
tables_by_schema: Dict[str, List[Tuple[str, str]]] = defaultdict(list)
for s, t, tt in table_rows:
if len(tables_by_schema[s]) < max_tables_per_schema:
tables_by_schema[s].append((t, tt))
# 2) 所有列(一次性取;如表非常多,可考虑拆分/分页)
cur.execute(
"""
SELECT table_schema, table_name, column_name, data_type, is_nullable, column_default, ordinal_position
FROM information_schema.columns
WHERE table_schema = ANY(%s)
ORDER BY table_schema, table_name, ordinal_position;
""",
(schemas,),
)
col_rows = cur.fetchall()
cols_map: Dict[Tuple[str, str], List[Dict[str, Any]]] = defaultdict(list)
for s, t, c, dt, nul, dft, pos in col_rows:
cols_map[(s, t)].append(
{"name": c, "type": dt, "nullable": nul, "default": dft, "position": pos}
)
# 3) 主键
cur.execute(
"""
SELECT kcu.table_schema, kcu.table_name, kcu.column_name, kcu.ordinal_position
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
AND tc.table_name = kcu.table_name
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema = ANY(%s)
ORDER BY kcu.table_schema, kcu.table_name, kcu.ordinal_position;
""",
(schemas,),
)
pk_rows = cur.fetchall()
pk_map: Dict[Tuple[str, str], List[str]] = defaultdict(list)
for s, t, col, _pos in pk_rows:
pk_map[(s, t)].append(col)
# 4) 组装
result: Dict[str, Any] = {
"schemas": {},
"include_views": include_views,
"limits": {"max_tables_per_schema": max_tables_per_schema},
}
for s in schemas:
schema_tables = tables_by_schema.get(s, [])
result["schemas"][s] = {"table_count": len(schema_tables), "tables": {}}
for t, tt in schema_tables:
key = (s, t)
result["schemas"][s]["tables"][t] = {
"type": tt,
"primary_key": pk_map.get(key, []),
"columns": cols_map.get(key, []),
"column_count": len(cols_map.get(key, [])),
}
return result
@mcp.tool()
def query_sql(schema: str, sql: str, max_rows: int = MAX_ROWS) -> Dict[str, Any]:
"""
在指定 schema 内执行只读 SQL会 SET LOCAL search_path并限制显式跨 schema 引用。
"""
err = _validate_schema(schema)
if err:
return err
sql = (sql or "").strip().rstrip(";")
if not _is_probably_readonly(sql):
return {"error": "SQL 被拒绝仅允许只读select/with/show/explain并禁止危险关键字。"}
cross = _reject_cross_schema(sql, allowed_schema=schema)
if cross:
return cross
with pool.connection() as conn:
with conn.cursor() as cur:
# schema 已白名单校验,可安全拼接
cur.execute(f"SET LOCAL search_path TO {schema}")
cur.execute(sql)
cols = [d.name for d in (cur.description or [])]
rows = cur.fetchmany(max_rows + 1)
truncated = len(rows) > max_rows
rows = rows[:max_rows]
safe_rows: List[List[Any]] = []
for r in rows:
safe_rows.append([v if isinstance(v, (int, float, str, bool)) or v is None else str(v) for v in r])
return {
"schema": schema,
"columns": cols,
"rows": safe_rows,
"row_count": len(safe_rows),
"truncated": truncated,
"max_rows": max_rows,
}
# ----------------------------
# 鉴权 Middleware支持 Bearer 或 query token
# ----------------------------
class AuthMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if MCP_TOKEN and request.url.path.startswith("/mcp"):
auth = request.headers.get("authorization", "")
token_q = request.query_params.get("token", "")
if auth != f"Bearer {MCP_TOKEN}" and token_q != MCP_TOKEN:
return JSONResponse({"error": "unauthorized"}, status_code=401)
return await call_next(request)
# ----------------------------
# lifespan显式 open/close pool并运行 session_manager
# ----------------------------
@contextlib.asynccontextmanager
async def lifespan(app: Starlette):
pool.open(wait=True, timeout=30)
try:
async with mcp.session_manager.run():
yield
finally:
# 避免解释器退出阶段 __del__ 清理导致异常
pool.close(timeout=5)
# MCP endpoint/mcp默认 streamable_http_path="/mcp"
app = Starlette(
routes=[Mount("/", app=mcp.streamable_http_app())],
lifespan=lifespan,
)
app.add_middleware(AuthMiddleware)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=PORT,
proxy_headers=True,
forwarded_allow_ips="*",
)