在准备环境前提交次全部更改。

This commit is contained in:
Neo
2026-02-19 08:35:13 +08:00
parent ded6dfb9d8
commit 4eac07da47
1387 changed files with 6107191 additions and 33002 deletions

View File

@@ -0,0 +1 @@
"""认证模块JWT 令牌管理与 FastAPI 依赖注入。"""

View File

@@ -0,0 +1,67 @@
"""
FastAPI 依赖注入:从 JWT 提取当前用户信息。
用法:
@router.get("/protected")
async def protected_endpoint(user: CurrentUser = Depends(get_current_user)):
print(user.user_id, user.site_id)
"""
from dataclasses import dataclass
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jose import JWTError
from app.auth.jwt import decode_access_token
# Bearer token 提取器
_bearer_scheme = HTTPBearer(auto_error=True)
@dataclass(frozen=True)
class CurrentUser:
"""从 JWT 解析出的当前用户上下文。"""
user_id: int
site_id: int
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(_bearer_scheme),
) -> CurrentUser:
"""
FastAPI 依赖:从 Authorization header 提取 JWT验证后返回用户信息。
失败时抛出 401。
"""
token = credentials.credentials
try:
payload = decode_access_token(token)
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的令牌",
headers={"WWW-Authenticate": "Bearer"},
)
user_id_raw = payload.get("sub")
site_id = payload.get("site_id")
if user_id_raw is None or site_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="令牌缺少必要字段",
headers={"WWW-Authenticate": "Bearer"},
)
try:
user_id = int(user_id_raw)
except (TypeError, ValueError):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="令牌中 user_id 格式无效",
headers={"WWW-Authenticate": "Bearer"},
)
return CurrentUser(user_id=user_id, site_id=site_id)

View File

@@ -0,0 +1,112 @@
"""
JWT 令牌生成、验证与解码。
- access_token短期有效默认 30 分钟),用于 API 请求认证
- refresh_token长期有效默认 7 天),用于刷新 access_token
- payload 包含 user_id、site_id、令牌类型access / refresh
- 密码哈希直接使用 bcrypt 库passlib 与 bcrypt>=4.1 存在兼容性问题)
"""
from datetime import datetime, timedelta, timezone
import bcrypt
from jose import JWTError, jwt
from app import config
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""校验明文密码与哈希是否匹配。"""
return bcrypt.checkpw(
plain_password.encode("utf-8"), hashed_password.encode("utf-8")
)
def hash_password(password: str) -> str:
"""生成密码的 bcrypt 哈希。"""
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
def create_access_token(user_id: int, site_id: int) -> str:
"""
生成 access_token。
payload: sub=user_id, site_id, type=access, exp
"""
expire = datetime.now(timezone.utc) + timedelta(
minutes=config.JWT_ACCESS_TOKEN_EXPIRE_MINUTES
)
payload = {
"sub": str(user_id),
"site_id": site_id,
"type": "access",
"exp": expire,
}
return jwt.encode(payload, config.JWT_SECRET_KEY, algorithm=config.JWT_ALGORITHM)
def create_refresh_token(user_id: int, site_id: int) -> str:
"""
生成 refresh_token。
payload: sub=user_id, site_id, type=refresh, exp
"""
expire = datetime.now(timezone.utc) + timedelta(
days=config.JWT_REFRESH_TOKEN_EXPIRE_DAYS
)
payload = {
"sub": str(user_id),
"site_id": site_id,
"type": "refresh",
"exp": expire,
}
return jwt.encode(payload, config.JWT_SECRET_KEY, algorithm=config.JWT_ALGORITHM)
def create_token_pair(user_id: int, site_id: int) -> dict[str, str]:
"""生成 access_token + refresh_token 令牌对。"""
return {
"access_token": create_access_token(user_id, site_id),
"refresh_token": create_refresh_token(user_id, site_id),
"token_type": "bearer",
}
def decode_token(token: str) -> dict:
"""
解码并验证 JWT 令牌。
返回 payload dict包含 sub、site_id、type、exp。
令牌无效或过期时抛出 JWTError。
"""
try:
payload = jwt.decode(
token, config.JWT_SECRET_KEY, algorithms=[config.JWT_ALGORITHM]
)
return payload
except JWTError:
raise
def decode_access_token(token: str) -> dict:
"""
解码 access_token 并验证类型。
令牌类型不是 access 时抛出 JWTError。
"""
payload = decode_token(token)
if payload.get("type") != "access":
raise JWTError("令牌类型不是 access")
return payload
def decode_refresh_token(token: str) -> dict:
"""
解码 refresh_token 并验证类型。
令牌类型不是 refresh 时抛出 JWTError。
"""
payload = decode_token(token)
if payload.get("type") != "refresh":
raise JWTError("令牌类型不是 refresh")
return payload

View File

@@ -29,7 +29,37 @@ DB_HOST: str = get("DB_HOST", "localhost")
DB_PORT: str = get("DB_PORT", "5432")
DB_USER: str = get("DB_USER", "")
DB_PASSWORD: str = get("DB_PASSWORD", "")
APP_DB_NAME: str = get("APP_DB_NAME", "zqyy_app")
# CHANGE 2026-02-15 | 默认指向测试库,生产环境通过 .env 覆盖
APP_DB_NAME: str = get("APP_DB_NAME", "test_zqyy_app")
# ---- JWT 认证 ----
JWT_SECRET_KEY: str = get("JWT_SECRET_KEY", "") # 生产环境必须设置
JWT_ALGORITHM: str = get("JWT_ALGORITHM", "HS256")
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = int(get("JWT_ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = int(get("JWT_REFRESH_TOKEN_EXPIRE_DAYS", "7"))
# ---- ETL 数据库连接参数(可独立配置,缺省时复用 zqyy_app 的连接参数) ----
ETL_DB_HOST: str = get("ETL_DB_HOST") or DB_HOST
ETL_DB_PORT: str = get("ETL_DB_PORT") or DB_PORT
ETL_DB_USER: str = get("ETL_DB_USER") or DB_USER
ETL_DB_PASSWORD: str = get("ETL_DB_PASSWORD") or DB_PASSWORD
# CHANGE 2026-02-15 | 默认指向测试库,生产环境通过 .env 覆盖
ETL_DB_NAME: str = get("ETL_DB_NAME", "test_etl_feiqiu")
# ---- CORS ----
# 逗号分隔的允许来源列表;缺省允许 Vite 开发服务器
CORS_ORIGINS: list[str] = [
o.strip()
for o in get("CORS_ORIGINS", "http://localhost:5173").split(",")
if o.strip()
]
# ---- ETL 项目路径 ----
# ETL CLI 的工作目录(子进程 cwd缺省时按 monorepo 相对路径推算
ETL_PROJECT_PATH: str = get(
"ETL_PROJECT_PATH",
str(Path(__file__).resolve().parents[3] / "apps" / "etl" / "connectors" / "feiqiu"),
)
# ---- 通用 ----
TIMEZONE: str = get("TIMEZONE", "Asia/Shanghai")

View File

@@ -1,14 +1,30 @@
"""
zqyy_app 数据库连接
数据库连接
使用 psycopg2 直连 PostgreSQL不引入 ORM。
连接参数从环境变量读取(经 config 模块加载)。
提供两类连接:
- get_connection()zqyy_app 读写连接(用户/队列/调度等业务数据)
- get_etl_readonly_connection(site_id)etl_feiqiu 只读连接(数据库查看器),
自动设置 RLS site_id 隔离
"""
import psycopg2
from psycopg2.extensions import connection as PgConnection
from app.config import APP_DB_NAME, DB_HOST, DB_PASSWORD, DB_PORT, DB_USER
from app.config import (
APP_DB_NAME,
DB_HOST,
DB_PASSWORD,
DB_PORT,
DB_USER,
ETL_DB_HOST,
ETL_DB_NAME,
ETL_DB_PASSWORD,
ETL_DB_PORT,
ETL_DB_USER,
)
def get_connection() -> PgConnection:
@@ -24,3 +40,43 @@ def get_connection() -> PgConnection:
password=DB_PASSWORD,
dbname=APP_DB_NAME,
)
def get_etl_readonly_connection(site_id: int | str) -> PgConnection:
"""
获取 ETL 数据库etl_feiqiu的只读连接。
连接建立后自动执行:
1. SET default_transaction_read_only = on — 禁止写操作
2. SET LOCAL app.current_site_id = '{site_id}' — 启用 RLS 门店隔离
调用方负责关闭连接。典型用法::
conn = get_etl_readonly_connection(site_id)
try:
with conn.cursor() as cur:
cur.execute("SELECT ...")
finally:
conn.close()
"""
conn = psycopg2.connect(
host=ETL_DB_HOST,
port=ETL_DB_PORT,
user=ETL_DB_USER,
password=ETL_DB_PASSWORD,
dbname=ETL_DB_NAME,
)
try:
conn.autocommit = False
with conn.cursor() as cur:
# 会话级只读:防止任何写操作
cur.execute("SET default_transaction_read_only = on")
# 事务级 RLS 隔离:设置当前门店 ID
cur.execute(
"SET LOCAL app.current_site_id = %s", (str(site_id),)
)
conn.commit()
except Exception:
conn.close()
raise
return conn

View File

@@ -1,20 +1,66 @@
"""
NeoZQYY 后端 API 入口
基于 FastAPI 构建,为微信小程序提供 RESTful API。
基于 FastAPI 构建,为管理后台和微信小程序提供 RESTful API。
OpenAPI 文档自动生成于 /docsSwagger UI和 /redocReDoc
"""
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app import config
# CHANGE 2026-02-19 | 新增 xcx_test 路由MVP 验证)+ wx_callback 路由(微信消息推送)
from app.routers import auth, execution, schedules, tasks, env_config, db_viewer, etl_status, xcx_test, wx_callback
from app.services.scheduler import scheduler
from app.services.task_queue import task_queue
from app.ws.logs import ws_router
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期:启动时拉起后台服务,关闭时优雅停止。"""
# 启动
task_queue.start()
scheduler.start()
yield
# 关闭
await scheduler.stop()
await task_queue.stop()
app = FastAPI(
title="NeoZQYY API",
description="台球门店运营助手 — 微信小程序后端 API",
description="台球门店运营助手 — 后端 API(管理后台 + 微信小程序)",
version="0.1.0",
docs_url="/docs",
redoc_url="/redoc",
lifespan=lifespan,
)
# ---- CORS 中间件 ----
# 允许来源从环境变量 CORS_ORIGINS 读取,缺省允许 Vite 开发服务器 (localhost:5173)
app.add_middleware(
CORSMiddleware,
allow_origins=config.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---- 路由注册 ----
app.include_router(auth.router)
app.include_router(tasks.router)
app.include_router(execution.router)
app.include_router(schedules.router)
app.include_router(env_config.router)
app.include_router(db_viewer.router)
app.include_router(etl_status.router)
app.include_router(ws_router)
app.include_router(xcx_test.router)
app.include_router(wx_callback.router)
@app.get("/health", tags=["系统"])
async def health_check():

View File

@@ -0,0 +1,97 @@
"""
认证路由:登录与令牌刷新。
- POST /api/auth/login — 验证用户名密码,返回 JWT 令牌对
- POST /api/auth/refresh — 用刷新令牌换取新的访问令牌
"""
import logging
from fastapi import APIRouter, HTTPException, status
from jose import JWTError
from app.auth.jwt import (
create_access_token,
create_token_pair,
decode_refresh_token,
verify_password,
)
from app.database import get_connection
from app.schemas.auth import LoginRequest, RefreshRequest, TokenResponse
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/auth", tags=["认证"])
@router.post("/login", response_model=TokenResponse)
async def login(body: LoginRequest):
"""
用户登录。
查询 admin_users 表验证用户名密码,成功后返回 JWT 令牌对。
- 用户不存在或密码错误401
- 账号已禁用is_active=false401
"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"SELECT id, password_hash, site_id, is_active "
"FROM admin_users WHERE username = %s",
(body.username,),
)
row = cur.fetchone()
finally:
conn.close()
if row is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
)
user_id, password_hash, site_id, is_active = row
if not is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="账号已被禁用",
)
if not verify_password(body.password, password_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
)
tokens = create_token_pair(user_id, site_id)
return TokenResponse(**tokens)
@router.post("/refresh", response_model=TokenResponse)
async def refresh(body: RefreshRequest):
"""
刷新访问令牌。
验证 refresh_token 有效性,成功后仅返回新的 access_token
refresh_token 保持不变,由客户端继续持有)。
"""
try:
payload = decode_refresh_token(body.refresh_token)
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的刷新令牌",
)
user_id = int(payload["sub"])
site_id = payload["site_id"]
# 生成新的 access_tokenrefresh_token 原样返回
new_access = create_access_token(user_id, site_id)
return TokenResponse(
access_token=new_access,
refresh_token=body.refresh_token,
token_type="bearer",
)

View File

@@ -0,0 +1,228 @@
# -*- coding: utf-8 -*-
"""数据库查看器 API
提供 4 个端点:
- GET /api/db/schemas — 返回 Schema 列表
- GET /api/db/schemas/{name}/tables — 返回表列表和行数
- GET /api/db/tables/{schema}/{table}/columns — 返回列定义
- POST /api/db/query — 只读 SQL 执行
所有端点需要 JWT 认证。
使用 get_etl_readonly_connection(site_id) 确保 RLS 隔离。
"""
from __future__ import annotations
import logging
import re
from fastapi import APIRouter, Depends, HTTPException, status
from psycopg2 import errors as pg_errors, OperationalError
from app.auth.dependencies import CurrentUser, get_current_user
from app.database import get_etl_readonly_connection
from app.schemas.db_viewer import (
ColumnInfo,
QueryRequest,
QueryResponse,
SchemaInfo,
TableInfo,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/db", tags=["数据库查看器"])
# 写操作关键词(不区分大小写)
_WRITE_KEYWORDS = re.compile(
r"\b(INSERT|UPDATE|DELETE|DROP|TRUNCATE)\b",
re.IGNORECASE,
)
# 查询结果行数上限
_MAX_ROWS = 1000
# 查询超时(秒)
_QUERY_TIMEOUT_SEC = 30
# ── GET /api/db/schemas ──────────────────────────────────────
@router.get("/schemas", response_model=list[SchemaInfo])
async def list_schemas(
user: CurrentUser = Depends(get_current_user),
) -> list[SchemaInfo]:
"""返回 ETL 数据库中的 Schema 列表。"""
conn = get_etl_readonly_connection(user.site_id)
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
ORDER BY schema_name
"""
)
rows = cur.fetchall()
return [SchemaInfo(name=row[0]) for row in rows]
finally:
conn.close()
# ── GET /api/db/schemas/{name}/tables ────────────────────────
@router.get("/schemas/{name}/tables", response_model=list[TableInfo])
async def list_tables(
name: str,
user: CurrentUser = Depends(get_current_user),
) -> list[TableInfo]:
"""返回指定 Schema 下所有表的名称和行数统计。"""
conn = get_etl_readonly_connection(user.site_id)
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT
t.table_name,
s.n_live_tup
FROM information_schema.tables t
LEFT JOIN pg_stat_user_tables s
ON s.schemaname = t.table_schema
AND s.relname = t.table_name
WHERE t.table_schema = %s
AND t.table_type = 'BASE TABLE'
ORDER BY t.table_name
""",
(name,),
)
rows = cur.fetchall()
return [
TableInfo(name=row[0], row_count=row[1])
for row in rows
]
finally:
conn.close()
# ── GET /api/db/tables/{schema}/{table}/columns ──────────────
@router.get(
"/tables/{schema}/{table}/columns",
response_model=list[ColumnInfo],
)
async def list_columns(
schema: str,
table: str,
user: CurrentUser = Depends(get_current_user),
) -> list[ColumnInfo]:
"""返回指定表的列定义。"""
conn = get_etl_readonly_connection(user.site_id)
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT
column_name,
data_type,
is_nullable,
column_default
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
""",
(schema, table),
)
rows = cur.fetchall()
return [
ColumnInfo(
name=row[0],
data_type=row[1],
is_nullable=row[2] == "YES",
column_default=row[3],
)
for row in rows
]
finally:
conn.close()
# ── POST /api/db/query ───────────────────────────────────────
@router.post("/query", response_model=QueryResponse)
async def execute_query(
body: QueryRequest,
user: CurrentUser = Depends(get_current_user),
) -> QueryResponse:
"""只读 SQL 执行。
安全措施:
1. 拦截写操作关键词INSERT / UPDATE / DELETE / DROP / TRUNCATE
2. 限制返回行数上限 1000 行
3. 设置查询超时 30 秒
"""
sql = body.sql.strip()
if not sql:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="SQL 语句不能为空",
)
# 拦截写操作
if _WRITE_KEYWORDS.search(sql):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="只允许只读查询,禁止 INSERT / UPDATE / DELETE / DROP / TRUNCATE 操作",
)
conn = get_etl_readonly_connection(user.site_id)
try:
with conn.cursor() as cur:
# 设置查询超时
cur.execute(
"SET LOCAL statement_timeout = %s",
(f"{_QUERY_TIMEOUT_SEC}s",),
)
try:
cur.execute(sql)
except pg_errors.QueryCanceled:
raise HTTPException(
status_code=status.HTTP_408_REQUEST_TIMEOUT,
detail=f"查询超时(超过 {_QUERY_TIMEOUT_SEC} 秒)",
)
except Exception as exc:
# SQL 语法错误或其他执行错误
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"SQL 执行错误: {exc}",
)
# 提取列名
columns = (
[desc[0] for desc in cur.description]
if cur.description
else []
)
# 限制返回行数
rows = cur.fetchmany(_MAX_ROWS)
# 将元组转为列表,便于 JSON 序列化
rows_list = [list(row) for row in rows]
return QueryResponse(
columns=columns,
rows=rows_list,
row_count=len(rows_list),
)
except HTTPException:
raise
except OperationalError as exc:
# 连接级错误
logger.error("数据库查看器连接错误: %s", exc)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="数据库连接错误",
)
finally:
conn.close()

