在准备环境前提交次全部更改。

This commit is contained in:
Neo
2026-02-19 08:35:13 +08:00
parent ded6dfb9d8
commit 4eac07da47
1387 changed files with 6107191 additions and 33002 deletions

48
apps/backend/.env.local Normal file
View File

@@ -0,0 +1,48 @@
# ==============================================================================
# NeoZQYY 后端 .env.local — 私有覆盖层
# ==============================================================================
# 后端 config.py 以 override=True 加载此文件,优先级高于根 .env
# 敏感值禁止提交;本文件已在 .gitignore 中排除
# ------------------------------------------------------------------------------
# 业务数据库zqyy_app
# ------------------------------------------------------------------------------
# DB_HOST / DB_PORT / DB_USER / DB_PASSWORD 继承自根 .env无需重复
# CHANGE 2026-02-15 | 默认指向测试库,生产环境切换为 zqyy_app
APP_DB_NAME=test_zqyy_app
# ------------------------------------------------------------------------------
# ETL 数据库(后端只读访问,用于数据库查看器)
# ------------------------------------------------------------------------------
# 与 zqyy_app 同实例时可省略 ETL_DB_HOST/PORT/USER/PASSWORD自动复用
# CHANGE 2026-02-15 | 默认指向测试库,生产环境切换为 etl_feiqiu
ETL_DB_NAME=test_etl_feiqiu
# ------------------------------------------------------------------------------
# JWT 认证
# ------------------------------------------------------------------------------
JWT_SECRET_KEY=change-me-in-production
JWT_ALGORITHM=HS256
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
JWT_REFRESH_TOKEN_EXPIRE_DAYS=7
# ------------------------------------------------------------------------------
# CORS逗号分隔
# ------------------------------------------------------------------------------
CORS_ORIGINS=http://localhost:5173
# ------------------------------------------------------------------------------
# 微信消息推送(与微信后台填写的 Token 一致)
# CHANGE 2026-02-19 | 新增微信消息推送回调 Token
# ------------------------------------------------------------------------------
WX_CALLBACK_TOKEN=LLZQwx2026push
# ------------------------------------------------------------------------------
# 通用
# ------------------------------------------------------------------------------
LOG_LEVEL=INFO
# ------------------------------------------------------------------------------
# ETL 项目路径(子进程 cwd缺省按 monorepo 相对路径推算)
# ------------------------------------------------------------------------------
# ETL_PROJECT_PATH=C:/NeoZQYY/apps/etl/connectors/feiqiu

View File

@@ -4,7 +4,7 @@
## 内部结构
`
```
apps/backend/
├── app/
│ ├── main.py # FastAPI 入口,启用 OpenAPI 文档
@@ -16,21 +16,22 @@ apps/backend/
├── tests/ # 后端测试
├── pyproject.toml # 依赖声明
└── README.md
`
```
## 启动
`ash
```bash
# 确保已在根目录执行 uv sync --all-packages
cd apps/backend
uvicorn app.main:app --reload
`
uv run uvicorn app.main:app --host 127.0.0.1 --port 8000 --reload
```
API 文档自动生成于 http://localhost:8000/docs
## 依赖
- fastapi>=0.100, uvicorn>=0.23
- psycopg2-binary>=2.9.0
- fastapi>=0.115, uvicorn[standard]>=0.34
- psycopg2-binary>=2.9, python-dotenv>=1.0
- neozqyy-sharedworkspace 引用)
## Roadmap

View File

@@ -0,0 +1 @@
"""认证模块JWT 令牌管理与 FastAPI 依赖注入。"""

View File

@@ -0,0 +1,67 @@
"""
FastAPI 依赖注入:从 JWT 提取当前用户信息。
用法:
@router.get("/protected")
async def protected_endpoint(user: CurrentUser = Depends(get_current_user)):
print(user.user_id, user.site_id)
"""
from dataclasses import dataclass
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jose import JWTError
from app.auth.jwt import decode_access_token
# Bearer token 提取器
_bearer_scheme = HTTPBearer(auto_error=True)
@dataclass(frozen=True)
class CurrentUser:
"""从 JWT 解析出的当前用户上下文。"""
user_id: int
site_id: int
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(_bearer_scheme),
) -> CurrentUser:
"""
FastAPI 依赖:从 Authorization header 提取 JWT验证后返回用户信息。
失败时抛出 401。
"""
token = credentials.credentials
try:
payload = decode_access_token(token)
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的令牌",
headers={"WWW-Authenticate": "Bearer"},
)
user_id_raw = payload.get("sub")
site_id = payload.get("site_id")
if user_id_raw is None or site_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="令牌缺少必要字段",
headers={"WWW-Authenticate": "Bearer"},
)
try:
user_id = int(user_id_raw)
except (TypeError, ValueError):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="令牌中 user_id 格式无效",
headers={"WWW-Authenticate": "Bearer"},
)
return CurrentUser(user_id=user_id, site_id=site_id)

View File

@@ -0,0 +1,112 @@
"""
JWT 令牌生成、验证与解码。
- access_token短期有效默认 30 分钟),用于 API 请求认证
- refresh_token长期有效默认 7 天),用于刷新 access_token
- payload 包含 user_id、site_id、令牌类型access / refresh
- 密码哈希直接使用 bcrypt 库passlib 与 bcrypt>=4.1 存在兼容性问题)
"""
from datetime import datetime, timedelta, timezone
import bcrypt
from jose import JWTError, jwt
from app import config
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""校验明文密码与哈希是否匹配。"""
return bcrypt.checkpw(
plain_password.encode("utf-8"), hashed_password.encode("utf-8")
)
def hash_password(password: str) -> str:
"""生成密码的 bcrypt 哈希。"""
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
def create_access_token(user_id: int, site_id: int) -> str:
"""
生成 access_token。
payload: sub=user_id, site_id, type=access, exp
"""
expire = datetime.now(timezone.utc) + timedelta(
minutes=config.JWT_ACCESS_TOKEN_EXPIRE_MINUTES
)
payload = {
"sub": str(user_id),
"site_id": site_id,
"type": "access",
"exp": expire,
}
return jwt.encode(payload, config.JWT_SECRET_KEY, algorithm=config.JWT_ALGORITHM)
def create_refresh_token(user_id: int, site_id: int) -> str:
"""
生成 refresh_token。
payload: sub=user_id, site_id, type=refresh, exp
"""
expire = datetime.now(timezone.utc) + timedelta(
days=config.JWT_REFRESH_TOKEN_EXPIRE_DAYS
)
payload = {
"sub": str(user_id),
"site_id": site_id,
"type": "refresh",
"exp": expire,
}
return jwt.encode(payload, config.JWT_SECRET_KEY, algorithm=config.JWT_ALGORITHM)
def create_token_pair(user_id: int, site_id: int) -> dict[str, str]:
"""生成 access_token + refresh_token 令牌对。"""
return {
"access_token": create_access_token(user_id, site_id),
"refresh_token": create_refresh_token(user_id, site_id),
"token_type": "bearer",
}
def decode_token(token: str) -> dict:
"""
解码并验证 JWT 令牌。
返回 payload dict包含 sub、site_id、type、exp。
令牌无效或过期时抛出 JWTError。
"""
try:
payload = jwt.decode(
token, config.JWT_SECRET_KEY, algorithms=[config.JWT_ALGORITHM]
)
return payload
except JWTError:
raise
def decode_access_token(token: str) -> dict:
"""
解码 access_token 并验证类型。
令牌类型不是 access 时抛出 JWTError。
"""
payload = decode_token(token)
if payload.get("type") != "access":
raise JWTError("令牌类型不是 access")
return payload
def decode_refresh_token(token: str) -> dict:
"""
解码 refresh_token 并验证类型。
令牌类型不是 refresh 时抛出 JWTError。
"""
payload = decode_token(token)
if payload.get("type") != "refresh":
raise JWTError("令牌类型不是 refresh")
return payload

View File

@@ -29,7 +29,37 @@ DB_HOST: str = get("DB_HOST", "localhost")
DB_PORT: str = get("DB_PORT", "5432")
DB_USER: str = get("DB_USER", "")
DB_PASSWORD: str = get("DB_PASSWORD", "")
APP_DB_NAME: str = get("APP_DB_NAME", "zqyy_app")
# CHANGE 2026-02-15 | 默认指向测试库,生产环境通过 .env 覆盖
APP_DB_NAME: str = get("APP_DB_NAME", "test_zqyy_app")
# ---- JWT 认证 ----
JWT_SECRET_KEY: str = get("JWT_SECRET_KEY", "") # 生产环境必须设置
JWT_ALGORITHM: str = get("JWT_ALGORITHM", "HS256")
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = int(get("JWT_ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = int(get("JWT_REFRESH_TOKEN_EXPIRE_DAYS", "7"))
# ---- ETL 数据库连接参数(可独立配置,缺省时复用 zqyy_app 的连接参数) ----
ETL_DB_HOST: str = get("ETL_DB_HOST") or DB_HOST
ETL_DB_PORT: str = get("ETL_DB_PORT") or DB_PORT
ETL_DB_USER: str = get("ETL_DB_USER") or DB_USER
ETL_DB_PASSWORD: str = get("ETL_DB_PASSWORD") or DB_PASSWORD
# CHANGE 2026-02-15 | 默认指向测试库,生产环境通过 .env 覆盖
ETL_DB_NAME: str = get("ETL_DB_NAME", "test_etl_feiqiu")
# ---- CORS ----
# 逗号分隔的允许来源列表;缺省允许 Vite 开发服务器
CORS_ORIGINS: list[str] = [
o.strip()
for o in get("CORS_ORIGINS", "http://localhost:5173").split(",")
if o.strip()
]
# ---- ETL 项目路径 ----
# ETL CLI 的工作目录(子进程 cwd缺省时按 monorepo 相对路径推算
ETL_PROJECT_PATH: str = get(
"ETL_PROJECT_PATH",
str(Path(__file__).resolve().parents[3] / "apps" / "etl" / "connectors" / "feiqiu"),
)
# ---- 通用 ----
TIMEZONE: str = get("TIMEZONE", "Asia/Shanghai")

View File

@@ -1,14 +1,30 @@
"""
zqyy_app 数据库连接
数据库连接
使用 psycopg2 直连 PostgreSQL不引入 ORM。
连接参数从环境变量读取(经 config 模块加载)。
提供两类连接:
- get_connection()zqyy_app 读写连接(用户/队列/调度等业务数据)
- get_etl_readonly_connection(site_id)etl_feiqiu 只读连接(数据库查看器),
自动设置 RLS site_id 隔离
"""
import psycopg2
from psycopg2.extensions import connection as PgConnection
from app.config import APP_DB_NAME, DB_HOST, DB_PASSWORD, DB_PORT, DB_USER
from app.config import (
APP_DB_NAME,
DB_HOST,
DB_PASSWORD,
DB_PORT,
DB_USER,
ETL_DB_HOST,
ETL_DB_NAME,
ETL_DB_PASSWORD,
ETL_DB_PORT,
ETL_DB_USER,
)
def get_connection() -> PgConnection:
@@ -24,3 +40,43 @@ def get_connection() -> PgConnection:
password=DB_PASSWORD,
dbname=APP_DB_NAME,
)
def get_etl_readonly_connection(site_id: int | str) -> PgConnection:
"""
获取 ETL 数据库etl_feiqiu的只读连接。
连接建立后自动执行:
1. SET default_transaction_read_only = on — 禁止写操作
2. SET LOCAL app.current_site_id = '{site_id}' — 启用 RLS 门店隔离
调用方负责关闭连接。典型用法::
conn = get_etl_readonly_connection(site_id)
try:
with conn.cursor() as cur:
cur.execute("SELECT ...")
finally:
conn.close()
"""
conn = psycopg2.connect(
host=ETL_DB_HOST,
port=ETL_DB_PORT,
user=ETL_DB_USER,
password=ETL_DB_PASSWORD,
dbname=ETL_DB_NAME,
)
try:
conn.autocommit = False
with conn.cursor() as cur:
# 会话级只读:防止任何写操作
cur.execute("SET default_transaction_read_only = on")
# 事务级 RLS 隔离:设置当前门店 ID
cur.execute(
"SET LOCAL app.current_site_id = %s", (str(site_id),)
)
conn.commit()
except Exception:
conn.close()
raise
return conn

View File

@@ -1,20 +1,66 @@
"""
NeoZQYY 后端 API 入口
基于 FastAPI 构建,为微信小程序提供 RESTful API。
基于 FastAPI 构建,为管理后台和微信小程序提供 RESTful API。
OpenAPI 文档自动生成于 /docsSwagger UI和 /redocReDoc
"""
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app import config
# CHANGE 2026-02-19 | 新增 xcx_test 路由MVP 验证)+ wx_callback 路由(微信消息推送)
from app.routers import auth, execution, schedules, tasks, env_config, db_viewer, etl_status, xcx_test, wx_callback
from app.services.scheduler import scheduler
from app.services.task_queue import task_queue
from app.ws.logs import ws_router
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期:启动时拉起后台服务,关闭时优雅停止。"""
# 启动
task_queue.start()
scheduler.start()
yield
# 关闭
await scheduler.stop()
await task_queue.stop()
app = FastAPI(
title="NeoZQYY API",
description="台球门店运营助手 — 微信小程序后端 API",
description="台球门店运营助手 — 后端 API(管理后台 + 微信小程序)",
version="0.1.0",
docs_url="/docs",
redoc_url="/redoc",
lifespan=lifespan,
)
# ---- CORS 中间件 ----
# 允许来源从环境变量 CORS_ORIGINS 读取,缺省允许 Vite 开发服务器 (localhost:5173)
app.add_middleware(
CORSMiddleware,
allow_origins=config.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---- 路由注册 ----
app.include_router(auth.router)
app.include_router(tasks.router)
app.include_router(execution.router)
app.include_router(schedules.router)
app.include_router(env_config.router)
app.include_router(db_viewer.router)
app.include_router(etl_status.router)
app.include_router(ws_router)
app.include_router(xcx_test.router)
app.include_router(wx_callback.router)
@app.get("/health", tags=["系统"])
async def health_check():

View File

@@ -0,0 +1,97 @@
"""
认证路由:登录与令牌刷新。
- POST /api/auth/login — 验证用户名密码,返回 JWT 令牌对
- POST /api/auth/refresh — 用刷新令牌换取新的访问令牌
"""
import logging
from fastapi import APIRouter, HTTPException, status
from jose import JWTError
from app.auth.jwt import (
create_access_token,
create_token_pair,
decode_refresh_token,
verify_password,
)
from app.database import get_connection
from app.schemas.auth import LoginRequest, RefreshRequest, TokenResponse
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/auth", tags=["认证"])
@router.post("/login", response_model=TokenResponse)
async def login(body: LoginRequest):
"""
用户登录。
查询 admin_users 表验证用户名密码,成功后返回 JWT 令牌对。
- 用户不存在或密码错误401
- 账号已禁用is_active=false401
"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"SELECT id, password_hash, site_id, is_active "
"FROM admin_users WHERE username = %s",
(body.username,),
)
row = cur.fetchone()
finally:
conn.close()
if row is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
)
user_id, password_hash, site_id, is_active = row
if not is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="账号已被禁用",
)
if not verify_password(body.password, password_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
)
tokens = create_token_pair(user_id, site_id)
return TokenResponse(**tokens)
@router.post("/refresh", response_model=TokenResponse)
async def refresh(body: RefreshRequest):
"""
刷新访问令牌。
验证 refresh_token 有效性,成功后仅返回新的 access_token
refresh_token 保持不变,由客户端继续持有)。
"""
try:
payload = decode_refresh_token(body.refresh_token)
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的刷新令牌",
)
user_id = int(payload["sub"])
site_id = payload["site_id"]
# 生成新的 access_tokenrefresh_token 原样返回
new_access = create_access_token(user_id, site_id)
return TokenResponse(
access_token=new_access,
refresh_token=body.refresh_token,
token_type="bearer",
)

View File

@@ -0,0 +1,228 @@
# -*- coding: utf-8 -*-
"""数据库查看器 API
提供 4 个端点:
- GET /api/db/schemas — 返回 Schema 列表
- GET /api/db/schemas/{name}/tables — 返回表列表和行数
- GET /api/db/tables/{schema}/{table}/columns — 返回列定义
- POST /api/db/query — 只读 SQL 执行
所有端点需要 JWT 认证。
使用 get_etl_readonly_connection(site_id) 确保 RLS 隔离。
"""
from __future__ import annotations
import logging
import re
from fastapi import APIRouter, Depends, HTTPException, status
from psycopg2 import errors as pg_errors, OperationalError
from app.auth.dependencies import CurrentUser, get_current_user
from app.database import get_etl_readonly_connection
from app.schemas.db_viewer import (
ColumnInfo,
QueryRequest,
QueryResponse,
SchemaInfo,
TableInfo,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/db", tags=["数据库查看器"])
# 写操作关键词(不区分大小写)
_WRITE_KEYWORDS = re.compile(
r"\b(INSERT|UPDATE|DELETE|DROP|TRUNCATE)\b",
re.IGNORECASE,
)
# 查询结果行数上限
_MAX_ROWS = 1000
# 查询超时(秒)
_QUERY_TIMEOUT_SEC = 30
# ── GET /api/db/schemas ──────────────────────────────────────
@router.get("/schemas", response_model=list[SchemaInfo])
async def list_schemas(
user: CurrentUser = Depends(get_current_user),
) -> list[SchemaInfo]:
"""返回 ETL 数据库中的 Schema 列表。"""
conn = get_etl_readonly_connection(user.site_id)
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
ORDER BY schema_name
"""
)
rows = cur.fetchall()
return [SchemaInfo(name=row[0]) for row in rows]
finally:
conn.close()
# ── GET /api/db/schemas/{name}/tables ────────────────────────
@router.get("/schemas/{name}/tables", response_model=list[TableInfo])
async def list_tables(
name: str,
user: CurrentUser = Depends(get_current_user),
) -> list[TableInfo]:
"""返回指定 Schema 下所有表的名称和行数统计。"""
conn = get_etl_readonly_connection(user.site_id)
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT
t.table_name,
s.n_live_tup
FROM information_schema.tables t
LEFT JOIN pg_stat_user_tables s
ON s.schemaname = t.table_schema
AND s.relname = t.table_name
WHERE t.table_schema = %s
AND t.table_type = 'BASE TABLE'
ORDER BY t.table_name
""",
(name,),
)
rows = cur.fetchall()
return [
TableInfo(name=row[0], row_count=row[1])
for row in rows
]
finally:
conn.close()
# ── GET /api/db/tables/{schema}/{table}/columns ──────────────
@router.get(
"/tables/{schema}/{table}/columns",
response_model=list[ColumnInfo],
)
async def list_columns(
schema: str,
table: str,
user: CurrentUser = Depends(get_current_user),
) -> list[ColumnInfo]:
"""返回指定表的列定义。"""
conn = get_etl_readonly_connection(user.site_id)
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT
column_name,
data_type,
is_nullable,
column_default
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
""",
(schema, table),
)
rows = cur.fetchall()
return [
ColumnInfo(
name=row[0],
data_type=row[1],
is_nullable=row[2] == "YES",
column_default=row[3],
)
for row in rows
]
finally:
conn.close()
# ── POST /api/db/query ───────────────────────────────────────
@router.post("/query", response_model=QueryResponse)
async def execute_query(
body: QueryRequest,
user: CurrentUser = Depends(get_current_user),
) -> QueryResponse:
"""只读 SQL 执行。
安全措施:
1. 拦截写操作关键词INSERT / UPDATE / DELETE / DROP / TRUNCATE
2. 限制返回行数上限 1000 行
3. 设置查询超时 30 秒
"""
sql = body.sql.strip()
if not sql:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="SQL 语句不能为空",
)
# 拦截写操作
if _WRITE_KEYWORDS.search(sql):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="只允许只读查询,禁止 INSERT / UPDATE / DELETE / DROP / TRUNCATE 操作",
)
conn = get_etl_readonly_connection(user.site_id)
try:
with conn.cursor() as cur:
# 设置查询超时
cur.execute(
"SET LOCAL statement_timeout = %s",
(f"{_QUERY_TIMEOUT_SEC}s",),
)
try:
cur.execute(sql)
except pg_errors.QueryCanceled:
raise HTTPException(
status_code=status.HTTP_408_REQUEST_TIMEOUT,
detail=f"查询超时(超过 {_QUERY_TIMEOUT_SEC} 秒)",
)
except Exception as exc:
# SQL 语法错误或其他执行错误
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"SQL 执行错误: {exc}",
)
# 提取列名
columns = (
[desc[0] for desc in cur.description]
if cur.description
else []
)
# 限制返回行数
rows = cur.fetchmany(_MAX_ROWS)
# 将元组转为列表,便于 JSON 序列化
rows_list = [list(row) for row in rows]
return QueryResponse(
columns=columns,
rows=rows_list,
row_count=len(rows_list),
)
except HTTPException:
raise
except OperationalError as exc:
# 连接级错误
logger.error("数据库查看器连接错误: %s", exc)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="数据库连接错误",
)
finally:
conn.close()

View File

@@ -0,0 +1,240 @@
# -*- coding: utf-8 -*-
"""环境配置 API
提供 3 个端点:
- GET /api/env-config — 读取 .env敏感值掩码
- PUT /api/env-config — 验证并写入 .env
- GET /api/env-config/export — 导出去敏感值的配置文件
所有端点需要 JWT 认证。
敏感键判定:键名中包含 PASSWORD、TOKEN、SECRET、DSN不区分大小写
"""
from __future__ import annotations
import logging
import re
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel
from app.auth.dependencies import CurrentUser, get_current_user
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/env-config", tags=["环境配置"])
# .env 文件路径:项目根目录
_ENV_PATH = Path(__file__).resolve().parents[3] / ".env"
# 敏感键关键词(不区分大小写)
_SENSITIVE_KEYWORDS = ("PASSWORD", "TOKEN", "SECRET", "DSN")
_MASK = "****"
# ── Pydantic 模型 ────────────────────────────────────────────
class EnvEntry(BaseModel):
"""单条环境变量键值对。"""
key: str
value: str
class EnvConfigResponse(BaseModel):
"""GET 响应:键值对列表。"""
entries: list[EnvEntry]
class EnvConfigUpdateRequest(BaseModel):
"""PUT 请求体:键值对列表。"""
entries: list[EnvEntry]
# ── 工具函数 ─────────────────────────────────────────────────
def _is_sensitive(key: str) -> bool:
"""判断键名是否为敏感键。"""
upper = key.upper()
return any(kw in upper for kw in _SENSITIVE_KEYWORDS)
def _parse_env(content: str) -> list[dict]:
"""解析 .env 文件内容,返回行级结构。
每行分为三种类型:
- comment: 注释行或空行(原样保留)
- entry: 键值对
"""
lines: list[dict] = []
for raw_line in content.splitlines():
stripped = raw_line.strip()
if not stripped or stripped.startswith("#"):
lines.append({"type": "comment", "raw": raw_line})
else:
# 支持 KEY=VALUE 和 KEY="VALUE" 格式
match = re.match(r'^([A-Za-z_][A-Za-z0-9_]*)=(.*)', raw_line)
if match:
key = match.group(1)
value = match.group(2).strip()
# 去除引号包裹
if len(value) >= 2 and value[0] == value[-1] and value[0] in ('"', "'"):
value = value[1:-1]
lines.append({"type": "entry", "key": key, "value": value, "raw": raw_line})
else:
# 无法解析的行当作注释保留
lines.append({"type": "comment", "raw": raw_line})
return lines
def _read_env_file(path: Path) -> str:
"""读取 .env 文件内容。"""
if not path.exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=".env 文件不存在",
)
return path.read_text(encoding="utf-8")
def _write_env_file(path: Path, content: str) -> None:
"""写入 .env 文件。"""
try:
path.write_text(content, encoding="utf-8")
except OSError as exc:
logger.error("写入 .env 文件失败: %s", exc)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="写入 .env 文件失败",
)
def _validate_entries(entries: list[EnvEntry]) -> None:
"""验证键值对格式。"""
for idx, entry in enumerate(entries):
if not entry.key:
raise HTTPException(
status_code=422,
detail=f"{idx + 1} 行:键名不能为空",
)
if not re.match(r'^[A-Za-z_][A-Za-z0-9_]*$', entry.key):
raise HTTPException(
status_code=422,
detail=f"{idx + 1} 行:键名 '{entry.key}' 格式无效(仅允许字母、数字、下划线,且不能以数字开头)",
)
# ── GET /api/env-config — 读取 ───────────────────────────────
@router.get("", response_model=EnvConfigResponse)
async def get_env_config(
user: CurrentUser = Depends(get_current_user),
) -> EnvConfigResponse:
"""读取 .env 文件,敏感值以掩码展示。"""
content = _read_env_file(_ENV_PATH)
parsed = _parse_env(content)
entries = []
for line in parsed:
if line["type"] == "entry":
value = _MASK if _is_sensitive(line["key"]) else line["value"]
entries.append(EnvEntry(key=line["key"], value=value))
return EnvConfigResponse(entries=entries)
# ── PUT /api/env-config — 写入 ───────────────────────────────
@router.put("", response_model=EnvConfigResponse)
async def update_env_config(
body: EnvConfigUpdateRequest,
user: CurrentUser = Depends(get_current_user),
) -> EnvConfigResponse:
"""验证并写入 .env 文件。
保留原文件中的注释行和空行。对于已有键,更新值;
对于新键,追加到文件末尾。掩码值(****)的键跳过更新,保留原值。
"""
_validate_entries(body.entries)
# 读取原文件(如果存在)
if _ENV_PATH.exists():
original_content = _ENV_PATH.read_text(encoding="utf-8")
parsed = _parse_env(original_content)
else:
parsed = []
# 构建新值映射(跳过掩码值)
new_values: dict[str, str] = {}
for entry in body.entries:
if entry.value != _MASK:
new_values[entry.key] = entry.value
# 更新已有行
seen_keys: set[str] = set()
output_lines: list[str] = []
for line in parsed:
if line["type"] == "comment":
output_lines.append(line["raw"])
elif line["type"] == "entry":
key = line["key"]
seen_keys.add(key)
if key in new_values:
output_lines.append(f"{key}={new_values[key]}")
else:
# 保留原值(包括掩码跳过的敏感键)
output_lines.append(line["raw"])
# 追加新键
for entry in body.entries:
if entry.key not in seen_keys and entry.value != _MASK:
output_lines.append(f"{entry.key}={entry.value}")
new_content = "\n".join(output_lines)
if output_lines:
new_content += "\n"
_write_env_file(_ENV_PATH, new_content)
# 返回更新后的配置(敏感值掩码)
result_parsed = _parse_env(new_content)
entries = []
for line in result_parsed:
if line["type"] == "entry":
value = _MASK if _is_sensitive(line["key"]) else line["value"]
entries.append(EnvEntry(key=line["key"], value=value))
return EnvConfigResponse(entries=entries)
# ── GET /api/env-config/export — 导出 ────────────────────────
@router.get("/export")
async def export_env_config(
user: CurrentUser = Depends(get_current_user),
) -> PlainTextResponse:
"""导出去除敏感值的配置文件(作为文件下载)。"""
content = _read_env_file(_ENV_PATH)
parsed = _parse_env(content)
output_lines: list[str] = []
for line in parsed:
if line["type"] == "comment":
output_lines.append(line["raw"])
elif line["type"] == "entry":
if _is_sensitive(line["key"]):
output_lines.append(f"{line['key']}={_MASK}")
else:
output_lines.append(line["raw"])
export_content = "\n".join(output_lines)
if output_lines:
export_content += "\n"
return PlainTextResponse(
content=export_content,
media_type="text/plain",
headers={"Content-Disposition": "attachment; filename=env-config.txt"},
)

View File

@@ -0,0 +1,134 @@
# -*- coding: utf-8 -*-
"""ETL 状态监控 API
提供 2 个端点:
- GET /api/etl-status/cursors — 返回各任务的数据游标(最后抓取时间、记录数)
- GET /api/etl-status/recent-runs — 返回最近 50 条任务执行记录
所有端点需要 JWT 认证。
游标端点查询 ETL 数据库meta.etl_cursor
执行记录端点查询 zqyy_app 数据库task_execution_log
"""
from __future__ import annotations
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from psycopg2 import OperationalError
from app.auth.dependencies import CurrentUser, get_current_user
from app.database import get_connection, get_etl_readonly_connection
from app.schemas.etl_status import CursorInfo, RecentRun
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/etl-status", tags=["ETL 状态"])
# 最近执行记录条数上限
_RECENT_RUNS_LIMIT = 50
# ── GET /api/etl-status/cursors ──────────────────────────────
@router.get("/cursors", response_model=list[CursorInfo])
async def list_cursors(
user: CurrentUser = Depends(get_current_user),
) -> list[CursorInfo]:
"""返回各 ODS 表的最新数据游标。
查询 ETL 数据库中的 meta.etl_cursor 表。
如果该表不存在,返回空列表而非报错。
"""
conn = get_etl_readonly_connection(user.site_id)
try:
with conn.cursor() as cur:
# CHANGE 2026-02-15 | 对齐新库 etl_feiqiu 六层架构etl_admin → meta
cur.execute(
"""
SELECT EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_schema = 'meta'
AND table_name = 'etl_cursor'
)
"""
)
exists = cur.fetchone()[0]
if not exists:
return []
cur.execute(
"""
SELECT task_code, last_fetch_time, record_count
FROM meta.etl_cursor
ORDER BY task_code
"""
)
rows = cur.fetchall()
return [
CursorInfo(
task_code=row[0],
last_fetch_time=str(row[1]) if row[1] is not None else None,
record_count=row[2],
)
for row in rows
]
except OperationalError as exc:
logger.error("ETL 游标查询连接错误: %s", exc)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="ETL 数据库连接错误",
)
finally:
conn.close()
# ── GET /api/etl-status/recent-runs ──────────────────────────
@router.get("/recent-runs", response_model=list[RecentRun])
async def list_recent_runs(
user: CurrentUser = Depends(get_current_user),
) -> list[RecentRun]:
"""返回最近 50 条任务执行记录。
查询 zqyy_app 数据库中的 task_execution_log 表,
按 site_id 过滤,按 started_at DESC 排序。
"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, task_codes, status, started_at,
finished_at, duration_ms, exit_code
FROM task_execution_log
WHERE site_id = %s
ORDER BY started_at DESC
LIMIT %s
""",
(user.site_id, _RECENT_RUNS_LIMIT),
)
rows = cur.fetchall()
return [
RecentRun(
id=str(row[0]),
task_codes=list(row[1]) if row[1] else [],
status=row[2],
started_at=str(row[3]),
finished_at=str(row[4]) if row[4] is not None else None,
duration_ms=row[5],
exit_code=row[6],
)
for row in rows
]
except OperationalError as exc:
logger.error("执行记录查询连接错误: %s", exc)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="数据库连接错误",
)
finally:
conn.close()

View File

@@ -0,0 +1,281 @@
# -*- coding: utf-8 -*-
"""执行与队列 API
提供 8 个端点:
- POST /api/execution/run — 直接执行任务
- GET /api/execution/queue — 获取当前队列(按 site_id 过滤)
- POST /api/execution/queue — 添加到队列
- PUT /api/execution/queue/reorder — 重排队列
- DELETE /api/execution/queue/{id} — 删除队列任务
- POST /api/execution/{id}/cancel — 取消执行中的任务
- GET /api/execution/history — 执行历史(按 site_id 过滤)
- GET /api/execution/{id}/logs — 获取历史日志
所有端点需要 JWT 认证site_id 从 JWT 提取。
"""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, status
from app.auth.dependencies import CurrentUser, get_current_user
from app.database import get_connection
from app.schemas.execution import (
ExecutionHistoryItem,
ExecutionLogsResponse,
ExecutionRunResponse,
QueueTaskResponse,
ReorderRequest,
)
from app.schemas.tasks import TaskConfigSchema
from app.services.task_executor import task_executor
from app.services.task_queue import task_queue
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/execution", tags=["任务执行"])
# ── POST /api/execution/run — 直接执行任务 ────────────────────
@router.post("/run", response_model=ExecutionRunResponse)
async def run_task(
config: TaskConfigSchema,
user: CurrentUser = Depends(get_current_user),
) -> ExecutionRunResponse:
"""直接执行任务(不经过队列)。
从 JWT 注入 store_id创建 execution_id 后异步启动子进程。
"""
config = config.model_copy(update={"store_id": user.site_id})
execution_id = str(uuid.uuid4())
# 异步启动执行,不阻塞响应
asyncio.create_task(
task_executor.execute(
config=config,
execution_id=execution_id,
site_id=user.site_id,
)
)
return ExecutionRunResponse(
execution_id=execution_id,
message="任务已提交执行",
)
# ── GET /api/execution/queue — 获取当前队列 ───────────────────
@router.get("/queue", response_model=list[QueueTaskResponse])
async def get_queue(
user: CurrentUser = Depends(get_current_user),
) -> list[QueueTaskResponse]:
"""获取当前门店的待执行队列。"""
tasks = task_queue.list_pending(user.site_id)
return [
QueueTaskResponse(
id=t.id,
site_id=t.site_id,
config=t.config,
status=t.status,
position=t.position,
created_at=t.created_at,
started_at=t.started_at,
finished_at=t.finished_at,
exit_code=t.exit_code,
error_message=t.error_message,
)
for t in tasks
]
# ── POST /api/execution/queue — 添加到队列 ───────────────────
@router.post("/queue", response_model=QueueTaskResponse, status_code=status.HTTP_201_CREATED)
async def enqueue_task(
config: TaskConfigSchema,
user: CurrentUser = Depends(get_current_user),
) -> QueueTaskResponse:
"""将任务配置添加到执行队列。"""
config = config.model_copy(update={"store_id": user.site_id})
task_id = task_queue.enqueue(config, user.site_id)
# 查询刚创建的任务返回完整信息
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, site_id, config, status, position,
created_at, started_at, finished_at,
exit_code, error_message
FROM task_queue WHERE id = %s
""",
(task_id,),
)
row = cur.fetchone()
conn.commit()
finally:
conn.close()
if row is None:
raise HTTPException(status_code=500, detail="入队后查询失败")
config_data = row[2] if isinstance(row[2], dict) else json.loads(row[2])
return QueueTaskResponse(
id=str(row[0]),
site_id=row[1],
config=config_data,
status=row[3],
position=row[4],
created_at=row[5],
started_at=row[6],
finished_at=row[7],
exit_code=row[8],
error_message=row[9],
)
# ── PUT /api/execution/queue/reorder — 重排队列 ──────────────
@router.put("/queue/reorder")
async def reorder_queue(
body: ReorderRequest,
user: CurrentUser = Depends(get_current_user),
) -> dict:
"""调整队列中任务的执行顺序。"""
task_queue.reorder(body.task_id, body.new_position, user.site_id)
return {"message": "队列已重排"}
# ── DELETE /api/execution/queue/{id} — 删除队列任务 ──────────
@router.delete("/queue/{task_id}")
async def delete_queue_task(
task_id: str,
user: CurrentUser = Depends(get_current_user),
) -> dict:
"""从队列中删除待执行任务。仅允许删除 pending 状态的任务。"""
deleted = task_queue.delete(task_id, user.site_id)
if not deleted:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="任务不存在或非待执行状态,无法删除",
)
return {"message": "任务已从队列中删除"}
# ── POST /api/execution/{id}/cancel — 取消执行 ──────────────
@router.post("/{execution_id}/cancel")
async def cancel_execution(
execution_id: str,
user: CurrentUser = Depends(get_current_user),
) -> dict:
"""取消正在执行的任务。"""
cancelled = await task_executor.cancel(execution_id)
if not cancelled:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="执行任务不存在或已完成",
)
return {"message": "已发送取消信号"}
# ── GET /api/execution/history — 执行历史 ────────────────────
@router.get("/history", response_model=list[ExecutionHistoryItem])
async def get_execution_history(
limit: int = Query(default=50, ge=1, le=200),
user: CurrentUser = Depends(get_current_user),
) -> list[ExecutionHistoryItem]:
"""获取执行历史记录,按 started_at 降序排列。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, site_id, task_codes, status, started_at,
finished_at, exit_code, duration_ms, command, summary
FROM task_execution_log
WHERE site_id = %s
ORDER BY started_at DESC
LIMIT %s
""",
(user.site_id, limit),
)
rows = cur.fetchall()
conn.commit()
finally:
conn.close()
return [
ExecutionHistoryItem(
id=str(row[0]),
site_id=row[1],
task_codes=row[2] or [],
status=row[3],
started_at=row[4],
finished_at=row[5],
exit_code=row[6],
duration_ms=row[7],
command=row[8],
summary=row[9],
)
for row in rows
]
# ── GET /api/execution/{id}/logs — 获取历史日志 ──────────────
@router.get("/{execution_id}/logs", response_model=ExecutionLogsResponse)
async def get_execution_logs(
execution_id: str,
user: CurrentUser = Depends(get_current_user),
) -> ExecutionLogsResponse:
"""获取指定执行的完整日志。
优先从内存缓冲区读取(执行中),否则从数据库读取(已完成)。
"""
# 先尝试内存缓冲区(执行中的任务)
if task_executor.is_running(execution_id):
lines = task_executor.get_logs(execution_id)
return ExecutionLogsResponse(
execution_id=execution_id,
output_log="\n".join(lines) if lines else None,
)
# 从数据库读取已完成任务的日志
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT output_log, error_log
FROM task_execution_log
WHERE id = %s AND site_id = %s
""",
(execution_id, user.site_id),
)
row = cur.fetchone()
conn.commit()
finally:
conn.close()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="执行记录不存在",
)
return ExecutionLogsResponse(
execution_id=execution_id,
output_log=row[0],
error_log=row[1],
)

View File

