在准备环境前提交次全部更改。

This commit is contained in:
Neo
2026-02-19 08:35:13 +08:00
parent ded6dfb9d8
commit 4eac07da47
1387 changed files with 6107191 additions and 33002 deletions

View File

@@ -0,0 +1,62 @@
"""
FastAPI 依赖注入 get_current_user 单元测试。
通过 FastAPI TestClient 验证 Authorization header 处理。
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
import pytest
from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient
from app.auth.dependencies import CurrentUser, get_current_user
from app.auth.jwt import create_access_token, create_refresh_token
# 构造一个最小 FastAPI 应用用于测试依赖注入
_test_app = FastAPI()
@_test_app.get("/protected")
async def protected_route(user: CurrentUser = Depends(get_current_user)):
return {"user_id": user.user_id, "site_id": user.site_id}
client = TestClient(_test_app)
class TestGetCurrentUser:
def test_valid_access_token(self):
token = create_access_token(user_id=10, site_id=100)
resp = client.get("/protected", headers={"Authorization": f"Bearer {token}"})
assert resp.status_code == 200
data = resp.json()
assert data["user_id"] == 10
assert data["site_id"] == 100
def test_missing_auth_header_returns_401(self):
"""缺少 Authorization header 时返回 401。"""
resp = client.get("/protected")
assert resp.status_code in (401, 403)
def test_invalid_token_returns_401(self):
resp = client.get(
"/protected", headers={"Authorization": "Bearer invalid.token.here"}
)
assert resp.status_code == 401
def test_refresh_token_rejected(self):
"""refresh 令牌不能用于访问受保护端点。"""
token = create_refresh_token(user_id=1, site_id=1)
resp = client.get("/protected", headers={"Authorization": f"Bearer {token}"})
assert resp.status_code == 401
def test_current_user_is_frozen_dataclass(self):
"""CurrentUser 是不可变的。"""
user = CurrentUser(user_id=1, site_id=2)
assert user.user_id == 1
assert user.site_id == 2
with pytest.raises(AttributeError):
user.user_id = 99 # type: ignore[misc]

View File

@@ -0,0 +1,147 @@
"""
JWT 认证模块单元测试。
覆盖:令牌生成、验证、过期、类型校验、密码哈希、依赖注入。
"""
import os
import time
import pytest
from jose import jwt as jose_jwt
# 测试前设置 JWT_SECRET_KEY避免空密钥
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
from app.auth.jwt import (
create_access_token,
create_refresh_token,
create_token_pair,
decode_access_token,
decode_refresh_token,
decode_token,
hash_password,
verify_password,
)
from app import config
# ---------------------------------------------------------------------------
# 密码哈希
# ---------------------------------------------------------------------------
class TestPasswordHashing:
def test_hash_and_verify(self):
raw = "my_secure_password"
hashed = hash_password(raw)
assert verify_password(raw, hashed)
def test_wrong_password_rejected(self):
hashed = hash_password("correct")
assert not verify_password("wrong", hashed)
def test_hash_is_not_plaintext(self):
raw = "plaintext123"
hashed = hash_password(raw)
assert hashed != raw
# ---------------------------------------------------------------------------
# 令牌生成与解码
# ---------------------------------------------------------------------------
class TestTokenCreation:
def test_access_token_contains_expected_fields(self):
token = create_access_token(user_id=1, site_id=100)
payload = decode_token(token)
assert payload["sub"] == "1"
assert payload["site_id"] == 100
assert payload["type"] == "access"
assert "exp" in payload
def test_refresh_token_contains_expected_fields(self):
token = create_refresh_token(user_id=2, site_id=200)
payload = decode_token(token)
assert payload["sub"] == "2"
assert payload["site_id"] == 200
assert payload["type"] == "refresh"
assert "exp" in payload
def test_token_pair_returns_both_tokens(self):
pair = create_token_pair(user_id=3, site_id=300)
assert "access_token" in pair
assert "refresh_token" in pair
assert pair["token_type"] == "bearer"
# 验证两个令牌类型不同
access_payload = decode_token(pair["access_token"])
refresh_payload = decode_token(pair["refresh_token"])
assert access_payload["type"] == "access"
assert refresh_payload["type"] == "refresh"
# ---------------------------------------------------------------------------
# 令牌类型校验
# ---------------------------------------------------------------------------
class TestTokenTypeValidation:
def test_decode_access_token_rejects_refresh(self):
"""access 解码器拒绝 refresh 令牌。"""
token = create_refresh_token(user_id=1, site_id=1)
with pytest.raises(Exception):
decode_access_token(token)
def test_decode_refresh_token_rejects_access(self):
"""refresh 解码器拒绝 access 令牌。"""
token = create_access_token(user_id=1, site_id=1)
with pytest.raises(Exception):
decode_refresh_token(token)
def test_decode_access_token_accepts_access(self):
token = create_access_token(user_id=5, site_id=50)
payload = decode_access_token(token)
assert payload["sub"] == "5"
assert payload["site_id"] == 50
def test_decode_refresh_token_accepts_refresh(self):
token = create_refresh_token(user_id=6, site_id=60)
payload = decode_refresh_token(token)
assert payload["sub"] == "6"
assert payload["site_id"] == 60
# ---------------------------------------------------------------------------
# 令牌过期
# ---------------------------------------------------------------------------
class TestTokenExpiry:
def test_expired_token_rejected(self):
"""手动构造已过期令牌,验证解码失败。"""
payload = {
"sub": "1",
"site_id": 1,
"type": "access",
"exp": int(time.time()) - 10, # 10 秒前过期
}
token = jose_jwt.encode(
payload, config.JWT_SECRET_KEY, algorithm=config.JWT_ALGORITHM
)
with pytest.raises(Exception):
decode_token(token)
# ---------------------------------------------------------------------------
# 无效令牌
# ---------------------------------------------------------------------------
class TestInvalidToken:
def test_garbage_token_rejected(self):
with pytest.raises(Exception):
decode_token("not.a.valid.jwt")
def test_wrong_secret_rejected(self):
"""用不同密钥签发的令牌应被拒绝。"""
payload = {"sub": "1", "site_id": 1, "type": "access", "exp": int(time.time()) + 3600}
token = jose_jwt.encode(payload, "wrong-secret", algorithm="HS256")
with pytest.raises(Exception):
decode_token(token)

View File

@@ -0,0 +1,137 @@
"""
认证模块属性测试Property-Based Testing
使用 hypothesis 验证认证系统的通用正确性属性:
- Property 2: 无效凭据始终被拒绝
- Property 3: 有效 JWT 令牌授权访问
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-property-tests")
from unittest.mock import MagicMock, patch
from hypothesis import given, settings
from hypothesis import strategies as st
from app.auth.dependencies import CurrentUser, get_current_user
from app.auth.jwt import create_access_token
from app.main import app
from app.routers.auth import router
# 确保路由已挂载
if router not in [r for r in app.routes]:
app.include_router(router)
from fastapi.testclient import TestClient
client = TestClient(app)
# ---------------------------------------------------------------------------
# 策略Strategies
# ---------------------------------------------------------------------------
# 用户名策略1~64 字符的可打印字符串(排除控制字符)
_username_st = st.text(
alphabet=st.characters(whitelist_categories=("L", "N", "P", "S")),
min_size=1,
max_size=64,
)
# 密码策略1~128 字符的可打印字符串
_password_st = st.text(
alphabet=st.characters(whitelist_categories=("L", "N", "P", "S")),
min_size=1,
max_size=128,
)
# user_id 策略:正整数
_user_id_st = st.integers(min_value=1, max_value=2**31 - 1)
# site_id 策略:正整数
_site_id_st = st.integers(min_value=1, max_value=2**63 - 1)
def _mock_db_returning(row):
"""构造 mock get_connectioncursor.fetchone() 返回指定行。"""
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = row
mock_conn.cursor.return_value.__enter__ = lambda _: mock_cursor
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
return mock_conn
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 2: 无效凭据始终被拒绝
# **Validates: Requirements 1.2**
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(username=_username_st, password=_password_st)
@patch("app.routers.auth.get_connection")
def test_invalid_credentials_always_rejected(mock_get_conn, username, password):
"""
Property 2: 无效凭据始终被拒绝。
对于任意用户名/密码组合当数据库中不存在该用户时fetchone 返回 None
登录接口应始终返回 401 状态码。
"""
# mock 数据库返回 None — 用户不存在
mock_get_conn.return_value = _mock_db_returning(None)
resp = client.post(
"/api/auth/login",
json={"username": username, "password": password},
)
assert resp.status_code == 401, (
f"期望 401实际 {resp.status_code}username={username!r}"
)
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 3: 有效 JWT 令牌授权访问
# **Validates: Requirements 1.3**
# ---------------------------------------------------------------------------
import asyncio
from fastapi.security import HTTPAuthorizationCredentials
def _run_async(coro):
"""在同步上下文中执行异步协程,避免 DeprecationWarning。"""
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
loop.close()
@settings(max_examples=100)
@given(user_id=_user_id_st, site_id=_site_id_st)
def test_valid_jwt_grants_access(user_id, site_id):
"""
Property 3: 有效 JWT 令牌授权访问。
对于任意 user_id 和 site_id由系统签发的未过期 access_token
应能被 get_current_user 依赖成功解析为 CurrentUser 对象,
且解析出的 user_id 和 site_id 与签发时一致。
"""
# 生成有效的 access_token
token = create_access_token(user_id=user_id, site_id=site_id)
# 直接调用依赖函数验证令牌解析
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
result = _run_async(get_current_user(credentials))
assert isinstance(result, CurrentUser)
assert result.user_id == user_id, (
f"user_id 不匹配:期望 {user_id},实际 {result.user_id}"
)
assert result.site_id == site_id, (
f"site_id 不匹配:期望 {site_id},实际 {result.site_id}"
)

View File

@@ -0,0 +1,167 @@
"""
认证路由单元测试。
覆盖:登录成功/失败、刷新令牌、账号禁用等场景。
通过 mock 数据库连接避免依赖真实 PostgreSQL。
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from app.auth.jwt import (
create_refresh_token,
decode_access_token,
decode_refresh_token,
hash_password,
)
from app.main import app
from app.routers.auth import router
# 注册路由到 app测试时确保路由已挂载
if router not in [r for r in app.routes]:
app.include_router(router)
client = TestClient(app)
# 测试用固定数据
_TEST_PASSWORD = "correct_password"
_TEST_HASH = hash_password(_TEST_PASSWORD)
_TEST_USER_ROW = (1, _TEST_HASH, 100, True) # id, password_hash, site_id, is_active
_DISABLED_USER_ROW = (2, _TEST_HASH, 200, False)
def _mock_db_returning(row):
"""构造一个 mock get_connectioncursor.fetchone() 返回指定行。"""
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = row
mock_conn.cursor.return_value.__enter__ = lambda _: mock_cursor
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
return mock_conn
# ---------------------------------------------------------------------------
# POST /api/auth/login
# ---------------------------------------------------------------------------
class TestLogin:
@patch("app.routers.auth.get_connection")
def test_login_success(self, mock_get_conn):
mock_get_conn.return_value = _mock_db_returning(_TEST_USER_ROW)
resp = client.post(
"/api/auth/login",
json={"username": "admin", "password": _TEST_PASSWORD},
)
assert resp.status_code == 200
data = resp.json()
assert "access_token" in data
assert "refresh_token" in data
assert data["token_type"] == "bearer"
# 验证 access_token payload 包含正确的 user_id 和 site_id
payload = decode_access_token(data["access_token"])
assert payload["sub"] == "1"
assert payload["site_id"] == 100
@patch("app.routers.auth.get_connection")
def test_login_user_not_found(self, mock_get_conn):
"""用户不存在时返回 401。"""
mock_get_conn.return_value = _mock_db_returning(None)
resp = client.post(
"/api/auth/login",
json={"username": "nonexistent", "password": "whatever"},
)
assert resp.status_code == 401
assert "用户名或密码错误" in resp.json()["detail"]
@patch("app.routers.auth.get_connection")
def test_login_wrong_password(self, mock_get_conn):
"""密码错误时返回 401。"""
mock_get_conn.return_value = _mock_db_returning(_TEST_USER_ROW)
resp = client.post(
"/api/auth/login",
json={"username": "admin", "password": "wrong_password"},
)
assert resp.status_code == 401
assert "用户名或密码错误" in resp.json()["detail"]
@patch("app.routers.auth.get_connection")
def test_login_disabled_account(self, mock_get_conn):
"""账号已禁用时返回 401。"""
mock_get_conn.return_value = _mock_db_returning(_DISABLED_USER_ROW)
resp = client.post(
"/api/auth/login",
json={"username": "disabled_user", "password": _TEST_PASSWORD},
)
assert resp.status_code == 401
assert "禁用" in resp.json()["detail"]
def test_login_missing_username(self):
"""缺少 username 字段时返回 422。"""
resp = client.post("/api/auth/login", json={"password": "test"})
assert resp.status_code == 422
def test_login_empty_password(self):
"""空密码时返回 422。"""
resp = client.post(
"/api/auth/login", json={"username": "admin", "password": ""}
)
assert resp.status_code == 422
# ---------------------------------------------------------------------------
# POST /api/auth/refresh
# ---------------------------------------------------------------------------
class TestRefresh:
def test_refresh_success(self):
"""有效的 refresh_token 换取新的 access_token。"""
refresh = create_refresh_token(user_id=5, site_id=50)
resp = client.post(
"/api/auth/refresh", json={"refresh_token": refresh}
)
assert resp.status_code == 200
data = resp.json()
assert "access_token" in data
# refresh_token 原样返回
assert data["refresh_token"] == refresh
assert data["token_type"] == "bearer"
# 新 access_token 包含正确信息
payload = decode_access_token(data["access_token"])
assert payload["sub"] == "5"
assert payload["site_id"] == 50
def test_refresh_with_invalid_token(self):
"""无效令牌返回 401。"""
resp = client.post(
"/api/auth/refresh", json={"refresh_token": "garbage.token.here"}
)
assert resp.status_code == 401
assert "无效的刷新令牌" in resp.json()["detail"]
def test_refresh_with_access_token_rejected(self):
"""用 access_token 做刷新应被拒绝。"""
from app.auth.jwt import create_access_token
access = create_access_token(user_id=1, site_id=1)
resp = client.post(
"/api/auth/refresh", json={"refresh_token": access}
)
assert resp.status_code == 401
def test_refresh_missing_token(self):
"""缺少 refresh_token 字段时返回 422。"""
resp = client.post("/api/auth/refresh", json={})
assert resp.status_code == 422

View File

@@ -0,0 +1,259 @@
# -*- coding: utf-8 -*-
"""CLIBuilder 单元测试
覆盖7 种 Flow、3 种处理模式、时间窗口、store_id 自动注入、extra_args 等。
"""
import pytest
from app.schemas.tasks import TaskConfigSchema
from app.services.cli_builder import CLIBuilder, VALID_FLOWS, VALID_PROCESSING_MODES
@pytest.fixture
def builder() -> CLIBuilder:
return CLIBuilder()
ETL_PATH = "/fake/etl/project"
# ---------------------------------------------------------------------------
# 基本命令结构
# ---------------------------------------------------------------------------
class TestBasicCommand:
def test_minimal_command(self, builder: CLIBuilder):
"""最小配置应生成 python -m cli.main --pipeline ... --processing-mode ..."""
config = TaskConfigSchema(tasks=["ODS_MEMBER"])
cmd = builder.build_command(config, ETL_PATH)
assert cmd[:3] == ["python", "-m", "cli.main"]
assert "--pipeline" in cmd
assert "--processing-mode" in cmd
def test_custom_python_executable(self, builder: CLIBuilder):
config = TaskConfigSchema(tasks=["ODS_MEMBER"])
cmd = builder.build_command(config, ETL_PATH, python_executable="python3")
assert cmd[0] == "python3"
def test_tasks_joined_by_comma(self, builder: CLIBuilder):
config = TaskConfigSchema(tasks=["ODS_MEMBER", "ODS_PAYMENT", "ODS_REFUND"])
cmd = builder.build_command(config, ETL_PATH)
idx = cmd.index("--tasks")
assert cmd[idx + 1] == "ODS_MEMBER,ODS_PAYMENT,ODS_REFUND"
def test_empty_tasks_no_tasks_arg(self, builder: CLIBuilder):
"""空任务列表不应生成 --tasks 参数"""
config = TaskConfigSchema(tasks=[])
cmd = builder.build_command(config, ETL_PATH)
assert "--tasks" not in cmd
# ---------------------------------------------------------------------------
# 7 种 Flow
# ---------------------------------------------------------------------------
class TestFlows:
@pytest.mark.parametrize("flow_id", sorted(VALID_FLOWS))
def test_all_flows_accepted(self, builder: CLIBuilder, flow_id: str):
config = TaskConfigSchema(tasks=["ODS_MEMBER"], pipeline=flow_id)
cmd = builder.build_command(config, ETL_PATH)
idx = cmd.index("--pipeline")
assert cmd[idx + 1] == flow_id
def test_default_flow_is_api_ods_dwd(self, builder: CLIBuilder):
config = TaskConfigSchema(tasks=["ODS_MEMBER"])
cmd = builder.build_command(config, ETL_PATH)
idx = cmd.index("--pipeline")
assert cmd[idx + 1] == "api_ods_dwd"
# ---------------------------------------------------------------------------
# 3 种处理模式
# ---------------------------------------------------------------------------
class TestProcessingModes:
@pytest.mark.parametrize("mode", sorted(VALID_PROCESSING_MODES))
def test_all_modes_accepted(self, builder: CLIBuilder, mode: str):
config = TaskConfigSchema(tasks=["ODS_MEMBER"], processing_mode=mode)
cmd = builder.build_command(config, ETL_PATH)
idx = cmd.index("--processing-mode")
assert cmd[idx + 1] == mode
def test_fetch_before_verify_only_in_verify_mode(self, builder: CLIBuilder):
"""--fetch-before-verify 仅在 verify_only 模式下生效"""
config = TaskConfigSchema(
tasks=["ODS_MEMBER"],
processing_mode="verify_only",
fetch_before_verify=True,
)
cmd = builder.build_command(config, ETL_PATH)
assert "--fetch-before-verify" in cmd
def test_fetch_before_verify_ignored_in_increment_mode(self, builder: CLIBuilder):
"""increment_only 模式下 fetch_before_verify=True 不应生成参数"""
config = TaskConfigSchema(
tasks=["ODS_MEMBER"],
processing_mode="increment_only",
fetch_before_verify=True,
)
cmd = builder.build_command(config, ETL_PATH)
assert "--fetch-before-verify" not in cmd
# ---------------------------------------------------------------------------
# 时间窗口
# ---------------------------------------------------------------------------
class TestTimeWindow:
def test_lookback_mode_generates_lookback_args(self, builder: CLIBuilder):
config = TaskConfigSchema(
tasks=["ODS_MEMBER"],
window_mode="lookback",
lookback_hours=48,
overlap_seconds=1200,
)
cmd = builder.build_command(config, ETL_PATH)
idx_lb = cmd.index("--lookback-hours")
assert cmd[idx_lb + 1] == "48"
idx_ol = cmd.index("--overlap-seconds")
assert cmd[idx_ol + 1] == "1200"
# lookback 模式不应生成 --window-start / --window-end
assert "--window-start" not in cmd
assert "--window-end" not in cmd
def test_custom_mode_generates_window_args(self, builder: CLIBuilder):
config = TaskConfigSchema(
tasks=["ODS_MEMBER"],
window_mode="custom",
window_start="2026-01-01",
window_end="2026-01-31",
)
cmd = builder.build_command(config, ETL_PATH)
idx_s = cmd.index("--window-start")
assert cmd[idx_s + 1] == "2026-01-01"
idx_e = cmd.index("--window-end")
assert cmd[idx_e + 1] == "2026-01-31"
# custom 模式不应生成 --lookback-hours / --overlap-seconds
assert "--lookback-hours" not in cmd
assert "--overlap-seconds" not in cmd
def test_window_split_with_days(self, builder: CLIBuilder):
config = TaskConfigSchema(
tasks=["ODS_MEMBER"],
window_split="day",
window_split_days=10,
)
cmd = builder.build_command(config, ETL_PATH)
idx = cmd.index("--window-split")
assert cmd[idx + 1] == "day"
idx_d = cmd.index("--window-split-days")
assert cmd[idx_d + 1] == "10"
def test_window_split_none_not_generated(self, builder: CLIBuilder):
"""window_split='none' 不应生成 --window-split 参数"""
config = TaskConfigSchema(tasks=["ODS_MEMBER"], window_split="none")
cmd = builder.build_command(config, ETL_PATH)
assert "--window-split" not in cmd
# ---------------------------------------------------------------------------
# store_id 自动注入
# ---------------------------------------------------------------------------
class TestStoreId:
def test_store_id_injected(self, builder: CLIBuilder):
config = TaskConfigSchema(tasks=["ODS_MEMBER"], store_id=42)
cmd = builder.build_command(config, ETL_PATH)
idx = cmd.index("--store-id")
assert cmd[idx + 1] == "42"
def test_store_id_none_not_generated(self, builder: CLIBuilder):
config = TaskConfigSchema(tasks=["ODS_MEMBER"], store_id=None)
cmd = builder.build_command(config, ETL_PATH)
assert "--store-id" not in cmd
# ---------------------------------------------------------------------------
# dry_run
# ---------------------------------------------------------------------------
class TestDryRun:
def test_dry_run_flag(self, builder: CLIBuilder):
config = TaskConfigSchema(tasks=["ODS_MEMBER"], dry_run=True)
cmd = builder.build_command(config, ETL_PATH)
assert "--dry-run" in cmd
def test_no_dry_run_flag(self, builder: CLIBuilder):
config = TaskConfigSchema(tasks=["ODS_MEMBER"], dry_run=False)
cmd = builder.build_command(config, ETL_PATH)
assert "--dry-run" not in cmd
# ---------------------------------------------------------------------------
# extra_args
# ---------------------------------------------------------------------------
class TestExtraArgs:
def test_supported_value_arg(self, builder: CLIBuilder):
config = TaskConfigSchema(
tasks=["ODS_MEMBER"],
extra_args={"pg_dsn": "postgresql://localhost/test"},
)
cmd = builder.build_command(config, ETL_PATH)
idx = cmd.index("--pg-dsn")
assert cmd[idx + 1] == "postgresql://localhost/test"
def test_supported_bool_arg(self, builder: CLIBuilder):
config = TaskConfigSchema(
tasks=["ODS_MEMBER"],
extra_args={"force_window_override": True},
)
cmd = builder.build_command(config, ETL_PATH)
assert "--force-window-override" in cmd
def test_unsupported_arg_ignored(self, builder: CLIBuilder):
"""不在 CLI_SUPPORTED_ARGS 中的键应被忽略"""
config = TaskConfigSchema(
tasks=["ODS_MEMBER"],
extra_args={"unknown_param": "value"},
)
cmd = builder.build_command(config, ETL_PATH)
assert "--unknown-param" not in cmd
def test_none_value_ignored(self, builder: CLIBuilder):
config = TaskConfigSchema(
tasks=["ODS_MEMBER"],
extra_args={"pg_dsn": None},
)
cmd = builder.build_command(config, ETL_PATH)
assert "--pg-dsn" not in cmd
def test_false_bool_arg_not_generated(self, builder: CLIBuilder):
config = TaskConfigSchema(
tasks=["ODS_MEMBER"],
extra_args={"force_window_override": False},
)
cmd = builder.build_command(config, ETL_PATH)
assert "--force-window-override" not in cmd
# ---------------------------------------------------------------------------
# build_command_string
# ---------------------------------------------------------------------------
class TestBuildCommandString:
def test_returns_string(self, builder: CLIBuilder):
config = TaskConfigSchema(tasks=["ODS_MEMBER"])
result = builder.build_command_string(config, ETL_PATH)
assert isinstance(result, str)
assert "python -m cli.main" in result
def test_quotes_args_with_spaces(self, builder: CLIBuilder):
config = TaskConfigSchema(
tasks=["ODS_MEMBER"],
extra_args={"pg_dsn": "host=localhost dbname=test"},
)
result = builder.build_command_string(config, ETL_PATH)
# 包含空格的值应被引号包裹
assert '"host=localhost dbname=test"' in result

View File

@@ -0,0 +1,94 @@
"""
数据库连接模块单元测试。
覆盖ETL 只读连接的创建、RLS site_id 设置、只读模式、异常处理。
"""
import os
from unittest.mock import MagicMock, call, patch
import pytest
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
from app.database import get_etl_readonly_connection
# ---------------------------------------------------------------------------
# get_etl_readonly_connection
# ---------------------------------------------------------------------------
class TestGetEtlReadonlyConnection:
"""ETL 只读连接验证连接参数、只读设置、RLS 隔离。"""
@patch("app.database.psycopg2.connect")
def test_sets_readonly_and_site_id(self, mock_connect):
"""连接后应依次执行 SET read_only 和 SET LOCAL site_id。"""
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_connect.return_value = mock_conn
conn = get_etl_readonly_connection(site_id=42)
# 验证 autocommit 被关闭
assert mock_conn.autocommit is False
# 验证执行了两条 SET 语句
executed = [c.args[0] for c in mock_cursor.execute.call_args_list]
assert "SET default_transaction_read_only = on" in executed[0]
assert "SET LOCAL app.current_site_id" in executed[1]
# 验证 site_id 参数化传递(防 SQL 注入)
site_id_call = mock_cursor.execute.call_args_list[1]
assert site_id_call.args[1] == ("42",)
# 验证提交
mock_conn.commit.assert_called_once()
assert conn is mock_conn
@patch("app.database.psycopg2.connect")
def test_accepts_string_site_id(self, mock_connect):
"""site_id 为字符串时也应正常工作。"""
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_connect.return_value = mock_conn
get_etl_readonly_connection(site_id="99")
site_id_call = mock_cursor.execute.call_args_list[1]
assert site_id_call.args[1] == ("99",)
@patch("app.database.psycopg2.connect")
def test_closes_connection_on_setup_error(self, mock_connect):
"""SET 语句执行失败时应关闭连接并抛出异常。"""
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_cursor.execute.side_effect = Exception("SET failed")
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_connect.return_value = mock_conn
with pytest.raises(Exception, match="SET failed"):
get_etl_readonly_connection(site_id=1)
mock_conn.close.assert_called_once()
@patch("app.database.psycopg2.connect")
def test_uses_etl_config_params(self, mock_connect):
"""应使用 ETL_DB_* 配置项连接。"""
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_connect.return_value = mock_conn
get_etl_readonly_connection(site_id=1)
connect_kwargs = mock_connect.call_args.kwargs
# 验证使用了 ETL 数据库名(默认 etl_feiqiu
assert connect_kwargs["dbname"] == "etl_feiqiu"

View File

@@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-
"""数据库查看器属性测试Property-Based Testing
使用 hypothesis 验证数据库查看器的通用正确性属性:
- Property 17: SQL 写操作拦截
- Property 18: SQL 查询结果行数限制
测试策略:
- Property 17: 生成包含写操作关键词(随机大小写混合)的 SQL 字符串,
验证 _WRITE_KEYWORDS 正则表达式能匹配到
- Property 18: 生成随机长度的行列表(可能超过 1000 行),
验证截取前 _MAX_ROWS 个元素后长度 <= 1000
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-db-viewer-properties")
from hypothesis import given, settings
from hypothesis import strategies as st
from app.routers.db_viewer import _WRITE_KEYWORDS, _MAX_ROWS
# ---------------------------------------------------------------------------
# 通用策略Strategies
# ---------------------------------------------------------------------------
# 写操作关键词列表
_WRITE_OPS = ["INSERT", "UPDATE", "DELETE", "DROP", "TRUNCATE"]
# SQL 前缀/后缀:不含写操作关键词的简单文本
_sql_filler_st = st.text(
alphabet=st.characters(
whitelist_categories=("L", "N", "S"),
blacklist_characters="\x00",
),
min_size=0,
max_size=50,
)
# 随机大小写混合的写操作关键词
_random_case_keyword_st = st.sampled_from(_WRITE_OPS).flatmap(
lambda kw: st.tuples(
st.just(kw),
st.lists(
st.booleans(),
min_size=len(kw),
max_size=len(kw),
),
).map(
lambda pair: "".join(
c.upper() if flag else c.lower()
for c, flag in zip(pair[0], pair[1])
)
)
)
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 17: SQL 写操作拦截
# **Validates: Requirements 7.5**
# ---------------------------------------------------------------------------
@settings(max_examples=200)
@given(
prefix=_sql_filler_st,
keyword=_random_case_keyword_st,
suffix=_sql_filler_st,
)
def test_write_keywords_always_detected(prefix, keyword, suffix):
"""Property 17: SQL 写操作拦截。
包含 INSERT、UPDATE、DELETE、DROP、TRUNCATE 关键词(不区分大小写)的
SQL 语句_WRITE_KEYWORDS 正则表达式应能匹配到。
策略:在随机前缀和后缀之间插入一个随机大小写混合的写操作关键词,
用空格分隔以确保 \\b 词边界能匹配。
"""
# 用空格分隔确保词边界匹配
sql = f"{prefix} {keyword} {suffix}"
match = _WRITE_KEYWORDS.search(sql)
assert match is not None, (
f"正则表达式未能匹配到写操作关键词sql={sql!r}, keyword={keyword!r}"
)
# 匹配到的关键词(转大写后)应在写操作列表中
assert match.group(1).upper() in _WRITE_OPS, (
f"匹配到的关键词 '{match.group(1)}' 不在写操作列表中"
)
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 18: SQL 查询结果行数限制
# **Validates: Requirements 7.4**
# ---------------------------------------------------------------------------
# 模拟数据库返回的行:每行是一个简单列表
_row_st = st.lists(
st.one_of(st.integers(), st.text(max_size=20), st.none()),
min_size=1,
max_size=5,
)
# 行列表策略0 到 3000 行(覆盖超过 _MAX_ROWS 的情况)
_rows_st = st.lists(_row_st, min_size=0, max_size=3000)
@settings(max_examples=200)
@given(rows=_rows_st)
def test_row_count_never_exceeds_max(rows):
"""Property 18: SQL 查询结果行数限制。
对任意长度的行列表,取前 _MAX_ROWS 个元素后,
结果长度应 <= 1000。
这等价于 cur.fetchmany(_MAX_ROWS) 的行为:
数据库游标最多返回 _MAX_ROWS 行。
"""
# 模拟 fetchmany(_MAX_ROWS) 的行为
truncated = rows[:_MAX_ROWS]
assert len(truncated) <= _MAX_ROWS, (
f"截取后行数 {len(truncated)} 超过上限 {_MAX_ROWS}"
)
# 额外验证:如果原始行数 <= _MAX_ROWS截取后应保留全部
if len(rows) <= _MAX_ROWS:
assert len(truncated) == len(rows), (
f"原始行数 {len(rows)} <= {_MAX_ROWS},截取后应保留全部,"
f"实际 {len(truncated)}"
)
# 额外验证:如果原始行数 > _MAX_ROWS截取后应恰好为 _MAX_ROWS
if len(rows) > _MAX_ROWS:
assert len(truncated) == _MAX_ROWS, (
f"原始行数 {len(rows)} > {_MAX_ROWS},截取后应恰好为 {_MAX_ROWS}"
f"实际 {len(truncated)}"
)