View File

@@ -0,0 +1,240 @@
# -*- coding: utf-8 -*-
"""环境配置 API
提供 3 个端点:
- GET /api/env-config — 读取 .env敏感值掩码
- PUT /api/env-config — 验证并写入 .env
- GET /api/env-config/export — 导出去敏感值的配置文件
所有端点需要 JWT 认证。
敏感键判定:键名中包含 PASSWORD、TOKEN、SECRET、DSN不区分大小写
"""
from __future__ import annotations
import logging
import re
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel
from app.auth.dependencies import CurrentUser, get_current_user
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/env-config", tags=["环境配置"])
# .env 文件路径:项目根目录
_ENV_PATH = Path(__file__).resolve().parents[3] / ".env"
# 敏感键关键词(不区分大小写)
_SENSITIVE_KEYWORDS = ("PASSWORD", "TOKEN", "SECRET", "DSN")
_MASK = "****"
# ── Pydantic 模型 ────────────────────────────────────────────
class EnvEntry(BaseModel):
"""单条环境变量键值对。"""
key: str
value: str
class EnvConfigResponse(BaseModel):
"""GET 响应:键值对列表。"""
entries: list[EnvEntry]
class EnvConfigUpdateRequest(BaseModel):
"""PUT 请求体:键值对列表。"""
entries: list[EnvEntry]
# ── 工具函数 ─────────────────────────────────────────────────
def _is_sensitive(key: str) -> bool:
"""判断键名是否为敏感键。"""
upper = key.upper()
return any(kw in upper for kw in _SENSITIVE_KEYWORDS)
def _parse_env(content: str) -> list[dict]:
"""解析 .env 文件内容,返回行级结构。
每行分为三种类型:
- comment: 注释行或空行(原样保留)
- entry: 键值对
"""
lines: list[dict] = []
for raw_line in content.splitlines():
stripped = raw_line.strip()
if not stripped or stripped.startswith("#"):
lines.append({"type": "comment", "raw": raw_line})
else:
# 支持 KEY=VALUE 和 KEY="VALUE" 格式
match = re.match(r'^([A-Za-z_][A-Za-z0-9_]*)=(.*)', raw_line)
if match:
key = match.group(1)
value = match.group(2).strip()
# 去除引号包裹
if len(value) >= 2 and value[0] == value[-1] and value[0] in ('"', "'"):
value = value[1:-1]
lines.append({"type": "entry", "key": key, "value": value, "raw": raw_line})
else:
# 无法解析的行当作注释保留
lines.append({"type": "comment", "raw": raw_line})
return lines
def _read_env_file(path: Path) -> str:
"""读取 .env 文件内容。"""
if not path.exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=".env 文件不存在",
)
return path.read_text(encoding="utf-8")
def _write_env_file(path: Path, content: str) -> None:
"""写入 .env 文件。"""
try:
path.write_text(content, encoding="utf-8")
except OSError as exc:
logger.error("写入 .env 文件失败: %s", exc)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="写入 .env 文件失败",
)
def _validate_entries(entries: list[EnvEntry]) -> None:
"""验证键值对格式。"""
for idx, entry in enumerate(entries):
if not entry.key:
raise HTTPException(
status_code=422,
detail=f"{idx + 1} 行:键名不能为空",
)
if not re.match(r'^[A-Za-z_][A-Za-z0-9_]*$', entry.key):
raise HTTPException(
status_code=422,
detail=f"{idx + 1} 行:键名 '{entry.key}' 格式无效(仅允许字母、数字、下划线,且不能以数字开头)",
)
# ── GET /api/env-config — 读取 ───────────────────────────────
@router.get("", response_model=EnvConfigResponse)
async def get_env_config(
user: CurrentUser = Depends(get_current_user),
) -> EnvConfigResponse:
"""读取 .env 文件,敏感值以掩码展示。"""
content = _read_env_file(_ENV_PATH)
parsed = _parse_env(content)
entries = []
for line in parsed:
if line["type"] == "entry":
value = _MASK if _is_sensitive(line["key"]) else line["value"]
entries.append(EnvEntry(key=line["key"], value=value))
return EnvConfigResponse(entries=entries)
# ── PUT /api/env-config — 写入 ───────────────────────────────
@router.put("", response_model=EnvConfigResponse)
async def update_env_config(
body: EnvConfigUpdateRequest,
user: CurrentUser = Depends(get_current_user),
) -> EnvConfigResponse:
"""验证并写入 .env 文件。
保留原文件中的注释行和空行。对于已有键,更新值;
对于新键,追加到文件末尾。掩码值(****)的键跳过更新,保留原值。
"""
_validate_entries(body.entries)
# 读取原文件(如果存在)
if _ENV_PATH.exists():
original_content = _ENV_PATH.read_text(encoding="utf-8")
parsed = _parse_env(original_content)
else:
parsed = []
# 构建新值映射(跳过掩码值)
new_values: dict[str, str] = {}
for entry in body.entries:
if entry.value != _MASK:
new_values[entry.key] = entry.value
# 更新已有行
seen_keys: set[str] = set()
output_lines: list[str] = []
for line in parsed:
if line["type"] == "comment":
output_lines.append(line["raw"])
elif line["type"] == "entry":
key = line["key"]
seen_keys.add(key)
if key in new_values:
output_lines.append(f"{key}={new_values[key]}")
else:
# 保留原值(包括掩码跳过的敏感键)
output_lines.append(line["raw"])
# 追加新键
for entry in body.entries:
if entry.key not in seen_keys and entry.value != _MASK:
output_lines.append(f"{entry.key}={entry.value}")
new_content = "\n".join(output_lines)
if output_lines:
new_content += "\n"
_write_env_file(_ENV_PATH, new_content)
# 返回更新后的配置(敏感值掩码)
result_parsed = _parse_env(new_content)
entries = []
for line in result_parsed:
if line["type"] == "entry":
value = _MASK if _is_sensitive(line["key"]) else line["value"]
entries.append(EnvEntry(key=line["key"], value=value))
return EnvConfigResponse(entries=entries)
# ── GET /api/env-config/export — 导出 ────────────────────────
@router.get("/export")
async def export_env_config(
user: CurrentUser = Depends(get_current_user),
) -> PlainTextResponse:
"""导出去除敏感值的配置文件(作为文件下载)。"""
content = _read_env_file(_ENV_PATH)
parsed = _parse_env(content)
output_lines: list[str] = []
for line in parsed:
if line["type"] == "comment":
output_lines.append(line["raw"])
elif line["type"] == "entry":
if _is_sensitive(line["key"]):
output_lines.append(f"{line['key']}={_MASK}")
else:
output_lines.append(line["raw"])
export_content = "\n".join(output_lines)
if output_lines:
export_content += "\n"
return PlainTextResponse(
content=export_content,
media_type="text/plain",
headers={"Content-Disposition": "attachment; filename=env-config.txt"},
)

View File

@@ -0,0 +1,134 @@
# -*- coding: utf-8 -*-
"""ETL 状态监控 API
提供 2 个端点:
- GET /api/etl-status/cursors — 返回各任务的数据游标(最后抓取时间、记录数)
- GET /api/etl-status/recent-runs — 返回最近 50 条任务执行记录
所有端点需要 JWT 认证。
游标端点查询 ETL 数据库meta.etl_cursor
执行记录端点查询 zqyy_app 数据库task_execution_log
"""
from __future__ import annotations
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from psycopg2 import OperationalError
from app.auth.dependencies import CurrentUser, get_current_user
from app.database import get_connection, get_etl_readonly_connection
from app.schemas.etl_status import CursorInfo, RecentRun
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/etl-status", tags=["ETL 状态"])
# 最近执行记录条数上限
_RECENT_RUNS_LIMIT = 50
# ── GET /api/etl-status/cursors ──────────────────────────────
@router.get("/cursors", response_model=list[CursorInfo])
async def list_cursors(
user: CurrentUser = Depends(get_current_user),
) -> list[CursorInfo]:
"""返回各 ODS 表的最新数据游标。
查询 ETL 数据库中的 meta.etl_cursor 表。
如果该表不存在,返回空列表而非报错。
"""
conn = get_etl_readonly_connection(user.site_id)
try:
with conn.cursor() as cur:
# CHANGE 2026-02-15 | 对齐新库 etl_feiqiu 六层架构etl_admin → meta
cur.execute(
"""
SELECT EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_schema = 'meta'
AND table_name = 'etl_cursor'
)
"""
)
exists = cur.fetchone()[0]
if not exists:
return []
cur.execute(
"""
SELECT task_code, last_fetch_time, record_count
FROM meta.etl_cursor
ORDER BY task_code
"""
)
rows = cur.fetchall()
return [
CursorInfo(
task_code=row[0],
last_fetch_time=str(row[1]) if row[1] is not None else None,
record_count=row[2],
)
for row in rows
]
except OperationalError as exc:
logger.error("ETL 游标查询连接错误: %s", exc)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="ETL 数据库连接错误",
)
finally:
conn.close()
# ── GET /api/etl-status/recent-runs ──────────────────────────
@router.get("/recent-runs", response_model=list[RecentRun])
async def list_recent_runs(
user: CurrentUser = Depends(get_current_user),
) -> list[RecentRun]:
"""返回最近 50 条任务执行记录。
查询 zqyy_app 数据库中的 task_execution_log 表,
按 site_id 过滤,按 started_at DESC 排序。
"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, task_codes, status, started_at,
finished_at, duration_ms, exit_code
FROM task_execution_log
WHERE site_id = %s
ORDER BY started_at DESC
LIMIT %s
""",
(user.site_id, _RECENT_RUNS_LIMIT),
)
rows = cur.fetchall()
return [
RecentRun(
id=str(row[0]),
task_codes=list(row[1]) if row[1] else [],
status=row[2],
started_at=str(row[3]),
finished_at=str(row[4]) if row[4] is not None else None,
duration_ms=row[5],
exit_code=row[6],
)
for row in rows
]
except OperationalError as exc:
logger.error("执行记录查询连接错误: %s", exc)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="数据库连接错误",
)
finally:
conn.close()

View File