@@ -0,0 +1,293 @@
# -*- coding: utf-8 -*-
"""调度任务 CRUD API
提供 5 个端点:
- GET /api/schedules — 列表(按 site_id 过滤)
- POST /api/schedules — 创建
- PUT /api/schedules/{id} — 更新
- DELETE /api/schedules/{id} — 删除
- PATCH /api/schedules/{id}/toggle — 启用/禁用
所有端点需要 JWT 认证site_id 从 JWT 提取。
"""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, status
from app.auth.dependencies import CurrentUser, get_current_user
from app.database import get_connection
from app.schemas.schedules import (
CreateScheduleRequest,
ScheduleResponse,
UpdateScheduleRequest,
)
from app.services.scheduler import calculate_next_run
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/schedules", tags=["调度管理"])
def _row_to_response(row) -> ScheduleResponse:
"""将数据库行转换为 ScheduleResponse。"""
task_config = row[4] if isinstance(row[4], dict) else json.loads(row[4])
schedule_config = row[5] if isinstance(row[5], dict) else json.loads(row[5])
return ScheduleResponse(
id=str(row[0]),
site_id=row[1],
name=row[2],
task_codes=row[3] or [],
task_config=task_config,
schedule_config=schedule_config,
enabled=row[6],
last_run_at=row[7],
next_run_at=row[8],
run_count=row[9],
last_status=row[10],
created_at=row[11],
updated_at=row[12],
)
# 查询列列表,复用于多个端点
_SELECT_COLS = """
id, site_id, name, task_codes, task_config, schedule_config,
enabled, last_run_at, next_run_at, run_count, last_status,
created_at, updated_at
"""
# ── GET /api/schedules — 列表 ────────────────────────────────
@router.get("", response_model=list[ScheduleResponse])
async def list_schedules(
user: CurrentUser = Depends(get_current_user),
) -> list[ScheduleResponse]:
"""获取当前门店的所有调度任务。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
f"SELECT {_SELECT_COLS} FROM scheduled_tasks WHERE site_id = %s ORDER BY created_at DESC",
(user.site_id,),
)
rows = cur.fetchall()
conn.commit()
finally:
conn.close()
return [_row_to_response(row) for row in rows]
# ── POST /api/schedules — 创建 ──────────────────────────────
@router.post("", response_model=ScheduleResponse, status_code=status.HTTP_201_CREATED)
async def create_schedule(
body: CreateScheduleRequest,
user: CurrentUser = Depends(get_current_user),
) -> ScheduleResponse:
"""创建调度任务,自动计算 next_run_at。"""
now = datetime.now(timezone.utc)
next_run = calculate_next_run(body.schedule_config, now)
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
f"""
INSERT INTO scheduled_tasks
(site_id, name, task_codes, task_config, schedule_config, enabled, next_run_at)
VALUES (%s, %s, %s, %s, %s, %s, %s)
RETURNING {_SELECT_COLS}
""",
(
user.site_id,
body.name,
body.task_codes,
json.dumps(body.task_config),
body.schedule_config.model_dump_json(),
body.schedule_config.enabled,
next_run,
),
)
row = cur.fetchone()
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
return _row_to_response(row)
# ── PUT /api/schedules/{id} — 更新 ──────────────────────────
@router.put("/{schedule_id}", response_model=ScheduleResponse)
async def update_schedule(
schedule_id: str,
body: UpdateScheduleRequest,
user: CurrentUser = Depends(get_current_user),
) -> ScheduleResponse:
"""更新调度任务,仅更新请求中提供的字段。"""
# 构建动态 SET 子句
set_parts: list[str] = []
params: list = []
if body.name is not None:
set_parts.append("name = %s")
params.append(body.name)
if body.task_codes is not None:
set_parts.append("task_codes = %s")
params.append(body.task_codes)
if body.task_config is not None:
set_parts.append("task_config = %s")
params.append(json.dumps(body.task_config))
if body.schedule_config is not None:
set_parts.append("schedule_config = %s")
params.append(body.schedule_config.model_dump_json())
# 更新调度配置时重新计算 next_run_at
now = datetime.now(timezone.utc)
next_run = calculate_next_run(body.schedule_config, now)
set_parts.append("next_run_at = %s")
params.append(next_run)
if not set_parts:
raise HTTPException(
status_code=422,
detail="至少需要提供一个更新字段",
)
set_parts.append("updated_at = NOW()")
set_clause = ", ".join(set_parts)
params.extend([schedule_id, user.site_id])
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
f"""
UPDATE scheduled_tasks
SET {set_clause}
WHERE id = %s AND site_id = %s
RETURNING {_SELECT_COLS}
""",
params,
)
row = cur.fetchone()
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="调度任务不存在",
)
return _row_to_response(row)
# ── DELETE /api/schedules/{id} — 删除 ────────────────────────
@router.delete("/{schedule_id}")
async def delete_schedule(
schedule_id: str,
user: CurrentUser = Depends(get_current_user),
) -> dict:
"""删除调度任务。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"DELETE FROM scheduled_tasks WHERE id = %s AND site_id = %s",
(schedule_id, user.site_id),
)
deleted = cur.rowcount
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
if deleted == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="调度任务不存在",
)
return {"message": "调度任务已删除"}
# ── PATCH /api/schedules/{id}/toggle — 启用/禁用 ─────────────
@router.patch("/{schedule_id}/toggle", response_model=ScheduleResponse)
async def toggle_schedule(
schedule_id: str,
user: CurrentUser = Depends(get_current_user),
) -> ScheduleResponse:
"""切换调度任务的启用/禁用状态。
禁用时 next_run_at 置 NULL启用时重新计算 next_run_at。
"""
conn = get_connection()
try:
# 先查询当前状态和调度配置
with conn.cursor() as cur:
cur.execute(
"SELECT enabled, schedule_config FROM scheduled_tasks WHERE id = %s AND site_id = %s",
(schedule_id, user.site_id),
)
row = cur.fetchone()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="调度任务不存在",
)
current_enabled = row[0]
new_enabled = not current_enabled
if new_enabled:
# 启用:重新计算 next_run_at
schedule_config_raw = row[1] if isinstance(row[1], dict) else json.loads(row[1])
from app.schemas.schedules import ScheduleConfigSchema
schedule_cfg = ScheduleConfigSchema(**schedule_config_raw)
now = datetime.now(timezone.utc)
next_run = calculate_next_run(schedule_cfg, now)
else:
# 禁用next_run_at 置 NULL
next_run = None
with conn.cursor() as cur:
cur.execute(
f"""
UPDATE scheduled_tasks
SET enabled = %s, next_run_at = %s, updated_at = NOW()
WHERE id = %s AND site_id = %s
RETURNING {_SELECT_COLS}
""",
(new_enabled, next_run, schedule_id, user.site_id),
)
updated_row = cur.fetchone()
conn.commit()
except HTTPException:
raise
except Exception:
conn.rollback()
raise
finally:
conn.close()
return _row_to_response(updated_row)

View File

@@ -0,0 +1,209 @@
# -*- coding: utf-8 -*-
"""任务注册表 & 配置 API
提供 4 个端点:
- GET /api/tasks/registry — 按业务域分组的任务列表
- GET /api/tasks/dwd-tables — 按业务域分组的 DWD 表定义
- GET /api/tasks/flows — 7 种 Flow + 3 种处理模式
- POST /api/tasks/validate — 验证 TaskConfig 并返回 CLI 命令预览
所有端点需要 JWT 认证。validate 端点从 JWT 注入 store_id。
"""
from __future__ import annotations
from typing import Any
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from app.auth.dependencies import CurrentUser, get_current_user
from app.config import ETL_PROJECT_PATH
from app.schemas.tasks import (
FlowDefinition,
ProcessingModeDefinition,
TaskConfigSchema,
)
from app.services.cli_builder import cli_builder
from app.services.task_registry import (
DWD_TABLES,
FLOW_LAYER_MAP,
get_dwd_tables_grouped_by_domain,
get_tasks_grouped_by_domain,
)
router = APIRouter(prefix="/api/tasks", tags=["任务配置"])
# ── 响应模型 ──────────────────────────────────────────────────
class TaskItem(BaseModel):
code: str
name: str
description: str
domain: str
layer: str
requires_window: bool
is_ods: bool
is_dimension: bool
default_enabled: bool
is_common: bool
class DwdTableItem(BaseModel):
table_name: str
display_name: str
domain: str
ods_source: str
is_dimension: bool
class TaskRegistryResponse(BaseModel):
"""按业务域分组的任务列表"""
groups: dict[str, list[TaskItem]]
class DwdTablesResponse(BaseModel):
"""按业务域分组的 DWD 表定义"""
groups: dict[str, list[DwdTableItem]]
class FlowsResponse(BaseModel):
"""Flow 定义 + 处理模式定义"""
flows: list[FlowDefinition]
processing_modes: list[ProcessingModeDefinition]
class ValidateRequest(BaseModel):
"""验证请求体 — 复用 TaskConfigSchema但 store_id 由后端注入"""
config: TaskConfigSchema
class ValidateResponse(BaseModel):
"""验证结果 + CLI 命令预览"""
valid: bool
command: str
command_args: list[str]
errors: list[str]
# ── Flow 定义(静态) ────────────────────────────────────────
FLOW_DEFINITIONS: list[FlowDefinition] = [
FlowDefinition(id="api_ods", name="API → ODS", layers=["ODS"]),
FlowDefinition(id="api_ods_dwd", name="API → ODS → DWD", layers=["ODS", "DWD"]),
FlowDefinition(id="api_full", name="API → ODS → DWD → DWS汇总 → DWS指数", layers=["ODS", "DWD", "DWS", "INDEX"]),
FlowDefinition(id="ods_dwd", name="ODS → DWD", layers=["DWD"]),
FlowDefinition(id="dwd_dws", name="DWD → DWS汇总", layers=["DWS"]),
FlowDefinition(id="dwd_dws_index", name="DWD → DWS汇总 → DWS指数", layers=["DWS", "INDEX"]),
FlowDefinition(id="dwd_index", name="DWD → DWS指数", layers=["INDEX"]),
]
PROCESSING_MODE_DEFINITIONS: list[ProcessingModeDefinition] = [
ProcessingModeDefinition(id="increment_only", name="仅增量处理", description="只处理新增和变更的数据"),
ProcessingModeDefinition(id="verify_only", name="仅校验修复", description="校验现有数据并修复不一致(可选'校验前从 API 获取'"),
ProcessingModeDefinition(id="increment_verify", name="增量 + 校验修复", description="先增量处理,再校验并修复"),
]
# ── 端点 ──────────────────────────────────────────────────────
@router.get("/registry", response_model=TaskRegistryResponse)
async def get_task_registry(
user: CurrentUser = Depends(get_current_user),
) -> TaskRegistryResponse:
"""返回按业务域分组的任务列表"""
grouped = get_tasks_grouped_by_domain()
return TaskRegistryResponse(
groups={
domain: [
TaskItem(
code=t.code,
name=t.name,
description=t.description,
domain=t.domain,
layer=t.layer,
requires_window=t.requires_window,
is_ods=t.is_ods,
is_dimension=t.is_dimension,
default_enabled=t.default_enabled,
is_common=t.is_common,
)
for t in tasks
]
for domain, tasks in grouped.items()
}
)
@router.get("/dwd-tables", response_model=DwdTablesResponse)
async def get_dwd_tables(
user: CurrentUser = Depends(get_current_user),
) -> DwdTablesResponse:
"""返回按业务域分组的 DWD 表定义"""
grouped = get_dwd_tables_grouped_by_domain()
return DwdTablesResponse(
groups={
domain: [
DwdTableItem(
table_name=t.table_name,
display_name=t.display_name,
domain=t.domain,
ods_source=t.ods_source,
is_dimension=t.is_dimension,
)
for t in tables
]
for domain, tables in grouped.items()
}
)
@router.get("/flows", response_model=FlowsResponse)
async def get_flows(
user: CurrentUser = Depends(get_current_user),
) -> FlowsResponse:
"""返回 7 种 Flow 定义和 3 种处理模式定义"""
return FlowsResponse(
flows=FLOW_DEFINITIONS,
processing_modes=PROCESSING_MODE_DEFINITIONS,
)
@router.post("/validate", response_model=ValidateResponse)
async def validate_task_config(
body: ValidateRequest,
user: CurrentUser = Depends(get_current_user),
) -> ValidateResponse:
"""验证 TaskConfig 并返回生成的 CLI 命令预览
从 JWT 注入 store_id前端无需传递。
"""
config = body.config.model_copy(update={"store_id": user.site_id})
errors: list[str] = []
# 验证 Flow ID
if config.pipeline not in FLOW_LAYER_MAP:
errors.append(f"无效的执行流程: {config.pipeline}")
# 验证任务列表非空
if not config.tasks:
errors.append("任务列表不能为空")
if errors:
return ValidateResponse(
valid=False,
command="",
command_args=[],
errors=errors,
)
cmd_args = cli_builder.build_command(config, ETL_PROJECT_PATH)
cmd_str = cli_builder.build_command_string(config, ETL_PROJECT_PATH)
return ValidateResponse(
valid=True,
command=cmd_str,
command_args=cmd_args,
errors=[],
)

View File

@@ -0,0 +1,104 @@
# AI_CHANGELOG
# - 2026-02-19 | Prompt: 配置微信消息推送 | 新增微信消息推送回调接口,支持 GET 验签 + POST 消息接收
"""
微信消息推送回调接口
处理两类请求:
1. GET — 微信服务器验证(配置时触发一次)
2. POST — 接收微信推送的消息/事件
安全模式下需要解密消息体,当前先用明文模式跑通,后续切安全模式。
"""
import hashlib
import logging
from fastapi import APIRouter, Query, Request, Response
from app.config import get
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/wx", tags=["微信回调"])
# Token 从环境变量读取,与微信后台填写的一致
# 放在 apps/backend/.env.local 中WX_CALLBACK_TOKEN=你自定义的token
WX_CALLBACK_TOKEN: str = get("WX_CALLBACK_TOKEN", "")
def _check_signature(signature: str, timestamp: str, nonce: str) -> bool:
"""
验证请求是否来自微信服务器。
将 Token、timestamp、nonce 字典序排序后拼接,做 SHA1
与 signature 比对。
"""
if not WX_CALLBACK_TOKEN:
logger.error("WX_CALLBACK_TOKEN 未配置")
return False
items = sorted([WX_CALLBACK_TOKEN, timestamp, nonce])
hash_str = hashlib.sha1("".join(items).encode("utf-8")).hexdigest()
return hash_str == signature
@router.get("/callback")
async def verify(
signature: str = Query(...),
timestamp: str = Query(...),
nonce: str = Query(...),
echostr: str = Query(...),
):
"""
微信服务器验证接口。
配置消息推送时微信会发 GET 请求,验签通过后原样返回 echostr。
"""
if _check_signature(signature, timestamp, nonce):
logger.info("微信回调验证通过")
# 必须原样返回 echostr纯文本不能包裹 JSON
return Response(content=echostr, media_type="text/plain")
else:
logger.warning("微信回调验签失败: signature=%s", signature)
return Response(content="signature mismatch", status_code=403)
@router.post("/callback")
async def receive_message(
request: Request,
signature: str = Query(""),
timestamp: str = Query(""),
nonce: str = Query(""),
):
"""
接收微信推送的消息/事件。
当前为明文模式,直接解析 JSON 包体。
后续切安全模式时需增加 AES 解密逻辑。
"""
# 验签POST 也带 signature 参数)
if not _check_signature(signature, timestamp, nonce):
logger.warning("消息推送验签失败")
return Response(content="signature mismatch", status_code=403)
# 解析消息体
body = await request.body()
content_type = request.headers.get("content-type", "")
if "json" in content_type:
import json
try:
data = json.loads(body)
except json.JSONDecodeError:
data = {"raw": body.decode("utf-8", errors="replace")}
else:
# XML 格式暂不解析,记录原文
data = {"raw_xml": body.decode("utf-8", errors="replace")}
logger.info("收到微信推送: MsgType=%s, Event=%s",
data.get("MsgType", "?"), data.get("Event", "?"))
# TODO: 根据 MsgType/Event 分发处理(客服消息、订阅事件等)
# 当前统一返回 success
return Response(content="success", media_type="text/plain")

View File

@@ -0,0 +1,37 @@
# AI_CHANGELOG
# - 2026-02-19 | Prompt: 小程序 MVP 全链路验证 | 新增 /api/xcx-test 接口,从 test."xcx-test" 表读取 ti 列第一行
"""
小程序 MVP 验证接口
从 test_zqyy_app 库的 test."xcx-test" 表读取数据,
用于验证小程序 → 后端 → 数据库全链路连通性。
"""
from fastapi import APIRouter, HTTPException
from app.database import get_connection
router = APIRouter(prefix="/api/xcx-test", tags=["小程序MVP"])
@router.get("")
async def get_xcx_test():
"""
读取 test."xcx-test" 表 ti 列第一行。
用于小程序 MVP 全链路验证:小程序 → API → DB → 返回数据。
"""
conn = get_connection()
try:
with conn.cursor() as cur:
# CHANGE 2026-02-19 | 读取 test schema 下的 xcx-test 表
# 表名含连字符,必须用双引号包裹
cur.execute('SELECT ti FROM test."xcx-test" LIMIT 1')
row = cur.fetchone()
finally:
conn.close()
if row is None:
raise HTTPException(status_code=404, detail="无数据")
return {"ti": row[0]}

View File

@@ -0,0 +1,30 @@
"""
认证相关 Pydantic 模型。
- LoginRequest登录请求体
- TokenResponse令牌响应体
- RefreshRequest刷新令牌请求体
"""
from pydantic import BaseModel, Field
class LoginRequest(BaseModel):
"""登录请求。"""
username: str = Field(..., min_length=1, max_length=64, description="用户名")
password: str = Field(..., min_length=1, description="密码")
class RefreshRequest(BaseModel):
"""刷新令牌请求。"""
refresh_token: str = Field(..., min_length=1, description="刷新令牌")
class TokenResponse(BaseModel):
"""令牌响应。"""
access_token: str
refresh_token: str
token_type: str = "bearer"

View File

@@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
"""数据库查看器 Pydantic 模型
定义 Schema 浏览、表结构查看、SQL 查询的请求/响应模型。
"""
from __future__ import annotations
from typing import Any
from pydantic import BaseModel
class SchemaInfo(BaseModel):
"""Schema 信息。"""
name: str
class TableInfo(BaseModel):
"""表信息(含行数统计)。"""
name: str
row_count: int | None = None
class ColumnInfo(BaseModel):
"""列定义。"""
name: str
data_type: str
is_nullable: bool
column_default: str | None = None
class QueryRequest(BaseModel):
"""SQL 查询请求。"""
sql: str
class QueryResponse(BaseModel):
"""SQL 查询响应。"""
columns: list[str]
rows: list[list[Any]]
row_count: int

View File

@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
"""ETL 状态监控 Pydantic 模型
定义游标信息和最近执行记录的响应模型。
"""
from __future__ import annotations
from pydantic import BaseModel
class CursorInfo(BaseModel):
"""ETL 游标信息(单条任务的最后抓取状态)。"""
task_code: str
last_fetch_time: str | None = None
record_count: int | None = None
class RecentRun(BaseModel):
"""最近执行记录。"""
id: str
task_codes: list[str]
status: str
started_at: str
finished_at: str | None = None
duration_ms: int | None = None
exit_code: int | None = None

View File

@@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
"""执行与队列相关的 Pydantic 模型
用于 execution 路由的请求/响应序列化。
"""
from __future__ import annotations
from datetime import datetime
from typing import Any
from pydantic import BaseModel
class ReorderRequest(BaseModel):
"""队列重排请求"""
task_id: str
new_position: int
class QueueTaskResponse(BaseModel):
"""队列任务响应"""
id: str
site_id: int
config: dict[str, Any]
status: str
position: int
created_at: datetime | None = None
started_at: datetime | None = None
finished_at: datetime | None = None
exit_code: int | None = None
error_message: str | None = None
class ExecutionRunResponse(BaseModel):
"""直接执行任务的响应"""
execution_id: str
message: str
class ExecutionHistoryItem(BaseModel):
"""执行历史记录"""
id: str
site_id: int
task_codes: list[str]
status: str
started_at: datetime
finished_at: datetime | None = None
exit_code: int | None = None
duration_ms: int | None = None
command: str | None = None
summary: dict[str, Any] | None = None
class ExecutionLogsResponse(BaseModel):
"""执行日志响应"""
execution_id: str
output_log: str | None = None
error_log: str | None = None

View File

@@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
"""调度配置 Pydantic 模型
定义 ScheduleConfigSchema 及相关模型,供调度服务和路由使用。
"""
from datetime import datetime
from typing import Any, Literal
from pydantic import BaseModel
class ScheduleConfigSchema(BaseModel):
"""调度配置 — 支持 5 种调度类型"""
schedule_type: Literal["once", "interval", "daily", "weekly", "cron"]
interval_value: int = 1
interval_unit: Literal["minutes", "hours", "days"] = "hours"
daily_time: str = "04:00"
weekly_days: list[int] = [1]
weekly_time: str = "04:00"
cron_expression: str = "0 4 * * *"
enabled: bool = True
start_date: str | None = None
end_date: str | None = None
class CreateScheduleRequest(BaseModel):
"""创建调度任务请求"""
name: str
task_codes: list[str]
task_config: dict[str, Any]
schedule_config: ScheduleConfigSchema
class UpdateScheduleRequest(BaseModel):
"""更新调度任务请求(所有字段可选)"""
name: str | None = None
task_codes: list[str] | None = None
task_config: dict[str, Any] | None = None
schedule_config: ScheduleConfigSchema | None = None
class ScheduleResponse(BaseModel):
"""调度任务响应"""
id: str
site_id: int
name: str
task_codes: list[str]
task_config: dict[str, Any]
schedule_config: dict[str, Any]
enabled: bool
last_run_at: datetime | None = None
next_run_at: datetime | None = None
run_count: int
last_status: str | None = None
created_at: datetime
updated_at: datetime

View File

@@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-
"""任务配置 Pydantic 模型
定义 TaskConfigSchema 及相关模型,用于前后端传输和 CLIBuilder 消费。
"""
from typing import Any
from pydantic import BaseModel, model_validator
class TaskConfigSchema(BaseModel):
"""任务配置 — 前后端传输格式
字段与 CLI 参数的映射关系:
- pipeline → --pipelineFlow ID7 种之一)
- processing_mode → --processing-mode3 种处理模式)
- tasks → --tasks逗号分隔
- dry_run → --dry-run布尔标志
- window_mode → 决定使用 lookback 还是 custom 时间窗口(仅前端逻辑,不直接映射 CLI 参数)
- window_start → --window-start
- window_end → --window-end
- window_split → --window-split
- window_split_days → --window-split-days
- lookback_hours → --lookback-hours
- overlap_seconds → --overlap-seconds
- fetch_before_verify → --fetch-before-verify布尔标志
- store_id → --store-id由后端从 JWT 注入,前端不传)
- dwd_only_tables → 传入 extra_args 或未来扩展
"""
tasks: list[str]
pipeline: str = "api_ods_dwd"
processing_mode: str = "increment_only"
dry_run: bool = False
window_mode: str = "lookback"
window_start: str | None = None
window_end: str | None = None
window_split: str | None = None
window_split_days: int | None = None
lookback_hours: int = 24
overlap_seconds: int = 600
fetch_before_verify: bool = False
skip_ods_when_fetch_before_verify: bool = False
ods_use_local_json: bool = False
store_id: int | None = None
dwd_only_tables: list[str] | None = None
force_full: bool = False
extra_args: dict[str, Any] = {}
@model_validator(mode="after")
def validate_window(self) -> "TaskConfigSchema":
"""验证时间窗口:结束日期不早于开始日期"""
if self.window_start and self.window_end:
if self.window_end < self.window_start:
raise ValueError("window_end 不能早于 window_start")
return self
class FlowDefinition(BaseModel):
"""执行流程Flow定义"""
id: str
name: str
layers: list[str]
class ProcessingModeDefinition(BaseModel):
"""处理模式定义"""
id: str
name: str
description: str

View File

@@ -0,0 +1 @@
# -*- coding: utf-8 -*-

View File

@@ -0,0 +1,158 @@
# -*- coding: utf-8 -*-
"""CLI 命令构建器
从 gui/utils/cli_builder.py 迁移,适配后端 TaskConfigSchema。
将 TaskConfigSchema 转换为 ETL CLI 命令行参数列表。
支持:
- 7 种 Flowapi_ods / api_ods_dwd / api_full / ods_dwd / dwd_dws / dwd_dws_index / dwd_index
- 3 种处理模式increment_only / verify_only / increment_verify
- 自动注入 --store-id 参数
"""
from typing import Any
from ..schemas.tasks import TaskConfigSchema
# 有效的 Flow ID 集合
VALID_FLOWS: set[str] = {
"api_ods",
"api_ods_dwd",
"api_full",
"ods_dwd",
"dwd_dws",
"dwd_dws_index",
"dwd_index",
}
# 有效的处理模式集合
VALID_PROCESSING_MODES: set[str] = {
"increment_only",
"verify_only",
"increment_verify",
}
# CLI 支持的 extra_args 键(值类型 + 布尔类型)
CLI_SUPPORTED_ARGS: set[str] = {
# 值类型参数
"pg_dsn", "pg_host", "pg_port", "pg_name",
"pg_user", "pg_password", "api_base", "api_token", "api_timeout",
"api_page_size", "api_retry_max",
"export_root", "log_root", "fetch_root",
"ingest_source", "idle_start", "idle_end",
"data_source", "pipeline_flow",
"window_split_unit",
# 布尔类型参数
"force_window_override", "write_pretty_json", "allow_empty_advance",
}
class CLIBuilder:
"""将 TaskConfigSchema 转换为 ETL CLI 命令行参数列表"""
def build_command(
self,
config: TaskConfigSchema,
etl_project_path: str,
python_executable: str = "python",
) -> list[str]:
"""构建完整的 CLI 命令参数列表。
生成格式:
[python, -m, cli.main, --flow, {flow_id}, --tasks, ..., --store-id, {site_id}, ...]
Args:
config: 任务配置对象Pydantic 模型)
etl_project_path: ETL 项目根目录路径(用于 cwd不拼入命令
python_executable: Python 可执行文件路径,默认 "python"
Returns:
命令行参数列表
"""
cmd: list[str] = [python_executable, "-m", "cli.main"]
# -- Flow执行流程 --
cmd.extend(["--flow", config.pipeline])
# -- 处理模式 --
if config.processing_mode:
cmd.extend(["--processing-mode", config.processing_mode])
# -- 任务列表 --
if config.tasks:
cmd.extend(["--tasks", ",".join(config.tasks)])
# -- 校验前从 API 获取数据(仅 verify_only 模式有效) --
if config.fetch_before_verify and config.processing_mode == "verify_only":
cmd.append("--fetch-before-verify")
# -- 时间窗口 --
if config.window_mode == "lookback":
# 回溯模式
if config.lookback_hours is not None:
cmd.extend(["--lookback-hours", str(config.lookback_hours)])
if config.overlap_seconds is not None:
cmd.extend(["--overlap-seconds", str(config.overlap_seconds)])
else:
# 自定义时间窗口
if config.window_start:
cmd.extend(["--window-start", config.window_start])
if config.window_end:
cmd.extend(["--window-end", config.window_end])
# -- 时间窗口切分 --
if config.window_split and config.window_split != "none":
cmd.extend(["--window-split", config.window_split])
if config.window_split_days is not None:
cmd.extend(["--window-split-days", str(config.window_split_days)])
# -- Dry-run --
if config.dry_run:
cmd.append("--dry-run")
# -- 强制全量处理 --
if config.force_full:
cmd.append("--force-full")
# -- 本地 JSON 模式 → --data-source offline --
if config.ods_use_local_json:
cmd.extend(["--data-source", "offline"])
# -- 门店 ID自动注入 --
if config.store_id is not None:
cmd.extend(["--store-id", str(config.store_id)])
# -- 额外参数(只传递 CLI 支持的参数) --
for key, value in config.extra_args.items():
if value is not None and key in CLI_SUPPORTED_ARGS:
arg_name = f"--{key.replace('_', '-')}"
if isinstance(value, bool):
if value:
cmd.append(arg_name)
else:
cmd.extend([arg_name, str(value)])
return cmd
def build_command_string(
self,
config: TaskConfigSchema,
etl_project_path: str,
python_executable: str = "python",
) -> str:
"""构建命令行字符串(用于显示/日志记录)。
对包含空格的参数自动添加引号。
"""
cmd = self.build_command(config, etl_project_path, python_executable)
quoted: list[str] = []
for arg in cmd:
if " " in arg or '"' in arg:
quoted.append(f'"{arg}"')
else:
quoted.append(arg)
return " ".join(quoted)
# 全局单例
cli_builder = CLIBuilder()

View File