View File

@@ -0,0 +1,321 @@
# -*- coding: utf-8 -*-
"""数据库查看器路由单元测试
覆盖 4 个端点:
- GET /api/db/schemas
- GET /api/db/schemas/{name}/tables
- GET /api/db/tables/{schema}/{table}/columns
- POST /api/db/query
通过 mock 绕过数据库连接,专注路由逻辑验证。
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
from unittest.mock import patch, MagicMock
import pytest
from fastapi.testclient import TestClient
from psycopg2 import errors as pg_errors
from app.auth.dependencies import CurrentUser, get_current_user
from app.main import app
_TEST_USER = CurrentUser(user_id=1, site_id=100)
def _override_auth():
return _TEST_USER
app.dependency_overrides[get_current_user] = _override_auth
client = TestClient(app)
_MOCK_CONN = "app.routers.db_viewer.get_etl_readonly_connection"
def _make_mock_conn(rows, description=None):
"""构造 mock 数据库连接cursor 返回指定行和列描述。"""
mock_conn = MagicMock()
mock_cur = MagicMock()
mock_cur.fetchall.return_value = rows
mock_cur.fetchmany.return_value = rows
mock_cur.description = description
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cur
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
return mock_conn, mock_cur
# ---------------------------------------------------------------------------
# GET /api/db/schemas
# ---------------------------------------------------------------------------
class TestListSchemas:
@patch(_MOCK_CONN)
def test_returns_schema_list(self, mock_get_conn):
conn, cur = _make_mock_conn([("dwd",), ("dws",), ("ods",)])
mock_get_conn.return_value = conn
resp = client.get("/api/db/schemas")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 3
assert data[0]["name"] == "dwd"
assert data[2]["name"] == "ods"
# 验证 site_id 传递
mock_get_conn.assert_called_once_with(_TEST_USER.site_id)
conn.close.assert_called_once()
@patch(_MOCK_CONN)
def test_empty_schemas(self, mock_get_conn):
conn, cur = _make_mock_conn([])
mock_get_conn.return_value = conn
resp = client.get("/api/db/schemas")
assert resp.status_code == 200
assert resp.json() == []
# ---------------------------------------------------------------------------
# GET /api/db/schemas/{name}/tables
# ---------------------------------------------------------------------------
class TestListTables:
@patch(_MOCK_CONN)
def test_returns_tables_with_row_count(self, mock_get_conn):
conn, cur = _make_mock_conn([
("dim_member", 1500),
("fact_order", 32000),
])
mock_get_conn.return_value = conn
resp = client.get("/api/db/schemas/dwd/tables")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 2
assert data[0]["name"] == "dim_member"
assert data[0]["row_count"] == 1500
assert data[1]["name"] == "fact_order"
assert data[1]["row_count"] == 32000
@patch(_MOCK_CONN)
def test_null_row_count(self, mock_get_conn):
"""pg_stat_user_tables 可能没有统计信息row_count 为 None。"""
conn, cur = _make_mock_conn([("new_table", None)])
mock_get_conn.return_value = conn
resp = client.get("/api/db/schemas/ods/tables")
assert resp.status_code == 200
data = resp.json()
assert data[0]["row_count"] is None
@patch(_MOCK_CONN)
def test_empty_schema(self, mock_get_conn):
conn, cur = _make_mock_conn([])
mock_get_conn.return_value = conn
resp = client.get("/api/db/schemas/empty_schema/tables")
assert resp.status_code == 200
assert resp.json() == []
# ---------------------------------------------------------------------------
# GET /api/db/tables/{schema}/{table}/columns
# ---------------------------------------------------------------------------
class TestListColumns:
@patch(_MOCK_CONN)
def test_returns_column_definitions(self, mock_get_conn):
conn, cur = _make_mock_conn([
("id", "bigint", "NO", None),
("name", "character varying", "YES", None),
("created_at", "timestamp with time zone", "NO", "now()"),
])
mock_get_conn.return_value = conn
resp = client.get("/api/db/tables/dwd/dim_member/columns")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 3
assert data[0]["name"] == "id"
assert data[0]["data_type"] == "bigint"
assert data[0]["is_nullable"] is False
assert data[0]["column_default"] is None
assert data[1]["is_nullable"] is True
assert data[2]["column_default"] == "now()"
@patch(_MOCK_CONN)
def test_empty_table(self, mock_get_conn):
conn, cur = _make_mock_conn([])
mock_get_conn.return_value = conn
resp = client.get("/api/db/tables/dwd/nonexistent/columns")
assert resp.status_code == 200
assert resp.json() == []
# ---------------------------------------------------------------------------
# POST /api/db/query
# ---------------------------------------------------------------------------
class TestExecuteQuery:
@patch(_MOCK_CONN)
def test_successful_select(self, mock_get_conn):
description = [("id",), ("name",)]
conn, cur = _make_mock_conn(
[(1, "Alice"), (2, "Bob")],
description=description,
)
mock_get_conn.return_value = conn
resp = client.post("/api/db/query", json={"sql": "SELECT id, name FROM users"})
assert resp.status_code == 200
data = resp.json()
assert data["columns"] == ["id", "name"]
assert data["rows"] == [[1, "Alice"], [2, "Bob"]]
assert data["row_count"] == 2
@patch(_MOCK_CONN)
def test_empty_result(self, mock_get_conn):
description = [("id",)]
conn, cur = _make_mock_conn([], description=description)
mock_get_conn.return_value = conn
resp = client.post("/api/db/query", json={"sql": "SELECT id FROM empty_table"})
assert resp.status_code == 200
data = resp.json()
assert data["columns"] == ["id"]
assert data["rows"] == []
assert data["row_count"] == 0
# ── 写操作拦截 ──
@pytest.mark.parametrize("keyword", [
"INSERT", "UPDATE", "DELETE", "DROP", "TRUNCATE",
"insert", "update", "delete", "drop", "truncate",
"Insert", "Update", "Delete", "Drop", "Truncate",
])
def test_blocks_write_operations(self, keyword):
resp = client.post("/api/db/query", json={"sql": f"{keyword} INTO some_table VALUES (1)"})
assert resp.status_code == 400
assert "只读" in resp.json()["detail"] or "禁止" in resp.json()["detail"]
def test_blocks_mixed_case_write(self):
resp = client.post("/api/db/query", json={"sql": "DeLeTe FROM users WHERE id = 1"})
assert resp.status_code == 400
def test_blocks_write_in_subquery(self):
"""写操作关键词出现在 SQL 任意位置都应拦截。"""
resp = client.post("/api/db/query", json={"sql": "SELECT * FROM (DELETE FROM users) sub"})
assert resp.status_code == 400
# ── 空 SQL ──
def test_empty_sql(self):
resp = client.post("/api/db/query", json={"sql": ""})
assert resp.status_code == 400
def test_whitespace_only_sql(self):
resp = client.post("/api/db/query", json={"sql": " "})
assert resp.status_code == 400
# ── SQL 语法错误 ──
@patch(_MOCK_CONN)
def test_sql_syntax_error(self, mock_get_conn):
conn = MagicMock()
mock_cur = MagicMock()
# 第一次 execute 设置 timeout 成功,第二次抛异常
mock_cur.execute.side_effect = [None, Exception("syntax error at or near \"SELEC\"")]
mock_cur.description = None
conn.cursor.return_value.__enter__ = lambda s: mock_cur
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = conn
resp = client.post("/api/db/query", json={"sql": "SELEC * FROM users"})
assert resp.status_code == 400
assert "SQL 执行错误" in resp.json()["detail"]
# ── 查询超时 ──
@patch(_MOCK_CONN)
def test_query_timeout(self, mock_get_conn):
conn = MagicMock()
mock_cur = MagicMock()
mock_cur.execute.side_effect = [None, pg_errors.QueryCanceled()]
mock_cur.description = None
conn.cursor.return_value.__enter__ = lambda s: mock_cur
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = conn
resp = client.post("/api/db/query", json={"sql": "SELECT pg_sleep(60)"})
assert resp.status_code == 408
assert "超时" in resp.json()["detail"]
# ── 行数限制验证 ──
@patch(_MOCK_CONN)
def test_row_limit(self, mock_get_conn):
"""验证 fetchmany 被调用时传入 1000 行限制。"""
description = [("id",)]
conn, cur = _make_mock_conn(
[(i,) for i in range(1000)],
description=description,
)
mock_get_conn.return_value = conn
resp = client.post("/api/db/query", json={"sql": "SELECT id FROM big_table"})
assert resp.status_code == 200
# 验证 fetchmany 被调用时传入了 1000
cur.fetchmany.assert_called_once_with(1000)
# ── 超时设置验证 ──
@patch(_MOCK_CONN)
def test_sets_statement_timeout(self, mock_get_conn):
"""验证查询前设置了 statement_timeout。"""
description = [("id",)]
conn, cur = _make_mock_conn([(1,)], description=description)
mock_get_conn.return_value = conn
client.post("/api/db/query", json={"sql": "SELECT 1"})
# 第一次 execute 应该是设置超时
first_call = cur.execute.call_args_list[0]
assert "statement_timeout" in first_call[0][0]
# ---------------------------------------------------------------------------
# 认证测试
# ---------------------------------------------------------------------------
class TestDbViewerAuth:
def test_requires_auth(self):
"""移除 auth override 后,所有端点应返回 401/403。"""
original = app.dependency_overrides.pop(get_current_user, None)
try:
endpoints = [
("GET", "/api/db/schemas"),
("GET", "/api/db/schemas/dwd/tables"),
("GET", "/api/db/tables/dwd/dim_member/columns"),
("POST", "/api/db/query"),
]
for method, url in endpoints:
if method == "POST":
resp = client.request(method, url, json={"sql": "SELECT 1"})
else:
resp = client.request(method, url)
assert resp.status_code in (401, 403), f"{method} {url} 应需要认证"
finally:
if original:
app.dependency_overrides[get_current_user] = original

View File

@@ -0,0 +1,191 @@
# -*- coding: utf-8 -*-
"""环境配置属性测试Property-Based Testing
使用 hypothesis 验证环境配置管理的通用正确性属性:
- Property 15: .env 解析与敏感值掩码
- Property 16: .env 写入往返一致性
测试策略:
- Property 15: 生成随机 .env 内容(含敏感和非敏感键),验证 _parse_env + _is_sensitive 对敏感值掩码
- Property 16: 生成随机键值对,序列化为 .env 格式后再解析,验证往返一致性
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-env-config-properties")
from hypothesis import given, settings, assume
from hypothesis import strategies as st
from app.routers.env_config import _parse_env, _is_sensitive, _MASK, _SENSITIVE_KEYWORDS
# ---------------------------------------------------------------------------
# 通用策略Strategies
# ---------------------------------------------------------------------------
# 合法的环境变量键名:字母或下划线开头,后跟字母、数字、下划线
_key_start_char = st.sampled_from(
list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_")
)
_key_rest_char = st.sampled_from(
list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_")
)
_env_key_st = st.builds(
lambda first, rest: first + rest,
first=_key_start_char,
rest=st.text(alphabet=_key_rest_char, min_size=0, max_size=30),
)
# 值:不含换行符的可打印字符串(排除引号以避免解析歧义)
_env_value_st = st.text(
alphabet=st.characters(
whitelist_categories=("L", "N", "P", "S"),
blacklist_characters='\n\r"\'#',
),
min_size=0,
max_size=50,
)
# 敏感键:在随机键名中嵌入敏感关键词
_sensitive_keyword_st = st.sampled_from(list(_SENSITIVE_KEYWORDS))
_sensitive_key_st = st.builds(
lambda prefix, kw, suffix: prefix + kw + suffix,
prefix=st.text(alphabet=_key_rest_char, min_size=0, max_size=10),
kw=_sensitive_keyword_st,
suffix=st.text(alphabet=_key_rest_char, min_size=0, max_size=10),
).filter(lambda k: len(k) > 0 and k[0].isalpha() or k[0] == "_")
# 确保敏感键以字母或下划线开头
_safe_sensitive_key_st = st.builds(
lambda prefix, kw: prefix + "_" + kw,
prefix=st.sampled_from(["DB", "API", "ETL", "APP", "MY"]),
kw=_sensitive_keyword_st,
)
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 15: .env 解析与敏感值掩码
# **Validates: Requirements 6.1, 6.3**
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(
sensitive_keys=st.lists(_safe_sensitive_key_st, min_size=1, max_size=5, unique=True),
sensitive_values=st.lists(
st.text(min_size=1, max_size=30, alphabet=st.characters(
whitelist_categories=("L", "N"),
)),
min_size=1, max_size=5,
),
normal_keys=st.lists(_env_key_st, min_size=1, max_size=5, unique=True),
normal_values=st.lists(_env_value_st, min_size=1, max_size=5),
)
def test_sensitive_values_masked(sensitive_keys, sensitive_values, normal_keys, normal_values):
"""Property 15: .env 解析与敏感值掩码。
包含敏感键PASSWORD、TOKEN、SECRET、DSN的 .env 文件内容,
API 返回的键值对列表中这些键的值应被掩码替换,不包含原始敏感值。
"""
# 确保敏感键和普通键不重叠
normal_keys_filtered = [k for k in normal_keys if k not in sensitive_keys]
assume(len(normal_keys_filtered) >= 1)
# 对齐列表长度
s_vals = (sensitive_values * ((len(sensitive_keys) // len(sensitive_values)) + 1))[:len(sensitive_keys)]
n_vals = (normal_values * ((len(normal_keys_filtered) // len(normal_values)) + 1))[:len(normal_keys_filtered)]
# 构造 .env 内容
lines = []
for k, v in zip(sensitive_keys, s_vals):
lines.append(f"{k}={v}")
for k, v in zip(normal_keys_filtered, n_vals):
lines.append(f"{k}={v}")
env_content = "\n".join(lines) + "\n"
# 解析
parsed = _parse_env(env_content)
entries = [line for line in parsed if line["type"] == "entry"]
# 模拟 GET 端点的掩码逻辑
masked_entries = {}
for entry in entries:
if _is_sensitive(entry["key"]):
masked_entries[entry["key"]] = _MASK
else:
masked_entries[entry["key"]] = entry["value"]
# 验证:敏感键的值应被掩码
for k, v in zip(sensitive_keys, s_vals):
assert k in masked_entries, f"敏感键 {k} 应出现在解析结果中"
assert masked_entries[k] == _MASK, (
f"敏感键 {k} 的值应为掩码 '{_MASK}',实际为 '{masked_entries[k]}'"
)
# 原始敏感值不应出现在掩码后的结果中
assert masked_entries[k] != v, (
f"敏感键 {k} 的原始值 '{v}' 不应出现在掩码结果中"
)
# 验证:非敏感键的值应保持原样
for k, v in zip(normal_keys_filtered, n_vals):
if not _is_sensitive(k):
assert k in masked_entries, f"普通键 {k} 应出现在解析结果中"
assert masked_entries[k] == v, (
f"普通键 {k} 的值应为 '{v}',实际为 '{masked_entries[k]}'"
)
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 16: .env 写入往返一致性
# **Validates: Requirements 6.2**
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(
entries=st.lists(
st.tuples(_env_key_st, _env_value_st),
min_size=1,
max_size=10,
unique_by=lambda t: t[0], # 键唯一
),
)
def test_env_write_read_round_trip(entries):
"""Property 16: .env 写入往返一致性。
有效的键值对集合(不含注释和空行),写入 .env 文件后再读取解析,
应得到与原始集合等价的键值对。
"""
# 过滤掉值中可能导致解析歧义的情况(值前后空白会被 strip
clean_entries = [(k, v.strip()) for k, v in entries]
# 排除空键(策略已保证非空,但防御性检查)
clean_entries = [(k, v) for k, v in clean_entries if k]
assume(len(clean_entries) >= 1)
# 模拟写入:构造 .env 文件内容
lines = [f"{k}={v}" for k, v in clean_entries]
env_content = "\n".join(lines) + "\n"
# 解析
parsed = _parse_env(env_content)
parsed_entries = {
line["key"]: line["value"]
for line in parsed
if line["type"] == "entry"
}
# 验证往返一致性:每个写入的键值对都应在解析结果中
for k, v in clean_entries:
assert k in parsed_entries, (
f"'{k}' 应出现在解析结果中,实际键集合: {list(parsed_entries.keys())}"
)
assert parsed_entries[k] == v, (
f"'{k}' 的值不一致:写入='{v}',解析='{parsed_entries[k]}'"
)
# 验证:解析结果的键数量应与写入的一致
assert len(parsed_entries) == len(clean_entries), (
f"解析结果键数量 {len(parsed_entries)} 应等于写入数量 {len(clean_entries)}"
)

View File

@@ -0,0 +1,291 @@
# -*- coding: utf-8 -*-
"""环境配置路由单元测试
覆盖 3 个端点GET / PUT / GET /export
通过 mock 绕过文件 I/O专注路由逻辑验证。
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
from unittest.mock import patch, MagicMock
import pytest
from fastapi.testclient import TestClient
from app.auth.dependencies import CurrentUser, get_current_user
from app.main import app
_TEST_USER = CurrentUser(user_id=1, site_id=100)
def _override_auth():
return _TEST_USER
app.dependency_overrides[get_current_user] = _override_auth
client = TestClient(app)
# 模拟 .env 文件内容
_SAMPLE_ENV = """\
# 数据库配置
DB_HOST=localhost
DB_PORT=5432
DB_PASSWORD=super_secret_123
JWT_SECRET_KEY=my-jwt-secret
# ETL 配置
ETL_DB_DSN=postgresql://user:pass@host/db
TIMEZONE=Asia/Shanghai
"""
_MOCK_ENV_PATH = "app.routers.env_config._ENV_PATH"
def _mock_path(content: str | None = _SAMPLE_ENV, exists: bool = True):
"""构造 mock Path 对象。"""
mock = MagicMock()
mock.exists.return_value = exists
if content is not None:
mock.read_text.return_value = content
return mock
# ---------------------------------------------------------------------------
# GET /api/env-config
# ---------------------------------------------------------------------------
class TestGetEnvConfig:
@patch(_MOCK_ENV_PATH)
def test_returns_entries_with_masked_sensitive(self, mock_path_obj):
mock_path_obj.__class__ = type(MagicMock())
mock_path_obj.exists.return_value = True
mock_path_obj.read_text.return_value = _SAMPLE_ENV
resp = client.get("/api/env-config")
assert resp.status_code == 200
data = resp.json()
entries = {e["key"]: e["value"] for e in data["entries"]}
# 非敏感值原样返回
assert entries["DB_HOST"] == "localhost"
assert entries["DB_PORT"] == "5432"
assert entries["TIMEZONE"] == "Asia/Shanghai"
# 敏感值掩码
assert entries["DB_PASSWORD"] == "****"
assert entries["JWT_SECRET_KEY"] == "****"
assert entries["ETL_DB_DSN"] == "****"
@patch(_MOCK_ENV_PATH)
def test_file_not_found(self, mock_path_obj):
mock_path_obj.exists.return_value = False
resp = client.get("/api/env-config")
assert resp.status_code == 404
@patch(_MOCK_ENV_PATH)
def test_empty_file(self, mock_path_obj):
mock_path_obj.exists.return_value = True
mock_path_obj.read_text.return_value = ""
resp = client.get("/api/env-config")
assert resp.status_code == 200
assert resp.json()["entries"] == []
@patch(_MOCK_ENV_PATH)
def test_comments_and_blank_lines_excluded(self, mock_path_obj):
mock_path_obj.exists.return_value = True
mock_path_obj.read_text.return_value = "# comment\n\nKEY=val\n"
resp = client.get("/api/env-config")
assert resp.status_code == 200
entries = resp.json()["entries"]
assert len(entries) == 1
assert entries[0]["key"] == "KEY"
# ---------------------------------------------------------------------------
# PUT /api/env-config
# ---------------------------------------------------------------------------
class TestUpdateEnvConfig:
@patch(_MOCK_ENV_PATH)
def test_update_existing_key(self, mock_path_obj):
mock_path_obj.exists.return_value = True
mock_path_obj.read_text.return_value = "DB_HOST=localhost\nDB_PORT=5432\n"
resp = client.put("/api/env-config", json={
"entries": [
{"key": "DB_HOST", "value": "192.168.1.1"},
{"key": "DB_PORT", "value": "5433"},
]
})
assert resp.status_code == 200
# 验证写入内容
written = mock_path_obj.write_text.call_args[0][0]
assert "DB_HOST=192.168.1.1" in written
assert "DB_PORT=5433" in written
@patch(_MOCK_ENV_PATH)
def test_add_new_key(self, mock_path_obj):
mock_path_obj.exists.return_value = True
mock_path_obj.read_text.return_value = "DB_HOST=localhost\n"
resp = client.put("/api/env-config", json={
"entries": [
{"key": "DB_HOST", "value": "localhost"},
{"key": "NEW_KEY", "value": "new_value"},
]
})
assert resp.status_code == 200
written = mock_path_obj.write_text.call_args[0][0]
assert "NEW_KEY=new_value" in written
@patch(_MOCK_ENV_PATH)
def test_masked_value_preserves_original(self, mock_path_obj):
"""掩码值(****)不应覆盖原始敏感值。"""
mock_path_obj.exists.return_value = True
mock_path_obj.read_text.return_value = "DB_PASSWORD=real_secret\nDB_HOST=localhost\n"
resp = client.put("/api/env-config", json={
"entries": [
{"key": "DB_PASSWORD", "value": "****"},
{"key": "DB_HOST", "value": "newhost"},
]
})
assert resp.status_code == 200
written = mock_path_obj.write_text.call_args[0][0]
# 原始密码应保留
assert "DB_PASSWORD=real_secret" in written
assert "DB_HOST=newhost" in written
@patch(_MOCK_ENV_PATH)
def test_preserves_comments(self, mock_path_obj):
mock_path_obj.exists.return_value = True
mock_path_obj.read_text.return_value = "# 注释行\nDB_HOST=localhost\n\n# 另一个注释\n"
resp = client.put("/api/env-config", json={
"entries": [{"key": "DB_HOST", "value": "newhost"}]
})
assert resp.status_code == 200
written = mock_path_obj.write_text.call_args[0][0]
assert "# 注释行" in written
assert "# 另一个注释" in written
def test_invalid_key_format(self):
resp = client.put("/api/env-config", json={
"entries": [{"key": "123BAD", "value": "val"}]
})
assert resp.status_code == 422
def test_empty_key(self):
resp = client.put("/api/env-config", json={
"entries": [{"key": "", "value": "val"}]
})
assert resp.status_code == 422
@patch(_MOCK_ENV_PATH)
def test_file_not_exists_creates_new(self, mock_path_obj):
"""文件不存在时,应创建新文件。"""
mock_path_obj.exists.return_value = False
resp = client.put("/api/env-config", json={
"entries": [{"key": "NEW_KEY", "value": "value"}]
})
assert resp.status_code == 200
written = mock_path_obj.write_text.call_args[0][0]
assert "NEW_KEY=value" in written
@patch(_MOCK_ENV_PATH)
def test_update_sensitive_with_new_value(self, mock_path_obj):
"""显式提供新密码时应更新。"""
mock_path_obj.exists.return_value = True
mock_path_obj.read_text.return_value = "DB_PASSWORD=old_secret\n"
resp = client.put("/api/env-config", json={
"entries": [{"key": "DB_PASSWORD", "value": "new_secret"}]
})
assert resp.status_code == 200
written = mock_path_obj.write_text.call_args[0][0]
assert "DB_PASSWORD=new_secret" in written
# 返回值中敏感键仍然掩码
entries = {e["key"]: e["value"] for e in resp.json()["entries"]}
assert entries["DB_PASSWORD"] == "****"
# ---------------------------------------------------------------------------
# GET /api/env-config/export
# ---------------------------------------------------------------------------
class TestExportEnvConfig:
@patch(_MOCK_ENV_PATH)
def test_export_masks_sensitive(self, mock_path_obj):
mock_path_obj.exists.return_value = True
mock_path_obj.read_text.return_value = _SAMPLE_ENV
resp = client.get("/api/env-config/export")
assert resp.status_code == 200
assert resp.headers["content-type"].startswith("text/plain")
assert "attachment" in resp.headers.get("content-disposition", "")
content = resp.text
# 非敏感值保留
assert "DB_HOST=localhost" in content
assert "TIMEZONE=Asia/Shanghai" in content
# 敏感值掩码
assert "super_secret_123" not in content
assert "my-jwt-secret" not in content
assert "DB_PASSWORD=****" in content
assert "JWT_SECRET_KEY=****" in content
@patch(_MOCK_ENV_PATH)
def test_export_preserves_comments(self, mock_path_obj):
mock_path_obj.exists.return_value = True
mock_path_obj.read_text.return_value = _SAMPLE_ENV
content = client.get("/api/env-config/export").text
assert "# 数据库配置" in content
assert "# ETL 配置" in content
@patch(_MOCK_ENV_PATH)
def test_export_file_not_found(self, mock_path_obj):
mock_path_obj.exists.return_value = False
resp = client.get("/api/env-config/export")
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# 认证测试
# ---------------------------------------------------------------------------
class TestEnvConfigAuth:
def test_requires_auth(self):
"""移除 auth override 后,所有端点应返回 401/403。"""
# 临时移除 override
original = app.dependency_overrides.pop(get_current_user, None)
try:
for method, url in [
("GET", "/api/env-config"),
("PUT", "/api/env-config"),
("GET", "/api/env-config/export"),
]:
resp = client.request(method, url)
assert resp.status_code in (401, 403), f"{method} {url} 应需要认证"
finally:
if original:
app.dependency_overrides[get_current_user] = original

View File

@@ -0,0 +1,246 @@
# -*- coding: utf-8 -*-
"""ETL 状态路由单元测试
覆盖 2 个端点:
- GET /api/etl-status/cursors
- GET /api/etl-status/recent-runs
通过 mock 绕过数据库连接,专注路由逻辑验证。
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
from unittest.mock import patch, MagicMock
from fastapi.testclient import TestClient
from app.auth.dependencies import CurrentUser, get_current_user
from app.main import app
_TEST_USER = CurrentUser(user_id=1, site_id=100)
def _override_auth():
return _TEST_USER
app.dependency_overrides[get_current_user] = _override_auth
client = TestClient(app)
_MOCK_ETL_CONN = "app.routers.etl_status.get_etl_readonly_connection"
_MOCK_APP_CONN = "app.routers.etl_status.get_connection"
def _make_mock_conn(rows):
"""构造 mock 数据库连接cursor 返回指定行。"""
mock_conn = MagicMock()
mock_cur = MagicMock()
mock_cur.fetchall.return_value = rows
mock_cur.fetchone.return_value = None
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cur
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
return mock_conn, mock_cur
# ---------------------------------------------------------------------------
# GET /api/etl-status/cursors
# ---------------------------------------------------------------------------
class TestListCursors:
@patch(_MOCK_ETL_CONN)
def test_returns_cursor_list(self, mock_get_conn):
conn, cur = _make_mock_conn([
("ODS_FETCH_ORDERS", "2024-06-15 10:30:00+08", 1500),
("ODS_FETCH_MEMBERS", "2024-06-15 09:00:00+08", 800),
])
# fetchone 用于 EXISTS 检查
cur.fetchone.return_value = (True,)
mock_get_conn.return_value = conn
resp = client.get("/api/etl-status/cursors")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 2
assert data[0]["task_code"] == "ODS_FETCH_ORDERS"
assert data[0]["last_fetch_time"] == "2024-06-15 10:30:00+08"
assert data[0]["record_count"] == 1500
assert data[1]["task_code"] == "ODS_FETCH_MEMBERS"
# 验证 site_id 传递
mock_get_conn.assert_called_once_with(_TEST_USER.site_id)
conn.close.assert_called_once()
@patch(_MOCK_ETL_CONN)
def test_table_not_exists_returns_empty(self, mock_get_conn):
"""etl_admin.etl_cursor 表不存在时返回空列表。"""
conn, cur = _make_mock_conn([])
cur.fetchone.return_value = (False,)
mock_get_conn.return_value = conn
resp = client.get("/api/etl-status/cursors")
assert resp.status_code == 200
assert resp.json() == []
@patch(_MOCK_ETL_CONN)
def test_null_fields(self, mock_get_conn):
"""游标字段可能为 None任务从未执行过"""
conn, cur = _make_mock_conn([
("ODS_FETCH_INVENTORY", None, None),
])
cur.fetchone.return_value = (True,)
mock_get_conn.return_value = conn
resp = client.get("/api/etl-status/cursors")
assert resp.status_code == 200
data = resp.json()
assert data[0]["task_code"] == "ODS_FETCH_INVENTORY"
assert data[0]["last_fetch_time"] is None
assert data[0]["record_count"] is None
@patch(_MOCK_ETL_CONN)
def test_empty_cursors(self, mock_get_conn):
"""表存在但无数据。"""
conn, cur = _make_mock_conn([])
cur.fetchone.return_value = (True,)
mock_get_conn.return_value = conn
resp = client.get("/api/etl-status/cursors")
assert resp.status_code == 200
assert resp.json() == []
# ---------------------------------------------------------------------------
# GET /api/etl-status/recent-runs
# ---------------------------------------------------------------------------
class TestListRecentRuns:
@patch(_MOCK_APP_CONN)
def test_returns_recent_runs(self, mock_get_conn):
conn, cur = _make_mock_conn([
(
"a1b2c3d4-0000-0000-0000-000000000001",
["ODS_FETCH_ORDERS", "DWD_LOAD_FROM_ODS"],
"success",
"2024-06-15 10:30:00+08",
"2024-06-15 10:35:00+08",
300000,
0,
),
(
"a1b2c3d4-0000-0000-0000-000000000002",
["DWS_AGGREGATE"],
"failed",
"2024-06-15 09:00:00+08",
"2024-06-15 09:01:00+08",
60000,
1,
),
])
mock_get_conn.return_value = conn
resp = client.get("/api/etl-status/recent-runs")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 2
run0 = data[0]
assert run0["id"] == "a1b2c3d4-0000-0000-0000-000000000001"
assert run0["task_codes"] == ["ODS_FETCH_ORDERS", "DWD_LOAD_FROM_ODS"]
assert run0["status"] == "success"
assert run0["duration_ms"] == 300000
assert run0["exit_code"] == 0
run1 = data[1]
assert run1["status"] == "failed"
assert run1["exit_code"] == 1
conn.close.assert_called_once()
@patch(_MOCK_APP_CONN)
def test_empty_runs(self, mock_get_conn):
conn, cur = _make_mock_conn([])
mock_get_conn.return_value = conn
resp = client.get("/api/etl-status/recent-runs")
assert resp.status_code == 200
assert resp.json() == []
@patch(_MOCK_APP_CONN)
def test_null_optional_fields(self, mock_get_conn):
"""正在执行的任务 finished_at / duration_ms / exit_code 为 None。"""
conn, cur = _make_mock_conn([
(
"a1b2c3d4-0000-0000-0000-000000000003",
["ODS_FETCH_MEMBERS"],
"running",
"2024-06-15 11:00:00+08",
None,
None,
None,
),
])
mock_get_conn.return_value = conn
resp = client.get("/api/etl-status/recent-runs")
assert resp.status_code == 200
data = resp.json()
assert data[0]["status"] == "running"
assert data[0]["finished_at"] is None
assert data[0]["duration_ms"] is None
assert data[0]["exit_code"] is None
@patch(_MOCK_APP_CONN)
def test_site_id_filter(self, mock_get_conn):
"""验证查询时传入了正确的 site_id 参数。"""
conn, cur = _make_mock_conn([])
mock_get_conn.return_value = conn
client.get("/api/etl-status/recent-runs")
# 验证 SQL 中传入了 site_id 和 limit
call_args = cur.execute.call_args
params = call_args[0][1]
assert params[0] == _TEST_USER.site_id
assert params[1] == 50
@patch(_MOCK_APP_CONN)
def test_empty_task_codes(self, mock_get_conn):
"""task_codes 为 None 时应返回空列表。"""
conn, cur = _make_mock_conn([
(
"a1b2c3d4-0000-0000-0000-000000000004",
None,
"pending",
"2024-06-15 12:00:00+08",
None,
None,
None,
),
])
mock_get_conn.return_value = conn
resp = client.get("/api/etl-status/recent-runs")
assert resp.status_code == 200
assert resp.json()[0]["task_codes"] == []
# ---------------------------------------------------------------------------
# 认证测试
# ---------------------------------------------------------------------------
class TestEtlStatusAuth:
def test_requires_auth(self):
"""移除 auth override 后,所有端点应返回 401/403。"""
original = app.dependency_overrides.pop(get_current_user, None)
try:
for url in ["/api/etl-status/cursors", "/api/etl-status/recent-runs"]:
resp = client.get(url)
assert resp.status_code in (401, 403), f"GET {url} 应需要认证"
finally:
if original:
app.dependency_overrides[get_current_user] = original

View File

@@ -0,0 +1,339 @@
# -*- coding: utf-8 -*-
"""执行与队列路由单元测试
覆盖 8 个端点run / queue CRUD / cancel / history / logs
通过 mock 绕过数据库和服务层,专注路由逻辑验证。
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
from dataclasses import dataclass
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from app.auth.dependencies import CurrentUser, get_current_user
from app.main import app
from app.services.task_queue import QueuedTask
# 固定测试用户
_TEST_USER = CurrentUser(user_id=1, site_id=100)
def _override_auth():
return _TEST_USER
app.dependency_overrides[get_current_user] = _override_auth
client = TestClient(app)
_NOW = datetime(2024, 6, 1, 12, 0, 0, tzinfo=timezone.utc)
# 构造测试用的 TaskConfig payload
_VALID_CONFIG = {
"tasks": ["ODS_MEMBER"],
"pipeline": "api_ods",
}
# ---------------------------------------------------------------------------
# POST /api/execution/run
# ---------------------------------------------------------------------------
class TestRunTask:
@patch("app.routers.execution.task_executor")
def test_run_returns_execution_id(self, mock_executor):
mock_executor.execute = AsyncMock()
resp = client.post("/api/execution/run", json=_VALID_CONFIG)
assert resp.status_code == 200
data = resp.json()
assert "execution_id" in data
assert data["message"] == "任务已提交执行"
@patch("app.routers.execution.task_executor")
def test_run_injects_store_id(self, mock_executor):
"""store_id 应从 JWT 注入"""
mock_executor.execute = AsyncMock()
resp = client.post("/api/execution/run", json={
**_VALID_CONFIG,
"store_id": 999, # 前端传的值应被覆盖
})
assert resp.status_code == 200
def test_run_requires_auth(self):
app.dependency_overrides.pop(get_current_user, None)
try:
resp = client.post("/api/execution/run", json=_VALID_CONFIG)
assert resp.status_code in (401, 403)
finally:
app.dependency_overrides[get_current_user] = _override_auth
def test_run_invalid_config_returns_422(self):
"""缺少必填字段 tasks 时返回 422"""
resp = client.post("/api/execution/run", json={"pipeline": "api_ods"})
assert resp.status_code == 422
# ---------------------------------------------------------------------------
# GET /api/execution/queue
# ---------------------------------------------------------------------------
class TestGetQueue:
@patch("app.routers.execution.task_queue")
def test_get_queue_returns_list(self, mock_queue):
mock_queue.list_pending.return_value = [
QueuedTask(
id="task-1", site_id=100, config={"tasks": ["ODS_MEMBER"]},
status="pending", position=1, created_at=_NOW,
),
]
resp = client.get("/api/execution/queue")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 1
assert data[0]["id"] == "task-1"
assert data[0]["status"] == "pending"
@patch("app.routers.execution.task_queue")
def test_get_queue_empty(self, mock_queue):
mock_queue.list_pending.return_value = []
resp = client.get("/api/execution/queue")
assert resp.status_code == 200
assert resp.json() == []
@patch("app.routers.execution.task_queue")
def test_get_queue_filters_by_site_id(self, mock_queue):
"""确认调用 list_pending 时传入了正确的 site_id"""
mock_queue.list_pending.return_value = []
client.get("/api/execution/queue")
mock_queue.list_pending.assert_called_once_with(100)
# ---------------------------------------------------------------------------
# POST /api/execution/queue
# ---------------------------------------------------------------------------
class TestEnqueueTask:
@patch("app.routers.execution.get_connection")
@patch("app.routers.execution.task_queue")
def test_enqueue_returns_201(self, mock_queue, mock_get_conn):
mock_queue.enqueue.return_value = "new-task-id"
# mock 数据库查询返回
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = (
"new-task-id", 100, '{"tasks": ["ODS_MEMBER"]}',
"pending", 1, _NOW, None, None, None, None,
)
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.post("/api/execution/queue", json=_VALID_CONFIG)
assert resp.status_code == 201
data = resp.json()
assert data["id"] == "new-task-id"
assert data["status"] == "pending"
@patch("app.routers.execution.task_queue")
def test_enqueue_calls_with_site_id(self, mock_queue):
"""确认 enqueue 时传入了 JWT 的 site_id"""
mock_queue.enqueue.return_value = "id-1"
# 让后续的 DB 查询抛异常来快速结束enqueue 本身已验证)
with patch("app.routers.execution.get_connection") as mock_conn:
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = (
"id-1", 100, '{"tasks": []}', "pending", 1,
_NOW, None, None, None, None,
)
conn = MagicMock()
conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_conn.return_value = conn
client.post("/api/execution/queue", json=_VALID_CONFIG)
# 验证 enqueue 的第二个参数是 site_id=100
call_args = mock_queue.enqueue.call_args
assert call_args[0][1] == 100 # site_id
# ---------------------------------------------------------------------------
# PUT /api/execution/queue/reorder
# ---------------------------------------------------------------------------
class TestReorderQueue:
@patch("app.routers.execution.task_queue")
def test_reorder_success(self, mock_queue):
mock_queue.reorder.return_value = None
resp = client.put("/api/execution/queue/reorder", json={
"task_id": "task-1",
"new_position": 3,
})
assert resp.status_code == 200
mock_queue.reorder.assert_called_once_with("task-1", 3, 100)
def test_reorder_missing_fields_returns_422(self):
resp = client.put("/api/execution/queue/reorder", json={})
assert resp.status_code == 422
# ---------------------------------------------------------------------------
# DELETE /api/execution/queue/{id}
# ---------------------------------------------------------------------------
class TestDeleteQueueTask:
@patch("app.routers.execution.task_queue")
def test_delete_success(self, mock_queue):
mock_queue.delete.return_value = True
resp = client.delete("/api/execution/queue/task-1")
assert resp.status_code == 200
mock_queue.delete.assert_called_once_with("task-1", 100)
@patch("app.routers.execution.task_queue")
def test_delete_nonexistent_returns_409(self, mock_queue):
mock_queue.delete.return_value = False
resp = client.delete("/api/execution/queue/nonexistent")
assert resp.status_code == 409
# ---------------------------------------------------------------------------
# POST /api/execution/{id}/cancel
# ---------------------------------------------------------------------------
class TestCancelExecution:
@patch("app.routers.execution.task_executor")
def test_cancel_success(self, mock_executor):
mock_executor.cancel = AsyncMock(return_value=True)
resp = client.post("/api/execution/some-id/cancel")
assert resp.status_code == 200
assert "取消" in resp.json()["message"]
@patch("app.routers.execution.task_executor")
def test_cancel_not_found(self, mock_executor):
mock_executor.cancel = AsyncMock(return_value=False)
resp = client.post("/api/execution/nonexistent/cancel")
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# GET /api/execution/history
# ---------------------------------------------------------------------------
class TestExecutionHistory:
@patch("app.routers.execution.get_connection")
def test_history_returns_list(self, mock_get_conn):
mock_cursor = MagicMock()
mock_cursor.fetchall.return_value = [
(
"exec-1", 100, ["ODS_MEMBER"], "success",
_NOW, _NOW, 0, 1234, "python -m cli.main", None,
),
]
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.get("/api/execution/history")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 1
assert data[0]["id"] == "exec-1"
assert data[0]["status"] == "success"
@patch("app.routers.execution.get_connection")
def test_history_respects_limit(self, mock_get_conn):
mock_cursor = MagicMock()
mock_cursor.fetchall.return_value = []
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.get("/api/execution/history?limit=10")
assert resp.status_code == 200
# 验证 SQL 中传入了 limit=10
call_args = mock_cursor.execute.call_args
assert call_args[0][1] == (100, 10) # (site_id, limit)
def test_history_limit_validation(self):
"""limit 超出范围时返回 422"""
resp = client.get("/api/execution/history?limit=0")
assert resp.status_code == 422
resp = client.get("/api/execution/history?limit=999")
assert resp.status_code == 422
@patch("app.routers.execution.get_connection")
def test_history_empty(self, mock_get_conn):
mock_cursor = MagicMock()
mock_cursor.fetchall.return_value = []
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.get("/api/execution/history")
assert resp.status_code == 200
assert resp.json() == []
# ---------------------------------------------------------------------------
# GET /api/execution/{id}/logs
# ---------------------------------------------------------------------------
class TestExecutionLogs:
@patch("app.routers.execution.get_connection")
@patch("app.routers.execution.task_executor")
def test_logs_from_db(self, mock_executor, mock_get_conn):
"""已完成任务从数据库读取日志"""
mock_executor.is_running.return_value = False
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = ("stdout output", "stderr output")
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.get("/api/execution/exec-1/logs")
assert resp.status_code == 200
data = resp.json()
assert data["execution_id"] == "exec-1"
assert data["output_log"] == "stdout output"
assert data["error_log"] == "stderr output"
@patch("app.routers.execution.task_executor")
def test_logs_from_memory(self, mock_executor):
"""执行中的任务从内存缓冲区读取"""
mock_executor.is_running.return_value = True
mock_executor.get_logs.return_value = ["line1", "line2"]
resp = client.get("/api/execution/running-id/logs")
assert resp.status_code == 200
data = resp.json()
assert data["execution_id"] == "running-id"
assert "line1" in data["output_log"]
assert "line2" in data["output_log"]
@patch("app.routers.execution.get_connection")
@patch("app.routers.execution.task_executor")
def test_logs_not_found(self, mock_executor, mock_get_conn):
mock_executor.is_running.return_value = False
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = None
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.get("/api/execution/nonexistent/logs")
assert resp.status_code == 404

View File

@@ -0,0 +1,510 @@
# -*- coding: utf-8 -*-
"""队列属性测试Property-Based Testing
使用 hypothesis 验证队列管理的通用正确性属性:
- Property 8: 队列 CRUD 不变量
- Property 9: 队列出队顺序
- Property 10: 队列重排一致性
- Property 11: 执行历史排序与限制
测试策略:
- Property 8-10 通过内存模拟队列状态mock 数据库操作,验证 TaskQueue 的核心逻辑
- Property 11 通过 mock 数据库返回,验证执行历史端点的排序与限制逻辑
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-queue-properties")
import json
import uuid
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, patch
from hypothesis import given, settings, assume
from hypothesis import strategies as st
from app.schemas.tasks import TaskConfigSchema
from app.services.task_queue import TaskQueue, QueuedTask
# ---------------------------------------------------------------------------
# 通用策略Strategies
# ---------------------------------------------------------------------------
_site_id_st = st.integers(min_value=1, max_value=2**31 - 1)
# 简单的任务代码列表
_task_codes = ["ODS_MEMBER", "ODS_PAYMENT", "ODS_ORDER", "DWD_LOAD_FROM_ODS", "DWS_SUMMARY"]
_simple_config_st = st.builds(
TaskConfigSchema,
tasks=st.lists(st.sampled_from(_task_codes), min_size=1, max_size=3, unique=True),
pipeline=st.sampled_from(["api_ods", "api_ods_dwd", "ods_dwd"]),
)
# ---------------------------------------------------------------------------
# 内存队列模拟器 — 用于 mock 数据库交互
# ---------------------------------------------------------------------------
class InMemoryQueueDB:
"""模拟 task_queue 表的内存存储,为 TaskQueue 方法提供 mock 数据库行为。"""
def __init__(self, site_id: int):
self.site_id = site_id
# 存储格式:{task_id: {config, status, position, ...}}
self.rows: dict[str, dict] = {}
@property
def pending_tasks(self) -> list[dict]:
"""按 position 排序的 pending 任务列表。"""
return sorted(
[r for r in self.rows.values() if r["status"] == "pending"],
key=lambda r: r["position"],
)
def mock_enqueue_connection(self):
"""为 enqueue 方法构造 mock connection。
enqueue 执行两条 SQL
1. SELECT COALESCE(MAX(position), 0) → 返回当前最大 position
2. INSERT INTO task_queue → 插入新行
"""
pending = self.pending_tasks
max_pos = max((r["position"] for r in pending), default=0)
call_count = [0]
db = self
def make_cursor():
cur = MagicMock()
executed_sqls = []
def execute_side_effect(sql, params=None):
executed_sqls.append((sql, params))
call_count[0] += 1
if "MAX(position)" in sql:
cur.fetchone.return_value = (max_pos,)
elif "INSERT INTO task_queue" in sql:
# 记录插入的行
task_id, site_id, config_json, new_pos = params
db.rows[task_id] = {
"id": task_id,
"site_id": site_id,
"config": json.loads(config_json),
"status": "pending",
"position": new_pos,
}
cur.execute = MagicMock(side_effect=execute_side_effect)
cur.__enter__ = MagicMock(return_value=cur)
cur.__exit__ = MagicMock(return_value=False)
return cur
conn = MagicMock()
conn.cursor.return_value = make_cursor()
return conn
def mock_dequeue_connection(self):
"""为 dequeue 方法构造 mock connection。
dequeue 执行两条 SQL
1. SELECT ... ORDER BY position ASC LIMIT 1 FOR UPDATE → 返回队首任务
2. UPDATE ... SET status = 'running' → 更新状态
"""
pending = self.pending_tasks
first = pending[0] if pending else None
db = self
def make_cursor():
cur = MagicMock()
def execute_side_effect(sql, params=None):
if "ORDER BY position ASC" in sql:
if first:
cur.fetchone.return_value = (
first["id"], first["site_id"],
json.dumps(first["config"]),
first["status"], first["position"],
None, None, None, None, None,
)
else:
cur.fetchone.return_value = None
elif "SET status = 'running'" in sql:
if first:
db.rows[first["id"]]["status"] = "running"
cur.execute = MagicMock(side_effect=execute_side_effect)
cur.__enter__ = MagicMock(return_value=cur)
cur.__exit__ = MagicMock(return_value=False)
return cur
conn = MagicMock()
conn.cursor.return_value = make_cursor()
return conn
def mock_delete_connection(self, task_id: str):
"""为 delete 方法构造 mock connection。"""
db = self
def make_cursor():
cur = MagicMock()
def execute_side_effect(sql, params=None):
tid = params[0]
if tid in db.rows and db.rows[tid]["status"] == "pending":
del db.rows[tid]
cur.rowcount = 1
else:
cur.rowcount = 0
cur.execute = MagicMock(side_effect=execute_side_effect)
cur.rowcount = 0
cur.__enter__ = MagicMock(return_value=cur)
cur.__exit__ = MagicMock(return_value=False)
return cur
conn = MagicMock()
conn.cursor.return_value = make_cursor()
return conn
def mock_reorder_connection(self):
"""为 reorder 方法构造 mock connection。
reorder 执行:
1. SELECT id FROM task_queue WHERE ... ORDER BY position ASC
2. 多次 UPDATE task_queue SET position = %s WHERE id = %s
"""
pending = self.pending_tasks
db = self
def make_cursor():
cur = MagicMock()
call_idx = [0]
def execute_side_effect(sql, params=None):
if "SELECT id FROM task_queue" in sql:
cur.fetchall.return_value = [(r["id"],) for r in pending]
elif "UPDATE task_queue SET position" in sql:
pos, tid = params
if tid in db.rows:
db.rows[tid]["position"] = pos
cur.execute = MagicMock(side_effect=execute_side_effect)
cur.__enter__ = MagicMock(return_value=cur)
cur.__exit__ = MagicMock(return_value=False)
return cur
conn = MagicMock()
conn.cursor.return_value = make_cursor()
return conn
def mock_list_pending_connection(self):
"""为 list_pending 方法构造 mock connection。"""
pending = self.pending_tasks
def make_cursor():
cur = MagicMock()
def execute_side_effect(sql, params=None):
cur.fetchall.return_value = [
(
r["id"], r["site_id"], json.dumps(r["config"]),
r["status"], r["position"],
None, None, None, None, None,
)
for r in pending
]
cur.execute = MagicMock(side_effect=execute_side_effect)
cur.__enter__ = MagicMock(return_value=cur)
cur.__exit__ = MagicMock(return_value=False)
return cur
conn = MagicMock()
conn.cursor.return_value = make_cursor()
return conn
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 8: 队列 CRUD 不变量
# **Validates: Requirements 4.1, 4.4**
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(
config=_simple_config_st,
site_id=_site_id_st,
initial_count=st.integers(min_value=0, max_value=5),
)
@patch("app.services.task_queue.get_connection")
def test_queue_crud_invariant(mock_get_conn, config, site_id, initial_count):
"""Property 8: 队列 CRUD 不变量。
入队一个任务后队列长度增加 1 且新任务状态为 pending
删除一个 pending 任务后队列长度减少 1 且该任务不再出现在队列中。
"""
queue = TaskQueue()
db = InMemoryQueueDB(site_id)
# 预填充若干任务
for i in range(initial_count):
tid = str(uuid.uuid4())
db.rows[tid] = {
"id": tid,
"site_id": site_id,
"config": {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"},
"status": "pending",
"position": i + 1,
}
before_count = len(db.pending_tasks)
# --- 入队 ---
mock_get_conn.return_value = db.mock_enqueue_connection()
new_id = queue.enqueue(config, site_id)
after_enqueue_count = len(db.pending_tasks)
assert after_enqueue_count == before_count + 1, (
f"入队后长度应 +1期望 {before_count + 1},实际 {after_enqueue_count}"
)
assert new_id in db.rows, "新任务应存在于队列中"
assert db.rows[new_id]["status"] == "pending", "新任务状态应为 pending"
# --- 删除刚入队的任务 ---
mock_get_conn.return_value = db.mock_delete_connection(new_id)
deleted = queue.delete(new_id, site_id)
after_delete_count = len(db.pending_tasks)
assert deleted is True, "删除 pending 任务应返回 True"
assert after_delete_count == before_count, (
f"删除后长度应恢复:期望 {before_count},实际 {after_delete_count}"
)
assert new_id not in db.rows, "已删除任务不应出现在队列中"
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 9: 队列出队顺序
# **Validates: Requirements 4.2**
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(
site_id=_site_id_st,
num_tasks=st.integers(min_value=1, max_value=8),
positions=st.data(),
)
@patch("app.services.task_queue.get_connection")
def test_queue_dequeue_order(mock_get_conn, site_id, num_tasks, positions):
"""Property 9: 队列出队顺序。
包含多个 pending 任务的队列dequeue 操作应返回 position 值最小的任务。
"""
queue = TaskQueue()
db = InMemoryQueueDB(site_id)
# 生成不重复的 position 值
pos_list = positions.draw(
st.lists(
st.integers(min_value=1, max_value=1000),
min_size=num_tasks,
max_size=num_tasks,
unique=True,
)
)
# 填充队列
task_ids = []
for i, pos in enumerate(pos_list):
tid = str(uuid.uuid4())
task_ids.append(tid)
db.rows[tid] = {
"id": tid,
"site_id": site_id,
"config": {"tasks": [_task_codes[i % len(_task_codes)]], "pipeline": "api_ods"},
"status": "pending",
"position": pos,
}
# 找出 position 最小的任务
expected_first = min(db.pending_tasks, key=lambda r: r["position"])
# dequeue
mock_get_conn.return_value = db.mock_dequeue_connection()
result = queue.dequeue(site_id)
assert result is not None, "队列非空时 dequeue 不应返回 None"
assert result.id == expected_first["id"], (
f"应返回 position 最小的任务:期望 id={expected_first['id']} "
f"(pos={expected_first['position']}),实际 id={result.id}"
)
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 10: 队列重排一致性
# **Validates: Requirements 4.3**
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(
site_id=_site_id_st,
num_tasks=st.integers(min_value=2, max_value=6),
data=st.data(),
)
@patch("app.services.task_queue.get_connection")
def test_queue_reorder_consistency(mock_get_conn, site_id, num_tasks, data):
"""Property 10: 队列重排一致性。
重排操作(将任务移动到新位置)后,队列中任务的相对顺序应与请求一致:
- 被移动的任务应出现在目标位置clamp 到有效范围)
- 其余任务保持原有相对顺序
- 所有任务仍在队列中(不丢失)
"""
queue = TaskQueue()
db = InMemoryQueueDB(site_id)
# 填充队列position 从 1 开始连续编号
task_ids = []
for i in range(num_tasks):
tid = str(uuid.uuid4())
task_ids.append(tid)
db.rows[tid] = {
"id": tid,
"site_id": site_id,
"config": {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"},
"status": "pending",
"position": i + 1,
}
# 随机选择要移动的任务和目标位置
move_idx = data.draw(st.integers(min_value=0, max_value=num_tasks - 1))
move_task_id = task_ids[move_idx]
new_position = data.draw(st.integers(min_value=1, max_value=num_tasks + 2))
# 执行 reorder
mock_get_conn.return_value = db.mock_reorder_connection()
queue.reorder(move_task_id, new_position, site_id)
# 验证:所有任务仍在队列中
remaining_ids = {r["id"] for r in db.rows.values() if r["status"] == "pending"}
assert remaining_ids == set(task_ids), "重排后不应丢失任何任务"
# 验证position 值连续且唯一1-based
positions = sorted(r["position"] for r in db.pending_tasks)
assert positions == list(range(1, num_tasks + 1)), (
f"重排后 position 应为连续编号 1..{num_tasks},实际 {positions}"
)
# 验证:被移动的任务在正确位置
# reorder 内部逻辑clamp new_position 到 [1, len(others)+1]
clamped_pos = max(1, min(new_position, num_tasks))
actual_pos = db.rows[move_task_id]["position"]
assert actual_pos == clamped_pos, (
f"被移动任务的 position 应为 {clamped_pos}clamp 后),实际 {actual_pos}"
)
# 验证:其余任务保持原有相对顺序
others_before = [tid for tid in task_ids if tid != move_task_id]
others_after = sorted(
[r for r in db.pending_tasks if r["id"] != move_task_id],
key=lambda r: r["position"],
)
others_after_ids = [r["id"] for r in others_after]
assert others_after_ids == others_before, (
"其余任务的相对顺序应保持不变"
)
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 11: 执行历史排序与限制
# **Validates: Requirements 4.5, 8.2**
# ---------------------------------------------------------------------------
# 导入 FastAPI 测试客户端
from app.auth.dependencies import CurrentUser, get_current_user
from app.main import app
from fastapi.testclient import TestClient
def _make_history_rows(count: int, site_id: int) -> list[tuple]:
"""生成 count 条执行历史记录started_at 随机但可排序。"""
base_time = datetime(2024, 1, 1, tzinfo=timezone.utc)
rows = []
for i in range(count):
rows.append((
str(uuid.uuid4()), # id
site_id, # site_id
["ODS_MEMBER"], # task_codes
"success", # status
base_time + timedelta(hours=i), # started_at
base_time + timedelta(hours=i, minutes=30), # finished_at
0, # exit_code
1800000, # duration_ms
"python -m cli.main", # command
None, # summary
))
return rows
@settings(max_examples=100, deadline=None)
@given(
site_id=_site_id_st,
total_records=st.integers(min_value=0, max_value=30),
limit=st.integers(min_value=1, max_value=200),
)
@patch("app.routers.execution.get_connection")
def test_execution_history_sort_and_limit(mock_get_conn, site_id, total_records, limit):
"""Property 11: 执行历史排序与限制。
执行历史记录集合API 返回的结果应按 started_at 降序排列,
且结果数量不超过请求的 limit 值。
"""
# 生成测试数据
all_rows = _make_history_rows(total_records, site_id)
# 模拟数据库:按 started_at DESC 排序后取 limit 条
sorted_rows = sorted(all_rows, key=lambda r: r[4], reverse=True)
returned_rows = sorted_rows[:limit]
# mock 数据库连接
mock_cursor = MagicMock()
mock_cursor.fetchall.return_value = returned_rows
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
# 覆盖认证依赖
test_user = CurrentUser(user_id=1, site_id=site_id)
app.dependency_overrides[get_current_user] = lambda: test_user
try:
client = TestClient(app)
# limit 必须在 [1, 200] 范围内API 约束)
clamped_limit = max(1, min(limit, 200))
resp = client.get(f"/api/execution/history?limit={clamped_limit}")
assert resp.status_code == 200
data = resp.json()
# 验证 1结果数量不超过 limit
assert len(data) <= clamped_limit, (
f"结果数量 {len(data)} 超过 limit {clamped_limit}"
)
# 验证 2结果数量不超过总记录数
assert len(data) <= total_records, (
f"结果数量 {len(data)} 超过总记录数 {total_records}"
)
# 验证 3按 started_at 降序排列
if len(data) >= 2:
for i in range(len(data) - 1):
t1 = data[i]["started_at"]
t2 = data[i + 1]["started_at"]
assert t1 >= t2, (
f"结果未按 started_at 降序排列data[{i}]={t1} < data[{i+1}]={t2}"
)
finally:
app.dependency_overrides[get_current_user] = lambda: CurrentUser(user_id=1, site_id=100)

View File

@@ -0,0 +1,439 @@
# -*- coding: utf-8 -*-
"""调度属性测试Property-Based Testing
使用 hypothesis 验证调度管理的通用正确性属性:
- Property 12: 调度任务 CRUD 往返
- Property 13: 到期调度任务自动入队
- Property 14: 调度任务启用/禁用状态
测试策略:
- Property 12: 通过 mock 数据库,验证 POST 创建后 GET 返回的 schedule_config 与提交的一致
- Property 13: 通过 mock 数据库返回到期任务,验证 check_and_enqueue 调用了 task_queue.enqueue
- Property 14: 通过 mock 数据库,验证 toggle 端点的 next_run_at 行为
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-schedule-properties")
import json
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, patch
from hypothesis import given, settings, assume
from hypothesis import strategies as st
from app.auth.dependencies import CurrentUser, get_current_user
from app.main import app
from app.schemas.schedules import ScheduleConfigSchema
from app.schemas.tasks import TaskConfigSchema
from app.services.scheduler import Scheduler, calculate_next_run
from fastapi.testclient import TestClient
# ---------------------------------------------------------------------------
# 通用策略Strategies
# ---------------------------------------------------------------------------
_site_id_st = st.integers(min_value=1, max_value=2**31 - 1)
_task_codes = ["ODS_MEMBER", "ODS_PAYMENT", "ODS_ORDER", "DWD_LOAD_FROM_ODS", "DWS_SUMMARY"]
_simple_task_config_st = st.fixed_dictionaries({
"tasks": st.lists(st.sampled_from(_task_codes), min_size=1, max_size=3, unique=True),
"pipeline": st.sampled_from(["api_ods", "api_ods_dwd", "ods_dwd", "api_full"]),
})
# 调度配置策略:覆盖 5 种调度类型
_schedule_type_st = st.sampled_from(["once", "interval", "daily", "weekly", "cron"])
_interval_unit_st = st.sampled_from(["minutes", "hours", "days"])
# HH:MM 格式的时间字符串
_time_str_st = st.builds(
lambda h, m: f"{h:02d}:{m:02d}",
h=st.integers(min_value=0, max_value=23),
m=st.integers(min_value=0, max_value=59),
)
# ISO weekday 列表1=Monday ... 7=Sunday
_weekly_days_st = st.lists(
st.integers(min_value=1, max_value=7),
min_size=1, max_size=7, unique=True,
)
# 简单 cron 表达式minute hour * * *
_cron_st = st.builds(
lambda m, h: f"{m} {h} * * *",
m=st.integers(min_value=0, max_value=59),
h=st.integers(min_value=0, max_value=23),
)
def _build_schedule_config(schedule_type, interval_value, interval_unit,
daily_time, weekly_days, weekly_time, cron_expression):
"""根据 schedule_type 构建 ScheduleConfigSchema。"""
return ScheduleConfigSchema(
schedule_type=schedule_type,
interval_value=interval_value,
interval_unit=interval_unit,
daily_time=daily_time,
weekly_days=weekly_days,
weekly_time=weekly_time,
cron_expression=cron_expression,
enabled=True,
)
_schedule_config_st = st.builds(
_build_schedule_config,
schedule_type=_schedule_type_st,
interval_value=st.integers(min_value=1, max_value=168),
interval_unit=_interval_unit_st,
daily_time=_time_str_st,
weekly_days=_weekly_days_st,
weekly_time=_time_str_st,
cron_expression=_cron_st,
)
# 用于 Property 14 的非 once 调度配置(启用后 next_run_at 应非 NULL
_non_once_schedule_type_st = st.sampled_from(["interval", "daily", "weekly", "cron"])
_non_once_schedule_config_st = st.builds(
_build_schedule_config,
schedule_type=_non_once_schedule_type_st,
interval_value=st.integers(min_value=1, max_value=168),
interval_unit=_interval_unit_st,
daily_time=_time_str_st,
weekly_days=_weekly_days_st,
weekly_time=_time_str_st,
cron_expression=_cron_st,
)
# ---------------------------------------------------------------------------
# 辅助函数
# ---------------------------------------------------------------------------
_NOW = datetime(2025, 6, 10, 10, 0, 0, tzinfo=timezone.utc)
# 模拟数据库行的列顺序(与 _SELECT_COLS 对应,共 13 列)
# id, site_id, name, task_codes, task_config, schedule_config,
# enabled, last_run_at, next_run_at, run_count, last_status,
# created_at, updated_at
def _make_db_row(
schedule_id: str,
site_id: int,
name: str,
task_codes: list[str],
task_config: dict,
schedule_config: dict,
enabled: bool = True,
next_run_at: datetime | None = None,
) -> tuple:
"""构造模拟数据库行。"""
return (
schedule_id, site_id, name, task_codes,
json.dumps(task_config) if isinstance(task_config, dict) else task_config,
json.dumps(schedule_config) if isinstance(schedule_config, dict) else schedule_config,
enabled, None, next_run_at, 0, None, _NOW, _NOW,
)
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 12: 调度任务 CRUD 往返
# **Validates: Requirements 5.1, 5.4**
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(
site_id=_site_id_st,
schedule_config=_schedule_config_st,
task_config=_simple_task_config_st,
name=st.text(min_size=1, max_size=50, alphabet=st.characters(
whitelist_categories=("L", "N"), whitelist_characters="_- "
)),
task_codes=st.lists(st.sampled_from(_task_codes), min_size=1, max_size=3, unique=True),
)
@patch("app.routers.schedules.get_connection")
def test_schedule_crud_round_trip(
mock_get_conn, site_id, schedule_config, task_config, name, task_codes,
):
"""Property 12: 调度任务 CRUD 往返。
有效的 ScheduleConfigSchema创建调度任务后再查询该任务
返回的调度配置应与创建时提交的配置等价。
"""
schedule_config_dict = schedule_config.model_dump()
next_run = calculate_next_run(schedule_config, _NOW)
# 构造创建后数据库返回的行
created_row = _make_db_row(
schedule_id="test-sched-id",
site_id=site_id,
name=name,
task_codes=task_codes,
task_config=task_config,
schedule_config=schedule_config_dict,
enabled=schedule_config.enabled,
next_run_at=next_run,
)
# --- 创建阶段 ---
# mock POST 的数据库连接INSERT ... RETURNING
create_cursor = MagicMock()
create_cursor.fetchone.return_value = created_row
create_conn = MagicMock()
create_conn.cursor.return_value.__enter__ = MagicMock(return_value=create_cursor)
create_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
# --- 查询阶段 ---
# mock GET 的数据库连接SELECT ... fetchall
list_cursor = MagicMock()
list_cursor.fetchall.return_value = [created_row]
list_conn = MagicMock()
list_conn.cursor.return_value.__enter__ = MagicMock(return_value=list_cursor)
list_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
# 依次返回 create_conn 和 list_conn
mock_get_conn.side_effect = [create_conn, list_conn]
# 覆盖认证
test_user = CurrentUser(user_id=1, site_id=site_id)
app.dependency_overrides[get_current_user] = lambda: test_user
try:
client = TestClient(app)
# 创建调度任务
create_body = {
"name": name,
"task_codes": task_codes,
"task_config": task_config,
"schedule_config": schedule_config_dict,
}
create_resp = client.post("/api/schedules", json=create_body)
assert create_resp.status_code == 201, (
f"创建应返回 201实际 {create_resp.status_code}: {create_resp.text}"
)
created_data = create_resp.json()
# 查询调度任务列表
list_resp = client.get("/api/schedules")
assert list_resp.status_code == 200
list_data = list_resp.json()
assert len(list_data) >= 1, "查询结果应至少包含刚创建的任务"
# 找到刚创建的任务
found = next((s for s in list_data if s["id"] == created_data["id"]), None)
assert found is not None, "查询结果应包含刚创建的任务"
# 核心验证schedule_config 往返一致
returned_config = found["schedule_config"]
for key in schedule_config_dict:
assert returned_config[key] == schedule_config_dict[key], (
f"schedule_config.{key} 不一致:"
f"提交={schedule_config_dict[key]},返回={returned_config[key]}"
)
# 验证 task_config 往返一致
returned_task_config = found["task_config"]
for key in task_config:
assert returned_task_config[key] == task_config[key], (
f"task_config.{key} 不一致:提交={task_config[key]},返回={returned_task_config[key]}"
)
# 验证基本字段
assert found["name"] == name
assert found["task_codes"] == task_codes
assert found["site_id"] == site_id
finally:
app.dependency_overrides[get_current_user] = lambda: CurrentUser(user_id=1, site_id=100)
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 13: 到期调度任务自动入队
# **Validates: Requirements 5.2**
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(
site_id=_site_id_st,
schedule_config=_schedule_config_st,
task_config=_simple_task_config_st,
)
@patch("app.services.scheduler.task_queue")
@patch("app.services.scheduler.get_connection")
def test_due_schedule_auto_enqueue(
mock_get_conn, mock_tq, site_id, schedule_config, task_config,
):
"""Property 13: 到期调度任务自动入队。
enabled 为 true 且 next_run_at 早于当前时间的调度任务,
check_and_enqueue 执行后该任务的 TaskConfig 应出现在执行队列中。
"""
sched = Scheduler()
schedule_config_dict = schedule_config.model_dump()
# 构造到期任务next_run_at 在过去(比 now 早 5 分钟)
task_id = "due-task-001"
# --- mock SELECT 到期任务 ---
select_cursor = MagicMock()
select_cursor.fetchall.return_value = [
(task_id, site_id, json.dumps(task_config), json.dumps(schedule_config_dict)),
]
select_cursor.__enter__ = MagicMock(return_value=select_cursor)
select_cursor.__exit__ = MagicMock(return_value=False)
# --- mock UPDATE 调度状态 ---
update_cursor = MagicMock()
update_cursor.__enter__ = MagicMock(return_value=update_cursor)
update_cursor.__exit__ = MagicMock(return_value=False)
conn = MagicMock()
conn.cursor.side_effect = [select_cursor, update_cursor]
mock_get_conn.return_value = conn
mock_tq.enqueue.return_value = "queue-id-123"
# 执行
count = sched.check_and_enqueue()
# 验证:到期任务被入队
assert count == 1, f"应有 1 个任务入队,实际 {count}"
mock_tq.enqueue.assert_called_once()
# 验证入队参数
call_args = mock_tq.enqueue.call_args
enqueued_config = call_args[0][0]
enqueued_site_id = call_args[0][1]
# site_id 应匹配
assert enqueued_site_id == site_id, (
f"入队的 site_id 应为 {site_id},实际 {enqueued_site_id}"
)
# TaskConfig 应与原始配置一致
assert isinstance(enqueued_config, TaskConfigSchema)
assert enqueued_config.tasks == task_config["tasks"], (
f"入队的 tasks 应为 {task_config['tasks']},实际 {enqueued_config.tasks}"
)
assert enqueued_config.pipeline == task_config["pipeline"], (
f"入队的 pipeline 应为 {task_config['pipeline']},实际 {enqueued_config.pipeline}"
)
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 14: 调度任务启用/禁用状态
# **Validates: Requirements 5.3**
# ---------------------------------------------------------------------------
@settings(max_examples=100, deadline=None)
@given(
site_id=_site_id_st,
schedule_config=_non_once_schedule_config_st,
task_config=_simple_task_config_st,
name=st.text(min_size=1, max_size=30, alphabet=st.characters(
whitelist_categories=("L", "N"), whitelist_characters="_- "
)),
task_codes=st.lists(st.sampled_from(_task_codes), min_size=1, max_size=3, unique=True),
)
@patch("app.routers.schedules.get_connection")
def test_schedule_toggle_next_run(
mock_get_conn, site_id, schedule_config, task_config, name, task_codes,
):
"""Property 14: 调度任务启用/禁用状态。
禁用后 next_run_at 应为 NULL
重新启用后 next_run_at 应被重新计算为非 NULL 值(对于非一次性调度)。
"""
schedule_config_dict = schedule_config.model_dump()
next_run_enabled = calculate_next_run(schedule_config, _NOW)
# --- 第一步禁用enabled=True → False---
# toggle 端点先 SELECT 当前状态,再 UPDATE RETURNING
# 禁用后的数据库行
disabled_row = _make_db_row(
schedule_id="sched-toggle-1",
site_id=site_id,
name=name,
task_codes=task_codes,
task_config=task_config,
schedule_config=schedule_config_dict,
enabled=False,
next_run_at=None, # 禁用后 next_run_at 为 NULL
)
# mock 禁用操作的数据库连接
disable_cursor = MagicMock()
disable_cursor.fetchone.side_effect = [
(True, json.dumps(schedule_config_dict)), # SELECT 当前状态enabled=True
disabled_row, # UPDATE RETURNING
]
disable_conn = MagicMock()
disable_conn.cursor.return_value.__enter__ = MagicMock(return_value=disable_cursor)
disable_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
# --- 第二步启用enabled=False → True---
enabled_row = _make_db_row(
schedule_id="sched-toggle-1",
site_id=site_id,
name=name,
task_codes=task_codes,
task_config=task_config,
schedule_config=schedule_config_dict,
enabled=True,
next_run_at=next_run_enabled, # 启用后 next_run_at 被重新计算
)
enable_cursor = MagicMock()
enable_cursor.fetchone.side_effect = [
(False, json.dumps(schedule_config_dict)), # SELECT 当前状态enabled=False
enabled_row, # UPDATE RETURNING
]
enable_conn = MagicMock()
enable_conn.cursor.return_value.__enter__ = MagicMock(return_value=enable_cursor)
enable_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
# 依次返回两个连接
mock_get_conn.side_effect = [disable_conn, enable_conn]
# 覆盖认证
test_user = CurrentUser(user_id=1, site_id=site_id)
app.dependency_overrides[get_current_user] = lambda: test_user
try:
client = TestClient(app)
# 禁用
disable_resp = client.patch("/api/schedules/sched-toggle-1/toggle")
assert disable_resp.status_code == 200, (
f"禁用应返回 200实际 {disable_resp.status_code}: {disable_resp.text}"
)
disable_data = disable_resp.json()
# 验证:禁用后 enabled=Falsenext_run_at=NULL
assert disable_data["enabled"] is False, "禁用后 enabled 应为 False"
assert disable_data["next_run_at"] is None, "禁用后 next_run_at 应为 NULL"
# 启用
enable_resp = client.patch("/api/schedules/sched-toggle-1/toggle")
assert enable_resp.status_code == 200, (
f"启用应返回 200实际 {enable_resp.status_code}: {enable_resp.text}"
)
enable_data = enable_resp.json()
# 验证:启用后 enabled=Truenext_run_at 非 NULL非一次性调度
assert enable_data["enabled"] is True, "启用后 enabled 应为 True"
assert enable_data["next_run_at"] is not None, (
"非一次性调度启用后 next_run_at 应被重新计算为非 NULL 值"
)
finally:
app.dependency_overrides[get_current_user] = lambda: CurrentUser(user_id=1, site_id=100)

View File

@@ -0,0 +1,384 @@
# -*- coding: utf-8 -*-
"""Scheduler 单元测试
覆盖:
- calculate_next_run各种调度类型的下次执行时间计算
- _parse_simple_cron简单 cron 表达式解析
- check_and_enqueue到期检查与入队逻辑
- start / stop后台循环生命周期
"""
import asyncio
import json
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, patch
import pytest
from app.schemas.schedules import ScheduleConfigSchema
from app.schemas.tasks import TaskConfigSchema
from app.services.scheduler import (
Scheduler,
calculate_next_run,
_parse_simple_cron,
_parse_time,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def sched() -> Scheduler:
return Scheduler()
@pytest.fixture
def now() -> datetime:
"""固定时间点2025-06-10 10:00:00 UTC周二"""
return datetime(2025, 6, 10, 10, 0, 0, tzinfo=timezone.utc)
def _mock_cursor(fetchone_val=None, fetchall_val=None, rowcount=1):
cur = MagicMock()
cur.fetchone.return_value = fetchone_val
cur.fetchall.return_value = fetchall_val or []
cur.rowcount = rowcount
cur.__enter__ = MagicMock(return_value=cur)
cur.__exit__ = MagicMock(return_value=False)
return cur
def _mock_conn(cursor):
conn = MagicMock()
conn.cursor.return_value = cursor
return conn
# ---------------------------------------------------------------------------
# _parse_time
# ---------------------------------------------------------------------------
class TestParseTime:
def test_standard_format(self):
assert _parse_time("04:00") == (4, 0)
def test_with_minutes(self):
assert _parse_time("23:45") == (23, 45)
def test_midnight(self):
assert _parse_time("00:00") == (0, 0)
# ---------------------------------------------------------------------------
# calculate_next_run — once
# ---------------------------------------------------------------------------
class TestNextRunOnce:
def test_once_returns_none(self, now):
cfg = ScheduleConfigSchema(schedule_type="once")
assert calculate_next_run(cfg, now) is None
# ---------------------------------------------------------------------------
# calculate_next_run — interval
# ---------------------------------------------------------------------------
class TestNextRunInterval:
def test_interval_minutes(self, now):
cfg = ScheduleConfigSchema(
schedule_type="interval", interval_value=15, interval_unit="minutes",
)
result = calculate_next_run(cfg, now)
assert result == now + timedelta(minutes=15)
def test_interval_hours(self, now):
cfg = ScheduleConfigSchema(
schedule_type="interval", interval_value=2, interval_unit="hours",
)
result = calculate_next_run(cfg, now)
assert result == now + timedelta(hours=2)
def test_interval_days(self, now):
cfg = ScheduleConfigSchema(
schedule_type="interval", interval_value=3, interval_unit="days",
)
result = calculate_next_run(cfg, now)
assert result == now + timedelta(days=3)
# ---------------------------------------------------------------------------
# calculate_next_run — daily
# ---------------------------------------------------------------------------
class TestNextRunDaily:
def test_daily_next_day(self, now):
cfg = ScheduleConfigSchema(schedule_type="daily", daily_time="04:00")
result = calculate_next_run(cfg, now)
expected = datetime(2025, 6, 11, 4, 0, 0, tzinfo=timezone.utc)
assert result == expected
def test_daily_custom_time(self, now):
cfg = ScheduleConfigSchema(schedule_type="daily", daily_time="18:30")
result = calculate_next_run(cfg, now)
expected = datetime(2025, 6, 11, 18, 30, 0, tzinfo=timezone.utc)
assert result == expected
# ---------------------------------------------------------------------------
# calculate_next_run — weekly
# ---------------------------------------------------------------------------
class TestNextRunWeekly:
def test_weekly_later_this_week(self, now):
# now 是周二(2)weekly_days=[5] 周五 → 3 天后
cfg = ScheduleConfigSchema(
schedule_type="weekly", weekly_days=[5], weekly_time="08:00",
)
result = calculate_next_run(cfg, now)
expected = datetime(2025, 6, 13, 8, 0, 0, tzinfo=timezone.utc)
assert result == expected
def test_weekly_next_week(self, now):
# now 是周二(2)weekly_days=[1] 周一 → 下周一6天后
cfg = ScheduleConfigSchema(
schedule_type="weekly", weekly_days=[1], weekly_time="04:00",
)
result = calculate_next_run(cfg, now)
expected = datetime(2025, 6, 16, 4, 0, 0, tzinfo=timezone.utc)
assert result == expected
def test_weekly_multiple_days_picks_next(self, now):
# now 是周二(2)weekly_days=[1, 4, 6] → 周四(4)2 天后
cfg = ScheduleConfigSchema(
schedule_type="weekly", weekly_days=[1, 4, 6], weekly_time="09:00",
)
result = calculate_next_run(cfg, now)
expected = datetime(2025, 6, 12, 9, 0, 0, tzinfo=timezone.utc)
assert result == expected
# ---------------------------------------------------------------------------
# calculate_next_run — cron
# ---------------------------------------------------------------------------
class TestNextRunCron:
def test_cron_daily(self, now):
cfg = ScheduleConfigSchema(schedule_type="cron", cron_expression="30 4 * * *")
result = calculate_next_run(cfg, now)
expected = datetime(2025, 6, 11, 4, 30, 0, tzinfo=timezone.utc)
assert result == expected
def test_cron_with_dow(self, now):
# "0 8 * * 5" → 每周五 08:00now 是周二 → 周五3天后
cfg = ScheduleConfigSchema(schedule_type="cron", cron_expression="0 8 * * 5")
result = calculate_next_run(cfg, now)
expected = datetime(2025, 6, 13, 8, 0, 0, tzinfo=timezone.utc)
assert result == expected
# ---------------------------------------------------------------------------
# _parse_simple_cron
# ---------------------------------------------------------------------------
class TestParseSimpleCron:
def test_daily_cron(self, now):
result = _parse_simple_cron("0 4 * * *", now)
expected = datetime(2025, 6, 11, 4, 0, 0, tzinfo=timezone.utc)
assert result == expected
def test_invalid_field_count_fallback(self, now):
# 字段数不对,回退到明天 04:00
result = _parse_simple_cron("0 4 *", now)
expected = datetime(2025, 6, 11, 4, 0, 0, tzinfo=timezone.utc)
assert result == expected
def test_wildcard_hour_minute(self, now):
# "* * * * *" → hour=0, minute=0明天 00:00
result = _parse_simple_cron("* * * * *", now)
expected = datetime(2025, 6, 11, 0, 0, 0, tzinfo=timezone.utc)
assert result == expected
def test_dow_sunday(self, now):
# "0 6 * * 0" → 每周日 06:00now 是周二 → 周日5天后
result = _parse_simple_cron("0 6 * * 0", now)
expected = datetime(2025, 6, 15, 6, 0, 0, tzinfo=timezone.utc)
assert result == expected
def test_dow_same_day_future_time(self):
# 周二 08:00cron 指定周二 12:00 → 当天
now = datetime(2025, 6, 10, 8, 0, 0, tzinfo=timezone.utc)
result = _parse_simple_cron("0 12 * * 2", now)
expected = datetime(2025, 6, 10, 12, 0, 0, tzinfo=timezone.utc)
assert result == expected
def test_dow_same_day_past_time(self):
# 周二 14:00cron 指定周二 12:00 → 下周二
now = datetime(2025, 6, 10, 14, 0, 0, tzinfo=timezone.utc)
result = _parse_simple_cron("0 12 * * 2", now)
expected = datetime(2025, 6, 17, 12, 0, 0, tzinfo=timezone.utc)
assert result == expected
# ---------------------------------------------------------------------------
# check_and_enqueue
# ---------------------------------------------------------------------------
class TestCheckAndEnqueue:
@patch("app.services.scheduler.get_connection")
@patch("app.services.scheduler.task_queue")
def test_enqueues_due_tasks(self, mock_tq, mock_get_conn, sched):
"""到期任务应被入队,且更新 last_run_at / run_count / next_run_at"""
task_config = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods_dwd"}
schedule_config = {
"schedule_type": "interval",
"interval_value": 1,
"interval_unit": "hours",
}
# 第一次 cursorSELECT 到期任务
select_cur = _mock_cursor(
fetchall_val=[
("task-uuid-1", 42, json.dumps(task_config), json.dumps(schedule_config)),
]
)
# 第二次 cursorUPDATE
update_cur = _mock_cursor()
conn = MagicMock()
# cursor() 依次返回 select_cur 和 update_cur
conn.cursor.side_effect = [select_cur, update_cur]
mock_get_conn.return_value = conn
mock_tq.enqueue.return_value = "queue-id-1"
count = sched.check_and_enqueue()
assert count == 1
mock_tq.enqueue.assert_called_once()
# 验证 enqueue 的参数
call_args = mock_tq.enqueue.call_args
assert call_args[0][1] == 42 # site_id
assert isinstance(call_args[0][0], TaskConfigSchema)
@patch("app.services.scheduler.get_connection")
@patch("app.services.scheduler.task_queue")
def test_no_due_tasks(self, mock_tq, mock_get_conn, sched):
"""没有到期任务时,不入队"""
cur = _mock_cursor(fetchall_val=[])
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
count = sched.check_and_enqueue()
assert count == 0
mock_tq.enqueue.assert_not_called()
@patch("app.services.scheduler.get_connection")
@patch("app.services.scheduler.task_queue")
def test_skips_invalid_config(self, mock_tq, mock_get_conn, sched):
"""配置反序列化失败的任务应被跳过"""
# task_config 缺少必填字段 tasks
bad_config = {"pipeline": "api_ods_dwd"}
schedule_config = {"schedule_type": "once"}
cur = _mock_cursor(
fetchall_val=[
("task-uuid-bad", 42, json.dumps(bad_config), json.dumps(schedule_config)),
]
)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
count = sched.check_and_enqueue()
assert count == 0
mock_tq.enqueue.assert_not_called()
@patch("app.services.scheduler.get_connection")
@patch("app.services.scheduler.task_queue")
def test_enqueue_failure_continues(self, mock_tq, mock_get_conn, sched):
"""入队失败时应跳过该任务,继续处理后续任务"""
task_config = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods_dwd"}
schedule_config = {"schedule_type": "once"}
cur = _mock_cursor(
fetchall_val=[
("task-1", 42, json.dumps(task_config), json.dumps(schedule_config)),
("task-2", 42, json.dumps(task_config), json.dumps(schedule_config)),
]
)
# 需要额外的 cursor 给 UPDATE 用
update_cur = _mock_cursor()
conn = MagicMock()
conn.cursor.side_effect = [cur, update_cur]
mock_get_conn.return_value = conn
# 第一次入队失败,第二次成功
mock_tq.enqueue.side_effect = [Exception("DB error"), "queue-id-2"]
count = sched.check_and_enqueue()
assert count == 1
assert mock_tq.enqueue.call_count == 2
@patch("app.services.scheduler.get_connection")
@patch("app.services.scheduler.task_queue")
def test_once_type_sets_next_run_none(self, mock_tq, mock_get_conn, sched):
"""once 类型任务入队后next_run_at 应被设为 NULL"""
task_config = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods_dwd"}
schedule_config = {"schedule_type": "once"}
select_cur = _mock_cursor(
fetchall_val=[
("task-uuid-1", 42, json.dumps(task_config), json.dumps(schedule_config)),
]
)
update_cur = _mock_cursor()
conn = MagicMock()
conn.cursor.side_effect = [select_cur, update_cur]
mock_get_conn.return_value = conn
mock_tq.enqueue.return_value = "queue-id-1"
sched.check_and_enqueue()
# 验证 UPDATE 语句中 next_run_at 参数为 None
update_call = update_cur.__enter__().execute.call_args
# 参数元组的第一个元素是 next_run_at
assert update_call[0][1][0] is None
# ---------------------------------------------------------------------------
# start / stop 生命周期
# ---------------------------------------------------------------------------
class TestLifecycle:
@pytest.mark.asyncio
async def test_stop_sets_running_false(self, sched):
sched._running = True
await sched.stop()
assert sched._running is False
assert sched._loop_task is None
def test_start_creates_task(self, sched):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# 在事件循环中启动
async def _run():
sched.start()
assert sched._loop_task is not None
assert not sched._loop_task.done()
await sched.stop()
loop.run_until_complete(_run())
finally:
loop.close()
@pytest.mark.asyncio
async def test_start_stop_idempotent(self, sched):
"""多次 stop 不应报错"""
await sched.stop()
await sched.stop()
assert sched._loop_task is None

View File

@@ -0,0 +1,310 @@
# -*- coding: utf-8 -*-
"""调度路由单元测试
覆盖 5 个端点list / create / update / delete / toggle
通过 mock 绕过数据库,专注路由逻辑验证。
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
import json
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from app.auth.dependencies import CurrentUser, get_current_user
from app.main import app
_TEST_USER = CurrentUser(user_id=1, site_id=100)
def _override_auth():
return _TEST_USER
app.dependency_overrides[get_current_user] = _override_auth
client = TestClient(app)
_NOW = datetime(2024, 6, 1, 12, 0, 0, tzinfo=timezone.utc)
_NEXT = datetime(2024, 6, 2, 4, 0, 0, tzinfo=timezone.utc)
_SCHEDULE_CONFIG = {
"schedule_type": "daily",
"daily_time": "04:00",
}
_VALID_CREATE = {
"name": "每日全量同步",
"task_codes": ["ODS_MEMBER", "ODS_ORDER"],
"task_config": {"tasks": ["ODS_MEMBER", "ODS_ORDER"], "pipeline": "api_ods"},
"schedule_config": _SCHEDULE_CONFIG,
}
# 模拟数据库返回的完整行13 列,与 _SELECT_COLS 对应)
_DB_ROW = (
"sched-1", 100, "每日全量同步", ["ODS_MEMBER", "ODS_ORDER"],
json.dumps({"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}),
json.dumps(_SCHEDULE_CONFIG),
True, None, _NEXT, 0, None, _NOW, _NOW,
)
def _mock_conn_with_fetchall(rows):
"""构造返回 fetchall 的 mock 连接。"""
mock_cursor = MagicMock()
mock_cursor.fetchall.return_value = rows
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
return mock_conn, mock_cursor
def _mock_conn_with_fetchone(row):
"""构造返回 fetchone 的 mock 连接。"""
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = row
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
return mock_conn, mock_cursor
# ---------------------------------------------------------------------------
# GET /api/schedules
# ---------------------------------------------------------------------------
class TestListSchedules:
@patch("app.routers.schedules.get_connection")
def test_list_returns_schedules(self, mock_get_conn):
mock_conn, _ = _mock_conn_with_fetchall([_DB_ROW])
mock_get_conn.return_value = mock_conn
resp = client.get("/api/schedules")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 1
assert data[0]["id"] == "sched-1"
assert data[0]["name"] == "每日全量同步"
assert data[0]["site_id"] == 100
assert data[0]["enabled"] is True
@patch("app.routers.schedules.get_connection")
def test_list_empty(self, mock_get_conn):
mock_conn, _ = _mock_conn_with_fetchall([])
mock_get_conn.return_value = mock_conn
resp = client.get("/api/schedules")
assert resp.status_code == 200
assert resp.json() == []
@patch("app.routers.schedules.get_connection")
def test_list_filters_by_site_id(self, mock_get_conn):
mock_conn, mock_cursor = _mock_conn_with_fetchall([])
mock_get_conn.return_value = mock_conn
client.get("/api/schedules")
call_args = mock_cursor.execute.call_args
assert call_args[0][1] == (100,)
# ---------------------------------------------------------------------------
# POST /api/schedules
# ---------------------------------------------------------------------------
class TestCreateSchedule:
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
@patch("app.routers.schedules.get_connection")
def test_create_returns_201(self, mock_get_conn, mock_calc):
mock_conn, mock_cursor = _mock_conn_with_fetchone(_DB_ROW)
mock_get_conn.return_value = mock_conn
resp = client.post("/api/schedules", json=_VALID_CREATE)
assert resp.status_code == 201
data = resp.json()
assert data["id"] == "sched-1"
assert data["name"] == "每日全量同步"
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
@patch("app.routers.schedules.get_connection")
def test_create_injects_site_id(self, mock_get_conn, mock_calc):
mock_conn, mock_cursor = _mock_conn_with_fetchone(_DB_ROW)
mock_get_conn.return_value = mock_conn
client.post("/api/schedules", json=_VALID_CREATE)
# INSERT 的第一个参数应为 site_id=100
insert_params = mock_cursor.execute.call_args[0][1]
assert insert_params[0] == 100
def test_create_missing_name_returns_422(self):
body = {**_VALID_CREATE}
del body["name"]
resp = client.post("/api/schedules", json=body)
assert resp.status_code == 422
def test_create_invalid_schedule_type_returns_422(self):
body = {**_VALID_CREATE, "schedule_config": {"schedule_type": "invalid"}}
resp = client.post("/api/schedules", json=body)
assert resp.status_code == 422
# ---------------------------------------------------------------------------
# PUT /api/schedules/{id}
# ---------------------------------------------------------------------------
class TestUpdateSchedule:
@patch("app.routers.schedules.get_connection")
def test_update_name(self, mock_get_conn):
updated_row = list(_DB_ROW)
updated_row[2] = "新名称"
mock_conn, _ = _mock_conn_with_fetchone(tuple(updated_row))
mock_get_conn.return_value = mock_conn
resp = client.put("/api/schedules/sched-1", json={"name": "新名称"})
assert resp.status_code == 200
assert resp.json()["name"] == "新名称"
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
@patch("app.routers.schedules.get_connection")
def test_update_schedule_config_recalculates_next_run(self, mock_get_conn, mock_calc):
mock_conn, _ = _mock_conn_with_fetchone(_DB_ROW)
mock_get_conn.return_value = mock_conn
resp = client.put("/api/schedules/sched-1", json={
"schedule_config": {"schedule_type": "interval", "interval_value": 2, "interval_unit": "hours"},
})
assert resp.status_code == 200
mock_calc.assert_called_once()
@patch("app.routers.schedules.get_connection")
def test_update_not_found(self, mock_get_conn):
mock_conn, _ = _mock_conn_with_fetchone(None)
mock_get_conn.return_value = mock_conn
resp = client.put("/api/schedules/nonexistent", json={"name": "x"})
assert resp.status_code == 404
def test_update_empty_body_returns_422(self):
resp = client.put("/api/schedules/sched-1", json={})
assert resp.status_code == 422
# ---------------------------------------------------------------------------
# DELETE /api/schedules/{id}
# ---------------------------------------------------------------------------
class TestDeleteSchedule:
@patch("app.routers.schedules.get_connection")
def test_delete_success(self, mock_get_conn):
mock_cursor = MagicMock()
mock_cursor.rowcount = 1
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.delete("/api/schedules/sched-1")
assert resp.status_code == 200
assert "已删除" in resp.json()["message"]
@patch("app.routers.schedules.get_connection")
def test_delete_not_found(self, mock_get_conn):
mock_cursor = MagicMock()
mock_cursor.rowcount = 0
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.delete("/api/schedules/nonexistent")
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# PATCH /api/schedules/{id}/toggle
# ---------------------------------------------------------------------------
class TestToggleSchedule:
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
@patch("app.routers.schedules.get_connection")
def test_toggle_disable(self, mock_get_conn, mock_calc):
"""启用 → 禁用next_run_at 应置 NULL"""
# 第一次 fetchone 返回当前状态enabled=True
# 第二次 fetchone 返回更新后的行
disabled_row = list(_DB_ROW)
disabled_row[6] = False # enabled
disabled_row[8] = None # next_run_at
mock_cursor = MagicMock()
mock_cursor.fetchone.side_effect = [
(True, json.dumps(_SCHEDULE_CONFIG)), # SELECT 当前状态
tuple(disabled_row), # UPDATE RETURNING
]
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.patch("/api/schedules/sched-1/toggle")
assert resp.status_code == 200
data = resp.json()
assert data["enabled"] is False
assert data["next_run_at"] is None
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
@patch("app.routers.schedules.get_connection")
def test_toggle_enable(self, mock_get_conn, mock_calc):
"""禁用 → 启用next_run_at 应被重新计算"""
enabled_row = list(_DB_ROW)
enabled_row[6] = True
enabled_row[8] = _NEXT
mock_cursor = MagicMock()
mock_cursor.fetchone.side_effect = [
(False, json.dumps(_SCHEDULE_CONFIG)), # SELECT 当前状态disabled
tuple(enabled_row), # UPDATE RETURNING
]
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.patch("/api/schedules/sched-1/toggle")
assert resp.status_code == 200
data = resp.json()
assert data["enabled"] is True
assert data["next_run_at"] is not None
mock_calc.assert_called_once()
@patch("app.routers.schedules.get_connection")
def test_toggle_not_found(self, mock_get_conn):
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = None
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.patch("/api/schedules/nonexistent/toggle")
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# 认证测试
# ---------------------------------------------------------------------------
class TestSchedulesAuth:
def test_requires_auth(self):
"""移除认证覆盖后,所有端点应返回 401/403"""
app.dependency_overrides.pop(get_current_user, None)
try:
assert client.get("/api/schedules").status_code in (401, 403)
assert client.post("/api/schedules", json=_VALID_CREATE).status_code in (401, 403)
assert client.put("/api/schedules/x", json={"name": "x"}).status_code in (401, 403)
assert client.delete("/api/schedules/x").status_code in (401, 403)
assert client.patch("/api/schedules/x/toggle").status_code in (401, 403)
finally:
app.dependency_overrides[get_current_user] = _override_auth

View File

@@ -0,0 +1,336 @@
# -*- coding: utf-8 -*-
"""门店隔离属性测试Property-Based Testing
Property 20: 对于任意两个不同 site_id 的 Operator一个 Operator 查询
队列/调度/执行历史时,结果中不应包含另一个 site_id 的数据。
Validates: Requirements 1.3
测试策略:
- 通过 mock 数据库交互,验证 API 路由在不同 site_id 下的数据隔离
- 队列隔离:为 site_id_a 入队任务,用 site_id_b 的 JWT 查询队列,结果应为空
- 调度隔离:为 site_id_a 创建调度任务,用 site_id_b 的 JWT 查询调度列表,结果应为空
- 执行历史隔离site_id_a 的执行历史,用 site_id_b 的 JWT 查询不到
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-isolation")
import json
import uuid
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, patch
from hypothesis import given, settings, assume
from hypothesis import strategies as st
from fastapi.testclient import TestClient
from app.auth.dependencies import CurrentUser, get_current_user
from app.main import app
# ---------------------------------------------------------------------------
# 通用策略Strategies
# ---------------------------------------------------------------------------
_site_id_st = st.integers(min_value=1, max_value=2**31 - 1)
# ---------------------------------------------------------------------------
# 辅助函数
# ---------------------------------------------------------------------------
def _make_mock_user(site_id: int) -> CurrentUser:
"""构造指定 site_id 的 mock 用户。"""
return CurrentUser(user_id=1, site_id=site_id)
def _make_queue_rows(site_id: int, count: int) -> list[tuple]:
"""生成 count 条属于 site_id 的队列行。"""
rows = []
for i in range(count):
rows.append((
str(uuid.uuid4()), # id
site_id, # site_id
json.dumps({"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}), # config
"pending", # status
i + 1, # position
datetime(2024, 1, 1, tzinfo=timezone.utc), # created_at
None, # started_at
None, # finished_at
None, # exit_code
None, # error_message
))
return rows
def _make_schedule_rows(site_id: int, count: int) -> list[tuple]:
"""生成 count 条属于 site_id 的调度行。"""
now = datetime.now(timezone.utc)
rows = []
for i in range(count):
rows.append((
str(uuid.uuid4()), # id
site_id, # site_id
f"调度任务_{i}", # name
["ODS_MEMBER"], # task_codes
{"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}, # task_config
{"schedule_type": "daily", "daily_time": "04:00", # schedule_config
"interval_value": 1, "interval_unit": "hours",
"weekly_days": [1], "weekly_time": "04:00",
"cron_expression": "0 4 * * *", "enabled": True,
"start_date": None, "end_date": None},
True, # enabled
None, # last_run_at
now + timedelta(hours=1), # next_run_at
0, # run_count
None, # last_status
now, # created_at
now, # updated_at
))
return rows
def _make_history_rows(site_id: int, count: int) -> list[tuple]:
"""生成 count 条属于 site_id 的执行历史行。"""
base_time = datetime(2024, 1, 1, tzinfo=timezone.utc)
rows = []
for i in range(count):
rows.append((
str(uuid.uuid4()), # id
site_id, # site_id
["ODS_MEMBER"], # task_codes
"success", # status
base_time + timedelta(hours=i), # started_at
base_time + timedelta(hours=i, minutes=30), # finished_at
0, # exit_code
1800000, # duration_ms
"python -m cli.main", # command
None, # summary
))
return rows
def _mock_conn_returning(rows: list[tuple]) -> MagicMock:
"""构造一个 mock connection其 cursor.fetchall 返回指定行。"""
mock_cursor = MagicMock()
mock_cursor.fetchall.return_value = rows
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
return mock_conn
# ---------------------------------------------------------------------------
# Property 20.1: 队列隔离
# **Validates: Requirements 1.3**
# ---------------------------------------------------------------------------
@settings(max_examples=100, deadline=None)
@given(
site_id_a=_site_id_st,
site_id_b=_site_id_st,
queue_count=st.integers(min_value=1, max_value=5),
)
@patch("app.services.task_queue.get_connection")
def test_queue_isolation(mock_get_conn, site_id_a, site_id_b, queue_count):
"""Property 20.1: 队列隔离。
为 site_id_a 入队若干任务后,用 site_id_b 的身份查询队列,
结果应为空——不同门店的队列数据互不可见。
"""
assume(site_id_a != site_id_b)
# site_id_a 的队列数据
rows_a = _make_queue_rows(site_id_a, queue_count)
# 核心隔离逻辑:根据查询时传入的 site_id 过滤
# list_pending 内部 SQL: WHERE site_id = %s AND status = 'pending'
def conn_for_site(querying_site_id):
"""模拟数据库行为:只返回匹配 site_id 的行。"""
if querying_site_id == site_id_a:
return rows_a
return [] # site_id_b 查不到 site_id_a 的数据
captured_params = {}
def make_mock_conn():
mock_cursor = MagicMock()
def execute_side_effect(sql, params=None):
if params:
captured_params["site_id"] = params[0]
# 根据 SQL 中的 site_id 参数返回对应数据
mock_cursor.fetchall.return_value = conn_for_site(params[0])
mock_cursor.execute = MagicMock(side_effect=execute_side_effect)
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
mock_cursor.__exit__ = MagicMock(return_value=False)
mock_conn = MagicMock()
mock_conn.cursor.return_value = mock_cursor
return mock_conn
mock_get_conn.return_value = make_mock_conn()
# 用 site_id_b 的身份查询队列
app.dependency_overrides[get_current_user] = lambda: _make_mock_user(site_id_b)
try:
client = TestClient(app)
resp = client.get("/api/execution/queue")
assert resp.status_code == 200
data = resp.json()
# 验证site_id_b 查不到 site_id_a 的任何数据
assert len(data) == 0, (
f"site_id_b={site_id_b} 不应看到 site_id_a={site_id_a} 的队列数据,"
f"但返回了 {len(data)} 条记录"
)
# 额外验证:即使有数据返回,也不应包含 site_id_a 的记录
for item in data:
assert item.get("site_id") != site_id_a, (
f"结果中不应包含 site_id_a={site_id_a} 的数据"
)
finally:
app.dependency_overrides.clear()
# ---------------------------------------------------------------------------
# Property 20.2: 调度隔离
# **Validates: Requirements 1.3**
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(
site_id_a=_site_id_st,
site_id_b=_site_id_st,
schedule_count=st.integers(min_value=1, max_value=5),
)
@patch("app.routers.schedules.get_connection")
def test_schedule_isolation(mock_get_conn, site_id_a, site_id_b, schedule_count):
"""Property 20.2: 调度隔离。
为 site_id_a 创建若干调度任务后,用 site_id_b 的身份查询调度列表,
结果应为空——不同门店的调度数据互不可见。
"""
assume(site_id_a != site_id_b)
# site_id_a 的调度数据
rows_a = _make_schedule_rows(site_id_a, schedule_count)
def make_mock_conn():
mock_cursor = MagicMock()
def execute_side_effect(sql, params=None):
if params:
querying_site_id = params[0]
# 只返回匹配 site_id 的行
if querying_site_id == site_id_a:
mock_cursor.fetchall.return_value = rows_a
else:
mock_cursor.fetchall.return_value = []
mock_cursor.execute = MagicMock(side_effect=execute_side_effect)
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
mock_cursor.__exit__ = MagicMock(return_value=False)
mock_conn = MagicMock()
mock_conn.cursor.return_value = mock_cursor
return mock_conn
mock_get_conn.return_value = make_mock_conn()
# 用 site_id_b 的身份查询调度列表
app.dependency_overrides[get_current_user] = lambda: _make_mock_user(site_id_b)
try:
client = TestClient(app)
resp = client.get("/api/schedules")
assert resp.status_code == 200
data = resp.json()
# 验证site_id_b 查不到 site_id_a 的任何调度数据
assert len(data) == 0, (
f"site_id_b={site_id_b} 不应看到 site_id_a={site_id_a} 的调度数据,"
f"但返回了 {len(data)} 条记录"
)
# 额外验证:即使有数据返回,也不应包含 site_id_a 的记录
for item in data:
assert item.get("site_id") != site_id_a, (
f"结果中不应包含 site_id_a={site_id_a} 的调度数据"
)
finally:
app.dependency_overrides.clear()
# ---------------------------------------------------------------------------
# Property 20.3: 执行历史隔离
# **Validates: Requirements 1.3**
# ---------------------------------------------------------------------------
@settings(max_examples=100, deadline=None)
@given(
site_id_a=_site_id_st,
site_id_b=_site_id_st,
history_count=st.integers(min_value=1, max_value=10),
)
@patch("app.routers.execution.get_connection")
def test_execution_history_isolation(mock_get_conn, site_id_a, site_id_b, history_count):
"""Property 20.3: 执行历史隔离。
site_id_a 有若干执行历史记录,用 site_id_b 的身份查询执行历史,
结果应为空——不同门店的执行历史互不可见。
"""
assume(site_id_a != site_id_b)
# site_id_a 的执行历史数据
rows_a = _make_history_rows(site_id_a, history_count)
def make_mock_conn():
mock_cursor = MagicMock()
def execute_side_effect(sql, params=None):
if params:
querying_site_id = params[0]
# 只返回匹配 site_id 的行
if querying_site_id == site_id_a:
mock_cursor.fetchall.return_value = rows_a
else:
mock_cursor.fetchall.return_value = []
mock_cursor.execute = MagicMock(side_effect=execute_side_effect)
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
mock_cursor.__exit__ = MagicMock(return_value=False)
mock_conn = MagicMock()
mock_conn.cursor.return_value = mock_cursor
return mock_conn
mock_get_conn.return_value = make_mock_conn()
# 用 site_id_b 的身份查询执行历史
app.dependency_overrides[get_current_user] = lambda: _make_mock_user(site_id_b)
try:
client = TestClient(app)
resp = client.get("/api/execution/history")
assert resp.status_code == 200
data = resp.json()
# 验证site_id_b 查不到 site_id_a 的任何执行历史
assert len(data) == 0, (
f"site_id_b={site_id_b} 不应看到 site_id_a={site_id_a} 的执行历史,"
f"但返回了 {len(data)} 条记录"
)
# 额外验证:即使有数据返回,也不应包含 site_id_a 的记录
for item in data:
assert item.get("site_id") != site_id_a, (
f"结果中不应包含 site_id_a={site_id_a} 的执行历史"
)
finally:
app.dependency_overrides.clear()

View File

@@ -0,0 +1,275 @@
# -*- coding: utf-8 -*-
"""TaskConfig 属性测试Property-Based Testing
使用 hypothesis 验证 TaskConfig 相关的通用正确性属性:
- Property 1: TaskConfig 序列化往返一致性
- Property 6: 时间窗口验证
- Property 7: TaskConfig 到 CLI 命令转换完整性
"""
import datetime
from hypothesis import given, settings, assume
from hypothesis import strategies as st
from pydantic import ValidationError
from app.schemas.tasks import TaskConfigSchema
from app.services.cli_builder import CLIBuilder, VALID_FLOWS, VALID_PROCESSING_MODES
from app.services.task_registry import ALL_TASKS
# ---------------------------------------------------------------------------
# 策略Strategies
# ---------------------------------------------------------------------------
# 从真实任务注册表中采样任务代码
_task_codes = [t.code for t in ALL_TASKS]
_tasks_st = st.lists(
st.sampled_from(_task_codes),
min_size=1,
max_size=5,
unique=True,
)
_pipeline_st = st.sampled_from(sorted(VALID_FLOWS))
_processing_mode_st = st.sampled_from(sorted(VALID_PROCESSING_MODES))
_window_mode_st = st.sampled_from(["lookback", "custom"])
# 日期策略:生成 YYYY-MM-DD 格式字符串
_date_st = st.dates(
min_value=datetime.date(2020, 1, 1),
max_value=datetime.date(2030, 12, 31),
).map(lambda d: d.isoformat())
_window_split_st = st.sampled_from([None, "none", "day"])
_window_split_days_st = st.one_of(st.none(), st.sampled_from([1, 10, 30]))
_lookback_hours_st = st.integers(min_value=1, max_value=720)
_overlap_seconds_st = st.integers(min_value=0, max_value=7200)
_store_id_st = st.one_of(st.none(), st.integers(min_value=1, max_value=2**31 - 1))
# DWD 表名采样
_dwd_table_names = [
"dwd.dim_site",
"dwd.dim_member",
"dwd.dwd_settlement_head",
]
_dwd_only_tables_st = st.one_of(
st.none(),
st.lists(st.sampled_from(_dwd_table_names), min_size=1, max_size=3, unique=True),
)
def _valid_task_config_st():
"""生成有效的 TaskConfigSchema 的复合策略。
确保 window_mode=custom 时 window_end >= window_start
避免触发 Pydantic 验证错误。
"""
@st.composite
def _build(draw):
tasks = draw(_tasks_st)
pipeline = draw(_pipeline_st)
processing_mode = draw(_processing_mode_st)
dry_run = draw(st.booleans())
window_mode = draw(_window_mode_st)
store_id = draw(_store_id_st)
dwd_only_tables = draw(_dwd_only_tables_st)
window_split = draw(_window_split_st)
window_split_days = draw(_window_split_days_st)
fetch_before_verify = draw(st.booleans())
skip_ods = draw(st.booleans())
ods_local = draw(st.booleans())
if window_mode == "custom":
d1 = draw(st.dates(
min_value=datetime.date(2020, 1, 1),
max_value=datetime.date(2030, 12, 31),
))
d2 = draw(st.dates(
min_value=datetime.date(2020, 1, 1),
max_value=datetime.date(2030, 12, 31),
))
# 保证 end >= start
window_start = min(d1, d2).isoformat()
window_end = max(d1, d2).isoformat()
lookback_hours = 24
overlap_seconds = 600
else:
window_start = None
window_end = None
lookback_hours = draw(_lookback_hours_st)
overlap_seconds = draw(_overlap_seconds_st)
return TaskConfigSchema(
tasks=tasks,
pipeline=pipeline,
processing_mode=processing_mode,
dry_run=dry_run,
window_mode=window_mode,
window_start=window_start,
window_end=window_end,
window_split=window_split,
window_split_days=window_split_days,
lookback_hours=lookback_hours,
overlap_seconds=overlap_seconds,
fetch_before_verify=fetch_before_verify,
skip_ods_when_fetch_before_verify=skip_ods,
ods_use_local_json=ods_local,
store_id=store_id,
dwd_only_tables=dwd_only_tables,
extra_args={},
)
return _build()
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 1: TaskConfig 序列化往返一致性
# **Validates: Requirements 11.1, 11.2, 11.3**
# ---------------------------------------------------------------------------
@settings(max_examples=200)
@given(config=_valid_task_config_st())
def test_task_config_round_trip(config: TaskConfigSchema):
"""Property 1: 序列化为 JSON 后再反序列化,应产生与原始对象等价的结果。"""
json_str = config.model_dump_json()
restored = TaskConfigSchema.model_validate_json(json_str)
assert restored == config, (
f"往返不一致:\n原始={config}\n还原={restored}"
)
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 6: 时间窗口验证
# **Validates: Requirements 2.3**
# ---------------------------------------------------------------------------
@settings(max_examples=200)
@given(
d1=st.dates(
min_value=datetime.date(2020, 1, 1),
max_value=datetime.date(2030, 12, 31),
),
d2=st.dates(
min_value=datetime.date(2020, 1, 1),
max_value=datetime.date(2030, 12, 31),
),
)
def test_time_window_validation(d1: datetime.date, d2: datetime.date):
"""Property 6: window_end < window_start 时验证应失败,否则应通过。"""
start_str = d1.isoformat()
end_str = d2.isoformat()
if end_str < start_str:
# window_end 早于 window_start → 验证应失败
try:
TaskConfigSchema(
tasks=["ODS_MEMBER"],
window_mode="custom",
window_start=start_str,
window_end=end_str,
)
raise AssertionError(
f"期望 ValidationError但验证通过了start={start_str}, end={end_str}"
)
except ValidationError:
pass # 预期行为
else:
# window_end >= window_start → 验证应通过
config = TaskConfigSchema(
tasks=["ODS_MEMBER"],
window_mode="custom",
window_start=start_str,
window_end=end_str,
)
assert config.window_start == start_str
assert config.window_end == end_str
# ---------------------------------------------------------------------------
# Feature: admin-web-console, Property 7: TaskConfig 到 CLI 命令转换完整性
# **Validates: Requirements 2.5, 2.6**
# ---------------------------------------------------------------------------
_builder = CLIBuilder()
_ETL_PATH = "/fake/etl/project"
@settings(max_examples=200)
@given(config=_valid_task_config_st())
def test_task_config_to_cli_completeness(config: TaskConfigSchema):
"""Property 7: CLIBuilder 生成的命令应包含 TaskConfig 中所有非空字段对应的 CLI 参数。"""
cmd = _builder.build_command(config, _ETL_PATH)
# 1) --pipeline 始终存在且值正确
assert "--pipeline" in cmd
idx = cmd.index("--pipeline")
assert cmd[idx + 1] == config.pipeline
# 2) --processing-mode 始终存在且值正确
assert "--processing-mode" in cmd
idx = cmd.index("--processing-mode")
assert cmd[idx + 1] == config.processing_mode
# 3) 非空任务列表 → --tasks 存在
if config.tasks:
assert "--tasks" in cmd
idx = cmd.index("--tasks")
assert set(cmd[idx + 1].split(",")) == set(config.tasks)
# 4) 时间窗口参数
if config.window_mode == "lookback":
# lookback 模式 → --lookback-hours 和 --overlap-seconds
if config.lookback_hours is not None:
assert "--lookback-hours" in cmd
idx = cmd.index("--lookback-hours")
assert cmd[idx + 1] == str(config.lookback_hours)
if config.overlap_seconds is not None:
assert "--overlap-seconds" in cmd
idx = cmd.index("--overlap-seconds")
assert cmd[idx + 1] == str(config.overlap_seconds)
# lookback 模式不应出现 custom 参数
assert "--window-start" not in cmd
assert "--window-end" not in cmd
else:
# custom 模式 → --window-start / --window-end
if config.window_start:
assert "--window-start" in cmd
if config.window_end:
assert "--window-end" in cmd
# custom 模式不应出现 lookback 参数
assert "--lookback-hours" not in cmd
assert "--overlap-seconds" not in cmd
# 5) dry_run → --dry-run
if config.dry_run:
assert "--dry-run" in cmd
else:
assert "--dry-run" not in cmd
# 6) store_id → --store-id
if config.store_id is not None:
assert "--store-id" in cmd
idx = cmd.index("--store-id")
assert cmd[idx + 1] == str(config.store_id)
else:
assert "--store-id" not in cmd
# 7) fetch_before_verify → 仅 verify_only 模式下生成
if config.fetch_before_verify and config.processing_mode == "verify_only":
assert "--fetch-before-verify" in cmd
else:
assert "--fetch-before-verify" not in cmd
# 8) window_split非 None 且非 "none")→ --window-split
if config.window_split and config.window_split != "none":
assert "--window-split" in cmd
idx = cmd.index("--window-split")
assert cmd[idx + 1] == config.window_split
if config.window_split_days is not None:
assert "--window-split-days" in cmd
else:
assert "--window-split" not in cmd

View File

@@ -0,0 +1,373 @@
# -*- coding: utf-8 -*-
"""TaskExecutor 单元测试
覆盖子进程启动、stdout/stderr 读取、日志广播、取消、数据库记录。
使用 asyncio 测试mock 子进程和数据库连接避免外部依赖。
"""
import asyncio
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.schemas.tasks import TaskConfigSchema
from app.services.task_executor import TaskExecutor
@pytest.fixture
def executor() -> TaskExecutor:
return TaskExecutor()
@pytest.fixture
def sample_config() -> TaskConfigSchema:
return TaskConfigSchema(
tasks=["ODS_MEMBER", "ODS_PAYMENT"],
pipeline="api_ods_dwd",
store_id=42,
)
def _make_stream(lines: list[bytes]) -> AsyncMock:
"""构造一个模拟的 asyncio.StreamReader按行返回数据。"""
stream = AsyncMock()
# readline 依次返回每行,最后返回 b"" 表示 EOF
stream.readline = AsyncMock(side_effect=[*lines, b""])
return stream
# ---------------------------------------------------------------------------
# 订阅 / 取消订阅
# ---------------------------------------------------------------------------
class TestSubscription:
def test_subscribe_returns_queue(self, executor: TaskExecutor):
q = executor.subscribe("exec-1")
assert isinstance(q, asyncio.Queue)
def test_subscribe_multiple(self, executor: TaskExecutor):
q1 = executor.subscribe("exec-1")
q2 = executor.subscribe("exec-1")
assert q1 is not q2
assert len(executor._subscribers["exec-1"]) == 2
def test_unsubscribe_removes_queue(self, executor: TaskExecutor):
q = executor.subscribe("exec-1")
executor.unsubscribe("exec-1", q)
# 最后一个订阅者移除后,键也被清理
assert "exec-1" not in executor._subscribers
def test_unsubscribe_nonexistent_is_safe(self, executor: TaskExecutor):
"""对不存在的 execution_id 取消订阅不应报错"""
q: asyncio.Queue = asyncio.Queue()
executor.unsubscribe("nonexistent", q)
# ---------------------------------------------------------------------------
# 广播
# ---------------------------------------------------------------------------
class TestBroadcast:
def test_broadcast_to_subscribers(self, executor: TaskExecutor):
q1 = executor.subscribe("exec-1")
q2 = executor.subscribe("exec-1")
executor._broadcast("exec-1", "hello")
assert q1.get_nowait() == "hello"
assert q2.get_nowait() == "hello"
def test_broadcast_no_subscribers_is_safe(self, executor: TaskExecutor):
"""无订阅者时广播不应报错"""
executor._broadcast("nonexistent", "hello")
def test_broadcast_end_sends_none(self, executor: TaskExecutor):
q = executor.subscribe("exec-1")
executor._broadcast_end("exec-1")
assert q.get_nowait() is None
# ---------------------------------------------------------------------------
# 日志缓冲区
# ---------------------------------------------------------------------------
class TestLogBuffer:
def test_get_logs_empty(self, executor: TaskExecutor):
assert executor.get_logs("nonexistent") == []
def test_get_logs_returns_copy(self, executor: TaskExecutor):
executor._log_buffers["exec-1"] = ["line1", "line2"]
logs = executor.get_logs("exec-1")
assert logs == ["line1", "line2"]
# 修改副本不影响原始
logs.append("line3")
assert len(executor._log_buffers["exec-1"]) == 2
# ---------------------------------------------------------------------------
# 执行状态查询
# ---------------------------------------------------------------------------
class TestRunningState:
def test_is_running_false_when_no_process(self, executor: TaskExecutor):
assert executor.is_running("nonexistent") is False
def test_is_running_true_when_process_active(self, executor: TaskExecutor):
proc = MagicMock()
proc.returncode = None
executor._processes["exec-1"] = proc
assert executor.is_running("exec-1") is True
def test_is_running_false_when_process_exited(self, executor: TaskExecutor):
proc = MagicMock()
proc.returncode = 0
executor._processes["exec-1"] = proc
assert executor.is_running("exec-1") is False
def test_get_running_ids(self, executor: TaskExecutor):
running = MagicMock()
running.returncode = None
exited = MagicMock()
exited.returncode = 0
executor._processes["a"] = running
executor._processes["b"] = exited
assert executor.get_running_ids() == ["a"]
# ---------------------------------------------------------------------------
# _read_stream
# ---------------------------------------------------------------------------
class TestReadStream:
@pytest.mark.asyncio
async def test_read_stdout_lines(self, executor: TaskExecutor):
executor._log_buffers["exec-1"] = []
stream = _make_stream([b"line1\n", b"line2\n"])
collector: list[str] = []
await executor._read_stream("exec-1", stream, "stdout", collector)
assert collector == ["line1", "line2"]
assert executor._log_buffers["exec-1"] == [
"[stdout] line1",
"[stdout] line2",
]
@pytest.mark.asyncio
async def test_read_stderr_lines(self, executor: TaskExecutor):
executor._log_buffers["exec-1"] = []
stream = _make_stream([b"err1\n"])
collector: list[str] = []
await executor._read_stream("exec-1", stream, "stderr", collector)
assert collector == ["err1"]
assert executor._log_buffers["exec-1"] == ["[stderr] err1"]
@pytest.mark.asyncio
async def test_read_stream_none_is_safe(self, executor: TaskExecutor):
"""stream 为 None 时不应报错"""
collector: list[str] = []
await executor._read_stream("exec-1", None, "stdout", collector)
assert collector == []
@pytest.mark.asyncio
async def test_broadcast_during_read(self, executor: TaskExecutor):
executor._log_buffers["exec-1"] = []
q = executor.subscribe("exec-1")
stream = _make_stream([b"hello\n"])
collector: list[str] = []
await executor._read_stream("exec-1", stream, "stdout", collector)
assert q.get_nowait() == "[stdout] hello"
# ---------------------------------------------------------------------------
# execute集成级mock 子进程和数据库)
# ---------------------------------------------------------------------------
class TestExecute:
@pytest.mark.asyncio
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
@patch("asyncio.create_subprocess_exec")
async def test_successful_execution(
self, mock_create, mock_write_log, mock_update_log,
executor: TaskExecutor, sample_config: TaskConfigSchema,
):
# 模拟子进程
proc = AsyncMock()
proc.returncode = None
proc.stdout = _make_stream([b"processing...\n", b"done\n"])
proc.stderr = _make_stream([])
proc.wait = AsyncMock(return_value=0)
# wait 调用后设置 returncode
async def _wait():
proc.returncode = 0
return 0
proc.wait = _wait
mock_create.return_value = proc
await executor.execute(sample_config, "exec-1", site_id=42)
# 验证写入了 running 状态
mock_write_log.assert_called_once()
call_kwargs = mock_write_log.call_args[1]
assert call_kwargs["status"] == "running"
assert call_kwargs["execution_id"] == "exec-1"
# 验证更新了 success 状态
mock_update_log.assert_called_once()
update_kwargs = mock_update_log.call_args[1]
assert update_kwargs["status"] == "success"
assert update_kwargs["exit_code"] == 0
assert "processing..." in update_kwargs["output_log"]
assert "done" in update_kwargs["output_log"]
# 进程已从跟踪表移除
assert "exec-1" not in executor._processes
@pytest.mark.asyncio
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
@patch("asyncio.create_subprocess_exec")
async def test_failed_execution(
self, mock_create, mock_write_log, mock_update_log,
executor: TaskExecutor, sample_config: TaskConfigSchema,
):
proc = AsyncMock()
proc.returncode = None
proc.stdout = _make_stream([])
proc.stderr = _make_stream([b"error occurred\n"])
async def _wait():
proc.returncode = 1
return 1
proc.wait = _wait
mock_create.return_value = proc
await executor.execute(sample_config, "exec-2", site_id=42)
update_kwargs = mock_update_log.call_args[1]
assert update_kwargs["status"] == "failed"
assert update_kwargs["exit_code"] == 1
assert "error occurred" in update_kwargs["error_log"]
@pytest.mark.asyncio
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
@patch("asyncio.create_subprocess_exec")
async def test_exception_during_execution(
self, mock_create, mock_write_log, mock_update_log,
executor: TaskExecutor, sample_config: TaskConfigSchema,
):
"""子进程创建失败时应记录 failed 状态"""
mock_create.side_effect = OSError("command not found")
await executor.execute(sample_config, "exec-3", site_id=42)
update_kwargs = mock_update_log.call_args[1]
assert update_kwargs["status"] == "failed"
@pytest.mark.asyncio
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
@patch("asyncio.create_subprocess_exec")
async def test_subscribers_notified_on_completion(
self, mock_create, mock_write_log, mock_update_log,
executor: TaskExecutor, sample_config: TaskConfigSchema,
):
proc = AsyncMock()
proc.returncode = None
proc.stdout = _make_stream([b"line\n"])
proc.stderr = _make_stream([])
async def _wait():
proc.returncode = 0
return 0
proc.wait = _wait
mock_create.return_value = proc
q = executor.subscribe("exec-4")
await executor.execute(sample_config, "exec-4", site_id=42)
# 应收到日志行 + None 哨兵
messages = []
while not q.empty():
messages.append(q.get_nowait())
assert "[stdout] line" in messages
assert None in messages
@pytest.mark.asyncio
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
@patch("asyncio.create_subprocess_exec")
async def test_duration_ms_recorded(
self, mock_create, mock_write_log, mock_update_log,
executor: TaskExecutor, sample_config: TaskConfigSchema,
):
proc = AsyncMock()
proc.returncode = None
proc.stdout = _make_stream([])
proc.stderr = _make_stream([])
async def _wait():
proc.returncode = 0
return 0
proc.wait = _wait
mock_create.return_value = proc
await executor.execute(sample_config, "exec-5", site_id=42)
update_kwargs = mock_update_log.call_args[1]
assert isinstance(update_kwargs["duration_ms"], int)
assert update_kwargs["duration_ms"] >= 0
# ---------------------------------------------------------------------------
# cancel
# ---------------------------------------------------------------------------
class TestCancel:
@pytest.mark.asyncio
async def test_cancel_running_process(self, executor: TaskExecutor):
proc = MagicMock()
proc.returncode = None
proc.terminate = MagicMock()
executor._processes["exec-1"] = proc
result = await executor.cancel("exec-1")
assert result is True
proc.terminate.assert_called_once()
@pytest.mark.asyncio
async def test_cancel_nonexistent_returns_false(self, executor: TaskExecutor):
result = await executor.cancel("nonexistent")
assert result is False
@pytest.mark.asyncio
async def test_cancel_already_exited_returns_false(self, executor: TaskExecutor):
proc = MagicMock()
proc.returncode = 0
executor._processes["exec-1"] = proc
result = await executor.cancel("exec-1")
assert result is False
@pytest.mark.asyncio
async def test_cancel_process_lookup_error(self, executor: TaskExecutor):
"""进程已消失时 terminate 抛出 ProcessLookupError"""
proc = MagicMock()
proc.returncode = None
proc.terminate = MagicMock(side_effect=ProcessLookupError)
executor._processes["exec-1"] = proc
result = await executor.cancel("exec-1")
assert result is False
# ---------------------------------------------------------------------------
# cleanup
# ---------------------------------------------------------------------------
class TestCleanup:
def test_cleanup_removes_buffers_and_subscribers(self, executor: TaskExecutor):
executor._log_buffers["exec-1"] = ["line"]
executor.subscribe("exec-1")
executor.cleanup("exec-1")
assert "exec-1" not in executor._log_buffers
assert "exec-1" not in executor._subscribers
def test_cleanup_nonexistent_is_safe(self, executor: TaskExecutor):
executor.cleanup("nonexistent")

View File

@@ -0,0 +1,482 @@
# -*- coding: utf-8 -*-
"""TaskQueue 单元测试
覆盖enqueue、dequeue、reorder、delete、process_loop 的核心逻辑。
使用 mock 数据库操作,专注于业务逻辑验证。
"""
import asyncio
import json
import uuid
from unittest.mock import MagicMock, AsyncMock, patch, call
import pytest
from app.schemas.tasks import TaskConfigSchema
from app.services.task_queue import TaskQueue, QueuedTask
@pytest.fixture
def queue() -> TaskQueue:
return TaskQueue()
@pytest.fixture
def sample_config() -> TaskConfigSchema:
return TaskConfigSchema(
tasks=["ODS_MEMBER", "ODS_PAYMENT"],
pipeline="api_ods_dwd",
store_id=42,
)
def _mock_cursor(fetchone_val=None, fetchall_val=None, rowcount=1):
"""构造 mock cursor支持 context manager 协议。"""
cur = MagicMock()
cur.fetchone.return_value = fetchone_val
cur.fetchall.return_value = fetchall_val or []
cur.rowcount = rowcount
cur.__enter__ = MagicMock(return_value=cur)
cur.__exit__ = MagicMock(return_value=False)
return cur
def _mock_conn(cursor):
"""构造 mock connection支持 cursor() context manager。"""
conn = MagicMock()
conn.cursor.return_value = cursor
return conn
# ---------------------------------------------------------------------------
# enqueue
# ---------------------------------------------------------------------------
class TestEnqueue:
@patch("app.services.task_queue.get_connection")
def test_enqueue_returns_uuid(self, mock_get_conn, queue, sample_config):
cur = _mock_cursor(fetchone_val=(0,))
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
task_id = queue.enqueue(sample_config, site_id=42)
# 返回有效 UUID
uuid.UUID(task_id)
conn.commit.assert_called_once()
conn.close.assert_called_once()
@patch("app.services.task_queue.get_connection")
def test_enqueue_position_increments(self, mock_get_conn, queue, sample_config):
"""新任务 position = 当前最大 position + 1"""
cur = _mock_cursor(fetchone_val=(5,))
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue.enqueue(sample_config, site_id=42)
# 检查 INSERT 调用中的 position 参数
insert_call = cur.execute.call_args_list[1]
args = insert_call[0][1]
# args = (task_id, site_id, config_json, new_pos)
assert args[3] == 6 # 5 + 1
@patch("app.services.task_queue.get_connection")
def test_enqueue_empty_queue_position_is_one(self, mock_get_conn, queue, sample_config):
"""空队列时 position = 1"""
cur = _mock_cursor(fetchone_val=(0,))
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue.enqueue(sample_config, site_id=42)
insert_call = cur.execute.call_args_list[1]
args = insert_call[0][1]
assert args[3] == 1
@patch("app.services.task_queue.get_connection")
def test_enqueue_serializes_config(self, mock_get_conn, queue, sample_config):
"""config 被序列化为 JSON 字符串"""
cur = _mock_cursor(fetchone_val=(0,))
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue.enqueue(sample_config, site_id=42)
insert_call = cur.execute.call_args_list[1]
config_json_str = insert_call[0][1][2]
parsed = json.loads(config_json_str)
assert parsed["tasks"] == ["ODS_MEMBER", "ODS_PAYMENT"]
assert parsed["pipeline"] == "api_ods_dwd"
# ---------------------------------------------------------------------------
# dequeue
# ---------------------------------------------------------------------------
class TestDequeue:
@patch("app.services.task_queue.get_connection")
def test_dequeue_returns_none_when_empty(self, mock_get_conn, queue):
cur = _mock_cursor(fetchone_val=None)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
result = queue.dequeue(site_id=42)
assert result is None
conn.commit.assert_called()
@patch("app.services.task_queue.get_connection")
def test_dequeue_returns_task(self, mock_get_conn, queue):
task_id = str(uuid.uuid4())
config_dict = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}
row = (
task_id, 42, json.dumps(config_dict), "pending", 1,
None, None, None, None, None,
)
cur = _mock_cursor(fetchone_val=row)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
result = queue.dequeue(site_id=42)
assert result is not None
assert result.id == task_id
assert result.site_id == 42
assert result.status == "running" # dequeue 后状态变为 running
assert result.config["tasks"] == ["ODS_MEMBER"]
@patch("app.services.task_queue.get_connection")
def test_dequeue_updates_status_to_running(self, mock_get_conn, queue):
task_id = str(uuid.uuid4())
config_dict = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}
row = (
task_id, 42, json.dumps(config_dict), "pending", 1,
None, None, None, None, None,
)
cur = _mock_cursor(fetchone_val=row)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue.dequeue(site_id=42)
# 第二次 execute 调用应该是 UPDATE status = 'running'
update_call = cur.execute.call_args_list[1]
sql = update_call[0][0]
assert "running" in sql
assert task_id in update_call[0][1]
# ---------------------------------------------------------------------------
# reorder
# ---------------------------------------------------------------------------
class TestReorder:
@patch("app.services.task_queue.get_connection")
def test_reorder_moves_task(self, mock_get_conn, queue):
"""将第 3 个任务移到第 1 位"""
ids = [str(uuid.uuid4()) for _ in range(3)]
rows = [(i,) for i in ids]
cur = _mock_cursor(fetchall_val=rows)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue.reorder(ids[2], new_position=1, site_id=42)
# 重排后顺序应为 [ids[2], ids[0], ids[1]]
update_calls = cur.execute.call_args_list[1:] # 跳过 SELECT
positions = {}
for c in update_calls:
pos, tid = c[0][1]
positions[tid] = pos
assert positions[ids[2]] == 1
assert positions[ids[0]] == 2
assert positions[ids[1]] == 3
@patch("app.services.task_queue.get_connection")
def test_reorder_nonexistent_task_is_noop(self, mock_get_conn, queue):
"""重排不存在的任务不报错"""
rows = [(str(uuid.uuid4()),)]
cur = _mock_cursor(fetchall_val=rows)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue.reorder("nonexistent-id", new_position=1, site_id=42)
# 只有 SELECT没有 UPDATE
assert cur.execute.call_count == 1
@patch("app.services.task_queue.get_connection")
def test_reorder_clamps_position(self, mock_get_conn, queue):
"""position 超出范围时 clamp 到有效范围"""
ids = [str(uuid.uuid4()) for _ in range(2)]
rows = [(i,) for i in ids]
cur = _mock_cursor(fetchall_val=rows)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
# new_position=100 超出范围,应 clamp 到末尾
queue.reorder(ids[0], new_position=100, site_id=42)
update_calls = cur.execute.call_args_list[1:]
positions = {}
for c in update_calls:
pos, tid = c[0][1]
positions[tid] = pos
# ids[0] 移到末尾
assert positions[ids[1]] == 1
assert positions[ids[0]] == 2
# ---------------------------------------------------------------------------
# delete
# ---------------------------------------------------------------------------
class TestDelete:
@patch("app.services.task_queue.get_connection")
def test_delete_pending_task(self, mock_get_conn, queue):
cur = _mock_cursor(rowcount=1)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
result = queue.delete("task-1", site_id=42)
assert result is True
conn.commit.assert_called_once()
@patch("app.services.task_queue.get_connection")
def test_delete_nonexistent_returns_false(self, mock_get_conn, queue):
cur = _mock_cursor(rowcount=0)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
result = queue.delete("nonexistent", site_id=42)
assert result is False
@patch("app.services.task_queue.get_connection")
def test_delete_only_affects_pending(self, mock_get_conn, queue):
"""DELETE SQL 包含 status = 'pending' 条件"""
cur = _mock_cursor(rowcount=0)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue.delete("task-1", site_id=42)
sql = cur.execute.call_args[0][0]
assert "pending" in sql
# ---------------------------------------------------------------------------
# list_pending / has_running
# ---------------------------------------------------------------------------
class TestQuery:
@patch("app.services.task_queue.get_connection")
def test_list_pending_empty(self, mock_get_conn, queue):
cur = _mock_cursor(fetchall_val=[])
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
result = queue.list_pending(site_id=42)
assert result == []
@patch("app.services.task_queue.get_connection")
def test_list_pending_returns_tasks(self, mock_get_conn, queue):
tid = str(uuid.uuid4())
config = json.dumps({"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"})
rows = [(tid, 42, config, "pending", 1, None, None, None, None, None)]
cur = _mock_cursor(fetchall_val=rows)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
result = queue.list_pending(site_id=42)
assert len(result) == 1
assert result[0].id == tid
@patch("app.services.task_queue.get_connection")
def test_has_running_true(self, mock_get_conn, queue):
cur = _mock_cursor(fetchone_val=(True,))
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
assert queue.has_running(site_id=42) is True
@patch("app.services.task_queue.get_connection")
def test_has_running_false(self, mock_get_conn, queue):
cur = _mock_cursor(fetchone_val=(False,))
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
assert queue.has_running(site_id=42) is False
# ---------------------------------------------------------------------------
# process_loop / _process_once
# ---------------------------------------------------------------------------
class TestProcessLoop:
@patch("app.services.task_queue.get_connection")
@pytest.mark.asyncio
async def test_process_once_skips_when_running(self, mock_get_conn, queue):
"""有 running 任务时不 dequeue"""
# _get_pending_site_ids 返回 [42]
# has_running(42) 返回 True
call_count = 0
def side_effect_conn():
nonlocal call_count
call_count += 1
if call_count == 1:
# _get_pending_site_ids
cur = _mock_cursor(fetchall_val=[(42,)])
return _mock_conn(cur)
else:
# has_running
cur = _mock_cursor(fetchone_val=(True,))
return _mock_conn(cur)
mock_get_conn.side_effect = side_effect_conn
mock_executor = MagicMock()
await queue._process_once(mock_executor)
# 不应调用 execute
mock_executor.execute.assert_not_called()
@patch("app.services.task_queue.get_connection")
@pytest.mark.asyncio
async def test_process_once_dequeues_and_executes(self, mock_get_conn, queue):
"""无 running 任务时 dequeue 并执行"""
task_id = str(uuid.uuid4())
config_dict = {
"tasks": ["ODS_MEMBER"],
"pipeline": "api_ods_dwd",
"processing_mode": "increment_only",
"dry_run": False,
"window_mode": "lookback",
"lookback_hours": 24,
"overlap_seconds": 600,
"fetch_before_verify": False,
"skip_ods_when_fetch_before_verify": False,
"ods_use_local_json": False,
"extra_args": {},
}
config_json = json.dumps(config_dict)
call_count = 0
def side_effect_conn():
nonlocal call_count
call_count += 1
if call_count == 1:
# _get_pending_site_ids
cur = _mock_cursor(fetchall_val=[(42,)])
return _mock_conn(cur)
elif call_count == 2:
# has_running → False
cur = _mock_cursor(fetchone_val=(False,))
return _mock_conn(cur)
else:
# dequeue → 返回任务
row = (
task_id, 42, config_json, "pending", 1,
None, None, None, None, None,
)
cur = _mock_cursor(fetchone_val=row)
return _mock_conn(cur)
mock_get_conn.side_effect = side_effect_conn
mock_executor = MagicMock()
mock_executor.execute = AsyncMock()
await queue._process_once(mock_executor)
# 给 create_task 一点时间启动
await asyncio.sleep(0.1)
@patch("app.services.task_queue.get_connection")
@pytest.mark.asyncio
async def test_process_once_no_pending(self, mock_get_conn, queue):
"""无 pending 任务时什么都不做"""
cur = _mock_cursor(fetchall_val=[])
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
mock_executor = MagicMock()
await queue._process_once(mock_executor)
mock_executor.execute.assert_not_called()
# ---------------------------------------------------------------------------
# 生命周期
# ---------------------------------------------------------------------------
class TestLifecycle:
@pytest.mark.asyncio
async def test_stop_sets_running_false(self, queue):
queue._running = True
queue._loop_task = None
await queue.stop()
assert queue._running is False
def test_start_creates_task(self, queue):
"""start() 应创建 asyncio.Task需要事件循环"""
# 仅验证 _running 初始状态
assert queue._running is False
assert queue._loop_task is None
# ---------------------------------------------------------------------------
# _mark_failed / _update_queue_status_from_log
# ---------------------------------------------------------------------------
class TestInternalHelpers:
@patch("app.services.task_queue.get_connection")
def test_mark_failed(self, mock_get_conn, queue):
cur = _mock_cursor()
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue._mark_failed("queue-1", "测试错误")
sql = cur.execute.call_args[0][0]
assert "failed" in sql
args = cur.execute.call_args[0][1]
assert args[0] == "测试错误"
assert args[1] == "queue-1"
@patch("app.services.task_queue.get_connection")
def test_update_queue_status_from_log(self, mock_get_conn, queue):
"""从 execution_log 同步状态到 task_queue"""
from datetime import datetime, timezone
finished = datetime.now(timezone.utc)
# 第一次 fetchone 返回 execution_log 行
cur = _mock_cursor(fetchone_val=("success", finished, 0, None))
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue._update_queue_status_from_log("queue-1")
# 应有 SELECT + UPDATE 两次 execute
assert cur.execute.call_count == 2
conn.commit.assert_called_once()
@patch("app.services.task_queue.get_connection")
def test_update_queue_status_no_log(self, mock_get_conn, queue):
"""execution_log 无记录时不更新"""
cur = _mock_cursor(fetchone_val=None)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue._update_queue_status_from_log("queue-1")
# 只有 SELECT没有 UPDATE
assert cur.execute.call_count == 1

View File

@@ -0,0 +1,299 @@
# -*- coding: utf-8 -*-
"""任务注册表分组属性测试Property-Based Testing
Property 4: 对于 Task_Registry 中的任务集合,分组结果中每个任务应出现在
且仅出现在其所属业务域的分组中。
Validates: Requirements 2.1
测试策略:
1. 直接测试 get_tasks_grouped_by_domain 函数:
- 每个任务出现在且仅出现在其 domain 对应的分组中
- 分组中的任务总数等于全部任务数(不多不少)
- 每个分组的 key 等于该分组内所有任务的 domain
2. 通过 API 端点测试TestClient + mock auth
- 返回的 groups 中每个任务的 domain 与其所在分组 key 一致
- 所有任务都出现在结果中
3. 随机子集验证:
- 随机选取任务子集,验证分组逻辑的一致性
- 随机选取 domain验证该 domain 下的任务都正确
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-registry")
from hypothesis import given, settings, assume
from hypothesis import strategies as st
from app.services.task_registry import (
get_all_tasks,
get_tasks_grouped_by_domain,
TaskDefinition,
)
from fastapi.testclient import TestClient
from app.main import app
from app.auth.dependencies import get_current_user, CurrentUser
# ---------------------------------------------------------------------------
# 辅助
# ---------------------------------------------------------------------------
ALL_TASKS = get_all_tasks()
ALL_CODES = [t.code for t in ALL_TASKS]
ALL_DOMAINS = list({t.domain for t in ALL_TASKS})
def _mock_user() -> CurrentUser:
return CurrentUser(user_id=1, site_id=1)
# ---------------------------------------------------------------------------
# Property 4.1: 分组完整性 — 每个任务出现在且仅出现在其 domain 分组中
# Validates: Requirements 2.1
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(data=st.data())
def test_every_task_in_exactly_its_domain_group(data):
"""Property 4.1: 每个任务出现在且仅出现在其所属业务域的分组中。
从全量任务中随机选取一个任务,验证它只出现在对应 domain 的分组里,
且不出现在其他任何分组中。
"""
grouped = get_tasks_grouped_by_domain()
# 随机选取一个任务
task = data.draw(st.sampled_from(ALL_TASKS))
# 该任务必须出现在其 domain 分组中
assert task.domain in grouped, (
f"任务 {task.code} 的 domain '{task.domain}' 不在分组 keys 中"
)
domain_codes = [t.code for t in grouped[task.domain]]
assert task.code in domain_codes, (
f"任务 {task.code} 未出现在其 domain '{task.domain}' 的分组中"
)
# 该任务不应出现在其他任何分组中
for other_domain, other_tasks in grouped.items():
if other_domain == task.domain:
continue
other_codes = [t.code for t in other_tasks]
assert task.code not in other_codes, (
f"任务 {task.code}domain={task.domain})错误地出现在 "
f"domain '{other_domain}' 的分组中"
)
# ---------------------------------------------------------------------------
# Property 4.2: 分组总数守恒 — 分组中的任务总数等于全部任务数
# Validates: Requirements 2.1
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(data=st.data())
def test_grouped_total_equals_all_tasks(data):
"""Property 4.2: 分组中的任务总数等于全部任务数(不多不少)。
随机选取若干 domain 进行局部验证,同时验证全局总数守恒。
"""
all_tasks = get_all_tasks()
grouped = get_tasks_grouped_by_domain()
# 全局守恒:分组内任务总数 == 全量任务数
grouped_total = sum(len(tasks) for tasks in grouped.values())
assert grouped_total == len(all_tasks), (
f"分组总数 {grouped_total} != 全量任务数 {len(all_tasks)}"
)
# 随机选取一个 domain验证该 domain 下的任务数量正确
domain = data.draw(st.sampled_from(ALL_DOMAINS))
expected_count = sum(1 for t in all_tasks if t.domain == domain)
actual_count = len(grouped[domain])
assert actual_count == expected_count, (
f"domain '{domain}' 分组内任务数 {actual_count} != 预期 {expected_count}"
)
# ---------------------------------------------------------------------------
# Property 4.3: 分组 key 一致性 — 每个分组的 key 等于组内所有任务的 domain
# Validates: Requirements 2.1
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(data=st.data())
def test_group_key_matches_task_domains(data):
"""Property 4.3: 每个分组的 key 等于该分组内所有任务的 domain。
随机选取一个 domain 分组,验证组内每个任务的 domain 字段都等于分组 key。
"""
grouped = get_tasks_grouped_by_domain()
domain = data.draw(st.sampled_from(list(grouped.keys())))
for task in grouped[domain]:
assert task.domain == domain, (
f"分组 '{domain}' 中的任务 {task.code} 的 domain 为 "
f"'{task.domain}',与分组 key 不一致"
)
# ---------------------------------------------------------------------------
# Property 4.4: 任务 code 全局唯一 — 分组后不应出现重复 code
# Validates: Requirements 2.1
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(data=st.data())
def test_no_duplicate_codes_across_groups(data):
"""Property 4.4: 分组后所有任务的 code 全局唯一,无重复。
随机选取若干 domain 的任务合并,验证 code 不重复。
"""
grouped = get_tasks_grouped_by_domain()
# 收集所有分组中的 code
all_codes_in_groups = []
for tasks in grouped.values():
all_codes_in_groups.extend(t.code for t in tasks)
assert len(all_codes_in_groups) == len(set(all_codes_in_groups)), (
"分组中存在重复的任务 code"
)
# 随机选取两个不同 domain验证它们的任务 code 无交集
if len(ALL_DOMAINS) >= 2:
domains = data.draw(
st.lists(st.sampled_from(ALL_DOMAINS), min_size=2, max_size=2, unique=True)
)
codes_a = {t.code for t in grouped[domains[0]]}
codes_b = {t.code for t in grouped[domains[1]]}
overlap = codes_a & codes_b
assert not overlap, (
f"domain '{domains[0]}''{domains[1]}' 存在重叠任务 code: {overlap}"
)
# ---------------------------------------------------------------------------
# Property 4.5: 随机子集分组一致性 — 子集中的任务分组结果与全量一致
# Validates: Requirements 2.1
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(
indices=st.lists(
st.integers(min_value=0, max_value=len(ALL_TASKS) - 1),
min_size=1,
max_size=min(20, len(ALL_TASKS)),
unique=True,
)
)
def test_subset_grouping_consistency(indices):
"""Property 4.5: 随机选取任务子集,验证每个任务在全量分组中的归属正确。
对于随机选取的任务子集,每个任务在 get_tasks_grouped_by_domain() 的结果中
都应出现在其 domain 对应的分组里。
"""
grouped = get_tasks_grouped_by_domain()
subset = [ALL_TASKS[i] for i in indices]
for task in subset:
# 任务的 domain 必须是分组的 key 之一
assert task.domain in grouped
# 任务必须在对应分组中
group_codes = {t.code for t in grouped[task.domain]}
assert task.code in group_codes, (
f"任务 {task.code} 未出现在 domain '{task.domain}' 的分组中"
)
# ---------------------------------------------------------------------------
# Property 4.6: API 端点分组正确性 — GET /api/tasks/registry 返回一致的分组
# Validates: Requirements 2.1
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(data=st.data())
def test_api_registry_grouping_correctness(data):
"""Property 4.6: API 端点返回的分组中,每个任务的 domain 与分组 key 一致,
且所有任务都出现在结果中。
"""
app.dependency_overrides[get_current_user] = _mock_user
try:
client = TestClient(app)
resp = client.get("/api/tasks/registry")
assert resp.status_code == 200
body = resp.json()
groups = body["groups"]
# 收集 API 返回的所有任务 code
api_codes: set[str] = set()
for domain_key, task_list in groups.items():
for task_item in task_list:
# 每个任务的 domain 必须等于分组 key
assert task_item["domain"] == domain_key, (
f"API 返回的任务 {task_item['code']}domain={task_item['domain']}"
f"出现在分组 '{domain_key}' 中,不一致"
)
api_codes.add(task_item["code"])
# 所有任务都应出现在 API 结果中
all_codes_set = {t.code for t in get_all_tasks()}
assert api_codes == all_codes_set, (
f"API 返回的任务集合与全量任务不一致。"
f"缺失: {all_codes_set - api_codes}"
f"多余: {api_codes - all_codes_set}"
)
# 随机选取一个 domain验证该 domain 下的任务数量与服务层一致
if groups:
domain = data.draw(st.sampled_from(list(groups.keys())))
expected = get_tasks_grouped_by_domain()
assert len(groups[domain]) == len(expected[domain]), (
f"API 返回的 domain '{domain}' 任务数 {len(groups[domain])} "
f"!= 服务层 {len(expected[domain])}"
)
finally:
app.dependency_overrides.pop(get_current_user, None)
# ---------------------------------------------------------------------------
# Property 4.7: 随机 domain 过滤验证
# Validates: Requirements 2.1
# ---------------------------------------------------------------------------
@settings(max_examples=100)
@given(domain=st.sampled_from(ALL_DOMAINS))
def test_random_domain_tasks_all_correct(domain):
"""Property 4.7: 随机选取一个 domain验证该 domain 下的所有任务都正确归属。
对于选定的 domain
- 分组中的每个任务的 domain 字段都等于选定的 domain
- 全量任务中所有属于该 domain 的任务都出现在分组中
"""
grouped = get_tasks_grouped_by_domain()
all_tasks = get_all_tasks()
# 分组中该 domain 的任务
group_tasks = grouped.get(domain, [])
# 全量任务中属于该 domain 的任务
expected_tasks = [t for t in all_tasks if t.domain == domain]
# 数量一致
assert len(group_tasks) == len(expected_tasks), (
f"domain '{domain}': 分组中 {len(group_tasks)} 个任务,"
f"预期 {len(expected_tasks)}"
)
# code 集合一致
group_codes = {t.code for t in group_tasks}
expected_codes = {t.code for t in expected_tasks}
assert group_codes == expected_codes, (
f"domain '{domain}': 分组 codes {group_codes} != 预期 {expected_codes}"
)
# 每个任务的 domain 字段都正确
for task in group_tasks:
assert task.domain == domain

View File

@@ -0,0 +1,274 @@
# -*- coding: utf-8 -*-
"""任务注册表 API 单元测试
覆盖 4 个端点registry / dwd-tables / flows / validate
通过 JWT mock 绕过认证依赖。
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
import pytest
from fastapi.testclient import TestClient
from app.auth.dependencies import CurrentUser, get_current_user
from app.main import app
from app.services.task_registry import (
ALL_TASKS,
DWD_TABLES,
FLOW_LAYER_MAP,
get_tasks_grouped_by_domain,
)
# 固定测试用户
_TEST_USER = CurrentUser(user_id=1, site_id=100)
def _override_auth():
return _TEST_USER
app.dependency_overrides[get_current_user] = _override_auth
client = TestClient(app)
# ---------------------------------------------------------------------------
# GET /api/tasks/registry
# ---------------------------------------------------------------------------
class TestTaskRegistry:
def setup_method(self):
"""每个测试方法前重新设置 auth 覆盖,防止其他测试文件的 clear/pop 导致状态泄漏"""
app.dependency_overrides[get_current_user] = _override_auth
def test_registry_returns_grouped_tasks(self):
resp = client.get("/api/tasks/registry")
assert resp.status_code == 200
data = resp.json()
assert "groups" in data
# 所有任务都应出现在某个分组中
all_codes_in_response = set()
for domain, tasks in data["groups"].items():
for t in tasks:
all_codes_in_response.add(t["code"])
assert t["domain"] == domain
expected_codes = {t.code for t in ALL_TASKS}
assert all_codes_in_response == expected_codes
def test_registry_task_fields_complete(self):
"""每个任务项包含所有必要字段"""
resp = client.get("/api/tasks/registry")
data = resp.json()
required_fields = {"code", "name", "description", "domain", "layer",
"requires_window", "is_ods", "is_dimension", "default_enabled"}
for tasks in data["groups"].values():
for t in tasks:
assert required_fields.issubset(t.keys())
def test_registry_requires_auth(self):
"""未认证时返回 401"""
app.dependency_overrides.pop(get_current_user, None)
try:
resp = client.get("/api/tasks/registry")
assert resp.status_code == 401
finally:
app.dependency_overrides[get_current_user] = _override_auth
# ---------------------------------------------------------------------------
# GET /api/tasks/dwd-tables
# ---------------------------------------------------------------------------
class TestDwdTables:
def test_dwd_tables_returns_grouped(self):
resp = client.get("/api/tasks/dwd-tables")
assert resp.status_code == 200
data = resp.json()
assert "groups" in data
all_tables_in_response = set()
for domain, tables in data["groups"].items():
for t in tables:
all_tables_in_response.add(t["table_name"])
assert t["domain"] == domain
expected_tables = {t.table_name for t in DWD_TABLES}
assert all_tables_in_response == expected_tables
def test_dwd_tables_fields_complete(self):
resp = client.get("/api/tasks/dwd-tables")
data = resp.json()
required_fields = {"table_name", "display_name", "domain", "ods_source", "is_dimension"}
for tables in data["groups"].values():
for t in tables:
assert required_fields.issubset(t.keys())
# ---------------------------------------------------------------------------
# GET /api/tasks/flows
# ---------------------------------------------------------------------------
class TestFlows:
def test_flows_returns_seven_flows(self):
resp = client.get("/api/tasks/flows")
assert resp.status_code == 200
data = resp.json()
assert len(data["flows"]) == 7
def test_flows_returns_three_processing_modes(self):
resp = client.get("/api/tasks/flows")
data = resp.json()
assert len(data["processing_modes"]) == 3
def test_flow_ids_match_registry(self):
"""Flow ID 与 FLOW_LAYER_MAP 一致"""
resp = client.get("/api/tasks/flows")
data = resp.json()
flow_ids = {f["id"] for f in data["flows"]}
assert flow_ids == set(FLOW_LAYER_MAP.keys())
def test_flow_layers_non_empty(self):
resp = client.get("/api/tasks/flows")
data = resp.json()
for f in data["flows"]:
assert len(f["layers"]) > 0
def test_processing_mode_ids(self):
resp = client.get("/api/tasks/flows")
data = resp.json()
mode_ids = {m["id"] for m in data["processing_modes"]}
assert mode_ids == {"increment_only", "verify_only", "increment_verify"}
# ---------------------------------------------------------------------------
# POST /api/tasks/validate
# ---------------------------------------------------------------------------
class TestValidate:
def test_validate_success(self):
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": ["ODS_MEMBER", "ODS_PAYMENT"],
"pipeline": "api_ods",
}
})
assert resp.status_code == 200
data = resp.json()
assert data["valid"] is True
assert data["errors"] == []
assert len(data["command_args"]) > 0
assert "--store-id" in data["command"]
# store_id 应从 JWT 注入(测试用户 site_id=100
assert "100" in data["command"]
def test_validate_injects_store_id(self):
"""即使前端传了 store_id后端也用 JWT 中的值覆盖"""
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": ["DWD_LOAD_FROM_ODS"],
"pipeline": "ods_dwd",
"store_id": 999,
}
})
assert resp.status_code == 200
data = resp.json()
# 命令中应包含 JWT 的 site_id=100而非前端传的 999
assert "--store-id" in data["command"]
idx = data["command_args"].index("--store-id")
assert data["command_args"][idx + 1] == "100"
def test_validate_invalid_flow(self):
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": ["ODS_MEMBER"],
"pipeline": "nonexistent_flow",
}
})
assert resp.status_code == 200
data = resp.json()
assert data["valid"] is False
assert any("无效的执行流程" in e for e in data["errors"])
def test_validate_empty_tasks(self):
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": [],
"pipeline": "api_ods",
}
})
assert resp.status_code == 200
data = resp.json()
assert data["valid"] is False
assert any("任务列表不能为空" in e for e in data["errors"])
def test_validate_custom_window(self):
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": ["ODS_MEMBER"],
"pipeline": "api_ods",
"window_mode": "custom",
"window_start": "2024-01-01",
"window_end": "2024-01-31",
}
})
assert resp.status_code == 200
data = resp.json()
assert data["valid"] is True
assert "--window-start" in data["command"]
assert "--window-end" in data["command"]
def test_validate_window_end_before_start_rejected(self):
"""window_end 早于 window_start 时 Pydantic 验证失败 → 422"""
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": ["ODS_MEMBER"],
"pipeline": "api_ods",
"window_mode": "custom",
"window_start": "2024-12-31",
"window_end": "2024-01-01",
}
})
assert resp.status_code == 422
def test_validate_dry_run_flag(self):
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": ["ODS_MEMBER"],
"pipeline": "api_ods",
"dry_run": True,
}
})
assert resp.status_code == 200
data = resp.json()
assert "--dry-run" in data["command"]
# ---------------------------------------------------------------------------
# task_registry 服务层测试
# ---------------------------------------------------------------------------
class TestTaskRegistryService:
def test_all_tasks_have_unique_codes(self):
codes = [t.code for t in ALL_TASKS]
assert len(codes) == len(set(codes))
def test_grouped_tasks_cover_all(self):
grouped = get_tasks_grouped_by_domain()
all_codes = set()
for tasks in grouped.values():
for t in tasks:
all_codes.add(t.code)
assert all_codes == {t.code for t in ALL_TASKS}
def test_ods_tasks_marked_is_ods(self):
for t in ALL_TASKS:
if t.layer == "ODS":
assert t.is_ods is True
def test_flow_layer_map_covers_all_flows(self):
expected_flows = {"api_ods", "api_ods_dwd", "api_full", "ods_dwd",
"dwd_dws", "dwd_dws_index", "dwd_index"}
assert set(FLOW_LAYER_MAP.keys()) == expected_flows

View File

@@ -0,0 +1,186 @@
# -*- coding: utf-8 -*-
"""WebSocket 日志推送端点测试
测试 /ws/logs/{execution_id} 端点的连接、日志回放、实时推送和断开行为。
利用 TaskExecutor 已有的 subscribe/broadcast 机制进行验证。
"""
from __future__ import annotations
import asyncio
import pytest
from starlette.testclient import TestClient
from starlette.websockets import WebSocketDisconnect
from app.main import app
from app.services.task_executor import task_executor
@pytest.fixture(autouse=True)
def _cleanup_executor():
"""每个测试前后清理 TaskExecutor 内部状态。"""
yield
# 清理所有残留的缓冲区和订阅者
for eid in list(task_executor._log_buffers.keys()):
task_executor.cleanup(eid)
task_executor._subscribers.clear()
task_executor._log_buffers.clear()
class TestWebSocketConnection:
"""WebSocket 连接/断开基本行为"""
def test_connect_and_disconnect(self):
"""客户端能成功建立和关闭 WebSocket 连接。"""
client = TestClient(app)
with client.websocket_connect("/ws/logs/test-exec-001") as ws:
# 连接成功,直接关闭
pass # __exit__ 会关闭连接
def test_connect_registers_subscriber(self):
"""连接后 TaskExecutor 应注册订阅者。"""
client = TestClient(app)
# 预先初始化缓冲区(模拟有任务在运行)
task_executor._log_buffers["test-exec-002"] = []
with client.websocket_connect("/ws/logs/test-exec-002"):
# 连接期间应有订阅者
assert "test-exec-002" in task_executor._subscribers
assert len(task_executor._subscribers["test-exec-002"]) >= 1
class TestLogReplay:
"""历史日志回放"""
def test_replay_existing_logs(self):
"""连接时应先收到内存缓冲区中已有的日志行。"""
eid = "test-exec-replay"
# 预填充日志缓冲区
task_executor._log_buffers[eid] = [
"[stdout] 第一行",
"[stdout] 第二行",
"[stderr] 警告信息",
]
client = TestClient(app)
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
# 应按顺序收到 3 条历史日志
msg1 = ws.receive_text()
msg2 = ws.receive_text()
msg3 = ws.receive_text()
assert msg1 == "[stdout] 第一行"
assert msg2 == "[stdout] 第二行"
assert msg3 == "[stderr] 警告信息"
def test_no_logs_no_replay(self):
"""没有历史日志时不应收到回放消息。"""
eid = "test-exec-empty"
task_executor._log_buffers[eid] = []
client = TestClient(app)
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
# 发送结束信号让连接正常关闭
task_executor._broadcast_end(eid)
# 不应有回放消息,直接收到的是后续的结束信号处理
class TestBroadcastReceive:
"""实时日志广播接收"""
def test_receive_broadcast_messages(self):
"""连接后应能收到 TaskExecutor 广播的实时日志。"""
eid = "test-exec-broadcast"
task_executor._log_buffers[eid] = []
client = TestClient(app)
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
# 模拟 TaskExecutor 广播日志
task_executor._broadcast(eid, "[stdout] 实时日志行1")
task_executor._broadcast(eid, "[stderr] 实时错误行")
msg1 = ws.receive_text()
msg2 = ws.receive_text()
assert msg1 == "[stdout] 实时日志行1"
assert msg2 == "[stderr] 实时错误行"
def test_end_signal_closes_connection(self):
"""收到 None 结束信号后 WebSocket 应正常关闭。"""
eid = "test-exec-end"
task_executor._log_buffers[eid] = []
client = TestClient(app)
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
# 广播一条日志后发送结束信号
task_executor._broadcast(eid, "[stdout] 最后一行")
task_executor._broadcast_end(eid)
msg = ws.receive_text()
assert msg == "[stdout] 最后一行"
def test_replay_then_broadcast(self):
"""先回放历史日志,再接收实时广播。"""
eid = "test-exec-mixed"
task_executor._log_buffers[eid] = ["[stdout] 历史行"]
client = TestClient(app)
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
# 先收到历史回放
replay = ws.receive_text()
assert replay == "[stdout] 历史行"
# 再收到实时广播
task_executor._broadcast(eid, "[stdout] 新行")
task_executor._broadcast_end(eid)
live = ws.receive_text()
assert live == "[stdout] 新行"
class TestLogBroadcasterUnit:
"""直接测试 TaskExecutor 的 subscribe/unsubscribe/broadcast 方法
(作为 LogBroadcaster 功能的单元测试)。
"""
def test_subscribe_creates_queue(self):
eid = "unit-sub"
q = task_executor.subscribe(eid)
assert isinstance(q, asyncio.Queue)
assert eid in task_executor._subscribers
task_executor.unsubscribe(eid, q)
def test_unsubscribe_removes_queue(self):
eid = "unit-unsub"
q = task_executor.subscribe(eid)
task_executor.unsubscribe(eid, q)
# 最后一个订阅者移除后key 也应被清理
assert eid not in task_executor._subscribers
def test_broadcast_delivers_to_all_subscribers(self):
eid = "unit-multi"
q1 = task_executor.subscribe(eid)
q2 = task_executor.subscribe(eid)
task_executor._broadcast(eid, "测试消息")
assert q1.get_nowait() == "测试消息"
assert q2.get_nowait() == "测试消息"
task_executor.unsubscribe(eid, q1)
task_executor.unsubscribe(eid, q2)
def test_broadcast_end_sends_none(self):
eid = "unit-end"
q = task_executor.subscribe(eid)
task_executor._broadcast_end(eid)
assert q.get_nowait() is None
task_executor.unsubscribe(eid, q)
def test_broadcast_no_subscribers_is_safe(self):
"""没有订阅者时广播不应报错。"""
task_executor._broadcast("nonexistent", "无人接收")
task_executor._broadcast_end("nonexistent")