@@ -0,0 +1,281 @@
# -*- coding: utf-8 -*-
"""执行与队列 API
提供 8 个端点:
- POST /api/execution/run — 直接执行任务
- GET /api/execution/queue — 获取当前队列(按 site_id 过滤)
- POST /api/execution/queue — 添加到队列
- PUT /api/execution/queue/reorder — 重排队列
- DELETE /api/execution/queue/{id} — 删除队列任务
- POST /api/execution/{id}/cancel — 取消执行中的任务
- GET /api/execution/history — 执行历史(按 site_id 过滤)
- GET /api/execution/{id}/logs — 获取历史日志
所有端点需要 JWT 认证site_id 从 JWT 提取。
"""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, status
from app.auth.dependencies import CurrentUser, get_current_user
from app.database import get_connection
from app.schemas.execution import (
ExecutionHistoryItem,
ExecutionLogsResponse,
ExecutionRunResponse,
QueueTaskResponse,
ReorderRequest,
)
from app.schemas.tasks import TaskConfigSchema
from app.services.task_executor import task_executor
from app.services.task_queue import task_queue
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/execution", tags=["任务执行"])
# ── POST /api/execution/run — 直接执行任务 ────────────────────
@router.post("/run", response_model=ExecutionRunResponse)
async def run_task(
config: TaskConfigSchema,
user: CurrentUser = Depends(get_current_user),
) -> ExecutionRunResponse:
"""直接执行任务(不经过队列)。
从 JWT 注入 store_id创建 execution_id 后异步启动子进程。
"""
config = config.model_copy(update={"store_id": user.site_id})
execution_id = str(uuid.uuid4())
# 异步启动执行,不阻塞响应
asyncio.create_task(
task_executor.execute(
config=config,
execution_id=execution_id,
site_id=user.site_id,
)
)
return ExecutionRunResponse(
execution_id=execution_id,
message="任务已提交执行",
)
# ── GET /api/execution/queue — 获取当前队列 ───────────────────
@router.get("/queue", response_model=list[QueueTaskResponse])
async def get_queue(
user: CurrentUser = Depends(get_current_user),
) -> list[QueueTaskResponse]:
"""获取当前门店的待执行队列。"""
tasks = task_queue.list_pending(user.site_id)
return [
QueueTaskResponse(
id=t.id,
site_id=t.site_id,
config=t.config,
status=t.status,
position=t.position,
created_at=t.created_at,
started_at=t.started_at,
finished_at=t.finished_at,
exit_code=t.exit_code,
error_message=t.error_message,
)
for t in tasks
]
# ── POST /api/execution/queue — 添加到队列 ───────────────────
@router.post("/queue", response_model=QueueTaskResponse, status_code=status.HTTP_201_CREATED)
async def enqueue_task(
config: TaskConfigSchema,
user: CurrentUser = Depends(get_current_user),
) -> QueueTaskResponse:
"""将任务配置添加到执行队列。"""
config = config.model_copy(update={"store_id": user.site_id})
task_id = task_queue.enqueue(config, user.site_id)
# 查询刚创建的任务返回完整信息
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, site_id, config, status, position,
created_at, started_at, finished_at,
exit_code, error_message
FROM task_queue WHERE id = %s
""",
(task_id,),
)
row = cur.fetchone()
conn.commit()
finally:
conn.close()
if row is None:
raise HTTPException(status_code=500, detail="入队后查询失败")
config_data = row[2] if isinstance(row[2], dict) else json.loads(row[2])
return QueueTaskResponse(
id=str(row[0]),
site_id=row[1],
config=config_data,
status=row[3],
position=row[4],
created_at=row[5],
started_at=row[6],
finished_at=row[7],
exit_code=row[8],
error_message=row[9],
)
# ── PUT /api/execution/queue/reorder — 重排队列 ──────────────
@router.put("/queue/reorder")
async def reorder_queue(
body: ReorderRequest,
user: CurrentUser = Depends(get_current_user),
) -> dict:
"""调整队列中任务的执行顺序。"""
task_queue.reorder(body.task_id, body.new_position, user.site_id)
return {"message": "队列已重排"}
# ── DELETE /api/execution/queue/{id} — 删除队列任务 ──────────
@router.delete("/queue/{task_id}")
async def delete_queue_task(
task_id: str,
user: CurrentUser = Depends(get_current_user),
) -> dict:
"""从队列中删除待执行任务。仅允许删除 pending 状态的任务。"""
deleted = task_queue.delete(task_id, user.site_id)
if not deleted:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="任务不存在或非待执行状态,无法删除",
)
return {"message": "任务已从队列中删除"}
# ── POST /api/execution/{id}/cancel — 取消执行 ──────────────
@router.post("/{execution_id}/cancel")
async def cancel_execution(
execution_id: str,
user: CurrentUser = Depends(get_current_user),
) -> dict:
"""取消正在执行的任务。"""
cancelled = await task_executor.cancel(execution_id)
if not cancelled:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="执行任务不存在或已完成",
)
return {"message": "已发送取消信号"}
# ── GET /api/execution/history — 执行历史 ────────────────────
@router.get("/history", response_model=list[ExecutionHistoryItem])
async def get_execution_history(
limit: int = Query(default=50, ge=1, le=200),
user: CurrentUser = Depends(get_current_user),
) -> list[ExecutionHistoryItem]:
"""获取执行历史记录,按 started_at 降序排列。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, site_id, task_codes, status, started_at,
finished_at, exit_code, duration_ms, command, summary
FROM task_execution_log
WHERE site_id = %s
ORDER BY started_at DESC
LIMIT %s
""",
(user.site_id, limit),
)
rows = cur.fetchall()
conn.commit()
finally:
conn.close()
return [
ExecutionHistoryItem(
id=str(row[0]),
site_id=row[1],
task_codes=row[2] or [],
status=row[3],
started_at=row[4],
finished_at=row[5],
exit_code=row[6],
duration_ms=row[7],
command=row[8],
summary=row[9],
)
for row in rows
]
# ── GET /api/execution/{id}/logs — 获取历史日志 ──────────────
@router.get("/{execution_id}/logs", response_model=ExecutionLogsResponse)
async def get_execution_logs(
execution_id: str,
user: CurrentUser = Depends(get_current_user),
) -> ExecutionLogsResponse:
"""获取指定执行的完整日志。
优先从内存缓冲区读取(执行中),否则从数据库读取(已完成)。
"""
# 先尝试内存缓冲区(执行中的任务)
if task_executor.is_running(execution_id):
lines = task_executor.get_logs(execution_id)
return ExecutionLogsResponse(
execution_id=execution_id,
output_log="\n".join(lines) if lines else None,
)
# 从数据库读取已完成任务的日志
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT output_log, error_log
FROM task_execution_log
WHERE id = %s AND site_id = %s
""",
(execution_id, user.site_id),
)
row = cur.fetchone()
conn.commit()
finally:
conn.close()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="执行记录不存在",
)
return ExecutionLogsResponse(
execution_id=execution_id,
output_log=row[0],
error_log=row[1],
)

View File

@@ -0,0 +1,293 @@
# -*- coding: utf-8 -*-
"""调度任务 CRUD API
提供 5 个端点:
- GET /api/schedules — 列表(按 site_id 过滤)
- POST /api/schedules — 创建
- PUT /api/schedules/{id} — 更新
- DELETE /api/schedules/{id} — 删除
- PATCH /api/schedules/{id}/toggle — 启用/禁用
所有端点需要 JWT 认证site_id 从 JWT 提取。
"""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, status
from app.auth.dependencies import CurrentUser, get_current_user
from app.database import get_connection
from app.schemas.schedules import (
CreateScheduleRequest,
ScheduleResponse,
UpdateScheduleRequest,
)
from app.services.scheduler import calculate_next_run
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/schedules", tags=["调度管理"])
def _row_to_response(row) -> ScheduleResponse:
"""将数据库行转换为 ScheduleResponse。"""
task_config = row[4] if isinstance(row[4], dict) else json.loads(row[4])
schedule_config = row[5] if isinstance(row[5], dict) else json.loads(row[5])
return ScheduleResponse(
id=str(row[0]),
site_id=row[1],
name=row[2],
task_codes=row[3] or [],
task_config=task_config,
schedule_config=schedule_config,
enabled=row[6],
last_run_at=row[7],
next_run_at=row[8],
run_count=row[9],
last_status=row[10],
created_at=row[11],
updated_at=row[12],
)
# 查询列列表,复用于多个端点
_SELECT_COLS = """
id, site_id, name, task_codes, task_config, schedule_config,
enabled, last_run_at, next_run_at, run_count, last_status,
created_at, updated_at
"""
# ── GET /api/schedules — 列表 ────────────────────────────────
@router.get("", response_model=list[ScheduleResponse])
async def list_schedules(
user: CurrentUser = Depends(get_current_user),
) -> list[ScheduleResponse]:
"""获取当前门店的所有调度任务。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
f"SELECT {_SELECT_COLS} FROM scheduled_tasks WHERE site_id = %s ORDER BY created_at DESC",
(user.site_id,),
)
rows = cur.fetchall()
conn.commit()
finally:
conn.close()
return [_row_to_response(row) for row in rows]
# ── POST /api/schedules — 创建 ──────────────────────────────
@router.post("", response_model=ScheduleResponse, status_code=status.HTTP_201_CREATED)
async def create_schedule(
body: CreateScheduleRequest,
user: CurrentUser = Depends(get_current_user),
) -> ScheduleResponse:
"""创建调度任务,自动计算 next_run_at。"""
now = datetime.now(timezone.utc)
next_run = calculate_next_run(body.schedule_config, now)
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
f"""
INSERT INTO scheduled_tasks
(site_id, name, task_codes, task_config, schedule_config, enabled, next_run_at)
VALUES (%s, %s, %s, %s, %s, %s, %s)
RETURNING {_SELECT_COLS}
""",
(
user.site_id,
body.name,
body.task_codes,
json.dumps(body.task_config),
body.schedule_config.model_dump_json(),
body.schedule_config.enabled,
next_run,
),
)
row = cur.fetchone()
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
return _row_to_response(row)
# ── PUT /api/schedules/{id} — 更新 ──────────────────────────
@router.put("/{schedule_id}", response_model=ScheduleResponse)
async def update_schedule(
schedule_id: str,
body: UpdateScheduleRequest,
user: CurrentUser = Depends(get_current_user),
) -> ScheduleResponse:
"""更新调度任务,仅更新请求中提供的字段。"""
# 构建动态 SET 子句
set_parts: list[str] = []
params: list = []
if body.name is not None:
set_parts.append("name = %s")
params.append(body.name)
if body.task_codes is not None:
set_parts.append("task_codes = %s")
params.append(body.task_codes)
if body.task_config is not None:
set_parts.append("task_config = %s")
params.append(json.dumps(body.task_config))
if body.schedule_config is not None:
set_parts.append("schedule_config = %s")
params.append(body.schedule_config.model_dump_json())
# 更新调度配置时重新计算 next_run_at
now = datetime.now(timezone.utc)
next_run = calculate_next_run(body.schedule_config, now)
set_parts.append("next_run_at = %s")
params.append(next_run)
if not set_parts:
raise HTTPException(
status_code=422,
detail="至少需要提供一个更新字段",
)
set_parts.append("updated_at = NOW()")
set_clause = ", ".join(set_parts)
params.extend([schedule_id, user.site_id])
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
f"""
UPDATE scheduled_tasks
SET {set_clause}
WHERE id = %s AND site_id = %s
RETURNING {_SELECT_COLS}
""",
params,
)
row = cur.fetchone()
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="调度任务不存在",
)
return _row_to_response(row)
# ── DELETE /api/schedules/{id} — 删除 ────────────────────────
@router.delete("/{schedule_id}")
async def delete_schedule(
schedule_id: str,
user: CurrentUser = Depends(get_current_user),
) -> dict:
"""删除调度任务。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"DELETE FROM scheduled_tasks WHERE id = %s AND site_id = %s",
(schedule_id, user.site_id),
)
deleted = cur.rowcount
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
if deleted == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="调度任务不存在",
)
return {"message": "调度任务已删除"}
# ── PATCH /api/schedules/{id}/toggle — 启用/禁用 ─────────────
@router.patch("/{schedule_id}/toggle", response_model=ScheduleResponse)
async def toggle_schedule(
schedule_id: str,
user: CurrentUser = Depends(get_current_user),
) -> ScheduleResponse:
"""切换调度任务的启用/禁用状态。
禁用时 next_run_at 置 NULL启用时重新计算 next_run_at。
"""
conn = get_connection()
try:
# 先查询当前状态和调度配置
with conn.cursor() as cur:
cur.execute(
"SELECT enabled, schedule_config FROM scheduled_tasks WHERE id = %s AND site_id = %s",
(schedule_id, user.site_id),
)
row = cur.fetchone()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="调度任务不存在",
)
current_enabled = row[0]
new_enabled = not current_enabled
if new_enabled:
# 启用:重新计算 next_run_at
schedule_config_raw = row[1] if isinstance(row[1], dict) else json.loads(row[1])
from app.schemas.schedules import ScheduleConfigSchema
schedule_cfg = ScheduleConfigSchema(**schedule_config_raw)
now = datetime.now(timezone.utc)
next_run = calculate_next_run(schedule_cfg, now)
else:
# 禁用next_run_at 置 NULL
next_run = None
with conn.cursor() as cur:
cur.execute(
f"""
UPDATE scheduled_tasks
SET enabled = %s, next_run_at = %s, updated_at = NOW()
WHERE id = %s AND site_id = %s
RETURNING {_SELECT_COLS}
""",
(new_enabled, next_run, schedule_id, user.site_id),
)
updated_row = cur.fetchone()
conn.commit()
except HTTPException:
raise
except Exception:
conn.rollback()
raise
finally:
conn.close()
return _row_to_response(updated_row)

View File