@@ -0,0 +1,303 @@
# -*- coding: utf-8 -*-
"""调度器服务
后台 asyncio 循环,每 30 秒检查一次到期的调度任务,
将其 TaskConfig 入队到 TaskQueue。
核心逻辑:
- check_and_enqueue():查询 enabled=true 且 next_run_at <= now 的调度任务
- start() / stop():管理后台循环生命周期
- _calculate_next_run():根据 ScheduleConfig 计算下次执行时间
"""
from __future__ import annotations
import asyncio
import json
import logging
from datetime import datetime, timedelta, timezone
from ..database import get_connection
from ..schemas.schedules import ScheduleConfigSchema
from ..schemas.tasks import TaskConfigSchema
from .task_queue import task_queue
logger = logging.getLogger(__name__)
# 调度器轮询间隔(秒)
SCHEDULER_POLL_INTERVAL = 30
def _parse_time(time_str: str) -> tuple[int, int]:
"""解析 HH:MM 格式的时间字符串,返回 (hour, minute)。"""
parts = time_str.split(":")
return int(parts[0]), int(parts[1])
def calculate_next_run(
schedule_config: ScheduleConfigSchema,
now: datetime | None = None,
) -> datetime | None:
"""根据调度配置计算下次执行时间。
Args:
schedule_config: 调度配置
now: 当前时间(默认 UTC now方便测试注入
Returns:
下次执行时间UTConce 类型返回 None 表示不再执行
"""
if now is None:
now = datetime.now(timezone.utc)
stype = schedule_config.schedule_type
if stype == "once":
# 一次性任务执行后不再调度
return None
if stype == "interval":
unit_map = {
"minutes": timedelta(minutes=schedule_config.interval_value),
"hours": timedelta(hours=schedule_config.interval_value),
"days": timedelta(days=schedule_config.interval_value),
}
delta = unit_map.get(schedule_config.interval_unit)
if delta is None:
logger.warning("未知的 interval_unit: %s", schedule_config.interval_unit)
return None
return now + delta
if stype == "daily":
hour, minute = _parse_time(schedule_config.daily_time)
# 计算明天的 daily_time
tomorrow = now + timedelta(days=1)
return tomorrow.replace(hour=hour, minute=minute, second=0, microsecond=0)
if stype == "weekly":
hour, minute = _parse_time(schedule_config.weekly_time)
days = sorted(schedule_config.weekly_days) if schedule_config.weekly_days else [1]
# ISO weekday: 1=Monday ... 7=Sunday
current_weekday = now.isoweekday()
# 找到下一个匹配的 weekday
for day in days:
if day > current_weekday:
delta_days = day - current_weekday
next_dt = now + timedelta(days=delta_days)
return next_dt.replace(hour=hour, minute=minute, second=0, microsecond=0)
# 本周没有更晚的 weekday跳到下周第一个
first_day = days[0]
delta_days = 7 - current_weekday + first_day
next_dt = now + timedelta(days=delta_days)
return next_dt.replace(hour=hour, minute=minute, second=0, microsecond=0)
if stype == "cron":
# 简单 cron 解析:仅支持 "minute hour * * *" 格式(每日定时)
# 复杂 cron 表达式可后续引入 croniter 库
return _parse_simple_cron(schedule_config.cron_expression, now)
logger.warning("未知的 schedule_type: %s", stype)
return None
def _parse_simple_cron(expression: str, now: datetime) -> datetime | None:
"""简单 cron 解析器,支持基本的 5 字段格式。
支持的格式:
- "M H * * *" → 每天 H:M
- "M H * * D" → 每周 D 的 H:MD 为 0-60=Sunday
- 其他格式回退到每天 04:00
不支持范围、列表、步进等高级语法。如需完整 cron 支持,
可在 pyproject.toml 中添加 croniter 依赖。
"""
parts = expression.strip().split()
if len(parts) != 5:
logger.warning("无法解析 cron 表达式: %s,回退到明天 04:00", expression)
tomorrow = now + timedelta(days=1)
return tomorrow.replace(hour=4, minute=0, second=0, microsecond=0)
minute_str, hour_str, dom, month, dow = parts
try:
minute = int(minute_str) if minute_str != "*" else 0
hour = int(hour_str) if hour_str != "*" else 0
except ValueError:
logger.warning("cron 表达式时间字段无法解析: %s,回退到明天 04:00", expression)
tomorrow = now + timedelta(days=1)
return tomorrow.replace(hour=4, minute=0, second=0, microsecond=0)
# 如果指定了 day-of-week非 *
if dow != "*":
try:
cron_dow = int(dow) # 0=Sunday, 1=Monday, ..., 6=Saturday
except ValueError:
tomorrow = now + timedelta(days=1)
return tomorrow.replace(hour=hour, minute=minute, second=0, microsecond=0)
# 转换为 ISO weekday1=Monday, 7=Sunday
iso_dow = 7 if cron_dow == 0 else cron_dow
current_iso = now.isoweekday()
if iso_dow > current_iso:
delta_days = iso_dow - current_iso
elif iso_dow < current_iso:
delta_days = 7 - current_iso + iso_dow
else:
# 同一天,看时间是否已过
target_today = now.replace(hour=hour, minute=minute, second=0, microsecond=0)
if now < target_today:
delta_days = 0
else:
delta_days = 7
next_dt = now + timedelta(days=delta_days)
return next_dt.replace(hour=hour, minute=minute, second=0, microsecond=0)
# 每天定时dom=* month=* dow=*
tomorrow = now + timedelta(days=1)
return tomorrow.replace(hour=hour, minute=minute, second=0, microsecond=0)
class Scheduler:
"""基于 PostgreSQL 的定时调度器
后台 asyncio 循环每 SCHEDULER_POLL_INTERVAL 秒检查一次到期任务,
将其 TaskConfig 入队到 TaskQueue。
"""
def __init__(self) -> None:
self._running = False
self._loop_task: asyncio.Task | None = None
# ------------------------------------------------------------------
# 核心:检查到期任务并入队
# ------------------------------------------------------------------
def check_and_enqueue(self) -> int:
"""查询 enabled=true 且 next_run_at <= now 的调度任务,将其入队。
Returns:
本次入队的任务数量
"""
conn = get_connection()
enqueued = 0
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, site_id, task_config, schedule_config
FROM scheduled_tasks
WHERE enabled = TRUE
AND next_run_at IS NOT NULL
AND next_run_at <= NOW()
ORDER BY next_run_at ASC
"""
)
rows = cur.fetchall()
for row in rows:
task_id = str(row[0])
site_id = row[1]
task_config_raw = row[2] if isinstance(row[2], dict) else json.loads(row[2])
schedule_config_raw = row[3] if isinstance(row[3], dict) else json.loads(row[3])
try:
config = TaskConfigSchema(**task_config_raw)
schedule_cfg = ScheduleConfigSchema(**schedule_config_raw)
except Exception:
logger.exception("调度任务 [%s] 配置反序列化失败,跳过", task_id)
continue
# 入队
try:
queue_id = task_queue.enqueue(config, site_id)
logger.info(
"调度任务 [%s] 入队成功 → queue_id=%s site_id=%s",
task_id, queue_id, site_id,
)
enqueued += 1
except Exception:
logger.exception("调度任务 [%s] 入队失败", task_id)
continue
# 更新调度任务状态
now = datetime.now(timezone.utc)
next_run = calculate_next_run(schedule_cfg, now)
with conn.cursor() as cur:
cur.execute(
"""
UPDATE scheduled_tasks
SET last_run_at = NOW(),
run_count = run_count + 1,
next_run_at = %s,
last_status = 'enqueued',
updated_at = NOW()
WHERE id = %s
""",
(next_run, task_id),
)
conn.commit()
except Exception:
logger.exception("check_and_enqueue 执行异常")
try:
conn.rollback()
except Exception:
pass
finally:
conn.close()
if enqueued > 0:
logger.info("本轮调度检查:%d 个任务入队", enqueued)
return enqueued
# ------------------------------------------------------------------
# 后台循环
# ------------------------------------------------------------------
async def _loop(self) -> None:
"""后台 asyncio 循环,每 SCHEDULER_POLL_INTERVAL 秒检查一次。"""
self._running = True
logger.info("Scheduler 后台循环启动(间隔 %ds", SCHEDULER_POLL_INTERVAL)
while self._running:
try:
# 在线程池中执行同步数据库操作,避免阻塞事件循环
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self.check_and_enqueue)
except Exception:
logger.exception("Scheduler 循环迭代异常")
await asyncio.sleep(SCHEDULER_POLL_INTERVAL)
logger.info("Scheduler 后台循环停止")
# ------------------------------------------------------------------
# 生命周期
# ------------------------------------------------------------------
def start(self) -> None:
"""启动后台调度循环(在 FastAPI lifespan 中调用)。"""
if self._loop_task is None or self._loop_task.done():
self._loop_task = asyncio.create_task(self._loop())
logger.info("Scheduler 已启动")
async def stop(self) -> None:
"""停止后台调度循环。"""
self._running = False
if self._loop_task and not self._loop_task.done():
self._loop_task.cancel()
try:
await self._loop_task
except asyncio.CancelledError:
pass
self._loop_task = None
logger.info("Scheduler 已停止")
# 全局单例
scheduler = Scheduler()

View File

@@ -0,0 +1,391 @@
# -*- coding: utf-8 -*-
"""ETL 任务执行器
通过 asyncio.create_subprocess_exec 启动 ETL CLI 子进程,
逐行读取 stdout/stderr 并广播到 WebSocket 订阅者,
执行完成后将结果写入 task_execution_log 表。
设计要点:
- 每个 execution_id 对应一个子进程,存储在 _processes 字典中
- 日志行存储在内存缓冲区 _log_buffers 中
- WebSocket 订阅者通过 asyncio.Queue 接收实时日志
- Windows 兼容:取消时使用 process.terminate() 而非 SIGTERM
"""
from __future__ import annotations
import asyncio
import logging
import subprocess
import sys
import threading
import time
from datetime import datetime, timezone
from typing import Any
from ..config import ETL_PROJECT_PATH
from ..database import get_connection
from ..schemas.tasks import TaskConfigSchema
from ..services.cli_builder import cli_builder
logger = logging.getLogger(__name__)
class TaskExecutor:
"""管理 ETL CLI 子进程的生命周期"""
def __init__(self) -> None:
# execution_id → subprocess.Popen
self._processes: dict[str, subprocess.Popen] = {}
# execution_id → list[str]stdout + stderr 混合日志)
self._log_buffers: dict[str, list[str]] = {}
# execution_id → set[asyncio.Queue]WebSocket 订阅者)
self._subscribers: dict[str, set[asyncio.Queue[str | None]]] = {}
# ------------------------------------------------------------------
# WebSocket 订阅管理
# ------------------------------------------------------------------
def subscribe(self, execution_id: str) -> asyncio.Queue[str | None]:
"""注册一个 WebSocket 订阅者,返回用于读取日志行的 Queue。
Queue 中推送 str 表示日志行None 表示执行结束。
"""
if execution_id not in self._subscribers:
self._subscribers[execution_id] = set()
queue: asyncio.Queue[str | None] = asyncio.Queue()
self._subscribers[execution_id].add(queue)
return queue
def unsubscribe(self, execution_id: str, queue: asyncio.Queue[str | None]) -> None:
"""移除一个 WebSocket 订阅者。"""
subs = self._subscribers.get(execution_id)
if subs:
subs.discard(queue)
if not subs:
del self._subscribers[execution_id]
def _broadcast(self, execution_id: str, line: str) -> None:
"""向所有订阅者广播一行日志。"""
subs = self._subscribers.get(execution_id)
if subs:
for q in subs:
q.put_nowait(line)
def _broadcast_end(self, execution_id: str) -> None:
"""通知所有订阅者执行已结束(发送 None 哨兵)。"""
subs = self._subscribers.get(execution_id)
if subs:
for q in subs:
q.put_nowait(None)
# ------------------------------------------------------------------
# 日志缓冲区
# ------------------------------------------------------------------
def get_logs(self, execution_id: str) -> list[str]:
"""获取指定执行的内存日志缓冲区(副本)。"""
return list(self._log_buffers.get(execution_id, []))
# ------------------------------------------------------------------
# 执行状态查询
# ------------------------------------------------------------------
def is_running(self, execution_id: str) -> bool:
"""判断指定执行是否仍在运行。"""
proc = self._processes.get(execution_id)
if proc is None:
return False
return proc.poll() is None
def get_running_ids(self) -> list[str]:
"""返回当前所有运行中的 execution_id 列表。"""
return [eid for eid, p in self._processes.items() if p.returncode is None]
# ------------------------------------------------------------------
# 核心执行
# ------------------------------------------------------------------
async def execute(
self,
config: TaskConfigSchema,
execution_id: str,
queue_id: str | None = None,
site_id: int | None = None,
) -> None:
"""以子进程方式调用 ETL CLI。
使用 subprocess.Popen + 线程读取,兼容 Windows避免
asyncio.create_subprocess_exec 在 Windows 上的 NotImplementedError
"""
cmd = cli_builder.build_command(
config, ETL_PROJECT_PATH, python_executable=sys.executable
)
command_str = " ".join(cmd)
effective_site_id = site_id or config.store_id
logger.info(
"启动 ETL 子进程 [%s]: %s (cwd=%s)",
execution_id, command_str, ETL_PROJECT_PATH,
)
self._log_buffers[execution_id] = []
started_at = datetime.now(timezone.utc)
t0 = time.monotonic()
self._write_execution_log(
execution_id=execution_id,
queue_id=queue_id,
site_id=effective_site_id,
task_codes=config.tasks,
status="running",
started_at=started_at,
command=command_str,
)
exit_code: int | None = None
status = "running"
stdout_lines: list[str] = []
stderr_lines: list[str] = []
try:
# 构建额外环境变量DWD 表过滤通过环境变量注入)
extra_env: dict[str, str] = {}
if config.dwd_only_tables:
extra_env["DWD_ONLY_TABLES"] = ",".join(config.dwd_only_tables)
# 在线程池中运行子进程,兼容 Windows
exit_code = await asyncio.get_event_loop().run_in_executor(
None,
self._run_subprocess,
cmd,
execution_id,
stdout_lines,
stderr_lines,
extra_env or None,
)
if exit_code == 0:
status = "success"
else:
status = "failed"
logger.info(
"ETL 子进程 [%s] 退出exit_code=%s, status=%s",
execution_id, exit_code, status,
)
except asyncio.CancelledError:
status = "cancelled"
logger.info("ETL 子进程 [%s] 已取消", execution_id)
# 尝试终止子进程
proc = self._processes.get(execution_id)
if proc and proc.poll() is None:
proc.terminate()
except Exception as exc:
status = "failed"
import traceback
tb = traceback.format_exc()
stderr_lines.append(f"[task_executor] 子进程启动/执行异常: {exc}")
stderr_lines.append(tb)
logger.exception("ETL 子进程 [%s] 执行异常", execution_id)
finally:
elapsed_ms = int((time.monotonic() - t0) * 1000)
finished_at = datetime.now(timezone.utc)
self._broadcast_end(execution_id)
self._processes.pop(execution_id, None)
self._update_execution_log(
execution_id=execution_id,
status=status,
finished_at=finished_at,
exit_code=exit_code,
duration_ms=elapsed_ms,
output_log="\n".join(stdout_lines),
error_log="\n".join(stderr_lines),
)
def _run_subprocess(
self,
cmd: list[str],
execution_id: str,
stdout_lines: list[str],
stderr_lines: list[str],
extra_env: dict[str, str] | None = None,
) -> int:
"""在线程中运行子进程并逐行读取输出。"""
import os
env = os.environ.copy()
# 强制子进程使用 UTF-8 输出,避免 Windows GBK 乱码
env["PYTHONIOENCODING"] = "utf-8"
if extra_env:
env.update(extra_env)
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=ETL_PROJECT_PATH,
env=env,
text=True,
encoding="utf-8",
errors="replace",
)
self._processes[execution_id] = proc
def read_stream(
stream, stream_name: str, collector: list[str],
) -> None:
"""逐行读取流并广播。"""
for raw_line in stream:
line = raw_line.rstrip("\n").rstrip("\r")
tagged = f"[{stream_name}] {line}"
buf = self._log_buffers.get(execution_id)
if buf is not None:
buf.append(tagged)
collector.append(line)
self._broadcast(execution_id, tagged)
t_out = threading.Thread(
target=read_stream, args=(proc.stdout, "stdout", stdout_lines),
daemon=True,
)
t_err = threading.Thread(
target=read_stream, args=(proc.stderr, "stderr", stderr_lines),
daemon=True,
)
t_out.start()
t_err.start()
proc.wait()
t_out.join(timeout=5)
t_err.join(timeout=5)
return proc.returncode
# ------------------------------------------------------------------
# 取消
# ------------------------------------------------------------------
async def cancel(self, execution_id: str) -> bool:
"""向子进程发送终止信号。
Returns:
True 表示成功发送终止信号False 表示进程不存在或已退出。
"""
proc = self._processes.get(execution_id)
if proc is None:
return False
# subprocess.Popen: poll() 返回 None 表示仍在运行
if proc.poll() is not None:
return False
logger.info("取消 ETL 子进程 [%s], pid=%s", execution_id, proc.pid)
try:
proc.terminate()
except ProcessLookupError:
return False
return True
# ------------------------------------------------------------------
# 数据库操作(同步,在线程池中执行也可,此处简单直连)
# ------------------------------------------------------------------
@staticmethod
def _write_execution_log(
*,
execution_id: str,
queue_id: str | None,
site_id: int | None,
task_codes: list[str],
status: str,
started_at: datetime,
command: str,
) -> None:
"""插入一条执行日志记录running 状态)。"""
try:
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO task_execution_log
(id, queue_id, site_id, task_codes, status,
started_at, command)
VALUES (%s, %s, %s, %s, %s, %s, %s)
""",
(
execution_id,
queue_id,
site_id or 0,
task_codes,
status,
started_at,
command,
),
)
conn.commit()
finally:
conn.close()
except Exception:
logger.exception("写入 execution_log 失败 [%s]", execution_id)
@staticmethod
def _update_execution_log(
*,
execution_id: str,
status: str,
finished_at: datetime,
exit_code: int | None,
duration_ms: int,
output_log: str,
error_log: str,
) -> None:
"""更新执行日志记录(完成状态)。"""
try:
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE task_execution_log
SET status = %s,
finished_at = %s,
exit_code = %s,
duration_ms = %s,
output_log = %s,
error_log = %s
WHERE id = %s
""",
(
status,
finished_at,
exit_code,
duration_ms,
output_log,
error_log,
execution_id,
),
)
conn.commit()
finally:
conn.close()
except Exception:
logger.exception("更新 execution_log 失败 [%s]", execution_id)
# ------------------------------------------------------------------
# 清理
# ------------------------------------------------------------------
def cleanup(self, execution_id: str) -> None:
"""清理指定执行的内存资源(日志缓冲区和订阅者)。
通常在确认日志已持久化后调用。
"""
self._log_buffers.pop(execution_id, None)
self._subscribers.pop(execution_id, None)
# 全局单例
task_executor = TaskExecutor()

View File

@@ -0,0 +1,486 @@
# -*- coding: utf-8 -*-
"""任务队列服务
基于 PostgreSQL task_queue 表实现 FIFO 队列,支持:
- enqueue入队自动分配 position当前最大 + 1
- dequeue取出 position 最小的 pending 任务
- reorder调整任务在队列中的位置
- delete删除 pending 任务
- process_loop后台协程队列非空且无运行中任务时自动取出执行
所有操作按 site_id 过滤,实现门店隔离。
"""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from dataclasses import dataclass, field
from typing import Any
from ..database import get_connection
from ..schemas.tasks import TaskConfigSchema
logger = logging.getLogger(__name__)
# 后台循环轮询间隔(秒)
POLL_INTERVAL_SECONDS = 2
@dataclass
class QueuedTask:
"""队列任务数据对象"""
id: str
site_id: int
config: dict[str, Any]
status: str
position: int
created_at: Any = None
started_at: Any = None
finished_at: Any = None
exit_code: int | None = None
error_message: str | None = None
class TaskQueue:
"""基于 PostgreSQL 的任务队列"""
def __init__(self) -> None:
self._running = False
self._loop_task: asyncio.Task | None = None
# ------------------------------------------------------------------
# 入队
# ------------------------------------------------------------------
def enqueue(self, config: TaskConfigSchema, site_id: int) -> str:
"""将任务配置入队,自动分配 position。
Args:
config: 任务配置
site_id: 门店 ID门店隔离
Returns:
新创建的队列任务 IDUUID 字符串)
"""
task_id = str(uuid.uuid4())
config_json = config.model_dump(mode="json")
conn = get_connection()
try:
with conn.cursor() as cur:
# 取当前该门店 pending 任务的最大 position新任务排在末尾
cur.execute(
"""
SELECT COALESCE(MAX(position), 0)
FROM task_queue
WHERE site_id = %s AND status = 'pending'
""",
(site_id,),
)
max_pos = cur.fetchone()[0]
new_pos = max_pos + 1
cur.execute(
"""
INSERT INTO task_queue (id, site_id, config, status, position)
VALUES (%s, %s, %s, 'pending', %s)
""",
(task_id, site_id, json.dumps(config_json), new_pos),
)
conn.commit()
finally:
conn.close()
logger.info("任务入队 [%s] site_id=%s position=%s", task_id, site_id, new_pos)
return task_id
# ------------------------------------------------------------------
# 出队
# ------------------------------------------------------------------
def dequeue(self, site_id: int) -> QueuedTask | None:
"""取出 position 最小的 pending 任务,将其状态改为 running。
Args:
site_id: 门店 ID
Returns:
QueuedTask 或 None队列为空时
"""
conn = get_connection()
try:
with conn.cursor() as cur:
# 选取 position 最小的 pending 任务并锁定
cur.execute(
"""
SELECT id, site_id, config, status, position,
created_at, started_at, finished_at,
exit_code, error_message
FROM task_queue
WHERE site_id = %s AND status = 'pending'
ORDER BY position ASC
LIMIT 1
FOR UPDATE SKIP LOCKED
""",
(site_id,),
)
row = cur.fetchone()
if row is None:
conn.commit()
return None
task = QueuedTask(
id=str(row[0]),
site_id=row[1],
config=row[2] if isinstance(row[2], dict) else json.loads(row[2]),
status=row[3],
position=row[4],
created_at=row[5],
started_at=row[6],
finished_at=row[7],
exit_code=row[8],
error_message=row[9],
)
# 更新状态为 running
cur.execute(
"""
UPDATE task_queue
SET status = 'running', started_at = NOW()
WHERE id = %s
""",
(task.id,),
)
conn.commit()
finally:
conn.close()
task.status = "running"
logger.info("任务出队 [%s] site_id=%s", task.id, site_id)
return task
# ------------------------------------------------------------------
# 重排
# ------------------------------------------------------------------
def reorder(self, task_id: str, new_position: int, site_id: int) -> None:
"""调整任务在队列中的位置。
仅允许对 pending 状态的任务重排。将目标任务移到 new_position
其余 pending 任务按原有相对顺序重新编号。
Args:
task_id: 要移动的任务 ID
new_position: 目标位置1-based
site_id: 门店 ID
"""
conn = get_connection()
try:
with conn.cursor() as cur:
# 获取该门店所有 pending 任务,按 position 排序
cur.execute(
"""
SELECT id FROM task_queue
WHERE site_id = %s AND status = 'pending'
ORDER BY position ASC
""",
(site_id,),
)
rows = cur.fetchall()
task_ids = [str(r[0]) for r in rows]
if task_id not in task_ids:
conn.commit()
return
# 从列表中移除目标任务,再插入到新位置
task_ids.remove(task_id)
# new_position 是 1-based转为 0-based 索引并 clamp
insert_idx = max(0, min(new_position - 1, len(task_ids)))
task_ids.insert(insert_idx, task_id)
# 按新顺序重新分配 position1-based 连续编号)
for idx, tid in enumerate(task_ids, start=1):
cur.execute(
"UPDATE task_queue SET position = %s WHERE id = %s",
(idx, tid),
)
conn.commit()
finally:
conn.close()
logger.info(
"任务重排 [%s] → position=%s site_id=%s",
task_id, new_position, site_id,
)
# ------------------------------------------------------------------
# 删除
# ------------------------------------------------------------------
def delete(self, task_id: str, site_id: int) -> bool:
"""删除 pending 状态的任务。
Args:
task_id: 任务 ID
site_id: 门店 ID
Returns:
True 表示成功删除False 表示任务不存在或非 pending 状态。
"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
DELETE FROM task_queue
WHERE id = %s AND site_id = %s AND status = 'pending'
""",
(task_id, site_id),
)
deleted = cur.rowcount > 0
conn.commit()
finally:
conn.close()
if deleted:
logger.info("任务删除 [%s] site_id=%s", task_id, site_id)
else:
logger.warning(
"任务删除失败 [%s] site_id=%s(不存在或非 pending",
task_id, site_id,
)
return deleted
# ------------------------------------------------------------------
# 查询
# ------------------------------------------------------------------
def list_pending(self, site_id: int) -> list[QueuedTask]:
"""列出指定门店的所有 pending 任务,按 position 升序。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, site_id, config, status, position,
created_at, started_at, finished_at,
exit_code, error_message
FROM task_queue
WHERE site_id = %s AND status = 'pending'
ORDER BY position ASC
""",
(site_id,),
)
rows = cur.fetchall()
conn.commit()
finally:
conn.close()
return [
QueuedTask(
id=str(r[0]),
site_id=r[1],
config=r[2] if isinstance(r[2], dict) else json.loads(r[2]),
status=r[3],
position=r[4],
created_at=r[5],
started_at=r[6],
finished_at=r[7],
exit_code=r[8],
error_message=r[9],
)
for r in rows
]
def has_running(self, site_id: int) -> bool:
"""检查指定门店是否有 running 状态的任务。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT EXISTS(
SELECT 1 FROM task_queue
WHERE site_id = %s AND status = 'running'
)
""",
(site_id,),
)
result = cur.fetchone()[0]
conn.commit()
finally:
conn.close()
return result
# ------------------------------------------------------------------
# 后台处理循环
# ------------------------------------------------------------------
async def process_loop(self) -> None:
"""后台协程:队列非空且无运行中任务时,自动取出并执行。
循环逻辑:
1. 查询所有有 pending 任务的 site_id
2. 对每个 site_id若无 running 任务则 dequeue 并执行
3. 等待 POLL_INTERVAL_SECONDS 后重复
"""
# 延迟导入避免循环依赖
from .task_executor import task_executor
self._running = True
logger.info("TaskQueue process_loop 启动")
while self._running:
try:
await self._process_once(task_executor)
except Exception:
logger.exception("process_loop 迭代异常")
await asyncio.sleep(POLL_INTERVAL_SECONDS)
logger.info("TaskQueue process_loop 停止")
async def _process_once(self, executor: Any) -> None:
"""单次处理:扫描所有门店的 pending 队列并执行。"""
site_ids = self._get_pending_site_ids()
for site_id in site_ids:
if self.has_running(site_id):
continue
task = self.dequeue(site_id)
if task is None:
continue
config = TaskConfigSchema(**task.config)
execution_id = str(uuid.uuid4())
logger.info(
"process_loop 自动执行 [%s] queue_id=%s site_id=%s",
execution_id, task.id, site_id,
)
# 异步启动执行(不阻塞循环)
asyncio.create_task(
self._execute_and_update(
executor, config, execution_id, task.id, site_id,
)
)
async def _execute_and_update(
self,
executor: Any,
config: TaskConfigSchema,
execution_id: str,
queue_id: str,
site_id: int,
) -> None:
"""执行任务并更新队列状态。"""
try:
await executor.execute(
config=config,
execution_id=execution_id,
queue_id=queue_id,
site_id=site_id,
)
# 执行完成后根据 executor 的结果更新 task_queue 状态
self._update_queue_status_from_log(queue_id)
except Exception:
logger.exception("队列任务执行异常 [%s]", queue_id)
self._mark_failed(queue_id, "执行过程中发生未捕获异常")
def _get_pending_site_ids(self) -> list[int]:
"""获取所有有 pending 任务的 site_id 列表。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT DISTINCT site_id FROM task_queue
WHERE status = 'pending'
"""
)
rows = cur.fetchall()
conn.commit()
finally:
conn.close()
return [r[0] for r in rows]
def _update_queue_status_from_log(self, queue_id: str) -> None:
"""从 task_execution_log 读取执行结果,同步到 task_queue 记录。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT status, finished_at, exit_code, error_log
FROM task_execution_log
WHERE queue_id = %s
ORDER BY started_at DESC
LIMIT 1
""",
(queue_id,),
)
row = cur.fetchone()
if row:
cur.execute(
"""
UPDATE task_queue
SET status = %s, finished_at = %s,
exit_code = %s, error_message = %s
WHERE id = %s
""",
(row[0], row[1], row[2], row[3], queue_id),
)
conn.commit()
finally:
conn.close()
def _mark_failed(self, queue_id: str, error_message: str) -> None:
"""将队列任务标记为 failed。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE task_queue
SET status = 'failed', finished_at = NOW(),
error_message = %s
WHERE id = %s
""",
(error_message, queue_id),
)
conn.commit()
finally:
conn.close()
# ------------------------------------------------------------------
# 生命周期
# ------------------------------------------------------------------
def start(self) -> None:
"""启动后台处理循环(在 FastAPI lifespan 中调用)。"""
if self._loop_task is None or self._loop_task.done():
self._loop_task = asyncio.create_task(self.process_loop())
logger.info("TaskQueue 后台循环已启动")
async def stop(self) -> None:
"""停止后台处理循环。"""
self._running = False
if self._loop_task and not self._loop_task.done():
self._loop_task.cancel()
try:
await self._loop_task
except asyncio.CancelledError:
pass
self._loop_task = None
logger.info("TaskQueue 后台循环已停止")
# 全局单例
task_queue = TaskQueue()

View File

@@ -0,0 +1,221 @@
# -*- coding: utf-8 -*-
"""静态任务注册表
从 ETL orchestration/task_registry.py 提取的任务元数据硬编码副本。
后端不直接导入 ETL 代码,避免引入重量级依赖链。
业务域分组逻辑:按任务代码前缀 / 目标表语义归类,与 GUI 保持一致。
"""
from __future__ import annotations
from dataclasses import dataclass, field
@dataclass(frozen=True)
class TaskDefinition:
"""单个 ETL 任务的元数据"""
code: str
name: str
description: str
domain: str # 业务域:会员 / 结算 / 助教 / 商品 / 台桌 / 团购 / 库存 / 财务 / 指数 / 工具
layer: str # ODS / DWD / DWS / INDEX / UTILITY
requires_window: bool = True
is_ods: bool = False
is_dimension: bool = False
default_enabled: bool = True
is_common: bool = True # 常用任务标记False 表示工具类/手动类任务
@dataclass(frozen=True)
class DwdTableDefinition:
"""DWD 表元数据"""
table_name: str # 完整表名(含 schema
display_name: str
domain: str
ods_source: str # 对应的 ODS 源表
is_dimension: bool = False
# ── ODS 任务定义 ──────────────────────────────────────────────
ODS_TASKS: list[TaskDefinition] = [
TaskDefinition("ODS_ASSISTANT_ACCOUNT", "助教账号", "抽取助教账号主数据", "助教", "ODS", is_ods=True),
TaskDefinition("ODS_ASSISTANT_LEDGER", "助教服务记录", "抽取助教服务流水", "助教", "ODS", is_ods=True),
TaskDefinition("ODS_ASSISTANT_ABOLISH", "助教取消记录", "抽取助教取消/作废记录", "助教", "ODS", is_ods=True),
TaskDefinition("ODS_SETTLEMENT_RECORDS", "结算记录", "抽取订单结算记录", "结算", "ODS", is_ods=True),
TaskDefinition("ODS_SETTLEMENT_TICKET", "结账小票", "抽取结账小票明细", "结算", "ODS", is_ods=True),
TaskDefinition("ODS_TABLE_USE", "台费流水", "抽取台费使用流水", "台桌", "ODS", is_ods=True),
TaskDefinition("ODS_TABLE_FEE_DISCOUNT", "台费折扣", "抽取台费折扣记录", "台桌", "ODS", is_ods=True),
TaskDefinition("ODS_TABLES", "台桌主数据", "抽取门店台桌信息", "台桌", "ODS", is_ods=True, requires_window=False),
TaskDefinition("ODS_PAYMENT", "支付流水", "抽取支付交易记录", "结算", "ODS", is_ods=True),
TaskDefinition("ODS_REFUND", "退款流水", "抽取退款交易记录", "结算", "ODS", is_ods=True),
TaskDefinition("ODS_PLATFORM_COUPON", "平台券核销", "抽取平台优惠券核销记录", "团购", "ODS", is_ods=True),
TaskDefinition("ODS_MEMBER", "会员主数据", "抽取会员档案", "会员", "ODS", is_ods=True),
TaskDefinition("ODS_MEMBER_CARD", "会员储值卡", "抽取会员储值卡信息", "会员", "ODS", is_ods=True),
TaskDefinition("ODS_MEMBER_BALANCE", "会员余额变动", "抽取会员余额变动记录", "会员", "ODS", is_ods=True),
TaskDefinition("ODS_RECHARGE_SETTLE", "充值结算", "抽取充值结算记录", "会员", "ODS", is_ods=True),
TaskDefinition("ODS_GROUP_PACKAGE", "团购套餐", "抽取团购套餐定义", "团购", "ODS", is_ods=True, requires_window=False),
TaskDefinition("ODS_GROUP_BUY_REDEMPTION", "团购核销", "抽取团购核销记录", "团购", "ODS", is_ods=True),
TaskDefinition("ODS_INVENTORY_STOCK", "库存快照", "抽取商品库存汇总", "库存", "ODS", is_ods=True, requires_window=False),
TaskDefinition("ODS_INVENTORY_CHANGE", "库存变动", "抽取库存出入库记录", "库存", "ODS", is_ods=True),
TaskDefinition("ODS_GOODS_CATEGORY", "商品分类", "抽取商品分类树", "商品", "ODS", is_ods=True, requires_window=False),
TaskDefinition("ODS_STORE_GOODS", "门店商品", "抽取门店商品主数据", "商品", "ODS", is_ods=True, requires_window=False),
TaskDefinition("ODS_STORE_GOODS_SALES", "商品销售", "抽取门店商品销售记录", "商品", "ODS", is_ods=True),
TaskDefinition("ODS_TENANT_GOODS", "租户商品", "抽取租户级商品主数据", "商品", "ODS", is_ods=True, requires_window=False),
]
# ── DWD 任务定义 ──────────────────────────────────────────────
DWD_TASKS: list[TaskDefinition] = [
TaskDefinition("DWD_LOAD_FROM_ODS", "DWD 装载", "从 ODS 装载至 DWD维度 SCD2 + 事实增量)", "通用", "DWD", requires_window=False),
TaskDefinition("DWD_QUALITY_CHECK", "DWD 质量检查", "对 DWD 层数据执行质量校验", "通用", "DWD", requires_window=False, is_common=False),
]
# ── DWS 任务定义 ──────────────────────────────────────────────
DWS_TASKS: list[TaskDefinition] = [
TaskDefinition("DWS_BUILD_ORDER_SUMMARY", "订单汇总构建", "构建订单汇总宽表", "结算", "DWS"),
TaskDefinition("DWS_ASSISTANT_DAILY", "助教日报", "汇总助教每日业绩", "助教", "DWS"),
TaskDefinition("DWS_ASSISTANT_MONTHLY", "助教月报", "汇总助教月度业绩", "助教", "DWS"),
TaskDefinition("DWS_ASSISTANT_CUSTOMER", "助教客户分析", "汇总助教-客户关系", "助教", "DWS"),
TaskDefinition("DWS_ASSISTANT_SALARY", "助教工资计算", "计算助教工资", "助教", "DWS"),
TaskDefinition("DWS_ASSISTANT_FINANCE", "助教财务汇总", "汇总助教财务数据", "助教", "DWS"),
TaskDefinition("DWS_MEMBER_CONSUMPTION", "会员消费分析", "汇总会员消费数据", "会员", "DWS"),
TaskDefinition("DWS_MEMBER_VISIT", "会员到店分析", "汇总会员到店频次", "会员", "DWS"),
TaskDefinition("DWS_FINANCE_DAILY", "财务日报", "汇总每日财务数据", "财务", "DWS"),
TaskDefinition("DWS_FINANCE_RECHARGE", "充值汇总", "汇总充值数据", "财务", "DWS"),
TaskDefinition("DWS_FINANCE_INCOME_STRUCTURE", "收入结构", "分析收入结构", "财务", "DWS"),
TaskDefinition("DWS_FINANCE_DISCOUNT_DETAIL", "折扣明细", "汇总折扣明细", "财务", "DWS"),
# CHANGE [2026-02-19] intent: 同步 ETL 侧合并——原 DWS_RETENTION_CLEANUP / DWS_MV_REFRESH_* 已合并为 DWS_MAINTENANCE
TaskDefinition("DWS_MAINTENANCE", "DWS 维护", "刷新物化视图 + 清理过期留存数据", "通用", "DWS", requires_window=False, is_common=False),
]
# ── INDEX 任务定义 ────────────────────────────────────────────
INDEX_TASKS: list[TaskDefinition] = [
TaskDefinition("DWS_WINBACK_INDEX", "回流指数 (WBI)", "计算会员回流指数", "指数", "INDEX"),
TaskDefinition("DWS_NEWCONV_INDEX", "新客转化指数 (NCI)", "计算新客转化指数", "指数", "INDEX"),
TaskDefinition("DWS_ML_MANUAL_IMPORT", "手动导入 (ML)", "手动导入机器学习数据", "指数", "INDEX", requires_window=False, is_common=False),
TaskDefinition("DWS_RELATION_INDEX", "关系指数 (RS)", "计算助教-客户关系指数", "指数", "INDEX"),
]
# ── 工具类任务定义 ────────────────────────────────────────────
UTILITY_TASKS: list[TaskDefinition] = [
TaskDefinition("MANUAL_INGEST", "手动导入", "从本地 JSON 文件手动导入数据", "工具", "UTILITY", requires_window=False, is_common=False),
TaskDefinition("INIT_ODS_SCHEMA", "初始化 ODS Schema", "创建 ODS 层表结构", "工具", "UTILITY", requires_window=False, is_common=False),
TaskDefinition("INIT_DWD_SCHEMA", "初始化 DWD Schema", "创建 DWD 层表结构", "工具", "UTILITY", requires_window=False, is_common=False),
TaskDefinition("INIT_DWS_SCHEMA", "初始化 DWS Schema", "创建 DWS 层表结构", "工具", "UTILITY", requires_window=False, is_common=False),
TaskDefinition("ODS_JSON_ARCHIVE", "ODS JSON 归档", "归档 ODS 原始 JSON 文件", "工具", "UTILITY", requires_window=False, is_common=False),
TaskDefinition("CHECK_CUTOFF", "游标检查", "检查各任务数据游标截止点", "工具", "UTILITY", requires_window=False, is_common=False),
TaskDefinition("SEED_DWS_CONFIG", "DWS 配置种子", "初始化 DWS 配置数据", "工具", "UTILITY", requires_window=False, is_common=False),
TaskDefinition("DATA_INTEGRITY_CHECK", "数据完整性校验", "校验跨层数据完整性", "工具", "UTILITY", requires_window=False, is_common=False),
]
# ── 全量任务列表 ──────────────────────────────────────────────
ALL_TASKS: list[TaskDefinition] = ODS_TASKS + DWD_TASKS + DWS_TASKS + INDEX_TASKS + UTILITY_TASKS
# 按 code 索引,便于快速查找
_TASK_BY_CODE: dict[str, TaskDefinition] = {t.code: t for t in ALL_TASKS}
def get_all_tasks() -> list[TaskDefinition]:
return ALL_TASKS
def get_task_by_code(code: str) -> TaskDefinition | None:
return _TASK_BY_CODE.get(code.upper())
def get_tasks_grouped_by_domain() -> dict[str, list[TaskDefinition]]:
"""按业务域分组返回任务列表"""
groups: dict[str, list[TaskDefinition]] = {}
for t in ALL_TASKS:
groups.setdefault(t.domain, []).append(t)
return groups
def get_tasks_by_layer(layer: str) -> list[TaskDefinition]:
"""获取指定层的所有任务"""
layer_upper = layer.upper()
return [t for t in ALL_TASKS if t.layer == layer_upper]
# ── Flow → 层映射 ────────────────────────────────────────────
# 每种 Flow 包含的层,用于前端按 Flow 过滤可选任务
FLOW_LAYER_MAP: dict[str, list[str]] = {
"api_ods": ["ODS"],
"api_ods_dwd": ["ODS", "DWD"],
"api_full": ["ODS", "DWD", "DWS", "INDEX"],
"ods_dwd": ["DWD"],
"dwd_dws": ["DWS"],
"dwd_dws_index": ["DWS", "INDEX"],
"dwd_index": ["INDEX"],
}
def get_compatible_tasks(flow_id: str) -> list[TaskDefinition]:
"""根据 Flow 包含的层,返回兼容的任务列表"""
layers = FLOW_LAYER_MAP.get(flow_id, [])
return [t for t in ALL_TASKS if t.layer in layers]
# ── DWD 表定义 ────────────────────────────────────────────────
DWD_TABLES: list[DwdTableDefinition] = [
# 维度表
DwdTableDefinition("dwd.dim_site", "门店维度", "台桌", "ods.table_fee_transactions", is_dimension=True),
DwdTableDefinition("dwd.dim_site_ex", "门店维度(扩展)", "台桌", "ods.table_fee_transactions", is_dimension=True),
DwdTableDefinition("dwd.dim_table", "台桌维度", "台桌", "ods.site_tables_master", is_dimension=True),
DwdTableDefinition("dwd.dim_table_ex", "台桌维度(扩展)", "台桌", "ods.site_tables_master", is_dimension=True),
DwdTableDefinition("dwd.dim_assistant", "助教维度", "助教", "ods.assistant_accounts_master", is_dimension=True),
DwdTableDefinition("dwd.dim_assistant_ex", "助教维度(扩展)", "助教", "ods.assistant_accounts_master", is_dimension=True),
DwdTableDefinition("dwd.dim_member", "会员维度", "会员", "ods.member_profiles", is_dimension=True),
DwdTableDefinition("dwd.dim_member_ex", "会员维度(扩展)", "会员", "ods.member_profiles", is_dimension=True),
DwdTableDefinition("dwd.dim_member_card_account", "会员储值卡维度", "会员", "ods.member_stored_value_cards", is_dimension=True),
DwdTableDefinition("dwd.dim_member_card_account_ex", "会员储值卡维度(扩展)", "会员", "ods.member_stored_value_cards", is_dimension=True),
DwdTableDefinition("dwd.dim_tenant_goods", "租户商品维度", "商品", "ods.tenant_goods_master", is_dimension=True),
DwdTableDefinition("dwd.dim_tenant_goods_ex", "租户商品维度(扩展)", "商品", "ods.tenant_goods_master", is_dimension=True),
DwdTableDefinition("dwd.dim_store_goods", "门店商品维度", "商品", "ods.store_goods_master", is_dimension=True),
DwdTableDefinition("dwd.dim_store_goods_ex", "门店商品维度(扩展)", "商品", "ods.store_goods_master", is_dimension=True),
DwdTableDefinition("dwd.dim_goods_category", "商品分类维度", "商品", "ods.stock_goods_category_tree", is_dimension=True),
DwdTableDefinition("dwd.dim_groupbuy_package", "团购套餐维度", "团购", "ods.group_buy_packages", is_dimension=True),
DwdTableDefinition("dwd.dim_groupbuy_package_ex", "团购套餐维度(扩展)", "团购", "ods.group_buy_packages", is_dimension=True),
# 事实表
DwdTableDefinition("dwd.dwd_settlement_head", "结算主表", "结算", "ods.settlement_records"),
DwdTableDefinition("dwd.dwd_settlement_head_ex", "结算主表(扩展)", "结算", "ods.settlement_records"),
DwdTableDefinition("dwd.dwd_table_fee_log", "台费流水", "台桌", "ods.table_fee_transactions"),
DwdTableDefinition("dwd.dwd_table_fee_log_ex", "台费流水(扩展)", "台桌", "ods.table_fee_transactions"),
DwdTableDefinition("dwd.dwd_table_fee_adjust", "台费折扣", "台桌", "ods.table_fee_discount_records"),
DwdTableDefinition("dwd.dwd_table_fee_adjust_ex", "台费折扣(扩展)", "台桌", "ods.table_fee_discount_records"),
DwdTableDefinition("dwd.dwd_store_goods_sale", "商品销售", "商品", "ods.store_goods_sales_records"),
DwdTableDefinition("dwd.dwd_store_goods_sale_ex", "商品销售(扩展)", "商品", "ods.store_goods_sales_records"),
DwdTableDefinition("dwd.dwd_assistant_service_log", "助教服务流水", "助教", "ods.assistant_service_records"),
DwdTableDefinition("dwd.dwd_assistant_service_log_ex", "助教服务流水(扩展)", "助教", "ods.assistant_service_records"),
DwdTableDefinition("dwd.dwd_assistant_trash_event", "助教取消事件", "助教", "ods.assistant_cancellation_records"),
DwdTableDefinition("dwd.dwd_assistant_trash_event_ex", "助教取消事件(扩展)", "助教", "ods.assistant_cancellation_records"),
DwdTableDefinition("dwd.dwd_member_balance_change", "会员余额变动", "会员", "ods.member_balance_changes"),
DwdTableDefinition("dwd.dwd_member_balance_change_ex", "会员余额变动(扩展)", "会员", "ods.member_balance_changes"),
DwdTableDefinition("dwd.dwd_groupbuy_redemption", "团购核销", "团购", "ods.group_buy_redemption_records"),
DwdTableDefinition("dwd.dwd_groupbuy_redemption_ex", "团购核销(扩展)", "团购", "ods.group_buy_redemption_records"),
DwdTableDefinition("dwd.dwd_platform_coupon_redemption", "平台券核销", "团购", "ods.platform_coupon_redemption_records"),
DwdTableDefinition("dwd.dwd_platform_coupon_redemption_ex", "平台券核销(扩展)", "团购", "ods.platform_coupon_redemption_records"),
DwdTableDefinition("dwd.dwd_recharge_order", "充值订单", "会员", "ods.recharge_settlements"),
DwdTableDefinition("dwd.dwd_recharge_order_ex", "充值订单(扩展)", "会员", "ods.recharge_settlements"),
DwdTableDefinition("dwd.dwd_payment", "支付流水", "结算", "ods.payment_transactions"),
DwdTableDefinition("dwd.dwd_refund", "退款流水", "结算", "ods.refund_transactions"),
DwdTableDefinition("dwd.dwd_refund_ex", "退款流水(扩展)", "结算", "ods.refund_transactions"),
]
def get_dwd_tables_grouped_by_domain() -> dict[str, list[DwdTableDefinition]]:
"""按业务域分组返回 DWD 表定义"""
groups: dict[str, list[DwdTableDefinition]] = {}
for t in DWD_TABLES:
groups.setdefault(t.domain, []).append(t)
return groups

View File

View File

@@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
"""WebSocket 日志推送端点
提供 WS /ws/logs/{execution_id} 端点,实时推送 ETL 任务执行日志。
客户端连接后,先发送已有的历史日志行,再实时推送新日志,
直到执行结束(收到 None 哨兵)或客户端断开。
设计要点:
- 利用 TaskExecutor 已有的 subscribe/unsubscribe 机制
- 连接时先回放内存缓冲区中的历史日志,避免丢失已产生的行
- 通过 asyncio.Queue 接收实时日志None 表示执行结束
"""
from __future__ import annotations
import logging
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from ..services.task_executor import task_executor
logger = logging.getLogger(__name__)
ws_router = APIRouter()
@ws_router.websocket("/ws/logs/{execution_id}")
async def ws_logs(websocket: WebSocket, execution_id: str) -> None:
"""实时推送指定 execution_id 的任务执行日志。
流程:
1. 接受 WebSocket 连接
2. 回放内存缓冲区中已有的日志行
3. 订阅 TaskExecutor持续推送新日志
4. 收到 None执行结束或客户端断开时关闭
"""
await websocket.accept()
logger.info("WebSocket 连接已建立: execution_id=%s", execution_id)
# 订阅日志流
queue = task_executor.subscribe(execution_id)
try:
# 回放已有的历史日志行
for line in task_executor.get_logs(execution_id):
await websocket.send_text(line)
# 如果任务已经不在运行且没有订阅者队列中的数据,
# 仍然保持连接等待——可能是任务刚结束但 queue 里还有未消费的消息
while True:
msg = await queue.get()
if msg is None:
# 执行结束哨兵
break
await websocket.send_text(msg)
except WebSocketDisconnect:
logger.info("WebSocket 客户端断开: execution_id=%s", execution_id)
except Exception:
logger.exception("WebSocket 异常: execution_id=%s", execution_id)
finally:
task_executor.unsubscribe(execution_id, queue)
# 安全关闭连接(客户端可能已断开,忽略错误)
try:
await websocket.close()
except Exception:
pass
logger.info("WebSocket 连接已清理: execution_id=%s", execution_id)

View File

@@ -0,0 +1,24 @@
-----BEGIN CERTIFICATE-----
MIID9DCCAtygAwIBAgIUaB2siLoT1Nb+u9K0aL18avodFRYwDQYJKoZIhvcNAQEL
BQAwbTELMAkGA1UEBhMCQ04xEjAQBgNVBAgMCUd1YW5nRG9uZzERMA8GA1UEBwwI
U2hlblpoZW4xEDAOBgNVBAoMB1RlbmNlbnQxEDAOBgNVBAsMB1RlbmNlbnQxEzAR
BgNVBAMMCnRzbTAwMDAwMDYwHhcNMjUxMTA0MTA1NzQ0WhcNMzUxMTAyMTA1NzQ0
WjCBmDEeMBwGCSqGSIb3DQEJARYPd2VpeGlubXBAcXEuY29tMRswGQYDVQQDDBJ3
eDdjMDc3OTNkODI3MzI5MjExFTATBgNVBAoMDFRlbmNlbnQgSW5jLjEOMAwGA1UE
CwwFV3hnTXAxCzAJBgNVBAYTAkNOMRIwEAYDVQQIDAlHdWFuZ0RvbmcxETAPBgNV
BAcMCFNoZW5aaGVuMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA8Ig7
m6qvcw+0ncNZBUg9Xfk+6WTKSj7cCwcuD66JgYPQ4pVreqCTzusc//E+EGapXyR3
fJZl+AL2sVWRIzLa1+dt9+45YBGm3lSbvarlbsdVMrYwFRW+d/vZgxcXfcS1VmVP
NPEw7DaAkWlvUkFmUGdNzH+QLsuXTZKWEtHtmSo7us9HTJYV/aEH2uJpsHE4A3fP
Vbc/wy1EwLt48o0ZDzpiPZiqn+nSrSXEqBPOEwzICxnHCJRpEH01RxBJGdTSouzF
pfncMeEpGfFGH8GW8IQYzzvrvYDbprsnVMfHNAo0MMGK+iyWyAOFMqrkvSb962x7
KoXLDH9OfmFRNse9WQIDAQABo2AwXjAdBgNVHQ4EFgQUccFm3WWmKA/m+uXeW8Xe
jfNsX+owHwYDVR0jBBgwFoAURM1183H4z2eJGmP5z/zWI7tjRZAwDAYDVR0TAQH/
BAIwADAOBgNVHQ8BAf8EBAMCBsAwDQYJKoZIhvcNAQELBQADggEBAIXfGfQARyxC
Ptut+rOccdq8TawasZT7o7TnAGCCTPsAWCd5RTAXse65mSGM6oxjQsppZxtYz4Kx
TLySl91Vok2nMH1jBoWPx9WoFyU6zCkmOkq7zWvEU23FR1Quq0QB0fmHrVMNQqxA
LKkUuUFTa1wmVuYaKtcz5LAaj+GmgrY3kTIWg81ybPF/Hibkz0zWh54SLBc64Ha6
zfNXDffq3vVVo04DKZW8Erd9nZL0F/w2u6+MpTl5CrAYzSZyDcNiIGbSrYpYYRt9
JagGAn/ZZD93SnOiMcRCsNfNq4LisSf6AUMSA3F9Rw8iuxas5lDBf073pEy2vWjG
VSp+Vio/oEY=
-----END CERTIFICATE-----

View File

@@ -0,0 +1,38 @@
## 开放平台证书
开放平台证书编号06e9682660ce742bb45ef278ae941af0
证书已下载。
编号 密钥类型 密钥明文
901b24f9af7b1421b80ebd5df9094141 对称密钥 D1fK6Zib6UOG10bM4WWhjsbMImCNXz7Mxq/0oRREGmw=
59347d0e3cd661af9e90f7def5b6ca00 非对称密钥 -----BEGIN PUBLIC KEY----- MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAzMrQ8iGGop0JEvcS/dZu sVUgjHeHqtds506Ftq/0WylTGKG9yByrY6eSKHttkBs/IyqG6JU9TtmskvVBam9B BmwOVyHkXATlXwEIkhyShu459p0dQzKwFaiygtj/fvxzWvurXVf1UcbIYVP1u7T0 E6yqatUkhmeaIyBsuw7vp7yYxpJoxsR2t70cDOTnfmHzl47FqSmq8xQRl/Cyw4FJ +1NS3i3cQBYkxjxGJ8Q3lhd5xd8IjwzDf+v/UgEhmoduZDAr/HkOwcrh3ihCCkLC DsfqV6qJrtqHhq5PHog3TqaaF6I9vk04FQTHn07XzIMciBDauYdD/Mq3B+ZBOKX4 gQIDAQAB -----END PUBLIC KEY-----
## 安全管理配置
https://developers.weixin.qq.com/miniprogram/dev/framework/share_signature.html
## 消息推送配置
(暂未处理)
填写的URL需要正确响应微信发送的Token验证填写说明请阅读消息推送服务器配置指南。https://developers.weixin.qq.com/miniprogram/dev/framework/server-ability/message-push.html
URL(服务器地址)
https://push.langlangzhuoqiu.cn
Token(令牌)
nCmGmzINYfaqf5jDGKdkDO2AJ9C0VWNe
EncodingAESKey(消息加密密钥):
2SZwWe90vG121o1l7RRMbtGt8GNvA1Juf727a3m7nZX
消息加密方式:
安全模式 (消息包为纯密文,需要加密和解密。)
数据格式:
JSON
## 业务域名
https://api.langlangzhuoqiu.cn

View File

@@ -1,10 +1,34 @@
# AI_CHANGELOG
# - 2026-02-15 | Prompt: 让 FastAPI 成功启动 | 补全运行依赖fastapi/uvicorn/psycopg2-binary/python-dotenv使后端可通过 uv run uvicorn 启动
# - 风险:依赖版本变更可能影响其他 workspace 成员验证uv sync --all-packages && uv run uvicorn app.main:app
[project]
name = "zqyy-backend"
version = "0.1.0"
requires-python = ">=3.10"
# CHANGE 2026-02-15 | intent: 补全后端运行依赖,原先仅声明 neozqyy-shared 导致 uvicorn/fastapi 缺失无法启动
# assumptions: 版本下限与 tech.md 记录的核心依赖一致uvicorn[standard] 包含 uvloop/httptools 等性能依赖
dependencies = [
"neozqyy-shared",
"fastapi>=0.115",
"uvicorn[standard]>=0.34",
"psycopg2-binary>=2.9",
"python-dotenv>=1.0",
"python-jose[cryptography]>=3.3",
"bcrypt>=4.0",
]
[tool.uv.sources]
neozqyy-shared = { workspace = true }
[dependency-groups]
dev = [
"pytest>=8.0",
"pytest-asyncio>=0.23",
"hypothesis>=6.100",
"httpx>=0.27",
]
[tool.pytest.ini_options]
testpaths = ["tests"]
pythonpath = ["."]

View File

@@ -0,0 +1,62 @@
"""
FastAPI 依赖注入 get_current_user 单元测试。
通过 FastAPI TestClient 验证 Authorization header 处理。
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
import pytest
from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient
from app.auth.dependencies import CurrentUser, get_current_user
from app.auth.jwt import create_access_token, create_refresh_token
# 构造一个最小 FastAPI 应用用于测试依赖注入
_test_app = FastAPI()
@_test_app.get("/protected")
async def protected_route(user: CurrentUser = Depends(get_current_user)):
return {"user_id": user.user_id, "site_id": user.site_id}
client = TestClient(_test_app)
class TestGetCurrentUser:
def test_valid_access_token(self):
token = create_access_token(user_id=10, site_id=100)
resp = client.get("/protected", headers={"Authorization": f"Bearer {token}"})
assert resp.status_code == 200
data = resp.json()
assert data["user_id"] == 10
assert data["site_id"] == 100
def test_missing_auth_header_returns_401(self):
"""缺少 Authorization header 时返回 401。"""
resp = client.get("/protected")
assert resp.status_code in (401, 403)
def test_invalid_token_returns_401(self):
resp = client.get(
"/protected", headers={"Authorization": "Bearer invalid.token.here"}
)
assert resp.status_code == 401
def test_refresh_token_rejected(self):
"""refresh 令牌不能用于访问受保护端点。"""
token = create_refresh_token(user_id=1, site_id=1)
resp = client.get("/protected", headers={"Authorization": f"Bearer {token}"})
assert resp.status_code == 401
def test_current_user_is_frozen_dataclass(self):
"""CurrentUser 是不可变的。"""
user = CurrentUser(user_id=1, site_id=2)
assert user.user_id == 1
assert user.site_id == 2
with pytest.raises(AttributeError):
user.user_id = 99 # type: ignore[misc]

