在准备环境前提交次全部更改。
This commit is contained in:
48
apps/backend/.env.local
Normal file
48
apps/backend/.env.local
Normal file
@@ -0,0 +1,48 @@
|
||||
# ==============================================================================
|
||||
# NeoZQYY 后端 .env.local — 私有覆盖层
|
||||
# ==============================================================================
|
||||
# 后端 config.py 以 override=True 加载此文件,优先级高于根 .env
|
||||
# 敏感值禁止提交;本文件已在 .gitignore 中排除
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 业务数据库(zqyy_app)
|
||||
# ------------------------------------------------------------------------------
|
||||
# DB_HOST / DB_PORT / DB_USER / DB_PASSWORD 继承自根 .env,无需重复
|
||||
# CHANGE 2026-02-15 | 默认指向测试库,生产环境切换为 zqyy_app
|
||||
APP_DB_NAME=test_zqyy_app
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# ETL 数据库(后端只读访问,用于数据库查看器)
|
||||
# ------------------------------------------------------------------------------
|
||||
# 与 zqyy_app 同实例时可省略 ETL_DB_HOST/PORT/USER/PASSWORD,自动复用
|
||||
# CHANGE 2026-02-15 | 默认指向测试库,生产环境切换为 etl_feiqiu
|
||||
ETL_DB_NAME=test_etl_feiqiu
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# JWT 认证
|
||||
# ------------------------------------------------------------------------------
|
||||
JWT_SECRET_KEY=change-me-in-production
|
||||
JWT_ALGORITHM=HS256
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# CORS(逗号分隔)
|
||||
# ------------------------------------------------------------------------------
|
||||
CORS_ORIGINS=http://localhost:5173
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 微信消息推送(与微信后台填写的 Token 一致)
|
||||
# CHANGE 2026-02-19 | 新增微信消息推送回调 Token
|
||||
# ------------------------------------------------------------------------------
|
||||
WX_CALLBACK_TOKEN=LLZQwx2026push
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 通用
|
||||
# ------------------------------------------------------------------------------
|
||||
LOG_LEVEL=INFO
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# ETL 项目路径(子进程 cwd,缺省按 monorepo 相对路径推算)
|
||||
# ------------------------------------------------------------------------------
|
||||
# ETL_PROJECT_PATH=C:/NeoZQYY/apps/etl/connectors/feiqiu
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
## 内部结构
|
||||
|
||||
`
|
||||
```
|
||||
apps/backend/
|
||||
├── app/
|
||||
│ ├── main.py # FastAPI 入口,启用 OpenAPI 文档
|
||||
@@ -16,21 +16,22 @@ apps/backend/
|
||||
├── tests/ # 后端测试
|
||||
├── pyproject.toml # 依赖声明
|
||||
└── README.md
|
||||
`
|
||||
```
|
||||
|
||||
## 启动
|
||||
|
||||
`ash
|
||||
```bash
|
||||
# 确保已在根目录执行 uv sync --all-packages
|
||||
cd apps/backend
|
||||
uvicorn app.main:app --reload
|
||||
`
|
||||
uv run uvicorn app.main:app --host 127.0.0.1 --port 8000 --reload
|
||||
```
|
||||
|
||||
API 文档自动生成于 http://localhost:8000/docs
|
||||
|
||||
## 依赖
|
||||
|
||||
- fastapi>=0.100, uvicorn>=0.23
|
||||
- psycopg2-binary>=2.9.0
|
||||
- fastapi>=0.115, uvicorn[standard]>=0.34
|
||||
- psycopg2-binary>=2.9, python-dotenv>=1.0
|
||||
- neozqyy-shared(workspace 引用)
|
||||
|
||||
## Roadmap
|
||||
|
||||
1
apps/backend/app/auth/__init__.py
Normal file
1
apps/backend/app/auth/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""认证模块:JWT 令牌管理与 FastAPI 依赖注入。"""
|
||||
67
apps/backend/app/auth/dependencies.py
Normal file
67
apps/backend/app/auth/dependencies.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
FastAPI 依赖注入:从 JWT 提取当前用户信息。
|
||||
|
||||
用法:
|
||||
@router.get("/protected")
|
||||
async def protected_endpoint(user: CurrentUser = Depends(get_current_user)):
|
||||
print(user.user_id, user.site_id)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import JWTError
|
||||
|
||||
from app.auth.jwt import decode_access_token
|
||||
|
||||
# Bearer token 提取器
|
||||
_bearer_scheme = HTTPBearer(auto_error=True)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CurrentUser:
|
||||
"""从 JWT 解析出的当前用户上下文。"""
|
||||
|
||||
user_id: int
|
||||
site_id: int
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(_bearer_scheme),
|
||||
) -> CurrentUser:
|
||||
"""
|
||||
FastAPI 依赖:从 Authorization header 提取 JWT,验证后返回用户信息。
|
||||
|
||||
失败时抛出 401。
|
||||
"""
|
||||
token = credentials.credentials
|
||||
try:
|
||||
payload = decode_access_token(token)
|
||||
except JWTError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无效的令牌",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
user_id_raw = payload.get("sub")
|
||||
site_id = payload.get("site_id")
|
||||
|
||||
if user_id_raw is None or site_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="令牌缺少必要字段",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
user_id = int(user_id_raw)
|
||||
except (TypeError, ValueError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="令牌中 user_id 格式无效",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
return CurrentUser(user_id=user_id, site_id=site_id)
|
||||
112
apps/backend/app/auth/jwt.py
Normal file
112
apps/backend/app/auth/jwt.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
JWT 令牌生成、验证与解码。
|
||||
|
||||
- access_token:短期有效(默认 30 分钟),用于 API 请求认证
|
||||
- refresh_token:长期有效(默认 7 天),用于刷新 access_token
|
||||
- payload 包含 user_id、site_id、令牌类型(access / refresh)
|
||||
- 密码哈希直接使用 bcrypt 库(passlib 与 bcrypt>=4.1 存在兼容性问题)
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import bcrypt
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from app import config
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""校验明文密码与哈希是否匹配。"""
|
||||
return bcrypt.checkpw(
|
||||
plain_password.encode("utf-8"), hashed_password.encode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""生成密码的 bcrypt 哈希。"""
|
||||
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||
|
||||
|
||||
def create_access_token(user_id: int, site_id: int) -> str:
|
||||
"""
|
||||
生成 access_token。
|
||||
|
||||
payload: sub=user_id, site_id, type=access, exp
|
||||
"""
|
||||
expire = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=config.JWT_ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"site_id": site_id,
|
||||
"type": "access",
|
||||
"exp": expire,
|
||||
}
|
||||
return jwt.encode(payload, config.JWT_SECRET_KEY, algorithm=config.JWT_ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(user_id: int, site_id: int) -> str:
|
||||
"""
|
||||
生成 refresh_token。
|
||||
|
||||
payload: sub=user_id, site_id, type=refresh, exp
|
||||
"""
|
||||
expire = datetime.now(timezone.utc) + timedelta(
|
||||
days=config.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"site_id": site_id,
|
||||
"type": "refresh",
|
||||
"exp": expire,
|
||||
}
|
||||
return jwt.encode(payload, config.JWT_SECRET_KEY, algorithm=config.JWT_ALGORITHM)
|
||||
|
||||
|
||||
def create_token_pair(user_id: int, site_id: int) -> dict[str, str]:
|
||||
"""生成 access_token + refresh_token 令牌对。"""
|
||||
return {
|
||||
"access_token": create_access_token(user_id, site_id),
|
||||
"refresh_token": create_refresh_token(user_id, site_id),
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
|
||||
def decode_token(token: str) -> dict:
|
||||
"""
|
||||
解码并验证 JWT 令牌。
|
||||
|
||||
返回 payload dict,包含 sub、site_id、type、exp。
|
||||
令牌无效或过期时抛出 JWTError。
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, config.JWT_SECRET_KEY, algorithms=[config.JWT_ALGORITHM]
|
||||
)
|
||||
return payload
|
||||
except JWTError:
|
||||
raise
|
||||
|
||||
|
||||
def decode_access_token(token: str) -> dict:
|
||||
"""
|
||||
解码 access_token 并验证类型。
|
||||
|
||||
令牌类型不是 access 时抛出 JWTError。
|
||||
"""
|
||||
payload = decode_token(token)
|
||||
if payload.get("type") != "access":
|
||||
raise JWTError("令牌类型不是 access")
|
||||
return payload
|
||||
|
||||
|
||||
def decode_refresh_token(token: str) -> dict:
|
||||
"""
|
||||
解码 refresh_token 并验证类型。
|
||||
|
||||
令牌类型不是 refresh 时抛出 JWTError。
|
||||
"""
|
||||
payload = decode_token(token)
|
||||
if payload.get("type") != "refresh":
|
||||
raise JWTError("令牌类型不是 refresh")
|
||||
return payload
|
||||
@@ -29,7 +29,37 @@ DB_HOST: str = get("DB_HOST", "localhost")
|
||||
DB_PORT: str = get("DB_PORT", "5432")
|
||||
DB_USER: str = get("DB_USER", "")
|
||||
DB_PASSWORD: str = get("DB_PASSWORD", "")
|
||||
APP_DB_NAME: str = get("APP_DB_NAME", "zqyy_app")
|
||||
# CHANGE 2026-02-15 | 默认指向测试库,生产环境通过 .env 覆盖
|
||||
APP_DB_NAME: str = get("APP_DB_NAME", "test_zqyy_app")
|
||||
|
||||
# ---- JWT 认证 ----
|
||||
JWT_SECRET_KEY: str = get("JWT_SECRET_KEY", "") # 生产环境必须设置
|
||||
JWT_ALGORITHM: str = get("JWT_ALGORITHM", "HS256")
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = int(get("JWT_ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
|
||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = int(get("JWT_REFRESH_TOKEN_EXPIRE_DAYS", "7"))
|
||||
|
||||
# ---- ETL 数据库连接参数(可独立配置,缺省时复用 zqyy_app 的连接参数) ----
|
||||
ETL_DB_HOST: str = get("ETL_DB_HOST") or DB_HOST
|
||||
ETL_DB_PORT: str = get("ETL_DB_PORT") or DB_PORT
|
||||
ETL_DB_USER: str = get("ETL_DB_USER") or DB_USER
|
||||
ETL_DB_PASSWORD: str = get("ETL_DB_PASSWORD") or DB_PASSWORD
|
||||
# CHANGE 2026-02-15 | 默认指向测试库,生产环境通过 .env 覆盖
|
||||
ETL_DB_NAME: str = get("ETL_DB_NAME", "test_etl_feiqiu")
|
||||
|
||||
# ---- CORS ----
|
||||
# 逗号分隔的允许来源列表;缺省允许 Vite 开发服务器
|
||||
CORS_ORIGINS: list[str] = [
|
||||
o.strip()
|
||||
for o in get("CORS_ORIGINS", "http://localhost:5173").split(",")
|
||||
if o.strip()
|
||||
]
|
||||
|
||||
# ---- ETL 项目路径 ----
|
||||
# ETL CLI 的工作目录(子进程 cwd),缺省时按 monorepo 相对路径推算
|
||||
ETL_PROJECT_PATH: str = get(
|
||||
"ETL_PROJECT_PATH",
|
||||
str(Path(__file__).resolve().parents[3] / "apps" / "etl" / "connectors" / "feiqiu"),
|
||||
)
|
||||
|
||||
# ---- 通用 ----
|
||||
TIMEZONE: str = get("TIMEZONE", "Asia/Shanghai")
|
||||
|
||||
@@ -1,14 +1,30 @@
|
||||
"""
|
||||
zqyy_app 数据库连接
|
||||
数据库连接
|
||||
|
||||
使用 psycopg2 直连 PostgreSQL,不引入 ORM。
|
||||
连接参数从环境变量读取(经 config 模块加载)。
|
||||
|
||||
提供两类连接:
|
||||
- get_connection():zqyy_app 读写连接(用户/队列/调度等业务数据)
|
||||
- get_etl_readonly_connection(site_id):etl_feiqiu 只读连接(数据库查看器),
|
||||
自动设置 RLS site_id 隔离
|
||||
"""
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extensions import connection as PgConnection
|
||||
|
||||
from app.config import APP_DB_NAME, DB_HOST, DB_PASSWORD, DB_PORT, DB_USER
|
||||
from app.config import (
|
||||
APP_DB_NAME,
|
||||
DB_HOST,
|
||||
DB_PASSWORD,
|
||||
DB_PORT,
|
||||
DB_USER,
|
||||
ETL_DB_HOST,
|
||||
ETL_DB_NAME,
|
||||
ETL_DB_PASSWORD,
|
||||
ETL_DB_PORT,
|
||||
ETL_DB_USER,
|
||||
)
|
||||
|
||||
|
||||
def get_connection() -> PgConnection:
|
||||
@@ -24,3 +40,43 @@ def get_connection() -> PgConnection:
|
||||
password=DB_PASSWORD,
|
||||
dbname=APP_DB_NAME,
|
||||
)
|
||||
|
||||
|
||||
def get_etl_readonly_connection(site_id: int | str) -> PgConnection:
|
||||
"""
|
||||
获取 ETL 数据库(etl_feiqiu)的只读连接。
|
||||
|
||||
连接建立后自动执行:
|
||||
1. SET default_transaction_read_only = on — 禁止写操作
|
||||
2. SET LOCAL app.current_site_id = '{site_id}' — 启用 RLS 门店隔离
|
||||
|
||||
调用方负责关闭连接。典型用法::
|
||||
|
||||
conn = get_etl_readonly_connection(site_id)
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SELECT ...")
|
||||
finally:
|
||||
conn.close()
|
||||
"""
|
||||
conn = psycopg2.connect(
|
||||
host=ETL_DB_HOST,
|
||||
port=ETL_DB_PORT,
|
||||
user=ETL_DB_USER,
|
||||
password=ETL_DB_PASSWORD,
|
||||
dbname=ETL_DB_NAME,
|
||||
)
|
||||
try:
|
||||
conn.autocommit = False
|
||||
with conn.cursor() as cur:
|
||||
# 会话级只读:防止任何写操作
|
||||
cur.execute("SET default_transaction_read_only = on")
|
||||
# 事务级 RLS 隔离:设置当前门店 ID
|
||||
cur.execute(
|
||||
"SET LOCAL app.current_site_id = %s", (str(site_id),)
|
||||
)
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.close()
|
||||
raise
|
||||
return conn
|
||||
|
||||
@@ -1,20 +1,66 @@
|
||||
"""
|
||||
NeoZQYY 后端 API 入口
|
||||
|
||||
基于 FastAPI 构建,为微信小程序提供 RESTful API。
|
||||
基于 FastAPI 构建,为管理后台和微信小程序提供 RESTful API。
|
||||
OpenAPI 文档自动生成于 /docs(Swagger UI)和 /redoc(ReDoc)。
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app import config
|
||||
# CHANGE 2026-02-19 | 新增 xcx_test 路由(MVP 验证)+ wx_callback 路由(微信消息推送)
|
||||
from app.routers import auth, execution, schedules, tasks, env_config, db_viewer, etl_status, xcx_test, wx_callback
|
||||
from app.services.scheduler import scheduler
|
||||
from app.services.task_queue import task_queue
|
||||
from app.ws.logs import ws_router
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期:启动时拉起后台服务,关闭时优雅停止。"""
|
||||
# 启动
|
||||
task_queue.start()
|
||||
scheduler.start()
|
||||
yield
|
||||
# 关闭
|
||||
await scheduler.stop()
|
||||
await task_queue.stop()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="NeoZQYY API",
|
||||
description="台球门店运营助手 — 微信小程序后端 API",
|
||||
description="台球门店运营助手 — 后端 API(管理后台 + 微信小程序)",
|
||||
version="0.1.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# ---- CORS 中间件 ----
|
||||
# 允许来源从环境变量 CORS_ORIGINS 读取,缺省允许 Vite 开发服务器 (localhost:5173)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=config.CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# ---- 路由注册 ----
|
||||
app.include_router(auth.router)
|
||||
app.include_router(tasks.router)
|
||||
app.include_router(execution.router)
|
||||
app.include_router(schedules.router)
|
||||
app.include_router(env_config.router)
|
||||
app.include_router(db_viewer.router)
|
||||
app.include_router(etl_status.router)
|
||||
app.include_router(ws_router)
|
||||
app.include_router(xcx_test.router)
|
||||
app.include_router(wx_callback.router)
|
||||
|
||||
|
||||
@app.get("/health", tags=["系统"])
|
||||
async def health_check():
|
||||
|
||||
97
apps/backend/app/routers/auth.py
Normal file
97
apps/backend/app/routers/auth.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
认证路由:登录与令牌刷新。
|
||||
|
||||
- POST /api/auth/login — 验证用户名密码,返回 JWT 令牌对
|
||||
- POST /api/auth/refresh — 用刷新令牌换取新的访问令牌
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from jose import JWTError
|
||||
|
||||
from app.auth.jwt import (
|
||||
create_access_token,
|
||||
create_token_pair,
|
||||
decode_refresh_token,
|
||||
verify_password,
|
||||
)
|
||||
from app.database import get_connection
|
||||
from app.schemas.auth import LoginRequest, RefreshRequest, TokenResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["认证"])
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(body: LoginRequest):
|
||||
"""
|
||||
用户登录。
|
||||
|
||||
查询 admin_users 表验证用户名密码,成功后返回 JWT 令牌对。
|
||||
- 用户不存在或密码错误:401
|
||||
- 账号已禁用(is_active=false):401
|
||||
"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT id, password_hash, site_id, is_active "
|
||||
"FROM admin_users WHERE username = %s",
|
||||
(body.username,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if row is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户名或密码错误",
|
||||
)
|
||||
|
||||
user_id, password_hash, site_id, is_active = row
|
||||
|
||||
if not is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="账号已被禁用",
|
||||
)
|
||||
|
||||
if not verify_password(body.password, password_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户名或密码错误",
|
||||
)
|
||||
|
||||
tokens = create_token_pair(user_id, site_id)
|
||||
return TokenResponse(**tokens)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def refresh(body: RefreshRequest):
|
||||
"""
|
||||
刷新访问令牌。
|
||||
|
||||
验证 refresh_token 有效性,成功后仅返回新的 access_token
|
||||
(refresh_token 保持不变,由客户端继续持有)。
|
||||
"""
|
||||
try:
|
||||
payload = decode_refresh_token(body.refresh_token)
|
||||
except JWTError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无效的刷新令牌",
|
||||
)
|
||||
|
||||
user_id = int(payload["sub"])
|
||||
site_id = payload["site_id"]
|
||||
|
||||
# 生成新的 access_token,refresh_token 原样返回
|
||||
new_access = create_access_token(user_id, site_id)
|
||||
return TokenResponse(
|
||||
access_token=new_access,
|
||||
refresh_token=body.refresh_token,
|
||||
token_type="bearer",
|
||||
)
|
||||
228
apps/backend/app/routers/db_viewer.py
Normal file
228
apps/backend/app/routers/db_viewer.py
Normal file
@@ -0,0 +1,228 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""数据库查看器 API
|
||||
|
||||
提供 4 个端点:
|
||||
- GET /api/db/schemas — 返回 Schema 列表
|
||||
- GET /api/db/schemas/{name}/tables — 返回表列表和行数
|
||||
- GET /api/db/tables/{schema}/{table}/columns — 返回列定义
|
||||
- POST /api/db/query — 只读 SQL 执行
|
||||
|
||||
所有端点需要 JWT 认证。
|
||||
使用 get_etl_readonly_connection(site_id) 确保 RLS 隔离。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from psycopg2 import errors as pg_errors, OperationalError
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.database import get_etl_readonly_connection
|
||||
from app.schemas.db_viewer import (
|
||||
ColumnInfo,
|
||||
QueryRequest,
|
||||
QueryResponse,
|
||||
SchemaInfo,
|
||||
TableInfo,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/db", tags=["数据库查看器"])
|
||||
|
||||
# 写操作关键词(不区分大小写)
|
||||
_WRITE_KEYWORDS = re.compile(
|
||||
r"\b(INSERT|UPDATE|DELETE|DROP|TRUNCATE)\b",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# 查询结果行数上限
|
||||
_MAX_ROWS = 1000
|
||||
|
||||
# 查询超时(秒)
|
||||
_QUERY_TIMEOUT_SEC = 30
|
||||
|
||||
|
||||
# ── GET /api/db/schemas ──────────────────────────────────────
|
||||
|
||||
@router.get("/schemas", response_model=list[SchemaInfo])
|
||||
async def list_schemas(
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> list[SchemaInfo]:
|
||||
"""返回 ETL 数据库中的 Schema 列表。"""
|
||||
conn = get_etl_readonly_connection(user.site_id)
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
ORDER BY schema_name
|
||||
"""
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [SchemaInfo(name=row[0]) for row in rows]
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
# ── GET /api/db/schemas/{name}/tables ────────────────────────
|
||||
|
||||
@router.get("/schemas/{name}/tables", response_model=list[TableInfo])
|
||||
async def list_tables(
|
||||
name: str,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> list[TableInfo]:
|
||||
"""返回指定 Schema 下所有表的名称和行数统计。"""
|
||||
conn = get_etl_readonly_connection(user.site_id)
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT
|
||||
t.table_name,
|
||||
s.n_live_tup
|
||||
FROM information_schema.tables t
|
||||
LEFT JOIN pg_stat_user_tables s
|
||||
ON s.schemaname = t.table_schema
|
||||
AND s.relname = t.table_name
|
||||
WHERE t.table_schema = %s
|
||||
AND t.table_type = 'BASE TABLE'
|
||||
ORDER BY t.table_name
|
||||
""",
|
||||
(name,),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [
|
||||
TableInfo(name=row[0], row_count=row[1])
|
||||
for row in rows
|
||||
]
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
# ── GET /api/db/tables/{schema}/{table}/columns ──────────────
|
||||
|
||||
@router.get(
|
||||
"/tables/{schema}/{table}/columns",
|
||||
response_model=list[ColumnInfo],
|
||||
)
|
||||
async def list_columns(
|
||||
schema: str,
|
||||
table: str,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> list[ColumnInfo]:
|
||||
"""返回指定表的列定义。"""
|
||||
conn = get_etl_readonly_connection(user.site_id)
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT
|
||||
column_name,
|
||||
data_type,
|
||||
is_nullable,
|
||||
column_default
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = %s AND table_name = %s
|
||||
ORDER BY ordinal_position
|
||||
""",
|
||||
(schema, table),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [
|
||||
ColumnInfo(
|
||||
name=row[0],
|
||||
data_type=row[1],
|
||||
is_nullable=row[2] == "YES",
|
||||
column_default=row[3],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
# ── POST /api/db/query ───────────────────────────────────────
|
||||
|
||||
@router.post("/query", response_model=QueryResponse)
|
||||
async def execute_query(
|
||||
body: QueryRequest,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> QueryResponse:
|
||||
"""只读 SQL 执行。
|
||||
|
||||
安全措施:
|
||||
1. 拦截写操作关键词(INSERT / UPDATE / DELETE / DROP / TRUNCATE)
|
||||
2. 限制返回行数上限 1000 行
|
||||
3. 设置查询超时 30 秒
|
||||
"""
|
||||
sql = body.sql.strip()
|
||||
if not sql:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="SQL 语句不能为空",
|
||||
)
|
||||
|
||||
# 拦截写操作
|
||||
if _WRITE_KEYWORDS.search(sql):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="只允许只读查询,禁止 INSERT / UPDATE / DELETE / DROP / TRUNCATE 操作",
|
||||
)
|
||||
|
||||
conn = get_etl_readonly_connection(user.site_id)
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
# 设置查询超时
|
||||
cur.execute(
|
||||
"SET LOCAL statement_timeout = %s",
|
||||
(f"{_QUERY_TIMEOUT_SEC}s",),
|
||||
)
|
||||
|
||||
try:
|
||||
cur.execute(sql)
|
||||
except pg_errors.QueryCanceled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_408_REQUEST_TIMEOUT,
|
||||
detail=f"查询超时(超过 {_QUERY_TIMEOUT_SEC} 秒)",
|
||||
)
|
||||
except Exception as exc:
|
||||
# SQL 语法错误或其他执行错误
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"SQL 执行错误: {exc}",
|
||||
)
|
||||
|
||||
# 提取列名
|
||||
columns = (
|
||||
[desc[0] for desc in cur.description]
|
||||
if cur.description
|
||||
else []
|
||||
)
|
||||
|
||||
# 限制返回行数
|
||||
rows = cur.fetchmany(_MAX_ROWS)
|
||||
# 将元组转为列表,便于 JSON 序列化
|
||||
rows_list = [list(row) for row in rows]
|
||||
|
||||
return QueryResponse(
|
||||
columns=columns,
|
||||
rows=rows_list,
|
||||
row_count=len(rows_list),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except OperationalError as exc:
|
||||
# 连接级错误
|
||||
logger.error("数据库查看器连接错误: %s", exc)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="数据库连接错误",
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
240
apps/backend/app/routers/env_config.py
Normal file
240
apps/backend/app/routers/env_config.py
Normal file
@@ -0,0 +1,240 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""环境配置 API
|
||||
|
||||
提供 3 个端点:
|
||||
- GET /api/env-config — 读取 .env,敏感值掩码
|
||||
- PUT /api/env-config — 验证并写入 .env
|
||||
- GET /api/env-config/export — 导出去敏感值的配置文件
|
||||
|
||||
所有端点需要 JWT 认证。
|
||||
敏感键判定:键名中包含 PASSWORD、TOKEN、SECRET、DSN(不区分大小写)。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/env-config", tags=["环境配置"])
|
||||
|
||||
# .env 文件路径:项目根目录
|
||||
_ENV_PATH = Path(__file__).resolve().parents[3] / ".env"
|
||||
|
||||
# 敏感键关键词(不区分大小写)
|
||||
_SENSITIVE_KEYWORDS = ("PASSWORD", "TOKEN", "SECRET", "DSN")
|
||||
|
||||
_MASK = "****"
|
||||
|
||||
|
||||
# ── Pydantic 模型 ────────────────────────────────────────────
|
||||
|
||||
class EnvEntry(BaseModel):
|
||||
"""单条环境变量键值对。"""
|
||||
key: str
|
||||
value: str
|
||||
|
||||
|
||||
class EnvConfigResponse(BaseModel):
|
||||
"""GET 响应:键值对列表。"""
|
||||
entries: list[EnvEntry]
|
||||
|
||||
|
||||
class EnvConfigUpdateRequest(BaseModel):
|
||||
"""PUT 请求体:键值对列表。"""
|
||||
entries: list[EnvEntry]
|
||||
|
||||
|
||||
# ── 工具函数 ─────────────────────────────────────────────────
|
||||
|
||||
def _is_sensitive(key: str) -> bool:
|
||||
"""判断键名是否为敏感键。"""
|
||||
upper = key.upper()
|
||||
return any(kw in upper for kw in _SENSITIVE_KEYWORDS)
|
||||
|
||||
|
||||
def _parse_env(content: str) -> list[dict]:
|
||||
"""解析 .env 文件内容,返回行级结构。
|
||||
|
||||
每行分为三种类型:
|
||||
- comment: 注释行或空行(原样保留)
|
||||
- entry: 键值对
|
||||
"""
|
||||
lines: list[dict] = []
|
||||
for raw_line in content.splitlines():
|
||||
stripped = raw_line.strip()
|
||||
if not stripped or stripped.startswith("#"):
|
||||
lines.append({"type": "comment", "raw": raw_line})
|
||||
else:
|
||||
# 支持 KEY=VALUE 和 KEY="VALUE" 格式
|
||||
match = re.match(r'^([A-Za-z_][A-Za-z0-9_]*)=(.*)', raw_line)
|
||||
if match:
|
||||
key = match.group(1)
|
||||
value = match.group(2).strip()
|
||||
# 去除引号包裹
|
||||
if len(value) >= 2 and value[0] == value[-1] and value[0] in ('"', "'"):
|
||||
value = value[1:-1]
|
||||
lines.append({"type": "entry", "key": key, "value": value, "raw": raw_line})
|
||||
else:
|
||||
# 无法解析的行当作注释保留
|
||||
lines.append({"type": "comment", "raw": raw_line})
|
||||
return lines
|
||||
|
||||
|
||||
def _read_env_file(path: Path) -> str:
|
||||
"""读取 .env 文件内容。"""
|
||||
if not path.exists():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=".env 文件不存在",
|
||||
)
|
||||
return path.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def _write_env_file(path: Path, content: str) -> None:
|
||||
"""写入 .env 文件。"""
|
||||
try:
|
||||
path.write_text(content, encoding="utf-8")
|
||||
except OSError as exc:
|
||||
logger.error("写入 .env 文件失败: %s", exc)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="写入 .env 文件失败",
|
||||
)
|
||||
|
||||
|
||||
def _validate_entries(entries: list[EnvEntry]) -> None:
|
||||
"""验证键值对格式。"""
|
||||
for idx, entry in enumerate(entries):
|
||||
if not entry.key:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"第 {idx + 1} 行:键名不能为空",
|
||||
)
|
||||
if not re.match(r'^[A-Za-z_][A-Za-z0-9_]*$', entry.key):
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"第 {idx + 1} 行:键名 '{entry.key}' 格式无效(仅允许字母、数字、下划线,且不能以数字开头)",
|
||||
)
|
||||
|
||||
|
||||
# ── GET /api/env-config — 读取 ───────────────────────────────
|
||||
|
||||
@router.get("", response_model=EnvConfigResponse)
|
||||
async def get_env_config(
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> EnvConfigResponse:
|
||||
"""读取 .env 文件,敏感值以掩码展示。"""
|
||||
content = _read_env_file(_ENV_PATH)
|
||||
parsed = _parse_env(content)
|
||||
|
||||
entries = []
|
||||
for line in parsed:
|
||||
if line["type"] == "entry":
|
||||
value = _MASK if _is_sensitive(line["key"]) else line["value"]
|
||||
entries.append(EnvEntry(key=line["key"], value=value))
|
||||
|
||||
return EnvConfigResponse(entries=entries)
|
||||
|
||||
|
||||
# ── PUT /api/env-config — 写入 ───────────────────────────────
|
||||
|
||||
@router.put("", response_model=EnvConfigResponse)
|
||||
async def update_env_config(
|
||||
body: EnvConfigUpdateRequest,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> EnvConfigResponse:
|
||||
"""验证并写入 .env 文件。
|
||||
|
||||
保留原文件中的注释行和空行。对于已有键,更新值;
|
||||
对于新键,追加到文件末尾。掩码值(****)的键跳过更新,保留原值。
|
||||
"""
|
||||
_validate_entries(body.entries)
|
||||
|
||||
# 读取原文件(如果存在)
|
||||
if _ENV_PATH.exists():
|
||||
original_content = _ENV_PATH.read_text(encoding="utf-8")
|
||||
parsed = _parse_env(original_content)
|
||||
else:
|
||||
parsed = []
|
||||
|
||||
# 构建新值映射(跳过掩码值)
|
||||
new_values: dict[str, str] = {}
|
||||
for entry in body.entries:
|
||||
if entry.value != _MASK:
|
||||
new_values[entry.key] = entry.value
|
||||
|
||||
# 更新已有行
|
||||
seen_keys: set[str] = set()
|
||||
output_lines: list[str] = []
|
||||
for line in parsed:
|
||||
if line["type"] == "comment":
|
||||
output_lines.append(line["raw"])
|
||||
elif line["type"] == "entry":
|
||||
key = line["key"]
|
||||
seen_keys.add(key)
|
||||
if key in new_values:
|
||||
output_lines.append(f"{key}={new_values[key]}")
|
||||
else:
|
||||
# 保留原值(包括掩码跳过的敏感键)
|
||||
output_lines.append(line["raw"])
|
||||
|
||||
# 追加新键
|
||||
for entry in body.entries:
|
||||
if entry.key not in seen_keys and entry.value != _MASK:
|
||||
output_lines.append(f"{entry.key}={entry.value}")
|
||||
|
||||
new_content = "\n".join(output_lines)
|
||||
if output_lines:
|
||||
new_content += "\n"
|
||||
|
||||
_write_env_file(_ENV_PATH, new_content)
|
||||
|
||||
# 返回更新后的配置(敏感值掩码)
|
||||
result_parsed = _parse_env(new_content)
|
||||
entries = []
|
||||
for line in result_parsed:
|
||||
if line["type"] == "entry":
|
||||
value = _MASK if _is_sensitive(line["key"]) else line["value"]
|
||||
entries.append(EnvEntry(key=line["key"], value=value))
|
||||
|
||||
return EnvConfigResponse(entries=entries)
|
||||
|
||||
|
||||
# ── GET /api/env-config/export — 导出 ────────────────────────
|
||||
|
||||
@router.get("/export")
|
||||
async def export_env_config(
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> PlainTextResponse:
|
||||
"""导出去除敏感值的配置文件(作为文件下载)。"""
|
||||
content = _read_env_file(_ENV_PATH)
|
||||
parsed = _parse_env(content)
|
||||
|
||||
output_lines: list[str] = []
|
||||
for line in parsed:
|
||||
if line["type"] == "comment":
|
||||
output_lines.append(line["raw"])
|
||||
elif line["type"] == "entry":
|
||||
if _is_sensitive(line["key"]):
|
||||
output_lines.append(f"{line['key']}={_MASK}")
|
||||
else:
|
||||
output_lines.append(line["raw"])
|
||||
|
||||
export_content = "\n".join(output_lines)
|
||||
if output_lines:
|
||||
export_content += "\n"
|
||||
|
||||
return PlainTextResponse(
|
||||
content=export_content,
|
||||
media_type="text/plain",
|
||||
headers={"Content-Disposition": "attachment; filename=env-config.txt"},
|
||||
)
|
||||
134
apps/backend/app/routers/etl_status.py
Normal file
134
apps/backend/app/routers/etl_status.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""ETL 状态监控 API
|
||||
|
||||
提供 2 个端点:
|
||||
- GET /api/etl-status/cursors — 返回各任务的数据游标(最后抓取时间、记录数)
|
||||
- GET /api/etl-status/recent-runs — 返回最近 50 条任务执行记录
|
||||
|
||||
所有端点需要 JWT 认证。
|
||||
游标端点查询 ETL 数据库(meta.etl_cursor),
|
||||
执行记录端点查询 zqyy_app 数据库(task_execution_log)。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from psycopg2 import OperationalError
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.database import get_connection, get_etl_readonly_connection
|
||||
from app.schemas.etl_status import CursorInfo, RecentRun
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/etl-status", tags=["ETL 状态"])
|
||||
|
||||
# 最近执行记录条数上限
|
||||
_RECENT_RUNS_LIMIT = 50
|
||||
|
||||
|
||||
# ── GET /api/etl-status/cursors ──────────────────────────────
|
||||
|
||||
@router.get("/cursors", response_model=list[CursorInfo])
|
||||
async def list_cursors(
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> list[CursorInfo]:
|
||||
"""返回各 ODS 表的最新数据游标。
|
||||
|
||||
查询 ETL 数据库中的 meta.etl_cursor 表。
|
||||
如果该表不存在,返回空列表而非报错。
|
||||
"""
|
||||
conn = get_etl_readonly_connection(user.site_id)
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
# CHANGE 2026-02-15 | 对齐新库 etl_feiqiu 六层架构:etl_admin → meta
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'meta'
|
||||
AND table_name = 'etl_cursor'
|
||||
)
|
||||
"""
|
||||
)
|
||||
exists = cur.fetchone()[0]
|
||||
if not exists:
|
||||
return []
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT task_code, last_fetch_time, record_count
|
||||
FROM meta.etl_cursor
|
||||
ORDER BY task_code
|
||||
"""
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
|
||||
return [
|
||||
CursorInfo(
|
||||
task_code=row[0],
|
||||
last_fetch_time=str(row[1]) if row[1] is not None else None,
|
||||
record_count=row[2],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
except OperationalError as exc:
|
||||
logger.error("ETL 游标查询连接错误: %s", exc)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="ETL 数据库连接错误",
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
# ── GET /api/etl-status/recent-runs ──────────────────────────
|
||||
|
||||
@router.get("/recent-runs", response_model=list[RecentRun])
|
||||
async def list_recent_runs(
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> list[RecentRun]:
|
||||
"""返回最近 50 条任务执行记录。
|
||||
|
||||
查询 zqyy_app 数据库中的 task_execution_log 表,
|
||||
按 site_id 过滤,按 started_at DESC 排序。
|
||||
"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, task_codes, status, started_at,
|
||||
finished_at, duration_ms, exit_code
|
||||
FROM task_execution_log
|
||||
WHERE site_id = %s
|
||||
ORDER BY started_at DESC
|
||||
LIMIT %s
|
||||
""",
|
||||
(user.site_id, _RECENT_RUNS_LIMIT),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
|
||||
return [
|
||||
RecentRun(
|
||||
id=str(row[0]),
|
||||
task_codes=list(row[1]) if row[1] else [],
|
||||
status=row[2],
|
||||
started_at=str(row[3]),
|
||||
finished_at=str(row[4]) if row[4] is not None else None,
|
||||
duration_ms=row[5],
|
||||
exit_code=row[6],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
except OperationalError as exc:
|
||||
logger.error("执行记录查询连接错误: %s", exc)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="数据库连接错误",
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
281
apps/backend/app/routers/execution.py
Normal file
281
apps/backend/app/routers/execution.py
Normal file
@@ -0,0 +1,281 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""执行与队列 API
|
||||
|
||||
提供 8 个端点:
|
||||
- POST /api/execution/run — 直接执行任务
|
||||
- GET /api/execution/queue — 获取当前队列(按 site_id 过滤)
|
||||
- POST /api/execution/queue — 添加到队列
|
||||
- PUT /api/execution/queue/reorder — 重排队列
|
||||
- DELETE /api/execution/queue/{id} — 删除队列任务
|
||||
- POST /api/execution/{id}/cancel — 取消执行中的任务
|
||||
- GET /api/execution/history — 执行历史(按 site_id 过滤)
|
||||
- GET /api/execution/{id}/logs — 获取历史日志
|
||||
|
||||
所有端点需要 JWT 认证,site_id 从 JWT 提取。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.database import get_connection
|
||||
from app.schemas.execution import (
|
||||
ExecutionHistoryItem,
|
||||
ExecutionLogsResponse,
|
||||
ExecutionRunResponse,
|
||||
QueueTaskResponse,
|
||||
ReorderRequest,
|
||||
)
|
||||
from app.schemas.tasks import TaskConfigSchema
|
||||
from app.services.task_executor import task_executor
|
||||
from app.services.task_queue import task_queue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/execution", tags=["任务执行"])
|
||||
|
||||
|
||||
# ── POST /api/execution/run — 直接执行任务 ────────────────────
|
||||
|
||||
@router.post("/run", response_model=ExecutionRunResponse)
|
||||
async def run_task(
|
||||
config: TaskConfigSchema,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> ExecutionRunResponse:
|
||||
"""直接执行任务(不经过队列)。
|
||||
|
||||
从 JWT 注入 store_id,创建 execution_id 后异步启动子进程。
|
||||
"""
|
||||
config = config.model_copy(update={"store_id": user.site_id})
|
||||
execution_id = str(uuid.uuid4())
|
||||
|
||||
# 异步启动执行,不阻塞响应
|
||||
asyncio.create_task(
|
||||
task_executor.execute(
|
||||
config=config,
|
||||
execution_id=execution_id,
|
||||
site_id=user.site_id,
|
||||
)
|
||||
)
|
||||
|
||||
return ExecutionRunResponse(
|
||||
execution_id=execution_id,
|
||||
message="任务已提交执行",
|
||||
)
|
||||
|
||||
|
||||
# ── GET /api/execution/queue — 获取当前队列 ───────────────────
|
||||
|
||||
@router.get("/queue", response_model=list[QueueTaskResponse])
|
||||
async def get_queue(
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> list[QueueTaskResponse]:
|
||||
"""获取当前门店的待执行队列。"""
|
||||
tasks = task_queue.list_pending(user.site_id)
|
||||
return [
|
||||
QueueTaskResponse(
|
||||
id=t.id,
|
||||
site_id=t.site_id,
|
||||
config=t.config,
|
||||
status=t.status,
|
||||
position=t.position,
|
||||
created_at=t.created_at,
|
||||
started_at=t.started_at,
|
||||
finished_at=t.finished_at,
|
||||
exit_code=t.exit_code,
|
||||
error_message=t.error_message,
|
||||
)
|
||||
for t in tasks
|
||||
]
|
||||
|
||||
|
||||
# ── POST /api/execution/queue — 添加到队列 ───────────────────
|
||||
|
||||
@router.post("/queue", response_model=QueueTaskResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def enqueue_task(
|
||||
config: TaskConfigSchema,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> QueueTaskResponse:
|
||||
"""将任务配置添加到执行队列。"""
|
||||
config = config.model_copy(update={"store_id": user.site_id})
|
||||
task_id = task_queue.enqueue(config, user.site_id)
|
||||
|
||||
# 查询刚创建的任务返回完整信息
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, site_id, config, status, position,
|
||||
created_at, started_at, finished_at,
|
||||
exit_code, error_message
|
||||
FROM task_queue WHERE id = %s
|
||||
""",
|
||||
(task_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if row is None:
|
||||
raise HTTPException(status_code=500, detail="入队后查询失败")
|
||||
|
||||
config_data = row[2] if isinstance(row[2], dict) else json.loads(row[2])
|
||||
return QueueTaskResponse(
|
||||
id=str(row[0]),
|
||||
site_id=row[1],
|
||||
config=config_data,
|
||||
status=row[3],
|
||||
position=row[4],
|
||||
created_at=row[5],
|
||||
started_at=row[6],
|
||||
finished_at=row[7],
|
||||
exit_code=row[8],
|
||||
error_message=row[9],
|
||||
)
|
||||
|
||||
|
||||
# ── PUT /api/execution/queue/reorder — 重排队列 ──────────────
|
||||
|
||||
@router.put("/queue/reorder")
|
||||
async def reorder_queue(
|
||||
body: ReorderRequest,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""调整队列中任务的执行顺序。"""
|
||||
task_queue.reorder(body.task_id, body.new_position, user.site_id)
|
||||
return {"message": "队列已重排"}
|
||||
|
||||
|
||||
# ── DELETE /api/execution/queue/{id} — 删除队列任务 ──────────
|
||||
|
||||
@router.delete("/queue/{task_id}")
|
||||
async def delete_queue_task(
|
||||
task_id: str,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""从队列中删除待执行任务。仅允许删除 pending 状态的任务。"""
|
||||
deleted = task_queue.delete(task_id, user.site_id)
|
||||
if not deleted:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="任务不存在或非待执行状态,无法删除",
|
||||
)
|
||||
return {"message": "任务已从队列中删除"}
|
||||
|
||||
|
||||
# ── POST /api/execution/{id}/cancel — 取消执行 ──────────────
|
||||
|
||||
@router.post("/{execution_id}/cancel")
|
||||
async def cancel_execution(
|
||||
execution_id: str,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""取消正在执行的任务。"""
|
||||
cancelled = await task_executor.cancel(execution_id)
|
||||
if not cancelled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="执行任务不存在或已完成",
|
||||
)
|
||||
return {"message": "已发送取消信号"}
|
||||
|
||||
|
||||
# ── GET /api/execution/history — 执行历史 ────────────────────
|
||||
|
||||
@router.get("/history", response_model=list[ExecutionHistoryItem])
|
||||
async def get_execution_history(
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> list[ExecutionHistoryItem]:
|
||||
"""获取执行历史记录,按 started_at 降序排列。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, site_id, task_codes, status, started_at,
|
||||
finished_at, exit_code, duration_ms, command, summary
|
||||
FROM task_execution_log
|
||||
WHERE site_id = %s
|
||||
ORDER BY started_at DESC
|
||||
LIMIT %s
|
||||
""",
|
||||
(user.site_id, limit),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return [
|
||||
ExecutionHistoryItem(
|
||||
id=str(row[0]),
|
||||
site_id=row[1],
|
||||
task_codes=row[2] or [],
|
||||
status=row[3],
|
||||
started_at=row[4],
|
||||
finished_at=row[5],
|
||||
exit_code=row[6],
|
||||
duration_ms=row[7],
|
||||
command=row[8],
|
||||
summary=row[9],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
|
||||
# ── GET /api/execution/{id}/logs — 获取历史日志 ──────────────
|
||||
|
||||
@router.get("/{execution_id}/logs", response_model=ExecutionLogsResponse)
|
||||
async def get_execution_logs(
|
||||
execution_id: str,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> ExecutionLogsResponse:
|
||||
"""获取指定执行的完整日志。
|
||||
|
||||
优先从内存缓冲区读取(执行中),否则从数据库读取(已完成)。
|
||||
"""
|
||||
# 先尝试内存缓冲区(执行中的任务)
|
||||
if task_executor.is_running(execution_id):
|
||||
lines = task_executor.get_logs(execution_id)
|
||||
return ExecutionLogsResponse(
|
||||
execution_id=execution_id,
|
||||
output_log="\n".join(lines) if lines else None,
|
||||
)
|
||||
|
||||
# 从数据库读取已完成任务的日志
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT output_log, error_log
|
||||
FROM task_execution_log
|
||||
WHERE id = %s AND site_id = %s
|
||||
""",
|
||||
(execution_id, user.site_id),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if row is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="执行记录不存在",
|
||||
)
|
||||
|
||||
return ExecutionLogsResponse(
|
||||
execution_id=execution_id,
|
||||
output_log=row[0],
|
||||
error_log=row[1],
|
||||
)
|
||||
293
apps/backend/app/routers/schedules.py
Normal file
293
apps/backend/app/routers/schedules.py
Normal file
@@ -0,0 +1,293 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""调度任务 CRUD API
|
||||
|
||||
提供 5 个端点:
|
||||
- GET /api/schedules — 列表(按 site_id 过滤)
|
||||
- POST /api/schedules — 创建
|
||||
- PUT /api/schedules/{id} — 更新
|
||||
- DELETE /api/schedules/{id} — 删除
|
||||
- PATCH /api/schedules/{id}/toggle — 启用/禁用
|
||||
|
||||
所有端点需要 JWT 认证,site_id 从 JWT 提取。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.database import get_connection
|
||||
from app.schemas.schedules import (
|
||||
CreateScheduleRequest,
|
||||
ScheduleResponse,
|
||||
UpdateScheduleRequest,
|
||||
)
|
||||
from app.services.scheduler import calculate_next_run
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/schedules", tags=["调度管理"])
|
||||
|
||||
|
||||
|
||||
def _row_to_response(row) -> ScheduleResponse:
|
||||
"""将数据库行转换为 ScheduleResponse。"""
|
||||
task_config = row[4] if isinstance(row[4], dict) else json.loads(row[4])
|
||||
schedule_config = row[5] if isinstance(row[5], dict) else json.loads(row[5])
|
||||
return ScheduleResponse(
|
||||
id=str(row[0]),
|
||||
site_id=row[1],
|
||||
name=row[2],
|
||||
task_codes=row[3] or [],
|
||||
task_config=task_config,
|
||||
schedule_config=schedule_config,
|
||||
enabled=row[6],
|
||||
last_run_at=row[7],
|
||||
next_run_at=row[8],
|
||||
run_count=row[9],
|
||||
last_status=row[10],
|
||||
created_at=row[11],
|
||||
updated_at=row[12],
|
||||
)
|
||||
|
||||
|
||||
# 查询列列表,复用于多个端点
|
||||
_SELECT_COLS = """
|
||||
id, site_id, name, task_codes, task_config, schedule_config,
|
||||
enabled, last_run_at, next_run_at, run_count, last_status,
|
||||
created_at, updated_at
|
||||
"""
|
||||
|
||||
|
||||
# ── GET /api/schedules — 列表 ────────────────────────────────
|
||||
|
||||
@router.get("", response_model=list[ScheduleResponse])
|
||||
async def list_schedules(
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> list[ScheduleResponse]:
|
||||
"""获取当前门店的所有调度任务。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT {_SELECT_COLS} FROM scheduled_tasks WHERE site_id = %s ORDER BY created_at DESC",
|
||||
(user.site_id,),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return [_row_to_response(row) for row in rows]
|
||||
|
||||
|
||||
# ── POST /api/schedules — 创建 ──────────────────────────────
|
||||
|
||||
@router.post("", response_model=ScheduleResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_schedule(
|
||||
body: CreateScheduleRequest,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> ScheduleResponse:
|
||||
"""创建调度任务,自动计算 next_run_at。"""
|
||||
now = datetime.now(timezone.utc)
|
||||
next_run = calculate_next_run(body.schedule_config, now)
|
||||
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
f"""
|
||||
INSERT INTO scheduled_tasks
|
||||
(site_id, name, task_codes, task_config, schedule_config, enabled, next_run_at)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
||||
RETURNING {_SELECT_COLS}
|
||||
""",
|
||||
(
|
||||
user.site_id,
|
||||
body.name,
|
||||
body.task_codes,
|
||||
json.dumps(body.task_config),
|
||||
body.schedule_config.model_dump_json(),
|
||||
body.schedule_config.enabled,
|
||||
next_run,
|
||||
),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return _row_to_response(row)
|
||||
|
||||
|
||||
# ── PUT /api/schedules/{id} — 更新 ──────────────────────────
|
||||
|
||||
@router.put("/{schedule_id}", response_model=ScheduleResponse)
|
||||
async def update_schedule(
|
||||
schedule_id: str,
|
||||
body: UpdateScheduleRequest,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> ScheduleResponse:
|
||||
"""更新调度任务,仅更新请求中提供的字段。"""
|
||||
# 构建动态 SET 子句
|
||||
set_parts: list[str] = []
|
||||
params: list = []
|
||||
|
||||
if body.name is not None:
|
||||
set_parts.append("name = %s")
|
||||
params.append(body.name)
|
||||
if body.task_codes is not None:
|
||||
set_parts.append("task_codes = %s")
|
||||
params.append(body.task_codes)
|
||||
if body.task_config is not None:
|
||||
set_parts.append("task_config = %s")
|
||||
params.append(json.dumps(body.task_config))
|
||||
if body.schedule_config is not None:
|
||||
set_parts.append("schedule_config = %s")
|
||||
params.append(body.schedule_config.model_dump_json())
|
||||
# 更新调度配置时重新计算 next_run_at
|
||||
now = datetime.now(timezone.utc)
|
||||
next_run = calculate_next_run(body.schedule_config, now)
|
||||
set_parts.append("next_run_at = %s")
|
||||
params.append(next_run)
|
||||
|
||||
if not set_parts:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="至少需要提供一个更新字段",
|
||||
)
|
||||
|
||||
set_parts.append("updated_at = NOW()")
|
||||
set_clause = ", ".join(set_parts)
|
||||
params.extend([schedule_id, user.site_id])
|
||||
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
f"""
|
||||
UPDATE scheduled_tasks
|
||||
SET {set_clause}
|
||||
WHERE id = %s AND site_id = %s
|
||||
RETURNING {_SELECT_COLS}
|
||||
""",
|
||||
params,
|
||||
)
|
||||
row = cur.fetchone()
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if row is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="调度任务不存在",
|
||||
)
|
||||
|
||||
return _row_to_response(row)
|
||||
|
||||
|
||||
# ── DELETE /api/schedules/{id} — 删除 ────────────────────────
|
||||
|
||||
@router.delete("/{schedule_id}")
|
||||
async def delete_schedule(
|
||||
schedule_id: str,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""删除调度任务。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"DELETE FROM scheduled_tasks WHERE id = %s AND site_id = %s",
|
||||
(schedule_id, user.site_id),
|
||||
)
|
||||
deleted = cur.rowcount
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if deleted == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="调度任务不存在",
|
||||
)
|
||||
|
||||
return {"message": "调度任务已删除"}
|
||||
|
||||
|
||||
# ── PATCH /api/schedules/{id}/toggle — 启用/禁用 ─────────────
|
||||
|
||||
@router.patch("/{schedule_id}/toggle", response_model=ScheduleResponse)
|
||||
async def toggle_schedule(
|
||||
schedule_id: str,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> ScheduleResponse:
|
||||
"""切换调度任务的启用/禁用状态。
|
||||
|
||||
禁用时 next_run_at 置 NULL;启用时重新计算 next_run_at。
|
||||
"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
# 先查询当前状态和调度配置
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT enabled, schedule_config FROM scheduled_tasks WHERE id = %s AND site_id = %s",
|
||||
(schedule_id, user.site_id),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
|
||||
if row is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="调度任务不存在",
|
||||
)
|
||||
|
||||
current_enabled = row[0]
|
||||
new_enabled = not current_enabled
|
||||
|
||||
if new_enabled:
|
||||
# 启用:重新计算 next_run_at
|
||||
schedule_config_raw = row[1] if isinstance(row[1], dict) else json.loads(row[1])
|
||||
from app.schemas.schedules import ScheduleConfigSchema
|
||||
schedule_cfg = ScheduleConfigSchema(**schedule_config_raw)
|
||||
now = datetime.now(timezone.utc)
|
||||
next_run = calculate_next_run(schedule_cfg, now)
|
||||
else:
|
||||
# 禁用:next_run_at 置 NULL
|
||||
next_run = None
|
||||
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
f"""
|
||||
UPDATE scheduled_tasks
|
||||
SET enabled = %s, next_run_at = %s, updated_at = NOW()
|
||||
WHERE id = %s AND site_id = %s
|
||||
RETURNING {_SELECT_COLS}
|
||||
""",
|
||||
(new_enabled, next_run, schedule_id, user.site_id),
|
||||
)
|
||||
updated_row = cur.fetchone()
|
||||
conn.commit()
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return _row_to_response(updated_row)
|
||||
209
apps/backend/app/routers/tasks.py
Normal file
209
apps/backend/app/routers/tasks.py
Normal file
@@ -0,0 +1,209 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""任务注册表 & 配置 API
|
||||
|
||||
提供 4 个端点:
|
||||
- GET /api/tasks/registry — 按业务域分组的任务列表
|
||||
- GET /api/tasks/dwd-tables — 按业务域分组的 DWD 表定义
|
||||
- GET /api/tasks/flows — 7 种 Flow + 3 种处理模式
|
||||
- POST /api/tasks/validate — 验证 TaskConfig 并返回 CLI 命令预览
|
||||
|
||||
所有端点需要 JWT 认证。validate 端点从 JWT 注入 store_id。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.config import ETL_PROJECT_PATH
|
||||
from app.schemas.tasks import (
|
||||
FlowDefinition,
|
||||
ProcessingModeDefinition,
|
||||
TaskConfigSchema,
|
||||
)
|
||||
from app.services.cli_builder import cli_builder
|
||||
from app.services.task_registry import (
|
||||
DWD_TABLES,
|
||||
FLOW_LAYER_MAP,
|
||||
get_dwd_tables_grouped_by_domain,
|
||||
get_tasks_grouped_by_domain,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/tasks", tags=["任务配置"])
|
||||
|
||||
|
||||
# ── 响应模型 ──────────────────────────────────────────────────
|
||||
|
||||
class TaskItem(BaseModel):
|
||||
code: str
|
||||
name: str
|
||||
description: str
|
||||
domain: str
|
||||
layer: str
|
||||
requires_window: bool
|
||||
is_ods: bool
|
||||
is_dimension: bool
|
||||
default_enabled: bool
|
||||
is_common: bool
|
||||
|
||||
|
||||
class DwdTableItem(BaseModel):
|
||||
table_name: str
|
||||
display_name: str
|
||||
domain: str
|
||||
ods_source: str
|
||||
is_dimension: bool
|
||||
|
||||
|
||||
class TaskRegistryResponse(BaseModel):
|
||||
"""按业务域分组的任务列表"""
|
||||
groups: dict[str, list[TaskItem]]
|
||||
|
||||
|
||||
class DwdTablesResponse(BaseModel):
|
||||
"""按业务域分组的 DWD 表定义"""
|
||||
groups: dict[str, list[DwdTableItem]]
|
||||
|
||||
|
||||
class FlowsResponse(BaseModel):
|
||||
"""Flow 定义 + 处理模式定义"""
|
||||
flows: list[FlowDefinition]
|
||||
processing_modes: list[ProcessingModeDefinition]
|
||||
|
||||
|
||||
class ValidateRequest(BaseModel):
|
||||
"""验证请求体 — 复用 TaskConfigSchema,但 store_id 由后端注入"""
|
||||
config: TaskConfigSchema
|
||||
|
||||
|
||||
class ValidateResponse(BaseModel):
|
||||
"""验证结果 + CLI 命令预览"""
|
||||
valid: bool
|
||||
command: str
|
||||
command_args: list[str]
|
||||
errors: list[str]
|
||||
|
||||
|
||||
# ── Flow 定义(静态) ────────────────────────────────────────
|
||||
|
||||
FLOW_DEFINITIONS: list[FlowDefinition] = [
|
||||
FlowDefinition(id="api_ods", name="API → ODS", layers=["ODS"]),
|
||||
FlowDefinition(id="api_ods_dwd", name="API → ODS → DWD", layers=["ODS", "DWD"]),
|
||||
FlowDefinition(id="api_full", name="API → ODS → DWD → DWS汇总 → DWS指数", layers=["ODS", "DWD", "DWS", "INDEX"]),
|
||||
FlowDefinition(id="ods_dwd", name="ODS → DWD", layers=["DWD"]),
|
||||
FlowDefinition(id="dwd_dws", name="DWD → DWS汇总", layers=["DWS"]),
|
||||
FlowDefinition(id="dwd_dws_index", name="DWD → DWS汇总 → DWS指数", layers=["DWS", "INDEX"]),
|
||||
FlowDefinition(id="dwd_index", name="DWD → DWS指数", layers=["INDEX"]),
|
||||
]
|
||||
|
||||
PROCESSING_MODE_DEFINITIONS: list[ProcessingModeDefinition] = [
|
||||
ProcessingModeDefinition(id="increment_only", name="仅增量处理", description="只处理新增和变更的数据"),
|
||||
ProcessingModeDefinition(id="verify_only", name="仅校验修复", description="校验现有数据并修复不一致(可选'校验前从 API 获取')"),
|
||||
ProcessingModeDefinition(id="increment_verify", name="增量 + 校验修复", description="先增量处理,再校验并修复"),
|
||||
]
|
||||
|
||||
|
||||
# ── 端点 ──────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/registry", response_model=TaskRegistryResponse)
|
||||
async def get_task_registry(
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> TaskRegistryResponse:
|
||||
"""返回按业务域分组的任务列表"""
|
||||
grouped = get_tasks_grouped_by_domain()
|
||||
return TaskRegistryResponse(
|
||||
groups={
|
||||
domain: [
|
||||
TaskItem(
|
||||
code=t.code,
|
||||
name=t.name,
|
||||
description=t.description,
|
||||
domain=t.domain,
|
||||
layer=t.layer,
|
||||
requires_window=t.requires_window,
|
||||
is_ods=t.is_ods,
|
||||
is_dimension=t.is_dimension,
|
||||
default_enabled=t.default_enabled,
|
||||
is_common=t.is_common,
|
||||
)
|
||||
for t in tasks
|
||||
]
|
||||
for domain, tasks in grouped.items()
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/dwd-tables", response_model=DwdTablesResponse)
|
||||
async def get_dwd_tables(
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> DwdTablesResponse:
|
||||
"""返回按业务域分组的 DWD 表定义"""
|
||||
grouped = get_dwd_tables_grouped_by_domain()
|
||||
return DwdTablesResponse(
|
||||
groups={
|
||||
domain: [
|
||||
DwdTableItem(
|
||||
table_name=t.table_name,
|
||||
display_name=t.display_name,
|
||||
domain=t.domain,
|
||||
ods_source=t.ods_source,
|
||||
is_dimension=t.is_dimension,
|
||||
)
|
||||
for t in tables
|
||||
]
|
||||
for domain, tables in grouped.items()
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/flows", response_model=FlowsResponse)
|
||||
async def get_flows(
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> FlowsResponse:
|
||||
"""返回 7 种 Flow 定义和 3 种处理模式定义"""
|
||||
return FlowsResponse(
|
||||
flows=FLOW_DEFINITIONS,
|
||||
processing_modes=PROCESSING_MODE_DEFINITIONS,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/validate", response_model=ValidateResponse)
|
||||
async def validate_task_config(
|
||||
body: ValidateRequest,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> ValidateResponse:
|
||||
"""验证 TaskConfig 并返回生成的 CLI 命令预览
|
||||
|
||||
从 JWT 注入 store_id,前端无需传递。
|
||||
"""
|
||||
config = body.config.model_copy(update={"store_id": user.site_id})
|
||||
errors: list[str] = []
|
||||
|
||||
# 验证 Flow ID
|
||||
if config.pipeline not in FLOW_LAYER_MAP:
|
||||
errors.append(f"无效的执行流程: {config.pipeline}")
|
||||
|
||||
# 验证任务列表非空
|
||||
if not config.tasks:
|
||||
errors.append("任务列表不能为空")
|
||||
|
||||
if errors:
|
||||
return ValidateResponse(
|
||||
valid=False,
|
||||
command="",
|
||||
command_args=[],
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
cmd_args = cli_builder.build_command(config, ETL_PROJECT_PATH)
|
||||
cmd_str = cli_builder.build_command_string(config, ETL_PROJECT_PATH)
|
||||
|
||||
return ValidateResponse(
|
||||
valid=True,
|
||||
command=cmd_str,
|
||||
command_args=cmd_args,
|
||||
errors=[],
|
||||
)
|
||||
104
apps/backend/app/routers/wx_callback.py
Normal file
104
apps/backend/app/routers/wx_callback.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# AI_CHANGELOG
|
||||
# - 2026-02-19 | Prompt: 配置微信消息推送 | 新增微信消息推送回调接口,支持 GET 验签 + POST 消息接收
|
||||
|
||||
"""
|
||||
微信消息推送回调接口
|
||||
|
||||
处理两类请求:
|
||||
1. GET — 微信服务器验证(配置时触发一次)
|
||||
2. POST — 接收微信推送的消息/事件
|
||||
|
||||
安全模式下需要解密消息体,当前先用明文模式跑通,后续切安全模式。
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Query, Request, Response
|
||||
|
||||
from app.config import get
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/wx", tags=["微信回调"])
|
||||
|
||||
# Token 从环境变量读取,与微信后台填写的一致
|
||||
# 放在 apps/backend/.env.local 中:WX_CALLBACK_TOKEN=你自定义的token
|
||||
WX_CALLBACK_TOKEN: str = get("WX_CALLBACK_TOKEN", "")
|
||||
|
||||
|
||||
def _check_signature(signature: str, timestamp: str, nonce: str) -> bool:
|
||||
"""
|
||||
验证请求是否来自微信服务器。
|
||||
|
||||
将 Token、timestamp、nonce 字典序排序后拼接,做 SHA1,
|
||||
与 signature 比对。
|
||||
"""
|
||||
if not WX_CALLBACK_TOKEN:
|
||||
logger.error("WX_CALLBACK_TOKEN 未配置")
|
||||
return False
|
||||
|
||||
items = sorted([WX_CALLBACK_TOKEN, timestamp, nonce])
|
||||
hash_str = hashlib.sha1("".join(items).encode("utf-8")).hexdigest()
|
||||
return hash_str == signature
|
||||
|
||||
|
||||
@router.get("/callback")
|
||||
async def verify(
|
||||
signature: str = Query(...),
|
||||
timestamp: str = Query(...),
|
||||
nonce: str = Query(...),
|
||||
echostr: str = Query(...),
|
||||
):
|
||||
"""
|
||||
微信服务器验证接口。
|
||||
|
||||
配置消息推送时微信会发 GET 请求,验签通过后原样返回 echostr。
|
||||
"""
|
||||
if _check_signature(signature, timestamp, nonce):
|
||||
logger.info("微信回调验证通过")
|
||||
# 必须原样返回 echostr(纯文本,不能包裹 JSON)
|
||||
return Response(content=echostr, media_type="text/plain")
|
||||
else:
|
||||
logger.warning("微信回调验签失败: signature=%s", signature)
|
||||
return Response(content="signature mismatch", status_code=403)
|
||||
|
||||
|
||||
@router.post("/callback")
|
||||
async def receive_message(
|
||||
request: Request,
|
||||
signature: str = Query(""),
|
||||
timestamp: str = Query(""),
|
||||
nonce: str = Query(""),
|
||||
):
|
||||
"""
|
||||
接收微信推送的消息/事件。
|
||||
|
||||
当前为明文模式,直接解析 JSON 包体。
|
||||
后续切安全模式时需增加 AES 解密逻辑。
|
||||
"""
|
||||
# 验签(POST 也带 signature 参数)
|
||||
if not _check_signature(signature, timestamp, nonce):
|
||||
logger.warning("消息推送验签失败")
|
||||
return Response(content="signature mismatch", status_code=403)
|
||||
|
||||
# 解析消息体
|
||||
body = await request.body()
|
||||
content_type = request.headers.get("content-type", "")
|
||||
|
||||
if "json" in content_type:
|
||||
import json
|
||||
try:
|
||||
data = json.loads(body)
|
||||
except json.JSONDecodeError:
|
||||
data = {"raw": body.decode("utf-8", errors="replace")}
|
||||
else:
|
||||
# XML 格式暂不解析,记录原文
|
||||
data = {"raw_xml": body.decode("utf-8", errors="replace")}
|
||||
|
||||
logger.info("收到微信推送: MsgType=%s, Event=%s",
|
||||
data.get("MsgType", "?"), data.get("Event", "?"))
|
||||
|
||||
# TODO: 根据 MsgType/Event 分发处理(客服消息、订阅事件等)
|
||||
# 当前统一返回 success
|
||||
return Response(content="success", media_type="text/plain")
|
||||
37
apps/backend/app/routers/xcx_test.py
Normal file
37
apps/backend/app/routers/xcx_test.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# AI_CHANGELOG
|
||||
# - 2026-02-19 | Prompt: 小程序 MVP 全链路验证 | 新增 /api/xcx-test 接口,从 test."xcx-test" 表读取 ti 列第一行
|
||||
|
||||
"""
|
||||
小程序 MVP 验证接口
|
||||
|
||||
从 test_zqyy_app 库的 test."xcx-test" 表读取数据,
|
||||
用于验证小程序 → 后端 → 数据库全链路连通性。
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from app.database import get_connection
|
||||
|
||||
router = APIRouter(prefix="/api/xcx-test", tags=["小程序MVP"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_xcx_test():
|
||||
"""
|
||||
读取 test."xcx-test" 表 ti 列第一行。
|
||||
|
||||
用于小程序 MVP 全链路验证:小程序 → API → DB → 返回数据。
|
||||
"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
# CHANGE 2026-02-19 | 读取 test schema 下的 xcx-test 表
|
||||
# 表名含连字符,必须用双引号包裹
|
||||
cur.execute('SELECT ti FROM test."xcx-test" LIMIT 1')
|
||||
row = cur.fetchone()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail="无数据")
|
||||
|
||||
return {"ti": row[0]}
|
||||
30
apps/backend/app/schemas/auth.py
Normal file
30
apps/backend/app/schemas/auth.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""
|
||||
认证相关 Pydantic 模型。
|
||||
|
||||
- LoginRequest:登录请求体
|
||||
- TokenResponse:令牌响应体
|
||||
- RefreshRequest:刷新令牌请求体
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
"""登录请求。"""
|
||||
|
||||
username: str = Field(..., min_length=1, max_length=64, description="用户名")
|
||||
password: str = Field(..., min_length=1, description="密码")
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
"""刷新令牌请求。"""
|
||||
|
||||
refresh_token: str = Field(..., min_length=1, description="刷新令牌")
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""令牌响应。"""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
42
apps/backend/app/schemas/db_viewer.py
Normal file
42
apps/backend/app/schemas/db_viewer.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""数据库查看器 Pydantic 模型
|
||||
|
||||
定义 Schema 浏览、表结构查看、SQL 查询的请求/响应模型。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SchemaInfo(BaseModel):
|
||||
"""Schema 信息。"""
|
||||
name: str
|
||||
|
||||
|
||||
class TableInfo(BaseModel):
|
||||
"""表信息(含行数统计)。"""
|
||||
name: str
|
||||
row_count: int | None = None
|
||||
|
||||
|
||||
class ColumnInfo(BaseModel):
|
||||
"""列定义。"""
|
||||
name: str
|
||||
data_type: str
|
||||
is_nullable: bool
|
||||
column_default: str | None = None
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
"""SQL 查询请求。"""
|
||||
sql: str
|
||||
|
||||
|
||||
class QueryResponse(BaseModel):
|
||||
"""SQL 查询响应。"""
|
||||
columns: list[str]
|
||||
rows: list[list[Any]]
|
||||
row_count: int
|
||||
27
apps/backend/app/schemas/etl_status.py
Normal file
27
apps/backend/app/schemas/etl_status.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""ETL 状态监控 Pydantic 模型
|
||||
|
||||
定义游标信息和最近执行记录的响应模型。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CursorInfo(BaseModel):
|
||||
"""ETL 游标信息(单条任务的最后抓取状态)。"""
|
||||
task_code: str
|
||||
last_fetch_time: str | None = None
|
||||
record_count: int | None = None
|
||||
|
||||
|
||||
class RecentRun(BaseModel):
|
||||
"""最近执行记录。"""
|
||||
id: str
|
||||
task_codes: list[str]
|
||||
status: str
|
||||
started_at: str
|
||||
finished_at: str | None = None
|
||||
duration_ms: int | None = None
|
||||
exit_code: int | None = None
|
||||
59
apps/backend/app/schemas/execution.py
Normal file
59
apps/backend/app/schemas/execution.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""执行与队列相关的 Pydantic 模型
|
||||
|
||||
用于 execution 路由的请求/响应序列化。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ReorderRequest(BaseModel):
|
||||
"""队列重排请求"""
|
||||
task_id: str
|
||||
new_position: int
|
||||
|
||||
|
||||
class QueueTaskResponse(BaseModel):
|
||||
"""队列任务响应"""
|
||||
id: str
|
||||
site_id: int
|
||||
config: dict[str, Any]
|
||||
status: str
|
||||
position: int
|
||||
created_at: datetime | None = None
|
||||
started_at: datetime | None = None
|
||||
finished_at: datetime | None = None
|
||||
exit_code: int | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class ExecutionRunResponse(BaseModel):
|
||||
"""直接执行任务的响应"""
|
||||
execution_id: str
|
||||
message: str
|
||||
|
||||
|
||||
class ExecutionHistoryItem(BaseModel):
|
||||
"""执行历史记录"""
|
||||
id: str
|
||||
site_id: int
|
||||
task_codes: list[str]
|
||||
status: str
|
||||
started_at: datetime
|
||||
finished_at: datetime | None = None
|
||||
exit_code: int | None = None
|
||||
duration_ms: int | None = None
|
||||
command: str | None = None
|
||||
summary: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ExecutionLogsResponse(BaseModel):
|
||||
"""执行日志响应"""
|
||||
execution_id: str
|
||||
output_log: str | None = None
|
||||
error_log: str | None = None
|
||||
61
apps/backend/app/schemas/schedules.py
Normal file
61
apps/backend/app/schemas/schedules.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""调度配置 Pydantic 模型
|
||||
|
||||
定义 ScheduleConfigSchema 及相关模型,供调度服务和路由使用。
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ScheduleConfigSchema(BaseModel):
|
||||
"""调度配置 — 支持 5 种调度类型"""
|
||||
|
||||
schedule_type: Literal["once", "interval", "daily", "weekly", "cron"]
|
||||
interval_value: int = 1
|
||||
interval_unit: Literal["minutes", "hours", "days"] = "hours"
|
||||
daily_time: str = "04:00"
|
||||
weekly_days: list[int] = [1]
|
||||
weekly_time: str = "04:00"
|
||||
cron_expression: str = "0 4 * * *"
|
||||
enabled: bool = True
|
||||
start_date: str | None = None
|
||||
end_date: str | None = None
|
||||
|
||||
|
||||
class CreateScheduleRequest(BaseModel):
|
||||
"""创建调度任务请求"""
|
||||
|
||||
name: str
|
||||
task_codes: list[str]
|
||||
task_config: dict[str, Any]
|
||||
schedule_config: ScheduleConfigSchema
|
||||
|
||||
|
||||
class UpdateScheduleRequest(BaseModel):
|
||||
"""更新调度任务请求(所有字段可选)"""
|
||||
|
||||
name: str | None = None
|
||||
task_codes: list[str] | None = None
|
||||
task_config: dict[str, Any] | None = None
|
||||
schedule_config: ScheduleConfigSchema | None = None
|
||||
|
||||
|
||||
class ScheduleResponse(BaseModel):
|
||||
"""调度任务响应"""
|
||||
|
||||
id: str
|
||||
site_id: int
|
||||
name: str
|
||||
task_codes: list[str]
|
||||
task_config: dict[str, Any]
|
||||
schedule_config: dict[str, Any]
|
||||
enabled: bool
|
||||
last_run_at: datetime | None = None
|
||||
next_run_at: datetime | None = None
|
||||
run_count: int
|
||||
last_status: str | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
73
apps/backend/app/schemas/tasks.py
Normal file
73
apps/backend/app/schemas/tasks.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""任务配置 Pydantic 模型
|
||||
|
||||
定义 TaskConfigSchema 及相关模型,用于前后端传输和 CLIBuilder 消费。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
||||
class TaskConfigSchema(BaseModel):
|
||||
"""任务配置 — 前后端传输格式
|
||||
|
||||
字段与 CLI 参数的映射关系:
|
||||
- pipeline → --pipeline(Flow ID,7 种之一)
|
||||
- processing_mode → --processing-mode(3 种处理模式)
|
||||
- tasks → --tasks(逗号分隔)
|
||||
- dry_run → --dry-run(布尔标志)
|
||||
- window_mode → 决定使用 lookback 还是 custom 时间窗口(仅前端逻辑,不直接映射 CLI 参数)
|
||||
- window_start → --window-start
|
||||
- window_end → --window-end
|
||||
- window_split → --window-split
|
||||
- window_split_days → --window-split-days
|
||||
- lookback_hours → --lookback-hours
|
||||
- overlap_seconds → --overlap-seconds
|
||||
- fetch_before_verify → --fetch-before-verify(布尔标志)
|
||||
- store_id → --store-id(由后端从 JWT 注入,前端不传)
|
||||
- dwd_only_tables → 传入 extra_args 或未来扩展
|
||||
"""
|
||||
|
||||
tasks: list[str]
|
||||
pipeline: str = "api_ods_dwd"
|
||||
processing_mode: str = "increment_only"
|
||||
dry_run: bool = False
|
||||
window_mode: str = "lookback"
|
||||
window_start: str | None = None
|
||||
window_end: str | None = None
|
||||
window_split: str | None = None
|
||||
window_split_days: int | None = None
|
||||
lookback_hours: int = 24
|
||||
overlap_seconds: int = 600
|
||||
fetch_before_verify: bool = False
|
||||
skip_ods_when_fetch_before_verify: bool = False
|
||||
ods_use_local_json: bool = False
|
||||
store_id: int | None = None
|
||||
dwd_only_tables: list[str] | None = None
|
||||
force_full: bool = False
|
||||
extra_args: dict[str, Any] = {}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_window(self) -> "TaskConfigSchema":
|
||||
"""验证时间窗口:结束日期不早于开始日期"""
|
||||
if self.window_start and self.window_end:
|
||||
if self.window_end < self.window_start:
|
||||
raise ValueError("window_end 不能早于 window_start")
|
||||
return self
|
||||
|
||||
|
||||
class FlowDefinition(BaseModel):
|
||||
"""执行流程(Flow)定义"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
layers: list[str]
|
||||
|
||||
|
||||
class ProcessingModeDefinition(BaseModel):
|
||||
"""处理模式定义"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
1
apps/backend/app/services/__init__.py
Normal file
1
apps/backend/app/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
158
apps/backend/app/services/cli_builder.py
Normal file
158
apps/backend/app/services/cli_builder.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""CLI 命令构建器
|
||||
|
||||
从 gui/utils/cli_builder.py 迁移,适配后端 TaskConfigSchema。
|
||||
将 TaskConfigSchema 转换为 ETL CLI 命令行参数列表。
|
||||
|
||||
支持:
|
||||
- 7 种 Flow(api_ods / api_ods_dwd / api_full / ods_dwd / dwd_dws / dwd_dws_index / dwd_index)
|
||||
- 3 种处理模式(increment_only / verify_only / increment_verify)
|
||||
- 自动注入 --store-id 参数
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ..schemas.tasks import TaskConfigSchema
|
||||
|
||||
# 有效的 Flow ID 集合
|
||||
VALID_FLOWS: set[str] = {
|
||||
"api_ods",
|
||||
"api_ods_dwd",
|
||||
"api_full",
|
||||
"ods_dwd",
|
||||
"dwd_dws",
|
||||
"dwd_dws_index",
|
||||
"dwd_index",
|
||||
}
|
||||
|
||||
# 有效的处理模式集合
|
||||
VALID_PROCESSING_MODES: set[str] = {
|
||||
"increment_only",
|
||||
"verify_only",
|
||||
"increment_verify",
|
||||
}
|
||||
|
||||
# CLI 支持的 extra_args 键(值类型 + 布尔类型)
|
||||
CLI_SUPPORTED_ARGS: set[str] = {
|
||||
# 值类型参数
|
||||
"pg_dsn", "pg_host", "pg_port", "pg_name",
|
||||
"pg_user", "pg_password", "api_base", "api_token", "api_timeout",
|
||||
"api_page_size", "api_retry_max",
|
||||
"export_root", "log_root", "fetch_root",
|
||||
"ingest_source", "idle_start", "idle_end",
|
||||
"data_source", "pipeline_flow",
|
||||
"window_split_unit",
|
||||
# 布尔类型参数
|
||||
"force_window_override", "write_pretty_json", "allow_empty_advance",
|
||||
}
|
||||
|
||||
|
||||
class CLIBuilder:
|
||||
"""将 TaskConfigSchema 转换为 ETL CLI 命令行参数列表"""
|
||||
|
||||
def build_command(
|
||||
self,
|
||||
config: TaskConfigSchema,
|
||||
etl_project_path: str,
|
||||
python_executable: str = "python",
|
||||
) -> list[str]:
|
||||
"""构建完整的 CLI 命令参数列表。
|
||||
|
||||
生成格式:
|
||||
[python, -m, cli.main, --flow, {flow_id}, --tasks, ..., --store-id, {site_id}, ...]
|
||||
|
||||
Args:
|
||||
config: 任务配置对象(Pydantic 模型)
|
||||
etl_project_path: ETL 项目根目录路径(用于 cwd,不拼入命令)
|
||||
python_executable: Python 可执行文件路径,默认 "python"
|
||||
|
||||
Returns:
|
||||
命令行参数列表
|
||||
"""
|
||||
cmd: list[str] = [python_executable, "-m", "cli.main"]
|
||||
|
||||
# -- Flow(执行流程) --
|
||||
cmd.extend(["--flow", config.pipeline])
|
||||
|
||||
# -- 处理模式 --
|
||||
if config.processing_mode:
|
||||
cmd.extend(["--processing-mode", config.processing_mode])
|
||||
|
||||
# -- 任务列表 --
|
||||
if config.tasks:
|
||||
cmd.extend(["--tasks", ",".join(config.tasks)])
|
||||
|
||||
# -- 校验前从 API 获取数据(仅 verify_only 模式有效) --
|
||||
if config.fetch_before_verify and config.processing_mode == "verify_only":
|
||||
cmd.append("--fetch-before-verify")
|
||||
|
||||
# -- 时间窗口 --
|
||||
if config.window_mode == "lookback":
|
||||
# 回溯模式
|
||||
if config.lookback_hours is not None:
|
||||
cmd.extend(["--lookback-hours", str(config.lookback_hours)])
|
||||
if config.overlap_seconds is not None:
|
||||
cmd.extend(["--overlap-seconds", str(config.overlap_seconds)])
|
||||
else:
|
||||
# 自定义时间窗口
|
||||
if config.window_start:
|
||||
cmd.extend(["--window-start", config.window_start])
|
||||
if config.window_end:
|
||||
cmd.extend(["--window-end", config.window_end])
|
||||
|
||||
# -- 时间窗口切分 --
|
||||
if config.window_split and config.window_split != "none":
|
||||
cmd.extend(["--window-split", config.window_split])
|
||||
if config.window_split_days is not None:
|
||||
cmd.extend(["--window-split-days", str(config.window_split_days)])
|
||||
|
||||
# -- Dry-run --
|
||||
if config.dry_run:
|
||||
cmd.append("--dry-run")
|
||||
|
||||
# -- 强制全量处理 --
|
||||
if config.force_full:
|
||||
cmd.append("--force-full")
|
||||
|
||||
# -- 本地 JSON 模式 → --data-source offline --
|
||||
if config.ods_use_local_json:
|
||||
cmd.extend(["--data-source", "offline"])
|
||||
|
||||
# -- 门店 ID(自动注入) --
|
||||
if config.store_id is not None:
|
||||
cmd.extend(["--store-id", str(config.store_id)])
|
||||
|
||||
# -- 额外参数(只传递 CLI 支持的参数) --
|
||||
for key, value in config.extra_args.items():
|
||||
if value is not None and key in CLI_SUPPORTED_ARGS:
|
||||
arg_name = f"--{key.replace('_', '-')}"
|
||||
if isinstance(value, bool):
|
||||
if value:
|
||||
cmd.append(arg_name)
|
||||
else:
|
||||
cmd.extend([arg_name, str(value)])
|
||||
|
||||
return cmd
|
||||
|
||||
def build_command_string(
|
||||
self,
|
||||
config: TaskConfigSchema,
|
||||
etl_project_path: str,
|
||||
python_executable: str = "python",
|
||||
) -> str:
|
||||
"""构建命令行字符串(用于显示/日志记录)。
|
||||
|
||||
对包含空格的参数自动添加引号。
|
||||
"""
|
||||
cmd = self.build_command(config, etl_project_path, python_executable)
|
||||
quoted: list[str] = []
|
||||
for arg in cmd:
|
||||
if " " in arg or '"' in arg:
|
||||
quoted.append(f'"{arg}"')
|
||||
else:
|
||||
quoted.append(arg)
|
||||
return " ".join(quoted)
|
||||
|
||||
|
||||
# 全局单例
|
||||
cli_builder = CLIBuilder()
|
||||
303
apps/backend/app/services/scheduler.py
Normal file
303
apps/backend/app/services/scheduler.py
Normal file
@@ -0,0 +1,303 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""调度器服务
|
||||
|
||||
后台 asyncio 循环,每 30 秒检查一次到期的调度任务,
|
||||
将其 TaskConfig 入队到 TaskQueue。
|
||||
|
||||
核心逻辑:
|
||||
- check_and_enqueue():查询 enabled=true 且 next_run_at <= now 的调度任务
|
||||
- start() / stop():管理后台循环生命周期
|
||||
- _calculate_next_run():根据 ScheduleConfig 计算下次执行时间
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from ..database import get_connection
|
||||
from ..schemas.schedules import ScheduleConfigSchema
|
||||
from ..schemas.tasks import TaskConfigSchema
|
||||
from .task_queue import task_queue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 调度器轮询间隔(秒)
|
||||
SCHEDULER_POLL_INTERVAL = 30
|
||||
|
||||
|
||||
def _parse_time(time_str: str) -> tuple[int, int]:
|
||||
"""解析 HH:MM 格式的时间字符串,返回 (hour, minute)。"""
|
||||
parts = time_str.split(":")
|
||||
return int(parts[0]), int(parts[1])
|
||||
|
||||
|
||||
def calculate_next_run(
|
||||
schedule_config: ScheduleConfigSchema,
|
||||
now: datetime | None = None,
|
||||
) -> datetime | None:
|
||||
"""根据调度配置计算下次执行时间。
|
||||
|
||||
Args:
|
||||
schedule_config: 调度配置
|
||||
now: 当前时间(默认 UTC now),方便测试注入
|
||||
|
||||
Returns:
|
||||
下次执行时间(UTC),once 类型返回 None 表示不再执行
|
||||
"""
|
||||
if now is None:
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
stype = schedule_config.schedule_type
|
||||
|
||||
if stype == "once":
|
||||
# 一次性任务执行后不再调度
|
||||
return None
|
||||
|
||||
if stype == "interval":
|
||||
unit_map = {
|
||||
"minutes": timedelta(minutes=schedule_config.interval_value),
|
||||
"hours": timedelta(hours=schedule_config.interval_value),
|
||||
"days": timedelta(days=schedule_config.interval_value),
|
||||
}
|
||||
delta = unit_map.get(schedule_config.interval_unit)
|
||||
if delta is None:
|
||||
logger.warning("未知的 interval_unit: %s", schedule_config.interval_unit)
|
||||
return None
|
||||
return now + delta
|
||||
|
||||
if stype == "daily":
|
||||
hour, minute = _parse_time(schedule_config.daily_time)
|
||||
# 计算明天的 daily_time
|
||||
tomorrow = now + timedelta(days=1)
|
||||
return tomorrow.replace(hour=hour, minute=minute, second=0, microsecond=0)
|
||||
|
||||
if stype == "weekly":
|
||||
hour, minute = _parse_time(schedule_config.weekly_time)
|
||||
days = sorted(schedule_config.weekly_days) if schedule_config.weekly_days else [1]
|
||||
# ISO weekday: 1=Monday ... 7=Sunday
|
||||
current_weekday = now.isoweekday()
|
||||
|
||||
# 找到下一个匹配的 weekday
|
||||
for day in days:
|
||||
if day > current_weekday:
|
||||
delta_days = day - current_weekday
|
||||
next_dt = now + timedelta(days=delta_days)
|
||||
return next_dt.replace(hour=hour, minute=minute, second=0, microsecond=0)
|
||||
|
||||
# 本周没有更晚的 weekday,跳到下周第一个
|
||||
first_day = days[0]
|
||||
delta_days = 7 - current_weekday + first_day
|
||||
next_dt = now + timedelta(days=delta_days)
|
||||
return next_dt.replace(hour=hour, minute=minute, second=0, microsecond=0)
|
||||
|
||||
if stype == "cron":
|
||||
# 简单 cron 解析:仅支持 "minute hour * * *" 格式(每日定时)
|
||||
# 复杂 cron 表达式可后续引入 croniter 库
|
||||
return _parse_simple_cron(schedule_config.cron_expression, now)
|
||||
|
||||
logger.warning("未知的 schedule_type: %s", stype)
|
||||
return None
|
||||
|
||||
|
||||
def _parse_simple_cron(expression: str, now: datetime) -> datetime | None:
|
||||
"""简单 cron 解析器,支持基本的 5 字段格式。
|
||||
|
||||
支持的格式:
|
||||
- "M H * * *" → 每天 H:M
|
||||
- "M H * * D" → 每周 D 的 H:M(D 为 0-6,0=Sunday)
|
||||
- 其他格式回退到每天 04:00
|
||||
|
||||
不支持范围、列表、步进等高级语法。如需完整 cron 支持,
|
||||
可在 pyproject.toml 中添加 croniter 依赖。
|
||||
"""
|
||||
parts = expression.strip().split()
|
||||
if len(parts) != 5:
|
||||
logger.warning("无法解析 cron 表达式: %s,回退到明天 04:00", expression)
|
||||
tomorrow = now + timedelta(days=1)
|
||||
return tomorrow.replace(hour=4, minute=0, second=0, microsecond=0)
|
||||
|
||||
minute_str, hour_str, dom, month, dow = parts
|
||||
|
||||
try:
|
||||
minute = int(minute_str) if minute_str != "*" else 0
|
||||
hour = int(hour_str) if hour_str != "*" else 0
|
||||
except ValueError:
|
||||
logger.warning("cron 表达式时间字段无法解析: %s,回退到明天 04:00", expression)
|
||||
tomorrow = now + timedelta(days=1)
|
||||
return tomorrow.replace(hour=4, minute=0, second=0, microsecond=0)
|
||||
|
||||
# 如果指定了 day-of-week(非 *)
|
||||
if dow != "*":
|
||||
try:
|
||||
cron_dow = int(dow) # 0=Sunday, 1=Monday, ..., 6=Saturday
|
||||
except ValueError:
|
||||
tomorrow = now + timedelta(days=1)
|
||||
return tomorrow.replace(hour=hour, minute=minute, second=0, microsecond=0)
|
||||
|
||||
# 转换为 ISO weekday(1=Monday, 7=Sunday)
|
||||
iso_dow = 7 if cron_dow == 0 else cron_dow
|
||||
current_iso = now.isoweekday()
|
||||
|
||||
if iso_dow > current_iso:
|
||||
delta_days = iso_dow - current_iso
|
||||
elif iso_dow < current_iso:
|
||||
delta_days = 7 - current_iso + iso_dow
|
||||
else:
|
||||
# 同一天,看时间是否已过
|
||||
target_today = now.replace(hour=hour, minute=minute, second=0, microsecond=0)
|
||||
if now < target_today:
|
||||
delta_days = 0
|
||||
else:
|
||||
delta_days = 7
|
||||
|
||||
next_dt = now + timedelta(days=delta_days)
|
||||
return next_dt.replace(hour=hour, minute=minute, second=0, microsecond=0)
|
||||
|
||||
# 每天定时(dom=* month=* dow=*)
|
||||
tomorrow = now + timedelta(days=1)
|
||||
return tomorrow.replace(hour=hour, minute=minute, second=0, microsecond=0)
|
||||
|
||||
|
||||
class Scheduler:
|
||||
"""基于 PostgreSQL 的定时调度器
|
||||
|
||||
后台 asyncio 循环每 SCHEDULER_POLL_INTERVAL 秒检查一次到期任务,
|
||||
将其 TaskConfig 入队到 TaskQueue。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._running = False
|
||||
self._loop_task: asyncio.Task | None = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 核心:检查到期任务并入队
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def check_and_enqueue(self) -> int:
|
||||
"""查询 enabled=true 且 next_run_at <= now 的调度任务,将其入队。
|
||||
|
||||
Returns:
|
||||
本次入队的任务数量
|
||||
"""
|
||||
conn = get_connection()
|
||||
enqueued = 0
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, site_id, task_config, schedule_config
|
||||
FROM scheduled_tasks
|
||||
WHERE enabled = TRUE
|
||||
AND next_run_at IS NOT NULL
|
||||
AND next_run_at <= NOW()
|
||||
ORDER BY next_run_at ASC
|
||||
"""
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
|
||||
for row in rows:
|
||||
task_id = str(row[0])
|
||||
site_id = row[1]
|
||||
task_config_raw = row[2] if isinstance(row[2], dict) else json.loads(row[2])
|
||||
schedule_config_raw = row[3] if isinstance(row[3], dict) else json.loads(row[3])
|
||||
|
||||
try:
|
||||
config = TaskConfigSchema(**task_config_raw)
|
||||
schedule_cfg = ScheduleConfigSchema(**schedule_config_raw)
|
||||
except Exception:
|
||||
logger.exception("调度任务 [%s] 配置反序列化失败,跳过", task_id)
|
||||
continue
|
||||
|
||||
# 入队
|
||||
try:
|
||||
queue_id = task_queue.enqueue(config, site_id)
|
||||
logger.info(
|
||||
"调度任务 [%s] 入队成功 → queue_id=%s site_id=%s",
|
||||
task_id, queue_id, site_id,
|
||||
)
|
||||
enqueued += 1
|
||||
except Exception:
|
||||
logger.exception("调度任务 [%s] 入队失败", task_id)
|
||||
continue
|
||||
|
||||
# 更新调度任务状态
|
||||
now = datetime.now(timezone.utc)
|
||||
next_run = calculate_next_run(schedule_cfg, now)
|
||||
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE scheduled_tasks
|
||||
SET last_run_at = NOW(),
|
||||
run_count = run_count + 1,
|
||||
next_run_at = %s,
|
||||
last_status = 'enqueued',
|
||||
updated_at = NOW()
|
||||
WHERE id = %s
|
||||
""",
|
||||
(next_run, task_id),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
except Exception:
|
||||
logger.exception("check_and_enqueue 执行异常")
|
||||
try:
|
||||
conn.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if enqueued > 0:
|
||||
logger.info("本轮调度检查:%d 个任务入队", enqueued)
|
||||
return enqueued
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 后台循环
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _loop(self) -> None:
|
||||
"""后台 asyncio 循环,每 SCHEDULER_POLL_INTERVAL 秒检查一次。"""
|
||||
self._running = True
|
||||
logger.info("Scheduler 后台循环启动(间隔 %ds)", SCHEDULER_POLL_INTERVAL)
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# 在线程池中执行同步数据库操作,避免阻塞事件循环
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, self.check_and_enqueue)
|
||||
except Exception:
|
||||
logger.exception("Scheduler 循环迭代异常")
|
||||
|
||||
await asyncio.sleep(SCHEDULER_POLL_INTERVAL)
|
||||
|
||||
logger.info("Scheduler 后台循环停止")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 生命周期
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def start(self) -> None:
|
||||
"""启动后台调度循环(在 FastAPI lifespan 中调用)。"""
|
||||
if self._loop_task is None or self._loop_task.done():
|
||||
self._loop_task = asyncio.create_task(self._loop())
|
||||
logger.info("Scheduler 已启动")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止后台调度循环。"""
|
||||
self._running = False
|
||||
if self._loop_task and not self._loop_task.done():
|
||||
self._loop_task.cancel()
|
||||
try:
|
||||
await self._loop_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._loop_task = None
|
||||
logger.info("Scheduler 已停止")
|
||||
|
||||
|
||||
# 全局单例
|
||||
scheduler = Scheduler()
|
||||
391
apps/backend/app/services/task_executor.py
Normal file
391
apps/backend/app/services/task_executor.py
Normal file
@@ -0,0 +1,391 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""ETL 任务执行器
|
||||
|
||||
通过 asyncio.create_subprocess_exec 启动 ETL CLI 子进程,
|
||||
逐行读取 stdout/stderr 并广播到 WebSocket 订阅者,
|
||||
执行完成后将结果写入 task_execution_log 表。
|
||||
|
||||
设计要点:
|
||||
- 每个 execution_id 对应一个子进程,存储在 _processes 字典中
|
||||
- 日志行存储在内存缓冲区 _log_buffers 中
|
||||
- WebSocket 订阅者通过 asyncio.Queue 接收实时日志
|
||||
- Windows 兼容:取消时使用 process.terminate() 而非 SIGTERM
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from ..config import ETL_PROJECT_PATH
|
||||
from ..database import get_connection
|
||||
from ..schemas.tasks import TaskConfigSchema
|
||||
from ..services.cli_builder import cli_builder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskExecutor:
|
||||
"""管理 ETL CLI 子进程的生命周期"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# execution_id → subprocess.Popen
|
||||
self._processes: dict[str, subprocess.Popen] = {}
|
||||
# execution_id → list[str](stdout + stderr 混合日志)
|
||||
self._log_buffers: dict[str, list[str]] = {}
|
||||
# execution_id → set[asyncio.Queue](WebSocket 订阅者)
|
||||
self._subscribers: dict[str, set[asyncio.Queue[str | None]]] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# WebSocket 订阅管理
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def subscribe(self, execution_id: str) -> asyncio.Queue[str | None]:
|
||||
"""注册一个 WebSocket 订阅者,返回用于读取日志行的 Queue。
|
||||
|
||||
Queue 中推送 str 表示日志行,None 表示执行结束。
|
||||
"""
|
||||
if execution_id not in self._subscribers:
|
||||
self._subscribers[execution_id] = set()
|
||||
queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||||
self._subscribers[execution_id].add(queue)
|
||||
return queue
|
||||
|
||||
def unsubscribe(self, execution_id: str, queue: asyncio.Queue[str | None]) -> None:
|
||||
"""移除一个 WebSocket 订阅者。"""
|
||||
subs = self._subscribers.get(execution_id)
|
||||
if subs:
|
||||
subs.discard(queue)
|
||||
if not subs:
|
||||
del self._subscribers[execution_id]
|
||||
|
||||
def _broadcast(self, execution_id: str, line: str) -> None:
|
||||
"""向所有订阅者广播一行日志。"""
|
||||
subs = self._subscribers.get(execution_id)
|
||||
if subs:
|
||||
for q in subs:
|
||||
q.put_nowait(line)
|
||||
|
||||
def _broadcast_end(self, execution_id: str) -> None:
|
||||
"""通知所有订阅者执行已结束(发送 None 哨兵)。"""
|
||||
subs = self._subscribers.get(execution_id)
|
||||
if subs:
|
||||
for q in subs:
|
||||
q.put_nowait(None)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 日志缓冲区
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_logs(self, execution_id: str) -> list[str]:
|
||||
"""获取指定执行的内存日志缓冲区(副本)。"""
|
||||
return list(self._log_buffers.get(execution_id, []))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 执行状态查询
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def is_running(self, execution_id: str) -> bool:
|
||||
"""判断指定执行是否仍在运行。"""
|
||||
proc = self._processes.get(execution_id)
|
||||
if proc is None:
|
||||
return False
|
||||
return proc.poll() is None
|
||||
|
||||
def get_running_ids(self) -> list[str]:
|
||||
"""返回当前所有运行中的 execution_id 列表。"""
|
||||
return [eid for eid, p in self._processes.items() if p.returncode is None]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 核心执行
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
config: TaskConfigSchema,
|
||||
execution_id: str,
|
||||
queue_id: str | None = None,
|
||||
site_id: int | None = None,
|
||||
) -> None:
|
||||
"""以子进程方式调用 ETL CLI。
|
||||
|
||||
使用 subprocess.Popen + 线程读取,兼容 Windows(避免
|
||||
asyncio.create_subprocess_exec 在 Windows 上的 NotImplementedError)。
|
||||
"""
|
||||
cmd = cli_builder.build_command(
|
||||
config, ETL_PROJECT_PATH, python_executable=sys.executable
|
||||
)
|
||||
command_str = " ".join(cmd)
|
||||
effective_site_id = site_id or config.store_id
|
||||
|
||||
logger.info(
|
||||
"启动 ETL 子进程 [%s]: %s (cwd=%s)",
|
||||
execution_id, command_str, ETL_PROJECT_PATH,
|
||||
)
|
||||
|
||||
self._log_buffers[execution_id] = []
|
||||
started_at = datetime.now(timezone.utc)
|
||||
t0 = time.monotonic()
|
||||
|
||||
self._write_execution_log(
|
||||
execution_id=execution_id,
|
||||
queue_id=queue_id,
|
||||
site_id=effective_site_id,
|
||||
task_codes=config.tasks,
|
||||
status="running",
|
||||
started_at=started_at,
|
||||
command=command_str,
|
||||
)
|
||||
|
||||
exit_code: int | None = None
|
||||
status = "running"
|
||||
stdout_lines: list[str] = []
|
||||
stderr_lines: list[str] = []
|
||||
|
||||
try:
|
||||
# 构建额外环境变量(DWD 表过滤通过环境变量注入)
|
||||
extra_env: dict[str, str] = {}
|
||||
if config.dwd_only_tables:
|
||||
extra_env["DWD_ONLY_TABLES"] = ",".join(config.dwd_only_tables)
|
||||
|
||||
# 在线程池中运行子进程,兼容 Windows
|
||||
exit_code = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
self._run_subprocess,
|
||||
cmd,
|
||||
execution_id,
|
||||
stdout_lines,
|
||||
stderr_lines,
|
||||
extra_env or None,
|
||||
)
|
||||
|
||||
if exit_code == 0:
|
||||
status = "success"
|
||||
else:
|
||||
status = "failed"
|
||||
|
||||
logger.info(
|
||||
"ETL 子进程 [%s] 退出,exit_code=%s, status=%s",
|
||||
execution_id, exit_code, status,
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
status = "cancelled"
|
||||
logger.info("ETL 子进程 [%s] 已取消", execution_id)
|
||||
# 尝试终止子进程
|
||||
proc = self._processes.get(execution_id)
|
||||
if proc and proc.poll() is None:
|
||||
proc.terminate()
|
||||
except Exception as exc:
|
||||
status = "failed"
|
||||
import traceback
|
||||
tb = traceback.format_exc()
|
||||
stderr_lines.append(f"[task_executor] 子进程启动/执行异常: {exc}")
|
||||
stderr_lines.append(tb)
|
||||
logger.exception("ETL 子进程 [%s] 执行异常", execution_id)
|
||||
finally:
|
||||
elapsed_ms = int((time.monotonic() - t0) * 1000)
|
||||
finished_at = datetime.now(timezone.utc)
|
||||
|
||||
self._broadcast_end(execution_id)
|
||||
self._processes.pop(execution_id, None)
|
||||
|
||||
self._update_execution_log(
|
||||
execution_id=execution_id,
|
||||
status=status,
|
||||
finished_at=finished_at,
|
||||
exit_code=exit_code,
|
||||
duration_ms=elapsed_ms,
|
||||
output_log="\n".join(stdout_lines),
|
||||
error_log="\n".join(stderr_lines),
|
||||
)
|
||||
|
||||
def _run_subprocess(
|
||||
self,
|
||||
cmd: list[str],
|
||||
execution_id: str,
|
||||
stdout_lines: list[str],
|
||||
stderr_lines: list[str],
|
||||
extra_env: dict[str, str] | None = None,
|
||||
) -> int:
|
||||
"""在线程中运行子进程并逐行读取输出。"""
|
||||
import os
|
||||
env = os.environ.copy()
|
||||
# 强制子进程使用 UTF-8 输出,避免 Windows GBK 乱码
|
||||
env["PYTHONIOENCODING"] = "utf-8"
|
||||
if extra_env:
|
||||
env.update(extra_env)
|
||||
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=ETL_PROJECT_PATH,
|
||||
env=env,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
)
|
||||
self._processes[execution_id] = proc
|
||||
|
||||
def read_stream(
|
||||
stream, stream_name: str, collector: list[str],
|
||||
) -> None:
|
||||
"""逐行读取流并广播。"""
|
||||
for raw_line in stream:
|
||||
line = raw_line.rstrip("\n").rstrip("\r")
|
||||
tagged = f"[{stream_name}] {line}"
|
||||
buf = self._log_buffers.get(execution_id)
|
||||
if buf is not None:
|
||||
buf.append(tagged)
|
||||
collector.append(line)
|
||||
self._broadcast(execution_id, tagged)
|
||||
|
||||
t_out = threading.Thread(
|
||||
target=read_stream, args=(proc.stdout, "stdout", stdout_lines),
|
||||
daemon=True,
|
||||
)
|
||||
t_err = threading.Thread(
|
||||
target=read_stream, args=(proc.stderr, "stderr", stderr_lines),
|
||||
daemon=True,
|
||||
)
|
||||
t_out.start()
|
||||
t_err.start()
|
||||
|
||||
proc.wait()
|
||||
t_out.join(timeout=5)
|
||||
t_err.join(timeout=5)
|
||||
|
||||
return proc.returncode
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 取消
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def cancel(self, execution_id: str) -> bool:
|
||||
"""向子进程发送终止信号。
|
||||
|
||||
Returns:
|
||||
True 表示成功发送终止信号,False 表示进程不存在或已退出。
|
||||
"""
|
||||
proc = self._processes.get(execution_id)
|
||||
if proc is None:
|
||||
return False
|
||||
# subprocess.Popen: poll() 返回 None 表示仍在运行
|
||||
if proc.poll() is not None:
|
||||
return False
|
||||
|
||||
logger.info("取消 ETL 子进程 [%s], pid=%s", execution_id, proc.pid)
|
||||
try:
|
||||
proc.terminate()
|
||||
except ProcessLookupError:
|
||||
return False
|
||||
return True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 数据库操作(同步,在线程池中执行也可,此处简单直连)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _write_execution_log(
|
||||
*,
|
||||
execution_id: str,
|
||||
queue_id: str | None,
|
||||
site_id: int | None,
|
||||
task_codes: list[str],
|
||||
status: str,
|
||||
started_at: datetime,
|
||||
command: str,
|
||||
) -> None:
|
||||
"""插入一条执行日志记录(running 状态)。"""
|
||||
try:
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO task_execution_log
|
||||
(id, queue_id, site_id, task_codes, status,
|
||||
started_at, command)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
||||
""",
|
||||
(
|
||||
execution_id,
|
||||
queue_id,
|
||||
site_id or 0,
|
||||
task_codes,
|
||||
status,
|
||||
started_at,
|
||||
command,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
except Exception:
|
||||
logger.exception("写入 execution_log 失败 [%s]", execution_id)
|
||||
|
||||
@staticmethod
|
||||
def _update_execution_log(
|
||||
*,
|
||||
execution_id: str,
|
||||
status: str,
|
||||
finished_at: datetime,
|
||||
exit_code: int | None,
|
||||
duration_ms: int,
|
||||
output_log: str,
|
||||
error_log: str,
|
||||
) -> None:
|
||||
"""更新执行日志记录(完成状态)。"""
|
||||
try:
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE task_execution_log
|
||||
SET status = %s,
|
||||
finished_at = %s,
|
||||
exit_code = %s,
|
||||
duration_ms = %s,
|
||||
output_log = %s,
|
||||
error_log = %s
|
||||
WHERE id = %s
|
||||
""",
|
||||
(
|
||||
status,
|
||||
finished_at,
|
||||
exit_code,
|
||||
duration_ms,
|
||||
output_log,
|
||||
error_log,
|
||||
execution_id,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
except Exception:
|
||||
logger.exception("更新 execution_log 失败 [%s]", execution_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 清理
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def cleanup(self, execution_id: str) -> None:
|
||||
"""清理指定执行的内存资源(日志缓冲区和订阅者)。
|
||||
|
||||
通常在确认日志已持久化后调用。
|
||||
"""
|
||||
self._log_buffers.pop(execution_id, None)
|
||||
self._subscribers.pop(execution_id, None)
|
||||
|
||||
|
||||
# 全局单例
|
||||
task_executor = TaskExecutor()
|
||||
486
apps/backend/app/services/task_queue.py
Normal file
486
apps/backend/app/services/task_queue.py
Normal file
@@ -0,0 +1,486 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""任务队列服务
|
||||
|
||||
基于 PostgreSQL task_queue 表实现 FIFO 队列,支持:
|
||||
- enqueue:入队,自动分配 position(当前最大 + 1)
|
||||
- dequeue:取出 position 最小的 pending 任务
|
||||
- reorder:调整任务在队列中的位置
|
||||
- delete:删除 pending 任务
|
||||
- process_loop:后台协程,队列非空且无运行中任务时自动取出执行
|
||||
|
||||
所有操作按 site_id 过滤,实现门店隔离。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..database import get_connection
|
||||
from ..schemas.tasks import TaskConfigSchema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 后台循环轮询间隔(秒)
|
||||
POLL_INTERVAL_SECONDS = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueuedTask:
|
||||
"""队列任务数据对象"""
|
||||
|
||||
id: str
|
||||
site_id: int
|
||||
config: dict[str, Any]
|
||||
status: str
|
||||
position: int
|
||||
created_at: Any = None
|
||||
started_at: Any = None
|
||||
finished_at: Any = None
|
||||
exit_code: int | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class TaskQueue:
|
||||
"""基于 PostgreSQL 的任务队列"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._running = False
|
||||
self._loop_task: asyncio.Task | None = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 入队
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def enqueue(self, config: TaskConfigSchema, site_id: int) -> str:
|
||||
"""将任务配置入队,自动分配 position。
|
||||
|
||||
Args:
|
||||
config: 任务配置
|
||||
site_id: 门店 ID(门店隔离)
|
||||
|
||||
Returns:
|
||||
新创建的队列任务 ID(UUID 字符串)
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
config_json = config.model_dump(mode="json")
|
||||
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
# 取当前该门店 pending 任务的最大 position,新任务排在末尾
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT COALESCE(MAX(position), 0)
|
||||
FROM task_queue
|
||||
WHERE site_id = %s AND status = 'pending'
|
||||
""",
|
||||
(site_id,),
|
||||
)
|
||||
max_pos = cur.fetchone()[0]
|
||||
new_pos = max_pos + 1
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO task_queue (id, site_id, config, status, position)
|
||||
VALUES (%s, %s, %s, 'pending', %s)
|
||||
""",
|
||||
(task_id, site_id, json.dumps(config_json), new_pos),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
logger.info("任务入队 [%s] site_id=%s position=%s", task_id, site_id, new_pos)
|
||||
return task_id
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 出队
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def dequeue(self, site_id: int) -> QueuedTask | None:
|
||||
"""取出 position 最小的 pending 任务,将其状态改为 running。
|
||||
|
||||
Args:
|
||||
site_id: 门店 ID
|
||||
|
||||
Returns:
|
||||
QueuedTask 或 None(队列为空时)
|
||||
"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
# 选取 position 最小的 pending 任务并锁定
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, site_id, config, status, position,
|
||||
created_at, started_at, finished_at,
|
||||
exit_code, error_message
|
||||
FROM task_queue
|
||||
WHERE site_id = %s AND status = 'pending'
|
||||
ORDER BY position ASC
|
||||
LIMIT 1
|
||||
FOR UPDATE SKIP LOCKED
|
||||
""",
|
||||
(site_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
conn.commit()
|
||||
return None
|
||||
|
||||
task = QueuedTask(
|
||||
id=str(row[0]),
|
||||
site_id=row[1],
|
||||
config=row[2] if isinstance(row[2], dict) else json.loads(row[2]),
|
||||
status=row[3],
|
||||
position=row[4],
|
||||
created_at=row[5],
|
||||
started_at=row[6],
|
||||
finished_at=row[7],
|
||||
exit_code=row[8],
|
||||
error_message=row[9],
|
||||
)
|
||||
|
||||
# 更新状态为 running
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE task_queue
|
||||
SET status = 'running', started_at = NOW()
|
||||
WHERE id = %s
|
||||
""",
|
||||
(task.id,),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
task.status = "running"
|
||||
logger.info("任务出队 [%s] site_id=%s", task.id, site_id)
|
||||
return task
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 重排
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def reorder(self, task_id: str, new_position: int, site_id: int) -> None:
|
||||
"""调整任务在队列中的位置。
|
||||
|
||||
仅允许对 pending 状态的任务重排。将目标任务移到 new_position,
|
||||
其余 pending 任务按原有相对顺序重新编号。
|
||||
|
||||
Args:
|
||||
task_id: 要移动的任务 ID
|
||||
new_position: 目标位置(1-based)
|
||||
site_id: 门店 ID
|
||||
"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
# 获取该门店所有 pending 任务,按 position 排序
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id FROM task_queue
|
||||
WHERE site_id = %s AND status = 'pending'
|
||||
ORDER BY position ASC
|
||||
""",
|
||||
(site_id,),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
task_ids = [str(r[0]) for r in rows]
|
||||
|
||||
if task_id not in task_ids:
|
||||
conn.commit()
|
||||
return
|
||||
|
||||
# 从列表中移除目标任务,再插入到新位置
|
||||
task_ids.remove(task_id)
|
||||
# new_position 是 1-based,转为 0-based 索引并 clamp
|
||||
insert_idx = max(0, min(new_position - 1, len(task_ids)))
|
||||
task_ids.insert(insert_idx, task_id)
|
||||
|
||||
# 按新顺序重新分配 position(1-based 连续编号)
|
||||
for idx, tid in enumerate(task_ids, start=1):
|
||||
cur.execute(
|
||||
"UPDATE task_queue SET position = %s WHERE id = %s",
|
||||
(idx, tid),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
logger.info(
|
||||
"任务重排 [%s] → position=%s site_id=%s",
|
||||
task_id, new_position, site_id,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 删除
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def delete(self, task_id: str, site_id: int) -> bool:
|
||||
"""删除 pending 状态的任务。
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
site_id: 门店 ID
|
||||
|
||||
Returns:
|
||||
True 表示成功删除,False 表示任务不存在或非 pending 状态。
|
||||
"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
DELETE FROM task_queue
|
||||
WHERE id = %s AND site_id = %s AND status = 'pending'
|
||||
""",
|
||||
(task_id, site_id),
|
||||
)
|
||||
deleted = cur.rowcount > 0
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if deleted:
|
||||
logger.info("任务删除 [%s] site_id=%s", task_id, site_id)
|
||||
else:
|
||||
logger.warning(
|
||||
"任务删除失败 [%s] site_id=%s(不存在或非 pending)",
|
||||
task_id, site_id,
|
||||
)
|
||||
return deleted
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 查询
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def list_pending(self, site_id: int) -> list[QueuedTask]:
|
||||
"""列出指定门店的所有 pending 任务,按 position 升序。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, site_id, config, status, position,
|
||||
created_at, started_at, finished_at,
|
||||
exit_code, error_message
|
||||
FROM task_queue
|
||||
WHERE site_id = %s AND status = 'pending'
|
||||
ORDER BY position ASC
|
||||
""",
|
||||
(site_id,),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return [
|
||||
QueuedTask(
|
||||
id=str(r[0]),
|
||||
site_id=r[1],
|
||||
config=r[2] if isinstance(r[2], dict) else json.loads(r[2]),
|
||||
status=r[3],
|
||||
position=r[4],
|
||||
created_at=r[5],
|
||||
started_at=r[6],
|
||||
finished_at=r[7],
|
||||
exit_code=r[8],
|
||||
error_message=r[9],
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
|
||||
def has_running(self, site_id: int) -> bool:
|
||||
"""检查指定门店是否有 running 状态的任务。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM task_queue
|
||||
WHERE site_id = %s AND status = 'running'
|
||||
)
|
||||
""",
|
||||
(site_id,),
|
||||
)
|
||||
result = cur.fetchone()[0]
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 后台处理循环
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def process_loop(self) -> None:
|
||||
"""后台协程:队列非空且无运行中任务时,自动取出并执行。
|
||||
|
||||
循环逻辑:
|
||||
1. 查询所有有 pending 任务的 site_id
|
||||
2. 对每个 site_id,若无 running 任务则 dequeue 并执行
|
||||
3. 等待 POLL_INTERVAL_SECONDS 后重复
|
||||
"""
|
||||
# 延迟导入避免循环依赖
|
||||
from .task_executor import task_executor
|
||||
|
||||
self._running = True
|
||||
logger.info("TaskQueue process_loop 启动")
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
await self._process_once(task_executor)
|
||||
except Exception:
|
||||
logger.exception("process_loop 迭代异常")
|
||||
|
||||
await asyncio.sleep(POLL_INTERVAL_SECONDS)
|
||||
|
||||
logger.info("TaskQueue process_loop 停止")
|
||||
|
||||
async def _process_once(self, executor: Any) -> None:
|
||||
"""单次处理:扫描所有门店的 pending 队列并执行。"""
|
||||
site_ids = self._get_pending_site_ids()
|
||||
|
||||
for site_id in site_ids:
|
||||
if self.has_running(site_id):
|
||||
continue
|
||||
|
||||
task = self.dequeue(site_id)
|
||||
if task is None:
|
||||
continue
|
||||
|
||||
config = TaskConfigSchema(**task.config)
|
||||
execution_id = str(uuid.uuid4())
|
||||
|
||||
logger.info(
|
||||
"process_loop 自动执行 [%s] queue_id=%s site_id=%s",
|
||||
execution_id, task.id, site_id,
|
||||
)
|
||||
|
||||
# 异步启动执行(不阻塞循环)
|
||||
asyncio.create_task(
|
||||
self._execute_and_update(
|
||||
executor, config, execution_id, task.id, site_id,
|
||||
)
|
||||
)
|
||||
|
||||
async def _execute_and_update(
|
||||
self,
|
||||
executor: Any,
|
||||
config: TaskConfigSchema,
|
||||
execution_id: str,
|
||||
queue_id: str,
|
||||
site_id: int,
|
||||
) -> None:
|
||||
"""执行任务并更新队列状态。"""
|
||||
try:
|
||||
await executor.execute(
|
||||
config=config,
|
||||
execution_id=execution_id,
|
||||
queue_id=queue_id,
|
||||
site_id=site_id,
|
||||
)
|
||||
# 执行完成后根据 executor 的结果更新 task_queue 状态
|
||||
self._update_queue_status_from_log(queue_id)
|
||||
except Exception:
|
||||
logger.exception("队列任务执行异常 [%s]", queue_id)
|
||||
self._mark_failed(queue_id, "执行过程中发生未捕获异常")
|
||||
|
||||
def _get_pending_site_ids(self) -> list[int]:
|
||||
"""获取所有有 pending 任务的 site_id 列表。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT DISTINCT site_id FROM task_queue
|
||||
WHERE status = 'pending'
|
||||
"""
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
return [r[0] for r in rows]
|
||||
|
||||
def _update_queue_status_from_log(self, queue_id: str) -> None:
|
||||
"""从 task_execution_log 读取执行结果,同步到 task_queue 记录。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT status, finished_at, exit_code, error_log
|
||||
FROM task_execution_log
|
||||
WHERE queue_id = %s
|
||||
ORDER BY started_at DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE task_queue
|
||||
SET status = %s, finished_at = %s,
|
||||
exit_code = %s, error_message = %s
|
||||
WHERE id = %s
|
||||
""",
|
||||
(row[0], row[1], row[2], row[3], queue_id),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _mark_failed(self, queue_id: str, error_message: str) -> None:
|
||||
"""将队列任务标记为 failed。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE task_queue
|
||||
SET status = 'failed', finished_at = NOW(),
|
||||
error_message = %s
|
||||
WHERE id = %s
|
||||
""",
|
||||
(error_message, queue_id),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 生命周期
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def start(self) -> None:
|
||||
"""启动后台处理循环(在 FastAPI lifespan 中调用)。"""
|
||||
if self._loop_task is None or self._loop_task.done():
|
||||
self._loop_task = asyncio.create_task(self.process_loop())
|
||||
logger.info("TaskQueue 后台循环已启动")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止后台处理循环。"""
|
||||
self._running = False
|
||||
if self._loop_task and not self._loop_task.done():
|
||||
self._loop_task.cancel()
|
||||
try:
|
||||
await self._loop_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._loop_task = None
|
||||
logger.info("TaskQueue 后台循环已停止")
|
||||
|
||||
|
||||
# 全局单例
|
||||
task_queue = TaskQueue()
|
||||
221
apps/backend/app/services/task_registry.py
Normal file
221
apps/backend/app/services/task_registry.py
Normal file
@@ -0,0 +1,221 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""静态任务注册表
|
||||
|
||||
从 ETL orchestration/task_registry.py 提取的任务元数据硬编码副本。
|
||||
后端不直接导入 ETL 代码,避免引入重量级依赖链。
|
||||
|
||||
业务域分组逻辑:按任务代码前缀 / 目标表语义归类,与 GUI 保持一致。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TaskDefinition:
|
||||
"""单个 ETL 任务的元数据"""
|
||||
|
||||
code: str
|
||||
name: str
|
||||
description: str
|
||||
domain: str # 业务域:会员 / 结算 / 助教 / 商品 / 台桌 / 团购 / 库存 / 财务 / 指数 / 工具
|
||||
layer: str # ODS / DWD / DWS / INDEX / UTILITY
|
||||
requires_window: bool = True
|
||||
is_ods: bool = False
|
||||
is_dimension: bool = False
|
||||
default_enabled: bool = True
|
||||
is_common: bool = True # 常用任务标记,False 表示工具类/手动类任务
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DwdTableDefinition:
|
||||
"""DWD 表元数据"""
|
||||
|
||||
table_name: str # 完整表名(含 schema)
|
||||
display_name: str
|
||||
domain: str
|
||||
ods_source: str # 对应的 ODS 源表
|
||||
is_dimension: bool = False
|
||||
|
||||
|
||||
# ── ODS 任务定义 ──────────────────────────────────────────────
|
||||
|
||||
ODS_TASKS: list[TaskDefinition] = [
|
||||
TaskDefinition("ODS_ASSISTANT_ACCOUNT", "助教账号", "抽取助教账号主数据", "助教", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_ASSISTANT_LEDGER", "助教服务记录", "抽取助教服务流水", "助教", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_ASSISTANT_ABOLISH", "助教取消记录", "抽取助教取消/作废记录", "助教", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_SETTLEMENT_RECORDS", "结算记录", "抽取订单结算记录", "结算", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_SETTLEMENT_TICKET", "结账小票", "抽取结账小票明细", "结算", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_TABLE_USE", "台费流水", "抽取台费使用流水", "台桌", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_TABLE_FEE_DISCOUNT", "台费折扣", "抽取台费折扣记录", "台桌", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_TABLES", "台桌主数据", "抽取门店台桌信息", "台桌", "ODS", is_ods=True, requires_window=False),
|
||||
TaskDefinition("ODS_PAYMENT", "支付流水", "抽取支付交易记录", "结算", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_REFUND", "退款流水", "抽取退款交易记录", "结算", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_PLATFORM_COUPON", "平台券核销", "抽取平台优惠券核销记录", "团购", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_MEMBER", "会员主数据", "抽取会员档案", "会员", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_MEMBER_CARD", "会员储值卡", "抽取会员储值卡信息", "会员", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_MEMBER_BALANCE", "会员余额变动", "抽取会员余额变动记录", "会员", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_RECHARGE_SETTLE", "充值结算", "抽取充值结算记录", "会员", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_GROUP_PACKAGE", "团购套餐", "抽取团购套餐定义", "团购", "ODS", is_ods=True, requires_window=False),
|
||||
TaskDefinition("ODS_GROUP_BUY_REDEMPTION", "团购核销", "抽取团购核销记录", "团购", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_INVENTORY_STOCK", "库存快照", "抽取商品库存汇总", "库存", "ODS", is_ods=True, requires_window=False),
|
||||
TaskDefinition("ODS_INVENTORY_CHANGE", "库存变动", "抽取库存出入库记录", "库存", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_GOODS_CATEGORY", "商品分类", "抽取商品分类树", "商品", "ODS", is_ods=True, requires_window=False),
|
||||
TaskDefinition("ODS_STORE_GOODS", "门店商品", "抽取门店商品主数据", "商品", "ODS", is_ods=True, requires_window=False),
|
||||
TaskDefinition("ODS_STORE_GOODS_SALES", "商品销售", "抽取门店商品销售记录", "商品", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_TENANT_GOODS", "租户商品", "抽取租户级商品主数据", "商品", "ODS", is_ods=True, requires_window=False),
|
||||
]
|
||||
|
||||
# ── DWD 任务定义 ──────────────────────────────────────────────
|
||||
|
||||
DWD_TASKS: list[TaskDefinition] = [
|
||||
TaskDefinition("DWD_LOAD_FROM_ODS", "DWD 装载", "从 ODS 装载至 DWD(维度 SCD2 + 事实增量)", "通用", "DWD", requires_window=False),
|
||||
TaskDefinition("DWD_QUALITY_CHECK", "DWD 质量检查", "对 DWD 层数据执行质量校验", "通用", "DWD", requires_window=False, is_common=False),
|
||||
]
|
||||
|
||||
# ── DWS 任务定义 ──────────────────────────────────────────────
|
||||
|
||||
DWS_TASKS: list[TaskDefinition] = [
|
||||
TaskDefinition("DWS_BUILD_ORDER_SUMMARY", "订单汇总构建", "构建订单汇总宽表", "结算", "DWS"),
|
||||
TaskDefinition("DWS_ASSISTANT_DAILY", "助教日报", "汇总助教每日业绩", "助教", "DWS"),
|
||||
TaskDefinition("DWS_ASSISTANT_MONTHLY", "助教月报", "汇总助教月度业绩", "助教", "DWS"),
|
||||
TaskDefinition("DWS_ASSISTANT_CUSTOMER", "助教客户分析", "汇总助教-客户关系", "助教", "DWS"),
|
||||
TaskDefinition("DWS_ASSISTANT_SALARY", "助教工资计算", "计算助教工资", "助教", "DWS"),
|
||||
TaskDefinition("DWS_ASSISTANT_FINANCE", "助教财务汇总", "汇总助教财务数据", "助教", "DWS"),
|
||||
TaskDefinition("DWS_MEMBER_CONSUMPTION", "会员消费分析", "汇总会员消费数据", "会员", "DWS"),
|
||||
TaskDefinition("DWS_MEMBER_VISIT", "会员到店分析", "汇总会员到店频次", "会员", "DWS"),
|
||||
TaskDefinition("DWS_FINANCE_DAILY", "财务日报", "汇总每日财务数据", "财务", "DWS"),
|
||||
TaskDefinition("DWS_FINANCE_RECHARGE", "充值汇总", "汇总充值数据", "财务", "DWS"),
|
||||
TaskDefinition("DWS_FINANCE_INCOME_STRUCTURE", "收入结构", "分析收入结构", "财务", "DWS"),
|
||||
TaskDefinition("DWS_FINANCE_DISCOUNT_DETAIL", "折扣明细", "汇总折扣明细", "财务", "DWS"),
|
||||
# CHANGE [2026-02-19] intent: 同步 ETL 侧合并——原 DWS_RETENTION_CLEANUP / DWS_MV_REFRESH_* 已合并为 DWS_MAINTENANCE
|
||||
TaskDefinition("DWS_MAINTENANCE", "DWS 维护", "刷新物化视图 + 清理过期留存数据", "通用", "DWS", requires_window=False, is_common=False),
|
||||
]
|
||||
|
||||
# ── INDEX 任务定义 ────────────────────────────────────────────
|
||||
|
||||
INDEX_TASKS: list[TaskDefinition] = [
|
||||
TaskDefinition("DWS_WINBACK_INDEX", "回流指数 (WBI)", "计算会员回流指数", "指数", "INDEX"),
|
||||
TaskDefinition("DWS_NEWCONV_INDEX", "新客转化指数 (NCI)", "计算新客转化指数", "指数", "INDEX"),
|
||||
TaskDefinition("DWS_ML_MANUAL_IMPORT", "手动导入 (ML)", "手动导入机器学习数据", "指数", "INDEX", requires_window=False, is_common=False),
|
||||
TaskDefinition("DWS_RELATION_INDEX", "关系指数 (RS)", "计算助教-客户关系指数", "指数", "INDEX"),
|
||||
]
|
||||
|
||||
# ── 工具类任务定义 ────────────────────────────────────────────
|
||||
|
||||
UTILITY_TASKS: list[TaskDefinition] = [
|
||||
TaskDefinition("MANUAL_INGEST", "手动导入", "从本地 JSON 文件手动导入数据", "工具", "UTILITY", requires_window=False, is_common=False),
|
||||
TaskDefinition("INIT_ODS_SCHEMA", "初始化 ODS Schema", "创建 ODS 层表结构", "工具", "UTILITY", requires_window=False, is_common=False),
|
||||
TaskDefinition("INIT_DWD_SCHEMA", "初始化 DWD Schema", "创建 DWD 层表结构", "工具", "UTILITY", requires_window=False, is_common=False),
|
||||
TaskDefinition("INIT_DWS_SCHEMA", "初始化 DWS Schema", "创建 DWS 层表结构", "工具", "UTILITY", requires_window=False, is_common=False),
|
||||
TaskDefinition("ODS_JSON_ARCHIVE", "ODS JSON 归档", "归档 ODS 原始 JSON 文件", "工具", "UTILITY", requires_window=False, is_common=False),
|
||||
TaskDefinition("CHECK_CUTOFF", "游标检查", "检查各任务数据游标截止点", "工具", "UTILITY", requires_window=False, is_common=False),
|
||||
TaskDefinition("SEED_DWS_CONFIG", "DWS 配置种子", "初始化 DWS 配置数据", "工具", "UTILITY", requires_window=False, is_common=False),
|
||||
TaskDefinition("DATA_INTEGRITY_CHECK", "数据完整性校验", "校验跨层数据完整性", "工具", "UTILITY", requires_window=False, is_common=False),
|
||||
]
|
||||
|
||||
# ── 全量任务列表 ──────────────────────────────────────────────
|
||||
|
||||
ALL_TASKS: list[TaskDefinition] = ODS_TASKS + DWD_TASKS + DWS_TASKS + INDEX_TASKS + UTILITY_TASKS
|
||||
|
||||
# 按 code 索引,便于快速查找
|
||||
_TASK_BY_CODE: dict[str, TaskDefinition] = {t.code: t for t in ALL_TASKS}
|
||||
|
||||
|
||||
def get_all_tasks() -> list[TaskDefinition]:
|
||||
return ALL_TASKS
|
||||
|
||||
|
||||
def get_task_by_code(code: str) -> TaskDefinition | None:
|
||||
return _TASK_BY_CODE.get(code.upper())
|
||||
|
||||
|
||||
def get_tasks_grouped_by_domain() -> dict[str, list[TaskDefinition]]:
|
||||
"""按业务域分组返回任务列表"""
|
||||
groups: dict[str, list[TaskDefinition]] = {}
|
||||
for t in ALL_TASKS:
|
||||
groups.setdefault(t.domain, []).append(t)
|
||||
return groups
|
||||
|
||||
|
||||
def get_tasks_by_layer(layer: str) -> list[TaskDefinition]:
|
||||
"""获取指定层的所有任务"""
|
||||
layer_upper = layer.upper()
|
||||
return [t for t in ALL_TASKS if t.layer == layer_upper]
|
||||
|
||||
|
||||
# ── Flow → 层映射 ────────────────────────────────────────────
|
||||
# 每种 Flow 包含的层,用于前端按 Flow 过滤可选任务
|
||||
|
||||
FLOW_LAYER_MAP: dict[str, list[str]] = {
|
||||
"api_ods": ["ODS"],
|
||||
"api_ods_dwd": ["ODS", "DWD"],
|
||||
"api_full": ["ODS", "DWD", "DWS", "INDEX"],
|
||||
"ods_dwd": ["DWD"],
|
||||
"dwd_dws": ["DWS"],
|
||||
"dwd_dws_index": ["DWS", "INDEX"],
|
||||
"dwd_index": ["INDEX"],
|
||||
}
|
||||
|
||||
|
||||
def get_compatible_tasks(flow_id: str) -> list[TaskDefinition]:
|
||||
"""根据 Flow 包含的层,返回兼容的任务列表"""
|
||||
layers = FLOW_LAYER_MAP.get(flow_id, [])
|
||||
return [t for t in ALL_TASKS if t.layer in layers]
|
||||
|
||||
|
||||
# ── DWD 表定义 ────────────────────────────────────────────────
|
||||
|
||||
DWD_TABLES: list[DwdTableDefinition] = [
|
||||
# 维度表
|
||||
DwdTableDefinition("dwd.dim_site", "门店维度", "台桌", "ods.table_fee_transactions", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_site_ex", "门店维度(扩展)", "台桌", "ods.table_fee_transactions", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_table", "台桌维度", "台桌", "ods.site_tables_master", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_table_ex", "台桌维度(扩展)", "台桌", "ods.site_tables_master", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_assistant", "助教维度", "助教", "ods.assistant_accounts_master", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_assistant_ex", "助教维度(扩展)", "助教", "ods.assistant_accounts_master", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_member", "会员维度", "会员", "ods.member_profiles", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_member_ex", "会员维度(扩展)", "会员", "ods.member_profiles", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_member_card_account", "会员储值卡维度", "会员", "ods.member_stored_value_cards", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_member_card_account_ex", "会员储值卡维度(扩展)", "会员", "ods.member_stored_value_cards", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_tenant_goods", "租户商品维度", "商品", "ods.tenant_goods_master", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_tenant_goods_ex", "租户商品维度(扩展)", "商品", "ods.tenant_goods_master", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_store_goods", "门店商品维度", "商品", "ods.store_goods_master", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_store_goods_ex", "门店商品维度(扩展)", "商品", "ods.store_goods_master", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_goods_category", "商品分类维度", "商品", "ods.stock_goods_category_tree", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_groupbuy_package", "团购套餐维度", "团购", "ods.group_buy_packages", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_groupbuy_package_ex", "团购套餐维度(扩展)", "团购", "ods.group_buy_packages", is_dimension=True),
|
||||
# 事实表
|
||||
DwdTableDefinition("dwd.dwd_settlement_head", "结算主表", "结算", "ods.settlement_records"),
|
||||
DwdTableDefinition("dwd.dwd_settlement_head_ex", "结算主表(扩展)", "结算", "ods.settlement_records"),
|
||||
DwdTableDefinition("dwd.dwd_table_fee_log", "台费流水", "台桌", "ods.table_fee_transactions"),
|
||||
DwdTableDefinition("dwd.dwd_table_fee_log_ex", "台费流水(扩展)", "台桌", "ods.table_fee_transactions"),
|
||||
DwdTableDefinition("dwd.dwd_table_fee_adjust", "台费折扣", "台桌", "ods.table_fee_discount_records"),
|
||||
DwdTableDefinition("dwd.dwd_table_fee_adjust_ex", "台费折扣(扩展)", "台桌", "ods.table_fee_discount_records"),
|
||||
DwdTableDefinition("dwd.dwd_store_goods_sale", "商品销售", "商品", "ods.store_goods_sales_records"),
|
||||
DwdTableDefinition("dwd.dwd_store_goods_sale_ex", "商品销售(扩展)", "商品", "ods.store_goods_sales_records"),
|
||||
DwdTableDefinition("dwd.dwd_assistant_service_log", "助教服务流水", "助教", "ods.assistant_service_records"),
|
||||
DwdTableDefinition("dwd.dwd_assistant_service_log_ex", "助教服务流水(扩展)", "助教", "ods.assistant_service_records"),
|
||||
DwdTableDefinition("dwd.dwd_assistant_trash_event", "助教取消事件", "助教", "ods.assistant_cancellation_records"),
|
||||
DwdTableDefinition("dwd.dwd_assistant_trash_event_ex", "助教取消事件(扩展)", "助教", "ods.assistant_cancellation_records"),
|
||||
DwdTableDefinition("dwd.dwd_member_balance_change", "会员余额变动", "会员", "ods.member_balance_changes"),
|
||||
DwdTableDefinition("dwd.dwd_member_balance_change_ex", "会员余额变动(扩展)", "会员", "ods.member_balance_changes"),
|
||||
DwdTableDefinition("dwd.dwd_groupbuy_redemption", "团购核销", "团购", "ods.group_buy_redemption_records"),
|
||||
DwdTableDefinition("dwd.dwd_groupbuy_redemption_ex", "团购核销(扩展)", "团购", "ods.group_buy_redemption_records"),
|
||||
DwdTableDefinition("dwd.dwd_platform_coupon_redemption", "平台券核销", "团购", "ods.platform_coupon_redemption_records"),
|
||||
DwdTableDefinition("dwd.dwd_platform_coupon_redemption_ex", "平台券核销(扩展)", "团购", "ods.platform_coupon_redemption_records"),
|
||||
DwdTableDefinition("dwd.dwd_recharge_order", "充值订单", "会员", "ods.recharge_settlements"),
|
||||
DwdTableDefinition("dwd.dwd_recharge_order_ex", "充值订单(扩展)", "会员", "ods.recharge_settlements"),
|
||||
DwdTableDefinition("dwd.dwd_payment", "支付流水", "结算", "ods.payment_transactions"),
|
||||
DwdTableDefinition("dwd.dwd_refund", "退款流水", "结算", "ods.refund_transactions"),
|
||||
DwdTableDefinition("dwd.dwd_refund_ex", "退款流水(扩展)", "结算", "ods.refund_transactions"),
|
||||
]
|
||||
|
||||
|
||||
def get_dwd_tables_grouped_by_domain() -> dict[str, list[DwdTableDefinition]]:
|
||||
"""按业务域分组返回 DWD 表定义"""
|
||||
groups: dict[str, list[DwdTableDefinition]] = {}
|
||||
for t in DWD_TABLES:
|
||||
groups.setdefault(t.domain, []).append(t)
|
||||
return groups
|
||||
0
apps/backend/app/ws/__init__.py
Normal file
0
apps/backend/app/ws/__init__.py
Normal file
68
apps/backend/app/ws/logs.py
Normal file
68
apps/backend/app/ws/logs.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""WebSocket 日志推送端点
|
||||
|
||||
提供 WS /ws/logs/{execution_id} 端点,实时推送 ETL 任务执行日志。
|
||||
客户端连接后,先发送已有的历史日志行,再实时推送新日志,
|
||||
直到执行结束(收到 None 哨兵)或客户端断开。
|
||||
|
||||
设计要点:
|
||||
- 利用 TaskExecutor 已有的 subscribe/unsubscribe 机制
|
||||
- 连接时先回放内存缓冲区中的历史日志,避免丢失已产生的行
|
||||
- 通过 asyncio.Queue 接收实时日志,None 表示执行结束
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
|
||||
from ..services.task_executor import task_executor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ws_router = APIRouter()
|
||||
|
||||
|
||||
@ws_router.websocket("/ws/logs/{execution_id}")
|
||||
async def ws_logs(websocket: WebSocket, execution_id: str) -> None:
|
||||
"""实时推送指定 execution_id 的任务执行日志。
|
||||
|
||||
流程:
|
||||
1. 接受 WebSocket 连接
|
||||
2. 回放内存缓冲区中已有的日志行
|
||||
3. 订阅 TaskExecutor,持续推送新日志
|
||||
4. 收到 None(执行结束)或客户端断开时关闭
|
||||
"""
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket 连接已建立: execution_id=%s", execution_id)
|
||||
|
||||
# 订阅日志流
|
||||
queue = task_executor.subscribe(execution_id)
|
||||
|
||||
try:
|
||||
# 回放已有的历史日志行
|
||||
for line in task_executor.get_logs(execution_id):
|
||||
await websocket.send_text(line)
|
||||
|
||||
# 如果任务已经不在运行且没有订阅者队列中的数据,
|
||||
# 仍然保持连接等待——可能是任务刚结束但 queue 里还有未消费的消息
|
||||
while True:
|
||||
msg = await queue.get()
|
||||
if msg is None:
|
||||
# 执行结束哨兵
|
||||
break
|
||||
await websocket.send_text(msg)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info("WebSocket 客户端断开: execution_id=%s", execution_id)
|
||||
except Exception:
|
||||
logger.exception("WebSocket 异常: execution_id=%s", execution_id)
|
||||
finally:
|
||||
task_executor.unsubscribe(execution_id, queue)
|
||||
# 安全关闭连接(客户端可能已断开,忽略错误)
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("WebSocket 连接已清理: execution_id=%s", execution_id)
|
||||
24
apps/backend/doc/开放平台证书.cer
Normal file
24
apps/backend/doc/开放平台证书.cer
Normal file
@@ -0,0 +1,24 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIID9DCCAtygAwIBAgIUaB2siLoT1Nb+u9K0aL18avodFRYwDQYJKoZIhvcNAQEL
|
||||
BQAwbTELMAkGA1UEBhMCQ04xEjAQBgNVBAgMCUd1YW5nRG9uZzERMA8GA1UEBwwI
|
||||
U2hlblpoZW4xEDAOBgNVBAoMB1RlbmNlbnQxEDAOBgNVBAsMB1RlbmNlbnQxEzAR
|
||||
BgNVBAMMCnRzbTAwMDAwMDYwHhcNMjUxMTA0MTA1NzQ0WhcNMzUxMTAyMTA1NzQ0
|
||||
WjCBmDEeMBwGCSqGSIb3DQEJARYPd2VpeGlubXBAcXEuY29tMRswGQYDVQQDDBJ3
|
||||
eDdjMDc3OTNkODI3MzI5MjExFTATBgNVBAoMDFRlbmNlbnQgSW5jLjEOMAwGA1UE
|
||||
CwwFV3hnTXAxCzAJBgNVBAYTAkNOMRIwEAYDVQQIDAlHdWFuZ0RvbmcxETAPBgNV
|
||||
BAcMCFNoZW5aaGVuMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA8Ig7
|
||||
m6qvcw+0ncNZBUg9Xfk+6WTKSj7cCwcuD66JgYPQ4pVreqCTzusc//E+EGapXyR3
|
||||
fJZl+AL2sVWRIzLa1+dt9+45YBGm3lSbvarlbsdVMrYwFRW+d/vZgxcXfcS1VmVP
|
||||
NPEw7DaAkWlvUkFmUGdNzH+QLsuXTZKWEtHtmSo7us9HTJYV/aEH2uJpsHE4A3fP
|
||||
Vbc/wy1EwLt48o0ZDzpiPZiqn+nSrSXEqBPOEwzICxnHCJRpEH01RxBJGdTSouzF
|
||||
pfncMeEpGfFGH8GW8IQYzzvrvYDbprsnVMfHNAo0MMGK+iyWyAOFMqrkvSb962x7
|
||||
KoXLDH9OfmFRNse9WQIDAQABo2AwXjAdBgNVHQ4EFgQUccFm3WWmKA/m+uXeW8Xe
|
||||
jfNsX+owHwYDVR0jBBgwFoAURM1183H4z2eJGmP5z/zWI7tjRZAwDAYDVR0TAQH/
|
||||
BAIwADAOBgNVHQ8BAf8EBAMCBsAwDQYJKoZIhvcNAQELBQADggEBAIXfGfQARyxC
|
||||
Ptut+rOccdq8TawasZT7o7TnAGCCTPsAWCd5RTAXse65mSGM6oxjQsppZxtYz4Kx
|
||||
TLySl91Vok2nMH1jBoWPx9WoFyU6zCkmOkq7zWvEU23FR1Quq0QB0fmHrVMNQqxA
|
||||
LKkUuUFTa1wmVuYaKtcz5LAaj+GmgrY3kTIWg81ybPF/Hibkz0zWh54SLBc64Ha6
|
||||
zfNXDffq3vVVo04DKZW8Erd9nZL0F/w2u6+MpTl5CrAYzSZyDcNiIGbSrYpYYRt9
|
||||
JagGAn/ZZD93SnOiMcRCsNfNq4LisSf6AUMSA3F9Rw8iuxas5lDBf073pEy2vWjG
|
||||
VSp+Vio/oEY=
|
||||
-----END CERTIFICATE-----
|
||||
38
apps/backend/doc/微信开放平台 小程序 配置.md
Normal file
38
apps/backend/doc/微信开放平台 小程序 配置.md
Normal file
@@ -0,0 +1,38 @@
|
||||
## 开放平台证书
|
||||
开放平台证书编号:06e9682660ce742bb45ef278ae941af0
|
||||
证书已下载。
|
||||
|
||||
编号 密钥类型 密钥明文
|
||||
901b24f9af7b1421b80ebd5df9094141 对称密钥 D1fK6Zib6UOG10bM4WWhjsbMImCNXz7Mxq/0oRREGmw=
|
||||
59347d0e3cd661af9e90f7def5b6ca00 非对称密钥 -----BEGIN PUBLIC KEY----- MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAzMrQ8iGGop0JEvcS/dZu sVUgjHeHqtds506Ftq/0WylTGKG9yByrY6eSKHttkBs/IyqG6JU9TtmskvVBam9B BmwOVyHkXATlXwEIkhyShu459p0dQzKwFaiygtj/fvxzWvurXVf1UcbIYVP1u7T0 E6yqatUkhmeaIyBsuw7vp7yYxpJoxsR2t70cDOTnfmHzl47FqSmq8xQRl/Cyw4FJ +1NS3i3cQBYkxjxGJ8Q3lhd5xd8IjwzDf+v/UgEhmoduZDAr/HkOwcrh3ihCCkLC DsfqV6qJrtqHhq5PHog3TqaaF6I9vk04FQTHn07XzIMciBDauYdD/Mq3B+ZBOKX4 gQIDAQAB -----END PUBLIC KEY-----
|
||||
|
||||
|
||||
|
||||
## 安全管理配置
|
||||
https://developers.weixin.qq.com/miniprogram/dev/framework/share_signature.html
|
||||
|
||||
## 消息推送配置
|
||||
(暂未处理)
|
||||
|
||||
填写的URL需要正确响应微信发送的Token验证,填写说明请阅读消息推送服务器配置指南。https://developers.weixin.qq.com/miniprogram/dev/framework/server-ability/message-push.html
|
||||
|
||||
|
||||
|
||||
URL(服务器地址):
|
||||
https://push.langlangzhuoqiu.cn
|
||||
|
||||
Token(令牌):
|
||||
nCmGmzINYfaqf5jDGKdkDO2AJ9C0VWNe
|
||||
|
||||
EncodingAESKey(消息加密密钥):
|
||||
2SZwWe90vG121o1l7RRMbtGt8GNvA1Juf727a3m7nZX
|
||||
|
||||
消息加密方式:
|
||||
安全模式 (消息包为纯密文,需要加密和解密。)
|
||||
|
||||
数据格式:
|
||||
JSON
|
||||
|
||||
|
||||
## 业务域名
|
||||
https://api.langlangzhuoqiu.cn
|
||||
@@ -1,10 +1,34 @@
|
||||
# AI_CHANGELOG
|
||||
# - 2026-02-15 | Prompt: 让 FastAPI 成功启动 | 补全运行依赖(fastapi/uvicorn/psycopg2-binary/python-dotenv),使后端可通过 uv run uvicorn 启动
|
||||
# - 风险:依赖版本变更可能影响其他 workspace 成员;验证:uv sync --all-packages && uv run uvicorn app.main:app
|
||||
|
||||
[project]
|
||||
name = "zqyy-backend"
|
||||
version = "0.1.0"
|
||||
requires-python = ">=3.10"
|
||||
# CHANGE 2026-02-15 | intent: 补全后端运行依赖,原先仅声明 neozqyy-shared 导致 uvicorn/fastapi 缺失无法启动
|
||||
# assumptions: 版本下限与 tech.md 记录的核心依赖一致;uvicorn[standard] 包含 uvloop/httptools 等性能依赖
|
||||
dependencies = [
|
||||
"neozqyy-shared",
|
||||
"fastapi>=0.115",
|
||||
"uvicorn[standard]>=0.34",
|
||||
"psycopg2-binary>=2.9",
|
||||
"python-dotenv>=1.0",
|
||||
"python-jose[cryptography]>=3.3",
|
||||
"bcrypt>=4.0",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
neozqyy-shared = { workspace = true }
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest>=8.0",
|
||||
"pytest-asyncio>=0.23",
|
||||
"hypothesis>=6.100",
|
||||
"httpx>=0.27",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
pythonpath = ["."]
|
||||
|
||||
62
apps/backend/tests/test_auth_dependencies.py
Normal file
62
apps/backend/tests/test_auth_dependencies.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
FastAPI 依赖注入 get_current_user 单元测试。
|
||||
|
||||
通过 FastAPI TestClient 验证 Authorization header 处理。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
import pytest
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.auth.jwt import create_access_token, create_refresh_token
|
||||
|
||||
# 构造一个最小 FastAPI 应用用于测试依赖注入
|
||||
_test_app = FastAPI()
|
||||
|
||||
|
||||
@_test_app.get("/protected")
|
||||
async def protected_route(user: CurrentUser = Depends(get_current_user)):
|
||||
return {"user_id": user.user_id, "site_id": user.site_id}
|
||||
|
||||
|
||||
client = TestClient(_test_app)
|
||||
|
||||
|
||||
class TestGetCurrentUser:
|
||||
def test_valid_access_token(self):
|
||||
token = create_access_token(user_id=10, site_id=100)
|
||||
resp = client.get("/protected", headers={"Authorization": f"Bearer {token}"})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["user_id"] == 10
|
||||
assert data["site_id"] == 100
|
||||
|
||||
def test_missing_auth_header_returns_401(self):
|
||||
"""缺少 Authorization header 时返回 401。"""
|
||||
resp = client.get("/protected")
|
||||
assert resp.status_code in (401, 403)
|
||||
|
||||
def test_invalid_token_returns_401(self):
|
||||
resp = client.get(
|
||||
"/protected", headers={"Authorization": "Bearer invalid.token.here"}
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_refresh_token_rejected(self):
|
||||
"""refresh 令牌不能用于访问受保护端点。"""
|
||||
token = create_refresh_token(user_id=1, site_id=1)
|
||||
resp = client.get("/protected", headers={"Authorization": f"Bearer {token}"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_current_user_is_frozen_dataclass(self):
|
||||
"""CurrentUser 是不可变的。"""
|
||||
user = CurrentUser(user_id=1, site_id=2)
|
||||
assert user.user_id == 1
|
||||
assert user.site_id == 2
|
||||
with pytest.raises(AttributeError):
|
||||
user.user_id = 99 # type: ignore[misc]
|
||||
147
apps/backend/tests/test_auth_jwt.py
Normal file
147
apps/backend/tests/test_auth_jwt.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""
|
||||
JWT 认证模块单元测试。
|
||||
|
||||
覆盖:令牌生成、验证、过期、类型校验、密码哈希、依赖注入。
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from jose import jwt as jose_jwt
|
||||
|
||||
# 测试前设置 JWT_SECRET_KEY,避免空密钥
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
from app.auth.jwt import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
create_token_pair,
|
||||
decode_access_token,
|
||||
decode_refresh_token,
|
||||
decode_token,
|
||||
hash_password,
|
||||
verify_password,
|
||||
)
|
||||
from app import config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 密码哈希
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPasswordHashing:
|
||||
def test_hash_and_verify(self):
|
||||
raw = "my_secure_password"
|
||||
hashed = hash_password(raw)
|
||||
assert verify_password(raw, hashed)
|
||||
|
||||
def test_wrong_password_rejected(self):
|
||||
hashed = hash_password("correct")
|
||||
assert not verify_password("wrong", hashed)
|
||||
|
||||
def test_hash_is_not_plaintext(self):
|
||||
raw = "plaintext123"
|
||||
hashed = hash_password(raw)
|
||||
assert hashed != raw
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 令牌生成与解码
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTokenCreation:
|
||||
def test_access_token_contains_expected_fields(self):
|
||||
token = create_access_token(user_id=1, site_id=100)
|
||||
payload = decode_token(token)
|
||||
assert payload["sub"] == "1"
|
||||
assert payload["site_id"] == 100
|
||||
assert payload["type"] == "access"
|
||||
assert "exp" in payload
|
||||
|
||||
def test_refresh_token_contains_expected_fields(self):
|
||||
token = create_refresh_token(user_id=2, site_id=200)
|
||||
payload = decode_token(token)
|
||||
assert payload["sub"] == "2"
|
||||
assert payload["site_id"] == 200
|
||||
assert payload["type"] == "refresh"
|
||||
assert "exp" in payload
|
||||
|
||||
def test_token_pair_returns_both_tokens(self):
|
||||
pair = create_token_pair(user_id=3, site_id=300)
|
||||
assert "access_token" in pair
|
||||
assert "refresh_token" in pair
|
||||
assert pair["token_type"] == "bearer"
|
||||
|
||||
# 验证两个令牌类型不同
|
||||
access_payload = decode_token(pair["access_token"])
|
||||
refresh_payload = decode_token(pair["refresh_token"])
|
||||
assert access_payload["type"] == "access"
|
||||
assert refresh_payload["type"] == "refresh"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 令牌类型校验
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTokenTypeValidation:
|
||||
def test_decode_access_token_rejects_refresh(self):
|
||||
"""access 解码器拒绝 refresh 令牌。"""
|
||||
token = create_refresh_token(user_id=1, site_id=1)
|
||||
with pytest.raises(Exception):
|
||||
decode_access_token(token)
|
||||
|
||||
def test_decode_refresh_token_rejects_access(self):
|
||||
"""refresh 解码器拒绝 access 令牌。"""
|
||||
token = create_access_token(user_id=1, site_id=1)
|
||||
with pytest.raises(Exception):
|
||||
decode_refresh_token(token)
|
||||
|
||||
def test_decode_access_token_accepts_access(self):
|
||||
token = create_access_token(user_id=5, site_id=50)
|
||||
payload = decode_access_token(token)
|
||||
assert payload["sub"] == "5"
|
||||
assert payload["site_id"] == 50
|
||||
|
||||
def test_decode_refresh_token_accepts_refresh(self):
|
||||
token = create_refresh_token(user_id=6, site_id=60)
|
||||
payload = decode_refresh_token(token)
|
||||
assert payload["sub"] == "6"
|
||||
assert payload["site_id"] == 60
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 令牌过期
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTokenExpiry:
|
||||
def test_expired_token_rejected(self):
|
||||
"""手动构造已过期令牌,验证解码失败。"""
|
||||
payload = {
|
||||
"sub": "1",
|
||||
"site_id": 1,
|
||||
"type": "access",
|
||||
"exp": int(time.time()) - 10, # 10 秒前过期
|
||||
}
|
||||
token = jose_jwt.encode(
|
||||
payload, config.JWT_SECRET_KEY, algorithm=config.JWT_ALGORITHM
|
||||
)
|
||||
with pytest.raises(Exception):
|
||||
decode_token(token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 无效令牌
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestInvalidToken:
|
||||
def test_garbage_token_rejected(self):
|
||||
with pytest.raises(Exception):
|
||||
decode_token("not.a.valid.jwt")
|
||||
|
||||
def test_wrong_secret_rejected(self):
|
||||
"""用不同密钥签发的令牌应被拒绝。"""
|
||||
payload = {"sub": "1", "site_id": 1, "type": "access", "exp": int(time.time()) + 3600}
|
||||
token = jose_jwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||||
with pytest.raises(Exception):
|
||||
decode_token(token)
|
||||
137
apps/backend/tests/test_auth_properties.py
Normal file
137
apps/backend/tests/test_auth_properties.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
认证模块属性测试(Property-Based Testing)。
|
||||
|
||||
使用 hypothesis 验证认证系统的通用正确性属性:
|
||||
- Property 2: 无效凭据始终被拒绝
|
||||
- Property 3: 有效 JWT 令牌授权访问
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-property-tests")
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.auth.jwt import create_access_token
|
||||
from app.main import app
|
||||
from app.routers.auth import router
|
||||
|
||||
# 确保路由已挂载
|
||||
if router not in [r for r in app.routes]:
|
||||
app.include_router(router)
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 策略(Strategies)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 用户名策略:1~64 字符的可打印字符串(排除控制字符)
|
||||
_username_st = st.text(
|
||||
alphabet=st.characters(whitelist_categories=("L", "N", "P", "S")),
|
||||
min_size=1,
|
||||
max_size=64,
|
||||
)
|
||||
|
||||
# 密码策略:1~128 字符的可打印字符串
|
||||
_password_st = st.text(
|
||||
alphabet=st.characters(whitelist_categories=("L", "N", "P", "S")),
|
||||
min_size=1,
|
||||
max_size=128,
|
||||
)
|
||||
|
||||
# user_id 策略:正整数
|
||||
_user_id_st = st.integers(min_value=1, max_value=2**31 - 1)
|
||||
|
||||
# site_id 策略:正整数
|
||||
_site_id_st = st.integers(min_value=1, max_value=2**63 - 1)
|
||||
|
||||
|
||||
def _mock_db_returning(row):
|
||||
"""构造 mock get_connection,cursor.fetchone() 返回指定行。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = row
|
||||
mock_conn.cursor.return_value.__enter__ = lambda _: mock_cursor
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 2: 无效凭据始终被拒绝
|
||||
# **Validates: Requirements 1.2**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(username=_username_st, password=_password_st)
|
||||
@patch("app.routers.auth.get_connection")
|
||||
def test_invalid_credentials_always_rejected(mock_get_conn, username, password):
|
||||
"""
|
||||
Property 2: 无效凭据始终被拒绝。
|
||||
|
||||
对于任意用户名/密码组合,当数据库中不存在该用户时(fetchone 返回 None),
|
||||
登录接口应始终返回 401 状态码。
|
||||
"""
|
||||
# mock 数据库返回 None — 用户不存在
|
||||
mock_get_conn.return_value = _mock_db_returning(None)
|
||||
|
||||
resp = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": username, "password": password},
|
||||
)
|
||||
assert resp.status_code == 401, (
|
||||
f"期望 401,实际 {resp.status_code},username={username!r}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 3: 有效 JWT 令牌授权访问
|
||||
# **Validates: Requirements 1.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
import asyncio
|
||||
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
|
||||
|
||||
def _run_async(coro):
|
||||
"""在同步上下文中执行异步协程,避免 DeprecationWarning。"""
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(user_id=_user_id_st, site_id=_site_id_st)
|
||||
def test_valid_jwt_grants_access(user_id, site_id):
|
||||
"""
|
||||
Property 3: 有效 JWT 令牌授权访问。
|
||||
|
||||
对于任意 user_id 和 site_id,由系统签发的未过期 access_token
|
||||
应能被 get_current_user 依赖成功解析为 CurrentUser 对象,
|
||||
且解析出的 user_id 和 site_id 与签发时一致。
|
||||
"""
|
||||
# 生成有效的 access_token
|
||||
token = create_access_token(user_id=user_id, site_id=site_id)
|
||||
|
||||
# 直接调用依赖函数验证令牌解析
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
|
||||
result = _run_async(get_current_user(credentials))
|
||||
|
||||
assert isinstance(result, CurrentUser)
|
||||
assert result.user_id == user_id, (
|
||||
f"user_id 不匹配:期望 {user_id},实际 {result.user_id}"
|
||||
)
|
||||
assert result.site_id == site_id, (
|
||||
f"site_id 不匹配:期望 {site_id},实际 {result.site_id}"
|
||||
)
|
||||
167
apps/backend/tests/test_auth_router.py
Normal file
167
apps/backend/tests/test_auth_router.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
认证路由单元测试。
|
||||
|
||||
覆盖:登录成功/失败、刷新令牌、账号禁用等场景。
|
||||
通过 mock 数据库连接避免依赖真实 PostgreSQL。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.jwt import (
|
||||
create_refresh_token,
|
||||
decode_access_token,
|
||||
decode_refresh_token,
|
||||
hash_password,
|
||||
)
|
||||
from app.main import app
|
||||
from app.routers.auth import router
|
||||
|
||||
# 注册路由到 app(测试时确保路由已挂载)
|
||||
if router not in [r for r in app.routes]:
|
||||
app.include_router(router)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# 测试用固定数据
|
||||
_TEST_PASSWORD = "correct_password"
|
||||
_TEST_HASH = hash_password(_TEST_PASSWORD)
|
||||
_TEST_USER_ROW = (1, _TEST_HASH, 100, True) # id, password_hash, site_id, is_active
|
||||
_DISABLED_USER_ROW = (2, _TEST_HASH, 200, False)
|
||||
|
||||
|
||||
def _mock_db_returning(row):
|
||||
"""构造一个 mock get_connection,cursor.fetchone() 返回指定行。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = row
|
||||
mock_conn.cursor.return_value.__enter__ = lambda _: mock_cursor
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/auth/login
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLogin:
|
||||
@patch("app.routers.auth.get_connection")
|
||||
def test_login_success(self, mock_get_conn):
|
||||
mock_get_conn.return_value = _mock_db_returning(_TEST_USER_ROW)
|
||||
|
||||
resp = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "admin", "password": _TEST_PASSWORD},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "access_token" in data
|
||||
assert "refresh_token" in data
|
||||
assert data["token_type"] == "bearer"
|
||||
|
||||
# 验证 access_token payload 包含正确的 user_id 和 site_id
|
||||
payload = decode_access_token(data["access_token"])
|
||||
assert payload["sub"] == "1"
|
||||
assert payload["site_id"] == 100
|
||||
|
||||
@patch("app.routers.auth.get_connection")
|
||||
def test_login_user_not_found(self, mock_get_conn):
|
||||
"""用户不存在时返回 401。"""
|
||||
mock_get_conn.return_value = _mock_db_returning(None)
|
||||
|
||||
resp = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "nonexistent", "password": "whatever"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert "用户名或密码错误" in resp.json()["detail"]
|
||||
|
||||
@patch("app.routers.auth.get_connection")
|
||||
def test_login_wrong_password(self, mock_get_conn):
|
||||
"""密码错误时返回 401。"""
|
||||
mock_get_conn.return_value = _mock_db_returning(_TEST_USER_ROW)
|
||||
|
||||
resp = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "admin", "password": "wrong_password"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert "用户名或密码错误" in resp.json()["detail"]
|
||||
|
||||
@patch("app.routers.auth.get_connection")
|
||||
def test_login_disabled_account(self, mock_get_conn):
|
||||
"""账号已禁用时返回 401。"""
|
||||
mock_get_conn.return_value = _mock_db_returning(_DISABLED_USER_ROW)
|
||||
|
||||
resp = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "disabled_user", "password": _TEST_PASSWORD},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert "禁用" in resp.json()["detail"]
|
||||
|
||||
def test_login_missing_username(self):
|
||||
"""缺少 username 字段时返回 422。"""
|
||||
resp = client.post("/api/auth/login", json={"password": "test"})
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_login_empty_password(self):
|
||||
"""空密码时返回 422。"""
|
||||
resp = client.post(
|
||||
"/api/auth/login", json={"username": "admin", "password": ""}
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/auth/refresh
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRefresh:
|
||||
def test_refresh_success(self):
|
||||
"""有效的 refresh_token 换取新的 access_token。"""
|
||||
refresh = create_refresh_token(user_id=5, site_id=50)
|
||||
|
||||
resp = client.post(
|
||||
"/api/auth/refresh", json={"refresh_token": refresh}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "access_token" in data
|
||||
# refresh_token 原样返回
|
||||
assert data["refresh_token"] == refresh
|
||||
assert data["token_type"] == "bearer"
|
||||
|
||||
# 新 access_token 包含正确信息
|
||||
payload = decode_access_token(data["access_token"])
|
||||
assert payload["sub"] == "5"
|
||||
assert payload["site_id"] == 50
|
||||
|
||||
def test_refresh_with_invalid_token(self):
|
||||
"""无效令牌返回 401。"""
|
||||
resp = client.post(
|
||||
"/api/auth/refresh", json={"refresh_token": "garbage.token.here"}
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert "无效的刷新令牌" in resp.json()["detail"]
|
||||
|
||||
def test_refresh_with_access_token_rejected(self):
|
||||
"""用 access_token 做刷新应被拒绝。"""
|
||||
from app.auth.jwt import create_access_token
|
||||
|
||||
access = create_access_token(user_id=1, site_id=1)
|
||||
resp = client.post(
|
||||
"/api/auth/refresh", json={"refresh_token": access}
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_refresh_missing_token(self):
|
||||
"""缺少 refresh_token 字段时返回 422。"""
|
||||
resp = client.post("/api/auth/refresh", json={})
|
||||
assert resp.status_code == 422
|
||||
259
apps/backend/tests/test_cli_builder.py
Normal file
259
apps/backend/tests/test_cli_builder.py
Normal file
@@ -0,0 +1,259 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""CLIBuilder 单元测试
|
||||
|
||||
覆盖:7 种 Flow、3 种处理模式、时间窗口、store_id 自动注入、extra_args 等。
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.schemas.tasks import TaskConfigSchema
|
||||
from app.services.cli_builder import CLIBuilder, VALID_FLOWS, VALID_PROCESSING_MODES
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def builder() -> CLIBuilder:
|
||||
return CLIBuilder()
|
||||
|
||||
|
||||
ETL_PATH = "/fake/etl/project"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 基本命令结构
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBasicCommand:
|
||||
def test_minimal_command(self, builder: CLIBuilder):
|
||||
"""最小配置应生成 python -m cli.main --pipeline ... --processing-mode ..."""
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"])
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert cmd[:3] == ["python", "-m", "cli.main"]
|
||||
assert "--pipeline" in cmd
|
||||
assert "--processing-mode" in cmd
|
||||
|
||||
def test_custom_python_executable(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"])
|
||||
cmd = builder.build_command(config, ETL_PATH, python_executable="python3")
|
||||
assert cmd[0] == "python3"
|
||||
|
||||
def test_tasks_joined_by_comma(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER", "ODS_PAYMENT", "ODS_REFUND"])
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx = cmd.index("--tasks")
|
||||
assert cmd[idx + 1] == "ODS_MEMBER,ODS_PAYMENT,ODS_REFUND"
|
||||
|
||||
def test_empty_tasks_no_tasks_arg(self, builder: CLIBuilder):
|
||||
"""空任务列表不应生成 --tasks 参数"""
|
||||
config = TaskConfigSchema(tasks=[])
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--tasks" not in cmd
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7 种 Flow
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFlows:
|
||||
@pytest.mark.parametrize("flow_id", sorted(VALID_FLOWS))
|
||||
def test_all_flows_accepted(self, builder: CLIBuilder, flow_id: str):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"], pipeline=flow_id)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx = cmd.index("--pipeline")
|
||||
assert cmd[idx + 1] == flow_id
|
||||
|
||||
def test_default_flow_is_api_ods_dwd(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"])
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx = cmd.index("--pipeline")
|
||||
assert cmd[idx + 1] == "api_ods_dwd"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3 种处理模式
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestProcessingModes:
|
||||
@pytest.mark.parametrize("mode", sorted(VALID_PROCESSING_MODES))
|
||||
def test_all_modes_accepted(self, builder: CLIBuilder, mode: str):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"], processing_mode=mode)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx = cmd.index("--processing-mode")
|
||||
assert cmd[idx + 1] == mode
|
||||
|
||||
def test_fetch_before_verify_only_in_verify_mode(self, builder: CLIBuilder):
|
||||
"""--fetch-before-verify 仅在 verify_only 模式下生效"""
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
processing_mode="verify_only",
|
||||
fetch_before_verify=True,
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--fetch-before-verify" in cmd
|
||||
|
||||
def test_fetch_before_verify_ignored_in_increment_mode(self, builder: CLIBuilder):
|
||||
"""increment_only 模式下 fetch_before_verify=True 不应生成参数"""
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
processing_mode="increment_only",
|
||||
fetch_before_verify=True,
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--fetch-before-verify" not in cmd
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 时间窗口
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTimeWindow:
|
||||
def test_lookback_mode_generates_lookback_args(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
window_mode="lookback",
|
||||
lookback_hours=48,
|
||||
overlap_seconds=1200,
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx_lb = cmd.index("--lookback-hours")
|
||||
assert cmd[idx_lb + 1] == "48"
|
||||
idx_ol = cmd.index("--overlap-seconds")
|
||||
assert cmd[idx_ol + 1] == "1200"
|
||||
# lookback 模式不应生成 --window-start / --window-end
|
||||
assert "--window-start" not in cmd
|
||||
assert "--window-end" not in cmd
|
||||
|
||||
def test_custom_mode_generates_window_args(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
window_mode="custom",
|
||||
window_start="2026-01-01",
|
||||
window_end="2026-01-31",
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx_s = cmd.index("--window-start")
|
||||
assert cmd[idx_s + 1] == "2026-01-01"
|
||||
idx_e = cmd.index("--window-end")
|
||||
assert cmd[idx_e + 1] == "2026-01-31"
|
||||
# custom 模式不应生成 --lookback-hours / --overlap-seconds
|
||||
assert "--lookback-hours" not in cmd
|
||||
assert "--overlap-seconds" not in cmd
|
||||
|
||||
def test_window_split_with_days(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
window_split="day",
|
||||
window_split_days=10,
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx = cmd.index("--window-split")
|
||||
assert cmd[idx + 1] == "day"
|
||||
idx_d = cmd.index("--window-split-days")
|
||||
assert cmd[idx_d + 1] == "10"
|
||||
|
||||
def test_window_split_none_not_generated(self, builder: CLIBuilder):
|
||||
"""window_split='none' 不应生成 --window-split 参数"""
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"], window_split="none")
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--window-split" not in cmd
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# store_id 自动注入
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestStoreId:
|
||||
def test_store_id_injected(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"], store_id=42)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx = cmd.index("--store-id")
|
||||
assert cmd[idx + 1] == "42"
|
||||
|
||||
def test_store_id_none_not_generated(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"], store_id=None)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--store-id" not in cmd
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# dry_run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDryRun:
|
||||
def test_dry_run_flag(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"], dry_run=True)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--dry-run" in cmd
|
||||
|
||||
def test_no_dry_run_flag(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"], dry_run=False)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--dry-run" not in cmd
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extra_args
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtraArgs:
|
||||
def test_supported_value_arg(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
extra_args={"pg_dsn": "postgresql://localhost/test"},
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx = cmd.index("--pg-dsn")
|
||||
assert cmd[idx + 1] == "postgresql://localhost/test"
|
||||
|
||||
def test_supported_bool_arg(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
extra_args={"force_window_override": True},
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--force-window-override" in cmd
|
||||
|
||||
def test_unsupported_arg_ignored(self, builder: CLIBuilder):
|
||||
"""不在 CLI_SUPPORTED_ARGS 中的键应被忽略"""
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
extra_args={"unknown_param": "value"},
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--unknown-param" not in cmd
|
||||
|
||||
def test_none_value_ignored(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
extra_args={"pg_dsn": None},
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--pg-dsn" not in cmd
|
||||
|
||||
def test_false_bool_arg_not_generated(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
extra_args={"force_window_override": False},
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--force-window-override" not in cmd
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_command_string
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBuildCommandString:
|
||||
def test_returns_string(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"])
|
||||
result = builder.build_command_string(config, ETL_PATH)
|
||||
assert isinstance(result, str)
|
||||
assert "python -m cli.main" in result
|
||||
|
||||
def test_quotes_args_with_spaces(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
extra_args={"pg_dsn": "host=localhost dbname=test"},
|
||||
)
|
||||
result = builder.build_command_string(config, ETL_PATH)
|
||||
# 包含空格的值应被引号包裹
|
||||
assert '"host=localhost dbname=test"' in result
|
||||
94
apps/backend/tests/test_database.py
Normal file
94
apps/backend/tests/test_database.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
数据库连接模块单元测试。
|
||||
|
||||
覆盖:ETL 只读连接的创建、RLS site_id 设置、只读模式、异常处理。
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
from app.database import get_etl_readonly_connection
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_etl_readonly_connection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetEtlReadonlyConnection:
|
||||
"""ETL 只读连接:验证连接参数、只读设置、RLS 隔离。"""
|
||||
|
||||
@patch("app.database.psycopg2.connect")
|
||||
def test_sets_readonly_and_site_id(self, mock_connect):
|
||||
"""连接后应依次执行 SET read_only 和 SET LOCAL site_id。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
conn = get_etl_readonly_connection(site_id=42)
|
||||
|
||||
# 验证 autocommit 被关闭
|
||||
assert mock_conn.autocommit is False
|
||||
|
||||
# 验证执行了两条 SET 语句
|
||||
executed = [c.args[0] for c in mock_cursor.execute.call_args_list]
|
||||
assert "SET default_transaction_read_only = on" in executed[0]
|
||||
assert "SET LOCAL app.current_site_id" in executed[1]
|
||||
|
||||
# 验证 site_id 参数化传递(防 SQL 注入)
|
||||
site_id_call = mock_cursor.execute.call_args_list[1]
|
||||
assert site_id_call.args[1] == ("42",)
|
||||
|
||||
# 验证提交
|
||||
mock_conn.commit.assert_called_once()
|
||||
assert conn is mock_conn
|
||||
|
||||
@patch("app.database.psycopg2.connect")
|
||||
def test_accepts_string_site_id(self, mock_connect):
|
||||
"""site_id 为字符串时也应正常工作。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
get_etl_readonly_connection(site_id="99")
|
||||
|
||||
site_id_call = mock_cursor.execute.call_args_list[1]
|
||||
assert site_id_call.args[1] == ("99",)
|
||||
|
||||
@patch("app.database.psycopg2.connect")
|
||||
def test_closes_connection_on_setup_error(self, mock_connect):
|
||||
"""SET 语句执行失败时应关闭连接并抛出异常。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.execute.side_effect = Exception("SET failed")
|
||||
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
with pytest.raises(Exception, match="SET failed"):
|
||||
get_etl_readonly_connection(site_id=1)
|
||||
|
||||
mock_conn.close.assert_called_once()
|
||||
|
||||
@patch("app.database.psycopg2.connect")
|
||||
def test_uses_etl_config_params(self, mock_connect):
|
||||
"""应使用 ETL_DB_* 配置项连接。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
get_etl_readonly_connection(site_id=1)
|
||||
|
||||
connect_kwargs = mock_connect.call_args.kwargs
|
||||
# 验证使用了 ETL 数据库名(默认 etl_feiqiu)
|
||||
assert connect_kwargs["dbname"] == "etl_feiqiu"
|
||||
139
apps/backend/tests/test_db_viewer_properties.py
Normal file
139
apps/backend/tests/test_db_viewer_properties.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""数据库查看器属性测试(Property-Based Testing)。
|
||||
|
||||
使用 hypothesis 验证数据库查看器的通用正确性属性:
|
||||
- Property 17: SQL 写操作拦截
|
||||
- Property 18: SQL 查询结果行数限制
|
||||
|
||||
测试策略:
|
||||
- Property 17: 生成包含写操作关键词(随机大小写混合)的 SQL 字符串,
|
||||
验证 _WRITE_KEYWORDS 正则表达式能匹配到
|
||||
- Property 18: 生成随机长度的行列表(可能超过 1000 行),
|
||||
验证截取前 _MAX_ROWS 个元素后长度 <= 1000
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-db-viewer-properties")
|
||||
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from app.routers.db_viewer import _WRITE_KEYWORDS, _MAX_ROWS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 通用策略(Strategies)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 写操作关键词列表
|
||||
_WRITE_OPS = ["INSERT", "UPDATE", "DELETE", "DROP", "TRUNCATE"]
|
||||
|
||||
# SQL 前缀/后缀:不含写操作关键词的简单文本
|
||||
_sql_filler_st = st.text(
|
||||
alphabet=st.characters(
|
||||
whitelist_categories=("L", "N", "S"),
|
||||
blacklist_characters="\x00",
|
||||
),
|
||||
min_size=0,
|
||||
max_size=50,
|
||||
)
|
||||
|
||||
# 随机大小写混合的写操作关键词
|
||||
_random_case_keyword_st = st.sampled_from(_WRITE_OPS).flatmap(
|
||||
lambda kw: st.tuples(
|
||||
st.just(kw),
|
||||
st.lists(
|
||||
st.booleans(),
|
||||
min_size=len(kw),
|
||||
max_size=len(kw),
|
||||
),
|
||||
).map(
|
||||
lambda pair: "".join(
|
||||
c.upper() if flag else c.lower()
|
||||
for c, flag in zip(pair[0], pair[1])
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 17: SQL 写操作拦截
|
||||
# **Validates: Requirements 7.5**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=200)
|
||||
@given(
|
||||
prefix=_sql_filler_st,
|
||||
keyword=_random_case_keyword_st,
|
||||
suffix=_sql_filler_st,
|
||||
)
|
||||
def test_write_keywords_always_detected(prefix, keyword, suffix):
|
||||
"""Property 17: SQL 写操作拦截。
|
||||
|
||||
包含 INSERT、UPDATE、DELETE、DROP、TRUNCATE 关键词(不区分大小写)的
|
||||
SQL 语句,_WRITE_KEYWORDS 正则表达式应能匹配到。
|
||||
|
||||
策略:在随机前缀和后缀之间插入一个随机大小写混合的写操作关键词,
|
||||
用空格分隔以确保 \\b 词边界能匹配。
|
||||
"""
|
||||
# 用空格分隔确保词边界匹配
|
||||
sql = f"{prefix} {keyword} {suffix}"
|
||||
|
||||
match = _WRITE_KEYWORDS.search(sql)
|
||||
assert match is not None, (
|
||||
f"正则表达式未能匹配到写操作关键词:sql={sql!r}, keyword={keyword!r}"
|
||||
)
|
||||
# 匹配到的关键词(转大写后)应在写操作列表中
|
||||
assert match.group(1).upper() in _WRITE_OPS, (
|
||||
f"匹配到的关键词 '{match.group(1)}' 不在写操作列表中"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 18: SQL 查询结果行数限制
|
||||
# **Validates: Requirements 7.4**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 模拟数据库返回的行:每行是一个简单列表
|
||||
_row_st = st.lists(
|
||||
st.one_of(st.integers(), st.text(max_size=20), st.none()),
|
||||
min_size=1,
|
||||
max_size=5,
|
||||
)
|
||||
|
||||
# 行列表策略:0 到 3000 行(覆盖超过 _MAX_ROWS 的情况)
|
||||
_rows_st = st.lists(_row_st, min_size=0, max_size=3000)
|
||||
|
||||
|
||||
@settings(max_examples=200)
|
||||
@given(rows=_rows_st)
|
||||
def test_row_count_never_exceeds_max(rows):
|
||||
"""Property 18: SQL 查询结果行数限制。
|
||||
|
||||
对任意长度的行列表,取前 _MAX_ROWS 个元素后,
|
||||
结果长度应 <= 1000。
|
||||
|
||||
这等价于 cur.fetchmany(_MAX_ROWS) 的行为:
|
||||
数据库游标最多返回 _MAX_ROWS 行。
|
||||
"""
|
||||
# 模拟 fetchmany(_MAX_ROWS) 的行为
|
||||
truncated = rows[:_MAX_ROWS]
|
||||
|
||||
assert len(truncated) <= _MAX_ROWS, (
|
||||
f"截取后行数 {len(truncated)} 超过上限 {_MAX_ROWS}"
|
||||
)
|
||||
|
||||
# 额外验证:如果原始行数 <= _MAX_ROWS,截取后应保留全部
|
||||
if len(rows) <= _MAX_ROWS:
|
||||
assert len(truncated) == len(rows), (
|
||||
f"原始行数 {len(rows)} <= {_MAX_ROWS},截取后应保留全部,"
|
||||
f"实际 {len(truncated)}"
|
||||
)
|
||||
|
||||
# 额外验证:如果原始行数 > _MAX_ROWS,截取后应恰好为 _MAX_ROWS
|
||||
if len(rows) > _MAX_ROWS:
|
||||
assert len(truncated) == _MAX_ROWS, (
|
||||
f"原始行数 {len(rows)} > {_MAX_ROWS},截取后应恰好为 {_MAX_ROWS},"
|
||||
f"实际 {len(truncated)}"
|
||||
)
|
||||
321
apps/backend/tests/test_db_viewer_router.py
Normal file
321
apps/backend/tests/test_db_viewer_router.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""数据库查看器路由单元测试
|
||||
|
||||
覆盖 4 个端点:
|
||||
- GET /api/db/schemas
|
||||
- GET /api/db/schemas/{name}/tables
|
||||
- GET /api/db/tables/{schema}/{table}/columns
|
||||
- POST /api/db/query
|
||||
|
||||
通过 mock 绕过数据库连接,专注路由逻辑验证。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from psycopg2 import errors as pg_errors
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
|
||||
_TEST_USER = CurrentUser(user_id=1, site_id=100)
|
||||
|
||||
|
||||
def _override_auth():
|
||||
return _TEST_USER
|
||||
|
||||
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
client = TestClient(app)
|
||||
|
||||
_MOCK_CONN = "app.routers.db_viewer.get_etl_readonly_connection"
|
||||
|
||||
|
||||
def _make_mock_conn(rows, description=None):
|
||||
"""构造 mock 数据库连接,cursor 返回指定行和列描述。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
mock_cur.fetchall.return_value = rows
|
||||
mock_cur.fetchmany.return_value = rows
|
||||
mock_cur.description = description
|
||||
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cur
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_conn, mock_cur
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/db/schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListSchemas:
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_returns_schema_list(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([("dwd",), ("dws",), ("ods",)])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/schemas")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 3
|
||||
assert data[0]["name"] == "dwd"
|
||||
assert data[2]["name"] == "ods"
|
||||
|
||||
# 验证 site_id 传递
|
||||
mock_get_conn.assert_called_once_with(_TEST_USER.site_id)
|
||||
conn.close.assert_called_once()
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_empty_schemas(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/schemas")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/db/schemas/{name}/tables
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListTables:
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_returns_tables_with_row_count(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([
|
||||
("dim_member", 1500),
|
||||
("fact_order", 32000),
|
||||
])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/schemas/dwd/tables")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 2
|
||||
assert data[0]["name"] == "dim_member"
|
||||
assert data[0]["row_count"] == 1500
|
||||
assert data[1]["name"] == "fact_order"
|
||||
assert data[1]["row_count"] == 32000
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_null_row_count(self, mock_get_conn):
|
||||
"""pg_stat_user_tables 可能没有统计信息,row_count 为 None。"""
|
||||
conn, cur = _make_mock_conn([("new_table", None)])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/schemas/ods/tables")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data[0]["row_count"] is None
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_empty_schema(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/schemas/empty_schema/tables")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/db/tables/{schema}/{table}/columns
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListColumns:
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_returns_column_definitions(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([
|
||||
("id", "bigint", "NO", None),
|
||||
("name", "character varying", "YES", None),
|
||||
("created_at", "timestamp with time zone", "NO", "now()"),
|
||||
])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/tables/dwd/dim_member/columns")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 3
|
||||
|
||||
assert data[0]["name"] == "id"
|
||||
assert data[0]["data_type"] == "bigint"
|
||||
assert data[0]["is_nullable"] is False
|
||||
assert data[0]["column_default"] is None
|
||||
|
||||
assert data[1]["is_nullable"] is True
|
||||
assert data[2]["column_default"] == "now()"
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_empty_table(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/tables/dwd/nonexistent/columns")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/db/query
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExecuteQuery:
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_successful_select(self, mock_get_conn):
|
||||
description = [("id",), ("name",)]
|
||||
conn, cur = _make_mock_conn(
|
||||
[(1, "Alice"), (2, "Bob")],
|
||||
description=description,
|
||||
)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.post("/api/db/query", json={"sql": "SELECT id, name FROM users"})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["columns"] == ["id", "name"]
|
||||
assert data["rows"] == [[1, "Alice"], [2, "Bob"]]
|
||||
assert data["row_count"] == 2
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_empty_result(self, mock_get_conn):
|
||||
description = [("id",)]
|
||||
conn, cur = _make_mock_conn([], description=description)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.post("/api/db/query", json={"sql": "SELECT id FROM empty_table"})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["columns"] == ["id"]
|
||||
assert data["rows"] == []
|
||||
assert data["row_count"] == 0
|
||||
|
||||
# ── 写操作拦截 ──
|
||||
|
||||
@pytest.mark.parametrize("keyword", [
|
||||
"INSERT", "UPDATE", "DELETE", "DROP", "TRUNCATE",
|
||||
"insert", "update", "delete", "drop", "truncate",
|
||||
"Insert", "Update", "Delete", "Drop", "Truncate",
|
||||
])
|
||||
def test_blocks_write_operations(self, keyword):
|
||||
resp = client.post("/api/db/query", json={"sql": f"{keyword} INTO some_table VALUES (1)"})
|
||||
assert resp.status_code == 400
|
||||
assert "只读" in resp.json()["detail"] or "禁止" in resp.json()["detail"]
|
||||
|
||||
def test_blocks_mixed_case_write(self):
|
||||
resp = client.post("/api/db/query", json={"sql": "DeLeTe FROM users WHERE id = 1"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_blocks_write_in_subquery(self):
|
||||
"""写操作关键词出现在 SQL 任意位置都应拦截。"""
|
||||
resp = client.post("/api/db/query", json={"sql": "SELECT * FROM (DELETE FROM users) sub"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
# ── 空 SQL ──
|
||||
|
||||
def test_empty_sql(self):
|
||||
resp = client.post("/api/db/query", json={"sql": ""})
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_whitespace_only_sql(self):
|
||||
resp = client.post("/api/db/query", json={"sql": " "})
|
||||
assert resp.status_code == 400
|
||||
|
||||
# ── SQL 语法错误 ──
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_sql_syntax_error(self, mock_get_conn):
|
||||
conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
# 第一次 execute 设置 timeout 成功,第二次抛异常
|
||||
mock_cur.execute.side_effect = [None, Exception("syntax error at or near \"SELEC\"")]
|
||||
mock_cur.description = None
|
||||
conn.cursor.return_value.__enter__ = lambda s: mock_cur
|
||||
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.post("/api/db/query", json={"sql": "SELEC * FROM users"})
|
||||
assert resp.status_code == 400
|
||||
assert "SQL 执行错误" in resp.json()["detail"]
|
||||
|
||||
# ── 查询超时 ──
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_query_timeout(self, mock_get_conn):
|
||||
conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
mock_cur.execute.side_effect = [None, pg_errors.QueryCanceled()]
|
||||
mock_cur.description = None
|
||||
conn.cursor.return_value.__enter__ = lambda s: mock_cur
|
||||
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.post("/api/db/query", json={"sql": "SELECT pg_sleep(60)"})
|
||||
assert resp.status_code == 408
|
||||
assert "超时" in resp.json()["detail"]
|
||||
|
||||
# ── 行数限制验证 ──
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_row_limit(self, mock_get_conn):
|
||||
"""验证 fetchmany 被调用时传入 1000 行限制。"""
|
||||
description = [("id",)]
|
||||
conn, cur = _make_mock_conn(
|
||||
[(i,) for i in range(1000)],
|
||||
description=description,
|
||||
)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.post("/api/db/query", json={"sql": "SELECT id FROM big_table"})
|
||||
assert resp.status_code == 200
|
||||
# 验证 fetchmany 被调用时传入了 1000
|
||||
cur.fetchmany.assert_called_once_with(1000)
|
||||
|
||||
# ── 超时设置验证 ──
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_sets_statement_timeout(self, mock_get_conn):
|
||||
"""验证查询前设置了 statement_timeout。"""
|
||||
description = [("id",)]
|
||||
conn, cur = _make_mock_conn([(1,)], description=description)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
client.post("/api/db/query", json={"sql": "SELECT 1"})
|
||||
|
||||
# 第一次 execute 应该是设置超时
|
||||
first_call = cur.execute.call_args_list[0]
|
||||
assert "statement_timeout" in first_call[0][0]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 认证测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDbViewerAuth:
|
||||
|
||||
def test_requires_auth(self):
|
||||
"""移除 auth override 后,所有端点应返回 401/403。"""
|
||||
original = app.dependency_overrides.pop(get_current_user, None)
|
||||
try:
|
||||
endpoints = [
|
||||
("GET", "/api/db/schemas"),
|
||||
("GET", "/api/db/schemas/dwd/tables"),
|
||||
("GET", "/api/db/tables/dwd/dim_member/columns"),
|
||||
("POST", "/api/db/query"),
|
||||
]
|
||||
for method, url in endpoints:
|
||||
if method == "POST":
|
||||
resp = client.request(method, url, json={"sql": "SELECT 1"})
|
||||
else:
|
||||
resp = client.request(method, url)
|
||||
assert resp.status_code in (401, 403), f"{method} {url} 应需要认证"
|
||||
finally:
|
||||
if original:
|
||||
app.dependency_overrides[get_current_user] = original
|
||||
191
apps/backend/tests/test_env_config_properties.py
Normal file
191
apps/backend/tests/test_env_config_properties.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""环境配置属性测试(Property-Based Testing)。
|
||||
|
||||
使用 hypothesis 验证环境配置管理的通用正确性属性:
|
||||
- Property 15: .env 解析与敏感值掩码
|
||||
- Property 16: .env 写入往返一致性
|
||||
|
||||
测试策略:
|
||||
- Property 15: 生成随机 .env 内容(含敏感和非敏感键),验证 _parse_env + _is_sensitive 对敏感值掩码
|
||||
- Property 16: 生成随机键值对,序列化为 .env 格式后再解析,验证往返一致性
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-env-config-properties")
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from app.routers.env_config import _parse_env, _is_sensitive, _MASK, _SENSITIVE_KEYWORDS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 通用策略(Strategies)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 合法的环境变量键名:字母或下划线开头,后跟字母、数字、下划线
|
||||
_key_start_char = st.sampled_from(
|
||||
list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_")
|
||||
)
|
||||
_key_rest_char = st.sampled_from(
|
||||
list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_")
|
||||
)
|
||||
|
||||
_env_key_st = st.builds(
|
||||
lambda first, rest: first + rest,
|
||||
first=_key_start_char,
|
||||
rest=st.text(alphabet=_key_rest_char, min_size=0, max_size=30),
|
||||
)
|
||||
|
||||
# 值:不含换行符的可打印字符串(排除引号以避免解析歧义)
|
||||
_env_value_st = st.text(
|
||||
alphabet=st.characters(
|
||||
whitelist_categories=("L", "N", "P", "S"),
|
||||
blacklist_characters='\n\r"\'#',
|
||||
),
|
||||
min_size=0,
|
||||
max_size=50,
|
||||
)
|
||||
|
||||
# 敏感键:在随机键名中嵌入敏感关键词
|
||||
_sensitive_keyword_st = st.sampled_from(list(_SENSITIVE_KEYWORDS))
|
||||
|
||||
_sensitive_key_st = st.builds(
|
||||
lambda prefix, kw, suffix: prefix + kw + suffix,
|
||||
prefix=st.text(alphabet=_key_rest_char, min_size=0, max_size=10),
|
||||
kw=_sensitive_keyword_st,
|
||||
suffix=st.text(alphabet=_key_rest_char, min_size=0, max_size=10),
|
||||
).filter(lambda k: len(k) > 0 and k[0].isalpha() or k[0] == "_")
|
||||
|
||||
# 确保敏感键以字母或下划线开头
|
||||
_safe_sensitive_key_st = st.builds(
|
||||
lambda prefix, kw: prefix + "_" + kw,
|
||||
prefix=st.sampled_from(["DB", "API", "ETL", "APP", "MY"]),
|
||||
kw=_sensitive_keyword_st,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 15: .env 解析与敏感值掩码
|
||||
# **Validates: Requirements 6.1, 6.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
sensitive_keys=st.lists(_safe_sensitive_key_st, min_size=1, max_size=5, unique=True),
|
||||
sensitive_values=st.lists(
|
||||
st.text(min_size=1, max_size=30, alphabet=st.characters(
|
||||
whitelist_categories=("L", "N"),
|
||||
)),
|
||||
min_size=1, max_size=5,
|
||||
),
|
||||
normal_keys=st.lists(_env_key_st, min_size=1, max_size=5, unique=True),
|
||||
normal_values=st.lists(_env_value_st, min_size=1, max_size=5),
|
||||
)
|
||||
def test_sensitive_values_masked(sensitive_keys, sensitive_values, normal_keys, normal_values):
|
||||
"""Property 15: .env 解析与敏感值掩码。
|
||||
|
||||
包含敏感键(PASSWORD、TOKEN、SECRET、DSN)的 .env 文件内容,
|
||||
API 返回的键值对列表中这些键的值应被掩码替换,不包含原始敏感值。
|
||||
"""
|
||||
# 确保敏感键和普通键不重叠
|
||||
normal_keys_filtered = [k for k in normal_keys if k not in sensitive_keys]
|
||||
assume(len(normal_keys_filtered) >= 1)
|
||||
|
||||
# 对齐列表长度
|
||||
s_vals = (sensitive_values * ((len(sensitive_keys) // len(sensitive_values)) + 1))[:len(sensitive_keys)]
|
||||
n_vals = (normal_values * ((len(normal_keys_filtered) // len(normal_values)) + 1))[:len(normal_keys_filtered)]
|
||||
|
||||
# 构造 .env 内容
|
||||
lines = []
|
||||
for k, v in zip(sensitive_keys, s_vals):
|
||||
lines.append(f"{k}={v}")
|
||||
for k, v in zip(normal_keys_filtered, n_vals):
|
||||
lines.append(f"{k}={v}")
|
||||
env_content = "\n".join(lines) + "\n"
|
||||
|
||||
# 解析
|
||||
parsed = _parse_env(env_content)
|
||||
entries = [line for line in parsed if line["type"] == "entry"]
|
||||
|
||||
# 模拟 GET 端点的掩码逻辑
|
||||
masked_entries = {}
|
||||
for entry in entries:
|
||||
if _is_sensitive(entry["key"]):
|
||||
masked_entries[entry["key"]] = _MASK
|
||||
else:
|
||||
masked_entries[entry["key"]] = entry["value"]
|
||||
|
||||
# 验证:敏感键的值应被掩码
|
||||
for k, v in zip(sensitive_keys, s_vals):
|
||||
assert k in masked_entries, f"敏感键 {k} 应出现在解析结果中"
|
||||
assert masked_entries[k] == _MASK, (
|
||||
f"敏感键 {k} 的值应为掩码 '{_MASK}',实际为 '{masked_entries[k]}'"
|
||||
)
|
||||
# 原始敏感值不应出现在掩码后的结果中
|
||||
assert masked_entries[k] != v, (
|
||||
f"敏感键 {k} 的原始值 '{v}' 不应出现在掩码结果中"
|
||||
)
|
||||
|
||||
# 验证:非敏感键的值应保持原样
|
||||
for k, v in zip(normal_keys_filtered, n_vals):
|
||||
if not _is_sensitive(k):
|
||||
assert k in masked_entries, f"普通键 {k} 应出现在解析结果中"
|
||||
assert masked_entries[k] == v, (
|
||||
f"普通键 {k} 的值应为 '{v}',实际为 '{masked_entries[k]}'"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 16: .env 写入往返一致性
|
||||
# **Validates: Requirements 6.2**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
entries=st.lists(
|
||||
st.tuples(_env_key_st, _env_value_st),
|
||||
min_size=1,
|
||||
max_size=10,
|
||||
unique_by=lambda t: t[0], # 键唯一
|
||||
),
|
||||
)
|
||||
def test_env_write_read_round_trip(entries):
|
||||
"""Property 16: .env 写入往返一致性。
|
||||
|
||||
有效的键值对集合(不含注释和空行),写入 .env 文件后再读取解析,
|
||||
应得到与原始集合等价的键值对。
|
||||
"""
|
||||
# 过滤掉值中可能导致解析歧义的情况(值前后空白会被 strip)
|
||||
clean_entries = [(k, v.strip()) for k, v in entries]
|
||||
# 排除空键(策略已保证非空,但防御性检查)
|
||||
clean_entries = [(k, v) for k, v in clean_entries if k]
|
||||
|
||||
assume(len(clean_entries) >= 1)
|
||||
|
||||
# 模拟写入:构造 .env 文件内容
|
||||
lines = [f"{k}={v}" for k, v in clean_entries]
|
||||
env_content = "\n".join(lines) + "\n"
|
||||
|
||||
# 解析
|
||||
parsed = _parse_env(env_content)
|
||||
parsed_entries = {
|
||||
line["key"]: line["value"]
|
||||
for line in parsed
|
||||
if line["type"] == "entry"
|
||||
}
|
||||
|
||||
# 验证往返一致性:每个写入的键值对都应在解析结果中
|
||||
for k, v in clean_entries:
|
||||
assert k in parsed_entries, (
|
||||
f"键 '{k}' 应出现在解析结果中,实际键集合: {list(parsed_entries.keys())}"
|
||||
)
|
||||
assert parsed_entries[k] == v, (
|
||||
f"键 '{k}' 的值不一致:写入='{v}',解析='{parsed_entries[k]}'"
|
||||
)
|
||||
|
||||
# 验证:解析结果的键数量应与写入的一致
|
||||
assert len(parsed_entries) == len(clean_entries), (
|
||||
f"解析结果键数量 {len(parsed_entries)} 应等于写入数量 {len(clean_entries)}"
|
||||
)
|
||||
291
apps/backend/tests/test_env_config_router.py
Normal file
291
apps/backend/tests/test_env_config_router.py
Normal file
@@ -0,0 +1,291 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""环境配置路由单元测试
|
||||
|
||||
覆盖 3 个端点:GET / PUT / GET /export
|
||||
通过 mock 绕过文件 I/O,专注路由逻辑验证。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
|
||||
_TEST_USER = CurrentUser(user_id=1, site_id=100)
|
||||
|
||||
|
||||
def _override_auth():
|
||||
return _TEST_USER
|
||||
|
||||
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
client = TestClient(app)
|
||||
|
||||
# 模拟 .env 文件内容
|
||||
_SAMPLE_ENV = """\
|
||||
# 数据库配置
|
||||
DB_HOST=localhost
|
||||
DB_PORT=5432
|
||||
DB_PASSWORD=super_secret_123
|
||||
JWT_SECRET_KEY=my-jwt-secret
|
||||
|
||||
# ETL 配置
|
||||
ETL_DB_DSN=postgresql://user:pass@host/db
|
||||
TIMEZONE=Asia/Shanghai
|
||||
"""
|
||||
|
||||
_MOCK_ENV_PATH = "app.routers.env_config._ENV_PATH"
|
||||
|
||||
|
||||
def _mock_path(content: str | None = _SAMPLE_ENV, exists: bool = True):
|
||||
"""构造 mock Path 对象。"""
|
||||
mock = MagicMock()
|
||||
mock.exists.return_value = exists
|
||||
if content is not None:
|
||||
mock.read_text.return_value = content
|
||||
return mock
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/env-config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetEnvConfig:
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_returns_entries_with_masked_sensitive(self, mock_path_obj):
|
||||
mock_path_obj.__class__ = type(MagicMock())
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = _SAMPLE_ENV
|
||||
|
||||
resp = client.get("/api/env-config")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
entries = {e["key"]: e["value"] for e in data["entries"]}
|
||||
|
||||
# 非敏感值原样返回
|
||||
assert entries["DB_HOST"] == "localhost"
|
||||
assert entries["DB_PORT"] == "5432"
|
||||
assert entries["TIMEZONE"] == "Asia/Shanghai"
|
||||
|
||||
# 敏感值掩码
|
||||
assert entries["DB_PASSWORD"] == "****"
|
||||
assert entries["JWT_SECRET_KEY"] == "****"
|
||||
assert entries["ETL_DB_DSN"] == "****"
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_file_not_found(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = False
|
||||
|
||||
resp = client.get("/api/env-config")
|
||||
assert resp.status_code == 404
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_empty_file(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = ""
|
||||
|
||||
resp = client.get("/api/env-config")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["entries"] == []
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_comments_and_blank_lines_excluded(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = "# comment\n\nKEY=val\n"
|
||||
|
||||
resp = client.get("/api/env-config")
|
||||
assert resp.status_code == 200
|
||||
entries = resp.json()["entries"]
|
||||
assert len(entries) == 1
|
||||
assert entries[0]["key"] == "KEY"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PUT /api/env-config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUpdateEnvConfig:
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_update_existing_key(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = "DB_HOST=localhost\nDB_PORT=5432\n"
|
||||
|
||||
resp = client.put("/api/env-config", json={
|
||||
"entries": [
|
||||
{"key": "DB_HOST", "value": "192.168.1.1"},
|
||||
{"key": "DB_PORT", "value": "5433"},
|
||||
]
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
# 验证写入内容
|
||||
written = mock_path_obj.write_text.call_args[0][0]
|
||||
assert "DB_HOST=192.168.1.1" in written
|
||||
assert "DB_PORT=5433" in written
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_add_new_key(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = "DB_HOST=localhost\n"
|
||||
|
||||
resp = client.put("/api/env-config", json={
|
||||
"entries": [
|
||||
{"key": "DB_HOST", "value": "localhost"},
|
||||
{"key": "NEW_KEY", "value": "new_value"},
|
||||
]
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
written = mock_path_obj.write_text.call_args[0][0]
|
||||
assert "NEW_KEY=new_value" in written
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_masked_value_preserves_original(self, mock_path_obj):
|
||||
"""掩码值(****)不应覆盖原始敏感值。"""
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = "DB_PASSWORD=real_secret\nDB_HOST=localhost\n"
|
||||
|
||||
resp = client.put("/api/env-config", json={
|
||||
"entries": [
|
||||
{"key": "DB_PASSWORD", "value": "****"},
|
||||
{"key": "DB_HOST", "value": "newhost"},
|
||||
]
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
written = mock_path_obj.write_text.call_args[0][0]
|
||||
# 原始密码应保留
|
||||
assert "DB_PASSWORD=real_secret" in written
|
||||
assert "DB_HOST=newhost" in written
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_preserves_comments(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = "# 注释行\nDB_HOST=localhost\n\n# 另一个注释\n"
|
||||
|
||||
resp = client.put("/api/env-config", json={
|
||||
"entries": [{"key": "DB_HOST", "value": "newhost"}]
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
written = mock_path_obj.write_text.call_args[0][0]
|
||||
assert "# 注释行" in written
|
||||
assert "# 另一个注释" in written
|
||||
|
||||
def test_invalid_key_format(self):
|
||||
resp = client.put("/api/env-config", json={
|
||||
"entries": [{"key": "123BAD", "value": "val"}]
|
||||
})
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_empty_key(self):
|
||||
resp = client.put("/api/env-config", json={
|
||||
"entries": [{"key": "", "value": "val"}]
|
||||
})
|
||||
assert resp.status_code == 422
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_file_not_exists_creates_new(self, mock_path_obj):
|
||||
"""文件不存在时,应创建新文件。"""
|
||||
mock_path_obj.exists.return_value = False
|
||||
|
||||
resp = client.put("/api/env-config", json={
|
||||
"entries": [{"key": "NEW_KEY", "value": "value"}]
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
written = mock_path_obj.write_text.call_args[0][0]
|
||||
assert "NEW_KEY=value" in written
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_update_sensitive_with_new_value(self, mock_path_obj):
|
||||
"""显式提供新密码时应更新。"""
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = "DB_PASSWORD=old_secret\n"
|
||||
|
||||
resp = client.put("/api/env-config", json={
|
||||
"entries": [{"key": "DB_PASSWORD", "value": "new_secret"}]
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
written = mock_path_obj.write_text.call_args[0][0]
|
||||
assert "DB_PASSWORD=new_secret" in written
|
||||
|
||||
# 返回值中敏感键仍然掩码
|
||||
entries = {e["key"]: e["value"] for e in resp.json()["entries"]}
|
||||
assert entries["DB_PASSWORD"] == "****"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/env-config/export
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExportEnvConfig:
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_export_masks_sensitive(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = _SAMPLE_ENV
|
||||
|
||||
resp = client.get("/api/env-config/export")
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["content-type"].startswith("text/plain")
|
||||
assert "attachment" in resp.headers.get("content-disposition", "")
|
||||
|
||||
content = resp.text
|
||||
# 非敏感值保留
|
||||
assert "DB_HOST=localhost" in content
|
||||
assert "TIMEZONE=Asia/Shanghai" in content
|
||||
|
||||
# 敏感值掩码
|
||||
assert "super_secret_123" not in content
|
||||
assert "my-jwt-secret" not in content
|
||||
assert "DB_PASSWORD=****" in content
|
||||
assert "JWT_SECRET_KEY=****" in content
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_export_preserves_comments(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = _SAMPLE_ENV
|
||||
|
||||
content = client.get("/api/env-config/export").text
|
||||
assert "# 数据库配置" in content
|
||||
assert "# ETL 配置" in content
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_export_file_not_found(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = False
|
||||
|
||||
resp = client.get("/api/env-config/export")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 认证测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEnvConfigAuth:
|
||||
|
||||
def test_requires_auth(self):
|
||||
"""移除 auth override 后,所有端点应返回 401/403。"""
|
||||
# 临时移除 override
|
||||
original = app.dependency_overrides.pop(get_current_user, None)
|
||||
try:
|
||||
for method, url in [
|
||||
("GET", "/api/env-config"),
|
||||
("PUT", "/api/env-config"),
|
||||
("GET", "/api/env-config/export"),
|
||||
]:
|
||||
resp = client.request(method, url)
|
||||
assert resp.status_code in (401, 403), f"{method} {url} 应需要认证"
|
||||
finally:
|
||||
if original:
|
||||
app.dependency_overrides[get_current_user] = original
|
||||
246
apps/backend/tests/test_etl_status_router.py
Normal file
246
apps/backend/tests/test_etl_status_router.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""ETL 状态路由单元测试
|
||||
|
||||
覆盖 2 个端点:
|
||||
- GET /api/etl-status/cursors
|
||||
- GET /api/etl-status/recent-runs
|
||||
|
||||
通过 mock 绕过数据库连接,专注路由逻辑验证。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
|
||||
_TEST_USER = CurrentUser(user_id=1, site_id=100)
|
||||
|
||||
|
||||
def _override_auth():
|
||||
return _TEST_USER
|
||||
|
||||
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
client = TestClient(app)
|
||||
|
||||
_MOCK_ETL_CONN = "app.routers.etl_status.get_etl_readonly_connection"
|
||||
_MOCK_APP_CONN = "app.routers.etl_status.get_connection"
|
||||
|
||||
|
||||
def _make_mock_conn(rows):
|
||||
"""构造 mock 数据库连接,cursor 返回指定行。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
mock_cur.fetchall.return_value = rows
|
||||
mock_cur.fetchone.return_value = None
|
||||
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cur
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_conn, mock_cur
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/etl-status/cursors
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListCursors:
|
||||
|
||||
@patch(_MOCK_ETL_CONN)
|
||||
def test_returns_cursor_list(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([
|
||||
("ODS_FETCH_ORDERS", "2024-06-15 10:30:00+08", 1500),
|
||||
("ODS_FETCH_MEMBERS", "2024-06-15 09:00:00+08", 800),
|
||||
])
|
||||
# fetchone 用于 EXISTS 检查
|
||||
cur.fetchone.return_value = (True,)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/etl-status/cursors")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 2
|
||||
assert data[0]["task_code"] == "ODS_FETCH_ORDERS"
|
||||
assert data[0]["last_fetch_time"] == "2024-06-15 10:30:00+08"
|
||||
assert data[0]["record_count"] == 1500
|
||||
assert data[1]["task_code"] == "ODS_FETCH_MEMBERS"
|
||||
|
||||
# 验证 site_id 传递
|
||||
mock_get_conn.assert_called_once_with(_TEST_USER.site_id)
|
||||
conn.close.assert_called_once()
|
||||
|
||||
@patch(_MOCK_ETL_CONN)
|
||||
def test_table_not_exists_returns_empty(self, mock_get_conn):
|
||||
"""etl_admin.etl_cursor 表不存在时返回空列表。"""
|
||||
conn, cur = _make_mock_conn([])
|
||||
cur.fetchone.return_value = (False,)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/etl-status/cursors")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
@patch(_MOCK_ETL_CONN)
|
||||
def test_null_fields(self, mock_get_conn):
|
||||
"""游标字段可能为 None(任务从未执行过)。"""
|
||||
conn, cur = _make_mock_conn([
|
||||
("ODS_FETCH_INVENTORY", None, None),
|
||||
])
|
||||
cur.fetchone.return_value = (True,)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/etl-status/cursors")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data[0]["task_code"] == "ODS_FETCH_INVENTORY"
|
||||
assert data[0]["last_fetch_time"] is None
|
||||
assert data[0]["record_count"] is None
|
||||
|
||||
@patch(_MOCK_ETL_CONN)
|
||||
def test_empty_cursors(self, mock_get_conn):
|
||||
"""表存在但无数据。"""
|
||||
conn, cur = _make_mock_conn([])
|
||||
cur.fetchone.return_value = (True,)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/etl-status/cursors")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/etl-status/recent-runs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListRecentRuns:
|
||||
|
||||
@patch(_MOCK_APP_CONN)
|
||||
def test_returns_recent_runs(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([
|
||||
(
|
||||
"a1b2c3d4-0000-0000-0000-000000000001",
|
||||
["ODS_FETCH_ORDERS", "DWD_LOAD_FROM_ODS"],
|
||||
"success",
|
||||
"2024-06-15 10:30:00+08",
|
||||
"2024-06-15 10:35:00+08",
|
||||
300000,
|
||||
0,
|
||||
),
|
||||
(
|
||||
"a1b2c3d4-0000-0000-0000-000000000002",
|
||||
["DWS_AGGREGATE"],
|
||||
"failed",
|
||||
"2024-06-15 09:00:00+08",
|
||||
"2024-06-15 09:01:00+08",
|
||||
60000,
|
||||
1,
|
||||
),
|
||||
])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/etl-status/recent-runs")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 2
|
||||
|
||||
run0 = data[0]
|
||||
assert run0["id"] == "a1b2c3d4-0000-0000-0000-000000000001"
|
||||
assert run0["task_codes"] == ["ODS_FETCH_ORDERS", "DWD_LOAD_FROM_ODS"]
|
||||
assert run0["status"] == "success"
|
||||
assert run0["duration_ms"] == 300000
|
||||
assert run0["exit_code"] == 0
|
||||
|
||||
run1 = data[1]
|
||||
assert run1["status"] == "failed"
|
||||
assert run1["exit_code"] == 1
|
||||
|
||||
conn.close.assert_called_once()
|
||||
|
||||
@patch(_MOCK_APP_CONN)
|
||||
def test_empty_runs(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/etl-status/recent-runs")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
@patch(_MOCK_APP_CONN)
|
||||
def test_null_optional_fields(self, mock_get_conn):
|
||||
"""正在执行的任务 finished_at / duration_ms / exit_code 为 None。"""
|
||||
conn, cur = _make_mock_conn([
|
||||
(
|
||||
"a1b2c3d4-0000-0000-0000-000000000003",
|
||||
["ODS_FETCH_MEMBERS"],
|
||||
"running",
|
||||
"2024-06-15 11:00:00+08",
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/etl-status/recent-runs")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data[0]["status"] == "running"
|
||||
assert data[0]["finished_at"] is None
|
||||
assert data[0]["duration_ms"] is None
|
||||
assert data[0]["exit_code"] is None
|
||||
|
||||
@patch(_MOCK_APP_CONN)
|
||||
def test_site_id_filter(self, mock_get_conn):
|
||||
"""验证查询时传入了正确的 site_id 参数。"""
|
||||
conn, cur = _make_mock_conn([])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
client.get("/api/etl-status/recent-runs")
|
||||
|
||||
# 验证 SQL 中传入了 site_id 和 limit
|
||||
call_args = cur.execute.call_args
|
||||
params = call_args[0][1]
|
||||
assert params[0] == _TEST_USER.site_id
|
||||
assert params[1] == 50
|
||||
|
||||
@patch(_MOCK_APP_CONN)
|
||||
def test_empty_task_codes(self, mock_get_conn):
|
||||
"""task_codes 为 None 时应返回空列表。"""
|
||||
conn, cur = _make_mock_conn([
|
||||
(
|
||||
"a1b2c3d4-0000-0000-0000-000000000004",
|
||||
None,
|
||||
"pending",
|
||||
"2024-06-15 12:00:00+08",
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/etl-status/recent-runs")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()[0]["task_codes"] == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 认证测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEtlStatusAuth:
|
||||
|
||||
def test_requires_auth(self):
|
||||
"""移除 auth override 后,所有端点应返回 401/403。"""
|
||||
original = app.dependency_overrides.pop(get_current_user, None)
|
||||
try:
|
||||
for url in ["/api/etl-status/cursors", "/api/etl-status/recent-runs"]:
|
||||
resp = client.get(url)
|
||||
assert resp.status_code in (401, 403), f"GET {url} 应需要认证"
|
||||
finally:
|
||||
if original:
|
||||
app.dependency_overrides[get_current_user] = original
|
||||
339
apps/backend/tests/test_execution_router.py
Normal file
339
apps/backend/tests/test_execution_router.py
Normal file
@@ -0,0 +1,339 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""执行与队列路由单元测试
|
||||
|
||||
覆盖 8 个端点:run / queue CRUD / cancel / history / logs
|
||||
通过 mock 绕过数据库和服务层,专注路由逻辑验证。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
from app.services.task_queue import QueuedTask
|
||||
|
||||
# 固定测试用户
|
||||
_TEST_USER = CurrentUser(user_id=1, site_id=100)
|
||||
|
||||
|
||||
def _override_auth():
|
||||
return _TEST_USER
|
||||
|
||||
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
client = TestClient(app)
|
||||
|
||||
_NOW = datetime(2024, 6, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
# 构造测试用的 TaskConfig payload
|
||||
_VALID_CONFIG = {
|
||||
"tasks": ["ODS_MEMBER"],
|
||||
"pipeline": "api_ods",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/execution/run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRunTask:
|
||||
@patch("app.routers.execution.task_executor")
|
||||
def test_run_returns_execution_id(self, mock_executor):
|
||||
mock_executor.execute = AsyncMock()
|
||||
resp = client.post("/api/execution/run", json=_VALID_CONFIG)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "execution_id" in data
|
||||
assert data["message"] == "任务已提交执行"
|
||||
|
||||
@patch("app.routers.execution.task_executor")
|
||||
def test_run_injects_store_id(self, mock_executor):
|
||||
"""store_id 应从 JWT 注入"""
|
||||
mock_executor.execute = AsyncMock()
|
||||
resp = client.post("/api/execution/run", json={
|
||||
**_VALID_CONFIG,
|
||||
"store_id": 999, # 前端传的值应被覆盖
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_run_requires_auth(self):
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
try:
|
||||
resp = client.post("/api/execution/run", json=_VALID_CONFIG)
|
||||
assert resp.status_code in (401, 403)
|
||||
finally:
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
|
||||
def test_run_invalid_config_returns_422(self):
|
||||
"""缺少必填字段 tasks 时返回 422"""
|
||||
resp = client.post("/api/execution/run", json={"pipeline": "api_ods"})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/execution/queue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetQueue:
|
||||
@patch("app.routers.execution.task_queue")
|
||||
def test_get_queue_returns_list(self, mock_queue):
|
||||
mock_queue.list_pending.return_value = [
|
||||
QueuedTask(
|
||||
id="task-1", site_id=100, config={"tasks": ["ODS_MEMBER"]},
|
||||
status="pending", position=1, created_at=_NOW,
|
||||
),
|
||||
]
|
||||
resp = client.get("/api/execution/queue")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["id"] == "task-1"
|
||||
assert data[0]["status"] == "pending"
|
||||
|
||||
@patch("app.routers.execution.task_queue")
|
||||
def test_get_queue_empty(self, mock_queue):
|
||||
mock_queue.list_pending.return_value = []
|
||||
resp = client.get("/api/execution/queue")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
@patch("app.routers.execution.task_queue")
|
||||
def test_get_queue_filters_by_site_id(self, mock_queue):
|
||||
"""确认调用 list_pending 时传入了正确的 site_id"""
|
||||
mock_queue.list_pending.return_value = []
|
||||
client.get("/api/execution/queue")
|
||||
mock_queue.list_pending.assert_called_once_with(100)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/execution/queue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEnqueueTask:
|
||||
@patch("app.routers.execution.get_connection")
|
||||
@patch("app.routers.execution.task_queue")
|
||||
def test_enqueue_returns_201(self, mock_queue, mock_get_conn):
|
||||
mock_queue.enqueue.return_value = "new-task-id"
|
||||
|
||||
# mock 数据库查询返回
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = (
|
||||
"new-task-id", 100, '{"tasks": ["ODS_MEMBER"]}',
|
||||
"pending", 1, _NOW, None, None, None, None,
|
||||
)
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.post("/api/execution/queue", json=_VALID_CONFIG)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["id"] == "new-task-id"
|
||||
assert data["status"] == "pending"
|
||||
|
||||
@patch("app.routers.execution.task_queue")
|
||||
def test_enqueue_calls_with_site_id(self, mock_queue):
|
||||
"""确认 enqueue 时传入了 JWT 的 site_id"""
|
||||
mock_queue.enqueue.return_value = "id-1"
|
||||
|
||||
# 让后续的 DB 查询抛异常来快速结束(enqueue 本身已验证)
|
||||
with patch("app.routers.execution.get_connection") as mock_conn:
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = (
|
||||
"id-1", 100, '{"tasks": []}', "pending", 1,
|
||||
_NOW, None, None, None, None,
|
||||
)
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_conn.return_value = conn
|
||||
|
||||
client.post("/api/execution/queue", json=_VALID_CONFIG)
|
||||
|
||||
# 验证 enqueue 的第二个参数是 site_id=100
|
||||
call_args = mock_queue.enqueue.call_args
|
||||
assert call_args[0][1] == 100 # site_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PUT /api/execution/queue/reorder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReorderQueue:
|
||||
@patch("app.routers.execution.task_queue")
|
||||
def test_reorder_success(self, mock_queue):
|
||||
mock_queue.reorder.return_value = None
|
||||
resp = client.put("/api/execution/queue/reorder", json={
|
||||
"task_id": "task-1",
|
||||
"new_position": 3,
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
mock_queue.reorder.assert_called_once_with("task-1", 3, 100)
|
||||
|
||||
def test_reorder_missing_fields_returns_422(self):
|
||||
resp = client.put("/api/execution/queue/reorder", json={})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DELETE /api/execution/queue/{id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDeleteQueueTask:
|
||||
@patch("app.routers.execution.task_queue")
|
||||
def test_delete_success(self, mock_queue):
|
||||
mock_queue.delete.return_value = True
|
||||
resp = client.delete("/api/execution/queue/task-1")
|
||||
assert resp.status_code == 200
|
||||
mock_queue.delete.assert_called_once_with("task-1", 100)
|
||||
|
||||
@patch("app.routers.execution.task_queue")
|
||||
def test_delete_nonexistent_returns_409(self, mock_queue):
|
||||
mock_queue.delete.return_value = False
|
||||
resp = client.delete("/api/execution/queue/nonexistent")
|
||||
assert resp.status_code == 409
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/execution/{id}/cancel
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCancelExecution:
|
||||
@patch("app.routers.execution.task_executor")
|
||||
def test_cancel_success(self, mock_executor):
|
||||
mock_executor.cancel = AsyncMock(return_value=True)
|
||||
resp = client.post("/api/execution/some-id/cancel")
|
||||
assert resp.status_code == 200
|
||||
assert "取消" in resp.json()["message"]
|
||||
|
||||
@patch("app.routers.execution.task_executor")
|
||||
def test_cancel_not_found(self, mock_executor):
|
||||
mock_executor.cancel = AsyncMock(return_value=False)
|
||||
resp = client.post("/api/execution/nonexistent/cancel")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/execution/history
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExecutionHistory:
|
||||
@patch("app.routers.execution.get_connection")
|
||||
def test_history_returns_list(self, mock_get_conn):
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchall.return_value = [
|
||||
(
|
||||
"exec-1", 100, ["ODS_MEMBER"], "success",
|
||||
_NOW, _NOW, 0, 1234, "python -m cli.main", None,
|
||||
),
|
||||
]
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.get("/api/execution/history")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["id"] == "exec-1"
|
||||
assert data[0]["status"] == "success"
|
||||
|
||||
@patch("app.routers.execution.get_connection")
|
||||
def test_history_respects_limit(self, mock_get_conn):
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchall.return_value = []
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.get("/api/execution/history?limit=10")
|
||||
assert resp.status_code == 200
|
||||
|
||||
# 验证 SQL 中传入了 limit=10
|
||||
call_args = mock_cursor.execute.call_args
|
||||
assert call_args[0][1] == (100, 10) # (site_id, limit)
|
||||
|
||||
def test_history_limit_validation(self):
|
||||
"""limit 超出范围时返回 422"""
|
||||
resp = client.get("/api/execution/history?limit=0")
|
||||
assert resp.status_code == 422
|
||||
resp = client.get("/api/execution/history?limit=999")
|
||||
assert resp.status_code == 422
|
||||
|
||||
@patch("app.routers.execution.get_connection")
|
||||
def test_history_empty(self, mock_get_conn):
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchall.return_value = []
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.get("/api/execution/history")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/execution/{id}/logs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExecutionLogs:
|
||||
@patch("app.routers.execution.get_connection")
|
||||
@patch("app.routers.execution.task_executor")
|
||||
def test_logs_from_db(self, mock_executor, mock_get_conn):
|
||||
"""已完成任务从数据库读取日志"""
|
||||
mock_executor.is_running.return_value = False
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = ("stdout output", "stderr output")
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.get("/api/execution/exec-1/logs")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["execution_id"] == "exec-1"
|
||||
assert data["output_log"] == "stdout output"
|
||||
assert data["error_log"] == "stderr output"
|
||||
|
||||
@patch("app.routers.execution.task_executor")
|
||||
def test_logs_from_memory(self, mock_executor):
|
||||
"""执行中的任务从内存缓冲区读取"""
|
||||
mock_executor.is_running.return_value = True
|
||||
mock_executor.get_logs.return_value = ["line1", "line2"]
|
||||
|
||||
resp = client.get("/api/execution/running-id/logs")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["execution_id"] == "running-id"
|
||||
assert "line1" in data["output_log"]
|
||||
assert "line2" in data["output_log"]
|
||||
|
||||
@patch("app.routers.execution.get_connection")
|
||||
@patch("app.routers.execution.task_executor")
|
||||
def test_logs_not_found(self, mock_executor, mock_get_conn):
|
||||
mock_executor.is_running.return_value = False
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = None
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.get("/api/execution/nonexistent/logs")
|
||||
assert resp.status_code == 404
|
||||
510
apps/backend/tests/test_queue_properties.py
Normal file
510
apps/backend/tests/test_queue_properties.py
Normal file
@@ -0,0 +1,510 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""队列属性测试(Property-Based Testing)。
|
||||
|
||||
使用 hypothesis 验证队列管理的通用正确性属性:
|
||||
- Property 8: 队列 CRUD 不变量
|
||||
- Property 9: 队列出队顺序
|
||||
- Property 10: 队列重排一致性
|
||||
- Property 11: 执行历史排序与限制
|
||||
|
||||
测试策略:
|
||||
- Property 8-10 通过内存模拟队列状态,mock 数据库操作,验证 TaskQueue 的核心逻辑
|
||||
- Property 11 通过 mock 数据库返回,验证执行历史端点的排序与限制逻辑
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-queue-properties")
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from app.schemas.tasks import TaskConfigSchema
|
||||
from app.services.task_queue import TaskQueue, QueuedTask
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 通用策略(Strategies)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_site_id_st = st.integers(min_value=1, max_value=2**31 - 1)
|
||||
|
||||
# 简单的任务代码列表
|
||||
_task_codes = ["ODS_MEMBER", "ODS_PAYMENT", "ODS_ORDER", "DWD_LOAD_FROM_ODS", "DWS_SUMMARY"]
|
||||
|
||||
_simple_config_st = st.builds(
|
||||
TaskConfigSchema,
|
||||
tasks=st.lists(st.sampled_from(_task_codes), min_size=1, max_size=3, unique=True),
|
||||
pipeline=st.sampled_from(["api_ods", "api_ods_dwd", "ods_dwd"]),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 内存队列模拟器 — 用于 mock 数据库交互
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class InMemoryQueueDB:
|
||||
"""模拟 task_queue 表的内存存储,为 TaskQueue 方法提供 mock 数据库行为。"""
|
||||
|
||||
def __init__(self, site_id: int):
|
||||
self.site_id = site_id
|
||||
# 存储格式:{task_id: {config, status, position, ...}}
|
||||
self.rows: dict[str, dict] = {}
|
||||
|
||||
@property
|
||||
def pending_tasks(self) -> list[dict]:
|
||||
"""按 position 排序的 pending 任务列表。"""
|
||||
return sorted(
|
||||
[r for r in self.rows.values() if r["status"] == "pending"],
|
||||
key=lambda r: r["position"],
|
||||
)
|
||||
|
||||
def mock_enqueue_connection(self):
|
||||
"""为 enqueue 方法构造 mock connection。
|
||||
|
||||
enqueue 执行两条 SQL:
|
||||
1. SELECT COALESCE(MAX(position), 0) → 返回当前最大 position
|
||||
2. INSERT INTO task_queue → 插入新行
|
||||
"""
|
||||
pending = self.pending_tasks
|
||||
max_pos = max((r["position"] for r in pending), default=0)
|
||||
|
||||
call_count = [0]
|
||||
db = self
|
||||
|
||||
def make_cursor():
|
||||
cur = MagicMock()
|
||||
executed_sqls = []
|
||||
|
||||
def execute_side_effect(sql, params=None):
|
||||
executed_sqls.append((sql, params))
|
||||
call_count[0] += 1
|
||||
if "MAX(position)" in sql:
|
||||
cur.fetchone.return_value = (max_pos,)
|
||||
elif "INSERT INTO task_queue" in sql:
|
||||
# 记录插入的行
|
||||
task_id, site_id, config_json, new_pos = params
|
||||
db.rows[task_id] = {
|
||||
"id": task_id,
|
||||
"site_id": site_id,
|
||||
"config": json.loads(config_json),
|
||||
"status": "pending",
|
||||
"position": new_pos,
|
||||
}
|
||||
|
||||
cur.execute = MagicMock(side_effect=execute_side_effect)
|
||||
cur.__enter__ = MagicMock(return_value=cur)
|
||||
cur.__exit__ = MagicMock(return_value=False)
|
||||
return cur
|
||||
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value = make_cursor()
|
||||
return conn
|
||||
|
||||
def mock_dequeue_connection(self):
|
||||
"""为 dequeue 方法构造 mock connection。
|
||||
|
||||
dequeue 执行两条 SQL:
|
||||
1. SELECT ... ORDER BY position ASC LIMIT 1 FOR UPDATE → 返回队首任务
|
||||
2. UPDATE ... SET status = 'running' → 更新状态
|
||||
"""
|
||||
pending = self.pending_tasks
|
||||
first = pending[0] if pending else None
|
||||
db = self
|
||||
|
||||
def make_cursor():
|
||||
cur = MagicMock()
|
||||
|
||||
def execute_side_effect(sql, params=None):
|
||||
if "ORDER BY position ASC" in sql:
|
||||
if first:
|
||||
cur.fetchone.return_value = (
|
||||
first["id"], first["site_id"],
|
||||
json.dumps(first["config"]),
|
||||
first["status"], first["position"],
|
||||
None, None, None, None, None,
|
||||
)
|
||||
else:
|
||||
cur.fetchone.return_value = None
|
||||
elif "SET status = 'running'" in sql:
|
||||
if first:
|
||||
db.rows[first["id"]]["status"] = "running"
|
||||
|
||||
cur.execute = MagicMock(side_effect=execute_side_effect)
|
||||
cur.__enter__ = MagicMock(return_value=cur)
|
||||
cur.__exit__ = MagicMock(return_value=False)
|
||||
return cur
|
||||
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value = make_cursor()
|
||||
return conn
|
||||
|
||||
def mock_delete_connection(self, task_id: str):
|
||||
"""为 delete 方法构造 mock connection。"""
|
||||
db = self
|
||||
|
||||
def make_cursor():
|
||||
cur = MagicMock()
|
||||
|
||||
def execute_side_effect(sql, params=None):
|
||||
tid = params[0]
|
||||
if tid in db.rows and db.rows[tid]["status"] == "pending":
|
||||
del db.rows[tid]
|
||||
cur.rowcount = 1
|
||||
else:
|
||||
cur.rowcount = 0
|
||||
|
||||
cur.execute = MagicMock(side_effect=execute_side_effect)
|
||||
cur.rowcount = 0
|
||||
cur.__enter__ = MagicMock(return_value=cur)
|
||||
cur.__exit__ = MagicMock(return_value=False)
|
||||
return cur
|
||||
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value = make_cursor()
|
||||
return conn
|
||||
|
||||
def mock_reorder_connection(self):
|
||||
"""为 reorder 方法构造 mock connection。
|
||||
|
||||
reorder 执行:
|
||||
1. SELECT id FROM task_queue WHERE ... ORDER BY position ASC
|
||||
2. 多次 UPDATE task_queue SET position = %s WHERE id = %s
|
||||
"""
|
||||
pending = self.pending_tasks
|
||||
db = self
|
||||
|
||||
def make_cursor():
|
||||
cur = MagicMock()
|
||||
call_idx = [0]
|
||||
|
||||
def execute_side_effect(sql, params=None):
|
||||
if "SELECT id FROM task_queue" in sql:
|
||||
cur.fetchall.return_value = [(r["id"],) for r in pending]
|
||||
elif "UPDATE task_queue SET position" in sql:
|
||||
pos, tid = params
|
||||
if tid in db.rows:
|
||||
db.rows[tid]["position"] = pos
|
||||
|
||||
cur.execute = MagicMock(side_effect=execute_side_effect)
|
||||
cur.__enter__ = MagicMock(return_value=cur)
|
||||
cur.__exit__ = MagicMock(return_value=False)
|
||||
return cur
|
||||
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value = make_cursor()
|
||||
return conn
|
||||
|
||||
def mock_list_pending_connection(self):
|
||||
"""为 list_pending 方法构造 mock connection。"""
|
||||
pending = self.pending_tasks
|
||||
|
||||
def make_cursor():
|
||||
cur = MagicMock()
|
||||
|
||||
def execute_side_effect(sql, params=None):
|
||||
cur.fetchall.return_value = [
|
||||
(
|
||||
r["id"], r["site_id"], json.dumps(r["config"]),
|
||||
r["status"], r["position"],
|
||||
None, None, None, None, None,
|
||||
)
|
||||
for r in pending
|
||||
]
|
||||
|
||||
cur.execute = MagicMock(side_effect=execute_side_effect)
|
||||
cur.__enter__ = MagicMock(return_value=cur)
|
||||
cur.__exit__ = MagicMock(return_value=False)
|
||||
return cur
|
||||
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value = make_cursor()
|
||||
return conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 8: 队列 CRUD 不变量
|
||||
# **Validates: Requirements 4.1, 4.4**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
config=_simple_config_st,
|
||||
site_id=_site_id_st,
|
||||
initial_count=st.integers(min_value=0, max_value=5),
|
||||
)
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_queue_crud_invariant(mock_get_conn, config, site_id, initial_count):
|
||||
"""Property 8: 队列 CRUD 不变量。
|
||||
|
||||
入队一个任务后队列长度增加 1 且新任务状态为 pending;
|
||||
删除一个 pending 任务后队列长度减少 1 且该任务不再出现在队列中。
|
||||
"""
|
||||
queue = TaskQueue()
|
||||
db = InMemoryQueueDB(site_id)
|
||||
|
||||
# 预填充若干任务
|
||||
for i in range(initial_count):
|
||||
tid = str(uuid.uuid4())
|
||||
db.rows[tid] = {
|
||||
"id": tid,
|
||||
"site_id": site_id,
|
||||
"config": {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"},
|
||||
"status": "pending",
|
||||
"position": i + 1,
|
||||
}
|
||||
|
||||
before_count = len(db.pending_tasks)
|
||||
|
||||
# --- 入队 ---
|
||||
mock_get_conn.return_value = db.mock_enqueue_connection()
|
||||
new_id = queue.enqueue(config, site_id)
|
||||
|
||||
after_enqueue_count = len(db.pending_tasks)
|
||||
assert after_enqueue_count == before_count + 1, (
|
||||
f"入队后长度应 +1:期望 {before_count + 1},实际 {after_enqueue_count}"
|
||||
)
|
||||
assert new_id in db.rows, "新任务应存在于队列中"
|
||||
assert db.rows[new_id]["status"] == "pending", "新任务状态应为 pending"
|
||||
|
||||
# --- 删除刚入队的任务 ---
|
||||
mock_get_conn.return_value = db.mock_delete_connection(new_id)
|
||||
deleted = queue.delete(new_id, site_id)
|
||||
|
||||
after_delete_count = len(db.pending_tasks)
|
||||
assert deleted is True, "删除 pending 任务应返回 True"
|
||||
assert after_delete_count == before_count, (
|
||||
f"删除后长度应恢复:期望 {before_count},实际 {after_delete_count}"
|
||||
)
|
||||
assert new_id not in db.rows, "已删除任务不应出现在队列中"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 9: 队列出队顺序
|
||||
# **Validates: Requirements 4.2**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
site_id=_site_id_st,
|
||||
num_tasks=st.integers(min_value=1, max_value=8),
|
||||
positions=st.data(),
|
||||
)
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_queue_dequeue_order(mock_get_conn, site_id, num_tasks, positions):
|
||||
"""Property 9: 队列出队顺序。
|
||||
|
||||
包含多个 pending 任务的队列,dequeue 操作应返回 position 值最小的任务。
|
||||
"""
|
||||
queue = TaskQueue()
|
||||
db = InMemoryQueueDB(site_id)
|
||||
|
||||
# 生成不重复的 position 值
|
||||
pos_list = positions.draw(
|
||||
st.lists(
|
||||
st.integers(min_value=1, max_value=1000),
|
||||
min_size=num_tasks,
|
||||
max_size=num_tasks,
|
||||
unique=True,
|
||||
)
|
||||
)
|
||||
|
||||
# 填充队列
|
||||
task_ids = []
|
||||
for i, pos in enumerate(pos_list):
|
||||
tid = str(uuid.uuid4())
|
||||
task_ids.append(tid)
|
||||
db.rows[tid] = {
|
||||
"id": tid,
|
||||
"site_id": site_id,
|
||||
"config": {"tasks": [_task_codes[i % len(_task_codes)]], "pipeline": "api_ods"},
|
||||
"status": "pending",
|
||||
"position": pos,
|
||||
}
|
||||
|
||||
# 找出 position 最小的任务
|
||||
expected_first = min(db.pending_tasks, key=lambda r: r["position"])
|
||||
|
||||
# dequeue
|
||||
mock_get_conn.return_value = db.mock_dequeue_connection()
|
||||
result = queue.dequeue(site_id)
|
||||
|
||||
assert result is not None, "队列非空时 dequeue 不应返回 None"
|
||||
assert result.id == expected_first["id"], (
|
||||
f"应返回 position 最小的任务:期望 id={expected_first['id']} "
|
||||
f"(pos={expected_first['position']}),实际 id={result.id}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 10: 队列重排一致性
|
||||
# **Validates: Requirements 4.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
site_id=_site_id_st,
|
||||
num_tasks=st.integers(min_value=2, max_value=6),
|
||||
data=st.data(),
|
||||
)
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_queue_reorder_consistency(mock_get_conn, site_id, num_tasks, data):
|
||||
"""Property 10: 队列重排一致性。
|
||||
|
||||
重排操作(将任务移动到新位置)后,队列中任务的相对顺序应与请求一致:
|
||||
- 被移动的任务应出现在目标位置(clamp 到有效范围)
|
||||
- 其余任务保持原有相对顺序
|
||||
- 所有任务仍在队列中(不丢失)
|
||||
"""
|
||||
queue = TaskQueue()
|
||||
db = InMemoryQueueDB(site_id)
|
||||
|
||||
# 填充队列,position 从 1 开始连续编号
|
||||
task_ids = []
|
||||
for i in range(num_tasks):
|
||||
tid = str(uuid.uuid4())
|
||||
task_ids.append(tid)
|
||||
db.rows[tid] = {
|
||||
"id": tid,
|
||||
"site_id": site_id,
|
||||
"config": {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"},
|
||||
"status": "pending",
|
||||
"position": i + 1,
|
||||
}
|
||||
|
||||
# 随机选择要移动的任务和目标位置
|
||||
move_idx = data.draw(st.integers(min_value=0, max_value=num_tasks - 1))
|
||||
move_task_id = task_ids[move_idx]
|
||||
new_position = data.draw(st.integers(min_value=1, max_value=num_tasks + 2))
|
||||
|
||||
# 执行 reorder
|
||||
mock_get_conn.return_value = db.mock_reorder_connection()
|
||||
queue.reorder(move_task_id, new_position, site_id)
|
||||
|
||||
# 验证:所有任务仍在队列中
|
||||
remaining_ids = {r["id"] for r in db.rows.values() if r["status"] == "pending"}
|
||||
assert remaining_ids == set(task_ids), "重排后不应丢失任何任务"
|
||||
|
||||
# 验证:position 值连续且唯一(1-based)
|
||||
positions = sorted(r["position"] for r in db.pending_tasks)
|
||||
assert positions == list(range(1, num_tasks + 1)), (
|
||||
f"重排后 position 应为连续编号 1..{num_tasks},实际 {positions}"
|
||||
)
|
||||
|
||||
# 验证:被移动的任务在正确位置
|
||||
# reorder 内部逻辑:clamp new_position 到 [1, len(others)+1]
|
||||
clamped_pos = max(1, min(new_position, num_tasks))
|
||||
actual_pos = db.rows[move_task_id]["position"]
|
||||
assert actual_pos == clamped_pos, (
|
||||
f"被移动任务的 position 应为 {clamped_pos}(clamp 后),实际 {actual_pos}"
|
||||
)
|
||||
|
||||
# 验证:其余任务保持原有相对顺序
|
||||
others_before = [tid for tid in task_ids if tid != move_task_id]
|
||||
others_after = sorted(
|
||||
[r for r in db.pending_tasks if r["id"] != move_task_id],
|
||||
key=lambda r: r["position"],
|
||||
)
|
||||
others_after_ids = [r["id"] for r in others_after]
|
||||
assert others_after_ids == others_before, (
|
||||
"其余任务的相对顺序应保持不变"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 11: 执行历史排序与限制
|
||||
# **Validates: Requirements 4.5, 8.2**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 导入 FastAPI 测试客户端
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
def _make_history_rows(count: int, site_id: int) -> list[tuple]:
|
||||
"""生成 count 条执行历史记录,started_at 随机但可排序。"""
|
||||
base_time = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
rows = []
|
||||
for i in range(count):
|
||||
rows.append((
|
||||
str(uuid.uuid4()), # id
|
||||
site_id, # site_id
|
||||
["ODS_MEMBER"], # task_codes
|
||||
"success", # status
|
||||
base_time + timedelta(hours=i), # started_at
|
||||
base_time + timedelta(hours=i, minutes=30), # finished_at
|
||||
0, # exit_code
|
||||
1800000, # duration_ms
|
||||
"python -m cli.main", # command
|
||||
None, # summary
|
||||
))
|
||||
return rows
|
||||
|
||||
|
||||
@settings(max_examples=100, deadline=None)
|
||||
@given(
|
||||
site_id=_site_id_st,
|
||||
total_records=st.integers(min_value=0, max_value=30),
|
||||
limit=st.integers(min_value=1, max_value=200),
|
||||
)
|
||||
@patch("app.routers.execution.get_connection")
|
||||
def test_execution_history_sort_and_limit(mock_get_conn, site_id, total_records, limit):
|
||||
"""Property 11: 执行历史排序与限制。
|
||||
|
||||
执行历史记录集合,API 返回的结果应按 started_at 降序排列,
|
||||
且结果数量不超过请求的 limit 值。
|
||||
"""
|
||||
# 生成测试数据
|
||||
all_rows = _make_history_rows(total_records, site_id)
|
||||
|
||||
# 模拟数据库:按 started_at DESC 排序后取 limit 条
|
||||
sorted_rows = sorted(all_rows, key=lambda r: r[4], reverse=True)
|
||||
returned_rows = sorted_rows[:limit]
|
||||
|
||||
# mock 数据库连接
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchall.return_value = returned_rows
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
# 覆盖认证依赖
|
||||
test_user = CurrentUser(user_id=1, site_id=site_id)
|
||||
app.dependency_overrides[get_current_user] = lambda: test_user
|
||||
|
||||
try:
|
||||
client = TestClient(app)
|
||||
# limit 必须在 [1, 200] 范围内(API 约束)
|
||||
clamped_limit = max(1, min(limit, 200))
|
||||
resp = client.get(f"/api/execution/history?limit={clamped_limit}")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
|
||||
# 验证 1:结果数量不超过 limit
|
||||
assert len(data) <= clamped_limit, (
|
||||
f"结果数量 {len(data)} 超过 limit {clamped_limit}"
|
||||
)
|
||||
|
||||
# 验证 2:结果数量不超过总记录数
|
||||
assert len(data) <= total_records, (
|
||||
f"结果数量 {len(data)} 超过总记录数 {total_records}"
|
||||
)
|
||||
|
||||
# 验证 3:按 started_at 降序排列
|
||||
if len(data) >= 2:
|
||||
for i in range(len(data) - 1):
|
||||
t1 = data[i]["started_at"]
|
||||
t2 = data[i + 1]["started_at"]
|
||||
assert t1 >= t2, (
|
||||
f"结果未按 started_at 降序排列:data[{i}]={t1} < data[{i+1}]={t2}"
|
||||
)
|
||||
finally:
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(user_id=1, site_id=100)
|
||||
439
apps/backend/tests/test_schedule_properties.py
Normal file
439
apps/backend/tests/test_schedule_properties.py
Normal file
@@ -0,0 +1,439 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""调度属性测试(Property-Based Testing)。
|
||||
|
||||
使用 hypothesis 验证调度管理的通用正确性属性:
|
||||
- Property 12: 调度任务 CRUD 往返
|
||||
- Property 13: 到期调度任务自动入队
|
||||
- Property 14: 调度任务启用/禁用状态
|
||||
|
||||
测试策略:
|
||||
- Property 12: 通过 mock 数据库,验证 POST 创建后 GET 返回的 schedule_config 与提交的一致
|
||||
- Property 13: 通过 mock 数据库返回到期任务,验证 check_and_enqueue 调用了 task_queue.enqueue
|
||||
- Property 14: 通过 mock 数据库,验证 toggle 端点的 next_run_at 行为
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-schedule-properties")
|
||||
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
from app.schemas.schedules import ScheduleConfigSchema
|
||||
from app.schemas.tasks import TaskConfigSchema
|
||||
from app.services.scheduler import Scheduler, calculate_next_run
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 通用策略(Strategies)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_site_id_st = st.integers(min_value=1, max_value=2**31 - 1)
|
||||
|
||||
_task_codes = ["ODS_MEMBER", "ODS_PAYMENT", "ODS_ORDER", "DWD_LOAD_FROM_ODS", "DWS_SUMMARY"]
|
||||
|
||||
_simple_task_config_st = st.fixed_dictionaries({
|
||||
"tasks": st.lists(st.sampled_from(_task_codes), min_size=1, max_size=3, unique=True),
|
||||
"pipeline": st.sampled_from(["api_ods", "api_ods_dwd", "ods_dwd", "api_full"]),
|
||||
})
|
||||
|
||||
# 调度配置策略:覆盖 5 种调度类型
|
||||
_schedule_type_st = st.sampled_from(["once", "interval", "daily", "weekly", "cron"])
|
||||
|
||||
_interval_unit_st = st.sampled_from(["minutes", "hours", "days"])
|
||||
|
||||
# HH:MM 格式的时间字符串
|
||||
_time_str_st = st.builds(
|
||||
lambda h, m: f"{h:02d}:{m:02d}",
|
||||
h=st.integers(min_value=0, max_value=23),
|
||||
m=st.integers(min_value=0, max_value=59),
|
||||
)
|
||||
|
||||
# ISO weekday 列表(1=Monday ... 7=Sunday)
|
||||
_weekly_days_st = st.lists(
|
||||
st.integers(min_value=1, max_value=7),
|
||||
min_size=1, max_size=7, unique=True,
|
||||
)
|
||||
|
||||
# 简单 cron 表达式(minute hour * * *)
|
||||
_cron_st = st.builds(
|
||||
lambda m, h: f"{m} {h} * * *",
|
||||
m=st.integers(min_value=0, max_value=59),
|
||||
h=st.integers(min_value=0, max_value=23),
|
||||
)
|
||||
|
||||
|
||||
def _build_schedule_config(schedule_type, interval_value, interval_unit,
|
||||
daily_time, weekly_days, weekly_time, cron_expression):
|
||||
"""根据 schedule_type 构建 ScheduleConfigSchema。"""
|
||||
return ScheduleConfigSchema(
|
||||
schedule_type=schedule_type,
|
||||
interval_value=interval_value,
|
||||
interval_unit=interval_unit,
|
||||
daily_time=daily_time,
|
||||
weekly_days=weekly_days,
|
||||
weekly_time=weekly_time,
|
||||
cron_expression=cron_expression,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
|
||||
_schedule_config_st = st.builds(
|
||||
_build_schedule_config,
|
||||
schedule_type=_schedule_type_st,
|
||||
interval_value=st.integers(min_value=1, max_value=168),
|
||||
interval_unit=_interval_unit_st,
|
||||
daily_time=_time_str_st,
|
||||
weekly_days=_weekly_days_st,
|
||||
weekly_time=_time_str_st,
|
||||
cron_expression=_cron_st,
|
||||
)
|
||||
|
||||
# 用于 Property 14 的非 once 调度配置(启用后 next_run_at 应非 NULL)
|
||||
_non_once_schedule_type_st = st.sampled_from(["interval", "daily", "weekly", "cron"])
|
||||
|
||||
_non_once_schedule_config_st = st.builds(
|
||||
_build_schedule_config,
|
||||
schedule_type=_non_once_schedule_type_st,
|
||||
interval_value=st.integers(min_value=1, max_value=168),
|
||||
interval_unit=_interval_unit_st,
|
||||
daily_time=_time_str_st,
|
||||
weekly_days=_weekly_days_st,
|
||||
weekly_time=_time_str_st,
|
||||
cron_expression=_cron_st,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_NOW = datetime(2025, 6, 10, 10, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
# 模拟数据库行的列顺序(与 _SELECT_COLS 对应,共 13 列)
|
||||
# id, site_id, name, task_codes, task_config, schedule_config,
|
||||
# enabled, last_run_at, next_run_at, run_count, last_status,
|
||||
# created_at, updated_at
|
||||
|
||||
|
||||
def _make_db_row(
|
||||
schedule_id: str,
|
||||
site_id: int,
|
||||
name: str,
|
||||
task_codes: list[str],
|
||||
task_config: dict,
|
||||
schedule_config: dict,
|
||||
enabled: bool = True,
|
||||
next_run_at: datetime | None = None,
|
||||
) -> tuple:
|
||||
"""构造模拟数据库行。"""
|
||||
return (
|
||||
schedule_id, site_id, name, task_codes,
|
||||
json.dumps(task_config) if isinstance(task_config, dict) else task_config,
|
||||
json.dumps(schedule_config) if isinstance(schedule_config, dict) else schedule_config,
|
||||
enabled, None, next_run_at, 0, None, _NOW, _NOW,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 12: 调度任务 CRUD 往返
|
||||
# **Validates: Requirements 5.1, 5.4**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
site_id=_site_id_st,
|
||||
schedule_config=_schedule_config_st,
|
||||
task_config=_simple_task_config_st,
|
||||
name=st.text(min_size=1, max_size=50, alphabet=st.characters(
|
||||
whitelist_categories=("L", "N"), whitelist_characters="_- "
|
||||
)),
|
||||
task_codes=st.lists(st.sampled_from(_task_codes), min_size=1, max_size=3, unique=True),
|
||||
)
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_schedule_crud_round_trip(
|
||||
mock_get_conn, site_id, schedule_config, task_config, name, task_codes,
|
||||
):
|
||||
"""Property 12: 调度任务 CRUD 往返。
|
||||
|
||||
有效的 ScheduleConfigSchema,创建调度任务后再查询该任务,
|
||||
返回的调度配置应与创建时提交的配置等价。
|
||||
"""
|
||||
schedule_config_dict = schedule_config.model_dump()
|
||||
next_run = calculate_next_run(schedule_config, _NOW)
|
||||
|
||||
# 构造创建后数据库返回的行
|
||||
created_row = _make_db_row(
|
||||
schedule_id="test-sched-id",
|
||||
site_id=site_id,
|
||||
name=name,
|
||||
task_codes=task_codes,
|
||||
task_config=task_config,
|
||||
schedule_config=schedule_config_dict,
|
||||
enabled=schedule_config.enabled,
|
||||
next_run_at=next_run,
|
||||
)
|
||||
|
||||
# --- 创建阶段 ---
|
||||
# mock POST 的数据库连接(INSERT ... RETURNING)
|
||||
create_cursor = MagicMock()
|
||||
create_cursor.fetchone.return_value = created_row
|
||||
create_conn = MagicMock()
|
||||
create_conn.cursor.return_value.__enter__ = MagicMock(return_value=create_cursor)
|
||||
create_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# --- 查询阶段 ---
|
||||
# mock GET 的数据库连接(SELECT ... fetchall)
|
||||
list_cursor = MagicMock()
|
||||
list_cursor.fetchall.return_value = [created_row]
|
||||
list_conn = MagicMock()
|
||||
list_conn.cursor.return_value.__enter__ = MagicMock(return_value=list_cursor)
|
||||
list_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# 依次返回 create_conn 和 list_conn
|
||||
mock_get_conn.side_effect = [create_conn, list_conn]
|
||||
|
||||
# 覆盖认证
|
||||
test_user = CurrentUser(user_id=1, site_id=site_id)
|
||||
app.dependency_overrides[get_current_user] = lambda: test_user
|
||||
|
||||
try:
|
||||
client = TestClient(app)
|
||||
|
||||
# 创建调度任务
|
||||
create_body = {
|
||||
"name": name,
|
||||
"task_codes": task_codes,
|
||||
"task_config": task_config,
|
||||
"schedule_config": schedule_config_dict,
|
||||
}
|
||||
create_resp = client.post("/api/schedules", json=create_body)
|
||||
assert create_resp.status_code == 201, (
|
||||
f"创建应返回 201,实际 {create_resp.status_code}: {create_resp.text}"
|
||||
)
|
||||
created_data = create_resp.json()
|
||||
|
||||
# 查询调度任务列表
|
||||
list_resp = client.get("/api/schedules")
|
||||
assert list_resp.status_code == 200
|
||||
list_data = list_resp.json()
|
||||
|
||||
assert len(list_data) >= 1, "查询结果应至少包含刚创建的任务"
|
||||
|
||||
# 找到刚创建的任务
|
||||
found = next((s for s in list_data if s["id"] == created_data["id"]), None)
|
||||
assert found is not None, "查询结果应包含刚创建的任务"
|
||||
|
||||
# 核心验证:schedule_config 往返一致
|
||||
returned_config = found["schedule_config"]
|
||||
for key in schedule_config_dict:
|
||||
assert returned_config[key] == schedule_config_dict[key], (
|
||||
f"schedule_config.{key} 不一致:"
|
||||
f"提交={schedule_config_dict[key]},返回={returned_config[key]}"
|
||||
)
|
||||
|
||||
# 验证 task_config 往返一致
|
||||
returned_task_config = found["task_config"]
|
||||
for key in task_config:
|
||||
assert returned_task_config[key] == task_config[key], (
|
||||
f"task_config.{key} 不一致:提交={task_config[key]},返回={returned_task_config[key]}"
|
||||
)
|
||||
|
||||
# 验证基本字段
|
||||
assert found["name"] == name
|
||||
assert found["task_codes"] == task_codes
|
||||
assert found["site_id"] == site_id
|
||||
finally:
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(user_id=1, site_id=100)
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 13: 到期调度任务自动入队
|
||||
# **Validates: Requirements 5.2**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
site_id=_site_id_st,
|
||||
schedule_config=_schedule_config_st,
|
||||
task_config=_simple_task_config_st,
|
||||
)
|
||||
@patch("app.services.scheduler.task_queue")
|
||||
@patch("app.services.scheduler.get_connection")
|
||||
def test_due_schedule_auto_enqueue(
|
||||
mock_get_conn, mock_tq, site_id, schedule_config, task_config,
|
||||
):
|
||||
"""Property 13: 到期调度任务自动入队。
|
||||
|
||||
enabled 为 true 且 next_run_at 早于当前时间的调度任务,
|
||||
check_and_enqueue 执行后该任务的 TaskConfig 应出现在执行队列中。
|
||||
"""
|
||||
sched = Scheduler()
|
||||
schedule_config_dict = schedule_config.model_dump()
|
||||
|
||||
# 构造到期任务:next_run_at 在过去(比 now 早 5 分钟)
|
||||
task_id = "due-task-001"
|
||||
|
||||
# --- mock SELECT 到期任务 ---
|
||||
select_cursor = MagicMock()
|
||||
select_cursor.fetchall.return_value = [
|
||||
(task_id, site_id, json.dumps(task_config), json.dumps(schedule_config_dict)),
|
||||
]
|
||||
select_cursor.__enter__ = MagicMock(return_value=select_cursor)
|
||||
select_cursor.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# --- mock UPDATE 调度状态 ---
|
||||
update_cursor = MagicMock()
|
||||
update_cursor.__enter__ = MagicMock(return_value=update_cursor)
|
||||
update_cursor.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
conn = MagicMock()
|
||||
conn.cursor.side_effect = [select_cursor, update_cursor]
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
mock_tq.enqueue.return_value = "queue-id-123"
|
||||
|
||||
# 执行
|
||||
count = sched.check_and_enqueue()
|
||||
|
||||
# 验证:到期任务被入队
|
||||
assert count == 1, f"应有 1 个任务入队,实际 {count}"
|
||||
mock_tq.enqueue.assert_called_once()
|
||||
|
||||
# 验证入队参数
|
||||
call_args = mock_tq.enqueue.call_args
|
||||
enqueued_config = call_args[0][0]
|
||||
enqueued_site_id = call_args[0][1]
|
||||
|
||||
# site_id 应匹配
|
||||
assert enqueued_site_id == site_id, (
|
||||
f"入队的 site_id 应为 {site_id},实际 {enqueued_site_id}"
|
||||
)
|
||||
|
||||
# TaskConfig 应与原始配置一致
|
||||
assert isinstance(enqueued_config, TaskConfigSchema)
|
||||
assert enqueued_config.tasks == task_config["tasks"], (
|
||||
f"入队的 tasks 应为 {task_config['tasks']},实际 {enqueued_config.tasks}"
|
||||
)
|
||||
assert enqueued_config.pipeline == task_config["pipeline"], (
|
||||
f"入队的 pipeline 应为 {task_config['pipeline']},实际 {enqueued_config.pipeline}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 14: 调度任务启用/禁用状态
|
||||
# **Validates: Requirements 5.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100, deadline=None)
|
||||
@given(
|
||||
site_id=_site_id_st,
|
||||
schedule_config=_non_once_schedule_config_st,
|
||||
task_config=_simple_task_config_st,
|
||||
name=st.text(min_size=1, max_size=30, alphabet=st.characters(
|
||||
whitelist_categories=("L", "N"), whitelist_characters="_- "
|
||||
)),
|
||||
task_codes=st.lists(st.sampled_from(_task_codes), min_size=1, max_size=3, unique=True),
|
||||
)
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_schedule_toggle_next_run(
|
||||
mock_get_conn, site_id, schedule_config, task_config, name, task_codes,
|
||||
):
|
||||
"""Property 14: 调度任务启用/禁用状态。
|
||||
|
||||
禁用后 next_run_at 应为 NULL;
|
||||
重新启用后 next_run_at 应被重新计算为非 NULL 值(对于非一次性调度)。
|
||||
"""
|
||||
schedule_config_dict = schedule_config.model_dump()
|
||||
next_run_enabled = calculate_next_run(schedule_config, _NOW)
|
||||
|
||||
# --- 第一步:禁用(enabled=True → False)---
|
||||
# toggle 端点先 SELECT 当前状态,再 UPDATE RETURNING
|
||||
|
||||
# 禁用后的数据库行
|
||||
disabled_row = _make_db_row(
|
||||
schedule_id="sched-toggle-1",
|
||||
site_id=site_id,
|
||||
name=name,
|
||||
task_codes=task_codes,
|
||||
task_config=task_config,
|
||||
schedule_config=schedule_config_dict,
|
||||
enabled=False,
|
||||
next_run_at=None, # 禁用后 next_run_at 为 NULL
|
||||
)
|
||||
|
||||
# mock 禁用操作的数据库连接
|
||||
disable_cursor = MagicMock()
|
||||
disable_cursor.fetchone.side_effect = [
|
||||
(True, json.dumps(schedule_config_dict)), # SELECT 当前状态(enabled=True)
|
||||
disabled_row, # UPDATE RETURNING
|
||||
]
|
||||
disable_conn = MagicMock()
|
||||
disable_conn.cursor.return_value.__enter__ = MagicMock(return_value=disable_cursor)
|
||||
disable_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# --- 第二步:启用(enabled=False → True)---
|
||||
enabled_row = _make_db_row(
|
||||
schedule_id="sched-toggle-1",
|
||||
site_id=site_id,
|
||||
name=name,
|
||||
task_codes=task_codes,
|
||||
task_config=task_config,
|
||||
schedule_config=schedule_config_dict,
|
||||
enabled=True,
|
||||
next_run_at=next_run_enabled, # 启用后 next_run_at 被重新计算
|
||||
)
|
||||
|
||||
enable_cursor = MagicMock()
|
||||
enable_cursor.fetchone.side_effect = [
|
||||
(False, json.dumps(schedule_config_dict)), # SELECT 当前状态(enabled=False)
|
||||
enabled_row, # UPDATE RETURNING
|
||||
]
|
||||
enable_conn = MagicMock()
|
||||
enable_conn.cursor.return_value.__enter__ = MagicMock(return_value=enable_cursor)
|
||||
enable_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# 依次返回两个连接
|
||||
mock_get_conn.side_effect = [disable_conn, enable_conn]
|
||||
|
||||
# 覆盖认证
|
||||
test_user = CurrentUser(user_id=1, site_id=site_id)
|
||||
app.dependency_overrides[get_current_user] = lambda: test_user
|
||||
|
||||
try:
|
||||
client = TestClient(app)
|
||||
|
||||
# 禁用
|
||||
disable_resp = client.patch("/api/schedules/sched-toggle-1/toggle")
|
||||
assert disable_resp.status_code == 200, (
|
||||
f"禁用应返回 200,实际 {disable_resp.status_code}: {disable_resp.text}"
|
||||
)
|
||||
disable_data = disable_resp.json()
|
||||
|
||||
# 验证:禁用后 enabled=False,next_run_at=NULL
|
||||
assert disable_data["enabled"] is False, "禁用后 enabled 应为 False"
|
||||
assert disable_data["next_run_at"] is None, "禁用后 next_run_at 应为 NULL"
|
||||
|
||||
# 启用
|
||||
enable_resp = client.patch("/api/schedules/sched-toggle-1/toggle")
|
||||
assert enable_resp.status_code == 200, (
|
||||
f"启用应返回 200,实际 {enable_resp.status_code}: {enable_resp.text}"
|
||||
)
|
||||
enable_data = enable_resp.json()
|
||||
|
||||
# 验证:启用后 enabled=True,next_run_at 非 NULL(非一次性调度)
|
||||
assert enable_data["enabled"] is True, "启用后 enabled 应为 True"
|
||||
assert enable_data["next_run_at"] is not None, (
|
||||
"非一次性调度启用后 next_run_at 应被重新计算为非 NULL 值"
|
||||
)
|
||||
finally:
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(user_id=1, site_id=100)
|
||||
384
apps/backend/tests/test_scheduler.py
Normal file
384
apps/backend/tests/test_scheduler.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Scheduler 单元测试
|
||||
|
||||
覆盖:
|
||||
- calculate_next_run:各种调度类型的下次执行时间计算
|
||||
- _parse_simple_cron:简单 cron 表达式解析
|
||||
- check_and_enqueue:到期检查与入队逻辑
|
||||
- start / stop:后台循环生命周期
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.schemas.schedules import ScheduleConfigSchema
|
||||
from app.schemas.tasks import TaskConfigSchema
|
||||
from app.services.scheduler import (
|
||||
Scheduler,
|
||||
calculate_next_run,
|
||||
_parse_simple_cron,
|
||||
_parse_time,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def sched() -> Scheduler:
|
||||
return Scheduler()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def now() -> datetime:
|
||||
"""固定时间点:2025-06-10 10:00:00 UTC(周二)"""
|
||||
return datetime(2025, 6, 10, 10, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def _mock_cursor(fetchone_val=None, fetchall_val=None, rowcount=1):
|
||||
cur = MagicMock()
|
||||
cur.fetchone.return_value = fetchone_val
|
||||
cur.fetchall.return_value = fetchall_val or []
|
||||
cur.rowcount = rowcount
|
||||
cur.__enter__ = MagicMock(return_value=cur)
|
||||
cur.__exit__ = MagicMock(return_value=False)
|
||||
return cur
|
||||
|
||||
|
||||
def _mock_conn(cursor):
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value = cursor
|
||||
return conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_time
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseTime:
|
||||
def test_standard_format(self):
|
||||
assert _parse_time("04:00") == (4, 0)
|
||||
|
||||
def test_with_minutes(self):
|
||||
assert _parse_time("23:45") == (23, 45)
|
||||
|
||||
def test_midnight(self):
|
||||
assert _parse_time("00:00") == (0, 0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# calculate_next_run — once
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNextRunOnce:
|
||||
def test_once_returns_none(self, now):
|
||||
cfg = ScheduleConfigSchema(schedule_type="once")
|
||||
assert calculate_next_run(cfg, now) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# calculate_next_run — interval
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNextRunInterval:
|
||||
def test_interval_minutes(self, now):
|
||||
cfg = ScheduleConfigSchema(
|
||||
schedule_type="interval", interval_value=15, interval_unit="minutes",
|
||||
)
|
||||
result = calculate_next_run(cfg, now)
|
||||
assert result == now + timedelta(minutes=15)
|
||||
|
||||
def test_interval_hours(self, now):
|
||||
cfg = ScheduleConfigSchema(
|
||||
schedule_type="interval", interval_value=2, interval_unit="hours",
|
||||
)
|
||||
result = calculate_next_run(cfg, now)
|
||||
assert result == now + timedelta(hours=2)
|
||||
|
||||
def test_interval_days(self, now):
|
||||
cfg = ScheduleConfigSchema(
|
||||
schedule_type="interval", interval_value=3, interval_unit="days",
|
||||
)
|
||||
result = calculate_next_run(cfg, now)
|
||||
assert result == now + timedelta(days=3)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# calculate_next_run — daily
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNextRunDaily:
|
||||
def test_daily_next_day(self, now):
|
||||
cfg = ScheduleConfigSchema(schedule_type="daily", daily_time="04:00")
|
||||
result = calculate_next_run(cfg, now)
|
||||
expected = datetime(2025, 6, 11, 4, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_daily_custom_time(self, now):
|
||||
cfg = ScheduleConfigSchema(schedule_type="daily", daily_time="18:30")
|
||||
result = calculate_next_run(cfg, now)
|
||||
expected = datetime(2025, 6, 11, 18, 30, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# calculate_next_run — weekly
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNextRunWeekly:
|
||||
def test_weekly_later_this_week(self, now):
|
||||
# now 是周二(2),weekly_days=[5] 周五 → 3 天后
|
||||
cfg = ScheduleConfigSchema(
|
||||
schedule_type="weekly", weekly_days=[5], weekly_time="08:00",
|
||||
)
|
||||
result = calculate_next_run(cfg, now)
|
||||
expected = datetime(2025, 6, 13, 8, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_weekly_next_week(self, now):
|
||||
# now 是周二(2),weekly_days=[1] 周一 → 下周一(6天后)
|
||||
cfg = ScheduleConfigSchema(
|
||||
schedule_type="weekly", weekly_days=[1], weekly_time="04:00",
|
||||
)
|
||||
result = calculate_next_run(cfg, now)
|
||||
expected = datetime(2025, 6, 16, 4, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_weekly_multiple_days_picks_next(self, now):
|
||||
# now 是周二(2),weekly_days=[1, 4, 6] → 周四(4),2 天后
|
||||
cfg = ScheduleConfigSchema(
|
||||
schedule_type="weekly", weekly_days=[1, 4, 6], weekly_time="09:00",
|
||||
)
|
||||
result = calculate_next_run(cfg, now)
|
||||
expected = datetime(2025, 6, 12, 9, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# calculate_next_run — cron
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNextRunCron:
|
||||
def test_cron_daily(self, now):
|
||||
cfg = ScheduleConfigSchema(schedule_type="cron", cron_expression="30 4 * * *")
|
||||
result = calculate_next_run(cfg, now)
|
||||
expected = datetime(2025, 6, 11, 4, 30, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_cron_with_dow(self, now):
|
||||
# "0 8 * * 5" → 每周五 08:00,now 是周二 → 周五(3天后)
|
||||
cfg = ScheduleConfigSchema(schedule_type="cron", cron_expression="0 8 * * 5")
|
||||
result = calculate_next_run(cfg, now)
|
||||
expected = datetime(2025, 6, 13, 8, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_simple_cron
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseSimpleCron:
|
||||
def test_daily_cron(self, now):
|
||||
result = _parse_simple_cron("0 4 * * *", now)
|
||||
expected = datetime(2025, 6, 11, 4, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_invalid_field_count_fallback(self, now):
|
||||
# 字段数不对,回退到明天 04:00
|
||||
result = _parse_simple_cron("0 4 *", now)
|
||||
expected = datetime(2025, 6, 11, 4, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_wildcard_hour_minute(self, now):
|
||||
# "* * * * *" → hour=0, minute=0,明天 00:00
|
||||
result = _parse_simple_cron("* * * * *", now)
|
||||
expected = datetime(2025, 6, 11, 0, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_dow_sunday(self, now):
|
||||
# "0 6 * * 0" → 每周日 06:00,now 是周二 → 周日(5天后)
|
||||
result = _parse_simple_cron("0 6 * * 0", now)
|
||||
expected = datetime(2025, 6, 15, 6, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_dow_same_day_future_time(self):
|
||||
# 周二 08:00,cron 指定周二 12:00 → 当天
|
||||
now = datetime(2025, 6, 10, 8, 0, 0, tzinfo=timezone.utc)
|
||||
result = _parse_simple_cron("0 12 * * 2", now)
|
||||
expected = datetime(2025, 6, 10, 12, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_dow_same_day_past_time(self):
|
||||
# 周二 14:00,cron 指定周二 12:00 → 下周二
|
||||
now = datetime(2025, 6, 10, 14, 0, 0, tzinfo=timezone.utc)
|
||||
result = _parse_simple_cron("0 12 * * 2", now)
|
||||
expected = datetime(2025, 6, 17, 12, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_and_enqueue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckAndEnqueue:
|
||||
@patch("app.services.scheduler.get_connection")
|
||||
@patch("app.services.scheduler.task_queue")
|
||||
def test_enqueues_due_tasks(self, mock_tq, mock_get_conn, sched):
|
||||
"""到期任务应被入队,且更新 last_run_at / run_count / next_run_at"""
|
||||
task_config = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods_dwd"}
|
||||
schedule_config = {
|
||||
"schedule_type": "interval",
|
||||
"interval_value": 1,
|
||||
"interval_unit": "hours",
|
||||
}
|
||||
|
||||
# 第一次 cursor:SELECT 到期任务
|
||||
select_cur = _mock_cursor(
|
||||
fetchall_val=[
|
||||
("task-uuid-1", 42, json.dumps(task_config), json.dumps(schedule_config)),
|
||||
]
|
||||
)
|
||||
# 第二次 cursor:UPDATE
|
||||
update_cur = _mock_cursor()
|
||||
|
||||
conn = MagicMock()
|
||||
# cursor() 依次返回 select_cur 和 update_cur
|
||||
conn.cursor.side_effect = [select_cur, update_cur]
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
mock_tq.enqueue.return_value = "queue-id-1"
|
||||
|
||||
count = sched.check_and_enqueue()
|
||||
|
||||
assert count == 1
|
||||
mock_tq.enqueue.assert_called_once()
|
||||
# 验证 enqueue 的参数
|
||||
call_args = mock_tq.enqueue.call_args
|
||||
assert call_args[0][1] == 42 # site_id
|
||||
assert isinstance(call_args[0][0], TaskConfigSchema)
|
||||
|
||||
@patch("app.services.scheduler.get_connection")
|
||||
@patch("app.services.scheduler.task_queue")
|
||||
def test_no_due_tasks(self, mock_tq, mock_get_conn, sched):
|
||||
"""没有到期任务时,不入队"""
|
||||
cur = _mock_cursor(fetchall_val=[])
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
count = sched.check_and_enqueue()
|
||||
|
||||
assert count == 0
|
||||
mock_tq.enqueue.assert_not_called()
|
||||
|
||||
@patch("app.services.scheduler.get_connection")
|
||||
@patch("app.services.scheduler.task_queue")
|
||||
def test_skips_invalid_config(self, mock_tq, mock_get_conn, sched):
|
||||
"""配置反序列化失败的任务应被跳过"""
|
||||
# task_config 缺少必填字段 tasks
|
||||
bad_config = {"pipeline": "api_ods_dwd"}
|
||||
schedule_config = {"schedule_type": "once"}
|
||||
|
||||
cur = _mock_cursor(
|
||||
fetchall_val=[
|
||||
("task-uuid-bad", 42, json.dumps(bad_config), json.dumps(schedule_config)),
|
||||
]
|
||||
)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
count = sched.check_and_enqueue()
|
||||
|
||||
assert count == 0
|
||||
mock_tq.enqueue.assert_not_called()
|
||||
|
||||
@patch("app.services.scheduler.get_connection")
|
||||
@patch("app.services.scheduler.task_queue")
|
||||
def test_enqueue_failure_continues(self, mock_tq, mock_get_conn, sched):
|
||||
"""入队失败时应跳过该任务,继续处理后续任务"""
|
||||
task_config = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods_dwd"}
|
||||
schedule_config = {"schedule_type": "once"}
|
||||
|
||||
cur = _mock_cursor(
|
||||
fetchall_val=[
|
||||
("task-1", 42, json.dumps(task_config), json.dumps(schedule_config)),
|
||||
("task-2", 42, json.dumps(task_config), json.dumps(schedule_config)),
|
||||
]
|
||||
)
|
||||
# 需要额外的 cursor 给 UPDATE 用
|
||||
update_cur = _mock_cursor()
|
||||
conn = MagicMock()
|
||||
conn.cursor.side_effect = [cur, update_cur]
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
# 第一次入队失败,第二次成功
|
||||
mock_tq.enqueue.side_effect = [Exception("DB error"), "queue-id-2"]
|
||||
|
||||
count = sched.check_and_enqueue()
|
||||
|
||||
assert count == 1
|
||||
assert mock_tq.enqueue.call_count == 2
|
||||
|
||||
@patch("app.services.scheduler.get_connection")
|
||||
@patch("app.services.scheduler.task_queue")
|
||||
def test_once_type_sets_next_run_none(self, mock_tq, mock_get_conn, sched):
|
||||
"""once 类型任务入队后,next_run_at 应被设为 NULL"""
|
||||
task_config = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods_dwd"}
|
||||
schedule_config = {"schedule_type": "once"}
|
||||
|
||||
select_cur = _mock_cursor(
|
||||
fetchall_val=[
|
||||
("task-uuid-1", 42, json.dumps(task_config), json.dumps(schedule_config)),
|
||||
]
|
||||
)
|
||||
update_cur = _mock_cursor()
|
||||
conn = MagicMock()
|
||||
conn.cursor.side_effect = [select_cur, update_cur]
|
||||
mock_get_conn.return_value = conn
|
||||
mock_tq.enqueue.return_value = "queue-id-1"
|
||||
|
||||
sched.check_and_enqueue()
|
||||
|
||||
# 验证 UPDATE 语句中 next_run_at 参数为 None
|
||||
update_call = update_cur.__enter__().execute.call_args
|
||||
# 参数元组的第一个元素是 next_run_at
|
||||
assert update_call[0][1][0] is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# start / stop 生命周期
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLifecycle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_sets_running_false(self, sched):
|
||||
sched._running = True
|
||||
await sched.stop()
|
||||
assert sched._running is False
|
||||
assert sched._loop_task is None
|
||||
|
||||
def test_start_creates_task(self, sched):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
# 在事件循环中启动
|
||||
async def _run():
|
||||
sched.start()
|
||||
assert sched._loop_task is not None
|
||||
assert not sched._loop_task.done()
|
||||
await sched.stop()
|
||||
|
||||
loop.run_until_complete(_run())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_stop_idempotent(self, sched):
|
||||
"""多次 stop 不应报错"""
|
||||
await sched.stop()
|
||||
await sched.stop()
|
||||
assert sched._loop_task is None
|
||||
310
apps/backend/tests/test_schedules_router.py
Normal file
310
apps/backend/tests/test_schedules_router.py
Normal file
@@ -0,0 +1,310 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""调度路由单元测试
|
||||
|
||||
覆盖 5 个端点:list / create / update / delete / toggle
|
||||
通过 mock 绕过数据库,专注路由逻辑验证。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
|
||||
_TEST_USER = CurrentUser(user_id=1, site_id=100)
|
||||
|
||||
|
||||
def _override_auth():
|
||||
return _TEST_USER
|
||||
|
||||
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
client = TestClient(app)
|
||||
|
||||
_NOW = datetime(2024, 6, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
_NEXT = datetime(2024, 6, 2, 4, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
_SCHEDULE_CONFIG = {
|
||||
"schedule_type": "daily",
|
||||
"daily_time": "04:00",
|
||||
}
|
||||
|
||||
_VALID_CREATE = {
|
||||
"name": "每日全量同步",
|
||||
"task_codes": ["ODS_MEMBER", "ODS_ORDER"],
|
||||
"task_config": {"tasks": ["ODS_MEMBER", "ODS_ORDER"], "pipeline": "api_ods"},
|
||||
"schedule_config": _SCHEDULE_CONFIG,
|
||||
}
|
||||
|
||||
# 模拟数据库返回的完整行(13 列,与 _SELECT_COLS 对应)
|
||||
_DB_ROW = (
|
||||
"sched-1", 100, "每日全量同步", ["ODS_MEMBER", "ODS_ORDER"],
|
||||
json.dumps({"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}),
|
||||
json.dumps(_SCHEDULE_CONFIG),
|
||||
True, None, _NEXT, 0, None, _NOW, _NOW,
|
||||
)
|
||||
|
||||
|
||||
def _mock_conn_with_fetchall(rows):
|
||||
"""构造返回 fetchall 的 mock 连接。"""
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchall.return_value = rows
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_conn, mock_cursor
|
||||
|
||||
|
||||
def _mock_conn_with_fetchone(row):
|
||||
"""构造返回 fetchone 的 mock 连接。"""
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = row
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_conn, mock_cursor
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/schedules
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListSchedules:
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_list_returns_schedules(self, mock_get_conn):
|
||||
mock_conn, _ = _mock_conn_with_fetchall([_DB_ROW])
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.get("/api/schedules")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["id"] == "sched-1"
|
||||
assert data[0]["name"] == "每日全量同步"
|
||||
assert data[0]["site_id"] == 100
|
||||
assert data[0]["enabled"] is True
|
||||
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_list_empty(self, mock_get_conn):
|
||||
mock_conn, _ = _mock_conn_with_fetchall([])
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.get("/api/schedules")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_list_filters_by_site_id(self, mock_get_conn):
|
||||
mock_conn, mock_cursor = _mock_conn_with_fetchall([])
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
client.get("/api/schedules")
|
||||
call_args = mock_cursor.execute.call_args
|
||||
assert call_args[0][1] == (100,)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/schedules
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCreateSchedule:
|
||||
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_create_returns_201(self, mock_get_conn, mock_calc):
|
||||
mock_conn, mock_cursor = _mock_conn_with_fetchone(_DB_ROW)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.post("/api/schedules", json=_VALID_CREATE)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["id"] == "sched-1"
|
||||
assert data["name"] == "每日全量同步"
|
||||
|
||||
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_create_injects_site_id(self, mock_get_conn, mock_calc):
|
||||
mock_conn, mock_cursor = _mock_conn_with_fetchone(_DB_ROW)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
client.post("/api/schedules", json=_VALID_CREATE)
|
||||
# INSERT 的第一个参数应为 site_id=100
|
||||
insert_params = mock_cursor.execute.call_args[0][1]
|
||||
assert insert_params[0] == 100
|
||||
|
||||
def test_create_missing_name_returns_422(self):
|
||||
body = {**_VALID_CREATE}
|
||||
del body["name"]
|
||||
resp = client.post("/api/schedules", json=body)
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_create_invalid_schedule_type_returns_422(self):
|
||||
body = {**_VALID_CREATE, "schedule_config": {"schedule_type": "invalid"}}
|
||||
resp = client.post("/api/schedules", json=body)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PUT /api/schedules/{id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUpdateSchedule:
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_update_name(self, mock_get_conn):
|
||||
updated_row = list(_DB_ROW)
|
||||
updated_row[2] = "新名称"
|
||||
mock_conn, _ = _mock_conn_with_fetchone(tuple(updated_row))
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.put("/api/schedules/sched-1", json={"name": "新名称"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "新名称"
|
||||
|
||||
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_update_schedule_config_recalculates_next_run(self, mock_get_conn, mock_calc):
|
||||
mock_conn, _ = _mock_conn_with_fetchone(_DB_ROW)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.put("/api/schedules/sched-1", json={
|
||||
"schedule_config": {"schedule_type": "interval", "interval_value": 2, "interval_unit": "hours"},
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
mock_calc.assert_called_once()
|
||||
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_update_not_found(self, mock_get_conn):
|
||||
mock_conn, _ = _mock_conn_with_fetchone(None)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.put("/api/schedules/nonexistent", json={"name": "x"})
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_update_empty_body_returns_422(self):
|
||||
resp = client.put("/api/schedules/sched-1", json={})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DELETE /api/schedules/{id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDeleteSchedule:
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_delete_success(self, mock_get_conn):
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.rowcount = 1
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.delete("/api/schedules/sched-1")
|
||||
assert resp.status_code == 200
|
||||
assert "已删除" in resp.json()["message"]
|
||||
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_delete_not_found(self, mock_get_conn):
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.rowcount = 0
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.delete("/api/schedules/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PATCH /api/schedules/{id}/toggle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestToggleSchedule:
|
||||
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_toggle_disable(self, mock_get_conn, mock_calc):
|
||||
"""启用 → 禁用:next_run_at 应置 NULL"""
|
||||
# 第一次 fetchone 返回当前状态(enabled=True)
|
||||
# 第二次 fetchone 返回更新后的行
|
||||
disabled_row = list(_DB_ROW)
|
||||
disabled_row[6] = False # enabled
|
||||
disabled_row[8] = None # next_run_at
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.side_effect = [
|
||||
(True, json.dumps(_SCHEDULE_CONFIG)), # SELECT 当前状态
|
||||
tuple(disabled_row), # UPDATE RETURNING
|
||||
]
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.patch("/api/schedules/sched-1/toggle")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["enabled"] is False
|
||||
assert data["next_run_at"] is None
|
||||
|
||||
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_toggle_enable(self, mock_get_conn, mock_calc):
|
||||
"""禁用 → 启用:next_run_at 应被重新计算"""
|
||||
enabled_row = list(_DB_ROW)
|
||||
enabled_row[6] = True
|
||||
enabled_row[8] = _NEXT
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.side_effect = [
|
||||
(False, json.dumps(_SCHEDULE_CONFIG)), # SELECT 当前状态(disabled)
|
||||
tuple(enabled_row), # UPDATE RETURNING
|
||||
]
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.patch("/api/schedules/sched-1/toggle")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["enabled"] is True
|
||||
assert data["next_run_at"] is not None
|
||||
mock_calc.assert_called_once()
|
||||
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_toggle_not_found(self, mock_get_conn):
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = None
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.patch("/api/schedules/nonexistent/toggle")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 认证测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSchedulesAuth:
|
||||
def test_requires_auth(self):
|
||||
"""移除认证覆盖后,所有端点应返回 401/403"""
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
try:
|
||||
assert client.get("/api/schedules").status_code in (401, 403)
|
||||
assert client.post("/api/schedules", json=_VALID_CREATE).status_code in (401, 403)
|
||||
assert client.put("/api/schedules/x", json={"name": "x"}).status_code in (401, 403)
|
||||
assert client.delete("/api/schedules/x").status_code in (401, 403)
|
||||
assert client.patch("/api/schedules/x/toggle").status_code in (401, 403)
|
||||
finally:
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
336
apps/backend/tests/test_site_isolation_properties.py
Normal file
336
apps/backend/tests/test_site_isolation_properties.py
Normal file
@@ -0,0 +1,336 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""门店隔离属性测试(Property-Based Testing)。
|
||||
|
||||
Property 20: 对于任意两个不同 site_id 的 Operator,一个 Operator 查询
|
||||
队列/调度/执行历史时,结果中不应包含另一个 site_id 的数据。
|
||||
|
||||
Validates: Requirements 1.3
|
||||
|
||||
测试策略:
|
||||
- 通过 mock 数据库交互,验证 API 路由在不同 site_id 下的数据隔离
|
||||
- 队列隔离:为 site_id_a 入队任务,用 site_id_b 的 JWT 查询队列,结果应为空
|
||||
- 调度隔离:为 site_id_a 创建调度任务,用 site_id_b 的 JWT 查询调度列表,结果应为空
|
||||
- 执行历史隔离:site_id_a 的执行历史,用 site_id_b 的 JWT 查询不到
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-isolation")
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 通用策略(Strategies)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_site_id_st = st.integers(min_value=1, max_value=2**31 - 1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_mock_user(site_id: int) -> CurrentUser:
|
||||
"""构造指定 site_id 的 mock 用户。"""
|
||||
return CurrentUser(user_id=1, site_id=site_id)
|
||||
|
||||
|
||||
def _make_queue_rows(site_id: int, count: int) -> list[tuple]:
|
||||
"""生成 count 条属于 site_id 的队列行。"""
|
||||
rows = []
|
||||
for i in range(count):
|
||||
rows.append((
|
||||
str(uuid.uuid4()), # id
|
||||
site_id, # site_id
|
||||
json.dumps({"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}), # config
|
||||
"pending", # status
|
||||
i + 1, # position
|
||||
datetime(2024, 1, 1, tzinfo=timezone.utc), # created_at
|
||||
None, # started_at
|
||||
None, # finished_at
|
||||
None, # exit_code
|
||||
None, # error_message
|
||||
))
|
||||
return rows
|
||||
|
||||
|
||||
def _make_schedule_rows(site_id: int, count: int) -> list[tuple]:
|
||||
"""生成 count 条属于 site_id 的调度行。"""
|
||||
now = datetime.now(timezone.utc)
|
||||
rows = []
|
||||
for i in range(count):
|
||||
rows.append((
|
||||
str(uuid.uuid4()), # id
|
||||
site_id, # site_id
|
||||
f"调度任务_{i}", # name
|
||||
["ODS_MEMBER"], # task_codes
|
||||
{"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}, # task_config
|
||||
{"schedule_type": "daily", "daily_time": "04:00", # schedule_config
|
||||
"interval_value": 1, "interval_unit": "hours",
|
||||
"weekly_days": [1], "weekly_time": "04:00",
|
||||
"cron_expression": "0 4 * * *", "enabled": True,
|
||||
"start_date": None, "end_date": None},
|
||||
True, # enabled
|
||||
None, # last_run_at
|
||||
now + timedelta(hours=1), # next_run_at
|
||||
0, # run_count
|
||||
None, # last_status
|
||||
now, # created_at
|
||||
now, # updated_at
|
||||
))
|
||||
return rows
|
||||
|
||||
|
||||
def _make_history_rows(site_id: int, count: int) -> list[tuple]:
|
||||
"""生成 count 条属于 site_id 的执行历史行。"""
|
||||
base_time = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
rows = []
|
||||
for i in range(count):
|
||||
rows.append((
|
||||
str(uuid.uuid4()), # id
|
||||
site_id, # site_id
|
||||
["ODS_MEMBER"], # task_codes
|
||||
"success", # status
|
||||
base_time + timedelta(hours=i), # started_at
|
||||
base_time + timedelta(hours=i, minutes=30), # finished_at
|
||||
0, # exit_code
|
||||
1800000, # duration_ms
|
||||
"python -m cli.main", # command
|
||||
None, # summary
|
||||
))
|
||||
return rows
|
||||
|
||||
|
||||
def _mock_conn_returning(rows: list[tuple]) -> MagicMock:
|
||||
"""构造一个 mock connection,其 cursor.fetchall 返回指定行。"""
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchall.return_value = rows
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 20.1: 队列隔离
|
||||
# **Validates: Requirements 1.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100, deadline=None)
|
||||
@given(
|
||||
site_id_a=_site_id_st,
|
||||
site_id_b=_site_id_st,
|
||||
queue_count=st.integers(min_value=1, max_value=5),
|
||||
)
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_queue_isolation(mock_get_conn, site_id_a, site_id_b, queue_count):
|
||||
"""Property 20.1: 队列隔离。
|
||||
|
||||
为 site_id_a 入队若干任务后,用 site_id_b 的身份查询队列,
|
||||
结果应为空——不同门店的队列数据互不可见。
|
||||
"""
|
||||
assume(site_id_a != site_id_b)
|
||||
|
||||
# site_id_a 的队列数据
|
||||
rows_a = _make_queue_rows(site_id_a, queue_count)
|
||||
|
||||
# 核心隔离逻辑:根据查询时传入的 site_id 过滤
|
||||
# list_pending 内部 SQL: WHERE site_id = %s AND status = 'pending'
|
||||
def conn_for_site(querying_site_id):
|
||||
"""模拟数据库行为:只返回匹配 site_id 的行。"""
|
||||
if querying_site_id == site_id_a:
|
||||
return rows_a
|
||||
return [] # site_id_b 查不到 site_id_a 的数据
|
||||
|
||||
captured_params = {}
|
||||
|
||||
def make_mock_conn():
|
||||
mock_cursor = MagicMock()
|
||||
|
||||
def execute_side_effect(sql, params=None):
|
||||
if params:
|
||||
captured_params["site_id"] = params[0]
|
||||
# 根据 SQL 中的 site_id 参数返回对应数据
|
||||
mock_cursor.fetchall.return_value = conn_for_site(params[0])
|
||||
|
||||
mock_cursor.execute = MagicMock(side_effect=execute_side_effect)
|
||||
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_cursor.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
return mock_conn
|
||||
|
||||
mock_get_conn.return_value = make_mock_conn()
|
||||
|
||||
# 用 site_id_b 的身份查询队列
|
||||
app.dependency_overrides[get_current_user] = lambda: _make_mock_user(site_id_b)
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/execution/queue")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
|
||||
# 验证:site_id_b 查不到 site_id_a 的任何数据
|
||||
assert len(data) == 0, (
|
||||
f"site_id_b={site_id_b} 不应看到 site_id_a={site_id_a} 的队列数据,"
|
||||
f"但返回了 {len(data)} 条记录"
|
||||
)
|
||||
|
||||
# 额外验证:即使有数据返回,也不应包含 site_id_a 的记录
|
||||
for item in data:
|
||||
assert item.get("site_id") != site_id_a, (
|
||||
f"结果中不应包含 site_id_a={site_id_a} 的数据"
|
||||
)
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 20.2: 调度隔离
|
||||
# **Validates: Requirements 1.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
site_id_a=_site_id_st,
|
||||
site_id_b=_site_id_st,
|
||||
schedule_count=st.integers(min_value=1, max_value=5),
|
||||
)
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_schedule_isolation(mock_get_conn, site_id_a, site_id_b, schedule_count):
|
||||
"""Property 20.2: 调度隔离。
|
||||
|
||||
为 site_id_a 创建若干调度任务后,用 site_id_b 的身份查询调度列表,
|
||||
结果应为空——不同门店的调度数据互不可见。
|
||||
"""
|
||||
assume(site_id_a != site_id_b)
|
||||
|
||||
# site_id_a 的调度数据
|
||||
rows_a = _make_schedule_rows(site_id_a, schedule_count)
|
||||
|
||||
def make_mock_conn():
|
||||
mock_cursor = MagicMock()
|
||||
|
||||
def execute_side_effect(sql, params=None):
|
||||
if params:
|
||||
querying_site_id = params[0]
|
||||
# 只返回匹配 site_id 的行
|
||||
if querying_site_id == site_id_a:
|
||||
mock_cursor.fetchall.return_value = rows_a
|
||||
else:
|
||||
mock_cursor.fetchall.return_value = []
|
||||
|
||||
mock_cursor.execute = MagicMock(side_effect=execute_side_effect)
|
||||
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_cursor.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
return mock_conn
|
||||
|
||||
mock_get_conn.return_value = make_mock_conn()
|
||||
|
||||
# 用 site_id_b 的身份查询调度列表
|
||||
app.dependency_overrides[get_current_user] = lambda: _make_mock_user(site_id_b)
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/schedules")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
|
||||
# 验证:site_id_b 查不到 site_id_a 的任何调度数据
|
||||
assert len(data) == 0, (
|
||||
f"site_id_b={site_id_b} 不应看到 site_id_a={site_id_a} 的调度数据,"
|
||||
f"但返回了 {len(data)} 条记录"
|
||||
)
|
||||
|
||||
# 额外验证:即使有数据返回,也不应包含 site_id_a 的记录
|
||||
for item in data:
|
||||
assert item.get("site_id") != site_id_a, (
|
||||
f"结果中不应包含 site_id_a={site_id_a} 的调度数据"
|
||||
)
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 20.3: 执行历史隔离
|
||||
# **Validates: Requirements 1.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100, deadline=None)
|
||||
@given(
|
||||
site_id_a=_site_id_st,
|
||||
site_id_b=_site_id_st,
|
||||
history_count=st.integers(min_value=1, max_value=10),
|
||||
)
|
||||
@patch("app.routers.execution.get_connection")
|
||||
def test_execution_history_isolation(mock_get_conn, site_id_a, site_id_b, history_count):
|
||||
"""Property 20.3: 执行历史隔离。
|
||||
|
||||
site_id_a 有若干执行历史记录,用 site_id_b 的身份查询执行历史,
|
||||
结果应为空——不同门店的执行历史互不可见。
|
||||
"""
|
||||
assume(site_id_a != site_id_b)
|
||||
|
||||
# site_id_a 的执行历史数据
|
||||
rows_a = _make_history_rows(site_id_a, history_count)
|
||||
|
||||
def make_mock_conn():
|
||||
mock_cursor = MagicMock()
|
||||
|
||||
def execute_side_effect(sql, params=None):
|
||||
if params:
|
||||
querying_site_id = params[0]
|
||||
# 只返回匹配 site_id 的行
|
||||
if querying_site_id == site_id_a:
|
||||
mock_cursor.fetchall.return_value = rows_a
|
||||
else:
|
||||
mock_cursor.fetchall.return_value = []
|
||||
|
||||
mock_cursor.execute = MagicMock(side_effect=execute_side_effect)
|
||||
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_cursor.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
return mock_conn
|
||||
|
||||
mock_get_conn.return_value = make_mock_conn()
|
||||
|
||||
# 用 site_id_b 的身份查询执行历史
|
||||
app.dependency_overrides[get_current_user] = lambda: _make_mock_user(site_id_b)
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/execution/history")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
|
||||
# 验证:site_id_b 查不到 site_id_a 的任何执行历史
|
||||
assert len(data) == 0, (
|
||||
f"site_id_b={site_id_b} 不应看到 site_id_a={site_id_a} 的执行历史,"
|
||||
f"但返回了 {len(data)} 条记录"
|
||||
)
|
||||
|
||||
# 额外验证:即使有数据返回,也不应包含 site_id_a 的记录
|
||||
for item in data:
|
||||
assert item.get("site_id") != site_id_a, (
|
||||
f"结果中不应包含 site_id_a={site_id_a} 的执行历史"
|
||||
)
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
275
apps/backend/tests/test_task_config_properties.py
Normal file
275
apps/backend/tests/test_task_config_properties.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""TaskConfig 属性测试(Property-Based Testing)。
|
||||
|
||||
使用 hypothesis 验证 TaskConfig 相关的通用正确性属性:
|
||||
- Property 1: TaskConfig 序列化往返一致性
|
||||
- Property 6: 时间窗口验证
|
||||
- Property 7: TaskConfig 到 CLI 命令转换完整性
|
||||
"""
|
||||
|
||||
import datetime
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.schemas.tasks import TaskConfigSchema
|
||||
from app.services.cli_builder import CLIBuilder, VALID_FLOWS, VALID_PROCESSING_MODES
|
||||
from app.services.task_registry import ALL_TASKS
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 策略(Strategies)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 从真实任务注册表中采样任务代码
|
||||
_task_codes = [t.code for t in ALL_TASKS]
|
||||
|
||||
_tasks_st = st.lists(
|
||||
st.sampled_from(_task_codes),
|
||||
min_size=1,
|
||||
max_size=5,
|
||||
unique=True,
|
||||
)
|
||||
|
||||
_pipeline_st = st.sampled_from(sorted(VALID_FLOWS))
|
||||
_processing_mode_st = st.sampled_from(sorted(VALID_PROCESSING_MODES))
|
||||
_window_mode_st = st.sampled_from(["lookback", "custom"])
|
||||
|
||||
# 日期策略:生成 YYYY-MM-DD 格式字符串
|
||||
_date_st = st.dates(
|
||||
min_value=datetime.date(2020, 1, 1),
|
||||
max_value=datetime.date(2030, 12, 31),
|
||||
).map(lambda d: d.isoformat())
|
||||
|
||||
_window_split_st = st.sampled_from([None, "none", "day"])
|
||||
_window_split_days_st = st.one_of(st.none(), st.sampled_from([1, 10, 30]))
|
||||
_lookback_hours_st = st.integers(min_value=1, max_value=720)
|
||||
_overlap_seconds_st = st.integers(min_value=0, max_value=7200)
|
||||
_store_id_st = st.one_of(st.none(), st.integers(min_value=1, max_value=2**31 - 1))
|
||||
|
||||
# DWD 表名采样
|
||||
_dwd_table_names = [
|
||||
"dwd.dim_site",
|
||||
"dwd.dim_member",
|
||||
"dwd.dwd_settlement_head",
|
||||
]
|
||||
_dwd_only_tables_st = st.one_of(
|
||||
st.none(),
|
||||
st.lists(st.sampled_from(_dwd_table_names), min_size=1, max_size=3, unique=True),
|
||||
)
|
||||
|
||||
|
||||
def _valid_task_config_st():
|
||||
"""生成有效的 TaskConfigSchema 的复合策略。
|
||||
|
||||
确保 window_mode=custom 时 window_end >= window_start,
|
||||
避免触发 Pydantic 验证错误。
|
||||
"""
|
||||
|
||||
@st.composite
|
||||
def _build(draw):
|
||||
tasks = draw(_tasks_st)
|
||||
pipeline = draw(_pipeline_st)
|
||||
processing_mode = draw(_processing_mode_st)
|
||||
dry_run = draw(st.booleans())
|
||||
window_mode = draw(_window_mode_st)
|
||||
store_id = draw(_store_id_st)
|
||||
dwd_only_tables = draw(_dwd_only_tables_st)
|
||||
window_split = draw(_window_split_st)
|
||||
window_split_days = draw(_window_split_days_st)
|
||||
fetch_before_verify = draw(st.booleans())
|
||||
skip_ods = draw(st.booleans())
|
||||
ods_local = draw(st.booleans())
|
||||
|
||||
if window_mode == "custom":
|
||||
d1 = draw(st.dates(
|
||||
min_value=datetime.date(2020, 1, 1),
|
||||
max_value=datetime.date(2030, 12, 31),
|
||||
))
|
||||
d2 = draw(st.dates(
|
||||
min_value=datetime.date(2020, 1, 1),
|
||||
max_value=datetime.date(2030, 12, 31),
|
||||
))
|
||||
# 保证 end >= start
|
||||
window_start = min(d1, d2).isoformat()
|
||||
window_end = max(d1, d2).isoformat()
|
||||
lookback_hours = 24
|
||||
overlap_seconds = 600
|
||||
else:
|
||||
window_start = None
|
||||
window_end = None
|
||||
lookback_hours = draw(_lookback_hours_st)
|
||||
overlap_seconds = draw(_overlap_seconds_st)
|
||||
|
||||
return TaskConfigSchema(
|
||||
tasks=tasks,
|
||||
pipeline=pipeline,
|
||||
processing_mode=processing_mode,
|
||||
dry_run=dry_run,
|
||||
window_mode=window_mode,
|
||||
window_start=window_start,
|
||||
window_end=window_end,
|
||||
window_split=window_split,
|
||||
window_split_days=window_split_days,
|
||||
lookback_hours=lookback_hours,
|
||||
overlap_seconds=overlap_seconds,
|
||||
fetch_before_verify=fetch_before_verify,
|
||||
skip_ods_when_fetch_before_verify=skip_ods,
|
||||
ods_use_local_json=ods_local,
|
||||
store_id=store_id,
|
||||
dwd_only_tables=dwd_only_tables,
|
||||
extra_args={},
|
||||
)
|
||||
|
||||
return _build()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 1: TaskConfig 序列化往返一致性
|
||||
# **Validates: Requirements 11.1, 11.2, 11.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@settings(max_examples=200)
|
||||
@given(config=_valid_task_config_st())
|
||||
def test_task_config_round_trip(config: TaskConfigSchema):
|
||||
"""Property 1: 序列化为 JSON 后再反序列化,应产生与原始对象等价的结果。"""
|
||||
json_str = config.model_dump_json()
|
||||
restored = TaskConfigSchema.model_validate_json(json_str)
|
||||
assert restored == config, (
|
||||
f"往返不一致:\n原始={config}\n还原={restored}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 6: 时间窗口验证
|
||||
# **Validates: Requirements 2.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@settings(max_examples=200)
|
||||
@given(
|
||||
d1=st.dates(
|
||||
min_value=datetime.date(2020, 1, 1),
|
||||
max_value=datetime.date(2030, 12, 31),
|
||||
),
|
||||
d2=st.dates(
|
||||
min_value=datetime.date(2020, 1, 1),
|
||||
max_value=datetime.date(2030, 12, 31),
|
||||
),
|
||||
)
|
||||
def test_time_window_validation(d1: datetime.date, d2: datetime.date):
|
||||
"""Property 6: window_end < window_start 时验证应失败,否则应通过。"""
|
||||
start_str = d1.isoformat()
|
||||
end_str = d2.isoformat()
|
||||
|
||||
if end_str < start_str:
|
||||
# window_end 早于 window_start → 验证应失败
|
||||
try:
|
||||
TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
window_mode="custom",
|
||||
window_start=start_str,
|
||||
window_end=end_str,
|
||||
)
|
||||
raise AssertionError(
|
||||
f"期望 ValidationError,但验证通过了:start={start_str}, end={end_str}"
|
||||
)
|
||||
except ValidationError:
|
||||
pass # 预期行为
|
||||
else:
|
||||
# window_end >= window_start → 验证应通过
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
window_mode="custom",
|
||||
window_start=start_str,
|
||||
window_end=end_str,
|
||||
)
|
||||
assert config.window_start == start_str
|
||||
assert config.window_end == end_str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 7: TaskConfig 到 CLI 命令转换完整性
|
||||
# **Validates: Requirements 2.5, 2.6**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_builder = CLIBuilder()
|
||||
_ETL_PATH = "/fake/etl/project"
|
||||
|
||||
|
||||
@settings(max_examples=200)
|
||||
@given(config=_valid_task_config_st())
|
||||
def test_task_config_to_cli_completeness(config: TaskConfigSchema):
|
||||
"""Property 7: CLIBuilder 生成的命令应包含 TaskConfig 中所有非空字段对应的 CLI 参数。"""
|
||||
cmd = _builder.build_command(config, _ETL_PATH)
|
||||
|
||||
# 1) --pipeline 始终存在且值正确
|
||||
assert "--pipeline" in cmd
|
||||
idx = cmd.index("--pipeline")
|
||||
assert cmd[idx + 1] == config.pipeline
|
||||
|
||||
# 2) --processing-mode 始终存在且值正确
|
||||
assert "--processing-mode" in cmd
|
||||
idx = cmd.index("--processing-mode")
|
||||
assert cmd[idx + 1] == config.processing_mode
|
||||
|
||||
# 3) 非空任务列表 → --tasks 存在
|
||||
if config.tasks:
|
||||
assert "--tasks" in cmd
|
||||
idx = cmd.index("--tasks")
|
||||
assert set(cmd[idx + 1].split(",")) == set(config.tasks)
|
||||
|
||||
# 4) 时间窗口参数
|
||||
if config.window_mode == "lookback":
|
||||
# lookback 模式 → --lookback-hours 和 --overlap-seconds
|
||||
if config.lookback_hours is not None:
|
||||
assert "--lookback-hours" in cmd
|
||||
idx = cmd.index("--lookback-hours")
|
||||
assert cmd[idx + 1] == str(config.lookback_hours)
|
||||
if config.overlap_seconds is not None:
|
||||
assert "--overlap-seconds" in cmd
|
||||
idx = cmd.index("--overlap-seconds")
|
||||
assert cmd[idx + 1] == str(config.overlap_seconds)
|
||||
# lookback 模式不应出现 custom 参数
|
||||
assert "--window-start" not in cmd
|
||||
assert "--window-end" not in cmd
|
||||
else:
|
||||
# custom 模式 → --window-start / --window-end
|
||||
if config.window_start:
|
||||
assert "--window-start" in cmd
|
||||
if config.window_end:
|
||||
assert "--window-end" in cmd
|
||||
# custom 模式不应出现 lookback 参数
|
||||
assert "--lookback-hours" not in cmd
|
||||
assert "--overlap-seconds" not in cmd
|
||||
|
||||
# 5) dry_run → --dry-run
|
||||
if config.dry_run:
|
||||
assert "--dry-run" in cmd
|
||||
else:
|
||||
assert "--dry-run" not in cmd
|
||||
|
||||
# 6) store_id → --store-id
|
||||
if config.store_id is not None:
|
||||
assert "--store-id" in cmd
|
||||
idx = cmd.index("--store-id")
|
||||
assert cmd[idx + 1] == str(config.store_id)
|
||||
else:
|
||||
assert "--store-id" not in cmd
|
||||
|
||||
# 7) fetch_before_verify → 仅 verify_only 模式下生成
|
||||
if config.fetch_before_verify and config.processing_mode == "verify_only":
|
||||
assert "--fetch-before-verify" in cmd
|
||||
else:
|
||||
assert "--fetch-before-verify" not in cmd
|
||||
|
||||
# 8) window_split(非 None 且非 "none")→ --window-split
|
||||
if config.window_split and config.window_split != "none":
|
||||
assert "--window-split" in cmd
|
||||
idx = cmd.index("--window-split")
|
||||
assert cmd[idx + 1] == config.window_split
|
||||
if config.window_split_days is not None:
|
||||
assert "--window-split-days" in cmd
|
||||
else:
|
||||
assert "--window-split" not in cmd
|
||||
373
apps/backend/tests/test_task_executor.py
Normal file
373
apps/backend/tests/test_task_executor.py
Normal file
@@ -0,0 +1,373 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""TaskExecutor 单元测试
|
||||
|
||||
覆盖:子进程启动、stdout/stderr 读取、日志广播、取消、数据库记录。
|
||||
使用 asyncio 测试,mock 子进程和数据库连接避免外部依赖。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.schemas.tasks import TaskConfigSchema
|
||||
from app.services.task_executor import TaskExecutor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def executor() -> TaskExecutor:
|
||||
return TaskExecutor()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_config() -> TaskConfigSchema:
|
||||
return TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER", "ODS_PAYMENT"],
|
||||
pipeline="api_ods_dwd",
|
||||
store_id=42,
|
||||
)
|
||||
|
||||
|
||||
def _make_stream(lines: list[bytes]) -> AsyncMock:
|
||||
"""构造一个模拟的 asyncio.StreamReader,按行返回数据。"""
|
||||
stream = AsyncMock()
|
||||
# readline 依次返回每行,最后返回 b"" 表示 EOF
|
||||
stream.readline = AsyncMock(side_effect=[*lines, b""])
|
||||
return stream
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 订阅 / 取消订阅
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSubscription:
|
||||
def test_subscribe_returns_queue(self, executor: TaskExecutor):
|
||||
q = executor.subscribe("exec-1")
|
||||
assert isinstance(q, asyncio.Queue)
|
||||
|
||||
def test_subscribe_multiple(self, executor: TaskExecutor):
|
||||
q1 = executor.subscribe("exec-1")
|
||||
q2 = executor.subscribe("exec-1")
|
||||
assert q1 is not q2
|
||||
assert len(executor._subscribers["exec-1"]) == 2
|
||||
|
||||
def test_unsubscribe_removes_queue(self, executor: TaskExecutor):
|
||||
q = executor.subscribe("exec-1")
|
||||
executor.unsubscribe("exec-1", q)
|
||||
# 最后一个订阅者移除后,键也被清理
|
||||
assert "exec-1" not in executor._subscribers
|
||||
|
||||
def test_unsubscribe_nonexistent_is_safe(self, executor: TaskExecutor):
|
||||
"""对不存在的 execution_id 取消订阅不应报错"""
|
||||
q: asyncio.Queue = asyncio.Queue()
|
||||
executor.unsubscribe("nonexistent", q)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 广播
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBroadcast:
|
||||
def test_broadcast_to_subscribers(self, executor: TaskExecutor):
|
||||
q1 = executor.subscribe("exec-1")
|
||||
q2 = executor.subscribe("exec-1")
|
||||
executor._broadcast("exec-1", "hello")
|
||||
assert q1.get_nowait() == "hello"
|
||||
assert q2.get_nowait() == "hello"
|
||||
|
||||
def test_broadcast_no_subscribers_is_safe(self, executor: TaskExecutor):
|
||||
"""无订阅者时广播不应报错"""
|
||||
executor._broadcast("nonexistent", "hello")
|
||||
|
||||
def test_broadcast_end_sends_none(self, executor: TaskExecutor):
|
||||
q = executor.subscribe("exec-1")
|
||||
executor._broadcast_end("exec-1")
|
||||
assert q.get_nowait() is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 日志缓冲区
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLogBuffer:
|
||||
def test_get_logs_empty(self, executor: TaskExecutor):
|
||||
assert executor.get_logs("nonexistent") == []
|
||||
|
||||
def test_get_logs_returns_copy(self, executor: TaskExecutor):
|
||||
executor._log_buffers["exec-1"] = ["line1", "line2"]
|
||||
logs = executor.get_logs("exec-1")
|
||||
assert logs == ["line1", "line2"]
|
||||
# 修改副本不影响原始
|
||||
logs.append("line3")
|
||||
assert len(executor._log_buffers["exec-1"]) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 执行状态查询
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRunningState:
|
||||
def test_is_running_false_when_no_process(self, executor: TaskExecutor):
|
||||
assert executor.is_running("nonexistent") is False
|
||||
|
||||
def test_is_running_true_when_process_active(self, executor: TaskExecutor):
|
||||
proc = MagicMock()
|
||||
proc.returncode = None
|
||||
executor._processes["exec-1"] = proc
|
||||
assert executor.is_running("exec-1") is True
|
||||
|
||||
def test_is_running_false_when_process_exited(self, executor: TaskExecutor):
|
||||
proc = MagicMock()
|
||||
proc.returncode = 0
|
||||
executor._processes["exec-1"] = proc
|
||||
assert executor.is_running("exec-1") is False
|
||||
|
||||
def test_get_running_ids(self, executor: TaskExecutor):
|
||||
running = MagicMock()
|
||||
running.returncode = None
|
||||
exited = MagicMock()
|
||||
exited.returncode = 0
|
||||
executor._processes["a"] = running
|
||||
executor._processes["b"] = exited
|
||||
assert executor.get_running_ids() == ["a"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_stream
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadStream:
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_stdout_lines(self, executor: TaskExecutor):
|
||||
executor._log_buffers["exec-1"] = []
|
||||
stream = _make_stream([b"line1\n", b"line2\n"])
|
||||
collector: list[str] = []
|
||||
await executor._read_stream("exec-1", stream, "stdout", collector)
|
||||
assert collector == ["line1", "line2"]
|
||||
assert executor._log_buffers["exec-1"] == [
|
||||
"[stdout] line1",
|
||||
"[stdout] line2",
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_stderr_lines(self, executor: TaskExecutor):
|
||||
executor._log_buffers["exec-1"] = []
|
||||
stream = _make_stream([b"err1\n"])
|
||||
collector: list[str] = []
|
||||
await executor._read_stream("exec-1", stream, "stderr", collector)
|
||||
assert collector == ["err1"]
|
||||
assert executor._log_buffers["exec-1"] == ["[stderr] err1"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_stream_none_is_safe(self, executor: TaskExecutor):
|
||||
"""stream 为 None 时不应报错"""
|
||||
collector: list[str] = []
|
||||
await executor._read_stream("exec-1", None, "stdout", collector)
|
||||
assert collector == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_during_read(self, executor: TaskExecutor):
|
||||
executor._log_buffers["exec-1"] = []
|
||||
q = executor.subscribe("exec-1")
|
||||
stream = _make_stream([b"hello\n"])
|
||||
collector: list[str] = []
|
||||
await executor._read_stream("exec-1", stream, "stdout", collector)
|
||||
assert q.get_nowait() == "[stdout] hello"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# execute(集成级,mock 子进程和数据库)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExecute:
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
|
||||
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
|
||||
@patch("asyncio.create_subprocess_exec")
|
||||
async def test_successful_execution(
|
||||
self, mock_create, mock_write_log, mock_update_log,
|
||||
executor: TaskExecutor, sample_config: TaskConfigSchema,
|
||||
):
|
||||
# 模拟子进程
|
||||
proc = AsyncMock()
|
||||
proc.returncode = None
|
||||
proc.stdout = _make_stream([b"processing...\n", b"done\n"])
|
||||
proc.stderr = _make_stream([])
|
||||
proc.wait = AsyncMock(return_value=0)
|
||||
# wait 调用后设置 returncode
|
||||
async def _wait():
|
||||
proc.returncode = 0
|
||||
return 0
|
||||
proc.wait = _wait
|
||||
mock_create.return_value = proc
|
||||
|
||||
await executor.execute(sample_config, "exec-1", site_id=42)
|
||||
|
||||
# 验证写入了 running 状态
|
||||
mock_write_log.assert_called_once()
|
||||
call_kwargs = mock_write_log.call_args[1]
|
||||
assert call_kwargs["status"] == "running"
|
||||
assert call_kwargs["execution_id"] == "exec-1"
|
||||
|
||||
# 验证更新了 success 状态
|
||||
mock_update_log.assert_called_once()
|
||||
update_kwargs = mock_update_log.call_args[1]
|
||||
assert update_kwargs["status"] == "success"
|
||||
assert update_kwargs["exit_code"] == 0
|
||||
assert "processing..." in update_kwargs["output_log"]
|
||||
assert "done" in update_kwargs["output_log"]
|
||||
|
||||
# 进程已从跟踪表移除
|
||||
assert "exec-1" not in executor._processes
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
|
||||
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
|
||||
@patch("asyncio.create_subprocess_exec")
|
||||
async def test_failed_execution(
|
||||
self, mock_create, mock_write_log, mock_update_log,
|
||||
executor: TaskExecutor, sample_config: TaskConfigSchema,
|
||||
):
|
||||
proc = AsyncMock()
|
||||
proc.returncode = None
|
||||
proc.stdout = _make_stream([])
|
||||
proc.stderr = _make_stream([b"error occurred\n"])
|
||||
async def _wait():
|
||||
proc.returncode = 1
|
||||
return 1
|
||||
proc.wait = _wait
|
||||
mock_create.return_value = proc
|
||||
|
||||
await executor.execute(sample_config, "exec-2", site_id=42)
|
||||
|
||||
update_kwargs = mock_update_log.call_args[1]
|
||||
assert update_kwargs["status"] == "failed"
|
||||
assert update_kwargs["exit_code"] == 1
|
||||
assert "error occurred" in update_kwargs["error_log"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
|
||||
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
|
||||
@patch("asyncio.create_subprocess_exec")
|
||||
async def test_exception_during_execution(
|
||||
self, mock_create, mock_write_log, mock_update_log,
|
||||
executor: TaskExecutor, sample_config: TaskConfigSchema,
|
||||
):
|
||||
"""子进程创建失败时应记录 failed 状态"""
|
||||
mock_create.side_effect = OSError("command not found")
|
||||
|
||||
await executor.execute(sample_config, "exec-3", site_id=42)
|
||||
|
||||
update_kwargs = mock_update_log.call_args[1]
|
||||
assert update_kwargs["status"] == "failed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
|
||||
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
|
||||
@patch("asyncio.create_subprocess_exec")
|
||||
async def test_subscribers_notified_on_completion(
|
||||
self, mock_create, mock_write_log, mock_update_log,
|
||||
executor: TaskExecutor, sample_config: TaskConfigSchema,
|
||||
):
|
||||
proc = AsyncMock()
|
||||
proc.returncode = None
|
||||
proc.stdout = _make_stream([b"line\n"])
|
||||
proc.stderr = _make_stream([])
|
||||
async def _wait():
|
||||
proc.returncode = 0
|
||||
return 0
|
||||
proc.wait = _wait
|
||||
mock_create.return_value = proc
|
||||
|
||||
q = executor.subscribe("exec-4")
|
||||
await executor.execute(sample_config, "exec-4", site_id=42)
|
||||
|
||||
# 应收到日志行 + None 哨兵
|
||||
messages = []
|
||||
while not q.empty():
|
||||
messages.append(q.get_nowait())
|
||||
assert "[stdout] line" in messages
|
||||
assert None in messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
|
||||
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
|
||||
@patch("asyncio.create_subprocess_exec")
|
||||
async def test_duration_ms_recorded(
|
||||
self, mock_create, mock_write_log, mock_update_log,
|
||||
executor: TaskExecutor, sample_config: TaskConfigSchema,
|
||||
):
|
||||
proc = AsyncMock()
|
||||
proc.returncode = None
|
||||
proc.stdout = _make_stream([])
|
||||
proc.stderr = _make_stream([])
|
||||
async def _wait():
|
||||
proc.returncode = 0
|
||||
return 0
|
||||
proc.wait = _wait
|
||||
mock_create.return_value = proc
|
||||
|
||||
await executor.execute(sample_config, "exec-5", site_id=42)
|
||||
|
||||
update_kwargs = mock_update_log.call_args[1]
|
||||
assert isinstance(update_kwargs["duration_ms"], int)
|
||||
assert update_kwargs["duration_ms"] >= 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cancel
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCancel:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_running_process(self, executor: TaskExecutor):
|
||||
proc = MagicMock()
|
||||
proc.returncode = None
|
||||
proc.terminate = MagicMock()
|
||||
executor._processes["exec-1"] = proc
|
||||
|
||||
result = await executor.cancel("exec-1")
|
||||
assert result is True
|
||||
proc.terminate.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_nonexistent_returns_false(self, executor: TaskExecutor):
|
||||
result = await executor.cancel("nonexistent")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_already_exited_returns_false(self, executor: TaskExecutor):
|
||||
proc = MagicMock()
|
||||
proc.returncode = 0
|
||||
executor._processes["exec-1"] = proc
|
||||
|
||||
result = await executor.cancel("exec-1")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_process_lookup_error(self, executor: TaskExecutor):
|
||||
"""进程已消失时 terminate 抛出 ProcessLookupError"""
|
||||
proc = MagicMock()
|
||||
proc.returncode = None
|
||||
proc.terminate = MagicMock(side_effect=ProcessLookupError)
|
||||
executor._processes["exec-1"] = proc
|
||||
|
||||
result = await executor.cancel("exec-1")
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cleanup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCleanup:
|
||||
def test_cleanup_removes_buffers_and_subscribers(self, executor: TaskExecutor):
|
||||
executor._log_buffers["exec-1"] = ["line"]
|
||||
executor.subscribe("exec-1")
|
||||
executor.cleanup("exec-1")
|
||||
assert "exec-1" not in executor._log_buffers
|
||||
assert "exec-1" not in executor._subscribers
|
||||
|
||||
def test_cleanup_nonexistent_is_safe(self, executor: TaskExecutor):
|
||||
executor.cleanup("nonexistent")
|
||||
482
apps/backend/tests/test_task_queue.py
Normal file
482
apps/backend/tests/test_task_queue.py
Normal file
@@ -0,0 +1,482 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""TaskQueue 单元测试
|
||||
|
||||
覆盖:enqueue、dequeue、reorder、delete、process_loop 的核心逻辑。
|
||||
使用 mock 数据库操作,专注于业务逻辑验证。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, AsyncMock, patch, call
|
||||
|
||||
import pytest
|
||||
|
||||
from app.schemas.tasks import TaskConfigSchema
|
||||
from app.services.task_queue import TaskQueue, QueuedTask
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def queue() -> TaskQueue:
|
||||
return TaskQueue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_config() -> TaskConfigSchema:
|
||||
return TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER", "ODS_PAYMENT"],
|
||||
pipeline="api_ods_dwd",
|
||||
store_id=42,
|
||||
)
|
||||
|
||||
|
||||
def _mock_cursor(fetchone_val=None, fetchall_val=None, rowcount=1):
|
||||
"""构造 mock cursor,支持 context manager 协议。"""
|
||||
cur = MagicMock()
|
||||
cur.fetchone.return_value = fetchone_val
|
||||
cur.fetchall.return_value = fetchall_val or []
|
||||
cur.rowcount = rowcount
|
||||
cur.__enter__ = MagicMock(return_value=cur)
|
||||
cur.__exit__ = MagicMock(return_value=False)
|
||||
return cur
|
||||
|
||||
|
||||
def _mock_conn(cursor):
|
||||
"""构造 mock connection,支持 cursor() context manager。"""
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value = cursor
|
||||
return conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# enqueue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEnqueue:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_enqueue_returns_uuid(self, mock_get_conn, queue, sample_config):
|
||||
cur = _mock_cursor(fetchone_val=(0,))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
task_id = queue.enqueue(sample_config, site_id=42)
|
||||
|
||||
# 返回有效 UUID
|
||||
uuid.UUID(task_id)
|
||||
conn.commit.assert_called_once()
|
||||
conn.close.assert_called_once()
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_enqueue_position_increments(self, mock_get_conn, queue, sample_config):
|
||||
"""新任务 position = 当前最大 position + 1"""
|
||||
cur = _mock_cursor(fetchone_val=(5,))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.enqueue(sample_config, site_id=42)
|
||||
|
||||
# 检查 INSERT 调用中的 position 参数
|
||||
insert_call = cur.execute.call_args_list[1]
|
||||
args = insert_call[0][1]
|
||||
# args = (task_id, site_id, config_json, new_pos)
|
||||
assert args[3] == 6 # 5 + 1
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_enqueue_empty_queue_position_is_one(self, mock_get_conn, queue, sample_config):
|
||||
"""空队列时 position = 1"""
|
||||
cur = _mock_cursor(fetchone_val=(0,))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.enqueue(sample_config, site_id=42)
|
||||
|
||||
insert_call = cur.execute.call_args_list[1]
|
||||
args = insert_call[0][1]
|
||||
assert args[3] == 1
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_enqueue_serializes_config(self, mock_get_conn, queue, sample_config):
|
||||
"""config 被序列化为 JSON 字符串"""
|
||||
cur = _mock_cursor(fetchone_val=(0,))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.enqueue(sample_config, site_id=42)
|
||||
|
||||
insert_call = cur.execute.call_args_list[1]
|
||||
config_json_str = insert_call[0][1][2]
|
||||
parsed = json.loads(config_json_str)
|
||||
assert parsed["tasks"] == ["ODS_MEMBER", "ODS_PAYMENT"]
|
||||
assert parsed["pipeline"] == "api_ods_dwd"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# dequeue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDequeue:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_dequeue_returns_none_when_empty(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor(fetchone_val=None)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
result = queue.dequeue(site_id=42)
|
||||
|
||||
assert result is None
|
||||
conn.commit.assert_called()
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_dequeue_returns_task(self, mock_get_conn, queue):
|
||||
task_id = str(uuid.uuid4())
|
||||
config_dict = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}
|
||||
row = (
|
||||
task_id, 42, json.dumps(config_dict), "pending", 1,
|
||||
None, None, None, None, None,
|
||||
)
|
||||
cur = _mock_cursor(fetchone_val=row)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
result = queue.dequeue(site_id=42)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == task_id
|
||||
assert result.site_id == 42
|
||||
assert result.status == "running" # dequeue 后状态变为 running
|
||||
assert result.config["tasks"] == ["ODS_MEMBER"]
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_dequeue_updates_status_to_running(self, mock_get_conn, queue):
|
||||
task_id = str(uuid.uuid4())
|
||||
config_dict = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}
|
||||
row = (
|
||||
task_id, 42, json.dumps(config_dict), "pending", 1,
|
||||
None, None, None, None, None,
|
||||
)
|
||||
cur = _mock_cursor(fetchone_val=row)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.dequeue(site_id=42)
|
||||
|
||||
# 第二次 execute 调用应该是 UPDATE status = 'running'
|
||||
update_call = cur.execute.call_args_list[1]
|
||||
sql = update_call[0][0]
|
||||
assert "running" in sql
|
||||
assert task_id in update_call[0][1]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reorder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReorder:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_reorder_moves_task(self, mock_get_conn, queue):
|
||||
"""将第 3 个任务移到第 1 位"""
|
||||
ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||
rows = [(i,) for i in ids]
|
||||
cur = _mock_cursor(fetchall_val=rows)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.reorder(ids[2], new_position=1, site_id=42)
|
||||
|
||||
# 重排后顺序应为 [ids[2], ids[0], ids[1]]
|
||||
update_calls = cur.execute.call_args_list[1:] # 跳过 SELECT
|
||||
positions = {}
|
||||
for c in update_calls:
|
||||
pos, tid = c[0][1]
|
||||
positions[tid] = pos
|
||||
assert positions[ids[2]] == 1
|
||||
assert positions[ids[0]] == 2
|
||||
assert positions[ids[1]] == 3
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_reorder_nonexistent_task_is_noop(self, mock_get_conn, queue):
|
||||
"""重排不存在的任务不报错"""
|
||||
rows = [(str(uuid.uuid4()),)]
|
||||
cur = _mock_cursor(fetchall_val=rows)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.reorder("nonexistent-id", new_position=1, site_id=42)
|
||||
|
||||
# 只有 SELECT,没有 UPDATE
|
||||
assert cur.execute.call_count == 1
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_reorder_clamps_position(self, mock_get_conn, queue):
|
||||
"""position 超出范围时 clamp 到有效范围"""
|
||||
ids = [str(uuid.uuid4()) for _ in range(2)]
|
||||
rows = [(i,) for i in ids]
|
||||
cur = _mock_cursor(fetchall_val=rows)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
# new_position=100 超出范围,应 clamp 到末尾
|
||||
queue.reorder(ids[0], new_position=100, site_id=42)
|
||||
|
||||
update_calls = cur.execute.call_args_list[1:]
|
||||
positions = {}
|
||||
for c in update_calls:
|
||||
pos, tid = c[0][1]
|
||||
positions[tid] = pos
|
||||
# ids[0] 移到末尾
|
||||
assert positions[ids[1]] == 1
|
||||
assert positions[ids[0]] == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# delete
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDelete:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_delete_pending_task(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor(rowcount=1)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
result = queue.delete("task-1", site_id=42)
|
||||
|
||||
assert result is True
|
||||
conn.commit.assert_called_once()
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_delete_nonexistent_returns_false(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor(rowcount=0)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
result = queue.delete("nonexistent", site_id=42)
|
||||
|
||||
assert result is False
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_delete_only_affects_pending(self, mock_get_conn, queue):
|
||||
"""DELETE SQL 包含 status = 'pending' 条件"""
|
||||
cur = _mock_cursor(rowcount=0)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.delete("task-1", site_id=42)
|
||||
|
||||
sql = cur.execute.call_args[0][0]
|
||||
assert "pending" in sql
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_pending / has_running
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestQuery:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_list_pending_empty(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor(fetchall_val=[])
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
result = queue.list_pending(site_id=42)
|
||||
|
||||
assert result == []
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_list_pending_returns_tasks(self, mock_get_conn, queue):
|
||||
tid = str(uuid.uuid4())
|
||||
config = json.dumps({"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"})
|
||||
rows = [(tid, 42, config, "pending", 1, None, None, None, None, None)]
|
||||
cur = _mock_cursor(fetchall_val=rows)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
result = queue.list_pending(site_id=42)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == tid
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_has_running_true(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor(fetchone_val=(True,))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
assert queue.has_running(site_id=42) is True
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_has_running_false(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor(fetchone_val=(False,))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
assert queue.has_running(site_id=42) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# process_loop / _process_once
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestProcessLoop:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_once_skips_when_running(self, mock_get_conn, queue):
|
||||
"""有 running 任务时不 dequeue"""
|
||||
# _get_pending_site_ids 返回 [42]
|
||||
# has_running(42) 返回 True
|
||||
call_count = 0
|
||||
|
||||
def side_effect_conn():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# _get_pending_site_ids
|
||||
cur = _mock_cursor(fetchall_val=[(42,)])
|
||||
return _mock_conn(cur)
|
||||
else:
|
||||
# has_running
|
||||
cur = _mock_cursor(fetchone_val=(True,))
|
||||
return _mock_conn(cur)
|
||||
|
||||
mock_get_conn.side_effect = side_effect_conn
|
||||
|
||||
mock_executor = MagicMock()
|
||||
await queue._process_once(mock_executor)
|
||||
|
||||
# 不应调用 execute
|
||||
mock_executor.execute.assert_not_called()
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_once_dequeues_and_executes(self, mock_get_conn, queue):
|
||||
"""无 running 任务时 dequeue 并执行"""
|
||||
task_id = str(uuid.uuid4())
|
||||
config_dict = {
|
||||
"tasks": ["ODS_MEMBER"],
|
||||
"pipeline": "api_ods_dwd",
|
||||
"processing_mode": "increment_only",
|
||||
"dry_run": False,
|
||||
"window_mode": "lookback",
|
||||
"lookback_hours": 24,
|
||||
"overlap_seconds": 600,
|
||||
"fetch_before_verify": False,
|
||||
"skip_ods_when_fetch_before_verify": False,
|
||||
"ods_use_local_json": False,
|
||||
"extra_args": {},
|
||||
}
|
||||
config_json = json.dumps(config_dict)
|
||||
|
||||
call_count = 0
|
||||
|
||||
def side_effect_conn():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# _get_pending_site_ids
|
||||
cur = _mock_cursor(fetchall_val=[(42,)])
|
||||
return _mock_conn(cur)
|
||||
elif call_count == 2:
|
||||
# has_running → False
|
||||
cur = _mock_cursor(fetchone_val=(False,))
|
||||
return _mock_conn(cur)
|
||||
else:
|
||||
# dequeue → 返回任务
|
||||
row = (
|
||||
task_id, 42, config_json, "pending", 1,
|
||||
None, None, None, None, None,
|
||||
)
|
||||
cur = _mock_cursor(fetchone_val=row)
|
||||
return _mock_conn(cur)
|
||||
|
||||
mock_get_conn.side_effect = side_effect_conn
|
||||
|
||||
mock_executor = MagicMock()
|
||||
mock_executor.execute = AsyncMock()
|
||||
|
||||
await queue._process_once(mock_executor)
|
||||
|
||||
# 给 create_task 一点时间启动
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_once_no_pending(self, mock_get_conn, queue):
|
||||
"""无 pending 任务时什么都不做"""
|
||||
cur = _mock_cursor(fetchall_val=[])
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
mock_executor = MagicMock()
|
||||
await queue._process_once(mock_executor)
|
||||
|
||||
mock_executor.execute.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 生命周期
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLifecycle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_sets_running_false(self, queue):
|
||||
queue._running = True
|
||||
queue._loop_task = None
|
||||
|
||||
await queue.stop()
|
||||
|
||||
assert queue._running is False
|
||||
|
||||
def test_start_creates_task(self, queue):
|
||||
"""start() 应创建 asyncio.Task(需要事件循环)"""
|
||||
# 仅验证 _running 初始状态
|
||||
assert queue._running is False
|
||||
assert queue._loop_task is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _mark_failed / _update_queue_status_from_log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestInternalHelpers:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_mark_failed(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor()
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue._mark_failed("queue-1", "测试错误")
|
||||
|
||||
sql = cur.execute.call_args[0][0]
|
||||
assert "failed" in sql
|
||||
args = cur.execute.call_args[0][1]
|
||||
assert args[0] == "测试错误"
|
||||
assert args[1] == "queue-1"
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_update_queue_status_from_log(self, mock_get_conn, queue):
|
||||
"""从 execution_log 同步状态到 task_queue"""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
finished = datetime.now(timezone.utc)
|
||||
# 第一次 fetchone 返回 execution_log 行
|
||||
cur = _mock_cursor(fetchone_val=("success", finished, 0, None))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue._update_queue_status_from_log("queue-1")
|
||||
|
||||
# 应有 SELECT + UPDATE 两次 execute
|
||||
assert cur.execute.call_count == 2
|
||||
conn.commit.assert_called_once()
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_update_queue_status_no_log(self, mock_get_conn, queue):
|
||||
"""execution_log 无记录时不更新"""
|
||||
cur = _mock_cursor(fetchone_val=None)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue._update_queue_status_from_log("queue-1")
|
||||
|
||||
# 只有 SELECT,没有 UPDATE
|
||||
assert cur.execute.call_count == 1
|
||||
299
apps/backend/tests/test_task_registry_properties.py
Normal file
299
apps/backend/tests/test_task_registry_properties.py
Normal file
@@ -0,0 +1,299 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""任务注册表分组属性测试(Property-Based Testing)。
|
||||
|
||||
Property 4: 对于 Task_Registry 中的任务集合,分组结果中每个任务应出现在
|
||||
且仅出现在其所属业务域的分组中。
|
||||
|
||||
Validates: Requirements 2.1
|
||||
|
||||
测试策略:
|
||||
1. 直接测试 get_tasks_grouped_by_domain 函数:
|
||||
- 每个任务出现在且仅出现在其 domain 对应的分组中
|
||||
- 分组中的任务总数等于全部任务数(不多不少)
|
||||
- 每个分组的 key 等于该分组内所有任务的 domain
|
||||
2. 通过 API 端点测试(TestClient + mock auth):
|
||||
- 返回的 groups 中每个任务的 domain 与其所在分组 key 一致
|
||||
- 所有任务都出现在结果中
|
||||
3. 随机子集验证:
|
||||
- 随机选取任务子集,验证分组逻辑的一致性
|
||||
- 随机选取 domain,验证该 domain 下的任务都正确
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-registry")
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from app.services.task_registry import (
|
||||
get_all_tasks,
|
||||
get_tasks_grouped_by_domain,
|
||||
TaskDefinition,
|
||||
)
|
||||
from fastapi.testclient import TestClient
|
||||
from app.main import app
|
||||
from app.auth.dependencies import get_current_user, CurrentUser
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
ALL_TASKS = get_all_tasks()
|
||||
ALL_CODES = [t.code for t in ALL_TASKS]
|
||||
ALL_DOMAINS = list({t.domain for t in ALL_TASKS})
|
||||
|
||||
|
||||
def _mock_user() -> CurrentUser:
|
||||
return CurrentUser(user_id=1, site_id=1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4.1: 分组完整性 — 每个任务出现在且仅出现在其 domain 分组中
|
||||
# Validates: Requirements 2.1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(data=st.data())
|
||||
def test_every_task_in_exactly_its_domain_group(data):
|
||||
"""Property 4.1: 每个任务出现在且仅出现在其所属业务域的分组中。
|
||||
|
||||
从全量任务中随机选取一个任务,验证它只出现在对应 domain 的分组里,
|
||||
且不出现在其他任何分组中。
|
||||
"""
|
||||
grouped = get_tasks_grouped_by_domain()
|
||||
# 随机选取一个任务
|
||||
task = data.draw(st.sampled_from(ALL_TASKS))
|
||||
|
||||
# 该任务必须出现在其 domain 分组中
|
||||
assert task.domain in grouped, (
|
||||
f"任务 {task.code} 的 domain '{task.domain}' 不在分组 keys 中"
|
||||
)
|
||||
domain_codes = [t.code for t in grouped[task.domain]]
|
||||
assert task.code in domain_codes, (
|
||||
f"任务 {task.code} 未出现在其 domain '{task.domain}' 的分组中"
|
||||
)
|
||||
|
||||
# 该任务不应出现在其他任何分组中
|
||||
for other_domain, other_tasks in grouped.items():
|
||||
if other_domain == task.domain:
|
||||
continue
|
||||
other_codes = [t.code for t in other_tasks]
|
||||
assert task.code not in other_codes, (
|
||||
f"任务 {task.code}(domain={task.domain})错误地出现在 "
|
||||
f"domain '{other_domain}' 的分组中"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4.2: 分组总数守恒 — 分组中的任务总数等于全部任务数
|
||||
# Validates: Requirements 2.1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(data=st.data())
|
||||
def test_grouped_total_equals_all_tasks(data):
|
||||
"""Property 4.2: 分组中的任务总数等于全部任务数(不多不少)。
|
||||
|
||||
随机选取若干 domain 进行局部验证,同时验证全局总数守恒。
|
||||
"""
|
||||
all_tasks = get_all_tasks()
|
||||
grouped = get_tasks_grouped_by_domain()
|
||||
|
||||
# 全局守恒:分组内任务总数 == 全量任务数
|
||||
grouped_total = sum(len(tasks) for tasks in grouped.values())
|
||||
assert grouped_total == len(all_tasks), (
|
||||
f"分组总数 {grouped_total} != 全量任务数 {len(all_tasks)}"
|
||||
)
|
||||
|
||||
# 随机选取一个 domain,验证该 domain 下的任务数量正确
|
||||
domain = data.draw(st.sampled_from(ALL_DOMAINS))
|
||||
expected_count = sum(1 for t in all_tasks if t.domain == domain)
|
||||
actual_count = len(grouped[domain])
|
||||
assert actual_count == expected_count, (
|
||||
f"domain '{domain}' 分组内任务数 {actual_count} != 预期 {expected_count}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4.3: 分组 key 一致性 — 每个分组的 key 等于组内所有任务的 domain
|
||||
# Validates: Requirements 2.1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(data=st.data())
|
||||
def test_group_key_matches_task_domains(data):
|
||||
"""Property 4.3: 每个分组的 key 等于该分组内所有任务的 domain。
|
||||
|
||||
随机选取一个 domain 分组,验证组内每个任务的 domain 字段都等于分组 key。
|
||||
"""
|
||||
grouped = get_tasks_grouped_by_domain()
|
||||
domain = data.draw(st.sampled_from(list(grouped.keys())))
|
||||
|
||||
for task in grouped[domain]:
|
||||
assert task.domain == domain, (
|
||||
f"分组 '{domain}' 中的任务 {task.code} 的 domain 为 "
|
||||
f"'{task.domain}',与分组 key 不一致"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4.4: 任务 code 全局唯一 — 分组后不应出现重复 code
|
||||
# Validates: Requirements 2.1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(data=st.data())
|
||||
def test_no_duplicate_codes_across_groups(data):
|
||||
"""Property 4.4: 分组后所有任务的 code 全局唯一,无重复。
|
||||
|
||||
随机选取若干 domain 的任务合并,验证 code 不重复。
|
||||
"""
|
||||
grouped = get_tasks_grouped_by_domain()
|
||||
|
||||
# 收集所有分组中的 code
|
||||
all_codes_in_groups = []
|
||||
for tasks in grouped.values():
|
||||
all_codes_in_groups.extend(t.code for t in tasks)
|
||||
|
||||
assert len(all_codes_in_groups) == len(set(all_codes_in_groups)), (
|
||||
"分组中存在重复的任务 code"
|
||||
)
|
||||
|
||||
# 随机选取两个不同 domain,验证它们的任务 code 无交集
|
||||
if len(ALL_DOMAINS) >= 2:
|
||||
domains = data.draw(
|
||||
st.lists(st.sampled_from(ALL_DOMAINS), min_size=2, max_size=2, unique=True)
|
||||
)
|
||||
codes_a = {t.code for t in grouped[domains[0]]}
|
||||
codes_b = {t.code for t in grouped[domains[1]]}
|
||||
overlap = codes_a & codes_b
|
||||
assert not overlap, (
|
||||
f"domain '{domains[0]}' 和 '{domains[1]}' 存在重叠任务 code: {overlap}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4.5: 随机子集分组一致性 — 子集中的任务分组结果与全量一致
|
||||
# Validates: Requirements 2.1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
indices=st.lists(
|
||||
st.integers(min_value=0, max_value=len(ALL_TASKS) - 1),
|
||||
min_size=1,
|
||||
max_size=min(20, len(ALL_TASKS)),
|
||||
unique=True,
|
||||
)
|
||||
)
|
||||
def test_subset_grouping_consistency(indices):
|
||||
"""Property 4.5: 随机选取任务子集,验证每个任务在全量分组中的归属正确。
|
||||
|
||||
对于随机选取的任务子集,每个任务在 get_tasks_grouped_by_domain() 的结果中
|
||||
都应出现在其 domain 对应的分组里。
|
||||
"""
|
||||
grouped = get_tasks_grouped_by_domain()
|
||||
subset = [ALL_TASKS[i] for i in indices]
|
||||
|
||||
for task in subset:
|
||||
# 任务的 domain 必须是分组的 key 之一
|
||||
assert task.domain in grouped
|
||||
# 任务必须在对应分组中
|
||||
group_codes = {t.code for t in grouped[task.domain]}
|
||||
assert task.code in group_codes, (
|
||||
f"任务 {task.code} 未出现在 domain '{task.domain}' 的分组中"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4.6: API 端点分组正确性 — GET /api/tasks/registry 返回一致的分组
|
||||
# Validates: Requirements 2.1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(data=st.data())
|
||||
def test_api_registry_grouping_correctness(data):
|
||||
"""Property 4.6: API 端点返回的分组中,每个任务的 domain 与分组 key 一致,
|
||||
且所有任务都出现在结果中。
|
||||
"""
|
||||
app.dependency_overrides[get_current_user] = _mock_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/tasks/registry")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
groups = body["groups"]
|
||||
|
||||
# 收集 API 返回的所有任务 code
|
||||
api_codes: set[str] = set()
|
||||
for domain_key, task_list in groups.items():
|
||||
for task_item in task_list:
|
||||
# 每个任务的 domain 必须等于分组 key
|
||||
assert task_item["domain"] == domain_key, (
|
||||
f"API 返回的任务 {task_item['code']}(domain={task_item['domain']})"
|
||||
f"出现在分组 '{domain_key}' 中,不一致"
|
||||
)
|
||||
api_codes.add(task_item["code"])
|
||||
|
||||
# 所有任务都应出现在 API 结果中
|
||||
all_codes_set = {t.code for t in get_all_tasks()}
|
||||
assert api_codes == all_codes_set, (
|
||||
f"API 返回的任务集合与全量任务不一致。"
|
||||
f"缺失: {all_codes_set - api_codes},"
|
||||
f"多余: {api_codes - all_codes_set}"
|
||||
)
|
||||
|
||||
# 随机选取一个 domain,验证该 domain 下的任务数量与服务层一致
|
||||
if groups:
|
||||
domain = data.draw(st.sampled_from(list(groups.keys())))
|
||||
expected = get_tasks_grouped_by_domain()
|
||||
assert len(groups[domain]) == len(expected[domain]), (
|
||||
f"API 返回的 domain '{domain}' 任务数 {len(groups[domain])} "
|
||||
f"!= 服务层 {len(expected[domain])}"
|
||||
)
|
||||
finally:
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4.7: 随机 domain 过滤验证
|
||||
# Validates: Requirements 2.1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(domain=st.sampled_from(ALL_DOMAINS))
|
||||
def test_random_domain_tasks_all_correct(domain):
|
||||
"""Property 4.7: 随机选取一个 domain,验证该 domain 下的所有任务都正确归属。
|
||||
|
||||
对于选定的 domain:
|
||||
- 分组中的每个任务的 domain 字段都等于选定的 domain
|
||||
- 全量任务中所有属于该 domain 的任务都出现在分组中
|
||||
"""
|
||||
grouped = get_tasks_grouped_by_domain()
|
||||
all_tasks = get_all_tasks()
|
||||
|
||||
# 分组中该 domain 的任务
|
||||
group_tasks = grouped.get(domain, [])
|
||||
|
||||
# 全量任务中属于该 domain 的任务
|
||||
expected_tasks = [t for t in all_tasks if t.domain == domain]
|
||||
|
||||
# 数量一致
|
||||
assert len(group_tasks) == len(expected_tasks), (
|
||||
f"domain '{domain}': 分组中 {len(group_tasks)} 个任务,"
|
||||
f"预期 {len(expected_tasks)} 个"
|
||||
)
|
||||
|
||||
# code 集合一致
|
||||
group_codes = {t.code for t in group_tasks}
|
||||
expected_codes = {t.code for t in expected_tasks}
|
||||
assert group_codes == expected_codes, (
|
||||
f"domain '{domain}': 分组 codes {group_codes} != 预期 {expected_codes}"
|
||||
)
|
||||
|
||||
# 每个任务的 domain 字段都正确
|
||||
for task in group_tasks:
|
||||
assert task.domain == domain
|
||||
274
apps/backend/tests/test_tasks_router.py
Normal file
274
apps/backend/tests/test_tasks_router.py
Normal file
@@ -0,0 +1,274 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""任务注册表 API 单元测试
|
||||
|
||||
覆盖 4 个端点:registry / dwd-tables / flows / validate
|
||||
通过 JWT mock 绕过认证依赖。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
from app.services.task_registry import (
|
||||
ALL_TASKS,
|
||||
DWD_TABLES,
|
||||
FLOW_LAYER_MAP,
|
||||
get_tasks_grouped_by_domain,
|
||||
)
|
||||
|
||||
# 固定测试用户
|
||||
_TEST_USER = CurrentUser(user_id=1, site_id=100)
|
||||
|
||||
|
||||
def _override_auth():
|
||||
return _TEST_USER
|
||||
|
||||
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/tasks/registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTaskRegistry:
|
||||
def setup_method(self):
|
||||
"""每个测试方法前重新设置 auth 覆盖,防止其他测试文件的 clear/pop 导致状态泄漏"""
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
|
||||
def test_registry_returns_grouped_tasks(self):
|
||||
resp = client.get("/api/tasks/registry")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "groups" in data
|
||||
|
||||
# 所有任务都应出现在某个分组中
|
||||
all_codes_in_response = set()
|
||||
for domain, tasks in data["groups"].items():
|
||||
for t in tasks:
|
||||
all_codes_in_response.add(t["code"])
|
||||
assert t["domain"] == domain
|
||||
|
||||
expected_codes = {t.code for t in ALL_TASKS}
|
||||
assert all_codes_in_response == expected_codes
|
||||
|
||||
def test_registry_task_fields_complete(self):
|
||||
"""每个任务项包含所有必要字段"""
|
||||
resp = client.get("/api/tasks/registry")
|
||||
data = resp.json()
|
||||
required_fields = {"code", "name", "description", "domain", "layer",
|
||||
"requires_window", "is_ods", "is_dimension", "default_enabled"}
|
||||
for tasks in data["groups"].values():
|
||||
for t in tasks:
|
||||
assert required_fields.issubset(t.keys())
|
||||
|
||||
def test_registry_requires_auth(self):
|
||||
"""未认证时返回 401"""
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
try:
|
||||
resp = client.get("/api/tasks/registry")
|
||||
assert resp.status_code == 401
|
||||
finally:
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/tasks/dwd-tables
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDwdTables:
|
||||
def test_dwd_tables_returns_grouped(self):
|
||||
resp = client.get("/api/tasks/dwd-tables")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "groups" in data
|
||||
|
||||
all_tables_in_response = set()
|
||||
for domain, tables in data["groups"].items():
|
||||
for t in tables:
|
||||
all_tables_in_response.add(t["table_name"])
|
||||
assert t["domain"] == domain
|
||||
|
||||
expected_tables = {t.table_name for t in DWD_TABLES}
|
||||
assert all_tables_in_response == expected_tables
|
||||
|
||||
def test_dwd_tables_fields_complete(self):
|
||||
resp = client.get("/api/tasks/dwd-tables")
|
||||
data = resp.json()
|
||||
required_fields = {"table_name", "display_name", "domain", "ods_source", "is_dimension"}
|
||||
for tables in data["groups"].values():
|
||||
for t in tables:
|
||||
assert required_fields.issubset(t.keys())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/tasks/flows
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFlows:
|
||||
def test_flows_returns_seven_flows(self):
|
||||
resp = client.get("/api/tasks/flows")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["flows"]) == 7
|
||||
|
||||
def test_flows_returns_three_processing_modes(self):
|
||||
resp = client.get("/api/tasks/flows")
|
||||
data = resp.json()
|
||||
assert len(data["processing_modes"]) == 3
|
||||
|
||||
def test_flow_ids_match_registry(self):
|
||||
"""Flow ID 与 FLOW_LAYER_MAP 一致"""
|
||||
resp = client.get("/api/tasks/flows")
|
||||
data = resp.json()
|
||||
flow_ids = {f["id"] for f in data["flows"]}
|
||||
assert flow_ids == set(FLOW_LAYER_MAP.keys())
|
||||
|
||||
def test_flow_layers_non_empty(self):
|
||||
resp = client.get("/api/tasks/flows")
|
||||
data = resp.json()
|
||||
for f in data["flows"]:
|
||||
assert len(f["layers"]) > 0
|
||||
|
||||
def test_processing_mode_ids(self):
|
||||
resp = client.get("/api/tasks/flows")
|
||||
data = resp.json()
|
||||
mode_ids = {m["id"] for m in data["processing_modes"]}
|
||||
assert mode_ids == {"increment_only", "verify_only", "increment_verify"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/tasks/validate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestValidate:
|
||||
def test_validate_success(self):
|
||||
resp = client.post("/api/tasks/validate", json={
|
||||
"config": {
|
||||
"tasks": ["ODS_MEMBER", "ODS_PAYMENT"],
|
||||
"pipeline": "api_ods",
|
||||
}
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["valid"] is True
|
||||
assert data["errors"] == []
|
||||
assert len(data["command_args"]) > 0
|
||||
assert "--store-id" in data["command"]
|
||||
# store_id 应从 JWT 注入(测试用户 site_id=100)
|
||||
assert "100" in data["command"]
|
||||
|
||||
def test_validate_injects_store_id(self):
|
||||
"""即使前端传了 store_id,后端也用 JWT 中的值覆盖"""
|
||||
resp = client.post("/api/tasks/validate", json={
|
||||
"config": {
|
||||
"tasks": ["DWD_LOAD_FROM_ODS"],
|
||||
"pipeline": "ods_dwd",
|
||||
"store_id": 999,
|
||||
}
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
# 命令中应包含 JWT 的 site_id=100,而非前端传的 999
|
||||
assert "--store-id" in data["command"]
|
||||
idx = data["command_args"].index("--store-id")
|
||||
assert data["command_args"][idx + 1] == "100"
|
||||
|
||||
def test_validate_invalid_flow(self):
|
||||
resp = client.post("/api/tasks/validate", json={
|
||||
"config": {
|
||||
"tasks": ["ODS_MEMBER"],
|
||||
"pipeline": "nonexistent_flow",
|
||||
}
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["valid"] is False
|
||||
assert any("无效的执行流程" in e for e in data["errors"])
|
||||
|
||||
def test_validate_empty_tasks(self):
|
||||
resp = client.post("/api/tasks/validate", json={
|
||||
"config": {
|
||||
"tasks": [],
|
||||
"pipeline": "api_ods",
|
||||
}
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["valid"] is False
|
||||
assert any("任务列表不能为空" in e for e in data["errors"])
|
||||
|
||||
def test_validate_custom_window(self):
|
||||
resp = client.post("/api/tasks/validate", json={
|
||||
"config": {
|
||||
"tasks": ["ODS_MEMBER"],
|
||||
"pipeline": "api_ods",
|
||||
"window_mode": "custom",
|
||||
"window_start": "2024-01-01",
|
||||
"window_end": "2024-01-31",
|
||||
}
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["valid"] is True
|
||||
assert "--window-start" in data["command"]
|
||||
assert "--window-end" in data["command"]
|
||||
|
||||
def test_validate_window_end_before_start_rejected(self):
|
||||
"""window_end 早于 window_start 时 Pydantic 验证失败 → 422"""
|
||||
resp = client.post("/api/tasks/validate", json={
|
||||
"config": {
|
||||
"tasks": ["ODS_MEMBER"],
|
||||
"pipeline": "api_ods",
|
||||
"window_mode": "custom",
|
||||
"window_start": "2024-12-31",
|
||||
"window_end": "2024-01-01",
|
||||
}
|
||||
})
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_validate_dry_run_flag(self):
|
||||
resp = client.post("/api/tasks/validate", json={
|
||||
"config": {
|
||||
"tasks": ["ODS_MEMBER"],
|
||||
"pipeline": "api_ods",
|
||||
"dry_run": True,
|
||||
}
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "--dry-run" in data["command"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# task_registry 服务层测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTaskRegistryService:
|
||||
def test_all_tasks_have_unique_codes(self):
|
||||
codes = [t.code for t in ALL_TASKS]
|
||||
assert len(codes) == len(set(codes))
|
||||
|
||||
def test_grouped_tasks_cover_all(self):
|
||||
grouped = get_tasks_grouped_by_domain()
|
||||
all_codes = set()
|
||||
for tasks in grouped.values():
|
||||
for t in tasks:
|
||||
all_codes.add(t.code)
|
||||
assert all_codes == {t.code for t in ALL_TASKS}
|
||||
|
||||
def test_ods_tasks_marked_is_ods(self):
|
||||
for t in ALL_TASKS:
|
||||
if t.layer == "ODS":
|
||||
assert t.is_ods is True
|
||||
|
||||
def test_flow_layer_map_covers_all_flows(self):
|
||||
expected_flows = {"api_ods", "api_ods_dwd", "api_full", "ods_dwd",
|
||||
"dwd_dws", "dwd_dws_index", "dwd_index"}
|
||||
assert set(FLOW_LAYER_MAP.keys()) == expected_flows
|
||||
186
apps/backend/tests/test_ws_logs.py
Normal file
186
apps/backend/tests/test_ws_logs.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""WebSocket 日志推送端点测试
|
||||
|
||||
测试 /ws/logs/{execution_id} 端点的连接、日志回放、实时推送和断开行为。
|
||||
利用 TaskExecutor 已有的 subscribe/broadcast 机制进行验证。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
from app.main import app
|
||||
from app.services.task_executor import task_executor
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _cleanup_executor():
|
||||
"""每个测试前后清理 TaskExecutor 内部状态。"""
|
||||
yield
|
||||
# 清理所有残留的缓冲区和订阅者
|
||||
for eid in list(task_executor._log_buffers.keys()):
|
||||
task_executor.cleanup(eid)
|
||||
task_executor._subscribers.clear()
|
||||
task_executor._log_buffers.clear()
|
||||
|
||||
|
||||
class TestWebSocketConnection:
|
||||
"""WebSocket 连接/断开基本行为"""
|
||||
|
||||
def test_connect_and_disconnect(self):
|
||||
"""客户端能成功建立和关闭 WebSocket 连接。"""
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/ws/logs/test-exec-001") as ws:
|
||||
# 连接成功,直接关闭
|
||||
pass # __exit__ 会关闭连接
|
||||
|
||||
def test_connect_registers_subscriber(self):
|
||||
"""连接后 TaskExecutor 应注册订阅者。"""
|
||||
client = TestClient(app)
|
||||
# 预先初始化缓冲区(模拟有任务在运行)
|
||||
task_executor._log_buffers["test-exec-002"] = []
|
||||
|
||||
with client.websocket_connect("/ws/logs/test-exec-002"):
|
||||
# 连接期间应有订阅者
|
||||
assert "test-exec-002" in task_executor._subscribers
|
||||
assert len(task_executor._subscribers["test-exec-002"]) >= 1
|
||||
|
||||
|
||||
class TestLogReplay:
|
||||
"""历史日志回放"""
|
||||
|
||||
def test_replay_existing_logs(self):
|
||||
"""连接时应先收到内存缓冲区中已有的日志行。"""
|
||||
eid = "test-exec-replay"
|
||||
# 预填充日志缓冲区
|
||||
task_executor._log_buffers[eid] = [
|
||||
"[stdout] 第一行",
|
||||
"[stdout] 第二行",
|
||||
"[stderr] 警告信息",
|
||||
]
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
|
||||
# 应按顺序收到 3 条历史日志
|
||||
msg1 = ws.receive_text()
|
||||
msg2 = ws.receive_text()
|
||||
msg3 = ws.receive_text()
|
||||
|
||||
assert msg1 == "[stdout] 第一行"
|
||||
assert msg2 == "[stdout] 第二行"
|
||||
assert msg3 == "[stderr] 警告信息"
|
||||
|
||||
def test_no_logs_no_replay(self):
|
||||
"""没有历史日志时不应收到回放消息。"""
|
||||
eid = "test-exec-empty"
|
||||
task_executor._log_buffers[eid] = []
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
|
||||
# 发送结束信号让连接正常关闭
|
||||
task_executor._broadcast_end(eid)
|
||||
# 不应有回放消息,直接收到的是后续的结束信号处理
|
||||
|
||||
|
||||
class TestBroadcastReceive:
|
||||
"""实时日志广播接收"""
|
||||
|
||||
def test_receive_broadcast_messages(self):
|
||||
"""连接后应能收到 TaskExecutor 广播的实时日志。"""
|
||||
eid = "test-exec-broadcast"
|
||||
task_executor._log_buffers[eid] = []
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
|
||||
# 模拟 TaskExecutor 广播日志
|
||||
task_executor._broadcast(eid, "[stdout] 实时日志行1")
|
||||
task_executor._broadcast(eid, "[stderr] 实时错误行")
|
||||
|
||||
msg1 = ws.receive_text()
|
||||
msg2 = ws.receive_text()
|
||||
|
||||
assert msg1 == "[stdout] 实时日志行1"
|
||||
assert msg2 == "[stderr] 实时错误行"
|
||||
|
||||
def test_end_signal_closes_connection(self):
|
||||
"""收到 None 结束信号后 WebSocket 应正常关闭。"""
|
||||
eid = "test-exec-end"
|
||||
task_executor._log_buffers[eid] = []
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
|
||||
# 广播一条日志后发送结束信号
|
||||
task_executor._broadcast(eid, "[stdout] 最后一行")
|
||||
task_executor._broadcast_end(eid)
|
||||
|
||||
msg = ws.receive_text()
|
||||
assert msg == "[stdout] 最后一行"
|
||||
|
||||
def test_replay_then_broadcast(self):
|
||||
"""先回放历史日志,再接收实时广播。"""
|
||||
eid = "test-exec-mixed"
|
||||
task_executor._log_buffers[eid] = ["[stdout] 历史行"]
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
|
||||
# 先收到历史回放
|
||||
replay = ws.receive_text()
|
||||
assert replay == "[stdout] 历史行"
|
||||
|
||||
# 再收到实时广播
|
||||
task_executor._broadcast(eid, "[stdout] 新行")
|
||||
task_executor._broadcast_end(eid)
|
||||
|
||||
live = ws.receive_text()
|
||||
assert live == "[stdout] 新行"
|
||||
|
||||
|
||||
class TestLogBroadcasterUnit:
|
||||
"""直接测试 TaskExecutor 的 subscribe/unsubscribe/broadcast 方法
|
||||
(作为 LogBroadcaster 功能的单元测试)。
|
||||
"""
|
||||
|
||||
def test_subscribe_creates_queue(self):
|
||||
eid = "unit-sub"
|
||||
q = task_executor.subscribe(eid)
|
||||
assert isinstance(q, asyncio.Queue)
|
||||
assert eid in task_executor._subscribers
|
||||
task_executor.unsubscribe(eid, q)
|
||||
|
||||
def test_unsubscribe_removes_queue(self):
|
||||
eid = "unit-unsub"
|
||||
q = task_executor.subscribe(eid)
|
||||
task_executor.unsubscribe(eid, q)
|
||||
# 最后一个订阅者移除后,key 也应被清理
|
||||
assert eid not in task_executor._subscribers
|
||||
|
||||
def test_broadcast_delivers_to_all_subscribers(self):
|
||||
eid = "unit-multi"
|
||||
q1 = task_executor.subscribe(eid)
|
||||
q2 = task_executor.subscribe(eid)
|
||||
|
||||
task_executor._broadcast(eid, "测试消息")
|
||||
|
||||
assert q1.get_nowait() == "测试消息"
|
||||
assert q2.get_nowait() == "测试消息"
|
||||
|
||||
task_executor.unsubscribe(eid, q1)
|
||||
task_executor.unsubscribe(eid, q2)
|
||||
|
||||
def test_broadcast_end_sends_none(self):
|
||||
eid = "unit-end"
|
||||
q = task_executor.subscribe(eid)
|
||||
|
||||
task_executor._broadcast_end(eid)
|
||||
|
||||
assert q.get_nowait() is None
|
||||
task_executor.unsubscribe(eid, q)
|
||||
|
||||
def test_broadcast_no_subscribers_is_safe(self):
|
||||
"""没有订阅者时广播不应报错。"""
|
||||
task_executor._broadcast("nonexistent", "无人接收")
|
||||
task_executor._broadcast_end("nonexistent")
|
||||
Reference in New Issue
Block a user