@@ -0,0 +1,209 @@
# -*- coding: utf-8 -*-
"""任务注册表 & 配置 API
提供 4 个端点:
- GET /api/tasks/registry — 按业务域分组的任务列表
- GET /api/tasks/dwd-tables — 按业务域分组的 DWD 表定义
- GET /api/tasks/flows — 7 种 Flow + 3 种处理模式
- POST /api/tasks/validate — 验证 TaskConfig 并返回 CLI 命令预览
所有端点需要 JWT 认证。validate 端点从 JWT 注入 store_id。
"""
from __future__ import annotations
from typing import Any
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from app.auth.dependencies import CurrentUser, get_current_user
from app.config import ETL_PROJECT_PATH
from app.schemas.tasks import (
FlowDefinition,
ProcessingModeDefinition,
TaskConfigSchema,
)
from app.services.cli_builder import cli_builder
from app.services.task_registry import (
DWD_TABLES,
FLOW_LAYER_MAP,
get_dwd_tables_grouped_by_domain,
get_tasks_grouped_by_domain,
)
router = APIRouter(prefix="/api/tasks", tags=["任务配置"])
# ── 响应模型 ──────────────────────────────────────────────────
class TaskItem(BaseModel):
code: str
name: str
description: str
domain: str
layer: str
requires_window: bool
is_ods: bool
is_dimension: bool
default_enabled: bool
is_common: bool
class DwdTableItem(BaseModel):
table_name: str
display_name: str
domain: str
ods_source: str
is_dimension: bool
class TaskRegistryResponse(BaseModel):
"""按业务域分组的任务列表"""
groups: dict[str, list[TaskItem]]
class DwdTablesResponse(BaseModel):
"""按业务域分组的 DWD 表定义"""
groups: dict[str, list[DwdTableItem]]
class FlowsResponse(BaseModel):
"""Flow 定义 + 处理模式定义"""
flows: list[FlowDefinition]
processing_modes: list[ProcessingModeDefinition]
class ValidateRequest(BaseModel):
"""验证请求体 — 复用 TaskConfigSchema但 store_id 由后端注入"""
config: TaskConfigSchema
class ValidateResponse(BaseModel):
"""验证结果 + CLI 命令预览"""
valid: bool
command: str
command_args: list[str]
errors: list[str]
# ── Flow 定义(静态) ────────────────────────────────────────
FLOW_DEFINITIONS: list[FlowDefinition] = [
FlowDefinition(id="api_ods", name="API → ODS", layers=["ODS"]),
FlowDefinition(id="api_ods_dwd", name="API → ODS → DWD", layers=["ODS", "DWD"]),
FlowDefinition(id="api_full", name="API → ODS → DWD → DWS汇总 → DWS指数", layers=["ODS", "DWD", "DWS", "INDEX"]),
FlowDefinition(id="ods_dwd", name="ODS → DWD", layers=["DWD"]),
FlowDefinition(id="dwd_dws", name="DWD → DWS汇总", layers=["DWS"]),
FlowDefinition(id="dwd_dws_index", name="DWD → DWS汇总 → DWS指数", layers=["DWS", "INDEX"]),
FlowDefinition(id="dwd_index", name="DWD → DWS指数", layers=["INDEX"]),
]
PROCESSING_MODE_DEFINITIONS: list[ProcessingModeDefinition] = [
ProcessingModeDefinition(id="increment_only", name="仅增量处理", description="只处理新增和变更的数据"),
ProcessingModeDefinition(id="verify_only", name="仅校验修复", description="校验现有数据并修复不一致(可选'校验前从 API 获取'"),
ProcessingModeDefinition(id="increment_verify", name="增量 + 校验修复", description="先增量处理,再校验并修复"),
]
# ── 端点 ──────────────────────────────────────────────────────
@router.get("/registry", response_model=TaskRegistryResponse)
async def get_task_registry(
user: CurrentUser = Depends(get_current_user),
) -> TaskRegistryResponse:
"""返回按业务域分组的任务列表"""
grouped = get_tasks_grouped_by_domain()
return TaskRegistryResponse(
groups={
domain: [
TaskItem(
code=t.code,
name=t.name,
description=t.description,
domain=t.domain,
layer=t.layer,
requires_window=t.requires_window,
is_ods=t.is_ods,
is_dimension=t.is_dimension,
default_enabled=t.default_enabled,
is_common=t.is_common,
)
for t in tasks
]
for domain, tasks in grouped.items()
}
)
@router.get("/dwd-tables", response_model=DwdTablesResponse)
async def get_dwd_tables(
user: CurrentUser = Depends(get_current_user),
) -> DwdTablesResponse:
"""返回按业务域分组的 DWD 表定义"""
grouped = get_dwd_tables_grouped_by_domain()
return DwdTablesResponse(
groups={
domain: [
DwdTableItem(
table_name=t.table_name,
display_name=t.display_name,
domain=t.domain,
ods_source=t.ods_source,
is_dimension=t.is_dimension,
)
for t in tables
]
for domain, tables in grouped.items()
}
)
@router.get("/flows", response_model=FlowsResponse)
async def get_flows(
user: CurrentUser = Depends(get_current_user),
) -> FlowsResponse:
"""返回 7 种 Flow 定义和 3 种处理模式定义"""
return FlowsResponse(
flows=FLOW_DEFINITIONS,
processing_modes=PROCESSING_MODE_DEFINITIONS,
)
@router.post("/validate", response_model=ValidateResponse)
async def validate_task_config(
body: ValidateRequest,
user: CurrentUser = Depends(get_current_user),
) -> ValidateResponse:
"""验证 TaskConfig 并返回生成的 CLI 命令预览
从 JWT 注入 store_id前端无需传递。
"""
config = body.config.model_copy(update={"store_id": user.site_id})
errors: list[str] = []
# 验证 Flow ID
if config.pipeline not in FLOW_LAYER_MAP:
errors.append(f"无效的执行流程: {config.pipeline}")
# 验证任务列表非空
if not config.tasks:
errors.append("任务列表不能为空")
if errors:
return ValidateResponse(
valid=False,
command="",
command_args=[],
errors=errors,
)
cmd_args = cli_builder.build_command(config, ETL_PROJECT_PATH)
cmd_str = cli_builder.build_command_string(config, ETL_PROJECT_PATH)
return ValidateResponse(
valid=True,
command=cmd_str,
command_args=cmd_args,
errors=[],
)

View File

@@ -0,0 +1,104 @@
# AI_CHANGELOG
# - 2026-02-19 | Prompt: 配置微信消息推送 | 新增微信消息推送回调接口,支持 GET 验签 + POST 消息接收
"""
微信消息推送回调接口
处理两类请求:
1. GET — 微信服务器验证(配置时触发一次)
2. POST — 接收微信推送的消息/事件
安全模式下需要解密消息体,当前先用明文模式跑通,后续切安全模式。
"""
import hashlib
import logging
from fastapi import APIRouter, Query, Request, Response
from app.config import get
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/wx", tags=["微信回调"])
# Token 从环境变量读取,与微信后台填写的一致
# 放在 apps/backend/.env.local 中WX_CALLBACK_TOKEN=你自定义的token
WX_CALLBACK_TOKEN: str = get("WX_CALLBACK_TOKEN", "")
def _check_signature(signature: str, timestamp: str, nonce: str) -> bool:
"""
验证请求是否来自微信服务器。
将 Token、timestamp、nonce 字典序排序后拼接,做 SHA1
与 signature 比对。
"""
if not WX_CALLBACK_TOKEN:
logger.error("WX_CALLBACK_TOKEN 未配置")
return False
items = sorted([WX_CALLBACK_TOKEN, timestamp, nonce])
hash_str = hashlib.sha1("".join(items).encode("utf-8")).hexdigest()
return hash_str == signature
@router.get("/callback")
async def verify(
signature: str = Query(...),
timestamp: str = Query(...),
nonce: str = Query(...),
echostr: str = Query(...),
):
"""
微信服务器验证接口。
配置消息推送时微信会发 GET 请求,验签通过后原样返回 echostr。
"""
if _check_signature(signature, timestamp, nonce):
logger.info("微信回调验证通过")
# 必须原样返回 echostr纯文本不能包裹 JSON
return Response(content=echostr, media_type="text/plain")
else:
logger.warning("微信回调验签失败: signature=%s", signature)
return Response(content="signature mismatch", status_code=403)
@router.post("/callback")
async def receive_message(
request: Request,
signature: str = Query(""),
timestamp: str = Query(""),
nonce: str = Query(""),
):
"""
接收微信推送的消息/事件。
当前为明文模式,直接解析 JSON 包体。
后续切安全模式时需增加 AES 解密逻辑。
"""
# 验签POST 也带 signature 参数)
if not _check_signature(signature, timestamp, nonce):
logger.warning("消息推送验签失败")
return Response(content="signature mismatch", status_code=403)
# 解析消息体
body = await request.body()
content_type = request.headers.get("content-type", "")
if "json" in content_type:
import json
try:
data = json.loads(body)
except json.JSONDecodeError:
data = {"raw": body.decode("utf-8", errors="replace")}
else:
# XML 格式暂不解析,记录原文
data = {"raw_xml": body.decode("utf-8", errors="replace")}
logger.info("收到微信推送: MsgType=%s, Event=%s",
data.get("MsgType", "?"), data.get("Event", "?"))
# TODO: 根据 MsgType/Event 分发处理(客服消息、订阅事件等)
# 当前统一返回 success
return Response(content="success", media_type="text/plain")

View File

@@ -0,0 +1,37 @@
# AI_CHANGELOG
# - 2026-02-19 | Prompt: 小程序 MVP 全链路验证 | 新增 /api/xcx-test 接口,从 test."xcx-test" 表读取 ti 列第一行
"""
小程序 MVP 验证接口
从 test_zqyy_app 库的 test."xcx-test" 表读取数据,
用于验证小程序 → 后端 → 数据库全链路连通性。
"""
from fastapi import APIRouter, HTTPException
from app.database import get_connection
router = APIRouter(prefix="/api/xcx-test", tags=["小程序MVP"])
@router.get("")
async def get_xcx_test():
"""
读取 test."xcx-test" 表 ti 列第一行。
用于小程序 MVP 全链路验证:小程序 → API → DB → 返回数据。
"""
conn = get_connection()
try:
with conn.cursor() as cur:
# CHANGE 2026-02-19 | 读取 test schema 下的 xcx-test 表
# 表名含连字符,必须用双引号包裹
cur.execute('SELECT ti FROM test."xcx-test" LIMIT 1')
row = cur.fetchone()
finally:
conn.close()
if row is None:
raise HTTPException(status_code=404, detail="无数据")
return {"ti": row[0]}

View File

@@ -0,0 +1,30 @@
"""
认证相关 Pydantic 模型。
- LoginRequest登录请求体
- TokenResponse令牌响应体
- RefreshRequest刷新令牌请求体
"""
from pydantic import BaseModel, Field
class LoginRequest(BaseModel):
"""登录请求。"""
username: str = Field(..., min_length=1, max_length=64, description="用户名")
password: str = Field(..., min_length=1, description="密码")
class RefreshRequest(BaseModel):
"""刷新令牌请求。"""
refresh_token: str = Field(..., min_length=1, description="刷新令牌")
class TokenResponse(BaseModel):
"""令牌响应。"""
access_token: str
refresh_token: str
token_type: str = "bearer"

View File

@@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
"""数据库查看器 Pydantic 模型
定义 Schema 浏览、表结构查看、SQL 查询的请求/响应模型。
"""
from __future__ import annotations
from typing import Any
from pydantic import BaseModel
class SchemaInfo(BaseModel):
"""Schema 信息。"""
name: str
class TableInfo(BaseModel):
"""表信息(含行数统计)。"""
name: str
row_count: int | None = None
class ColumnInfo(BaseModel):
"""列定义。"""
name: str
data_type: str
is_nullable: bool
column_default: str | None = None
class QueryRequest(BaseModel):
"""SQL 查询请求。"""
sql: str
class QueryResponse(BaseModel):
"""SQL 查询响应。"""
columns: list[str]
rows: list[list[Any]]
row_count: int

View File

@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
"""ETL 状态监控 Pydantic 模型
定义游标信息和最近执行记录的响应模型。
"""
from __future__ import annotations
from pydantic import BaseModel
class CursorInfo(BaseModel):
"""ETL 游标信息(单条任务的最后抓取状态)。"""
task_code: str
last_fetch_time: str | None = None
record_count: int | None = None
class RecentRun(BaseModel):
"""最近执行记录。"""
id: str
task_codes: list[str]
status: str
started_at: str
finished_at: str | None = None
duration_ms: int | None = None
exit_code: int | None = None

View File

@@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
"""执行与队列相关的 Pydantic 模型
用于 execution 路由的请求/响应序列化。
"""
from __future__ import annotations
from datetime import datetime
from typing import Any
from pydantic import BaseModel
class ReorderRequest(BaseModel):
"""队列重排请求"""
task_id: str
new_position: int
class QueueTaskResponse(BaseModel):
"""队列任务响应"""
id: str
site_id: int
config: dict[str, Any]
status: str
position: int
created_at: datetime | None = None
started_at: datetime | None = None
finished_at: datetime | None = None
exit_code: int | None = None
error_message: str | None = None
class ExecutionRunResponse(BaseModel):
"""直接执行任务的响应"""
execution_id: str
message: str
class ExecutionHistoryItem(BaseModel):
"""执行历史记录"""
id: str
site_id: int
task_codes: list[str]
status: str
started_at: datetime
finished_at: datetime | None = None
exit_code: int | None = None
duration_ms: int | None = None
command: str | None = None
summary: dict[str, Any] | None = None
class ExecutionLogsResponse(BaseModel):
"""执行日志响应"""
execution_id: str
output_log: str | None = None
error_log: str | None = None

View File

@@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
"""调度配置 Pydantic 模型
定义 ScheduleConfigSchema 及相关模型,供调度服务和路由使用。
"""
from datetime import datetime
from typing import Any, Literal
from pydantic import BaseModel
class ScheduleConfigSchema(BaseModel):
"""调度配置 — 支持 5 种调度类型"""
schedule_type: Literal["once", "interval", "daily", "weekly", "cron"]
interval_value: int = 1
interval_unit: Literal["minutes", "hours", "days"] = "hours"
daily_time: str = "04:00"
weekly_days: list[int] = [1]
weekly_time: str = "04:00"
cron_expression: str = "0 4 * * *"
enabled: bool = True
start_date: str | None = None
end_date: str | None = None
class CreateScheduleRequest(BaseModel):
"""创建调度任务请求"""
name: str
task_codes: list[str]
task_config: dict[str, Any]
schedule_config: ScheduleConfigSchema
class UpdateScheduleRequest(BaseModel):
"""更新调度任务请求(所有字段可选)"""
name: str | None = None
task_codes: list[str] | None = None
task_config: dict[str, Any] | None = None
schedule_config: ScheduleConfigSchema | None = None
class ScheduleResponse(BaseModel):
"""调度任务响应"""
id: str
site_id: int
name: str
task_codes: list[str]
task_config: dict[str, Any]
schedule_config: dict[str, Any]
enabled: bool
last_run_at: datetime | None = None
next_run_at: datetime | None = None
run_count: int
last_status: str | None = None
created_at: datetime
updated_at: datetime