View File

@@ -0,0 +1,147 @@
"""
JWT 认证模块单元测试。
覆盖:令牌生成、验证、过期、类型校验、密码哈希、依赖注入。
"""
import os
import time
import pytest
from jose import jwt as jose_jwt
# 测试前设置 JWT_SECRET_KEY避免空密钥
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
from app.auth.jwt import (
create_access_token,
create_refresh_token,
create_token_pair,
decode_access_token,
decode_refresh_token,
decode_token,
hash_password,
verify_password,
)
from app import config
# ---------------------------------------------------------------------------
# 密码哈希
# ---------------------------------------------------------------------------
class TestPasswordHashing:
def test_hash_and_verify(self):
raw = "my_secure_password"
hashed = hash_password(raw)
assert verify_password(raw, hashed)
def test_wrong_password_rejected(self):
hashed = hash_password("correct")
assert not verify_password("wrong", hashed)
def test_hash_is_not_plaintext(self):
raw = "plaintext123"
hashed = hash_password(raw)
assert hashed != raw
# ---------------------------------------------------------------------------
# 令牌生成与解码
# ---------------------------------------------------------------------------
class TestTokenCreation:
def test_access_token_contains_expected_fields(self):
token = create_access_token(user_id=1, site_id=100)
payload = decode_token(token)
assert payload["sub"] == "1"
assert payload["site_id"] == 100
assert payload["type"] == "access"
assert "exp" in payload
def test_refresh_token_contains_expected_fields(self):
token = create_refresh_token(user_id=2, site_id=200)
payload = decode_token(token)
assert payload["sub"] == "2"
assert payload["site_id"] == 200
assert payload["type"] == "refresh"
assert "exp" in payload
def test_token_pair_returns_both_tokens(self):
pair = create_token_pair(user_id=3, site_id=300)
assert "access_token" in pair
assert "refresh_token" in pair
assert pair["token_type"] == "bearer"
# 验证两个令牌类型不同
access_payload = decode_token(pair["access_token"])
refresh_payload = decode_token(pair["refresh_token"])
assert access_payload["type"] == "access"
assert refresh_payload["type"] == "refresh"
# ---------------------------------------------------------------------------
# 令牌类型校验
# ---------------------------------------------------------------------------
class TestTokenTypeValidation:
def test_decode_access_token_rejects_refresh(self):
"""access 解码器拒绝 refresh 令牌。"""
token = create_refresh_token(user_id=1, site_id=1)
with pytest.raises(Exception):
decode_access_token(token)
def test_decode_refresh_token_rejects_access(self):
"""refresh 解码器拒绝 access 令牌。"""
token = create_access_token(user_id=1, site_id=1)
with pytest.raises(Exception):
decode_refresh_token(token)
def test_decode_access_token_accepts_access(self):
token = create_access_token(user_id=5, site_id=50)
payload = decode_access_token(token)
assert payload["sub"] == "5"
assert payload["site_id"] == 50
def test_decode_refresh_token_accepts_refresh(self):
token = create_refresh_token(user_id=6, site_id=60)
payload = decode_refresh_token(token)
assert payload["sub"] == "6"
assert payload["site_id"] == 60
# ---------------------------------------------------------------------------
# 令牌过期
# ---------------------------------------------------------------------------
class TestTokenExpiry:
def test_expired_token_rejected(self):
"""手动构造已过期令牌,验证解码失败。"""
payload = {
"sub": "1",
"site_id": 1,
"type": "access",
"exp": int(time.time()) - 10, # 10 秒前过期
}
token = jose_jwt.encode(
payload, config.JWT_SECRET_KEY, algorithm=config.JWT_ALGORITHM
)
with pytest.raises(Exception):
decode_token(token)
# ---------------------------------------------------------------------------
# 无效令牌
# ---------------------------------------------------------------------------
class TestInvalidToken:
def test_garbage_token_rejected(self):
with pytest.raises(Exception):
decode_token("not.a.valid.jwt")
def test_wrong_secret_rejected(self):
"""用不同密钥签发的令牌应被拒绝。"""
payload = {"sub": "1", "site_id": 1, "type": "access", "exp": int(time.time()) + 3600}
token = jose_jwt.encode(payload, "wrong-secret", algorithm="HS256")
with pytest.raises(Exception):
decode_token(token)

View File

@@ -0,0 +1,137 @@
"""
认证模块属性测试Property-Based Testing
使用 hypothesis 验证认证系统的通用正确性属性:
- Property 2: 无效凭据始终被拒绝
- Property 3: 有效 JWT 令牌授权访问
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-property-tests")
from unittest.mock import MagicMock, patch
from hypothesis import given, settings
from hypothesis import strategies as st
from app.auth.dependencies import CurrentUser, get_current_user
from app.auth.jwt import create_access_token
from app.main import app
from app.routers.auth import router
# 确保路由已挂载
if router not in [r for r in app.routes]:
app.include_router(router)
from fastapi.testclient import TestClient
client = TestClient(app)
# ---------------------------------------------------------------------------
# 策略Strategies
# ---------------------------------------------------------------------------
# 用户名策略1~64 字符的可打印字符串(排除控制字符)
_username_st = st.text(
alphabet=st.characters(whitelist_categories=("L", "N", "P", "S")),
min_size=1,
max_size=64,
)
# 密码策略1~128 字符的可打印字符串
_password_st = st.text(
alphabet=st.characters(whitelist_categories=("L", "N", "P", "S")),
min_size=1,
max_size=128,
)
# user_id 策略:正整数
_user_id_st = st.integers(min_value=1, max_value=2**31 - 1)
# site_id 策略:正整数
_site_id_st = st.integers(min_value=1, max_value=2**63 - 1)
def _mock_db_returning(row):
"""构造 mock get_connectioncursor.fetchone() 返回指定行。"""
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = row
mock_conn.cursor.return_value.__enter__ = lambda _: mock_cursor
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
return mock_conn
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 2: 无效凭据始终被拒绝
# **Validates: Requirements 1.2**
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(username=_username_st, password=_password_st)
@patch("app.routers.auth.get_connection")
def test_invalid_credentials_always_rejected(mock_get_conn, username, password):
"""
Property 2: 无效凭据始终被拒绝。
对于任意用户名/密码组合当数据库中不存在该用户时fetchone 返回 None
登录接口应始终返回 401 状态码。
"""
# mock 数据库返回 None — 用户不存在
mock_get_conn.return_value = _mock_db_returning(None)
resp = client.post(
"/api/auth/login",
json={"username": username, "password": password},
)
assert resp.status_code == 401, (
f"期望 401实际 {resp.status_code}username={username!r}"
)
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 3: 有效 JWT 令牌授权访问
# **Validates: Requirements 1.3**
# ---------------------------------------------------------------------------
import asyncio
from fastapi.security import HTTPAuthorizationCredentials
def _run_async(coro):
"""在同步上下文中执行异步协程,避免 DeprecationWarning。"""
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
loop.close()
@settings(max_examples=100)
@given(user_id=_user_id_st, site_id=_site_id_st)
def test_valid_jwt_grants_access(user_id, site_id):
"""
Property 3: 有效 JWT 令牌授权访问。
对于任意 user_id 和 site_id由系统签发的未过期 access_token
应能被 get_current_user 依赖成功解析为 CurrentUser 对象,
且解析出的 user_id 和 site_id 与签发时一致。
"""
# 生成有效的 access_token
token = create_access_token(user_id=user_id, site_id=site_id)
# 直接调用依赖函数验证令牌解析
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
result = _run_async(get_current_user(credentials))
assert isinstance(result, CurrentUser)
assert result.user_id == user_id, (
f"user_id 不匹配:期望 {user_id},实际 {result.user_id}"
)
assert result.site_id == site_id, (
f"site_id 不匹配:期望 {site_id},实际 {result.site_id}"
)

View File

@@ -0,0 +1,167 @@
"""
认证路由单元测试。
覆盖:登录成功/失败、刷新令牌、账号禁用等场景。
通过 mock 数据库连接避免依赖真实 PostgreSQL。
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from app.auth.jwt import (
create_refresh_token,
decode_access_token,
decode_refresh_token,
hash_password,
)
from app.main import app
from app.routers.auth import router
# 注册路由到 app测试时确保路由已挂载
if router not in [r for r in app.routes]:
app.include_router(router)
client = TestClient(app)
# 测试用固定数据
_TEST_PASSWORD = "correct_password"
_TEST_HASH = hash_password(_TEST_PASSWORD)
_TEST_USER_ROW = (1, _TEST_HASH, 100, True) # id, password_hash, site_id, is_active
_DISABLED_USER_ROW = (2, _TEST_HASH, 200, False)
def _mock_db_returning(row):
"""构造一个 mock get_connectioncursor.fetchone() 返回指定行。"""
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = row
mock_conn.cursor.return_value.__enter__ = lambda _: mock_cursor
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
return mock_conn
# ---------------------------------------------------------------------------
# POST /api/auth/login
# ---------------------------------------------------------------------------
class TestLogin:
@patch("app.routers.auth.get_connection")
def test_login_success(self, mock_get_conn):
mock_get_conn.return_value = _mock_db_returning(_TEST_USER_ROW)
resp = client.post(
"/api/auth/login",
json={"username": "admin", "password": _TEST_PASSWORD},
)
assert resp.status_code == 200
data = resp.json()
assert "access_token" in data
assert "refresh_token" in data
assert data["token_type"] == "bearer"
# 验证 access_token payload 包含正确的 user_id 和 site_id
payload = decode_access_token(data["access_token"])
assert payload["sub"] == "1"
assert payload["site_id"] == 100
@patch("app.routers.auth.get_connection")
def test_login_user_not_found(self, mock_get_conn):
"""用户不存在时返回 401。"""
mock_get_conn.return_value = _mock_db_returning(None)
resp = client.post(
"/api/auth/login",
json={"username": "nonexistent", "password": "whatever"},
)
assert resp.status_code == 401
assert "用户名或密码错误" in resp.json()["detail"]
@patch("app.routers.auth.get_connection")
def test_login_wrong_password(self, mock_get_conn):
"""密码错误时返回 401。"""
mock_get_conn.return_value = _mock_db_returning(_TEST_USER_ROW)
resp = client.post(
"/api/auth/login",
json={"username": "admin", "password": "wrong_password"},
)
assert resp.status_code == 401
assert "用户名或密码错误" in resp.json()["detail"]
@patch("app.routers.auth.get_connection")
def test_login_disabled_account(self, mock_get_conn):
"""账号已禁用时返回 401。"""
mock_get_conn.return_value = _mock_db_returning(_DISABLED_USER_ROW)
resp = client.post(
"/api/auth/login",
json={"username": "disabled_user", "password": _TEST_PASSWORD},
)
assert resp.status_code == 401
assert "禁用" in resp.json()["detail"]
def test_login_missing_username(self):
"""缺少 username 字段时返回 422。"""
resp = client.post("/api/auth/login", json={"password": "test"})
assert resp.status_code == 422
def test_login_empty_password(self):
"""空密码时返回 422。"""
resp = client.post(
"/api/auth/login", json={"username": "admin", "password": ""}
)
assert resp.status_code == 422
# ---------------------------------------------------------------------------
# POST /api/auth/refresh
# ---------------------------------------------------------------------------
class TestRefresh:
def test_refresh_success(self):
"""有效的 refresh_token 换取新的 access_token。"""
refresh = create_refresh_token(user_id=5, site_id=50)
resp = client.post(
"/api/auth/refresh", json={"refresh_token": refresh}
)
assert resp.status_code == 200
data = resp.json()
assert "access_token" in data
# refresh_token 原样返回
assert data["refresh_token"] == refresh
assert data["token_type"] == "bearer"
# 新 access_token 包含正确信息
payload = decode_access_token(data["access_token"])
assert payload["sub"] == "5"
assert payload["site_id"] == 50
def test_refresh_with_invalid_token(self):
"""无效令牌返回 401。"""
resp = client.post(
"/api/auth/refresh", json={"refresh_token": "garbage.token.here"}
)
assert resp.status_code == 401
assert "无效的刷新令牌" in resp.json()["detail"]
def test_refresh_with_access_token_rejected(self):
"""用 access_token 做刷新应被拒绝。"""
from app.auth.jwt import create_access_token
access = create_access_token(user_id=1, site_id=1)
resp = client.post(
"/api/auth/refresh", json={"refresh_token": access}
)
assert resp.status_code == 401
def test_refresh_missing_token(self):
"""缺少 refresh_token 字段时返回 422。"""
resp = client.post("/api/auth/refresh", json={})
assert resp.status_code == 422

View File

@@ -0,0 +1,259 @@
# -*- coding: utf-8 -*-
"""CLIBuilder 单元测试
覆盖7 种 Flow、3 种处理模式、时间窗口、store_id 自动注入、extra_args 等。
"""
import pytest
from app.schemas.tasks import TaskConfigSchema
from app.services.cli_builder import CLIBuilder, VALID_FLOWS, VALID_PROCESSING_MODES
@pytest.fixture
def builder() -> CLIBuilder:
return CLIBuilder()
ETL_PATH = "/fake/etl/project"
# ---------------------------------------------------------------------------
# 基本命令结构
# ---------------------------------------------------------------------------
class TestBasicCommand:
def test_minimal_command(self, builder: CLIBuilder):
"""最小配置应生成 python -m cli.main --pipeline ... --processing-mode ..."""
config = TaskConfigSchema(tasks=["ODS_MEMBER"])
cmd = builder.build_command(config, ETL_PATH)
assert cmd[:3] == ["python", "-m", "cli.main"]
assert "--pipeline" in cmd
assert "--processing-mode" in cmd
def test_custom_python_executable(self, builder: CLIBuilder):
config = TaskConfigSchema(tasks=["ODS_MEMBER"])
cmd = builder.build_command(config, ETL_PATH, python_executable="python3")
assert cmd[0] == "python3"
def test_tasks_joined_by_comma(self, builder: CLIBuilder):
config = TaskConfigSchema(tasks=["ODS_MEMBER", "ODS_PAYMENT", "ODS_REFUND"])
cmd = builder.build_command(config, ETL_PATH)
idx = cmd.index("--tasks")
assert cmd[idx + 1] == "ODS_MEMBER,ODS_PAYMENT,ODS_REFUND"
def test_empty_tasks_no_tasks_arg(self, builder: CLIBuilder):
"""空任务列表不应生成 --tasks 参数"""
config = TaskConfigSchema(tasks=[])
cmd = builder.build_command(config, ETL_PATH)
assert "--tasks" not in cmd
# ---------------------------------------------------------------------------
# 7 种 Flow
# ---------------------------------------------------------------------------
class TestFlows:
@pytest.mark.parametrize("flow_id", sorted(VALID_FLOWS))
def test_all_flows_accepted(self, builder: CLIBuilder, flow_id: str):
config = TaskConfigSchema(tasks=["ODS_MEMBER"], pipeline=flow_id)
cmd = builder.build_command(config, ETL_PATH)
idx = cmd.index("--pipeline")
assert cmd[idx + 1] == flow_id
def test_default_flow_is_api_ods_dwd(self, builder: CLIBuilder):
config = TaskConfigSchema(tasks=["ODS_MEMBER"])
cmd = builder.build_command(config, ETL_PATH)
idx = cmd.index("--pipeline")
assert cmd[idx + 1] == "api_ods_dwd"
# ---------------------------------------------------------------------------
# 3 种处理模式
# ---------------------------------------------------------------------------
class TestProcessingModes:
@pytest.mark.parametrize("mode", sorted(VALID_PROCESSING_MODES))
def test_all_modes_accepted(self, builder: CLIBuilder, mode: str):
config = TaskConfigSchema(tasks=["ODS_MEMBER"], processing_mode=mode)
cmd = builder.build_command(config, ETL_PATH)
idx = cmd.index("--processing-mode")
assert cmd[idx + 1] == mode
def test_fetch_before_verify_only_in_verify_mode(self, builder: CLIBuilder):
"""--fetch-before-verify 仅在 verify_only 模式下生效"""
config = TaskConfigSchema(
tasks=["ODS_MEMBER"],
processing_mode="verify_only",
fetch_before_verify=True,
)
cmd = builder.build_command(config, ETL_PATH)
assert "--fetch-before-verify" in cmd
def test_fetch_before_verify_ignored_in_increment_mode(self, builder: CLIBuilder):
"""increment_only 模式下 fetch_before_verify=True 不应生成参数"""
config = TaskConfigSchema(
tasks=["ODS_MEMBER"],
processing_mode="increment_only",
fetch_before_verify=True,
)
cmd = builder.build_command(config, ETL_PATH)
assert "--fetch-before-verify" not in cmd
# ---------------------------------------------------------------------------
# 时间窗口
# ---------------------------------------------------------------------------
class TestTimeWindow:
def test_lookback_mode_generates_lookback_args(self, builder: CLIBuilder):
config = TaskConfigSchema(
tasks=["ODS_MEMBER"],
window_mode="lookback",
lookback_hours=48,
overlap_seconds=1200,
)
cmd = builder.build_command(config, ETL_PATH)
idx_lb = cmd.index("--lookback-hours")
assert cmd[idx_lb + 1] == "48"
idx_ol = cmd.index("--overlap-seconds")
assert cmd[idx_ol + 1] == "1200"
# lookback 模式不应生成 --window-start / --window-end
assert "--window-start" not in cmd
assert "--window-end" not in cmd
def test_custom_mode_generates_window_args(self, builder: CLIBuilder):
config = TaskConfigSchema(
tasks=["ODS_MEMBER"],
window_mode="custom",
window_start="2026-01-01",
window_end="2026-01-31",
)
cmd = builder.build_command(config, ETL_PATH)
idx_s = cmd.index("--window-start")
assert cmd[idx_s + 1] == "2026-01-01"
idx_e = cmd.index("--window-end")
assert cmd[idx_e + 1] == "2026-01-31"
# custom 模式不应生成 --lookback-hours / --overlap-seconds
assert "--lookback-hours" not in cmd
assert "--overlap-seconds" not in cmd
def test_window_split_with_days(self, builder: CLIBuilder):
config = TaskConfigSchema(
tasks=["ODS_MEMBER"],
window_split="day",
window_split_days=10,
)
cmd = builder.build_command(config, ETL_PATH)
idx = cmd.index("--window-split")
assert cmd[idx + 1] == "day"
idx_d = cmd.index("--window-split-days")
assert cmd[idx_d + 1] == "10"
def test_window_split_none_not_generated(self, builder: CLIBuilder):
"""window_split='none' 不应生成 --window-split 参数"""
config = TaskConfigSchema(tasks=["ODS_MEMBER"], window_split="none")
cmd = builder.build_command(config, ETL_PATH)
assert "--window-split" not in cmd
# ---------------------------------------------------------------------------
# store_id 自动注入
# ---------------------------------------------------------------------------
class TestStoreId:
def test_store_id_injected(self, builder: CLIBuilder):
config = TaskConfigSchema(tasks=["ODS_MEMBER"], store_id=42)
cmd = builder.build_command(config, ETL_PATH)
idx = cmd.index("--store-id")
assert cmd[idx + 1] == "42"
def test_store_id_none_not_generated(self, builder: CLIBuilder):
config = TaskConfigSchema(tasks=["ODS_MEMBER"], store_id=None)
cmd = builder.build_command(config, ETL_PATH)
assert "--store-id" not in cmd
# ---------------------------------------------------------------------------
# dry_run
# ---------------------------------------------------------------------------
class TestDryRun:
def test_dry_run_flag(self, builder: CLIBuilder):
config = TaskConfigSchema(tasks=["ODS_MEMBER"], dry_run=True)
cmd = builder.build_command(config, ETL_PATH)
assert "--dry-run" in cmd
def test_no_dry_run_flag(self, builder: CLIBuilder):
config = TaskConfigSchema(tasks=["ODS_MEMBER"], dry_run=False)
cmd = builder.build_command(config, ETL_PATH)
assert "--dry-run" not in cmd
# ---------------------------------------------------------------------------
# extra_args
# ---------------------------------------------------------------------------
class TestExtraArgs:
def test_supported_value_arg(self, builder: CLIBuilder):
config = TaskConfigSchema(
tasks=["ODS_MEMBER"],
extra_args={"pg_dsn": "postgresql://localhost/test"},
)
cmd = builder.build_command(config, ETL_PATH)
idx = cmd.index("--pg-dsn")
assert cmd[idx + 1] == "postgresql://localhost/test"
def test_supported_bool_arg(self, builder: CLIBuilder):
config = TaskConfigSchema(
tasks=["ODS_MEMBER"],
extra_args={"force_window_override": True},
)
cmd = builder.build_command(config, ETL_PATH)
assert "--force-window-override" in cmd
def test_unsupported_arg_ignored(self, builder: CLIBuilder):
"""不在 CLI_SUPPORTED_ARGS 中的键应被忽略"""
config = TaskConfigSchema(
tasks=["ODS_MEMBER"],
extra_args={"unknown_param": "value"},
)
cmd = builder.build_command(config, ETL_PATH)
assert "--unknown-param" not in cmd
def test_none_value_ignored(self, builder: CLIBuilder):
config = TaskConfigSchema(
tasks=["ODS_MEMBER"],
extra_args={"pg_dsn": None},
)
cmd = builder.build_command(config, ETL_PATH)
assert "--pg-dsn" not in cmd
def test_false_bool_arg_not_generated(self, builder: CLIBuilder):
config = TaskConfigSchema(
tasks=["ODS_MEMBER"],
extra_args={"force_window_override": False},
)
cmd = builder.build_command(config, ETL_PATH)
assert "--force-window-override" not in cmd
# ---------------------------------------------------------------------------
# build_command_string
# ---------------------------------------------------------------------------
class TestBuildCommandString:
def test_returns_string(self, builder: CLIBuilder):
config = TaskConfigSchema(tasks=["ODS_MEMBER"])
result = builder.build_command_string(config, ETL_PATH)
assert isinstance(result, str)
assert "python -m cli.main" in result
def test_quotes_args_with_spaces(self, builder: CLIBuilder):
config = TaskConfigSchema(
tasks=["ODS_MEMBER"],
extra_args={"pg_dsn": "host=localhost dbname=test"},
)
result = builder.build_command_string(config, ETL_PATH)
# 包含空格的值应被引号包裹
assert '"host=localhost dbname=test"' in result

View File

@@ -0,0 +1,94 @@
"""
数据库连接模块单元测试。
覆盖ETL 只读连接的创建、RLS site_id 设置、只读模式、异常处理。
"""
import os
from unittest.mock import MagicMock, call, patch
import pytest
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
from app.database import get_etl_readonly_connection
# ---------------------------------------------------------------------------
# get_etl_readonly_connection
# ---------------------------------------------------------------------------
class TestGetEtlReadonlyConnection:
"""ETL 只读连接验证连接参数、只读设置、RLS 隔离。"""
@patch("app.database.psycopg2.connect")
def test_sets_readonly_and_site_id(self, mock_connect):
"""连接后应依次执行 SET read_only 和 SET LOCAL site_id。"""
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_connect.return_value = mock_conn
conn = get_etl_readonly_connection(site_id=42)
# 验证 autocommit 被关闭
assert mock_conn.autocommit is False
# 验证执行了两条 SET 语句
executed = [c.args[0] for c in mock_cursor.execute.call_args_list]
assert "SET default_transaction_read_only = on" in executed[0]
assert "SET LOCAL app.current_site_id" in executed[1]
# 验证 site_id 参数化传递(防 SQL 注入)
site_id_call = mock_cursor.execute.call_args_list[1]
assert site_id_call.args[1] == ("42",)
# 验证提交
mock_conn.commit.assert_called_once()
assert conn is mock_conn
@patch("app.database.psycopg2.connect")
def test_accepts_string_site_id(self, mock_connect):
"""site_id 为字符串时也应正常工作。"""
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_connect.return_value = mock_conn
get_etl_readonly_connection(site_id="99")
site_id_call = mock_cursor.execute.call_args_list[1]
assert site_id_call.args[1] == ("99",)
@patch("app.database.psycopg2.connect")
def test_closes_connection_on_setup_error(self, mock_connect):
"""SET 语句执行失败时应关闭连接并抛出异常。"""
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_cursor.execute.side_effect = Exception("SET failed")
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_connect.return_value = mock_conn
with pytest.raises(Exception, match="SET failed"):
get_etl_readonly_connection(site_id=1)
mock_conn.close.assert_called_once()
@patch("app.database.psycopg2.connect")
def test_uses_etl_config_params(self, mock_connect):
"""应使用 ETL_DB_* 配置项连接。"""
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_connect.return_value = mock_conn
get_etl_readonly_connection(site_id=1)
connect_kwargs = mock_connect.call_args.kwargs
# 验证使用了 ETL 数据库名(默认 etl_feiqiu
assert connect_kwargs["dbname"] == "etl_feiqiu"

View File

@@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-
"""数据库查看器属性测试Property-Based Testing
使用 hypothesis 验证数据库查看器的通用正确性属性:
- Property 17: SQL 写操作拦截
- Property 18: SQL 查询结果行数限制
测试策略:
- Property 17: 生成包含写操作关键词(随机大小写混合)的 SQL 字符串,
验证 _WRITE_KEYWORDS 正则表达式能匹配到
- Property 18: 生成随机长度的行列表(可能超过 1000 行),
验证截取前 _MAX_ROWS 个元素后长度 <= 1000
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-db-viewer-properties")
from hypothesis import given, settings
from hypothesis import strategies as st
from app.routers.db_viewer import _WRITE_KEYWORDS, _MAX_ROWS
# ---------------------------------------------------------------------------
# 通用策略Strategies
# ---------------------------------------------------------------------------
# 写操作关键词列表
_WRITE_OPS = ["INSERT", "UPDATE", "DELETE", "DROP", "TRUNCATE"]
# SQL 前缀/后缀:不含写操作关键词的简单文本
_sql_filler_st = st.text(
alphabet=st.characters(
whitelist_categories=("L", "N", "S"),
blacklist_characters="\x00",
),
min_size=0,
max_size=50,
)
# 随机大小写混合的写操作关键词
_random_case_keyword_st = st.sampled_from(_WRITE_OPS).flatmap(
lambda kw: st.tuples(
st.just(kw),
st.lists(
st.booleans(),
min_size=len(kw),
max_size=len(kw),
),
).map(
lambda pair: "".join(
c.upper() if flag else c.lower()
for c, flag in zip(pair[0], pair[1])
)
)
)
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 17: SQL 写操作拦截
# **Validates: Requirements 7.5**
# ---------------------------------------------------------------------------
@settings(max_examples=200)
@given(
prefix=_sql_filler_st,
keyword=_random_case_keyword_st,
suffix=_sql_filler_st,
)
def test_write_keywords_always_detected(prefix, keyword, suffix):
"""Property 17: SQL 写操作拦截。
包含 INSERT、UPDATE、DELETE、DROP、TRUNCATE 关键词(不区分大小写)的
SQL 语句_WRITE_KEYWORDS 正则表达式应能匹配到。
策略:在随机前缀和后缀之间插入一个随机大小写混合的写操作关键词,
用空格分隔以确保 \\b 词边界能匹配。
"""
# 用空格分隔确保词边界匹配
sql = f"{prefix} {keyword} {suffix}"
match = _WRITE_KEYWORDS.search(sql)
assert match is not None, (
f"正则表达式未能匹配到写操作关键词sql={sql!r}, keyword={keyword!r}"
)
# 匹配到的关键词(转大写后)应在写操作列表中
assert match.group(1).upper() in _WRITE_OPS, (
f"匹配到的关键词 '{match.group(1)}' 不在写操作列表中"
)
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 18: SQL 查询结果行数限制
# **Validates: Requirements 7.4**
# ---------------------------------------------------------------------------
# 模拟数据库返回的行:每行是一个简单列表
_row_st = st.lists(
st.one_of(st.integers(), st.text(max_size=20), st.none()),
min_size=1,
max_size=5,
)
# 行列表策略0 到 3000 行(覆盖超过 _MAX_ROWS 的情况)
_rows_st = st.lists(_row_st, min_size=0, max_size=3000)
@settings(max_examples=200)
@given(rows=_rows_st)
def test_row_count_never_exceeds_max(rows):
"""Property 18: SQL 查询结果行数限制。
对任意长度的行列表,取前 _MAX_ROWS 个元素后,
结果长度应 <= 1000。
这等价于 cur.fetchmany(_MAX_ROWS) 的行为:
数据库游标最多返回 _MAX_ROWS 行。
"""
# 模拟 fetchmany(_MAX_ROWS) 的行为
truncated = rows[:_MAX_ROWS]
assert len(truncated) <= _MAX_ROWS, (
f"截取后行数 {len(truncated)} 超过上限 {_MAX_ROWS}"
)
# 额外验证:如果原始行数 <= _MAX_ROWS截取后应保留全部
if len(rows) <= _MAX_ROWS:
assert len(truncated) == len(rows), (
f"原始行数 {len(rows)} <= {_MAX_ROWS},截取后应保留全部,"
f"实际 {len(truncated)}"
)
# 额外验证:如果原始行数 > _MAX_ROWS截取后应恰好为 _MAX_ROWS
if len(rows) > _MAX_ROWS:
assert len(truncated) == _MAX_ROWS, (
f"原始行数 {len(rows)} > {_MAX_ROWS},截取后应恰好为 {_MAX_ROWS}"
f"实际 {len(truncated)}"
)

View File

@@ -0,0 +1,321 @@
# -*- coding: utf-8 -*-
"""数据库查看器路由单元测试
覆盖 4 个端点:
- GET /api/db/schemas
- GET /api/db/schemas/{name}/tables
- GET /api/db/tables/{schema}/{table}/columns
- POST /api/db/query
通过 mock 绕过数据库连接,专注路由逻辑验证。
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
from unittest.mock import patch, MagicMock
import pytest
from fastapi.testclient import TestClient
from psycopg2 import errors as pg_errors
from app.auth.dependencies import CurrentUser, get_current_user
from app.main import app
_TEST_USER = CurrentUser(user_id=1, site_id=100)
def _override_auth():
return _TEST_USER
app.dependency_overrides[get_current_user] = _override_auth
client = TestClient(app)
_MOCK_CONN = "app.routers.db_viewer.get_etl_readonly_connection"
def _make_mock_conn(rows, description=None):
"""构造 mock 数据库连接cursor 返回指定行和列描述。"""
mock_conn = MagicMock()
mock_cur = MagicMock()
mock_cur.fetchall.return_value = rows
mock_cur.fetchmany.return_value = rows
mock_cur.description = description
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cur
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
return mock_conn, mock_cur
# ---------------------------------------------------------------------------
# GET /api/db/schemas
# ---------------------------------------------------------------------------
class TestListSchemas:
@patch(_MOCK_CONN)
def test_returns_schema_list(self, mock_get_conn):
conn, cur = _make_mock_conn([("dwd",), ("dws",), ("ods",)])
mock_get_conn.return_value = conn
resp = client.get("/api/db/schemas")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 3
assert data[0]["name"] == "dwd"
assert data[2]["name"] == "ods"
# 验证 site_id 传递
mock_get_conn.assert_called_once_with(_TEST_USER.site_id)
conn.close.assert_called_once()
@patch(_MOCK_CONN)
def test_empty_schemas(self, mock_get_conn):
conn, cur = _make_mock_conn([])
mock_get_conn.return_value = conn
resp = client.get("/api/db/schemas")
assert resp.status_code == 200
assert resp.json() == []
# ---------------------------------------------------------------------------
# GET /api/db/schemas/{name}/tables
# ---------------------------------------------------------------------------
class TestListTables:
@patch(_MOCK_CONN)
def test_returns_tables_with_row_count(self, mock_get_conn):
conn, cur = _make_mock_conn([
("dim_member", 1500),
("fact_order", 32000),
])
mock_get_conn.return_value = conn
resp = client.get("/api/db/schemas/dwd/tables")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 2
assert data[0]["name"] == "dim_member"
assert data[0]["row_count"] == 1500
assert data[1]["name"] == "fact_order"
assert data[1]["row_count"] == 32000
@patch(_MOCK_CONN)
def test_null_row_count(self, mock_get_conn):
"""pg_stat_user_tables 可能没有统计信息row_count 为 None。"""
conn, cur = _make_mock_conn([("new_table", None)])
mock_get_conn.return_value = conn
resp = client.get("/api/db/schemas/ods/tables")
assert resp.status_code == 200
data = resp.json()
assert data[0]["row_count"] is None
@patch(_MOCK_CONN)
def test_empty_schema(self, mock_get_conn):
conn, cur = _make_mock_conn([])
mock_get_conn.return_value = conn
resp = client.get("/api/db/schemas/empty_schema/tables")
assert resp.status_code == 200
assert resp.json() == []
# ---------------------------------------------------------------------------
# GET /api/db/tables/{schema}/{table}/columns
# ---------------------------------------------------------------------------
class TestListColumns:
@patch(_MOCK_CONN)
def test_returns_column_definitions(self, mock_get_conn):
conn, cur = _make_mock_conn([
("id", "bigint", "NO", None),
("name", "character varying", "YES", None),
("created_at", "timestamp with time zone", "NO", "now()"),
])
mock_get_conn.return_value = conn
resp = client.get("/api/db/tables/dwd/dim_member/columns")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 3
assert data[0]["name"] == "id"
assert data[0]["data_type"] == "bigint"
assert data[0]["is_nullable"] is False
assert data[0]["column_default"] is None
assert data[1]["is_nullable"] is True
assert data[2]["column_default"] == "now()"
@patch(_MOCK_CONN)
def test_empty_table(self, mock_get_conn):
conn, cur = _make_mock_conn([])
mock_get_conn.return_value = conn
resp = client.get("/api/db/tables/dwd/nonexistent/columns")
assert resp.status_code == 200
assert resp.json() == []
# ---------------------------------------------------------------------------
# POST /api/db/query
# ---------------------------------------------------------------------------
class TestExecuteQuery:
@patch(_MOCK_CONN)
def test_successful_select(self, mock_get_conn):
description = [("id",), ("name",)]
conn, cur = _make_mock_conn(
[(1, "Alice"), (2, "Bob")],
description=description,
)
mock_get_conn.return_value = conn
resp = client.post("/api/db/query", json={"sql": "SELECT id, name FROM users"})
assert resp.status_code == 200
data = resp.json()
assert data["columns"] == ["id", "name"]
assert data["rows"] == [[1, "Alice"], [2, "Bob"]]
assert data["row_count"] == 2
@patch(_MOCK_CONN)
def test_empty_result(self, mock_get_conn):
description = [("id",)]
conn, cur = _make_mock_conn([], description=description)
mock_get_conn.return_value = conn
resp = client.post("/api/db/query", json={"sql": "SELECT id FROM empty_table"})
assert resp.status_code == 200
data = resp.json()
assert data["columns"] == ["id"]
assert data["rows"] == []
assert data["row_count"] == 0
# ── 写操作拦截 ──
@pytest.mark.parametrize("keyword", [
"INSERT", "UPDATE", "DELETE", "DROP", "TRUNCATE",
"insert", "update", "delete", "drop", "truncate",
"Insert", "Update", "Delete", "Drop", "Truncate",
])
def test_blocks_write_operations(self, keyword):
resp = client.post("/api/db/query", json={"sql": f"{keyword} INTO some_table VALUES (1)"})
assert resp.status_code == 400
assert "只读" in resp.json()["detail"] or "禁止" in resp.json()["detail"]
def test_blocks_mixed_case_write(self):
resp = client.post("/api/db/query", json={"sql": "DeLeTe FROM users WHERE id = 1"})
assert resp.status_code == 400
def test_blocks_write_in_subquery(self):
"""写操作关键词出现在 SQL 任意位置都应拦截。"""
resp = client.post("/api/db/query", json={"sql": "SELECT * FROM (DELETE FROM users) sub"})
assert resp.status_code == 400
# ── 空 SQL ──
def test_empty_sql(self):
resp = client.post("/api/db/query", json={"sql": ""})
assert resp.status_code == 400
def test_whitespace_only_sql(self):
resp = client.post("/api/db/query", json={"sql": " "})
assert resp.status_code == 400
# ── SQL 语法错误 ──
@patch(_MOCK_CONN)
def test_sql_syntax_error(self, mock_get_conn):
conn = MagicMock()
mock_cur = MagicMock()
# 第一次 execute 设置 timeout 成功,第二次抛异常
mock_cur.execute.side_effect = [None, Exception("syntax error at or near \"SELEC\"")]
mock_cur.description = None
conn.cursor.return_value.__enter__ = lambda s: mock_cur
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = conn
resp = client.post("/api/db/query", json={"sql": "SELEC * FROM users"})
assert resp.status_code == 400
assert "SQL 执行错误" in resp.json()["detail"]
# ── 查询超时 ──
@patch(_MOCK_CONN)
def test_query_timeout(self, mock_get_conn):
conn = MagicMock()
mock_cur = MagicMock()
mock_cur.execute.side_effect = [None, pg_errors.QueryCanceled()]
mock_cur.description = None
conn.cursor.return_value.__enter__ = lambda s: mock_cur
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = conn
resp = client.post("/api/db/query", json={"sql": "SELECT pg_sleep(60)"})
assert resp.status_code == 408
assert "超时" in resp.json()["detail"]
# ── 行数限制验证 ──
@patch(_MOCK_CONN)
def test_row_limit(self, mock_get_conn):
"""验证 fetchmany 被调用时传入 1000 行限制。"""
description = [("id",)]
conn, cur = _make_mock_conn(
[(i,) for i in range(1000)],
description=description,
)
mock_get_conn.return_value = conn
resp = client.post("/api/db/query", json={"sql": "SELECT id FROM big_table"})
assert resp.status_code == 200
# 验证 fetchmany 被调用时传入了 1000
cur.fetchmany.assert_called_once_with(1000)
# ── 超时设置验证 ──
@patch(_MOCK_CONN)
def test_sets_statement_timeout(self, mock_get_conn):
"""验证查询前设置了 statement_timeout。"""
description = [("id",)]
conn, cur = _make_mock_conn([(1,)], description=description)
mock_get_conn.return_value = conn
client.post("/api/db/query", json={"sql": "SELECT 1"})
# 第一次 execute 应该是设置超时
first_call = cur.execute.call_args_list[0]
assert "statement_timeout" in first_call[0][0]
# ---------------------------------------------------------------------------
# 认证测试
# ---------------------------------------------------------------------------
class TestDbViewerAuth:
def test_requires_auth(self):
"""移除 auth override 后,所有端点应返回 401/403。"""
original = app.dependency_overrides.pop(get_current_user, None)
try:
endpoints = [
("GET", "/api/db/schemas"),
("GET", "/api/db/schemas/dwd/tables"),
("GET", "/api/db/tables/dwd/dim_member/columns"),
("POST", "/api/db/query"),
]
for method, url in endpoints:
if method == "POST":
resp = client.request(method, url, json={"sql": "SELECT 1"})
else:
resp = client.request(method, url)
assert resp.status_code in (401, 403), f"{method} {url} 应需要认证"
finally:
if original:
app.dependency_overrides[get_current_user] = original

View File

@@ -0,0 +1,191 @@
# -*- coding: utf-8 -*-
"""环境配置属性测试Property-Based Testing
使用 hypothesis 验证环境配置管理的通用正确性属性:
- Property 15: .env 解析与敏感值掩码
- Property 16: .env 写入往返一致性
测试策略:
- Property 15: 生成随机 .env 内容(含敏感和非敏感键),验证 _parse_env + _is_sensitive 对敏感值掩码
- Property 16: 生成随机键值对,序列化为 .env 格式后再解析,验证往返一致性
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-env-config-properties")
from hypothesis import given, settings, assume
from hypothesis import strategies as st
from app.routers.env_config import _parse_env, _is_sensitive, _MASK, _SENSITIVE_KEYWORDS
# ---------------------------------------------------------------------------
# 通用策略Strategies
# ---------------------------------------------------------------------------
# 合法的环境变量键名:字母或下划线开头,后跟字母、数字、下划线
_key_start_char = st.sampled_from(
list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_")
)
_key_rest_char = st.sampled_from(
list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_")
)
_env_key_st = st.builds(
lambda first, rest: first + rest,
first=_key_start_char,
rest=st.text(alphabet=_key_rest_char, min_size=0, max_size=30),
)
# 值:不含换行符的可打印字符串(排除引号以避免解析歧义)
_env_value_st = st.text(
alphabet=st.characters(
whitelist_categories=("L", "N", "P", "S"),
blacklist_characters='\n\r"\'#',
),
min_size=0,
max_size=50,
)
# 敏感键:在随机键名中嵌入敏感关键词
_sensitive_keyword_st = st.sampled_from(list(_SENSITIVE_KEYWORDS))
_sensitive_key_st = st.builds(
lambda prefix, kw, suffix: prefix + kw + suffix,
prefix=st.text(alphabet=_key_rest_char, min_size=0, max_size=10),
kw=_sensitive_keyword_st,
suffix=st.text(alphabet=_key_rest_char, min_size=0, max_size=10),
).filter(lambda k: len(k) > 0 and k[0].isalpha() or k[0] == "_")
# 确保敏感键以字母或下划线开头
_safe_sensitive_key_st = st.builds(
lambda prefix, kw: prefix + "_" + kw,
prefix=st.sampled_from(["DB", "API", "ETL", "APP", "MY"]),
kw=_sensitive_keyword_st,
)
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 15: .env 解析与敏感值掩码
# **Validates: Requirements 6.1, 6.3**
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(
sensitive_keys=st.lists(_safe_sensitive_key_st, min_size=1, max_size=5, unique=True),
sensitive_values=st.lists(
st.text(min_size=1, max_size=30, alphabet=st.characters(
whitelist_categories=("L", "N"),
)),
min_size=1, max_size=5,
),
normal_keys=st.lists(_env_key_st, min_size=1, max_size=5, unique=True),
normal_values=st.lists(_env_value_st, min_size=1, max_size=5),
)
def test_sensitive_values_masked(sensitive_keys, sensitive_values, normal_keys, normal_values):
"""Property 15: .env 解析与敏感值掩码。
包含敏感键PASSWORD、TOKEN、SECRET、DSN的 .env 文件内容,
API 返回的键值对列表中这些键的值应被掩码替换,不包含原始敏感值。
"""
# 确保敏感键和普通键不重叠
normal_keys_filtered = [k for k in normal_keys if k not in sensitive_keys]
assume(len(normal_keys_filtered) >= 1)
# 对齐列表长度
s_vals = (sensitive_values * ((len(sensitive_keys) // len(sensitive_values)) + 1))[:len(sensitive_keys)]
n_vals = (normal_values * ((len(normal_keys_filtered) // len(normal_values)) + 1))[:len(normal_keys_filtered)]
# 构造 .env 内容
lines = []
for k, v in zip(sensitive_keys, s_vals):
lines.append(f"{k}={v}")
for k, v in zip(normal_keys_filtered, n_vals):
lines.append(f"{k}={v}")
env_content = "\n".join(lines) + "\n"
# 解析
parsed = _parse_env(env_content)
entries = [line for line in parsed if line["type"] == "entry"]
# 模拟 GET 端点的掩码逻辑
masked_entries = {}
for entry in entries:
if _is_sensitive(entry["key"]):
masked_entries[entry["key"]] = _MASK
else:
masked_entries[entry["key"]] = entry["value"]
# 验证:敏感键的值应被掩码
for k, v in zip(sensitive_keys, s_vals):
assert k in masked_entries, f"敏感键 {k} 应出现在解析结果中"
assert masked_entries[k] == _MASK, (
f"敏感键 {k} 的值应为掩码 '{_MASK}',实际为 '{masked_entries[k]}'"
)
# 原始敏感值不应出现在掩码后的结果中
assert masked_entries[k] != v, (
f"敏感键 {k} 的原始值 '{v}' 不应出现在掩码结果中"
)
# 验证:非敏感键的值应保持原样
for k, v in zip(normal_keys_filtered, n_vals):
if not _is_sensitive(k):
assert k in masked_entries, f"普通键 {k} 应出现在解析结果中"
assert masked_entries[k] == v, (
f"普通键 {k} 的值应为 '{v}',实际为 '{masked_entries[k]}'"
)
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 16: .env 写入往返一致性
# **Validates: Requirements 6.2**
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(
entries=st.lists(
st.tuples(_env_key_st, _env_value_st),
min_size=1,
max_size=10,
unique_by=lambda t: t[0], # 键唯一
),
)
def test_env_write_read_round_trip(entries):
"""Property 16: .env 写入往返一致性。
有效的键值对集合(不含注释和空行),写入 .env 文件后再读取解析,
应得到与原始集合等价的键值对。
"""
# 过滤掉值中可能导致解析歧义的情况(值前后空白会被 strip
clean_entries = [(k, v.strip()) for k, v in entries]
# 排除空键(策略已保证非空,但防御性检查)
clean_entries = [(k, v) for k, v in clean_entries if k]
assume(len(clean_entries) >= 1)
# 模拟写入:构造 .env 文件内容
lines = [f"{k}={v}" for k, v in clean_entries]
env_content = "\n".join(lines) + "\n"
# 解析
parsed = _parse_env(env_content)
parsed_entries = {
line["key"]: line["value"]
for line in parsed
if line["type"] == "entry"
}
# 验证往返一致性:每个写入的键值对都应在解析结果中
for k, v in clean_entries:
assert k in parsed_entries, (
f"'{k}' 应出现在解析结果中,实际键集合: {list(parsed_entries.keys())}"
)
assert parsed_entries[k] == v, (
f"'{k}' 的值不一致:写入='{v}',解析='{parsed_entries[k]}'"
)
# 验证:解析结果的键数量应与写入的一致
assert len(parsed_entries) == len(clean_entries), (
f"解析结果键数量 {len(parsed_entries)} 应等于写入数量 {len(clean_entries)}"
)

View File

@@ -0,0 +1,291 @@
# -*- coding: utf-8 -*-
"""环境配置路由单元测试
覆盖 3 个端点GET / PUT / GET /export
通过 mock 绕过文件 I/O专注路由逻辑验证。
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
from unittest.mock import patch, MagicMock
import pytest
from fastapi.testclient import TestClient
from app.auth.dependencies import CurrentUser, get_current_user
from app.main import app
_TEST_USER = CurrentUser(user_id=1, site_id=100)
def _override_auth():
return _TEST_USER
app.dependency_overrides[get_current_user] = _override_auth
client = TestClient(app)
# 模拟 .env 文件内容
_SAMPLE_ENV = """\
# 数据库配置
DB_HOST=localhost
DB_PORT=5432
DB_PASSWORD=super_secret_123
JWT_SECRET_KEY=my-jwt-secret
# ETL 配置
ETL_DB_DSN=postgresql://user:pass@host/db
TIMEZONE=Asia/Shanghai
"""
_MOCK_ENV_PATH = "app.routers.env_config._ENV_PATH"
def _mock_path(content: str | None = _SAMPLE_ENV, exists: bool = True):
"""构造 mock Path 对象。"""
mock = MagicMock()
mock.exists.return_value = exists
if content is not None:
mock.read_text.return_value = content
return mock
# ---------------------------------------------------------------------------
# GET /api/env-config
# ---------------------------------------------------------------------------
class TestGetEnvConfig:
@patch(_MOCK_ENV_PATH)
def test_returns_entries_with_masked_sensitive(self, mock_path_obj):
mock_path_obj.__class__ = type(MagicMock())
mock_path_obj.exists.return_value = True
mock_path_obj.read_text.return_value = _SAMPLE_ENV
resp = client.get("/api/env-config")
assert resp.status_code == 200
data = resp.json()
entries = {e["key"]: e["value"] for e in data["entries"]}
# 非敏感值原样返回
assert entries["DB_HOST"] == "localhost"
assert entries["DB_PORT"] == "5432"
assert entries["TIMEZONE"] == "Asia/Shanghai"
# 敏感值掩码
assert entries["DB_PASSWORD"] == "****"
assert entries["JWT_SECRET_KEY"] == "****"
assert entries["ETL_DB_DSN"] == "****"
@patch(_MOCK_ENV_PATH)
def test_file_not_found(self, mock_path_obj):
mock_path_obj.exists.return_value = False
resp = client.get("/api/env-config")
assert resp.status_code == 404
@patch(_MOCK_ENV_PATH)
def test_empty_file(self, mock_path_obj):
mock_path_obj.exists.return_value = True
mock_path_obj.read_text.return_value = ""
resp = client.get("/api/env-config")
assert resp.status_code == 200
assert resp.json()["entries"] == []
@patch(_MOCK_ENV_PATH)
def test_comments_and_blank_lines_excluded(self, mock_path_obj):
mock_path_obj.exists.return_value = True
mock_path_obj.read_text.return_value = "# comment\n\nKEY=val\n"
resp = client.get("/api/env-config")
assert resp.status_code == 200
entries = resp.json()["entries"]
assert len(entries) == 1
assert entries[0]["key"] == "KEY"
# ---------------------------------------------------------------------------
# PUT /api/env-config
# ---------------------------------------------------------------------------
class TestUpdateEnvConfig:
@patch(_MOCK_ENV_PATH)
def test_update_existing_key(self, mock_path_obj):
mock_path_obj.exists.return_value = True
mock_path_obj.read_text.return_value = "DB_HOST=localhost\nDB_PORT=5432\n"
resp = client.put("/api/env-config", json={
"entries": [
{"key": "DB_HOST", "value": "192.168.1.1"},
{"key": "DB_PORT", "value": "5433"},
]
})
assert resp.status_code == 200
# 验证写入内容
written = mock_path_obj.write_text.call_args[0][0]
assert "DB_HOST=192.168.1.1" in written
assert "DB_PORT=5433" in written
@patch(_MOCK_ENV_PATH)
def test_add_new_key(self, mock_path_obj):
mock_path_obj.exists.return_value = True
mock_path_obj.read_text.return_value = "DB_HOST=localhost\n"
resp = client.put("/api/env-config", json={
"entries": [
{"key": "DB_HOST", "value": "localhost"},
{"key": "NEW_KEY", "value": "new_value"},
]
})
assert resp.status_code == 200
written = mock_path_obj.write_text.call_args[0][0]
assert "NEW_KEY=new_value" in written
@patch(_MOCK_ENV_PATH)
def test_masked_value_preserves_original(self, mock_path_obj):
"""掩码值(****)不应覆盖原始敏感值。"""
mock_path_obj.exists.return_value = True
mock_path_obj.read_text.return_value = "DB_PASSWORD=real_secret\nDB_HOST=localhost\n"
resp = client.put("/api/env-config", json={
"entries": [
{"key": "DB_PASSWORD", "value": "****"},
{"key": "DB_HOST", "value": "newhost"},
]
})
assert resp.status_code == 200
written = mock_path_obj.write_text.call_args[0][0]
# 原始密码应保留
assert "DB_PASSWORD=real_secret" in written
assert "DB_HOST=newhost" in written
@patch(_MOCK_ENV_PATH)
def test_preserves_comments(self, mock_path_obj):
mock_path_obj.exists.return_value = True
mock_path_obj.read_text.return_value = "# 注释行\nDB_HOST=localhost\n\n# 另一个注释\n"
resp = client.put("/api/env-config", json={
"entries": [{"key": "DB_HOST", "value": "newhost"}]
})
assert resp.status_code == 200
written = mock_path_obj.write_text.call_args[0][0]
assert "# 注释行" in written
assert "# 另一个注释" in written
def test_invalid_key_format(self):
resp = client.put("/api/env-config", json={
"entries": [{"key": "123BAD", "value": "val"}]
})
assert resp.status_code == 422
def test_empty_key(self):
resp = client.put("/api/env-config", json={
"entries": [{"key": "", "value": "val"}]
})
assert resp.status_code == 422
@patch(_MOCK_ENV_PATH)
def test_file_not_exists_creates_new(self, mock_path_obj):
"""文件不存在时,应创建新文件。"""
mock_path_obj.exists.return_value = False
resp = client.put("/api/env-config", json={
"entries": [{"key": "NEW_KEY", "value": "value"}]
})
assert resp.status_code == 200
written = mock_path_obj.write_text.call_args[0][0]
assert "NEW_KEY=value" in written
@patch(_MOCK_ENV_PATH)
def test_update_sensitive_with_new_value(self, mock_path_obj):
"""显式提供新密码时应更新。"""
mock_path_obj.exists.return_value = True
mock_path_obj.read_text.return_value = "DB_PASSWORD=old_secret\n"
resp = client.put("/api/env-config", json={
"entries": [{"key": "DB_PASSWORD", "value": "new_secret"}]
})
assert resp.status_code == 200
written = mock_path_obj.write_text.call_args[0][0]
assert "DB_PASSWORD=new_secret" in written
# 返回值中敏感键仍然掩码
entries = {e["key"]: e["value"] for e in resp.json()["entries"]}
assert entries["DB_PASSWORD"] == "****"
# ---------------------------------------------------------------------------
# GET /api/env-config/export
# ---------------------------------------------------------------------------
class TestExportEnvConfig:
@patch(_MOCK_ENV_PATH)
def test_export_masks_sensitive(self, mock_path_obj):
mock_path_obj.exists.return_value = True
mock_path_obj.read_text.return_value = _SAMPLE_ENV
resp = client.get("/api/env-config/export")
assert resp.status_code == 200
assert resp.headers["content-type"].startswith("text/plain")
assert "attachment" in resp.headers.get("content-disposition", "")
content = resp.text
# 非敏感值保留
assert "DB_HOST=localhost" in content
assert "TIMEZONE=Asia/Shanghai" in content
# 敏感值掩码
assert "super_secret_123" not in content
assert "my-jwt-secret" not in content
assert "DB_PASSWORD=****" in content
assert "JWT_SECRET_KEY=****" in content
@patch(_MOCK_ENV_PATH)
def test_export_preserves_comments(self, mock_path_obj):
mock_path_obj.exists.return_value = True
mock_path_obj.read_text.return_value = _SAMPLE_ENV
content = client.get("/api/env-config/export").text
assert "# 数据库配置" in content
assert "# ETL 配置" in content
@patch(_MOCK_ENV_PATH)
def test_export_file_not_found(self, mock_path_obj):
mock_path_obj.exists.return_value = False
resp = client.get("/api/env-config/export")
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# 认证测试
# ---------------------------------------------------------------------------
class TestEnvConfigAuth:
def test_requires_auth(self):
"""移除 auth override 后,所有端点应返回 401/403。"""
# 临时移除 override
original = app.dependency_overrides.pop(get_current_user, None)
try:
for method, url in [
("GET", "/api/env-config"),
("PUT", "/api/env-config"),
("GET", "/api/env-config/export"),
]:
resp = client.request(method, url)
assert resp.status_code in (401, 403), f"{method} {url} 应需要认证"
finally:
if original:
app.dependency_overrides[get_current_user] = original

View File

@@ -0,0 +1,246 @@
# -*- coding: utf-8 -*-
"""ETL 状态路由单元测试
覆盖 2 个端点:
- GET /api/etl-status/cursors
- GET /api/etl-status/recent-runs
通过 mock 绕过数据库连接,专注路由逻辑验证。
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
from unittest.mock import patch, MagicMock
from fastapi.testclient import TestClient
from app.auth.dependencies import CurrentUser, get_current_user
from app.main import app
_TEST_USER = CurrentUser(user_id=1, site_id=100)
def _override_auth():
return _TEST_USER
app.dependency_overrides[get_current_user] = _override_auth
client = TestClient(app)
_MOCK_ETL_CONN = "app.routers.etl_status.get_etl_readonly_connection"
_MOCK_APP_CONN = "app.routers.etl_status.get_connection"
def _make_mock_conn(rows):
"""构造 mock 数据库连接cursor 返回指定行。"""
mock_conn = MagicMock()
mock_cur = MagicMock()
mock_cur.fetchall.return_value = rows
mock_cur.fetchone.return_value = None
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cur
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
return mock_conn, mock_cur
# ---------------------------------------------------------------------------
# GET /api/etl-status/cursors
# ---------------------------------------------------------------------------
class TestListCursors:
@patch(_MOCK_ETL_CONN)
def test_returns_cursor_list(self, mock_get_conn):
conn, cur = _make_mock_conn([
("ODS_FETCH_ORDERS", "2024-06-15 10:30:00+08", 1500),
("ODS_FETCH_MEMBERS", "2024-06-15 09:00:00+08", 800),
])
# fetchone 用于 EXISTS 检查
cur.fetchone.return_value = (True,)
mock_get_conn.return_value = conn
resp = client.get("/api/etl-status/cursors")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 2
assert data[0]["task_code"] == "ODS_FETCH_ORDERS"
assert data[0]["last_fetch_time"] == "2024-06-15 10:30:00+08"
assert data[0]["record_count"] == 1500
assert data[1]["task_code"] == "ODS_FETCH_MEMBERS"
# 验证 site_id 传递
mock_get_conn.assert_called_once_with(_TEST_USER.site_id)
conn.close.assert_called_once()
@patch(_MOCK_ETL_CONN)
def test_table_not_exists_returns_empty(self, mock_get_conn):
"""etl_admin.etl_cursor 表不存在时返回空列表。"""
conn, cur = _make_mock_conn([])
cur.fetchone.return_value = (False,)
mock_get_conn.return_value = conn
resp = client.get("/api/etl-status/cursors")
assert resp.status_code == 200
assert resp.json() == []
@patch(_MOCK_ETL_CONN)
def test_null_fields(self, mock_get_conn):
"""游标字段可能为 None任务从未执行过"""
conn, cur = _make_mock_conn([
("ODS_FETCH_INVENTORY", None, None),
])
cur.fetchone.return_value = (True,)
mock_get_conn.return_value = conn
resp = client.get("/api/etl-status/cursors")
assert resp.status_code == 200
data = resp.json()
assert data[0]["task_code"] == "ODS_FETCH_INVENTORY"
assert data[0]["last_fetch_time"] is None
assert data[0]["record_count"] is None
@patch(_MOCK_ETL_CONN)
def test_empty_cursors(self, mock_get_conn):
"""表存在但无数据。"""
conn, cur = _make_mock_conn([])
cur.fetchone.return_value = (True,)
mock_get_conn.return_value = conn
resp = client.get("/api/etl-status/cursors")
assert resp.status_code == 200
assert resp.json() == []
# ---------------------------------------------------------------------------
# GET /api/etl-status/recent-runs
# ---------------------------------------------------------------------------
class TestListRecentRuns:
@patch(_MOCK_APP_CONN)
def test_returns_recent_runs(self, mock_get_conn):
conn, cur = _make_mock_conn([
(
"a1b2c3d4-0000-0000-0000-000000000001",
["ODS_FETCH_ORDERS", "DWD_LOAD_FROM_ODS"],
"success",
"2024-06-15 10:30:00+08",
"2024-06-15 10:35:00+08",
300000,
0,
),
(
"a1b2c3d4-0000-0000-0000-000000000002",
["DWS_AGGREGATE"],
"failed",
"2024-06-15 09:00:00+08",
"2024-06-15 09:01:00+08",
60000,
1,
),
])
mock_get_conn.return_value = conn
resp = client.get("/api/etl-status/recent-runs")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 2
run0 = data[0]
assert run0["id"] == "a1b2c3d4-0000-0000-0000-000000000001"
assert run0["task_codes"] == ["ODS_FETCH_ORDERS", "DWD_LOAD_FROM_ODS"]
assert run0["status"] == "success"
assert run0["duration_ms"] == 300000
assert run0["exit_code"] == 0
run1 = data[1]
assert run1["status"] == "failed"
assert run1["exit_code"] == 1
conn.close.assert_called_once()
@patch(_MOCK_APP_CONN)
def test_empty_runs(self, mock_get_conn):
conn, cur = _make_mock_conn([])
mock_get_conn.return_value = conn
resp = client.get("/api/etl-status/recent-runs")
assert resp.status_code == 200
assert resp.json() == []
@patch(_MOCK_APP_CONN)
def test_null_optional_fields(self, mock_get_conn):
"""正在执行的任务 finished_at / duration_ms / exit_code 为 None。"""
conn, cur = _make_mock_conn([
(
"a1b2c3d4-0000-0000-0000-000000000003",
["ODS_FETCH_MEMBERS"],
"running",
"2024-06-15 11:00:00+08",
None,
None,
None,
),
])
mock_get_conn.return_value = conn
resp = client.get("/api/etl-status/recent-runs")
assert resp.status_code == 200
data = resp.json()
assert data[0]["status"] == "running"
assert data[0]["finished_at"] is None
assert data[0]["duration_ms"] is None
assert data[0]["exit_code"] is None
@patch(_MOCK_APP_CONN)
def test_site_id_filter(self, mock_get_conn):
"""验证查询时传入了正确的 site_id 参数。"""
conn, cur = _make_mock_conn([])
mock_get_conn.return_value = conn
client.get("/api/etl-status/recent-runs")
# 验证 SQL 中传入了 site_id 和 limit
call_args = cur.execute.call_args
params = call_args[0][1]
assert params[0] == _TEST_USER.site_id
assert params[1] == 50
@patch(_MOCK_APP_CONN)
def test_empty_task_codes(self, mock_get_conn):
"""task_codes 为 None 时应返回空列表。"""
conn, cur = _make_mock_conn([
(
"a1b2c3d4-0000-0000-0000-000000000004",
None,
"pending",
"2024-06-15 12:00:00+08",
None,
None,
None,
),
])
mock_get_conn.return_value = conn
resp = client.get("/api/etl-status/recent-runs")
assert resp.status_code == 200
assert resp.json()[0]["task_codes"] == []
# ---------------------------------------------------------------------------
# 认证测试
# ---------------------------------------------------------------------------
class TestEtlStatusAuth:
def test_requires_auth(self):
"""移除 auth override 后,所有端点应返回 401/403。"""
original = app.dependency_overrides.pop(get_current_user, None)
try:
for url in ["/api/etl-status/cursors", "/api/etl-status/recent-runs"]:
resp = client.get(url)
assert resp.status_code in (401, 403), f"GET {url} 应需要认证"
finally:
if original:
app.dependency_overrides[get_current_user] = original

View File

@@ -0,0 +1,339 @@
# -*- coding: utf-8 -*-
"""执行与队列路由单元测试
覆盖 8 个端点run / queue CRUD / cancel / history / logs
通过 mock 绕过数据库和服务层,专注路由逻辑验证。
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
from dataclasses import dataclass
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from app.auth.dependencies import CurrentUser, get_current_user
from app.main import app
from app.services.task_queue import QueuedTask
# 固定测试用户
_TEST_USER = CurrentUser(user_id=1, site_id=100)
def _override_auth():
return _TEST_USER
app.dependency_overrides[get_current_user] = _override_auth
client = TestClient(app)
_NOW = datetime(2024, 6, 1, 12, 0, 0, tzinfo=timezone.utc)
# 构造测试用的 TaskConfig payload
_VALID_CONFIG = {
"tasks": ["ODS_MEMBER"],
"pipeline": "api_ods",
}
# ---------------------------------------------------------------------------
# POST /api/execution/run
# ---------------------------------------------------------------------------
class TestRunTask:
@patch("app.routers.execution.task_executor")
def test_run_returns_execution_id(self, mock_executor):
mock_executor.execute = AsyncMock()
resp = client.post("/api/execution/run", json=_VALID_CONFIG)
assert resp.status_code == 200
data = resp.json()
assert "execution_id" in data
assert data["message"] == "任务已提交执行"
@patch("app.routers.execution.task_executor")
def test_run_injects_store_id(self, mock_executor):
"""store_id 应从 JWT 注入"""
mock_executor.execute = AsyncMock()
resp = client.post("/api/execution/run", json={
**_VALID_CONFIG,
"store_id": 999, # 前端传的值应被覆盖
})
assert resp.status_code == 200
def test_run_requires_auth(self):
app.dependency_overrides.pop(get_current_user, None)
try:
resp = client.post("/api/execution/run", json=_VALID_CONFIG)
assert resp.status_code in (401, 403)
finally:
app.dependency_overrides[get_current_user] = _override_auth
def test_run_invalid_config_returns_422(self):
"""缺少必填字段 tasks 时返回 422"""
resp = client.post("/api/execution/run", json={"pipeline": "api_ods"})
assert resp.status_code == 422
# ---------------------------------------------------------------------------
# GET /api/execution/queue
# ---------------------------------------------------------------------------
class TestGetQueue:
@patch("app.routers.execution.task_queue")
def test_get_queue_returns_list(self, mock_queue):
mock_queue.list_pending.return_value = [
QueuedTask(
id="task-1", site_id=100, config={"tasks": ["ODS_MEMBER"]},
status="pending", position=1, created_at=_NOW,
),
]
resp = client.get("/api/execution/queue")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 1
assert data[0]["id"] == "task-1"
assert data[0]["status"] == "pending"
@patch("app.routers.execution.task_queue")
def test_get_queue_empty(self, mock_queue):
mock_queue.list_pending.return_value = []
resp = client.get("/api/execution/queue")
assert resp.status_code == 200
assert resp.json() == []
@patch("app.routers.execution.task_queue")
def test_get_queue_filters_by_site_id(self, mock_queue):
"""确认调用 list_pending 时传入了正确的 site_id"""
mock_queue.list_pending.return_value = []
client.get("/api/execution/queue")
mock_queue.list_pending.assert_called_once_with(100)
# ---------------------------------------------------------------------------
# POST /api/execution/queue
# ---------------------------------------------------------------------------
class TestEnqueueTask:
@patch("app.routers.execution.get_connection")
@patch("app.routers.execution.task_queue")
def test_enqueue_returns_201(self, mock_queue, mock_get_conn):
mock_queue.enqueue.return_value = "new-task-id"
# mock 数据库查询返回
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = (
"new-task-id", 100, '{"tasks": ["ODS_MEMBER"]}',
"pending", 1, _NOW, None, None, None, None,
)
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.post("/api/execution/queue", json=_VALID_CONFIG)
assert resp.status_code == 201
data = resp.json()
assert data["id"] == "new-task-id"
assert data["status"] == "pending"
@patch("app.routers.execution.task_queue")
def test_enqueue_calls_with_site_id(self, mock_queue):
"""确认 enqueue 时传入了 JWT 的 site_id"""
mock_queue.enqueue.return_value = "id-1"
# 让后续的 DB 查询抛异常来快速结束enqueue 本身已验证)
with patch("app.routers.execution.get_connection") as mock_conn:
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = (
"id-1", 100, '{"tasks": []}', "pending", 1,
_NOW, None, None, None, None,
)
conn = MagicMock()
conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_conn.return_value = conn
client.post("/api/execution/queue", json=_VALID_CONFIG)
# 验证 enqueue 的第二个参数是 site_id=100
call_args = mock_queue.enqueue.call_args
assert call_args[0][1] == 100 # site_id
# ---------------------------------------------------------------------------
# PUT /api/execution/queue/reorder
# ---------------------------------------------------------------------------
class TestReorderQueue:
@patch("app.routers.execution.task_queue")
def test_reorder_success(self, mock_queue):
mock_queue.reorder.return_value = None
resp = client.put("/api/execution/queue/reorder", json={
"task_id": "task-1",
"new_position": 3,
})
assert resp.status_code == 200
mock_queue.reorder.assert_called_once_with("task-1", 3, 100)
def test_reorder_missing_fields_returns_422(self):
resp = client.put("/api/execution/queue/reorder", json={})
assert resp.status_code == 422
# ---------------------------------------------------------------------------
# DELETE /api/execution/queue/{id}
# ---------------------------------------------------------------------------
class TestDeleteQueueTask:
@patch("app.routers.execution.task_queue")
def test_delete_success(self, mock_queue):
mock_queue.delete.return_value = True
resp = client.delete("/api/execution/queue/task-1")
assert resp.status_code == 200
mock_queue.delete.assert_called_once_with("task-1", 100)
@patch("app.routers.execution.task_queue")
def test_delete_nonexistent_returns_409(self, mock_queue):
mock_queue.delete.return_value = False
resp = client.delete("/api/execution/queue/nonexistent")
assert resp.status_code == 409
# ---------------------------------------------------------------------------
# POST /api/execution/{id}/cancel
# ---------------------------------------------------------------------------
class TestCancelExecution:
@patch("app.routers.execution.task_executor")
def test_cancel_success(self, mock_executor):
mock_executor.cancel = AsyncMock(return_value=True)
resp = client.post("/api/execution/some-id/cancel")
assert resp.status_code == 200
assert "取消" in resp.json()["message"]
@patch("app.routers.execution.task_executor")
def test_cancel_not_found(self, mock_executor):
mock_executor.cancel = AsyncMock(return_value=False)
resp = client.post("/api/execution/nonexistent/cancel")
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# GET /api/execution/history
# ---------------------------------------------------------------------------
class TestExecutionHistory:
@patch("app.routers.execution.get_connection")
def test_history_returns_list(self, mock_get_conn):
mock_cursor = MagicMock()
mock_cursor.fetchall.return_value = [
(
"exec-1", 100, ["ODS_MEMBER"], "success",
_NOW, _NOW, 0, 1234, "python -m cli.main", None,
),
]
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.get("/api/execution/history")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 1
assert data[0]["id"] == "exec-1"
assert data[0]["status"] == "success"
@patch("app.routers.execution.get_connection")
def test_history_respects_limit(self, mock_get_conn):
mock_cursor = MagicMock()
mock_cursor.fetchall.return_value = []
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.get("/api/execution/history?limit=10")
assert resp.status_code == 200
# 验证 SQL 中传入了 limit=10
call_args = mock_cursor.execute.call_args
assert call_args[0][1] == (100, 10) # (site_id, limit)
def test_history_limit_validation(self):
"""limit 超出范围时返回 422"""
resp = client.get("/api/execution/history?limit=0")
assert resp.status_code == 422
resp = client.get("/api/execution/history?limit=999")
assert resp.status_code == 422
@patch("app.routers.execution.get_connection")
def test_history_empty(self, mock_get_conn):
mock_cursor = MagicMock()
mock_cursor.fetchall.return_value = []
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.get("/api/execution/history")
assert resp.status_code == 200
assert resp.json() == []
# ---------------------------------------------------------------------------
# GET /api/execution/{id}/logs
# ---------------------------------------------------------------------------
class TestExecutionLogs:
@patch("app.routers.execution.get_connection")
@patch("app.routers.execution.task_executor")
def test_logs_from_db(self, mock_executor, mock_get_conn):
"""已完成任务从数据库读取日志"""
mock_executor.is_running.return_value = False
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = ("stdout output", "stderr output")
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.get("/api/execution/exec-1/logs")
assert resp.status_code == 200
data = resp.json()
assert data["execution_id"] == "exec-1"
assert data["output_log"] == "stdout output"
assert data["error_log"] == "stderr output"
@patch("app.routers.execution.task_executor")
def test_logs_from_memory(self, mock_executor):
"""执行中的任务从内存缓冲区读取"""
mock_executor.is_running.return_value = True
mock_executor.get_logs.return_value = ["line1", "line2"]
resp = client.get("/api/execution/running-id/logs")
assert resp.status_code == 200
data = resp.json()
assert data["execution_id"] == "running-id"
assert "line1" in data["output_log"]
assert "line2" in data["output_log"]
@patch("app.routers.execution.get_connection")
@patch("app.routers.execution.task_executor")
def test_logs_not_found(self, mock_executor, mock_get_conn):
mock_executor.is_running.return_value = False
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = None
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.get("/api/execution/nonexistent/logs")
assert resp.status_code == 404

View File