View File

@@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-
"""任务配置 Pydantic 模型
定义 TaskConfigSchema 及相关模型,用于前后端传输和 CLIBuilder 消费。
"""
from typing import Any
from pydantic import BaseModel, model_validator
class TaskConfigSchema(BaseModel):
"""任务配置 — 前后端传输格式
字段与 CLI 参数的映射关系:
- pipeline → --pipelineFlow ID7 种之一)
- processing_mode → --processing-mode3 种处理模式)
- tasks → --tasks逗号分隔
- dry_run → --dry-run布尔标志
- window_mode → 决定使用 lookback 还是 custom 时间窗口(仅前端逻辑,不直接映射 CLI 参数)
- window_start → --window-start
- window_end → --window-end
- window_split → --window-split
- window_split_days → --window-split-days
- lookback_hours → --lookback-hours
- overlap_seconds → --overlap-seconds
- fetch_before_verify → --fetch-before-verify布尔标志
- store_id → --store-id由后端从 JWT 注入,前端不传)
- dwd_only_tables → 传入 extra_args 或未来扩展
"""
tasks: list[str]
pipeline: str = "api_ods_dwd"
processing_mode: str = "increment_only"
dry_run: bool = False
window_mode: str = "lookback"
window_start: str | None = None
window_end: str | None = None
window_split: str | None = None
window_split_days: int | None = None
lookback_hours: int = 24
overlap_seconds: int = 600
fetch_before_verify: bool = False
skip_ods_when_fetch_before_verify: bool = False
ods_use_local_json: bool = False
store_id: int | None = None
dwd_only_tables: list[str] | None = None
force_full: bool = False
extra_args: dict[str, Any] = {}
@model_validator(mode="after")
def validate_window(self) -> "TaskConfigSchema":
"""验证时间窗口:结束日期不早于开始日期"""
if self.window_start and self.window_end:
if self.window_end < self.window_start:
raise ValueError("window_end 不能早于 window_start")
return self
class FlowDefinition(BaseModel):
"""执行流程Flow定义"""
id: str
name: str
layers: list[str]
class ProcessingModeDefinition(BaseModel):
"""处理模式定义"""
id: str
name: str
description: str

View File

@@ -0,0 +1 @@
# -*- coding: utf-8 -*-

View File

@@ -0,0 +1,158 @@
# -*- coding: utf-8 -*-
"""CLI 命令构建器
从 gui/utils/cli_builder.py 迁移,适配后端 TaskConfigSchema。
将 TaskConfigSchema 转换为 ETL CLI 命令行参数列表。
支持:
- 7 种 Flowapi_ods / api_ods_dwd / api_full / ods_dwd / dwd_dws / dwd_dws_index / dwd_index
- 3 种处理模式increment_only / verify_only / increment_verify
- 自动注入 --store-id 参数
"""
from typing import Any
from ..schemas.tasks import TaskConfigSchema
# 有效的 Flow ID 集合
VALID_FLOWS: set[str] = {
"api_ods",
"api_ods_dwd",
"api_full",
"ods_dwd",
"dwd_dws",
"dwd_dws_index",
"dwd_index",
}
# 有效的处理模式集合
VALID_PROCESSING_MODES: set[str] = {
"increment_only",
"verify_only",
"increment_verify",
}
# CLI 支持的 extra_args 键(值类型 + 布尔类型)
CLI_SUPPORTED_ARGS: set[str] = {
# 值类型参数
"pg_dsn", "pg_host", "pg_port", "pg_name",
"pg_user", "pg_password", "api_base", "api_token", "api_timeout",
"api_page_size", "api_retry_max",
"export_root", "log_root", "fetch_root",
"ingest_source", "idle_start", "idle_end",
"data_source", "pipeline_flow",
"window_split_unit",
# 布尔类型参数
"force_window_override", "write_pretty_json", "allow_empty_advance",
}
class CLIBuilder:
"""将 TaskConfigSchema 转换为 ETL CLI 命令行参数列表"""
def build_command(
self,
config: TaskConfigSchema,
etl_project_path: str,
python_executable: str = "python",
) -> list[str]:
"""构建完整的 CLI 命令参数列表。
生成格式:
[python, -m, cli.main, --flow, {flow_id}, --tasks, ..., --store-id, {site_id}, ...]
Args:
config: 任务配置对象Pydantic 模型)
etl_project_path: ETL 项目根目录路径(用于 cwd不拼入命令
python_executable: Python 可执行文件路径,默认 "python"
Returns:
命令行参数列表
"""
cmd: list[str] = [python_executable, "-m", "cli.main"]
# -- Flow执行流程 --
cmd.extend(["--flow", config.pipeline])
# -- 处理模式 --
if config.processing_mode:
cmd.extend(["--processing-mode", config.processing_mode])
# -- 任务列表 --
if config.tasks:
cmd.extend(["--tasks", ",".join(config.tasks)])
# -- 校验前从 API 获取数据(仅 verify_only 模式有效) --
if config.fetch_before_verify and config.processing_mode == "verify_only":
cmd.append("--fetch-before-verify")
# -- 时间窗口 --
if config.window_mode == "lookback":
# 回溯模式
if config.lookback_hours is not None:
cmd.extend(["--lookback-hours", str(config.lookback_hours)])
if config.overlap_seconds is not None:
cmd.extend(["--overlap-seconds", str(config.overlap_seconds)])
else:
# 自定义时间窗口
if config.window_start:
cmd.extend(["--window-start", config.window_start])
if config.window_end:
cmd.extend(["--window-end", config.window_end])
# -- 时间窗口切分 --
if config.window_split and config.window_split != "none":
cmd.extend(["--window-split", config.window_split])
if config.window_split_days is not None:
cmd.extend(["--window-split-days", str(config.window_split_days)])
# -- Dry-run --
if config.dry_run:
cmd.append("--dry-run")
# -- 强制全量处理 --
if config.force_full:
cmd.append("--force-full")
# -- 本地 JSON 模式 → --data-source offline --
if config.ods_use_local_json:
cmd.extend(["--data-source", "offline"])
# -- 门店 ID自动注入 --
if config.store_id is not None:
cmd.extend(["--store-id", str(config.store_id)])
# -- 额外参数(只传递 CLI 支持的参数) --
for key, value in config.extra_args.items():
if value is not None and key in CLI_SUPPORTED_ARGS:
arg_name = f"--{key.replace('_', '-')}"
if isinstance(value, bool):
if value:
cmd.append(arg_name)
else:
cmd.extend([arg_name, str(value)])
return cmd
def build_command_string(
self,
config: TaskConfigSchema,
etl_project_path: str,
python_executable: str = "python",
) -> str:
"""构建命令行字符串(用于显示/日志记录)。
对包含空格的参数自动添加引号。
"""
cmd = self.build_command(config, etl_project_path, python_executable)
quoted: list[str] = []
for arg in cmd:
if " " in arg or '"' in arg:
quoted.append(f'"{arg}"')
else:
quoted.append(arg)
return " ".join(quoted)
# 全局单例
cli_builder = CLIBuilder()

View File