@@ -0,0 +1,510 @@
# -*- coding: utf-8 -*-
"""队列属性测试Property-Based Testing
使用 hypothesis 验证队列管理的通用正确性属性:
- Property 8: 队列 CRUD 不变量
- Property 9: 队列出队顺序
- Property 10: 队列重排一致性
- Property 11: 执行历史排序与限制
测试策略:
- Property 8-10 通过内存模拟队列状态mock 数据库操作,验证 TaskQueue 的核心逻辑
- Property 11 通过 mock 数据库返回,验证执行历史端点的排序与限制逻辑
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-queue-properties")
import json
import uuid
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, patch
from hypothesis import given, settings, assume
from hypothesis import strategies as st
from app.schemas.tasks import TaskConfigSchema
from app.services.task_queue import TaskQueue, QueuedTask
# ---------------------------------------------------------------------------
# 通用策略Strategies
# ---------------------------------------------------------------------------
_site_id_st = st.integers(min_value=1, max_value=2**31 - 1)
# 简单的任务代码列表
_task_codes = ["ODS_MEMBER", "ODS_PAYMENT", "ODS_ORDER", "DWD_LOAD_FROM_ODS", "DWS_SUMMARY"]
_simple_config_st = st.builds(
TaskConfigSchema,
tasks=st.lists(st.sampled_from(_task_codes), min_size=1, max_size=3, unique=True),
pipeline=st.sampled_from(["api_ods", "api_ods_dwd", "ods_dwd"]),
)
# ---------------------------------------------------------------------------
# 内存队列模拟器 — 用于 mock 数据库交互
# ---------------------------------------------------------------------------
class InMemoryQueueDB:
"""模拟 task_queue 表的内存存储,为 TaskQueue 方法提供 mock 数据库行为。"""
def __init__(self, site_id: int):
self.site_id = site_id
# 存储格式:{task_id: {config, status, position, ...}}
self.rows: dict[str, dict] = {}
@property
def pending_tasks(self) -> list[dict]:
"""按 position 排序的 pending 任务列表。"""
return sorted(
[r for r in self.rows.values() if r["status"] == "pending"],
key=lambda r: r["position"],
)
def mock_enqueue_connection(self):
"""为 enqueue 方法构造 mock connection。
enqueue 执行两条 SQL
1. SELECT COALESCE(MAX(position), 0) → 返回当前最大 position
2. INSERT INTO task_queue → 插入新行
"""
pending = self.pending_tasks
max_pos = max((r["position"] for r in pending), default=0)
call_count = [0]
db = self
def make_cursor():
cur = MagicMock()
executed_sqls = []
def execute_side_effect(sql, params=None):
executed_sqls.append((sql, params))
call_count[0] += 1
if "MAX(position)" in sql:
cur.fetchone.return_value = (max_pos,)
elif "INSERT INTO task_queue" in sql:
# 记录插入的行
task_id, site_id, config_json, new_pos = params
db.rows[task_id] = {
"id": task_id,
"site_id": site_id,
"config": json.loads(config_json),
"status": "pending",
"position": new_pos,
}
cur.execute = MagicMock(side_effect=execute_side_effect)
cur.__enter__ = MagicMock(return_value=cur)
cur.__exit__ = MagicMock(return_value=False)
return cur
conn = MagicMock()
conn.cursor.return_value = make_cursor()
return conn
def mock_dequeue_connection(self):
"""为 dequeue 方法构造 mock connection。
dequeue 执行两条 SQL
1. SELECT ... ORDER BY position ASC LIMIT 1 FOR UPDATE → 返回队首任务
2. UPDATE ... SET status = 'running' → 更新状态
"""
pending = self.pending_tasks
first = pending[0] if pending else None
db = self
def make_cursor():
cur = MagicMock()
def execute_side_effect(sql, params=None):
if "ORDER BY position ASC" in sql:
if first:
cur.fetchone.return_value = (
first["id"], first["site_id"],
json.dumps(first["config"]),
first["status"], first["position"],
None, None, None, None, None,
)
else:
cur.fetchone.return_value = None
elif "SET status = 'running'" in sql:
if first:
db.rows[first["id"]]["status"] = "running"
cur.execute = MagicMock(side_effect=execute_side_effect)
cur.__enter__ = MagicMock(return_value=cur)
cur.__exit__ = MagicMock(return_value=False)
return cur
conn = MagicMock()
conn.cursor.return_value = make_cursor()
return conn
def mock_delete_connection(self, task_id: str):
"""为 delete 方法构造 mock connection。"""
db = self
def make_cursor():
cur = MagicMock()
def execute_side_effect(sql, params=None):
tid = params[0]
if tid in db.rows and db.rows[tid]["status"] == "pending":
del db.rows[tid]
cur.rowcount = 1
else:
cur.rowcount = 0
cur.execute = MagicMock(side_effect=execute_side_effect)
cur.rowcount = 0
cur.__enter__ = MagicMock(return_value=cur)
cur.__exit__ = MagicMock(return_value=False)
return cur
conn = MagicMock()
conn.cursor.return_value = make_cursor()
return conn
def mock_reorder_connection(self):
"""为 reorder 方法构造 mock connection。
reorder 执行:
1. SELECT id FROM task_queue WHERE ... ORDER BY position ASC
2. 多次 UPDATE task_queue SET position = %s WHERE id = %s
"""
pending = self.pending_tasks
db = self
def make_cursor():
cur = MagicMock()
call_idx = [0]
def execute_side_effect(sql, params=None):
if "SELECT id FROM task_queue" in sql:
cur.fetchall.return_value = [(r["id"],) for r in pending]
elif "UPDATE task_queue SET position" in sql:
pos, tid = params
if tid in db.rows:
db.rows[tid]["position"] = pos
cur.execute = MagicMock(side_effect=execute_side_effect)
cur.__enter__ = MagicMock(return_value=cur)
cur.__exit__ = MagicMock(return_value=False)
return cur
conn = MagicMock()
conn.cursor.return_value = make_cursor()
return conn
def mock_list_pending_connection(self):
"""为 list_pending 方法构造 mock connection。"""
pending = self.pending_tasks
def make_cursor():
cur = MagicMock()
def execute_side_effect(sql, params=None):
cur.fetchall.return_value = [
(
r["id"], r["site_id"], json.dumps(r["config"]),
r["status"], r["position"],
None, None, None, None, None,
)
for r in pending
]
cur.execute = MagicMock(side_effect=execute_side_effect)
cur.__enter__ = MagicMock(return_value=cur)
cur.__exit__ = MagicMock(return_value=False)
return cur
conn = MagicMock()
conn.cursor.return_value = make_cursor()
return conn
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 8: 队列 CRUD 不变量
# **Validates: Requirements 4.1, 4.4**
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(
config=_simple_config_st,
site_id=_site_id_st,
initial_count=st.integers(min_value=0, max_value=5),
)
@patch("app.services.task_queue.get_connection")
def test_queue_crud_invariant(mock_get_conn, config, site_id, initial_count):
"""Property 8: 队列 CRUD 不变量。
入队一个任务后队列长度增加 1 且新任务状态为 pending
删除一个 pending 任务后队列长度减少 1 且该任务不再出现在队列中。
"""
queue = TaskQueue()
db = InMemoryQueueDB(site_id)
# 预填充若干任务
for i in range(initial_count):
tid = str(uuid.uuid4())
db.rows[tid] = {
"id": tid,
"site_id": site_id,
"config": {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"},
"status": "pending",
"position": i + 1,
}
before_count = len(db.pending_tasks)
# --- 入队 ---
mock_get_conn.return_value = db.mock_enqueue_connection()
new_id = queue.enqueue(config, site_id)
after_enqueue_count = len(db.pending_tasks)
assert after_enqueue_count == before_count + 1, (
f"入队后长度应 +1期望 {before_count + 1},实际 {after_enqueue_count}"
)
assert new_id in db.rows, "新任务应存在于队列中"
assert db.rows[new_id]["status"] == "pending", "新任务状态应为 pending"
# --- 删除刚入队的任务 ---
mock_get_conn.return_value = db.mock_delete_connection(new_id)
deleted = queue.delete(new_id, site_id)
after_delete_count = len(db.pending_tasks)
assert deleted is True, "删除 pending 任务应返回 True"
assert after_delete_count == before_count, (
f"删除后长度应恢复:期望 {before_count},实际 {after_delete_count}"
)
assert new_id not in db.rows, "已删除任务不应出现在队列中"
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 9: 队列出队顺序
# **Validates: Requirements 4.2**
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(
site_id=_site_id_st,
num_tasks=st.integers(min_value=1, max_value=8),
positions=st.data(),
)
@patch("app.services.task_queue.get_connection")
def test_queue_dequeue_order(mock_get_conn, site_id, num_tasks, positions):
"""Property 9: 队列出队顺序。
包含多个 pending 任务的队列dequeue 操作应返回 position 值最小的任务。
"""
queue = TaskQueue()
db = InMemoryQueueDB(site_id)
# 生成不重复的 position 值
pos_list = positions.draw(
st.lists(
st.integers(min_value=1, max_value=1000),
min_size=num_tasks,
max_size=num_tasks,
unique=True,
)
)
# 填充队列
task_ids = []
for i, pos in enumerate(pos_list):
tid = str(uuid.uuid4())
task_ids.append(tid)
db.rows[tid] = {
"id": tid,
"site_id": site_id,
"config": {"tasks": [_task_codes[i % len(_task_codes)]], "pipeline": "api_ods"},
"status": "pending",
"position": pos,
}
# 找出 position 最小的任务
expected_first = min(db.pending_tasks, key=lambda r: r["position"])
# dequeue
mock_get_conn.return_value = db.mock_dequeue_connection()
result = queue.dequeue(site_id)
assert result is not None, "队列非空时 dequeue 不应返回 None"
assert result.id == expected_first["id"], (
f"应返回 position 最小的任务:期望 id={expected_first['id']} "
f"(pos={expected_first['position']}),实际 id={result.id}"
)
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 10: 队列重排一致性
# **Validates: Requirements 4.3**
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(
site_id=_site_id_st,
num_tasks=st.integers(min_value=2, max_value=6),
data=st.data(),
)
@patch("app.services.task_queue.get_connection")
def test_queue_reorder_consistency(mock_get_conn, site_id, num_tasks, data):
"""Property 10: 队列重排一致性。
重排操作(将任务移动到新位置)后,队列中任务的相对顺序应与请求一致:
- 被移动的任务应出现在目标位置clamp 到有效范围)
- 其余任务保持原有相对顺序
- 所有任务仍在队列中(不丢失)
"""
queue = TaskQueue()
db = InMemoryQueueDB(site_id)
# 填充队列position 从 1 开始连续编号
task_ids = []
for i in range(num_tasks):
tid = str(uuid.uuid4())
task_ids.append(tid)
db.rows[tid] = {
"id": tid,
"site_id": site_id,
"config": {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"},
"status": "pending",
"position": i + 1,
}
# 随机选择要移动的任务和目标位置
move_idx = data.draw(st.integers(min_value=0, max_value=num_tasks - 1))
move_task_id = task_ids[move_idx]
new_position = data.draw(st.integers(min_value=1, max_value=num_tasks + 2))
# 执行 reorder
mock_get_conn.return_value = db.mock_reorder_connection()
queue.reorder(move_task_id, new_position, site_id)
# 验证:所有任务仍在队列中
remaining_ids = {r["id"] for r in db.rows.values() if r["status"] == "pending"}
assert remaining_ids == set(task_ids), "重排后不应丢失任何任务"
# 验证position 值连续且唯一1-based
positions = sorted(r["position"] for r in db.pending_tasks)
assert positions == list(range(1, num_tasks + 1)), (
f"重排后 position 应为连续编号 1..{num_tasks},实际 {positions}"
)
# 验证:被移动的任务在正确位置
# reorder 内部逻辑clamp new_position 到 [1, len(others)+1]
clamped_pos = max(1, min(new_position, num_tasks))
actual_pos = db.rows[move_task_id]["position"]
assert actual_pos == clamped_pos, (
f"被移动任务的 position 应为 {clamped_pos}clamp 后),实际 {actual_pos}"
)
# 验证:其余任务保持原有相对顺序
others_before = [tid for tid in task_ids if tid != move_task_id]
others_after = sorted(
[r for r in db.pending_tasks if r["id"] != move_task_id],
key=lambda r: r["position"],
)
others_after_ids = [r["id"] for r in others_after]
assert others_after_ids == others_before, (
"其余任务的相对顺序应保持不变"
)
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 11: 执行历史排序与限制
# **Validates: Requirements 4.5, 8.2**
# ---------------------------------------------------------------------------
# 导入 FastAPI 测试客户端
from app.auth.dependencies import CurrentUser, get_current_user
from app.main import app
from fastapi.testclient import TestClient
def _make_history_rows(count: int, site_id: int) -> list[tuple]:
"""生成 count 条执行历史记录started_at 随机但可排序。"""
base_time = datetime(2024, 1, 1, tzinfo=timezone.utc)
rows = []
for i in range(count):
rows.append((
str(uuid.uuid4()), # id
site_id, # site_id
["ODS_MEMBER"], # task_codes
"success", # status
base_time + timedelta(hours=i), # started_at
base_time + timedelta(hours=i, minutes=30), # finished_at
0, # exit_code
1800000, # duration_ms
"python -m cli.main", # command
None, # summary
))
return rows
@settings(max_examples=100, deadline=None)
@given(
site_id=_site_id_st,
total_records=st.integers(min_value=0, max_value=30),
limit=st.integers(min_value=1, max_value=200),
)
@patch("app.routers.execution.get_connection")
def test_execution_history_sort_and_limit(mock_get_conn, site_id, total_records, limit):
"""Property 11: 执行历史排序与限制。
执行历史记录集合API 返回的结果应按 started_at 降序排列,
且结果数量不超过请求的 limit 值。
"""
# 生成测试数据
all_rows = _make_history_rows(total_records, site_id)
# 模拟数据库:按 started_at DESC 排序后取 limit 条
sorted_rows = sorted(all_rows, key=lambda r: r[4], reverse=True)
returned_rows = sorted_rows[:limit]
# mock 数据库连接
mock_cursor = MagicMock()
mock_cursor.fetchall.return_value = returned_rows
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
# 覆盖认证依赖
test_user = CurrentUser(user_id=1, site_id=site_id)
app.dependency_overrides[get_current_user] = lambda: test_user
try:
client = TestClient(app)
# limit 必须在 [1, 200] 范围内API 约束)
clamped_limit = max(1, min(limit, 200))
resp = client.get(f"/api/execution/history?limit={clamped_limit}")
assert resp.status_code == 200
data = resp.json()
# 验证 1结果数量不超过 limit
assert len(data) <= clamped_limit, (
f"结果数量 {len(data)} 超过 limit {clamped_limit}"
)
# 验证 2结果数量不超过总记录数
assert len(data) <= total_records, (
f"结果数量 {len(data)} 超过总记录数 {total_records}"
)
# 验证 3按 started_at 降序排列
if len(data) >= 2:
for i in range(len(data) - 1):
t1 = data[i]["started_at"]
t2 = data[i + 1]["started_at"]
assert t1 >= t2, (
f"结果未按 started_at 降序排列data[{i}]={t1} < data[{i+1}]={t2}"
)
finally:
app.dependency_overrides[get_current_user] = lambda: CurrentUser(user_id=1, site_id=100)

View File

@@ -0,0 +1,439 @@
# -*- coding: utf-8 -*-
"""调度属性测试Property-Based Testing
使用 hypothesis 验证调度管理的通用正确性属性:
- Property 12: 调度任务 CRUD 往返
- Property 13: 到期调度任务自动入队
- Property 14: 调度任务启用/禁用状态
测试策略:
- Property 12: 通过 mock 数据库,验证 POST 创建后 GET 返回的 schedule_config 与提交的一致
- Property 13: 通过 mock 数据库返回到期任务,验证 check_and_enqueue 调用了 task_queue.enqueue
- Property 14: 通过 mock 数据库,验证 toggle 端点的 next_run_at 行为
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-schedule-properties")
import json
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, patch
from hypothesis import given, settings, assume
from hypothesis import strategies as st
from app.auth.dependencies import CurrentUser, get_current_user
from app.main import app
from app.schemas.schedules import ScheduleConfigSchema
from app.schemas.tasks import TaskConfigSchema
from app.services.scheduler import Scheduler, calculate_next_run
from fastapi.testclient import TestClient
# ---------------------------------------------------------------------------
# 通用策略Strategies
# ---------------------------------------------------------------------------
_site_id_st = st.integers(min_value=1, max_value=2**31 - 1)
_task_codes = ["ODS_MEMBER", "ODS_PAYMENT", "ODS_ORDER", "DWD_LOAD_FROM_ODS", "DWS_SUMMARY"]
_simple_task_config_st = st.fixed_dictionaries({
"tasks": st.lists(st.sampled_from(_task_codes), min_size=1, max_size=3, unique=True),
"pipeline": st.sampled_from(["api_ods", "api_ods_dwd", "ods_dwd", "api_full"]),
})
# 调度配置策略:覆盖 5 种调度类型
_schedule_type_st = st.sampled_from(["once", "interval", "daily", "weekly", "cron"])
_interval_unit_st = st.sampled_from(["minutes", "hours", "days"])
# HH:MM 格式的时间字符串
_time_str_st = st.builds(
lambda h, m: f"{h:02d}:{m:02d}",
h=st.integers(min_value=0, max_value=23),
m=st.integers(min_value=0, max_value=59),
)
# ISO weekday 列表1=Monday ... 7=Sunday
_weekly_days_st = st.lists(
st.integers(min_value=1, max_value=7),
min_size=1, max_size=7, unique=True,
)
# 简单 cron 表达式minute hour * * *
_cron_st = st.builds(
lambda m, h: f"{m} {h} * * *",
m=st.integers(min_value=0, max_value=59),
h=st.integers(min_value=0, max_value=23),
)
def _build_schedule_config(schedule_type, interval_value, interval_unit,
daily_time, weekly_days, weekly_time, cron_expression):
"""根据 schedule_type 构建 ScheduleConfigSchema。"""
return ScheduleConfigSchema(
schedule_type=schedule_type,
interval_value=interval_value,
interval_unit=interval_unit,
daily_time=daily_time,
weekly_days=weekly_days,
weekly_time=weekly_time,
cron_expression=cron_expression,
enabled=True,
)
_schedule_config_st = st.builds(
_build_schedule_config,
schedule_type=_schedule_type_st,
interval_value=st.integers(min_value=1, max_value=168),
interval_unit=_interval_unit_st,
daily_time=_time_str_st,
weekly_days=_weekly_days_st,
weekly_time=_time_str_st,
cron_expression=_cron_st,
)
# 用于 Property 14 的非 once 调度配置(启用后 next_run_at 应非 NULL
_non_once_schedule_type_st = st.sampled_from(["interval", "daily", "weekly", "cron"])
_non_once_schedule_config_st = st.builds(
_build_schedule_config,
schedule_type=_non_once_schedule_type_st,
interval_value=st.integers(min_value=1, max_value=168),
interval_unit=_interval_unit_st,
daily_time=_time_str_st,
weekly_days=_weekly_days_st,
weekly_time=_time_str_st,
cron_expression=_cron_st,
)
# ---------------------------------------------------------------------------
# 辅助函数
# ---------------------------------------------------------------------------
_NOW = datetime(2025, 6, 10, 10, 0, 0, tzinfo=timezone.utc)
# 模拟数据库行的列顺序(与 _SELECT_COLS 对应,共 13 列)
# id, site_id, name, task_codes, task_config, schedule_config,
# enabled, last_run_at, next_run_at, run_count, last_status,
# created_at, updated_at
def _make_db_row(
schedule_id: str,
site_id: int,
name: str,
task_codes: list[str],
task_config: dict,
schedule_config: dict,
enabled: bool = True,
next_run_at: datetime | None = None,
) -> tuple:
"""构造模拟数据库行。"""
return (
schedule_id, site_id, name, task_codes,
json.dumps(task_config) if isinstance(task_config, dict) else task_config,
json.dumps(schedule_config) if isinstance(schedule_config, dict) else schedule_config,
enabled, None, next_run_at, 0, None, _NOW, _NOW,
)
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 12: 调度任务 CRUD 往返
# **Validates: Requirements 5.1, 5.4**
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(
site_id=_site_id_st,
schedule_config=_schedule_config_st,
task_config=_simple_task_config_st,
name=st.text(min_size=1, max_size=50, alphabet=st.characters(
whitelist_categories=("L", "N"), whitelist_characters="_- "
)),
task_codes=st.lists(st.sampled_from(_task_codes), min_size=1, max_size=3, unique=True),
)
@patch("app.routers.schedules.get_connection")
def test_schedule_crud_round_trip(
mock_get_conn, site_id, schedule_config, task_config, name, task_codes,
):
"""Property 12: 调度任务 CRUD 往返。
有效的 ScheduleConfigSchema创建调度任务后再查询该任务
返回的调度配置应与创建时提交的配置等价。
"""
schedule_config_dict = schedule_config.model_dump()
next_run = calculate_next_run(schedule_config, _NOW)
# 构造创建后数据库返回的行
created_row = _make_db_row(
schedule_id="test-sched-id",
site_id=site_id,
name=name,
task_codes=task_codes,
task_config=task_config,
schedule_config=schedule_config_dict,
enabled=schedule_config.enabled,
next_run_at=next_run,
)
# --- 创建阶段 ---
# mock POST 的数据库连接INSERT ... RETURNING
create_cursor = MagicMock()
create_cursor.fetchone.return_value = created_row
create_conn = MagicMock()
create_conn.cursor.return_value.__enter__ = MagicMock(return_value=create_cursor)
create_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
# --- 查询阶段 ---
# mock GET 的数据库连接SELECT ... fetchall
list_cursor = MagicMock()
list_cursor.fetchall.return_value = [created_row]
list_conn = MagicMock()
list_conn.cursor.return_value.__enter__ = MagicMock(return_value=list_cursor)
list_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
# 依次返回 create_conn 和 list_conn
mock_get_conn.side_effect = [create_conn, list_conn]
# 覆盖认证
test_user = CurrentUser(user_id=1, site_id=site_id)
app.dependency_overrides[get_current_user] = lambda: test_user
try:
client = TestClient(app)
# 创建调度任务
create_body = {
"name": name,
"task_codes": task_codes,
"task_config": task_config,
"schedule_config": schedule_config_dict,
}
create_resp = client.post("/api/schedules", json=create_body)
assert create_resp.status_code == 201, (
f"创建应返回 201实际 {create_resp.status_code}: {create_resp.text}"
)
created_data = create_resp.json()
# 查询调度任务列表
list_resp = client.get("/api/schedules")
assert list_resp.status_code == 200
list_data = list_resp.json()
assert len(list_data) >= 1, "查询结果应至少包含刚创建的任务"
# 找到刚创建的任务
found = next((s for s in list_data if s["id"] == created_data["id"]), None)
assert found is not None, "查询结果应包含刚创建的任务"
# 核心验证schedule_config 往返一致
returned_config = found["schedule_config"]
for key in schedule_config_dict:
assert returned_config[key] == schedule_config_dict[key], (
f"schedule_config.{key} 不一致:"
f"提交={schedule_config_dict[key]},返回={returned_config[key]}"
)
# 验证 task_config 往返一致
returned_task_config = found["task_config"]
for key in task_config:
assert returned_task_config[key] == task_config[key], (
f"task_config.{key} 不一致:提交={task_config[key]},返回={returned_task_config[key]}"
)
# 验证基本字段
assert found["name"] == name
assert found["task_codes"] == task_codes
assert found["site_id"] == site_id
finally:
app.dependency_overrides[get_current_user] = lambda: CurrentUser(user_id=1, site_id=100)
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 13: 到期调度任务自动入队
# **Validates: Requirements 5.2**
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(
site_id=_site_id_st,
schedule_config=_schedule_config_st,
task_config=_simple_task_config_st,
)
@patch("app.services.scheduler.task_queue")
@patch("app.services.scheduler.get_connection")
def test_due_schedule_auto_enqueue(
mock_get_conn, mock_tq, site_id, schedule_config, task_config,
):
"""Property 13: 到期调度任务自动入队。
enabled 为 true 且 next_run_at 早于当前时间的调度任务,
check_and_enqueue 执行后该任务的 TaskConfig 应出现在执行队列中。
"""
sched = Scheduler()
schedule_config_dict = schedule_config.model_dump()
# 构造到期任务next_run_at 在过去(比 now 早 5 分钟)
task_id = "due-task-001"
# --- mock SELECT 到期任务 ---
select_cursor = MagicMock()
select_cursor.fetchall.return_value = [
(task_id, site_id, json.dumps(task_config), json.dumps(schedule_config_dict)),
]
select_cursor.__enter__ = MagicMock(return_value=select_cursor)
select_cursor.__exit__ = MagicMock(return_value=False)
# --- mock UPDATE 调度状态 ---
update_cursor = MagicMock()
update_cursor.__enter__ = MagicMock(return_value=update_cursor)
update_cursor.__exit__ = MagicMock(return_value=False)
conn = MagicMock()
conn.cursor.side_effect = [select_cursor, update_cursor]
mock_get_conn.return_value = conn
mock_tq.enqueue.return_value = "queue-id-123"
# 执行
count = sched.check_and_enqueue()
# 验证:到期任务被入队
assert count == 1, f"应有 1 个任务入队,实际 {count}"
mock_tq.enqueue.assert_called_once()
# 验证入队参数
call_args = mock_tq.enqueue.call_args
enqueued_config = call_args[0][0]
enqueued_site_id = call_args[0][1]
# site_id 应匹配
assert enqueued_site_id == site_id, (
f"入队的 site_id 应为 {site_id},实际 {enqueued_site_id}"
)
# TaskConfig 应与原始配置一致
assert isinstance(enqueued_config, TaskConfigSchema)
assert enqueued_config.tasks == task_config["tasks"], (
f"入队的 tasks 应为 {task_config['tasks']},实际 {enqueued_config.tasks}"
)
assert enqueued_config.pipeline == task_config["pipeline"], (
f"入队的 pipeline 应为 {task_config['pipeline']},实际 {enqueued_config.pipeline}"
)
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 14: 调度任务启用/禁用状态
# **Validates: Requirements 5.3**
# ---------------------------------------------------------------------------
@settings(max_examples=100, deadline=None)
@given(
site_id=_site_id_st,
schedule_config=_non_once_schedule_config_st,
task_config=_simple_task_config_st,
name=st.text(min_size=1, max_size=30, alphabet=st.characters(
whitelist_categories=("L", "N"), whitelist_characters="_- "
)),
task_codes=st.lists(st.sampled_from(_task_codes), min_size=1, max_size=3, unique=True),
)
@patch("app.routers.schedules.get_connection")
def test_schedule_toggle_next_run(
mock_get_conn, site_id, schedule_config, task_config, name, task_codes,
):
"""Property 14: 调度任务启用/禁用状态。
禁用后 next_run_at 应为 NULL
重新启用后 next_run_at 应被重新计算为非 NULL 值(对于非一次性调度)。
"""
schedule_config_dict = schedule_config.model_dump()
next_run_enabled = calculate_next_run(schedule_config, _NOW)
# --- 第一步禁用enabled=True → False---
# toggle 端点先 SELECT 当前状态,再 UPDATE RETURNING
# 禁用后的数据库行
disabled_row = _make_db_row(
schedule_id="sched-toggle-1",
site_id=site_id,
name=name,
task_codes=task_codes,
task_config=task_config,
schedule_config=schedule_config_dict,
enabled=False,
next_run_at=None, # 禁用后 next_run_at 为 NULL
)
# mock 禁用操作的数据库连接
disable_cursor = MagicMock()
disable_cursor.fetchone.side_effect = [
(True, json.dumps(schedule_config_dict)), # SELECT 当前状态enabled=True
disabled_row, # UPDATE RETURNING
]
disable_conn = MagicMock()
disable_conn.cursor.return_value.__enter__ = MagicMock(return_value=disable_cursor)
disable_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
# --- 第二步启用enabled=False → True---
enabled_row = _make_db_row(
schedule_id="sched-toggle-1",
site_id=site_id,
name=name,
task_codes=task_codes,
task_config=task_config,
schedule_config=schedule_config_dict,
enabled=True,
next_run_at=next_run_enabled, # 启用后 next_run_at 被重新计算
)
enable_cursor = MagicMock()
enable_cursor.fetchone.side_effect = [
(False, json.dumps(schedule_config_dict)), # SELECT 当前状态enabled=False
enabled_row, # UPDATE RETURNING
]
enable_conn = MagicMock()
enable_conn.cursor.return_value.__enter__ = MagicMock(return_value=enable_cursor)
enable_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
# 依次返回两个连接
mock_get_conn.side_effect = [disable_conn, enable_conn]
# 覆盖认证
test_user = CurrentUser(user_id=1, site_id=site_id)
app.dependency_overrides[get_current_user] = lambda: test_user
try:
client = TestClient(app)
# 禁用
disable_resp = client.patch("/api/schedules/sched-toggle-1/toggle")
assert disable_resp.status_code == 200, (
f"禁用应返回 200实际 {disable_resp.status_code}: {disable_resp.text}"
)
disable_data = disable_resp.json()
# 验证:禁用后 enabled=Falsenext_run_at=NULL
assert disable_data["enabled"] is False, "禁用后 enabled 应为 False"
assert disable_data["next_run_at"] is None, "禁用后 next_run_at 应为 NULL"
# 启用
enable_resp = client.patch("/api/schedules/sched-toggle-1/toggle")
assert enable_resp.status_code == 200, (
f"启用应返回 200实际 {enable_resp.status_code}: {enable_resp.text}"
)
enable_data = enable_resp.json()
# 验证:启用后 enabled=Truenext_run_at 非 NULL非一次性调度
assert enable_data["enabled"] is True, "启用后 enabled 应为 True"
assert enable_data["next_run_at"] is not None, (
"非一次性调度启用后 next_run_at 应被重新计算为非 NULL 值"
)
finally:
app.dependency_overrides[get_current_user] = lambda: CurrentUser(user_id=1, site_id=100)

View File

@@ -0,0 +1,384 @@
# -*- coding: utf-8 -*-
"""Scheduler 单元测试
覆盖:
- calculate_next_run各种调度类型的下次执行时间计算
- _parse_simple_cron简单 cron 表达式解析
- check_and_enqueue到期检查与入队逻辑
- start / stop后台循环生命周期
"""
import asyncio
import json
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, patch
import pytest
from app.schemas.schedules import ScheduleConfigSchema
from app.schemas.tasks import TaskConfigSchema
from app.services.scheduler import (
Scheduler,
calculate_next_run,
_parse_simple_cron,
_parse_time,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def sched() -> Scheduler:
return Scheduler()
@pytest.fixture
def now() -> datetime:
"""固定时间点2025-06-10 10:00:00 UTC周二"""
return datetime(2025, 6, 10, 10, 0, 0, tzinfo=timezone.utc)
def _mock_cursor(fetchone_val=None, fetchall_val=None, rowcount=1):
cur = MagicMock()
cur.fetchone.return_value = fetchone_val
cur.fetchall.return_value = fetchall_val or []
cur.rowcount = rowcount
cur.__enter__ = MagicMock(return_value=cur)
cur.__exit__ = MagicMock(return_value=False)
return cur
def _mock_conn(cursor):
conn = MagicMock()
conn.cursor.return_value = cursor
return conn
# ---------------------------------------------------------------------------
# _parse_time
# ---------------------------------------------------------------------------
class TestParseTime:
def test_standard_format(self):
assert _parse_time("04:00") == (4, 0)
def test_with_minutes(self):
assert _parse_time("23:45") == (23, 45)
def test_midnight(self):
assert _parse_time("00:00") == (0, 0)
# ---------------------------------------------------------------------------
# calculate_next_run — once
# ---------------------------------------------------------------------------
class TestNextRunOnce:
def test_once_returns_none(self, now):
cfg = ScheduleConfigSchema(schedule_type="once")
assert calculate_next_run(cfg, now) is None
# ---------------------------------------------------------------------------
# calculate_next_run — interval
# ---------------------------------------------------------------------------
class TestNextRunInterval:
def test_interval_minutes(self, now):
cfg = ScheduleConfigSchema(
schedule_type="interval", interval_value=15, interval_unit="minutes",
)
result = calculate_next_run(cfg, now)
assert result == now + timedelta(minutes=15)
def test_interval_hours(self, now):
cfg = ScheduleConfigSchema(
schedule_type="interval", interval_value=2, interval_unit="hours",
)
result = calculate_next_run(cfg, now)
assert result == now + timedelta(hours=2)
def test_interval_days(self, now):
cfg = ScheduleConfigSchema(
schedule_type="interval", interval_value=3, interval_unit="days",
)
result = calculate_next_run(cfg, now)
assert result == now + timedelta(days=3)
# ---------------------------------------------------------------------------
# calculate_next_run — daily
# ---------------------------------------------------------------------------
class TestNextRunDaily:
def test_daily_next_day(self, now):
cfg = ScheduleConfigSchema(schedule_type="daily", daily_time="04:00")
result = calculate_next_run(cfg, now)
expected = datetime(2025, 6, 11, 4, 0, 0, tzinfo=timezone.utc)
assert result == expected
def test_daily_custom_time(self, now):
cfg = ScheduleConfigSchema(schedule_type="daily", daily_time="18:30")
result = calculate_next_run(cfg, now)
expected = datetime(2025, 6, 11, 18, 30, 0, tzinfo=timezone.utc)
assert result == expected
# ---------------------------------------------------------------------------
# calculate_next_run — weekly
# ---------------------------------------------------------------------------
class TestNextRunWeekly:
def test_weekly_later_this_week(self, now):
# now 是周二(2)weekly_days=[5] 周五 → 3 天后
cfg = ScheduleConfigSchema(
schedule_type="weekly", weekly_days=[5], weekly_time="08:00",
)
result = calculate_next_run(cfg, now)
expected = datetime(2025, 6, 13, 8, 0, 0, tzinfo=timezone.utc)
assert result == expected
def test_weekly_next_week(self, now):
# now 是周二(2)weekly_days=[1] 周一 → 下周一6天后
cfg = ScheduleConfigSchema(
schedule_type="weekly", weekly_days=[1], weekly_time="04:00",
)
result = calculate_next_run(cfg, now)
expected = datetime(2025, 6, 16, 4, 0, 0, tzinfo=timezone.utc)
assert result == expected
def test_weekly_multiple_days_picks_next(self, now):
# now 是周二(2)weekly_days=[1, 4, 6] → 周四(4)2 天后
cfg = ScheduleConfigSchema(
schedule_type="weekly", weekly_days=[1, 4, 6], weekly_time="09:00",
)
result = calculate_next_run(cfg, now)
expected = datetime(2025, 6, 12, 9, 0, 0, tzinfo=timezone.utc)
assert result == expected
# ---------------------------------------------------------------------------
# calculate_next_run — cron
# ---------------------------------------------------------------------------
class TestNextRunCron:
def test_cron_daily(self, now):
cfg = ScheduleConfigSchema(schedule_type="cron", cron_expression="30 4 * * *")
result = calculate_next_run(cfg, now)
expected = datetime(2025, 6, 11, 4, 30, 0, tzinfo=timezone.utc)
assert result == expected
def test_cron_with_dow(self, now):
# "0 8 * * 5" → 每周五 08:00now 是周二 → 周五3天后
cfg = ScheduleConfigSchema(schedule_type="cron", cron_expression="0 8 * * 5")
result = calculate_next_run(cfg, now)
expected = datetime(2025, 6, 13, 8, 0, 0, tzinfo=timezone.utc)
assert result == expected
# ---------------------------------------------------------------------------
# _parse_simple_cron
# ---------------------------------------------------------------------------
class TestParseSimpleCron:
def test_daily_cron(self, now):
result = _parse_simple_cron("0 4 * * *", now)
expected = datetime(2025, 6, 11, 4, 0, 0, tzinfo=timezone.utc)
assert result == expected
def test_invalid_field_count_fallback(self, now):
# 字段数不对,回退到明天 04:00
result = _parse_simple_cron("0 4 *", now)
expected = datetime(2025, 6, 11, 4, 0, 0, tzinfo=timezone.utc)
assert result == expected
def test_wildcard_hour_minute(self, now):
# "* * * * *" → hour=0, minute=0明天 00:00
result = _parse_simple_cron("* * * * *", now)
expected = datetime(2025, 6, 11, 0, 0, 0, tzinfo=timezone.utc)
assert result == expected
def test_dow_sunday(self, now):
# "0 6 * * 0" → 每周日 06:00now 是周二 → 周日5天后
result = _parse_simple_cron("0 6 * * 0", now)
expected = datetime(2025, 6, 15, 6, 0, 0, tzinfo=timezone.utc)
assert result == expected
def test_dow_same_day_future_time(self):
# 周二 08:00cron 指定周二 12:00 → 当天
now = datetime(2025, 6, 10, 8, 0, 0, tzinfo=timezone.utc)
result = _parse_simple_cron("0 12 * * 2", now)
expected = datetime(2025, 6, 10, 12, 0, 0, tzinfo=timezone.utc)
assert result == expected
def test_dow_same_day_past_time(self):
# 周二 14:00cron 指定周二 12:00 → 下周二
now = datetime(2025, 6, 10, 14, 0, 0, tzinfo=timezone.utc)
result = _parse_simple_cron("0 12 * * 2", now)
expected = datetime(2025, 6, 17, 12, 0, 0, tzinfo=timezone.utc)
assert result == expected
# ---------------------------------------------------------------------------
# check_and_enqueue
# ---------------------------------------------------------------------------
class TestCheckAndEnqueue:
@patch("app.services.scheduler.get_connection")
@patch("app.services.scheduler.task_queue")
def test_enqueues_due_tasks(self, mock_tq, mock_get_conn, sched):
"""到期任务应被入队,且更新 last_run_at / run_count / next_run_at"""
task_config = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods_dwd"}
schedule_config = {
"schedule_type": "interval",
"interval_value": 1,
"interval_unit": "hours",
}
# 第一次 cursorSELECT 到期任务
select_cur = _mock_cursor(
fetchall_val=[
("task-uuid-1", 42, json.dumps(task_config), json.dumps(schedule_config)),
]
)
# 第二次 cursorUPDATE
update_cur = _mock_cursor()
conn = MagicMock()
# cursor() 依次返回 select_cur 和 update_cur
conn.cursor.side_effect = [select_cur, update_cur]
mock_get_conn.return_value = conn
mock_tq.enqueue.return_value = "queue-id-1"
count = sched.check_and_enqueue()
assert count == 1
mock_tq.enqueue.assert_called_once()
# 验证 enqueue 的参数
call_args = mock_tq.enqueue.call_args
assert call_args[0][1] == 42 # site_id
assert isinstance(call_args[0][0], TaskConfigSchema)
@patch("app.services.scheduler.get_connection")
@patch("app.services.scheduler.task_queue")
def test_no_due_tasks(self, mock_tq, mock_get_conn, sched):
"""没有到期任务时,不入队"""
cur = _mock_cursor(fetchall_val=[])
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
count = sched.check_and_enqueue()
assert count == 0
mock_tq.enqueue.assert_not_called()
@patch("app.services.scheduler.get_connection")
@patch("app.services.scheduler.task_queue")
def test_skips_invalid_config(self, mock_tq, mock_get_conn, sched):
"""配置反序列化失败的任务应被跳过"""
# task_config 缺少必填字段 tasks
bad_config = {"pipeline": "api_ods_dwd"}
schedule_config = {"schedule_type": "once"}
cur = _mock_cursor(
fetchall_val=[
("task-uuid-bad", 42, json.dumps(bad_config), json.dumps(schedule_config)),
]
)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
count = sched.check_and_enqueue()
assert count == 0
mock_tq.enqueue.assert_not_called()
@patch("app.services.scheduler.get_connection")
@patch("app.services.scheduler.task_queue")
def test_enqueue_failure_continues(self, mock_tq, mock_get_conn, sched):
"""入队失败时应跳过该任务,继续处理后续任务"""
task_config = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods_dwd"}
schedule_config = {"schedule_type": "once"}
cur = _mock_cursor(
fetchall_val=[
("task-1", 42, json.dumps(task_config), json.dumps(schedule_config)),
("task-2", 42, json.dumps(task_config), json.dumps(schedule_config)),
]
)
# 需要额外的 cursor 给 UPDATE 用
update_cur = _mock_cursor()
conn = MagicMock()
conn.cursor.side_effect = [cur, update_cur]
mock_get_conn.return_value = conn
# 第一次入队失败,第二次成功
mock_tq.enqueue.side_effect = [Exception("DB error"), "queue-id-2"]
count = sched.check_and_enqueue()
assert count == 1
assert mock_tq.enqueue.call_count == 2
@patch("app.services.scheduler.get_connection")
@patch("app.services.scheduler.task_queue")
def test_once_type_sets_next_run_none(self, mock_tq, mock_get_conn, sched):
"""once 类型任务入队后next_run_at 应被设为 NULL"""
task_config = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods_dwd"}
schedule_config = {"schedule_type": "once"}
select_cur = _mock_cursor(
fetchall_val=[
("task-uuid-1", 42, json.dumps(task_config), json.dumps(schedule_config)),
]
)
update_cur = _mock_cursor()
conn = MagicMock()
conn.cursor.side_effect = [select_cur, update_cur]
mock_get_conn.return_value = conn
mock_tq.enqueue.return_value = "queue-id-1"
sched.check_and_enqueue()
# 验证 UPDATE 语句中 next_run_at 参数为 None
update_call = update_cur.__enter__().execute.call_args
# 参数元组的第一个元素是 next_run_at
assert update_call[0][1][0] is None
# ---------------------------------------------------------------------------
# start / stop 生命周期
# ---------------------------------------------------------------------------
class TestLifecycle:
@pytest.mark.asyncio
async def test_stop_sets_running_false(self, sched):
sched._running = True
await sched.stop()
assert sched._running is False
assert sched._loop_task is None
def test_start_creates_task(self, sched):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# 在事件循环中启动
async def _run():
sched.start()
assert sched._loop_task is not None
assert not sched._loop_task.done()
await sched.stop()
loop.run_until_complete(_run())
finally:
loop.close()
@pytest.mark.asyncio
async def test_start_stop_idempotent(self, sched):
"""多次 stop 不应报错"""
await sched.stop()
await sched.stop()
assert sched._loop_task is None

View File

@@ -0,0 +1,310 @@
# -*- coding: utf-8 -*-
"""调度路由单元测试
覆盖 5 个端点list / create / update / delete / toggle
通过 mock 绕过数据库,专注路由逻辑验证。
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
import json
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from app.auth.dependencies import CurrentUser, get_current_user
from app.main import app
_TEST_USER = CurrentUser(user_id=1, site_id=100)
def _override_auth():
return _TEST_USER
app.dependency_overrides[get_current_user] = _override_auth
client = TestClient(app)
_NOW = datetime(2024, 6, 1, 12, 0, 0, tzinfo=timezone.utc)
_NEXT = datetime(2024, 6, 2, 4, 0, 0, tzinfo=timezone.utc)
_SCHEDULE_CONFIG = {
"schedule_type": "daily",
"daily_time": "04:00",
}
_VALID_CREATE = {
"name": "每日全量同步",
"task_codes": ["ODS_MEMBER", "ODS_ORDER"],
"task_config": {"tasks": ["ODS_MEMBER", "ODS_ORDER"], "pipeline": "api_ods"},
"schedule_config": _SCHEDULE_CONFIG,
}
# 模拟数据库返回的完整行13 列,与 _SELECT_COLS 对应)
_DB_ROW = (
"sched-1", 100, "每日全量同步", ["ODS_MEMBER", "ODS_ORDER"],
json.dumps({"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}),
json.dumps(_SCHEDULE_CONFIG),
True, None, _NEXT, 0, None, _NOW, _NOW,
)
def _mock_conn_with_fetchall(rows):
"""构造返回 fetchall 的 mock 连接。"""
mock_cursor = MagicMock()
mock_cursor.fetchall.return_value = rows
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
return mock_conn, mock_cursor
def _mock_conn_with_fetchone(row):
"""构造返回 fetchone 的 mock 连接。"""
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = row
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
return mock_conn, mock_cursor
# ---------------------------------------------------------------------------
# GET /api/schedules
# ---------------------------------------------------------------------------
class TestListSchedules:
@patch("app.routers.schedules.get_connection")
def test_list_returns_schedules(self, mock_get_conn):
mock_conn, _ = _mock_conn_with_fetchall([_DB_ROW])
mock_get_conn.return_value = mock_conn
resp = client.get("/api/schedules")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 1
assert data[0]["id"] == "sched-1"
assert data[0]["name"] == "每日全量同步"
assert data[0]["site_id"] == 100
assert data[0]["enabled"] is True
@patch("app.routers.schedules.get_connection")
def test_list_empty(self, mock_get_conn):
mock_conn, _ = _mock_conn_with_fetchall([])
mock_get_conn.return_value = mock_conn
resp = client.get("/api/schedules")
assert resp.status_code == 200
assert resp.json() == []
@patch("app.routers.schedules.get_connection")
def test_list_filters_by_site_id(self, mock_get_conn):
mock_conn, mock_cursor = _mock_conn_with_fetchall([])
mock_get_conn.return_value = mock_conn
client.get("/api/schedules")
call_args = mock_cursor.execute.call_args
assert call_args[0][1] == (100,)
# ---------------------------------------------------------------------------
# POST /api/schedules
# ---------------------------------------------------------------------------
class TestCreateSchedule:
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
@patch("app.routers.schedules.get_connection")
def test_create_returns_201(self, mock_get_conn, mock_calc):
mock_conn, mock_cursor = _mock_conn_with_fetchone(_DB_ROW)
mock_get_conn.return_value = mock_conn
resp = client.post("/api/schedules", json=_VALID_CREATE)
assert resp.status_code == 201
data = resp.json()
assert data["id"] == "sched-1"
assert data["name"] == "每日全量同步"
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
@patch("app.routers.schedules.get_connection")
def test_create_injects_site_id(self, mock_get_conn, mock_calc):
mock_conn, mock_cursor = _mock_conn_with_fetchone(_DB_ROW)
mock_get_conn.return_value = mock_conn
client.post("/api/schedules", json=_VALID_CREATE)
# INSERT 的第一个参数应为 site_id=100
insert_params = mock_cursor.execute.call_args[0][1]
assert insert_params[0] == 100
def test_create_missing_name_returns_422(self):
body = {**_VALID_CREATE}
del body["name"]
resp = client.post("/api/schedules", json=body)
assert resp.status_code == 422
def test_create_invalid_schedule_type_returns_422(self):
body = {**_VALID_CREATE, "schedule_config": {"schedule_type": "invalid"}}
resp = client.post("/api/schedules", json=body)
assert resp.status_code == 422
# ---------------------------------------------------------------------------
# PUT /api/schedules/{id}
# ---------------------------------------------------------------------------
class TestUpdateSchedule:
@patch("app.routers.schedules.get_connection")
def test_update_name(self, mock_get_conn):
updated_row = list(_DB_ROW)
updated_row[2] = "新名称"
mock_conn, _ = _mock_conn_with_fetchone(tuple(updated_row))
mock_get_conn.return_value = mock_conn
resp = client.put("/api/schedules/sched-1", json={"name": "新名称"})
assert resp.status_code == 200
assert resp.json()["name"] == "新名称"
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
@patch("app.routers.schedules.get_connection")
def test_update_schedule_config_recalculates_next_run(self, mock_get_conn, mock_calc):
mock_conn, _ = _mock_conn_with_fetchone(_DB_ROW)
mock_get_conn.return_value = mock_conn
resp = client.put("/api/schedules/sched-1", json={
"schedule_config": {"schedule_type": "interval", "interval_value": 2, "interval_unit": "hours"},
})
assert resp.status_code == 200
mock_calc.assert_called_once()
@patch("app.routers.schedules.get_connection")
def test_update_not_found(self, mock_get_conn):
mock_conn, _ = _mock_conn_with_fetchone(None)
mock_get_conn.return_value = mock_conn
resp = client.put("/api/schedules/nonexistent", json={"name": "x"})
assert resp.status_code == 404
def test_update_empty_body_returns_422(self):
resp = client.put("/api/schedules/sched-1", json={})
assert resp.status_code == 422
# ---------------------------------------------------------------------------
# DELETE /api/schedules/{id}
# ---------------------------------------------------------------------------
class TestDeleteSchedule:
@patch("app.routers.schedules.get_connection")
def test_delete_success(self, mock_get_conn):
mock_cursor = MagicMock()
mock_cursor.rowcount = 1
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.delete("/api/schedules/sched-1")
assert resp.status_code == 200
assert "已删除" in resp.json()["message"]
@patch("app.routers.schedules.get_connection")
def test_delete_not_found(self, mock_get_conn):
mock_cursor = MagicMock()
mock_cursor.rowcount = 0
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.delete("/api/schedules/nonexistent")
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# PATCH /api/schedules/{id}/toggle
# ---------------------------------------------------------------------------
class TestToggleSchedule:
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
@patch("app.routers.schedules.get_connection")
def test_toggle_disable(self, mock_get_conn, mock_calc):
"""启用 → 禁用next_run_at 应置 NULL"""
# 第一次 fetchone 返回当前状态enabled=True
# 第二次 fetchone 返回更新后的行
disabled_row = list(_DB_ROW)
disabled_row[6] = False # enabled
disabled_row[8] = None # next_run_at
mock_cursor = MagicMock()
mock_cursor.fetchone.side_effect = [
(True, json.dumps(_SCHEDULE_CONFIG)), # SELECT 当前状态
tuple(disabled_row), # UPDATE RETURNING
]
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.patch("/api/schedules/sched-1/toggle")
assert resp.status_code == 200
data = resp.json()
assert data["enabled"] is False
assert data["next_run_at"] is None
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
@patch("app.routers.schedules.get_connection")
def test_toggle_enable(self, mock_get_conn, mock_calc):
"""禁用 → 启用next_run_at 应被重新计算"""
enabled_row = list(_DB_ROW)
enabled_row[6] = True
enabled_row[8] = _NEXT
mock_cursor = MagicMock()
mock_cursor.fetchone.side_effect = [
(False, json.dumps(_SCHEDULE_CONFIG)), # SELECT 当前状态disabled
tuple(enabled_row), # UPDATE RETURNING
]
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.patch("/api/schedules/sched-1/toggle")
assert resp.status_code == 200
data = resp.json()
assert data["enabled"] is True
assert data["next_run_at"] is not None
mock_calc.assert_called_once()
@patch("app.routers.schedules.get_connection")
def test_toggle_not_found(self, mock_get_conn):
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = None
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.patch("/api/schedules/nonexistent/toggle")
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# 认证测试
# ---------------------------------------------------------------------------
class TestSchedulesAuth:
def test_requires_auth(self):
"""移除认证覆盖后,所有端点应返回 401/403"""
app.dependency_overrides.pop(get_current_user, None)
try:
assert client.get("/api/schedules").status_code in (401, 403)
assert client.post("/api/schedules", json=_VALID_CREATE).status_code in (401, 403)
assert client.put("/api/schedules/x", json={"name": "x"}).status_code in (401, 403)
assert client.delete("/api/schedules/x").status_code in (401, 403)
assert client.patch("/api/schedules/x/toggle").status_code in (401, 403)
finally:
app.dependency_overrides[get_current_user] = _override_auth

View File

@@ -0,0 +1,336 @@
# -*- coding: utf-8 -*-
"""门店隔离属性测试Property-Based Testing
Property 20: 对于任意两个不同 site_id 的 Operator一个 Operator 查询
队列/调度/执行历史时,结果中不应包含另一个 site_id 的数据。
Validates: Requirements 1.3
测试策略:
- 通过 mock 数据库交互,验证 API 路由在不同 site_id 下的数据隔离
- 队列隔离:为 site_id_a 入队任务,用 site_id_b 的 JWT 查询队列,结果应为空
- 调度隔离:为 site_id_a 创建调度任务,用 site_id_b 的 JWT 查询调度列表,结果应为空
- 执行历史隔离site_id_a 的执行历史,用 site_id_b 的 JWT 查询不到
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-isolation")
import json
import uuid
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, patch
from hypothesis import given, settings, assume
from hypothesis import strategies as st
from fastapi.testclient import TestClient
from app.auth.dependencies import CurrentUser, get_current_user
from app.main import app
# ---------------------------------------------------------------------------
# 通用策略Strategies
# ---------------------------------------------------------------------------
_site_id_st = st.integers(min_value=1, max_value=2**31 - 1)
# ---------------------------------------------------------------------------
# 辅助函数
# ---------------------------------------------------------------------------
def _make_mock_user(site_id: int) -> CurrentUser:
"""构造指定 site_id 的 mock 用户。"""
return CurrentUser(user_id=1, site_id=site_id)
def _make_queue_rows(site_id: int, count: int) -> list[tuple]:
"""生成 count 条属于 site_id 的队列行。"""
rows = []
for i in range(count):
rows.append((
str(uuid.uuid4()), # id
site_id, # site_id
json.dumps({"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}), # config
"pending", # status
i + 1, # position
datetime(2024, 1, 1, tzinfo=timezone.utc), # created_at
None, # started_at
None, # finished_at
None, # exit_code
None, # error_message
))
return rows
def _make_schedule_rows(site_id: int, count: int) -> list[tuple]:
"""生成 count 条属于 site_id 的调度行。"""
now = datetime.now(timezone.utc)
rows = []
for i in range(count):
rows.append((
str(uuid.uuid4()), # id
site_id, # site_id
f"调度任务_{i}", # name
["ODS_MEMBER"], # task_codes
{"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}, # task_config
{"schedule_type": "daily", "daily_time": "04:00", # schedule_config
"interval_value": 1, "interval_unit": "hours",
"weekly_days": [1], "weekly_time": "04:00",
"cron_expression": "0 4 * * *", "enabled": True,
"start_date": None, "end_date": None},
True, # enabled
None, # last_run_at
now + timedelta(hours=1), # next_run_at
0, # run_count
None, # last_status
now, # created_at
now, # updated_at
))
return rows
def _make_history_rows(site_id: int, count: int) -> list[tuple]:
"""生成 count 条属于 site_id 的执行历史行。"""
base_time = datetime(2024, 1, 1, tzinfo=timezone.utc)
rows = []
for i in range(count):
rows.append((
str(uuid.uuid4()), # id
site_id, # site_id
["ODS_MEMBER"], # task_codes
"success", # status
base_time + timedelta(hours=i), # started_at
base_time + timedelta(hours=i, minutes=30), # finished_at
0, # exit_code
1800000, # duration_ms
"python -m cli.main", # command
None, # summary
))
return rows
def _mock_conn_returning(rows: list[tuple]) -> MagicMock:
"""构造一个 mock connection其 cursor.fetchall 返回指定行。"""
mock_cursor = MagicMock()
mock_cursor.fetchall.return_value = rows
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
return mock_conn
# ---------------------------------------------------------------------------
# Property 20.1: 队列隔离
# **Validates: Requirements 1.3**
# ---------------------------------------------------------------------------
@settings(max_examples=100, deadline=None)
@given(
site_id_a=_site_id_st,
site_id_b=_site_id_st,
queue_count=st.integers(min_value=1, max_value=5),
)
@patch("app.services.task_queue.get_connection")
def test_queue_isolation(mock_get_conn, site_id_a, site_id_b, queue_count):
"""Property 20.1: 队列隔离。
为 site_id_a 入队若干任务后,用 site_id_b 的身份查询队列,
结果应为空——不同门店的队列数据互不可见。
"""
assume(site_id_a != site_id_b)
# site_id_a 的队列数据
rows_a = _make_queue_rows(site_id_a, queue_count)
# 核心隔离逻辑:根据查询时传入的 site_id 过滤
# list_pending 内部 SQL: WHERE site_id = %s AND status = 'pending'
def conn_for_site(querying_site_id):
"""模拟数据库行为:只返回匹配 site_id 的行。"""
if querying_site_id == site_id_a:
return rows_a
return [] # site_id_b 查不到 site_id_a 的数据
captured_params = {}
def make_mock_conn():
mock_cursor = MagicMock()
def execute_side_effect(sql, params=None):
if params:
captured_params["site_id"] = params[0]
# 根据 SQL 中的 site_id 参数返回对应数据
mock_cursor.fetchall.return_value = conn_for_site(params[0])
mock_cursor.execute = MagicMock(side_effect=execute_side_effect)
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
mock_cursor.__exit__ = MagicMock(return_value=False)
mock_conn = MagicMock()
mock_conn.cursor.return_value = mock_cursor
return mock_conn
mock_get_conn.return_value = make_mock_conn()
# 用 site_id_b 的身份查询队列
app.dependency_overrides[get_current_user] = lambda: _make_mock_user(site_id_b)
try:
client = TestClient(app)
resp = client.get("/api/execution/queue")
assert resp.status_code == 200
data = resp.json()
# 验证site_id_b 查不到 site_id_a 的任何数据
assert len(data) == 0, (
f"site_id_b={site_id_b} 不应看到 site_id_a={site_id_a} 的队列数据,"
f"但返回了 {len(data)} 条记录"
)
# 额外验证:即使有数据返回,也不应包含 site_id_a 的记录
for item in data:
assert item.get("site_id") != site_id_a, (
f"结果中不应包含 site_id_a={site_id_a} 的数据"
)
finally:
app.dependency_overrides.clear()
# ---------------------------------------------------------------------------
# Property 20.2: 调度隔离
# **Validates: Requirements 1.3**
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(
site_id_a=_site_id_st,
site_id_b=_site_id_st,
schedule_count=st.integers(min_value=1, max_value=5),
)
@patch("app.routers.schedules.get_connection")
def test_schedule_isolation(mock_get_conn, site_id_a, site_id_b, schedule_count):
"""Property 20.2: 调度隔离。
为 site_id_a 创建若干调度任务后,用 site_id_b 的身份查询调度列表,
结果应为空——不同门店的调度数据互不可见。
"""
assume(site_id_a != site_id_b)
# site_id_a 的调度数据
rows_a = _make_schedule_rows(site_id_a, schedule_count)
def make_mock_conn():
mock_cursor = MagicMock()
def execute_side_effect(sql, params=None):
if params:
querying_site_id = params[0]
# 只返回匹配 site_id 的行
if querying_site_id == site_id_a:
mock_cursor.fetchall.return_value = rows_a
else:
mock_cursor.fetchall.return_value = []
mock_cursor.execute = MagicMock(side_effect=execute_side_effect)
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
mock_cursor.__exit__ = MagicMock(return_value=False)
mock_conn = MagicMock()
mock_conn.cursor.return_value = mock_cursor
return mock_conn
mock_get_conn.return_value = make_mock_conn()
# 用 site_id_b 的身份查询调度列表
app.dependency_overrides[get_current_user] = lambda: _make_mock_user(site_id_b)
try:
client = TestClient(app)
resp = client.get("/api/schedules")
assert resp.status_code == 200
data = resp.json()
# 验证site_id_b 查不到 site_id_a 的任何调度数据
assert len(data) == 0, (
f"site_id_b={site_id_b} 不应看到 site_id_a={site_id_a} 的调度数据,"
f"但返回了 {len(data)} 条记录"
)
# 额外验证:即使有数据返回,也不应包含 site_id_a 的记录
for item in data:
assert item.get("site_id") != site_id_a, (
f"结果中不应包含 site_id_a={site_id_a} 的调度数据"
)
finally:
app.dependency_overrides.clear()
# ---------------------------------------------------------------------------
# Property 20.3: 执行历史隔离
# **Validates: Requirements 1.3**
# ---------------------------------------------------------------------------
@settings(max_examples=100, deadline=None)
@given(
site_id_a=_site_id_st,
site_id_b=_site_id_st,
history_count=st.integers(min_value=1, max_value=10),
)
@patch("app.routers.execution.get_connection")
def test_execution_history_isolation(mock_get_conn, site_id_a, site_id_b, history_count):
"""Property 20.3: 执行历史隔离。
site_id_a 有若干执行历史记录,用 site_id_b 的身份查询执行历史,
结果应为空——不同门店的执行历史互不可见。
"""
assume(site_id_a != site_id_b)
# site_id_a 的执行历史数据
rows_a = _make_history_rows(site_id_a, history_count)
def make_mock_conn():
mock_cursor = MagicMock()
def execute_side_effect(sql, params=None):
if params:
querying_site_id = params[0]
# 只返回匹配 site_id 的行
if querying_site_id == site_id_a:
mock_cursor.fetchall.return_value = rows_a
else:
mock_cursor.fetchall.return_value = []
mock_cursor.execute = MagicMock(side_effect=execute_side_effect)
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
mock_cursor.__exit__ = MagicMock(return_value=False)
mock_conn = MagicMock()
mock_conn.cursor.return_value = mock_cursor
return mock_conn
mock_get_conn.return_value = make_mock_conn()
# 用 site_id_b 的身份查询执行历史
app.dependency_overrides[get_current_user] = lambda: _make_mock_user(site_id_b)
try:
client = TestClient(app)
resp = client.get("/api/execution/history")
assert resp.status_code == 200
data = resp.json()
# 验证site_id_b 查不到 site_id_a 的任何执行历史
assert len(data) == 0, (
f"site_id_b={site_id_b} 不应看到 site_id_a={site_id_a} 的执行历史,"
f"但返回了 {len(data)} 条记录"
)
# 额外验证:即使有数据返回,也不应包含 site_id_a 的记录
for item in data:
assert item.get("site_id") != site_id_a, (
f"结果中不应包含 site_id_a={site_id_a} 的执行历史"
)
finally:
app.dependency_overrides.clear()

View File

@@ -0,0 +1,275 @@
# -*- coding: utf-8 -*-
"""TaskConfig 属性测试Property-Based Testing
使用 hypothesis 验证 TaskConfig 相关的通用正确性属性:
- Property 1: TaskConfig 序列化往返一致性
- Property 6: 时间窗口验证
- Property 7: TaskConfig 到 CLI 命令转换完整性
"""
import datetime
from hypothesis import given, settings, assume
from hypothesis import strategies as st
from pydantic import ValidationError
from app.schemas.tasks import TaskConfigSchema
from app.services.cli_builder import CLIBuilder, VALID_FLOWS, VALID_PROCESSING_MODES
from app.services.task_registry import ALL_TASKS
# ---------------------------------------------------------------------------
# 策略Strategies
# ---------------------------------------------------------------------------
# 从真实任务注册表中采样任务代码
_task_codes = [t.code for t in ALL_TASKS]
_tasks_st = st.lists(
st.sampled_from(_task_codes),
min_size=1,
max_size=5,
unique=True,
)
_pipeline_st = st.sampled_from(sorted(VALID_FLOWS))
_processing_mode_st = st.sampled_from(sorted(VALID_PROCESSING_MODES))
_window_mode_st = st.sampled_from(["lookback", "custom"])
# 日期策略:生成 YYYY-MM-DD 格式字符串
_date_st = st.dates(
min_value=datetime.date(2020, 1, 1),
max_value=datetime.date(2030, 12, 31),
).map(lambda d: d.isoformat())
_window_split_st = st.sampled_from([None, "none", "day"])
_window_split_days_st = st.one_of(st.none(), st.sampled_from([1, 10, 30]))
_lookback_hours_st = st.integers(min_value=1, max_value=720)
_overlap_seconds_st = st.integers(min_value=0, max_value=7200)
_store_id_st = st.one_of(st.none(), st.integers(min_value=1, max_value=2**31 - 1))
# DWD 表名采样
_dwd_table_names = [
"dwd.dim_site",
"dwd.dim_member",
"dwd.dwd_settlement_head",
]
_dwd_only_tables_st = st.one_of(
st.none(),
st.lists(st.sampled_from(_dwd_table_names), min_size=1, max_size=3, unique=True),
)
def _valid_task_config_st():
"""生成有效的 TaskConfigSchema 的复合策略。
确保 window_mode=custom 时 window_end >= window_start
避免触发 Pydantic 验证错误。
"""
@st.composite
def _build(draw):
tasks = draw(_tasks_st)
pipeline = draw(_pipeline_st)
processing_mode = draw(_processing_mode_st)
dry_run = draw(st.booleans())
window_mode = draw(_window_mode_st)
store_id = draw(_store_id_st)
dwd_only_tables = draw(_dwd_only_tables_st)
window_split = draw(_window_split_st)
window_split_days = draw(_window_split_days_st)
fetch_before_verify = draw(st.booleans())
skip_ods = draw(st.booleans())
ods_local = draw(st.booleans())
if window_mode == "custom":
d1 = draw(st.dates(
min_value=datetime.date(2020, 1, 1),
max_value=datetime.date(2030, 12, 31),
))
d2 = draw(st.dates(
min_value=datetime.date(2020, 1, 1),
max_value=datetime.date(2030, 12, 31),
))
# 保证 end >= start
window_start = min(d1, d2).isoformat()
window_end = max(d1, d2).isoformat()
lookback_hours = 24
overlap_seconds = 600
else:
window_start = None
window_end = None
lookback_hours = draw(_lookback_hours_st)
overlap_seconds = draw(_overlap_seconds_st)
return TaskConfigSchema(
tasks=tasks,
pipeline=pipeline,
processing_mode=processing_mode,
dry_run=dry_run,
window_mode=window_mode,
window_start=window_start,
window_end=window_end,
window_split=window_split,
window_split_days=window_split_days,
lookback_hours=lookback_hours,
overlap_seconds=overlap_seconds,
fetch_before_verify=fetch_before_verify,
skip_ods_when_fetch_before_verify=skip_ods,
ods_use_local_json=ods_local,
store_id=store_id,
dwd_only_tables=dwd_only_tables,
extra_args={},
)
return _build()
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 1: TaskConfig 序列化往返一致性
# **Validates: Requirements 11.1, 11.2, 11.3**
# ---------------------------------------------------------------------------
@settings(max_examples=200)
@given(config=_valid_task_config_st())
def test_task_config_round_trip(config: TaskConfigSchema):
"""Property 1: 序列化为 JSON 后再反序列化,应产生与原始对象等价的结果。"""
json_str = config.model_dump_json()
restored = TaskConfigSchema.model_validate_json(json_str)
assert restored == config, (
f"往返不一致:\n原始={config}\n还原={restored}"
)
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 6: 时间窗口验证
# **Validates: Requirements 2.3**
# ---------------------------------------------------------------------------
@settings(max_examples=200)
@given(
d1=st.dates(
min_value=datetime.date(2020, 1, 1),
max_value=datetime.date(2030, 12, 31),
),
d2=st.dates(
min_value=datetime.date(2020, 1, 1),
max_value=datetime.date(2030, 12, 31),
),
)
def test_time_window_validation(d1: datetime.date, d2: datetime.date):
"""Property 6: window_end < window_start 时验证应失败,否则应通过。"""
start_str = d1.isoformat()
end_str = d2.isoformat()
if end_str < start_str:
# window_end 早于 window_start → 验证应失败
try:
TaskConfigSchema(
tasks=["ODS_MEMBER"],
window_mode="custom",
window_start=start_str,
window_end=end_str,
)
raise AssertionError(
f"期望 ValidationError但验证通过了start={start_str}, end={end_str}"
)
except ValidationError:
pass # 预期行为
else:
# window_end >= window_start → 验证应通过
config = TaskConfigSchema(
tasks=["ODS_MEMBER"],
window_mode="custom",
window_start=start_str,
window_end=end_str,
)
assert config.window_start == start_str
assert config.window_end == end_str
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 7: TaskConfig 到 CLI 命令转换完整性
# **Validates: Requirements 2.5, 2.6**
# ---------------------------------------------------------------------------
_builder = CLIBuilder()
_ETL_PATH = "/fake/etl/project"
@settings(max_examples=200)
@given(config=_valid_task_config_st())
def test_task_config_to_cli_completeness(config: TaskConfigSchema):
"""Property 7: CLIBuilder 生成的命令应包含 TaskConfig 中所有非空字段对应的 CLI 参数。"""
cmd = _builder.build_command(config, _ETL_PATH)
# 1) --pipeline 始终存在且值正确
assert "--pipeline" in cmd
idx = cmd.index("--pipeline")
assert cmd[idx + 1] == config.pipeline
# 2) --processing-mode 始终存在且值正确
assert "--processing-mode" in cmd
idx = cmd.index("--processing-mode")
assert cmd[idx + 1] == config.processing_mode
# 3) 非空任务列表 → --tasks 存在
if config.tasks:
assert "--tasks" in cmd
idx = cmd.index("--tasks")
assert set(cmd[idx + 1].split(",")) == set(config.tasks)
# 4) 时间窗口参数
if config.window_mode == "lookback":
# lookback 模式 → --lookback-hours 和 --overlap-seconds
if config.lookback_hours is not None:
assert "--lookback-hours" in cmd
idx = cmd.index("--lookback-hours")
assert cmd[idx + 1] == str(config.lookback_hours)
if config.overlap_seconds is not None:
assert "--overlap-seconds" in cmd
idx = cmd.index("--overlap-seconds")
assert cmd[idx + 1] == str(config.overlap_seconds)
# lookback 模式不应出现 custom 参数
assert "--window-start" not in cmd
assert "--window-end" not in cmd
else:
# custom 模式 → --window-start / --window-end
if config.window_start:
assert "--window-start" in cmd
if config.window_end:
assert "--window-end" in cmd
# custom 模式不应出现 lookback 参数
assert "--lookback-hours" not in cmd
assert "--overlap-seconds" not in cmd
# 5) dry_run → --dry-run
if config.dry_run:
assert "--dry-run" in cmd
else:
assert "--dry-run" not in cmd
# 6) store_id → --store-id
if config.store_id is not None:
assert "--store-id" in cmd
idx = cmd.index("--store-id")
assert cmd[idx + 1] == str(config.store_id)
else:
assert "--store-id" not in cmd
# 7) fetch_before_verify → 仅 verify_only 模式下生成
if config.fetch_before_verify and config.processing_mode == "verify_only":
assert "--fetch-before-verify" in cmd
else:
assert "--fetch-before-verify" not in cmd
# 8) window_split非 None 且非 "none")→ --window-split
if config.window_split and config.window_split != "none":
assert "--window-split" in cmd
idx = cmd.index("--window-split")
assert cmd[idx + 1] == config.window_split
if config.window_split_days is not None:
assert "--window-split-days" in cmd
else:
assert "--window-split" not in cmd

View File

@@ -0,0 +1,373 @@
# -*- coding: utf-8 -*-
"""TaskExecutor 单元测试
覆盖子进程启动、stdout/stderr 读取、日志广播、取消、数据库记录。
使用 asyncio 测试mock 子进程和数据库连接避免外部依赖。
"""
import asyncio
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.schemas.tasks import TaskConfigSchema
from app.services.task_executor import TaskExecutor
@pytest.fixture
def executor() -> TaskExecutor:
return TaskExecutor()
@pytest.fixture
def sample_config() -> TaskConfigSchema:
return TaskConfigSchema(
tasks=["ODS_MEMBER", "ODS_PAYMENT"],
pipeline="api_ods_dwd",
store_id=42,
)
def _make_stream(lines: list[bytes]) -> AsyncMock:
"""构造一个模拟的 asyncio.StreamReader按行返回数据。"""
stream = AsyncMock()
# readline 依次返回每行,最后返回 b"" 表示 EOF
stream.readline = AsyncMock(side_effect=[*lines, b""])
return stream
# ---------------------------------------------------------------------------
# 订阅 / 取消订阅
# ---------------------------------------------------------------------------
class TestSubscription:
def test_subscribe_returns_queue(self, executor: TaskExecutor):
q = executor.subscribe("exec-1")
assert isinstance(q, asyncio.Queue)
def test_subscribe_multiple(self, executor: TaskExecutor):
q1 = executor.subscribe("exec-1")
q2 = executor.subscribe("exec-1")
assert q1 is not q2
assert len(executor._subscribers["exec-1"]) == 2
def test_unsubscribe_removes_queue(self, executor: TaskExecutor):
q = executor.subscribe("exec-1")
executor.unsubscribe("exec-1", q)
# 最后一个订阅者移除后,键也被清理
assert "exec-1" not in executor._subscribers
def test_unsubscribe_nonexistent_is_safe(self, executor: TaskExecutor):
"""对不存在的 execution_id 取消订阅不应报错"""
q: asyncio.Queue = asyncio.Queue()
executor.unsubscribe("nonexistent", q)
# ---------------------------------------------------------------------------
# 广播
# ---------------------------------------------------------------------------
class TestBroadcast:
def test_broadcast_to_subscribers(self, executor: TaskExecutor):
q1 = executor.subscribe("exec-1")
q2 = executor.subscribe("exec-1")
executor._broadcast("exec-1", "hello")
assert q1.get_nowait() == "hello"
assert q2.get_nowait() == "hello"
def test_broadcast_no_subscribers_is_safe(self, executor: TaskExecutor):
"""无订阅者时广播不应报错"""
executor._broadcast("nonexistent", "hello")
def test_broadcast_end_sends_none(self, executor: TaskExecutor):
q = executor.subscribe("exec-1")
executor._broadcast_end("exec-1")
assert q.get_nowait() is None
# ---------------------------------------------------------------------------
# 日志缓冲区
# ---------------------------------------------------------------------------
class TestLogBuffer:
def test_get_logs_empty(self, executor: TaskExecutor):
assert executor.get_logs("nonexistent") == []
def test_get_logs_returns_copy(self, executor: TaskExecutor):
executor._log_buffers["exec-1"] = ["line1", "line2"]
logs = executor.get_logs("exec-1")
assert logs == ["line1", "line2"]
# 修改副本不影响原始
logs.append("line3")
assert len(executor._log_buffers["exec-1"]) == 2
# ---------------------------------------------------------------------------
# 执行状态查询
# ---------------------------------------------------------------------------
class TestRunningState:
def test_is_running_false_when_no_process(self, executor: TaskExecutor):
assert executor.is_running("nonexistent") is False
def test_is_running_true_when_process_active(self, executor: TaskExecutor):
proc = MagicMock()
proc.returncode = None
executor._processes["exec-1"] = proc
assert executor.is_running("exec-1") is True
def test_is_running_false_when_process_exited(self, executor: TaskExecutor):
proc = MagicMock()
proc.returncode = 0
executor._processes["exec-1"] = proc
assert executor.is_running("exec-1") is False
def test_get_running_ids(self, executor: TaskExecutor):
running = MagicMock()
running.returncode = None
exited = MagicMock()
exited.returncode = 0
executor._processes["a"] = running
executor._processes["b"] = exited
assert executor.get_running_ids() == ["a"]
# ---------------------------------------------------------------------------
# _read_stream
# ---------------------------------------------------------------------------
class TestReadStream:
@pytest.mark.asyncio
async def test_read_stdout_lines(self, executor: TaskExecutor):
executor._log_buffers["exec-1"] = []
stream = _make_stream([b"line1\n", b"line2\n"])
collector: list[str] = []
await executor._read_stream("exec-1", stream, "stdout", collector)
assert collector == ["line1", "line2"]
assert executor._log_buffers["exec-1"] == [
"[stdout] line1",
"[stdout] line2",
]
@pytest.mark.asyncio
async def test_read_stderr_lines(self, executor: TaskExecutor):
executor._log_buffers["exec-1"] = []
stream = _make_stream([b"err1\n"])
collector: list[str] = []
await executor._read_stream("exec-1", stream, "stderr", collector)
assert collector == ["err1"]
assert executor._log_buffers["exec-1"] == ["[stderr] err1"]
@pytest.mark.asyncio
async def test_read_stream_none_is_safe(self, executor: TaskExecutor):
"""stream 为 None 时不应报错"""
collector: list[str] = []
await executor._read_stream("exec-1", None, "stdout", collector)
assert collector == []
@pytest.mark.asyncio
async def test_broadcast_during_read(self, executor: TaskExecutor):
executor._log_buffers["exec-1"] = []
q = executor.subscribe("exec-1")
stream = _make_stream([b"hello\n"])
collector: list[str] = []
await executor._read_stream("exec-1", stream, "stdout", collector)
assert q.get_nowait() == "[stdout] hello"
# ---------------------------------------------------------------------------
# execute集成级mock 子进程和数据库)
# ---------------------------------------------------------------------------
class TestExecute:
@pytest.mark.asyncio
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
@patch("asyncio.create_subprocess_exec")
async def test_successful_execution(
self, mock_create, mock_write_log, mock_update_log,
executor: TaskExecutor, sample_config: TaskConfigSchema,
):
# 模拟子进程
proc = AsyncMock()
proc.returncode = None
proc.stdout = _make_stream([b"processing...\n", b"done\n"])
proc.stderr = _make_stream([])
proc.wait = AsyncMock(return_value=0)
# wait 调用后设置 returncode
async def _wait():
proc.returncode = 0
return 0
proc.wait = _wait
mock_create.return_value = proc
await executor.execute(sample_config, "exec-1", site_id=42)
# 验证写入了 running 状态
mock_write_log.assert_called_once()
call_kwargs = mock_write_log.call_args[1]
assert call_kwargs["status"] == "running"
assert call_kwargs["execution_id"] == "exec-1"
# 验证更新了 success 状态
mock_update_log.assert_called_once()
update_kwargs = mock_update_log.call_args[1]
assert update_kwargs["status"] == "success"
assert update_kwargs["exit_code"] == 0
assert "processing..." in update_kwargs["output_log"]
assert "done" in update_kwargs["output_log"]
# 进程已从跟踪表移除
assert "exec-1" not in executor._processes
@pytest.mark.asyncio
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
@patch("asyncio.create_subprocess_exec")
async def test_failed_execution(
self, mock_create, mock_write_log, mock_update_log,
executor: TaskExecutor, sample_config: TaskConfigSchema,
):
proc = AsyncMock()
proc.returncode = None
proc.stdout = _make_stream([])
proc.stderr = _make_stream([b"error occurred\n"])
async def _wait():
proc.returncode = 1
return 1
proc.wait = _wait
mock_create.return_value = proc
await executor.execute(sample_config, "exec-2", site_id=42)
update_kwargs = mock_update_log.call_args[1]
assert update_kwargs["status"] == "failed"
assert update_kwargs["exit_code"] == 1
assert "error occurred" in update_kwargs["error_log"]
@pytest.mark.asyncio
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
@patch("asyncio.create_subprocess_exec")
async def test_exception_during_execution(
self, mock_create, mock_write_log, mock_update_log,
executor: TaskExecutor, sample_config: TaskConfigSchema,
):
"""子进程创建失败时应记录 failed 状态"""
mock_create.side_effect = OSError("command not found")
await executor.execute(sample_config, "exec-3", site_id=42)
update_kwargs = mock_update_log.call_args[1]
assert update_kwargs["status"] == "failed"
@pytest.mark.asyncio
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
@patch("asyncio.create_subprocess_exec")
async def test_subscribers_notified_on_completion(
self, mock_create, mock_write_log, mock_update_log,
executor: TaskExecutor, sample_config: TaskConfigSchema,
):
proc = AsyncMock()
proc.returncode = None
proc.stdout = _make_stream([b"line\n"])
proc.stderr = _make_stream([])
async def _wait():
proc.returncode = 0
return 0
proc.wait = _wait
mock_create.return_value = proc
q = executor.subscribe("exec-4")
await executor.execute(sample_config, "exec-4", site_id=42)
# 应收到日志行 + None 哨兵
messages = []
while not q.empty():
messages.append(q.get_nowait())
assert "[stdout] line" in messages
assert None in messages
@pytest.mark.asyncio
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
@patch("asyncio.create_subprocess_exec")
async def test_duration_ms_recorded(
self, mock_create, mock_write_log, mock_update_log,
executor: TaskExecutor, sample_config: TaskConfigSchema,
):
proc = AsyncMock()
proc.returncode = None
proc.stdout = _make_stream([])
proc.stderr = _make_stream([])
async def _wait():
proc.returncode = 0
return 0
proc.wait = _wait
mock_create.return_value = proc
await executor.execute(sample_config, "exec-5", site_id=42)
update_kwargs = mock_update_log.call_args[1]
assert isinstance(update_kwargs["duration_ms"], int)
assert update_kwargs["duration_ms"] >= 0
# ---------------------------------------------------------------------------
# cancel
# ---------------------------------------------------------------------------
class TestCancel:
@pytest.mark.asyncio
async def test_cancel_running_process(self, executor: TaskExecutor):
proc = MagicMock()
proc.returncode = None
proc.terminate = MagicMock()
executor._processes["exec-1"] = proc
result = await executor.cancel("exec-1")
assert result is True
proc.terminate.assert_called_once()
@pytest.mark.asyncio
async def test_cancel_nonexistent_returns_false(self, executor: TaskExecutor):
result = await executor.cancel("nonexistent")
assert result is False
@pytest.mark.asyncio
async def test_cancel_already_exited_returns_false(self, executor: TaskExecutor):
proc = MagicMock()
proc.returncode = 0
executor._processes["exec-1"] = proc
result = await executor.cancel("exec-1")
assert result is False
@pytest.mark.asyncio
async def test_cancel_process_lookup_error(self, executor: TaskExecutor):
"""进程已消失时 terminate 抛出 ProcessLookupError"""
proc = MagicMock()
proc.returncode = None
proc.terminate = MagicMock(side_effect=ProcessLookupError)
executor._processes["exec-1"] = proc
result = await executor.cancel("exec-1")
assert result is False
# ---------------------------------------------------------------------------
# cleanup
# ---------------------------------------------------------------------------
class TestCleanup:
def test_cleanup_removes_buffers_and_subscribers(self, executor: TaskExecutor):
executor._log_buffers["exec-1"] = ["line"]
executor.subscribe("exec-1")
executor.cleanup("exec-1")
assert "exec-1" not in executor._log_buffers
assert "exec-1" not in executor._subscribers
def test_cleanup_nonexistent_is_safe(self, executor: TaskExecutor):
executor.cleanup("nonexistent")

View File

@@ -0,0 +1,482 @@
# -*- coding: utf-8 -*-
"""TaskQueue 单元测试
覆盖enqueue、dequeue、reorder、delete、process_loop 的核心逻辑。
使用 mock 数据库操作,专注于业务逻辑验证。
"""
import asyncio
import json
import uuid
from unittest.mock import MagicMock, AsyncMock, patch, call
import pytest
from app.schemas.tasks import TaskConfigSchema
from app.services.task_queue import TaskQueue, QueuedTask
@pytest.fixture
def queue() -> TaskQueue:
return TaskQueue()
@pytest.fixture
def sample_config() -> TaskConfigSchema:
return TaskConfigSchema(
tasks=["ODS_MEMBER", "ODS_PAYMENT"],
pipeline="api_ods_dwd",
store_id=42,
)
def _mock_cursor(fetchone_val=None, fetchall_val=None, rowcount=1):
"""构造 mock cursor支持 context manager 协议。"""
cur = MagicMock()
cur.fetchone.return_value = fetchone_val
cur.fetchall.return_value = fetchall_val or []
cur.rowcount = rowcount
cur.__enter__ = MagicMock(return_value=cur)
cur.__exit__ = MagicMock(return_value=False)
return cur
def _mock_conn(cursor):
"""构造 mock connection支持 cursor() context manager。"""
conn = MagicMock()
conn.cursor.return_value = cursor
return conn
# ---------------------------------------------------------------------------
# enqueue
# ---------------------------------------------------------------------------
class TestEnqueue:
@patch("app.services.task_queue.get_connection")
def test_enqueue_returns_uuid(self, mock_get_conn, queue, sample_config):
cur = _mock_cursor(fetchone_val=(0,))
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
task_id = queue.enqueue(sample_config, site_id=42)
# 返回有效 UUID
uuid.UUID(task_id)
conn.commit.assert_called_once()
conn.close.assert_called_once()
@patch("app.services.task_queue.get_connection")
def test_enqueue_position_increments(self, mock_get_conn, queue, sample_config):
"""新任务 position = 当前最大 position + 1"""
cur = _mock_cursor(fetchone_val=(5,))
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue.enqueue(sample_config, site_id=42)
# 检查 INSERT 调用中的 position 参数
insert_call = cur.execute.call_args_list[1]
args = insert_call[0][1]
# args = (task_id, site_id, config_json, new_pos)
assert args[3] == 6 # 5 + 1
@patch("app.services.task_queue.get_connection")
def test_enqueue_empty_queue_position_is_one(self, mock_get_conn, queue, sample_config):
"""空队列时 position = 1"""
cur = _mock_cursor(fetchone_val=(0,))
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue.enqueue(sample_config, site_id=42)
insert_call = cur.execute.call_args_list[1]
args = insert_call[0][1]
assert args[3] == 1
@patch("app.services.task_queue.get_connection")
def test_enqueue_serializes_config(self, mock_get_conn, queue, sample_config):
"""config 被序列化为 JSON 字符串"""
cur = _mock_cursor(fetchone_val=(0,))
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue.enqueue(sample_config, site_id=42)
insert_call = cur.execute.call_args_list[1]
config_json_str = insert_call[0][1][2]
parsed = json.loads(config_json_str)
assert parsed["tasks"] == ["ODS_MEMBER", "ODS_PAYMENT"]
assert parsed["pipeline"] == "api_ods_dwd"
# ---------------------------------------------------------------------------
# dequeue
# ---------------------------------------------------------------------------
class TestDequeue:
@patch("app.services.task_queue.get_connection")
def test_dequeue_returns_none_when_empty(self, mock_get_conn, queue):
cur = _mock_cursor(fetchone_val=None)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
result = queue.dequeue(site_id=42)
assert result is None
conn.commit.assert_called()
@patch("app.services.task_queue.get_connection")
def test_dequeue_returns_task(self, mock_get_conn, queue):
task_id = str(uuid.uuid4())
config_dict = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}
row = (
task_id, 42, json.dumps(config_dict), "pending", 1,
None, None, None, None, None,
)
cur = _mock_cursor(fetchone_val=row)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
result = queue.dequeue(site_id=42)
assert result is not None
assert result.id == task_id
assert result.site_id == 42
assert result.status == "running" # dequeue 后状态变为 running
assert result.config["tasks"] == ["ODS_MEMBER"]
@patch("app.services.task_queue.get_connection")
def test_dequeue_updates_status_to_running(self, mock_get_conn, queue):
task_id = str(uuid.uuid4())
config_dict = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}
row = (
task_id, 42, json.dumps(config_dict), "pending", 1,
None, None, None, None, None,
)
cur = _mock_cursor(fetchone_val=row)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue.dequeue(site_id=42)
# 第二次 execute 调用应该是 UPDATE status = 'running'
update_call = cur.execute.call_args_list[1]
sql = update_call[0][0]
assert "running" in sql
assert task_id in update_call[0][1]
# ---------------------------------------------------------------------------
# reorder
# ---------------------------------------------------------------------------
class TestReorder:
@patch("app.services.task_queue.get_connection")
def test_reorder_moves_task(self, mock_get_conn, queue):
"""将第 3 个任务移到第 1 位"""
ids = [str(uuid.uuid4()) for _ in range(3)]
rows = [(i,) for i in ids]
cur = _mock_cursor(fetchall_val=rows)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue.reorder(ids[2], new_position=1, site_id=42)
# 重排后顺序应为 [ids[2], ids[0], ids[1]]
update_calls = cur.execute.call_args_list[1:] # 跳过 SELECT
positions = {}
for c in update_calls:
pos, tid = c[0][1]
positions[tid] = pos
assert positions[ids[2]] == 1
assert positions[ids[0]] == 2
assert positions[ids[1]] == 3
@patch("app.services.task_queue.get_connection")
def test_reorder_nonexistent_task_is_noop(self, mock_get_conn, queue):
"""重排不存在的任务不报错"""
rows = [(str(uuid.uuid4()),)]
cur = _mock_cursor(fetchall_val=rows)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue.reorder("nonexistent-id", new_position=1, site_id=42)
# 只有 SELECT没有 UPDATE
assert cur.execute.call_count == 1
@patch("app.services.task_queue.get_connection")
def test_reorder_clamps_position(self, mock_get_conn, queue):
"""position 超出范围时 clamp 到有效范围"""
ids = [str(uuid.uuid4()) for _ in range(2)]
rows = [(i,) for i in ids]
cur = _mock_cursor(fetchall_val=rows)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
# new_position=100 超出范围,应 clamp 到末尾
queue.reorder(ids[0], new_position=100, site_id=42)
update_calls = cur.execute.call_args_list[1:]
positions = {}
for c in update_calls:
pos, tid = c[0][1]
positions[tid] = pos
# ids[0] 移到末尾
assert positions[ids[1]] == 1
assert positions[ids[0]] == 2
# ---------------------------------------------------------------------------
# delete
# ---------------------------------------------------------------------------
class TestDelete:
@patch("app.services.task_queue.get_connection")
def test_delete_pending_task(self, mock_get_conn, queue):
cur = _mock_cursor(rowcount=1)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
result = queue.delete("task-1", site_id=42)
assert result is True
conn.commit.assert_called_once()
@patch("app.services.task_queue.get_connection")
def test_delete_nonexistent_returns_false(self, mock_get_conn, queue):
cur = _mock_cursor(rowcount=0)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
result = queue.delete("nonexistent", site_id=42)
assert result is False
@patch("app.services.task_queue.get_connection")
def test_delete_only_affects_pending(self, mock_get_conn, queue):
"""DELETE SQL 包含 status = 'pending' 条件"""
cur = _mock_cursor(rowcount=0)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue.delete("task-1", site_id=42)
sql = cur.execute.call_args[0][0]
assert "pending" in sql
# ---------------------------------------------------------------------------
# list_pending / has_running
# ---------------------------------------------------------------------------
class TestQuery:
@patch("app.services.task_queue.get_connection")
def test_list_pending_empty(self, mock_get_conn, queue):
cur = _mock_cursor(fetchall_val=[])
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
result = queue.list_pending(site_id=42)
assert result == []
@patch("app.services.task_queue.get_connection")
def test_list_pending_returns_tasks(self, mock_get_conn, queue):
tid = str(uuid.uuid4())
config = json.dumps({"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"})
rows = [(tid, 42, config, "pending", 1, None, None, None, None, None)]
cur = _mock_cursor(fetchall_val=rows)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
result = queue.list_pending(site_id=42)
assert len(result) == 1
assert result[0].id == tid
@patch("app.services.task_queue.get_connection")
def test_has_running_true(self, mock_get_conn, queue):
cur = _mock_cursor(fetchone_val=(True,))
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
assert queue.has_running(site_id=42) is True
@patch("app.services.task_queue.get_connection")
def test_has_running_false(self, mock_get_conn, queue):
cur = _mock_cursor(fetchone_val=(False,))
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
assert queue.has_running(site_id=42) is False
# ---------------------------------------------------------------------------
# process_loop / _process_once
# ---------------------------------------------------------------------------
class TestProcessLoop:
@patch("app.services.task_queue.get_connection")
@pytest.mark.asyncio
async def test_process_once_skips_when_running(self, mock_get_conn, queue):
"""有 running 任务时不 dequeue"""
# _get_pending_site_ids 返回 [42]
# has_running(42) 返回 True
call_count = 0
def side_effect_conn():
nonlocal call_count
call_count += 1
if call_count == 1:
# _get_pending_site_ids
cur = _mock_cursor(fetchall_val=[(42,)])
return _mock_conn(cur)
else:
# has_running
cur = _mock_cursor(fetchone_val=(True,))
return _mock_conn(cur)
mock_get_conn.side_effect = side_effect_conn
mock_executor = MagicMock()
await queue._process_once(mock_executor)
# 不应调用 execute
mock_executor.execute.assert_not_called()
@patch("app.services.task_queue.get_connection")
@pytest.mark.asyncio
async def test_process_once_dequeues_and_executes(self, mock_get_conn, queue):
"""无 running 任务时 dequeue 并执行"""
task_id = str(uuid.uuid4())
config_dict = {
"tasks": ["ODS_MEMBER"],
"pipeline": "api_ods_dwd",
"processing_mode": "increment_only",
"dry_run": False,
"window_mode": "lookback",
"lookback_hours": 24,
"overlap_seconds": 600,
"fetch_before_verify": False,
"skip_ods_when_fetch_before_verify": False,
"ods_use_local_json": False,
"extra_args": {},
}
config_json = json.dumps(config_dict)
call_count = 0
def side_effect_conn():
nonlocal call_count
call_count += 1
if call_count == 1:
# _get_pending_site_ids
cur = _mock_cursor(fetchall_val=[(42,)])
return _mock_conn(cur)
elif call_count == 2:
# has_running → False
cur = _mock_cursor(fetchone_val=(False,))
return _mock_conn(cur)
else:
# dequeue → 返回任务
row = (
task_id, 42, config_json, "pending", 1,
None, None, None, None, None,
)
cur = _mock_cursor(fetchone_val=row)
return _mock_conn(cur)
mock_get_conn.side_effect = side_effect_conn
mock_executor = MagicMock()
mock_executor.execute = AsyncMock()
await queue._process_once(mock_executor)
# 给 create_task 一点时间启动
await asyncio.sleep(0.1)
@patch("app.services.task_queue.get_connection")
@pytest.mark.asyncio
async def test_process_once_no_pending(self, mock_get_conn, queue):
"""无 pending 任务时什么都不做"""
cur = _mock_cursor(fetchall_val=[])
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
mock_executor = MagicMock()
await queue._process_once(mock_executor)
mock_executor.execute.assert_not_called()
# ---------------------------------------------------------------------------
# 生命周期
# ---------------------------------------------------------------------------
class TestLifecycle:
@pytest.mark.asyncio
async def test_stop_sets_running_false(self, queue):
queue._running = True
queue._loop_task = None
await queue.stop()
assert queue._running is False
def test_start_creates_task(self, queue):
"""start() 应创建 asyncio.Task需要事件循环"""
# 仅验证 _running 初始状态
assert queue._running is False
assert queue._loop_task is None
# ---------------------------------------------------------------------------
# _mark_failed / _update_queue_status_from_log
# ---------------------------------------------------------------------------
class TestInternalHelpers:
@patch("app.services.task_queue.get_connection")
def test_mark_failed(self, mock_get_conn, queue):
cur = _mock_cursor()
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue._mark_failed("queue-1", "测试错误")
sql = cur.execute.call_args[0][0]
assert "failed" in sql
args = cur.execute.call_args[0][1]
assert args[0] == "测试错误"
assert args[1] == "queue-1"
@patch("app.services.task_queue.get_connection")
def test_update_queue_status_from_log(self, mock_get_conn, queue):
"""从 execution_log 同步状态到 task_queue"""
from datetime import datetime, timezone
finished = datetime.now(timezone.utc)
# 第一次 fetchone 返回 execution_log 行
cur = _mock_cursor(fetchone_val=("success", finished, 0, None))
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue._update_queue_status_from_log("queue-1")
# 应有 SELECT + UPDATE 两次 execute
assert cur.execute.call_count == 2
conn.commit.assert_called_once()
@patch("app.services.task_queue.get_connection")
def test_update_queue_status_no_log(self, mock_get_conn, queue):
"""execution_log 无记录时不更新"""
cur = _mock_cursor(fetchone_val=None)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue._update_queue_status_from_log("queue-1")
# 只有 SELECT没有 UPDATE
assert cur.execute.call_count == 1

View File

@@ -0,0 +1,299 @@
# -*- coding: utf-8 -*-
"""任务注册表分组属性测试Property-Based Testing
Property 4: 对于 Task_Registry 中的任务集合,分组结果中每个任务应出现在
且仅出现在其所属业务域的分组中。
Validates: Requirements 2.1
测试策略:
1. 直接测试 get_tasks_grouped_by_domain 函数:
- 每个任务出现在且仅出现在其 domain 对应的分组中
- 分组中的任务总数等于全部任务数(不多不少)
- 每个分组的 key 等于该分组内所有任务的 domain
2. 通过 API 端点测试TestClient + mock auth
- 返回的 groups 中每个任务的 domain 与其所在分组 key 一致
- 所有任务都出现在结果中
3. 随机子集验证:
- 随机选取任务子集,验证分组逻辑的一致性
- 随机选取 domain验证该 domain 下的任务都正确
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-registry")
from hypothesis import given, settings, assume
from hypothesis import strategies as st
from app.services.task_registry import (
get_all_tasks,
get_tasks_grouped_by_domain,
TaskDefinition,
)
from fastapi.testclient import TestClient
from app.main import app
from app.auth.dependencies import get_current_user, CurrentUser
# ---------------------------------------------------------------------------
# 辅助
# ---------------------------------------------------------------------------
ALL_TASKS = get_all_tasks()
ALL_CODES = [t.code for t in ALL_TASKS]
ALL_DOMAINS = list({t.domain for t in ALL_TASKS})
def _mock_user() -> CurrentUser:
return CurrentUser(user_id=1, site_id=1)
# ---------------------------------------------------------------------------
# Property 4.1: 分组完整性 — 每个任务出现在且仅出现在其 domain 分组中
# Validates: Requirements 2.1
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(data=st.data())
def test_every_task_in_exactly_its_domain_group(data):
"""Property 4.1: 每个任务出现在且仅出现在其所属业务域的分组中。
从全量任务中随机选取一个任务,验证它只出现在对应 domain 的分组里,
且不出现在其他任何分组中。
"""
grouped = get_tasks_grouped_by_domain()
# 随机选取一个任务
task = data.draw(st.sampled_from(ALL_TASKS))
# 该任务必须出现在其 domain 分组中
assert task.domain in grouped, (
f"任务 {task.code} 的 domain '{task.domain}' 不在分组 keys 中"
)
domain_codes = [t.code for t in grouped[task.domain]]
assert task.code in domain_codes, (
f"任务 {task.code} 未出现在其 domain '{task.domain}' 的分组中"
)
# 该任务不应出现在其他任何分组中
for other_domain, other_tasks in grouped.items():
if other_domain == task.domain:
continue
other_codes = [t.code for t in other_tasks]
assert task.code not in other_codes, (
f"任务 {task.code}domain={task.domain})错误地出现在 "
f"domain '{other_domain}' 的分组中"
)
# ---------------------------------------------------------------------------
# Property 4.2: 分组总数守恒 — 分组中的任务总数等于全部任务数
# Validates: Requirements 2.1
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(data=st.data())
def test_grouped_total_equals_all_tasks(data):
"""Property 4.2: 分组中的任务总数等于全部任务数(不多不少)。
随机选取若干 domain 进行局部验证,同时验证全局总数守恒。
"""
all_tasks = get_all_tasks()
grouped = get_tasks_grouped_by_domain()
# 全局守恒:分组内任务总数 == 全量任务数
grouped_total = sum(len(tasks) for tasks in grouped.values())
assert grouped_total == len(all_tasks), (
f"分组总数 {grouped_total} != 全量任务数 {len(all_tasks)}"
)
# 随机选取一个 domain验证该 domain 下的任务数量正确
domain = data.draw(st.sampled_from(ALL_DOMAINS))
expected_count = sum(1 for t in all_tasks if t.domain == domain)
actual_count = len(grouped[domain])
assert actual_count == expected_count, (
f"domain '{domain}' 分组内任务数 {actual_count} != 预期 {expected_count}"
)
# ---------------------------------------------------------------------------
# Property 4.3: 分组 key 一致性 — 每个分组的 key 等于组内所有任务的 domain
# Validates: Requirements 2.1
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(data=st.data())
def test_group_key_matches_task_domains(data):
"""Property 4.3: 每个分组的 key 等于该分组内所有任务的 domain。
随机选取一个 domain 分组,验证组内每个任务的 domain 字段都等于分组 key。
"""
grouped = get_tasks_grouped_by_domain()
domain = data.draw(st.sampled_from(list(grouped.keys())))
for task in grouped[domain]:
assert task.domain == domain, (
f"分组 '{domain}' 中的任务 {task.code} 的 domain 为 "
f"'{task.domain}',与分组 key 不一致"
)
# ---------------------------------------------------------------------------
# Property 4.4: 任务 code 全局唯一 — 分组后不应出现重复 code
# Validates: Requirements 2.1
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(data=st.data())
def test_no_duplicate_codes_across_groups(data):
"""Property 4.4: 分组后所有任务的 code 全局唯一,无重复。
随机选取若干 domain 的任务合并,验证 code 不重复。
"""
grouped = get_tasks_grouped_by_domain()
# 收集所有分组中的 code
all_codes_in_groups = []
for tasks in grouped.values():
all_codes_in_groups.extend(t.code for t in tasks)
assert len(all_codes_in_groups) == len(set(all_codes_in_groups)), (
"分组中存在重复的任务 code"
)
# 随机选取两个不同 domain验证它们的任务 code 无交集
if len(ALL_DOMAINS) >= 2:
domains = data.draw(
st.lists(st.sampled_from(ALL_DOMAINS), min_size=2, max_size=2, unique=True)
)
codes_a = {t.code for t in grouped[domains[0]]}
codes_b = {t.code for t in grouped[domains[1]]}
overlap = codes_a & codes_b
assert not overlap, (
f"domain '{domains[0]}''{domains[1]}' 存在重叠任务 code: {overlap}"
)
# ---------------------------------------------------------------------------
# Property 4.5: 随机子集分组一致性 — 子集中的任务分组结果与全量一致
# Validates: Requirements 2.1
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(
indices=st.lists(
st.integers(min_value=0, max_value=len(ALL_TASKS) - 1),
min_size=1,
max_size=min(20, len(ALL_TASKS)),
unique=True,
)
)
def test_subset_grouping_consistency(indices):
"""Property 4.5: 随机选取任务子集,验证每个任务在全量分组中的归属正确。
对于随机选取的任务子集,每个任务在 get_tasks_grouped_by_domain() 的结果中
都应出现在其 domain 对应的分组里。
"""
grouped = get_tasks_grouped_by_domain()
subset = [ALL_TASKS[i] for i in indices]
for task in subset:
# 任务的 domain 必须是分组的 key 之一
assert task.domain in grouped
# 任务必须在对应分组中
group_codes = {t.code for t in grouped[task.domain]}
assert task.code in group_codes, (
f"任务 {task.code} 未出现在 domain '{task.domain}' 的分组中"
)
# ---------------------------------------------------------------------------
# Property 4.6: API 端点分组正确性 — GET /api/tasks/registry 返回一致的分组
# Validates: Requirements 2.1
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(data=st.data())
def test_api_registry_grouping_correctness(data):
"""Property 4.6: API 端点返回的分组中,每个任务的 domain 与分组 key 一致,
且所有任务都出现在结果中。
"""
app.dependency_overrides[get_current_user] = _mock_user
try:
client = TestClient(app)
resp = client.get("/api/tasks/registry")
assert resp.status_code == 200
body = resp.json()
groups = body["groups"]
# 收集 API 返回的所有任务 code
api_codes: set[str] = set()
for domain_key, task_list in groups.items():
for task_item in task_list:
# 每个任务的 domain 必须等于分组 key
assert task_item["domain"] == domain_key, (
f"API 返回的任务 {task_item['code']}domain={task_item['domain']}"
f"出现在分组 '{domain_key}' 中,不一致"
)
api_codes.add(task_item["code"])
# 所有任务都应出现在 API 结果中
all_codes_set = {t.code for t in get_all_tasks()}
assert api_codes == all_codes_set, (
f"API 返回的任务集合与全量任务不一致。"
f"缺失: {all_codes_set - api_codes}"
f"多余: {api_codes - all_codes_set}"
)
# 随机选取一个 domain验证该 domain 下的任务数量与服务层一致
if groups:
domain = data.draw(st.sampled_from(list(groups.keys())))
expected = get_tasks_grouped_by_domain()
assert len(groups[domain]) == len(expected[domain]), (
f"API 返回的 domain '{domain}' 任务数 {len(groups[domain])} "
f"!= 服务层 {len(expected[domain])}"
)
finally:
app.dependency_overrides.pop(get_current_user, None)
# ---------------------------------------------------------------------------
# Property 4.7: 随机 domain 过滤验证
# Validates: Requirements 2.1
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(domain=st.sampled_from(ALL_DOMAINS))
def test_random_domain_tasks_all_correct(domain):
"""Property 4.7: 随机选取一个 domain验证该 domain 下的所有任务都正确归属。
对于选定的 domain
- 分组中的每个任务的 domain 字段都等于选定的 domain
- 全量任务中所有属于该 domain 的任务都出现在分组中
"""
grouped = get_tasks_grouped_by_domain()
all_tasks = get_all_tasks()
# 分组中该 domain 的任务
group_tasks = grouped.get(domain, [])
# 全量任务中属于该 domain 的任务
expected_tasks = [t for t in all_tasks if t.domain == domain]
# 数量一致
assert len(group_tasks) == len(expected_tasks), (
f"domain '{domain}': 分组中 {len(group_tasks)} 个任务,"
f"预期 {len(expected_tasks)}"
)
# code 集合一致
group_codes = {t.code for t in group_tasks}
expected_codes = {t.code for t in expected_tasks}
assert group_codes == expected_codes, (
f"domain '{domain}': 分组 codes {group_codes} != 预期 {expected_codes}"
)
# 每个任务的 domain 字段都正确
for task in group_tasks:
assert task.domain == domain

View File

@@ -0,0 +1,274 @@
# -*- coding: utf-8 -*-
"""任务注册表 API 单元测试
覆盖 4 个端点registry / dwd-tables / flows / validate
通过 JWT mock 绕过认证依赖。
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
import pytest
from fastapi.testclient import TestClient
from app.auth.dependencies import CurrentUser, get_current_user
from app.main import app
from app.services.task_registry import (
ALL_TASKS,
DWD_TABLES,
FLOW_LAYER_MAP,
get_tasks_grouped_by_domain,
)
# 固定测试用户
_TEST_USER = CurrentUser(user_id=1, site_id=100)
def _override_auth():
return _TEST_USER
app.dependency_overrides[get_current_user] = _override_auth
client = TestClient(app)
# ---------------------------------------------------------------------------
# GET /api/tasks/registry
# ---------------------------------------------------------------------------
class TestTaskRegistry:
def setup_method(self):
"""每个测试方法前重新设置 auth 覆盖,防止其他测试文件的 clear/pop 导致状态泄漏"""
app.dependency_overrides[get_current_user] = _override_auth
def test_registry_returns_grouped_tasks(self):
resp = client.get("/api/tasks/registry")
assert resp.status_code == 200
data = resp.json()
assert "groups" in data
# 所有任务都应出现在某个分组中
all_codes_in_response = set()
for domain, tasks in data["groups"].items():
for t in tasks:
all_codes_in_response.add(t["code"])
assert t["domain"] == domain
expected_codes = {t.code for t in ALL_TASKS}
assert all_codes_in_response == expected_codes
def test_registry_task_fields_complete(self):
"""每个任务项包含所有必要字段"""
resp = client.get("/api/tasks/registry")
data = resp.json()
required_fields = {"code", "name", "description", "domain", "layer",
"requires_window", "is_ods", "is_dimension", "default_enabled"}
for tasks in data["groups"].values():
for t in tasks:
assert required_fields.issubset(t.keys())
def test_registry_requires_auth(self):
"""未认证时返回 401"""
app.dependency_overrides.pop(get_current_user, None)
try:
resp = client.get("/api/tasks/registry")
assert resp.status_code == 401
finally:
app.dependency_overrides[get_current_user] = _override_auth
# ---------------------------------------------------------------------------
# GET /api/tasks/dwd-tables
# ---------------------------------------------------------------------------
class TestDwdTables:
def test_dwd_tables_returns_grouped(self):
resp = client.get("/api/tasks/dwd-tables")
assert resp.status_code == 200
data = resp.json()
assert "groups" in data
all_tables_in_response = set()
for domain, tables in data["groups"].items():
for t in tables:
all_tables_in_response.add(t["table_name"])
assert t["domain"] == domain
expected_tables = {t.table_name for t in DWD_TABLES}
assert all_tables_in_response == expected_tables
def test_dwd_tables_fields_complete(self):
resp = client.get("/api/tasks/dwd-tables")
data = resp.json()
required_fields = {"table_name", "display_name", "domain", "ods_source", "is_dimension"}
for tables in data["groups"].values():
for t in tables:
assert required_fields.issubset(t.keys())
# ---------------------------------------------------------------------------
# GET /api/tasks/flows
# ---------------------------------------------------------------------------
class TestFlows:
def test_flows_returns_seven_flows(self):
resp = client.get("/api/tasks/flows")
assert resp.status_code == 200
data = resp.json()
assert len(data["flows"]) == 7
def test_flows_returns_three_processing_modes(self):
resp = client.get("/api/tasks/flows")
data = resp.json()
assert len(data["processing_modes"]) == 3
def test_flow_ids_match_registry(self):
"""Flow ID 与 FLOW_LAYER_MAP 一致"""
resp = client.get("/api/tasks/flows")
data = resp.json()
flow_ids = {f["id"] for f in data["flows"]}
assert flow_ids == set(FLOW_LAYER_MAP.keys())
def test_flow_layers_non_empty(self):
resp = client.get("/api/tasks/flows")
data = resp.json()
for f in data["flows"]:
assert len(f["layers"]) > 0
def test_processing_mode_ids(self):
resp = client.get("/api/tasks/flows")
data = resp.json()
mode_ids = {m["id"] for m in data["processing_modes"]}
assert mode_ids == {"increment_only", "verify_only", "increment_verify"}
# ---------------------------------------------------------------------------
# POST /api/tasks/validate
# ---------------------------------------------------------------------------
class TestValidate:
def test_validate_success(self):
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": ["ODS_MEMBER", "ODS_PAYMENT"],
"pipeline": "api_ods",
}
})
assert resp.status_code == 200
data = resp.json()
assert data["valid"] is True
assert data["errors"] == []
assert len(data["command_args"]) > 0
assert "--store-id" in data["command"]
# store_id 应从 JWT 注入(测试用户 site_id=100
assert "100" in data["command"]
def test_validate_injects_store_id(self):
"""即使前端传了 store_id后端也用 JWT 中的值覆盖"""
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": ["DWD_LOAD_FROM_ODS"],
"pipeline": "ods_dwd",
"store_id": 999,
}
})
assert resp.status_code == 200
data = resp.json()
# 命令中应包含 JWT 的 site_id=100而非前端传的 999
assert "--store-id" in data["command"]
idx = data["command_args"].index("--store-id")
assert data["command_args"][idx + 1] == "100"
def test_validate_invalid_flow(self):
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": ["ODS_MEMBER"],
"pipeline": "nonexistent_flow",
}
})
assert resp.status_code == 200
data = resp.json()
assert data["valid"] is False
assert any("无效的执行流程" in e for e in data["errors"])
def test_validate_empty_tasks(self):
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": [],
"pipeline": "api_ods",
}
})
assert resp.status_code == 200
data = resp.json()
assert data["valid"] is False
assert any("任务列表不能为空" in e for e in data["errors"])
def test_validate_custom_window(self):
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": ["ODS_MEMBER"],
"pipeline": "api_ods",
"window_mode": "custom",
"window_start": "2024-01-01",
"window_end": "2024-01-31",
}
})
assert resp.status_code == 200
data = resp.json()
assert data["valid"] is True
assert "--window-start" in data["command"]
assert "--window-end" in data["command"]
def test_validate_window_end_before_start_rejected(self):
"""window_end 早于 window_start 时 Pydantic 验证失败 → 422"""
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": ["ODS_MEMBER"],
"pipeline": "api_ods",
"window_mode": "custom",
"window_start": "2024-12-31",
"window_end": "2024-01-01",
}
})
assert resp.status_code == 422
def test_validate_dry_run_flag(self):
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": ["ODS_MEMBER"],
"pipeline": "api_ods",
"dry_run": True,
}
})
assert resp.status_code == 200
data = resp.json()
assert "--dry-run" in data["command"]
# ---------------------------------------------------------------------------
# task_registry 服务层测试
# ---------------------------------------------------------------------------
class TestTaskRegistryService:
def test_all_tasks_have_unique_codes(self):
codes = [t.code for t in ALL_TASKS]
assert len(codes) == len(set(codes))
def test_grouped_tasks_cover_all(self):
grouped = get_tasks_grouped_by_domain()
all_codes = set()
for tasks in grouped.values():
for t in tasks:
all_codes.add(t.code)
assert all_codes == {t.code for t in ALL_TASKS}
def test_ods_tasks_marked_is_ods(self):
for t in ALL_TASKS:
if t.layer == "ODS":
assert t.is_ods is True
def test_flow_layer_map_covers_all_flows(self):
expected_flows = {"api_ods", "api_ods_dwd", "api_full", "ods_dwd",
"dwd_dws", "dwd_dws_index", "dwd_index"}
assert set(FLOW_LAYER_MAP.keys()) == expected_flows