@@ -0,0 +1,303 @@
# -*- coding: utf-8 -*-
"""调度器服务
后台 asyncio 循环,每 30 秒检查一次到期的调度任务,
将其 TaskConfig 入队到 TaskQueue。
核心逻辑:
- check_and_enqueue():查询 enabled=true 且 next_run_at <= now 的调度任务
- start() / stop():管理后台循环生命周期
- _calculate_next_run():根据 ScheduleConfig 计算下次执行时间
"""
from __future__ import annotations
import asyncio
import json
import logging
from datetime import datetime, timedelta, timezone
from ..database import get_connection
from ..schemas.schedules import ScheduleConfigSchema
from ..schemas.tasks import TaskConfigSchema
from .task_queue import task_queue
logger = logging.getLogger(__name__)
# 调度器轮询间隔(秒)
SCHEDULER_POLL_INTERVAL = 30
def _parse_time(time_str: str) -> tuple[int, int]:
"""解析 HH:MM 格式的时间字符串,返回 (hour, minute)。"""
parts = time_str.split(":")
return int(parts[0]), int(parts[1])
def calculate_next_run(
schedule_config: ScheduleConfigSchema,
now: datetime | None = None,
) -> datetime | None:
"""根据调度配置计算下次执行时间。
Args:
schedule_config: 调度配置
now: 当前时间(默认 UTC now方便测试注入
Returns:
下次执行时间UTConce 类型返回 None 表示不再执行
"""
if now is None:
now = datetime.now(timezone.utc)
stype = schedule_config.schedule_type
if stype == "once":
# 一次性任务执行后不再调度
return None
if stype == "interval":
unit_map = {
"minutes": timedelta(minutes=schedule_config.interval_value),
"hours": timedelta(hours=schedule_config.interval_value),
"days": timedelta(days=schedule_config.interval_value),
}
delta = unit_map.get(schedule_config.interval_unit)
if delta is None:
logger.warning("未知的 interval_unit: %s", schedule_config.interval_unit)
return None
return now + delta
if stype == "daily":
hour, minute = _parse_time(schedule_config.daily_time)
# 计算明天的 daily_time
tomorrow = now + timedelta(days=1)
return tomorrow.replace(hour=hour, minute=minute, second=0, microsecond=0)
if stype == "weekly":
hour, minute = _parse_time(schedule_config.weekly_time)
days = sorted(schedule_config.weekly_days) if schedule_config.weekly_days else [1]
# ISO weekday: 1=Monday ... 7=Sunday
current_weekday = now.isoweekday()
# 找到下一个匹配的 weekday
for day in days:
if day > current_weekday:
delta_days = day - current_weekday
next_dt = now + timedelta(days=delta_days)
return next_dt.replace(hour=hour, minute=minute, second=0, microsecond=0)
# 本周没有更晚的 weekday跳到下周第一个
first_day = days[0]
delta_days = 7 - current_weekday + first_day
next_dt = now + timedelta(days=delta_days)
return next_dt.replace(hour=hour, minute=minute, second=0, microsecond=0)
if stype == "cron":
# 简单 cron 解析:仅支持 "minute hour * * *" 格式(每日定时)
# 复杂 cron 表达式可后续引入 croniter 库
return _parse_simple_cron(schedule_config.cron_expression, now)
logger.warning("未知的 schedule_type: %s", stype)
return None
def _parse_simple_cron(expression: str, now: datetime) -> datetime | None:
"""简单 cron 解析器,支持基本的 5 字段格式。
支持的格式:
- "M H * * *" → 每天 H:M
- "M H * * D" → 每周 D 的 H:MD 为 0-60=Sunday
- 其他格式回退到每天 04:00
不支持范围、列表、步进等高级语法。如需完整 cron 支持,
可在 pyproject.toml 中添加 croniter 依赖。
"""
parts = expression.strip().split()
if len(parts) != 5:
logger.warning("无法解析 cron 表达式: %s,回退到明天 04:00", expression)
tomorrow = now + timedelta(days=1)
return tomorrow.replace(hour=4, minute=0, second=0, microsecond=0)
minute_str, hour_str, dom, month, dow = parts
try:
minute = int(minute_str) if minute_str != "*" else 0
hour = int(hour_str) if hour_str != "*" else 0
except ValueError:
logger.warning("cron 表达式时间字段无法解析: %s,回退到明天 04:00", expression)
tomorrow = now + timedelta(days=1)
return tomorrow.replace(hour=4, minute=0, second=0, microsecond=0)
# 如果指定了 day-of-week非 *
if dow != "*":
try:
cron_dow = int(dow) # 0=Sunday, 1=Monday, ..., 6=Saturday
except ValueError:
tomorrow = now + timedelta(days=1)
return tomorrow.replace(hour=hour, minute=minute, second=0, microsecond=0)
# 转换为 ISO weekday1=Monday, 7=Sunday
iso_dow = 7 if cron_dow == 0 else cron_dow
current_iso = now.isoweekday()
if iso_dow > current_iso:
delta_days = iso_dow - current_iso
elif iso_dow < current_iso:
delta_days = 7 - current_iso + iso_dow
else:
# 同一天,看时间是否已过
target_today = now.replace(hour=hour, minute=minute, second=0, microsecond=0)
if now < target_today:
delta_days = 0
else:
delta_days = 7
next_dt = now + timedelta(days=delta_days)
return next_dt.replace(hour=hour, minute=minute, second=0, microsecond=0)
# 每天定时dom=* month=* dow=*
tomorrow = now + timedelta(days=1)
return tomorrow.replace(hour=hour, minute=minute, second=0, microsecond=0)
class Scheduler:
"""基于 PostgreSQL 的定时调度器
后台 asyncio 循环每 SCHEDULER_POLL_INTERVAL 秒检查一次到期任务,
将其 TaskConfig 入队到 TaskQueue。
"""
def __init__(self) -> None:
self._running = False
self._loop_task: asyncio.Task | None = None
# ------------------------------------------------------------------
# 核心:检查到期任务并入队
# ------------------------------------------------------------------
def check_and_enqueue(self) -> int:
"""查询 enabled=true 且 next_run_at <= now 的调度任务,将其入队。
Returns:
本次入队的任务数量
"""
conn = get_connection()
enqueued = 0
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, site_id, task_config, schedule_config
FROM scheduled_tasks
WHERE enabled = TRUE
AND next_run_at IS NOT NULL
AND next_run_at <= NOW()
ORDER BY next_run_at ASC
"""
)
rows = cur.fetchall()
for row in rows:
task_id = str(row[0])
site_id = row[1]
task_config_raw = row[2] if isinstance(row[2], dict) else json.loads(row[2])
schedule_config_raw = row[3] if isinstance(row[3], dict) else json.loads(row[3])
try:
config = TaskConfigSchema(**task_config_raw)
schedule_cfg = ScheduleConfigSchema(**schedule_config_raw)
except Exception:
logger.exception("调度任务 [%s] 配置反序列化失败,跳过", task_id)
continue
# 入队
try:
queue_id = task_queue.enqueue(config, site_id)
logger.info(
"调度任务 [%s] 入队成功 → queue_id=%s site_id=%s",
task_id, queue_id, site_id,
)
enqueued += 1
except Exception:
logger.exception("调度任务 [%s] 入队失败", task_id)
continue
# 更新调度任务状态
now = datetime.now(timezone.utc)
next_run = calculate_next_run(schedule_cfg, now)
with conn.cursor() as cur:
cur.execute(
"""
UPDATE scheduled_tasks
SET last_run_at = NOW(),
run_count = run_count + 1,
next_run_at = %s,
last_status = 'enqueued',
updated_at = NOW()
WHERE id = %s
""",
(next_run, task_id),
)
conn.commit()
except Exception:
logger.exception("check_and_enqueue 执行异常")
try:
conn.rollback()
except Exception:
pass
finally:
conn.close()
if enqueued > 0:
logger.info("本轮调度检查:%d 个任务入队", enqueued)
return enqueued
# ------------------------------------------------------------------
# 后台循环
# ------------------------------------------------------------------
async def _loop(self) -> None:
"""后台 asyncio 循环,每 SCHEDULER_POLL_INTERVAL 秒检查一次。"""
self._running = True
logger.info("Scheduler 后台循环启动(间隔 %ds", SCHEDULER_POLL_INTERVAL)
while self._running:
try:
# 在线程池中执行同步数据库操作,避免阻塞事件循环
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self.check_and_enqueue)
except Exception:
logger.exception("Scheduler 循环迭代异常")
await asyncio.sleep(SCHEDULER_POLL_INTERVAL)
logger.info("Scheduler 后台循环停止")
# ------------------------------------------------------------------
# 生命周期
# ------------------------------------------------------------------
def start(self) -> None:
"""启动后台调度循环(在 FastAPI lifespan 中调用)。"""
if self._loop_task is None or self._loop_task.done():
self._loop_task = asyncio.create_task(self._loop())
logger.info("Scheduler 已启动")
async def stop(self) -> None:
"""停止后台调度循环。"""
self._running = False
if self._loop_task and not self._loop_task.done():
self._loop_task.cancel()
try:
await self._loop_task
except asyncio.CancelledError:
pass
self._loop_task = None
logger.info("Scheduler 已停止")
# 全局单例
scheduler = Scheduler()

View File

@@ -0,0 +1,391 @@
# -*- coding: utf-8 -*-
"""ETL 任务执行器
通过 asyncio.create_subprocess_exec 启动 ETL CLI 子进程,
逐行读取 stdout/stderr 并广播到 WebSocket 订阅者,
执行完成后将结果写入 task_execution_log 表。
设计要点:
- 每个 execution_id 对应一个子进程,存储在 _processes 字典中
- 日志行存储在内存缓冲区 _log_buffers 中
- WebSocket 订阅者通过 asyncio.Queue 接收实时日志
- Windows 兼容:取消时使用 process.terminate() 而非 SIGTERM
"""
from __future__ import annotations
import asyncio
import logging
import subprocess
import sys
import threading
import time
from datetime import datetime, timezone
from typing import Any
from ..config import ETL_PROJECT_PATH
from ..database import get_connection
from ..schemas.tasks import TaskConfigSchema
from ..services.cli_builder import cli_builder
logger = logging.getLogger(__name__)
class TaskExecutor:
"""管理 ETL CLI 子进程的生命周期"""
def __init__(self) -> None:
# execution_id → subprocess.Popen
self._processes: dict[str, subprocess.Popen] = {}
# execution_id → list[str]stdout + stderr 混合日志)
self._log_buffers: dict[str, list[str]] = {}
# execution_id → set[asyncio.Queue]WebSocket 订阅者)
self._subscribers: dict[str, set[asyncio.Queue[str | None]]] = {}
# ------------------------------------------------------------------
# WebSocket 订阅管理
# ------------------------------------------------------------------
def subscribe(self, execution_id: str) -> asyncio.Queue[str | None]:
"""注册一个 WebSocket 订阅者,返回用于读取日志行的 Queue。
Queue 中推送 str 表示日志行None 表示执行结束。
"""
if execution_id not in self._subscribers:
self._subscribers[execution_id] = set()
queue: asyncio.Queue[str | None] = asyncio.Queue()
self._subscribers[execution_id].add(queue)
return queue
def unsubscribe(self, execution_id: str, queue: asyncio.Queue[str | None]) -> None:
"""移除一个 WebSocket 订阅者。"""
subs = self._subscribers.get(execution_id)
if subs:
subs.discard(queue)
if not subs:
del self._subscribers[execution_id]
def _broadcast(self, execution_id: str, line: str) -> None:
"""向所有订阅者广播一行日志。"""
subs = self._subscribers.get(execution_id)
if subs:
for q in subs:
q.put_nowait(line)
def _broadcast_end(self, execution_id: str) -> None:
"""通知所有订阅者执行已结束(发送 None 哨兵)。"""
subs = self._subscribers.get(execution_id)
if subs:
for q in subs:
q.put_nowait(None)
# ------------------------------------------------------------------
# 日志缓冲区
# ------------------------------------------------------------------
def get_logs(self, execution_id: str) -> list[str]:
"""获取指定执行的内存日志缓冲区(副本)。"""
return list(self._log_buffers.get(execution_id, []))
# ------------------------------------------------------------------
# 执行状态查询
# ------------------------------------------------------------------
def is_running(self, execution_id: str) -> bool:
"""判断指定执行是否仍在运行。"""
proc = self._processes.get(execution_id)
if proc is None:
return False
return proc.poll() is None
def get_running_ids(self) -> list[str]:
"""返回当前所有运行中的 execution_id 列表。"""
return [eid for eid, p in self._processes.items() if p.returncode is None]
# ------------------------------------------------------------------
# 核心执行
# ------------------------------------------------------------------
async def execute(
self,
config: TaskConfigSchema,
execution_id: str,
queue_id: str | None = None,
site_id: int | None = None,
) -> None:
"""以子进程方式调用 ETL CLI。
使用 subprocess.Popen + 线程读取,兼容 Windows避免
asyncio.create_subprocess_exec 在 Windows 上的 NotImplementedError
"""
cmd = cli_builder.build_command(
config, ETL_PROJECT_PATH, python_executable=sys.executable
)
command_str = " ".join(cmd)
effective_site_id = site_id or config.store_id
logger.info(
"启动 ETL 子进程 [%s]: %s (cwd=%s)",
execution_id, command_str, ETL_PROJECT_PATH,
)
self._log_buffers[execution_id] = []
started_at = datetime.now(timezone.utc)
t0 = time.monotonic()
self._write_execution_log(
execution_id=execution_id,
queue_id=queue_id,
site_id=effective_site_id,
task_codes=config.tasks,
status="running",
started_at=started_at,
command=command_str,
)
exit_code: int | None = None
status = "running"
stdout_lines: list[str] = []
stderr_lines: list[str] = []
try:
# 构建额外环境变量DWD 表过滤通过环境变量注入)
extra_env: dict[str, str] = {}
if config.dwd_only_tables:
extra_env["DWD_ONLY_TABLES"] = ",".join(config.dwd_only_tables)
# 在线程池中运行子进程,兼容 Windows
exit_code = await asyncio.get_event_loop().run_in_executor(
None,
self._run_subprocess,
cmd,
execution_id,
stdout_lines,
stderr_lines,
extra_env or None,
)
if exit_code == 0:
status = "success"
else:
status = "failed"
logger.info(
"ETL 子进程 [%s] 退出exit_code=%s, status=%s",
execution_id, exit_code, status,
)
except asyncio.CancelledError:
status = "cancelled"
logger.info("ETL 子进程 [%s] 已取消", execution_id)
# 尝试终止子进程
proc = self._processes.get(execution_id)
if proc and proc.poll() is None:
proc.terminate()
except Exception as exc:
status = "failed"
import traceback
tb = traceback.format_exc()
stderr_lines.append(f"[task_executor] 子进程启动/执行异常: {exc}")
stderr_lines.append(tb)
logger.exception("ETL 子进程 [%s] 执行异常", execution_id)
finally:
elapsed_ms = int((time.monotonic() - t0) * 1000)
finished_at = datetime.now(timezone.utc)
self._broadcast_end(execution_id)
self._processes.pop(execution_id, None)
self._update_execution_log(
execution_id=execution_id,
status=status,
finished_at=finished_at,
exit_code=exit_code,
duration_ms=elapsed_ms,
output_log="\n".join(stdout_lines),
error_log="\n".join(stderr_lines),
)
def _run_subprocess(
self,
cmd: list[str],
execution_id: str,
stdout_lines: list[str],
stderr_lines: list[str],
extra_env: dict[str, str] | None = None,
) -> int:
"""在线程中运行子进程并逐行读取输出。"""
import os
env = os.environ.copy()
# 强制子进程使用 UTF-8 输出,避免 Windows GBK 乱码
env["PYTHONIOENCODING"] = "utf-8"
if extra_env:
env.update(extra_env)
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=ETL_PROJECT_PATH,
env=env,
text=True,
encoding="utf-8",
errors="replace",
)
self._processes[execution_id] = proc
def read_stream(
stream, stream_name: str, collector: list[str],
) -> None:
"""逐行读取流并广播。"""
for raw_line in stream:
line = raw_line.rstrip("\n").rstrip("\r")
tagged = f"[{stream_name}] {line}"
buf = self._log_buffers.get(execution_id)
if buf is not None:
buf.append(tagged)
collector.append(line)
self._broadcast(execution_id, tagged)
t_out = threading.Thread(
target=read_stream, args=(proc.stdout, "stdout", stdout_lines),
daemon=True,
)
t_err = threading.Thread(
target=read_stream, args=(proc.stderr, "stderr", stderr_lines),
daemon=True,
)
t_out.start()
t_err.start()
proc.wait()
t_out.join(timeout=5)
t_err.join(timeout=5)
return proc.returncode
# ------------------------------------------------------------------
# 取消
# ------------------------------------------------------------------
async def cancel(self, execution_id: str) -> bool:
"""向子进程发送终止信号。
Returns:
True 表示成功发送终止信号False 表示进程不存在或已退出。
"""
proc = self._processes.get(execution_id)
if proc is None:
return False
# subprocess.Popen: poll() 返回 None 表示仍在运行
if proc.poll() is not None:
return False
logger.info("取消 ETL 子进程 [%s], pid=%s", execution_id, proc.pid)
try:
proc.terminate()
except ProcessLookupError:
return False
return True
# ------------------------------------------------------------------
# 数据库操作(同步,在线程池中执行也可,此处简单直连)
# ------------------------------------------------------------------
@staticmethod
def _write_execution_log(
*,
execution_id: str,
queue_id: str | None,
site_id: int | None,
task_codes: list[str],
status: str,
started_at: datetime,
command: str,
) -> None:
"""插入一条执行日志记录running 状态)。"""
try:
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO task_execution_log
(id, queue_id, site_id, task_codes, status,
started_at, command)
VALUES (%s, %s, %s, %s, %s, %s, %s)
""",
(
execution_id,
queue_id,
site_id or 0,
task_codes,
status,
started_at,
command,
),
)
conn.commit()
finally:
conn.close()
except Exception:
logger.exception("写入 execution_log 失败 [%s]", execution_id)
@staticmethod
def _update_execution_log(
*,
execution_id: str,
status: str,
finished_at: datetime,
exit_code: int | None,
duration_ms: int,
output_log: str,
error_log: str,
) -> None:
"""更新执行日志记录(完成状态)。"""
try:
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE task_execution_log
SET status = %s,
finished_at = %s,
exit_code = %s,
duration_ms = %s,
output_log = %s,
error_log = %s
WHERE id = %s
""",
(
status,
finished_at,
exit_code,
duration_ms,
output_log,
error_log,
execution_id,
),
)
conn.commit()
finally:
conn.close()
except Exception:
logger.exception("更新 execution_log 失败 [%s]", execution_id)
# ------------------------------------------------------------------
# 清理
# ------------------------------------------------------------------
def cleanup(self, execution_id: str) -> None:
"""清理指定执行的内存资源(日志缓冲区和订阅者)。
通常在确认日志已持久化后调用。
"""
self._log_buffers.pop(execution_id, None)
self._subscribers.pop(execution_id, None)
# 全局单例
task_executor = TaskExecutor()

View File

@@ -0,0 +1,486 @@
# -*- coding: utf-8 -*-
"""任务队列服务
基于 PostgreSQL task_queue 表实现 FIFO 队列,支持:
- enqueue入队自动分配 position当前最大 + 1
- dequeue取出 position 最小的 pending 任务
- reorder调整任务在队列中的位置
- delete删除 pending 任务
- process_loop后台协程队列非空且无运行中任务时自动取出执行
所有操作按 site_id 过滤,实现门店隔离。
"""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from dataclasses import dataclass, field
from typing import Any
from ..database import get_connection
from ..schemas.tasks import TaskConfigSchema
logger = logging.getLogger(__name__)
# 后台循环轮询间隔(秒)
POLL_INTERVAL_SECONDS = 2
@dataclass
class QueuedTask:
"""队列任务数据对象"""
id: str
site_id: int
config: dict[str, Any]
status: str
position: int
created_at: Any = None
started_at: Any = None
finished_at: Any = None
exit_code: int | None = None
error_message: str | None = None
class TaskQueue:
"""基于 PostgreSQL 的任务队列"""
def __init__(self) -> None:
self._running = False
self._loop_task: asyncio.Task | None = None
# ------------------------------------------------------------------
# 入队
# ------------------------------------------------------------------
def enqueue(self, config: TaskConfigSchema, site_id: int) -> str:
"""将任务配置入队,自动分配 position。
Args:
config: 任务配置
site_id: 门店 ID门店隔离
Returns:
新创建的队列任务 IDUUID 字符串)
"""
task_id = str(uuid.uuid4())
config_json = config.model_dump(mode="json")
conn = get_connection()
try:
with conn.cursor() as cur:
# 取当前该门店 pending 任务的最大 position新任务排在末尾
cur.execute(
"""
SELECT COALESCE(MAX(position), 0)
FROM task_queue
WHERE site_id = %s AND status = 'pending'
""",
(site_id,),
)
max_pos = cur.fetchone()[0]
new_pos = max_pos + 1
cur.execute(
"""
INSERT INTO task_queue (id, site_id, config, status, position)
VALUES (%s, %s, %s, 'pending', %s)
""",
(task_id, site_id, json.dumps(config_json), new_pos),
)
conn.commit()
finally:
conn.close()
logger.info("任务入队 [%s] site_id=%s position=%s", task_id, site_id, new_pos)
return task_id
# ------------------------------------------------------------------
# 出队
# ------------------------------------------------------------------
def dequeue(self, site_id: int) -> QueuedTask | None:
"""取出 position 最小的 pending 任务,将其状态改为 running。
Args:
site_id: 门店 ID
Returns:
QueuedTask 或 None队列为空时
"""
conn = get_connection()
try:
with conn.cursor() as cur:
# 选取 position 最小的 pending 任务并锁定
cur.execute(
"""
SELECT id, site_id, config, status, position,
created_at, started_at, finished_at,
exit_code, error_message
FROM task_queue
WHERE site_id = %s AND status = 'pending'
ORDER BY position ASC
LIMIT 1
FOR UPDATE SKIP LOCKED
""",
(site_id,),
)
row = cur.fetchone()
if row is None:
conn.commit()
return None
task = QueuedTask(
id=str(row[0]),
site_id=row[1],
config=row[2] if isinstance(row[2], dict) else json.loads(row[2]),
status=row[3],
position=row[4],
created_at=row[5],
started_at=row[6],
finished_at=row[7],
exit_code=row[8],
error_message=row[9],
)
# 更新状态为 running
cur.execute(
"""
UPDATE task_queue
SET status = 'running', started_at = NOW()
WHERE id = %s
""",
(task.id,),
)
conn.commit()
finally:
conn.close()
task.status = "running"
logger.info("任务出队 [%s] site_id=%s", task.id, site_id)
return task
# ------------------------------------------------------------------
# 重排
# ------------------------------------------------------------------
def reorder(self, task_id: str, new_position: int, site_id: int) -> None:
"""调整任务在队列中的位置。
仅允许对 pending 状态的任务重排。将目标任务移到 new_position
其余 pending 任务按原有相对顺序重新编号。
Args:
task_id: 要移动的任务 ID
new_position: 目标位置1-based
site_id: 门店 ID
"""
conn = get_connection()
try:
with conn.cursor() as cur:
# 获取该门店所有 pending 任务,按 position 排序
cur.execute(
"""
SELECT id FROM task_queue
WHERE site_id = %s AND status = 'pending'
ORDER BY position ASC
""",
(site_id,),
)
rows = cur.fetchall()
task_ids = [str(r[0]) for r in rows]
if task_id not in task_ids:
conn.commit()
return
# 从列表中移除目标任务,再插入到新位置
task_ids.remove(task_id)
# new_position 是 1-based转为 0-based 索引并 clamp
insert_idx = max(0, min(new_position - 1, len(task_ids)))
task_ids.insert(insert_idx, task_id)
# 按新顺序重新分配 position1-based 连续编号)
for idx, tid in enumerate(task_ids, start=1):
cur.execute(
"UPDATE task_queue SET position = %s WHERE id = %s",
(idx, tid),
)
conn.commit()
finally:
conn.close()
logger.info(
"任务重排 [%s] → position=%s site_id=%s",
task_id, new_position, site_id,
)
# ------------------------------------------------------------------
# 删除
# ------------------------------------------------------------------
def delete(self, task_id: str, site_id: int) -> bool:
"""删除 pending 状态的任务。
Args:
task_id: 任务 ID
site_id: 门店 ID
Returns:
True 表示成功删除False 表示任务不存在或非 pending 状态。
"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
DELETE FROM task_queue
WHERE id = %s AND site_id = %s AND status = 'pending'
""",
(task_id, site_id),
)
deleted = cur.rowcount > 0
conn.commit()
finally:
conn.close()
if deleted:
logger.info("任务删除 [%s] site_id=%s", task_id, site_id)
else:
logger.warning(
"任务删除失败 [%s] site_id=%s(不存在或非 pending",
task_id, site_id,
)
return deleted
# ------------------------------------------------------------------
# 查询
# ------------------------------------------------------------------
def list_pending(self, site_id: int) -> list[QueuedTask]:
"""列出指定门店的所有 pending 任务,按 position 升序。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, site_id, config, status, position,
created_at, started_at, finished_at,
exit_code, error_message
FROM task_queue
WHERE site_id = %s AND status = 'pending'
ORDER BY position ASC
""",
(site_id,),
)
rows = cur.fetchall()
conn.commit()
finally:
conn.close()
return [
QueuedTask(
id=str(r[0]),
site_id=r[1],
config=r[2] if isinstance(r[2], dict) else json.loads(r[2]),
status=r[3],
position=r[4],
created_at=r[5],
started_at=r[6],
finished_at=r[7],
exit_code=r[8],
error_message=r[9],
)
for r in rows
]
def has_running(self, site_id: int) -> bool:
"""检查指定门店是否有 running 状态的任务。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT EXISTS(
SELECT 1 FROM task_queue
WHERE site_id = %s AND status = 'running'
)
""",
(site_id,),
)
result = cur.fetchone()[0]
conn.commit()
finally:
conn.close()
return result
# ------------------------------------------------------------------
# 后台处理循环
# ------------------------------------------------------------------
async def process_loop(self) -> None:
"""后台协程:队列非空且无运行中任务时,自动取出并执行。
循环逻辑:
1. 查询所有有 pending 任务的 site_id
2. 对每个 site_id若无 running 任务则 dequeue 并执行
3. 等待 POLL_INTERVAL_SECONDS 后重复
"""
# 延迟导入避免循环依赖
from .task_executor import task_executor
self._running = True
logger.info("TaskQueue process_loop 启动")
while self._running:
try:
await self._process_once(task_executor)
except Exception:
logger.exception("process_loop 迭代异常")
await asyncio.sleep(POLL_INTERVAL_SECONDS)
logger.info("TaskQueue process_loop 停止")
async def _process_once(self, executor: Any) -> None:
"""单次处理:扫描所有门店的 pending 队列并执行。"""
site_ids = self._get_pending_site_ids()
for site_id in site_ids:
if self.has_running(site_id):
continue
task = self.dequeue(site_id)
if task is None:
continue
config = TaskConfigSchema(**task.config)
execution_id = str(uuid.uuid4())
logger.info(
"process_loop 自动执行 [%s] queue_id=%s site_id=%s",
execution_id, task.id, site_id,
)
# 异步启动执行(不阻塞循环)
asyncio.create_task(
self._execute_and_update(
executor, config, execution_id, task.id, site_id,
)
)
async def _execute_and_update(
self,
executor: Any,
config: TaskConfigSchema,
execution_id: str,
queue_id: str,
site_id: int,
) -> None:
"""执行任务并更新队列状态。"""
try:
await executor.execute(
config=config,
execution_id=execution_id,
queue_id=queue_id,
site_id=site_id,
)
# 执行完成后根据 executor 的结果更新 task_queue 状态
self._update_queue_status_from_log(queue_id)
except Exception:
logger.exception("队列任务执行异常 [%s]", queue_id)
self._mark_failed(queue_id, "执行过程中发生未捕获异常")
def _get_pending_site_ids(self) -> list[int]:
"""获取所有有 pending 任务的 site_id 列表。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT DISTINCT site_id FROM task_queue
WHERE status = 'pending'
"""
)
rows = cur.fetchall()
conn.commit()
finally:
conn.close()
return [r[0] for r in rows]
def _update_queue_status_from_log(self, queue_id: str) -> None:
"""从 task_execution_log 读取执行结果,同步到 task_queue 记录。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT status, finished_at, exit_code, error_log
FROM task_execution_log
WHERE queue_id = %s
ORDER BY started_at DESC
LIMIT 1
""",
(queue_id,),
)
row = cur.fetchone()
if row:
cur.execute(
"""
UPDATE task_queue
SET status = %s, finished_at = %s,
exit_code = %s, error_message = %s
WHERE id = %s
""",
(row[0], row[1], row[2], row[3], queue_id),
)
conn.commit()
finally:
conn.close()
def _mark_failed(self, queue_id: str, error_message: str) -> None:
"""将队列任务标记为 failed。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE task_queue
SET status = 'failed', finished_at = NOW(),
error_message = %s
WHERE id = %s
""",
(error_message, queue_id),
)
conn.commit()
finally:
conn.close()
# ------------------------------------------------------------------
# 生命周期
# ------------------------------------------------------------------
def start(self) -> None:
"""启动后台处理循环(在 FastAPI lifespan 中调用)。"""
if self._loop_task is None or self._loop_task.done():
self._loop_task = asyncio.create_task(self.process_loop())
logger.info("TaskQueue 后台循环已启动")
async def stop(self) -> None:
"""停止后台处理循环。"""
self._running = False
if self._loop_task and not self._loop_task.done():
self._loop_task.cancel()
try:
await self._loop_task
except asyncio.CancelledError:
pass
self._loop_task = None
logger.info("TaskQueue 后台循环已停止")
# 全局单例
task_queue = TaskQueue()

View File

@@ -0,0 +1,221 @@
# -*- coding: utf-8 -*-
"""静态任务注册表
从 ETL orchestration/task_registry.py 提取的任务元数据硬编码副本。
后端不直接导入 ETL 代码,避免引入重量级依赖链。
业务域分组逻辑:按任务代码前缀 / 目标表语义归类,与 GUI 保持一致。
"""
from __future__ import annotations
from dataclasses import dataclass, field
@dataclass(frozen=True)
class TaskDefinition:
"""单个 ETL 任务的元数据"""
code: str
name: str
description: str
domain: str # 业务域:会员 / 结算 / 助教 / 商品 / 台桌 / 团购 / 库存 / 财务 / 指数 / 工具
layer: str # ODS / DWD / DWS / INDEX / UTILITY
requires_window: bool = True
is_ods: bool = False
is_dimension: bool = False
default_enabled: bool = True
is_common: bool = True # 常用任务标记False 表示工具类/手动类任务
@dataclass(frozen=True)
class DwdTableDefinition:
"""DWD 表元数据"""
table_name: str # 完整表名(含 schema
display_name: str
domain: str
ods_source: str # 对应的 ODS 源表
is_dimension: bool = False
# ── ODS 任务定义 ──────────────────────────────────────────────
ODS_TASKS: list[TaskDefinition] = [
TaskDefinition("ODS_ASSISTANT_ACCOUNT", "助教账号", "抽取助教账号主数据", "助教", "ODS", is_ods=True),
TaskDefinition("ODS_ASSISTANT_LEDGER", "助教服务记录", "抽取助教服务流水", "助教", "ODS", is_ods=True),
TaskDefinition("ODS_ASSISTANT_ABOLISH", "助教取消记录", "抽取助教取消/作废记录", "助教", "ODS", is_ods=True),
TaskDefinition("ODS_SETTLEMENT_RECORDS", "结算记录", "抽取订单结算记录", "结算", "ODS", is_ods=True),
TaskDefinition("ODS_SETTLEMENT_TICKET", "结账小票", "抽取结账小票明细", "结算", "ODS", is_ods=True),
TaskDefinition("ODS_TABLE_USE", "台费流水", "抽取台费使用流水", "台桌", "ODS", is_ods=True),
TaskDefinition("ODS_TABLE_FEE_DISCOUNT", "台费折扣", "抽取台费折扣记录", "台桌", "ODS", is_ods=True),
TaskDefinition("ODS_TABLES", "台桌主数据", "抽取门店台桌信息", "台桌", "ODS", is_ods=True, requires_window=False),
TaskDefinition("ODS_PAYMENT", "支付流水", "抽取支付交易记录", "结算", "ODS", is_ods=True),
TaskDefinition("ODS_REFUND", "退款流水", "抽取退款交易记录", "结算", "ODS", is_ods=True),
TaskDefinition("ODS_PLATFORM_COUPON", "平台券核销", "抽取平台优惠券核销记录", "团购", "ODS", is_ods=True),
TaskDefinition("ODS_MEMBER", "会员主数据", "抽取会员档案", "会员", "ODS", is_ods=True),
TaskDefinition("ODS_MEMBER_CARD", "会员储值卡", "抽取会员储值卡信息", "会员", "ODS", is_ods=True),
TaskDefinition("ODS_MEMBER_BALANCE", "会员余额变动", "抽取会员余额变动记录", "会员", "ODS", is_ods=True),
TaskDefinition("ODS_RECHARGE_SETTLE", "充值结算", "抽取充值结算记录", "会员", "ODS", is_ods=True),
TaskDefinition("ODS_GROUP_PACKAGE", "团购套餐", "抽取团购套餐定义", "团购", "ODS", is_ods=True, requires_window=False),
TaskDefinition("ODS_GROUP_BUY_REDEMPTION", "团购核销", "抽取团购核销记录", "团购", "ODS", is_ods=True),
TaskDefinition("ODS_INVENTORY_STOCK", "库存快照", "抽取商品库存汇总", "库存", "ODS", is_ods=True, requires_window=False),
TaskDefinition("ODS_INVENTORY_CHANGE", "库存变动", "抽取库存出入库记录", "库存", "ODS", is_ods=True),
TaskDefinition("ODS_GOODS_CATEGORY", "商品分类", "抽取商品分类树", "商品", "ODS", is_ods=True, requires_window=False),
TaskDefinition("ODS_STORE_GOODS", "门店商品", "抽取门店商品主数据", "商品", "ODS", is_ods=True, requires_window=False),
TaskDefinition("ODS_STORE_GOODS_SALES", "商品销售", "抽取门店商品销售记录", "商品", "ODS", is_ods=True),
TaskDefinition("ODS_TENANT_GOODS", "租户商品", "抽取租户级商品主数据", "商品", "ODS", is_ods=True, requires_window=False),
]
# ── DWD 任务定义 ──────────────────────────────────────────────
DWD_TASKS: list[TaskDefinition] = [
TaskDefinition("DWD_LOAD_FROM_ODS", "DWD 装载", "从 ODS 装载至 DWD维度 SCD2 + 事实增量)", "通用", "DWD", requires_window=False),
TaskDefinition("DWD_QUALITY_CHECK", "DWD 质量检查", "对 DWD 层数据执行质量校验", "通用", "DWD", requires_window=False, is_common=False),
]
# ── DWS 任务定义 ──────────────────────────────────────────────
DWS_TASKS: list[TaskDefinition] = [
TaskDefinition("DWS_BUILD_ORDER_SUMMARY", "订单汇总构建", "构建订单汇总宽表", "结算", "DWS"),
TaskDefinition("DWS_ASSISTANT_DAILY", "助教日报", "汇总助教每日业绩", "助教", "DWS"),
TaskDefinition("DWS_ASSISTANT_MONTHLY", "助教月报", "汇总助教月度业绩", "助教", "DWS"),
TaskDefinition("DWS_ASSISTANT_CUSTOMER", "助教客户分析", "汇总助教-客户关系", "助教", "DWS"),
TaskDefinition("DWS_ASSISTANT_SALARY", "助教工资计算", "计算助教工资", "助教", "DWS"),
TaskDefinition("DWS_ASSISTANT_FINANCE", "助教财务汇总", "汇总助教财务数据", "助教", "DWS"),
TaskDefinition("DWS_MEMBER_CONSUMPTION", "会员消费分析", "汇总会员消费数据", "会员", "DWS"),
TaskDefinition("DWS_MEMBER_VISIT", "会员到店分析", "汇总会员到店频次", "会员", "DWS"),
TaskDefinition("DWS_FINANCE_DAILY", "财务日报", "汇总每日财务数据", "财务", "DWS"),
TaskDefinition("DWS_FINANCE_RECHARGE", "充值汇总", "汇总充值数据", "财务", "DWS"),
TaskDefinition("DWS_FINANCE_INCOME_STRUCTURE", "收入结构", "分析收入结构", "财务", "DWS"),
TaskDefinition("DWS_FINANCE_DISCOUNT_DETAIL", "折扣明细", "汇总折扣明细", "财务", "DWS"),
# CHANGE [2026-02-19] intent: 同步 ETL 侧合并——原 DWS_RETENTION_CLEANUP / DWS_MV_REFRESH_* 已合并为 DWS_MAINTENANCE
TaskDefinition("DWS_MAINTENANCE", "DWS 维护", "刷新物化视图 + 清理过期留存数据", "通用", "DWS", requires_window=False, is_common=False),
]
# ── INDEX 任务定义 ────────────────────────────────────────────
INDEX_TASKS: list[TaskDefinition] = [
TaskDefinition("DWS_WINBACK_INDEX", "回流指数 (WBI)", "计算会员回流指数", "指数", "INDEX"),
TaskDefinition("DWS_NEWCONV_INDEX", "新客转化指数 (NCI)", "计算新客转化指数", "指数", "INDEX"),
TaskDefinition("DWS_ML_MANUAL_IMPORT", "手动导入 (ML)", "手动导入机器学习数据", "指数", "INDEX", requires_window=False, is_common=False),
TaskDefinition("DWS_RELATION_INDEX", "关系指数 (RS)", "计算助教-客户关系指数", "指数", "INDEX"),
]
# ── 工具类任务定义 ────────────────────────────────────────────
UTILITY_TASKS: list[TaskDefinition] = [
TaskDefinition("MANUAL_INGEST", "手动导入", "从本地 JSON 文件手动导入数据", "工具", "UTILITY", requires_window=False, is_common=False),
TaskDefinition("INIT_ODS_SCHEMA", "初始化 ODS Schema", "创建 ODS 层表结构", "工具", "UTILITY", requires_window=False, is_common=False),
TaskDefinition("INIT_DWD_SCHEMA", "初始化 DWD Schema", "创建 DWD 层表结构", "工具", "UTILITY", requires_window=False, is_common=False),
TaskDefinition("INIT_DWS_SCHEMA", "初始化 DWS Schema", "创建 DWS 层表结构", "工具", "UTILITY", requires_window=False, is_common=False),
TaskDefinition("ODS_JSON_ARCHIVE", "ODS JSON 归档", "归档 ODS 原始 JSON 文件", "工具", "UTILITY", requires_window=False, is_common=False),
TaskDefinition("CHECK_CUTOFF", "游标检查", "检查各任务数据游标截止点", "工具", "UTILITY", requires_window=False, is_common=False),
TaskDefinition("SEED_DWS_CONFIG", "DWS 配置种子", "初始化 DWS 配置数据", "工具", "UTILITY", requires_window=False, is_common=False),
TaskDefinition("DATA_INTEGRITY_CHECK", "数据完整性校验", "校验跨层数据完整性", "工具", "UTILITY", requires_window=False, is_common=False),
]
# ── 全量任务列表 ──────────────────────────────────────────────
ALL_TASKS: list[TaskDefinition] = ODS_TASKS + DWD_TASKS + DWS_TASKS + INDEX_TASKS + UTILITY_TASKS
# 按 code 索引,便于快速查找
_TASK_BY_CODE: dict[str, TaskDefinition] = {t.code: t for t in ALL_TASKS}
def get_all_tasks() -> list[TaskDefinition]:
return ALL_TASKS
def get_task_by_code(code: str) -> TaskDefinition | None:
return _TASK_BY_CODE.get(code.upper())
def get_tasks_grouped_by_domain() -> dict[str, list[TaskDefinition]]:
"""按业务域分组返回任务列表"""
groups: dict[str, list[TaskDefinition]] = {}
for t in ALL_TASKS:
groups.setdefault(t.domain, []).append(t)
return groups
def get_tasks_by_layer(layer: str) -> list[TaskDefinition]:
"""获取指定层的所有任务"""
layer_upper = layer.upper()
return [t for t in ALL_TASKS if t.layer == layer_upper]
# ── Flow → 层映射 ────────────────────────────────────────────
# 每种 Flow 包含的层,用于前端按 Flow 过滤可选任务
FLOW_LAYER_MAP: dict[str, list[str]] = {
"api_ods": ["ODS"],
"api_ods_dwd": ["ODS", "DWD"],
"api_full": ["ODS", "DWD", "DWS", "INDEX"],
"ods_dwd": ["DWD"],
"dwd_dws": ["DWS"],
"dwd_dws_index": ["DWS", "INDEX"],
"dwd_index": ["INDEX"],
}
def get_compatible_tasks(flow_id: str) -> list[TaskDefinition]:
"""根据 Flow 包含的层,返回兼容的任务列表"""
layers = FLOW_LAYER_MAP.get(flow_id, [])
return [t for t in ALL_TASKS if t.layer in layers]
# ── DWD 表定义 ────────────────────────────────────────────────
DWD_TABLES: list[DwdTableDefinition] = [
# 维度表
DwdTableDefinition("dwd.dim_site", "门店维度", "台桌", "ods.table_fee_transactions", is_dimension=True),
DwdTableDefinition("dwd.dim_site_ex", "门店维度(扩展)", "台桌", "ods.table_fee_transactions", is_dimension=True),
DwdTableDefinition("dwd.dim_table", "台桌维度", "台桌", "ods.site_tables_master", is_dimension=True),
DwdTableDefinition("dwd.dim_table_ex", "台桌维度(扩展)", "台桌", "ods.site_tables_master", is_dimension=True),
DwdTableDefinition("dwd.dim_assistant", "助教维度", "助教", "ods.assistant_accounts_master", is_dimension=True),
DwdTableDefinition("dwd.dim_assistant_ex", "助教维度(扩展)", "助教", "ods.assistant_accounts_master", is_dimension=True),
DwdTableDefinition("dwd.dim_member", "会员维度", "会员", "ods.member_profiles", is_dimension=True),
DwdTableDefinition("dwd.dim_member_ex", "会员维度(扩展)", "会员", "ods.member_profiles", is_dimension=True),
DwdTableDefinition("dwd.dim_member_card_account", "会员储值卡维度", "会员", "ods.member_stored_value_cards", is_dimension=True),
DwdTableDefinition("dwd.dim_member_card_account_ex", "会员储值卡维度(扩展)", "会员", "ods.member_stored_value_cards", is_dimension=True),
DwdTableDefinition("dwd.dim_tenant_goods", "租户商品维度", "商品", "ods.tenant_goods_master", is_dimension=True),
DwdTableDefinition("dwd.dim_tenant_goods_ex", "租户商品维度(扩展)", "商品", "ods.tenant_goods_master", is_dimension=True),
DwdTableDefinition("dwd.dim_store_goods", "门店商品维度", "商品", "ods.store_goods_master", is_dimension=True),
DwdTableDefinition("dwd.dim_store_goods_ex", "门店商品维度(扩展)", "商品", "ods.store_goods_master", is_dimension=True),
DwdTableDefinition("dwd.dim_goods_category", "商品分类维度", "商品", "ods.stock_goods_category_tree", is_dimension=True),
DwdTableDefinition("dwd.dim_groupbuy_package", "团购套餐维度", "团购", "ods.group_buy_packages", is_dimension=True),
DwdTableDefinition("dwd.dim_groupbuy_package_ex", "团购套餐维度(扩展)", "团购", "ods.group_buy_packages", is_dimension=True),
# 事实表
DwdTableDefinition("dwd.dwd_settlement_head", "结算主表", "结算", "ods.settlement_records"),
DwdTableDefinition("dwd.dwd_settlement_head_ex", "结算主表(扩展)", "结算", "ods.settlement_records"),
DwdTableDefinition("dwd.dwd_table_fee_log", "台费流水", "台桌", "ods.table_fee_transactions"),
DwdTableDefinition("dwd.dwd_table_fee_log_ex", "台费流水(扩展)", "台桌", "ods.table_fee_transactions"),
DwdTableDefinition("dwd.dwd_table_fee_adjust", "台费折扣", "台桌", "ods.table_fee_discount_records"),
DwdTableDefinition("dwd.dwd_table_fee_adjust_ex", "台费折扣(扩展)", "台桌", "ods.table_fee_discount_records"),
DwdTableDefinition("dwd.dwd_store_goods_sale", "商品销售", "商品", "ods.store_goods_sales_records"),
DwdTableDefinition("dwd.dwd_store_goods_sale_ex", "商品销售(扩展)", "商品", "ods.store_goods_sales_records"),
DwdTableDefinition("dwd.dwd_assistant_service_log", "助教服务流水", "助教", "ods.assistant_service_records"),
DwdTableDefinition("dwd.dwd_assistant_service_log_ex", "助教服务流水(扩展)", "助教", "ods.assistant_service_records"),
DwdTableDefinition("dwd.dwd_assistant_trash_event", "助教取消事件", "助教", "ods.assistant_cancellation_records"),
DwdTableDefinition("dwd.dwd_assistant_trash_event_ex", "助教取消事件(扩展)", "助教", "ods.assistant_cancellation_records"),
DwdTableDefinition("dwd.dwd_member_balance_change", "会员余额变动", "会员", "ods.member_balance_changes"),
DwdTableDefinition("dwd.dwd_member_balance_change_ex", "会员余额变动(扩展)", "会员", "ods.member_balance_changes"),
DwdTableDefinition("dwd.dwd_groupbuy_redemption", "团购核销", "团购", "ods.group_buy_redemption_records"),
DwdTableDefinition("dwd.dwd_groupbuy_redemption_ex", "团购核销(扩展)", "团购", "ods.group_buy_redemption_records"),
DwdTableDefinition("dwd.dwd_platform_coupon_redemption", "平台券核销", "团购", "ods.platform_coupon_redemption_records"),
DwdTableDefinition("dwd.dwd_platform_coupon_redemption_ex", "平台券核销(扩展)", "团购", "ods.platform_coupon_redemption_records"),
DwdTableDefinition("dwd.dwd_recharge_order", "充值订单", "会员", "ods.recharge_settlements"),
DwdTableDefinition("dwd.dwd_recharge_order_ex", "充值订单(扩展)", "会员", "ods.recharge_settlements"),
DwdTableDefinition("dwd.dwd_payment", "支付流水", "结算", "ods.payment_transactions"),
DwdTableDefinition("dwd.dwd_refund", "退款流水", "结算", "ods.refund_transactions"),
DwdTableDefinition("dwd.dwd_refund_ex", "退款流水(扩展)", "结算", "ods.refund_transactions"),
]
def get_dwd_tables_grouped_by_domain() -> dict[str, list[DwdTableDefinition]]:
"""按业务域分组返回 DWD 表定义"""
groups: dict[str, list[DwdTableDefinition]] = {}
for t in DWD_TABLES:
groups.setdefault(t.domain, []).append(t)
return groups

View File

View File

@@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
"""WebSocket 日志推送端点
提供 WS /ws/logs/{execution_id} 端点,实时推送 ETL 任务执行日志。
客户端连接后,先发送已有的历史日志行,再实时推送新日志,
直到执行结束(收到 None 哨兵)或客户端断开。
设计要点:
- 利用 TaskExecutor 已有的 subscribe/unsubscribe 机制
- 连接时先回放内存缓冲区中的历史日志,避免丢失已产生的行
- 通过 asyncio.Queue 接收实时日志None 表示执行结束
"""
from __future__ import annotations
import logging
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from ..services.task_executor import task_executor
logger = logging.getLogger(__name__)
ws_router = APIRouter()
@ws_router.websocket("/ws/logs/{execution_id}")
async def ws_logs(websocket: WebSocket, execution_id: str) -> None:
"""实时推送指定 execution_id 的任务执行日志。
流程:
1. 接受 WebSocket 连接
2. 回放内存缓冲区中已有的日志行
3. 订阅 TaskExecutor持续推送新日志
4. 收到 None执行结束或客户端断开时关闭
"""
await websocket.accept()
logger.info("WebSocket 连接已建立: execution_id=%s", execution_id)
# 订阅日志流
queue = task_executor.subscribe(execution_id)
try:
# 回放已有的历史日志行
for line in task_executor.get_logs(execution_id):
await websocket.send_text(line)
# 如果任务已经不在运行且没有订阅者队列中的数据,
# 仍然保持连接等待——可能是任务刚结束但 queue 里还有未消费的消息
while True:
msg = await queue.get()
if msg is None:
# 执行结束哨兵
break
await websocket.send_text(msg)
except WebSocketDisconnect:
logger.info("WebSocket 客户端断开: execution_id=%s", execution_id)
except Exception:
logger.exception("WebSocket 异常: execution_id=%s", execution_id)
finally:
task_executor.unsubscribe(execution_id, queue)
# 安全关闭连接(客户端可能已断开,忽略错误)
try:
await websocket.close()
except Exception:
pass
logger.info("WebSocket 连接已清理: execution_id=%s", execution_id)