View File

@@ -0,0 +1,186 @@
# -*- coding: utf-8 -*-
"""WebSocket 日志推送端点测试
测试 /ws/logs/{execution_id} 端点的连接、日志回放、实时推送和断开行为。
利用 TaskExecutor 已有的 subscribe/broadcast 机制进行验证。
"""
from __future__ import annotations
import asyncio
import pytest
from starlette.testclient import TestClient
from starlette.websockets import WebSocketDisconnect
from app.main import app
from app.services.task_executor import task_executor
@pytest.fixture(autouse=True)
def _cleanup_executor():
"""每个测试前后清理 TaskExecutor 内部状态。"""
yield
# 清理所有残留的缓冲区和订阅者
for eid in list(task_executor._log_buffers.keys()):
task_executor.cleanup(eid)
task_executor._subscribers.clear()
task_executor._log_buffers.clear()
class TestWebSocketConnection:
"""WebSocket 连接/断开基本行为"""
def test_connect_and_disconnect(self):
"""客户端能成功建立和关闭 WebSocket 连接。"""
client = TestClient(app)
with client.websocket_connect("/ws/logs/test-exec-001") as ws:
# 连接成功,直接关闭
pass # __exit__ 会关闭连接
def test_connect_registers_subscriber(self):
"""连接后 TaskExecutor 应注册订阅者。"""
client = TestClient(app)
# 预先初始化缓冲区(模拟有任务在运行)
task_executor._log_buffers["test-exec-002"] = []
with client.websocket_connect("/ws/logs/test-exec-002"):
# 连接期间应有订阅者
assert "test-exec-002" in task_executor._subscribers
assert len(task_executor._subscribers["test-exec-002"]) >= 1
class TestLogReplay:
"""历史日志回放"""
def test_replay_existing_logs(self):
"""连接时应先收到内存缓冲区中已有的日志行。"""
eid = "test-exec-replay"
# 预填充日志缓冲区
task_executor._log_buffers[eid] = [
"[stdout] 第一行",
"[stdout] 第二行",
"[stderr] 警告信息",
]
client = TestClient(app)
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
# 应按顺序收到 3 条历史日志
msg1 = ws.receive_text()
msg2 = ws.receive_text()
msg3 = ws.receive_text()
assert msg1 == "[stdout] 第一行"
assert msg2 == "[stdout] 第二行"
assert msg3 == "[stderr] 警告信息"
def test_no_logs_no_replay(self):
"""没有历史日志时不应收到回放消息。"""
eid = "test-exec-empty"
task_executor._log_buffers[eid] = []
client = TestClient(app)
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
# 发送结束信号让连接正常关闭
task_executor._broadcast_end(eid)
# 不应有回放消息,直接收到的是后续的结束信号处理
class TestBroadcastReceive:
"""实时日志广播接收"""
def test_receive_broadcast_messages(self):
"""连接后应能收到 TaskExecutor 广播的实时日志。"""
eid = "test-exec-broadcast"
task_executor._log_buffers[eid] = []
client = TestClient(app)
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
# 模拟 TaskExecutor 广播日志
task_executor._broadcast(eid, "[stdout] 实时日志行1")
task_executor._broadcast(eid, "[stderr] 实时错误行")
msg1 = ws.receive_text()
msg2 = ws.receive_text()
assert msg1 == "[stdout] 实时日志行1"
assert msg2 == "[stderr] 实时错误行"
def test_end_signal_closes_connection(self):
"""收到 None 结束信号后 WebSocket 应正常关闭。"""
eid = "test-exec-end"
task_executor._log_buffers[eid] = []
client = TestClient(app)
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
# 广播一条日志后发送结束信号
task_executor._broadcast(eid, "[stdout] 最后一行")
task_executor._broadcast_end(eid)
msg = ws.receive_text()
assert msg == "[stdout] 最后一行"
def test_replay_then_broadcast(self):
"""先回放历史日志,再接收实时广播。"""
eid = "test-exec-mixed"
task_executor._log_buffers[eid] = ["[stdout] 历史行"]
client = TestClient(app)
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
# 先收到历史回放
replay = ws.receive_text()
assert replay == "[stdout] 历史行"
# 再收到实时广播
task_executor._broadcast(eid, "[stdout] 新行")
task_executor._broadcast_end(eid)
live = ws.receive_text()
assert live == "[stdout] 新行"
class TestLogBroadcasterUnit:
"""直接测试 TaskExecutor 的 subscribe/unsubscribe/broadcast 方法
(作为 LogBroadcaster 功能的单元测试)。
"""
def test_subscribe_creates_queue(self):
eid = "unit-sub"
q = task_executor.subscribe(eid)
assert isinstance(q, asyncio.Queue)
assert eid in task_executor._subscribers
task_executor.unsubscribe(eid, q)
def test_unsubscribe_removes_queue(self):
eid = "unit-unsub"
q = task_executor.subscribe(eid)
task_executor.unsubscribe(eid, q)
# 最后一个订阅者移除后key 也应被清理
assert eid not in task_executor._subscribers
def test_broadcast_delivers_to_all_subscribers(self):
eid = "unit-multi"
q1 = task_executor.subscribe(eid)
q2 = task_executor.subscribe(eid)
task_executor._broadcast(eid, "测试消息")
assert q1.get_nowait() == "测试消息"
assert q2.get_nowait() == "测试消息"
task_executor.unsubscribe(eid, q1)
task_executor.unsubscribe(eid, q2)
def test_broadcast_end_sends_none(self):
eid = "unit-end"
q = task_executor.subscribe(eid)
task_executor._broadcast_end(eid)
assert q.get_nowait() is None
task_executor.unsubscribe(eid, q)
def test_broadcast_no_subscribers_is_safe(self):
"""没有订阅者时广播不应报错。"""
task_executor._broadcast("nonexistent", "无人接收")
task_executor._broadcast_end("nonexistent")