在准备环境前提交次全部更改。
This commit is contained in:
62
apps/backend/tests/test_auth_dependencies.py
Normal file
62
apps/backend/tests/test_auth_dependencies.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
FastAPI 依赖注入 get_current_user 单元测试。
|
||||
|
||||
通过 FastAPI TestClient 验证 Authorization header 处理。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
import pytest
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.auth.jwt import create_access_token, create_refresh_token
|
||||
|
||||
# 构造一个最小 FastAPI 应用用于测试依赖注入
|
||||
_test_app = FastAPI()
|
||||
|
||||
|
||||
@_test_app.get("/protected")
|
||||
async def protected_route(user: CurrentUser = Depends(get_current_user)):
|
||||
return {"user_id": user.user_id, "site_id": user.site_id}
|
||||
|
||||
|
||||
client = TestClient(_test_app)
|
||||
|
||||
|
||||
class TestGetCurrentUser:
|
||||
def test_valid_access_token(self):
|
||||
token = create_access_token(user_id=10, site_id=100)
|
||||
resp = client.get("/protected", headers={"Authorization": f"Bearer {token}"})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["user_id"] == 10
|
||||
assert data["site_id"] == 100
|
||||
|
||||
def test_missing_auth_header_returns_401(self):
|
||||
"""缺少 Authorization header 时返回 401。"""
|
||||
resp = client.get("/protected")
|
||||
assert resp.status_code in (401, 403)
|
||||
|
||||
def test_invalid_token_returns_401(self):
|
||||
resp = client.get(
|
||||
"/protected", headers={"Authorization": "Bearer invalid.token.here"}
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_refresh_token_rejected(self):
|
||||
"""refresh 令牌不能用于访问受保护端点。"""
|
||||
token = create_refresh_token(user_id=1, site_id=1)
|
||||
resp = client.get("/protected", headers={"Authorization": f"Bearer {token}"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_current_user_is_frozen_dataclass(self):
|
||||
"""CurrentUser 是不可变的。"""
|
||||
user = CurrentUser(user_id=1, site_id=2)
|
||||
assert user.user_id == 1
|
||||
assert user.site_id == 2
|
||||
with pytest.raises(AttributeError):
|
||||
user.user_id = 99 # type: ignore[misc]
|
||||
147
apps/backend/tests/test_auth_jwt.py
Normal file
147
apps/backend/tests/test_auth_jwt.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""
|
||||
JWT 认证模块单元测试。
|
||||
|
||||
覆盖:令牌生成、验证、过期、类型校验、密码哈希、依赖注入。
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from jose import jwt as jose_jwt
|
||||
|
||||
# 测试前设置 JWT_SECRET_KEY,避免空密钥
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
from app.auth.jwt import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
create_token_pair,
|
||||
decode_access_token,
|
||||
decode_refresh_token,
|
||||
decode_token,
|
||||
hash_password,
|
||||
verify_password,
|
||||
)
|
||||
from app import config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 密码哈希
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPasswordHashing:
|
||||
def test_hash_and_verify(self):
|
||||
raw = "my_secure_password"
|
||||
hashed = hash_password(raw)
|
||||
assert verify_password(raw, hashed)
|
||||
|
||||
def test_wrong_password_rejected(self):
|
||||
hashed = hash_password("correct")
|
||||
assert not verify_password("wrong", hashed)
|
||||
|
||||
def test_hash_is_not_plaintext(self):
|
||||
raw = "plaintext123"
|
||||
hashed = hash_password(raw)
|
||||
assert hashed != raw
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 令牌生成与解码
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTokenCreation:
|
||||
def test_access_token_contains_expected_fields(self):
|
||||
token = create_access_token(user_id=1, site_id=100)
|
||||
payload = decode_token(token)
|
||||
assert payload["sub"] == "1"
|
||||
assert payload["site_id"] == 100
|
||||
assert payload["type"] == "access"
|
||||
assert "exp" in payload
|
||||
|
||||
def test_refresh_token_contains_expected_fields(self):
|
||||
token = create_refresh_token(user_id=2, site_id=200)
|
||||
payload = decode_token(token)
|
||||
assert payload["sub"] == "2"
|
||||
assert payload["site_id"] == 200
|
||||
assert payload["type"] == "refresh"
|
||||
assert "exp" in payload
|
||||
|
||||
def test_token_pair_returns_both_tokens(self):
|
||||
pair = create_token_pair(user_id=3, site_id=300)
|
||||
assert "access_token" in pair
|
||||
assert "refresh_token" in pair
|
||||
assert pair["token_type"] == "bearer"
|
||||
|
||||
# 验证两个令牌类型不同
|
||||
access_payload = decode_token(pair["access_token"])
|
||||
refresh_payload = decode_token(pair["refresh_token"])
|
||||
assert access_payload["type"] == "access"
|
||||
assert refresh_payload["type"] == "refresh"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 令牌类型校验
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTokenTypeValidation:
|
||||
def test_decode_access_token_rejects_refresh(self):
|
||||
"""access 解码器拒绝 refresh 令牌。"""
|
||||
token = create_refresh_token(user_id=1, site_id=1)
|
||||
with pytest.raises(Exception):
|
||||
decode_access_token(token)
|
||||
|
||||
def test_decode_refresh_token_rejects_access(self):
|
||||
"""refresh 解码器拒绝 access 令牌。"""
|
||||
token = create_access_token(user_id=1, site_id=1)
|
||||
with pytest.raises(Exception):
|
||||
decode_refresh_token(token)
|
||||
|
||||
def test_decode_access_token_accepts_access(self):
|
||||
token = create_access_token(user_id=5, site_id=50)
|
||||
payload = decode_access_token(token)
|
||||
assert payload["sub"] == "5"
|
||||
assert payload["site_id"] == 50
|
||||
|
||||
def test_decode_refresh_token_accepts_refresh(self):
|
||||
token = create_refresh_token(user_id=6, site_id=60)
|
||||
payload = decode_refresh_token(token)
|
||||
assert payload["sub"] == "6"
|
||||
assert payload["site_id"] == 60
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 令牌过期
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTokenExpiry:
|
||||
def test_expired_token_rejected(self):
|
||||
"""手动构造已过期令牌,验证解码失败。"""
|
||||
payload = {
|
||||
"sub": "1",
|
||||
"site_id": 1,
|
||||
"type": "access",
|
||||
"exp": int(time.time()) - 10, # 10 秒前过期
|
||||
}
|
||||
token = jose_jwt.encode(
|
||||
payload, config.JWT_SECRET_KEY, algorithm=config.JWT_ALGORITHM
|
||||
)
|
||||
with pytest.raises(Exception):
|
||||
decode_token(token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 无效令牌
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestInvalidToken:
|
||||
def test_garbage_token_rejected(self):
|
||||
with pytest.raises(Exception):
|
||||
decode_token("not.a.valid.jwt")
|
||||
|
||||
def test_wrong_secret_rejected(self):
|
||||
"""用不同密钥签发的令牌应被拒绝。"""
|
||||
payload = {"sub": "1", "site_id": 1, "type": "access", "exp": int(time.time()) + 3600}
|
||||
token = jose_jwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||||
with pytest.raises(Exception):
|
||||
decode_token(token)
|
||||
137
apps/backend/tests/test_auth_properties.py
Normal file
137
apps/backend/tests/test_auth_properties.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
认证模块属性测试(Property-Based Testing)。
|
||||
|
||||
使用 hypothesis 验证认证系统的通用正确性属性:
|
||||
- Property 2: 无效凭据始终被拒绝
|
||||
- Property 3: 有效 JWT 令牌授权访问
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-property-tests")
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.auth.jwt import create_access_token
|
||||
from app.main import app
|
||||
from app.routers.auth import router
|
||||
|
||||
# 确保路由已挂载
|
||||
if router not in [r for r in app.routes]:
|
||||
app.include_router(router)
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 策略(Strategies)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 用户名策略:1~64 字符的可打印字符串(排除控制字符)
|
||||
_username_st = st.text(
|
||||
alphabet=st.characters(whitelist_categories=("L", "N", "P", "S")),
|
||||
min_size=1,
|
||||
max_size=64,
|
||||
)
|
||||
|
||||
# 密码策略:1~128 字符的可打印字符串
|
||||
_password_st = st.text(
|
||||
alphabet=st.characters(whitelist_categories=("L", "N", "P", "S")),
|
||||
min_size=1,
|
||||
max_size=128,
|
||||
)
|
||||
|
||||
# user_id 策略:正整数
|
||||
_user_id_st = st.integers(min_value=1, max_value=2**31 - 1)
|
||||
|
||||
# site_id 策略:正整数
|
||||
_site_id_st = st.integers(min_value=1, max_value=2**63 - 1)
|
||||
|
||||
|
||||
def _mock_db_returning(row):
|
||||
"""构造 mock get_connection,cursor.fetchone() 返回指定行。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = row
|
||||
mock_conn.cursor.return_value.__enter__ = lambda _: mock_cursor
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 2: 无效凭据始终被拒绝
|
||||
# **Validates: Requirements 1.2**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(username=_username_st, password=_password_st)
|
||||
@patch("app.routers.auth.get_connection")
|
||||
def test_invalid_credentials_always_rejected(mock_get_conn, username, password):
|
||||
"""
|
||||
Property 2: 无效凭据始终被拒绝。
|
||||
|
||||
对于任意用户名/密码组合,当数据库中不存在该用户时(fetchone 返回 None),
|
||||
登录接口应始终返回 401 状态码。
|
||||
"""
|
||||
# mock 数据库返回 None — 用户不存在
|
||||
mock_get_conn.return_value = _mock_db_returning(None)
|
||||
|
||||
resp = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": username, "password": password},
|
||||
)
|
||||
assert resp.status_code == 401, (
|
||||
f"期望 401,实际 {resp.status_code},username={username!r}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 3: 有效 JWT 令牌授权访问
|
||||
# **Validates: Requirements 1.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
import asyncio
|
||||
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
|
||||
|
||||
def _run_async(coro):
|
||||
"""在同步上下文中执行异步协程,避免 DeprecationWarning。"""
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(user_id=_user_id_st, site_id=_site_id_st)
|
||||
def test_valid_jwt_grants_access(user_id, site_id):
|
||||
"""
|
||||
Property 3: 有效 JWT 令牌授权访问。
|
||||
|
||||
对于任意 user_id 和 site_id,由系统签发的未过期 access_token
|
||||
应能被 get_current_user 依赖成功解析为 CurrentUser 对象,
|
||||
且解析出的 user_id 和 site_id 与签发时一致。
|
||||
"""
|
||||
# 生成有效的 access_token
|
||||
token = create_access_token(user_id=user_id, site_id=site_id)
|
||||
|
||||
# 直接调用依赖函数验证令牌解析
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
|
||||
result = _run_async(get_current_user(credentials))
|
||||
|
||||
assert isinstance(result, CurrentUser)
|
||||
assert result.user_id == user_id, (
|
||||
f"user_id 不匹配:期望 {user_id},实际 {result.user_id}"
|
||||
)
|
||||
assert result.site_id == site_id, (
|
||||
f"site_id 不匹配:期望 {site_id},实际 {result.site_id}"
|
||||
)
|
||||
167
apps/backend/tests/test_auth_router.py
Normal file
167
apps/backend/tests/test_auth_router.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
认证路由单元测试。
|
||||
|
||||
覆盖:登录成功/失败、刷新令牌、账号禁用等场景。
|
||||
通过 mock 数据库连接避免依赖真实 PostgreSQL。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.jwt import (
|
||||
create_refresh_token,
|
||||
decode_access_token,
|
||||
decode_refresh_token,
|
||||
hash_password,
|
||||
)
|
||||
from app.main import app
|
||||
from app.routers.auth import router
|
||||
|
||||
# 注册路由到 app(测试时确保路由已挂载)
|
||||
if router not in [r for r in app.routes]:
|
||||
app.include_router(router)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# 测试用固定数据
|
||||
_TEST_PASSWORD = "correct_password"
|
||||
_TEST_HASH = hash_password(_TEST_PASSWORD)
|
||||
_TEST_USER_ROW = (1, _TEST_HASH, 100, True) # id, password_hash, site_id, is_active
|
||||
_DISABLED_USER_ROW = (2, _TEST_HASH, 200, False)
|
||||
|
||||
|
||||
def _mock_db_returning(row):
|
||||
"""构造一个 mock get_connection,cursor.fetchone() 返回指定行。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = row
|
||||
mock_conn.cursor.return_value.__enter__ = lambda _: mock_cursor
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/auth/login
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLogin:
|
||||
@patch("app.routers.auth.get_connection")
|
||||
def test_login_success(self, mock_get_conn):
|
||||
mock_get_conn.return_value = _mock_db_returning(_TEST_USER_ROW)
|
||||
|
||||
resp = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "admin", "password": _TEST_PASSWORD},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "access_token" in data
|
||||
assert "refresh_token" in data
|
||||
assert data["token_type"] == "bearer"
|
||||
|
||||
# 验证 access_token payload 包含正确的 user_id 和 site_id
|
||||
payload = decode_access_token(data["access_token"])
|
||||
assert payload["sub"] == "1"
|
||||
assert payload["site_id"] == 100
|
||||
|
||||
@patch("app.routers.auth.get_connection")
|
||||
def test_login_user_not_found(self, mock_get_conn):
|
||||
"""用户不存在时返回 401。"""
|
||||
mock_get_conn.return_value = _mock_db_returning(None)
|
||||
|
||||
resp = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "nonexistent", "password": "whatever"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert "用户名或密码错误" in resp.json()["detail"]
|
||||
|
||||
@patch("app.routers.auth.get_connection")
|
||||
def test_login_wrong_password(self, mock_get_conn):
|
||||
"""密码错误时返回 401。"""
|
||||
mock_get_conn.return_value = _mock_db_returning(_TEST_USER_ROW)
|
||||
|
||||
resp = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "admin", "password": "wrong_password"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert "用户名或密码错误" in resp.json()["detail"]
|
||||
|
||||
@patch("app.routers.auth.get_connection")
|
||||
def test_login_disabled_account(self, mock_get_conn):
|
||||
"""账号已禁用时返回 401。"""
|
||||
mock_get_conn.return_value = _mock_db_returning(_DISABLED_USER_ROW)
|
||||
|
||||
resp = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "disabled_user", "password": _TEST_PASSWORD},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert "禁用" in resp.json()["detail"]
|
||||
|
||||
def test_login_missing_username(self):
|
||||
"""缺少 username 字段时返回 422。"""
|
||||
resp = client.post("/api/auth/login", json={"password": "test"})
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_login_empty_password(self):
|
||||
"""空密码时返回 422。"""
|
||||
resp = client.post(
|
||||
"/api/auth/login", json={"username": "admin", "password": ""}
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/auth/refresh
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRefresh:
|
||||
def test_refresh_success(self):
|
||||
"""有效的 refresh_token 换取新的 access_token。"""
|
||||
refresh = create_refresh_token(user_id=5, site_id=50)
|
||||
|
||||
resp = client.post(
|
||||
"/api/auth/refresh", json={"refresh_token": refresh}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "access_token" in data
|
||||
# refresh_token 原样返回
|
||||
assert data["refresh_token"] == refresh
|
||||
assert data["token_type"] == "bearer"
|
||||
|
||||
# 新 access_token 包含正确信息
|
||||
payload = decode_access_token(data["access_token"])
|
||||
assert payload["sub"] == "5"
|
||||
assert payload["site_id"] == 50
|
||||
|
||||
def test_refresh_with_invalid_token(self):
|
||||
"""无效令牌返回 401。"""
|
||||
resp = client.post(
|
||||
"/api/auth/refresh", json={"refresh_token": "garbage.token.here"}
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert "无效的刷新令牌" in resp.json()["detail"]
|
||||
|
||||
def test_refresh_with_access_token_rejected(self):
|
||||
"""用 access_token 做刷新应被拒绝。"""
|
||||
from app.auth.jwt import create_access_token
|
||||
|
||||
access = create_access_token(user_id=1, site_id=1)
|
||||
resp = client.post(
|
||||
"/api/auth/refresh", json={"refresh_token": access}
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_refresh_missing_token(self):
|
||||
"""缺少 refresh_token 字段时返回 422。"""
|
||||
resp = client.post("/api/auth/refresh", json={})
|
||||
assert resp.status_code == 422
|
||||
259
apps/backend/tests/test_cli_builder.py
Normal file
259
apps/backend/tests/test_cli_builder.py
Normal file
@@ -0,0 +1,259 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""CLIBuilder 单元测试
|
||||
|
||||
覆盖:7 种 Flow、3 种处理模式、时间窗口、store_id 自动注入、extra_args 等。
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.schemas.tasks import TaskConfigSchema
|
||||
from app.services.cli_builder import CLIBuilder, VALID_FLOWS, VALID_PROCESSING_MODES
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def builder() -> CLIBuilder:
|
||||
return CLIBuilder()
|
||||
|
||||
|
||||
ETL_PATH = "/fake/etl/project"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 基本命令结构
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBasicCommand:
|
||||
def test_minimal_command(self, builder: CLIBuilder):
|
||||
"""最小配置应生成 python -m cli.main --pipeline ... --processing-mode ..."""
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"])
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert cmd[:3] == ["python", "-m", "cli.main"]
|
||||
assert "--pipeline" in cmd
|
||||
assert "--processing-mode" in cmd
|
||||
|
||||
def test_custom_python_executable(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"])
|
||||
cmd = builder.build_command(config, ETL_PATH, python_executable="python3")
|
||||
assert cmd[0] == "python3"
|
||||
|
||||
def test_tasks_joined_by_comma(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER", "ODS_PAYMENT", "ODS_REFUND"])
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx = cmd.index("--tasks")
|
||||
assert cmd[idx + 1] == "ODS_MEMBER,ODS_PAYMENT,ODS_REFUND"
|
||||
|
||||
def test_empty_tasks_no_tasks_arg(self, builder: CLIBuilder):
|
||||
"""空任务列表不应生成 --tasks 参数"""
|
||||
config = TaskConfigSchema(tasks=[])
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--tasks" not in cmd
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7 种 Flow
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFlows:
|
||||
@pytest.mark.parametrize("flow_id", sorted(VALID_FLOWS))
|
||||
def test_all_flows_accepted(self, builder: CLIBuilder, flow_id: str):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"], pipeline=flow_id)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx = cmd.index("--pipeline")
|
||||
assert cmd[idx + 1] == flow_id
|
||||
|
||||
def test_default_flow_is_api_ods_dwd(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"])
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx = cmd.index("--pipeline")
|
||||
assert cmd[idx + 1] == "api_ods_dwd"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3 种处理模式
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestProcessingModes:
|
||||
@pytest.mark.parametrize("mode", sorted(VALID_PROCESSING_MODES))
|
||||
def test_all_modes_accepted(self, builder: CLIBuilder, mode: str):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"], processing_mode=mode)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx = cmd.index("--processing-mode")
|
||||
assert cmd[idx + 1] == mode
|
||||
|
||||
def test_fetch_before_verify_only_in_verify_mode(self, builder: CLIBuilder):
|
||||
"""--fetch-before-verify 仅在 verify_only 模式下生效"""
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
processing_mode="verify_only",
|
||||
fetch_before_verify=True,
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--fetch-before-verify" in cmd
|
||||
|
||||
def test_fetch_before_verify_ignored_in_increment_mode(self, builder: CLIBuilder):
|
||||
"""increment_only 模式下 fetch_before_verify=True 不应生成参数"""
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
processing_mode="increment_only",
|
||||
fetch_before_verify=True,
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--fetch-before-verify" not in cmd
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 时间窗口
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTimeWindow:
|
||||
def test_lookback_mode_generates_lookback_args(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
window_mode="lookback",
|
||||
lookback_hours=48,
|
||||
overlap_seconds=1200,
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx_lb = cmd.index("--lookback-hours")
|
||||
assert cmd[idx_lb + 1] == "48"
|
||||
idx_ol = cmd.index("--overlap-seconds")
|
||||
assert cmd[idx_ol + 1] == "1200"
|
||||
# lookback 模式不应生成 --window-start / --window-end
|
||||
assert "--window-start" not in cmd
|
||||
assert "--window-end" not in cmd
|
||||
|
||||
def test_custom_mode_generates_window_args(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
window_mode="custom",
|
||||
window_start="2026-01-01",
|
||||
window_end="2026-01-31",
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx_s = cmd.index("--window-start")
|
||||
assert cmd[idx_s + 1] == "2026-01-01"
|
||||
idx_e = cmd.index("--window-end")
|
||||
assert cmd[idx_e + 1] == "2026-01-31"
|
||||
# custom 模式不应生成 --lookback-hours / --overlap-seconds
|
||||
assert "--lookback-hours" not in cmd
|
||||
assert "--overlap-seconds" not in cmd
|
||||
|
||||
def test_window_split_with_days(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
window_split="day",
|
||||
window_split_days=10,
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx = cmd.index("--window-split")
|
||||
assert cmd[idx + 1] == "day"
|
||||
idx_d = cmd.index("--window-split-days")
|
||||
assert cmd[idx_d + 1] == "10"
|
||||
|
||||
def test_window_split_none_not_generated(self, builder: CLIBuilder):
|
||||
"""window_split='none' 不应生成 --window-split 参数"""
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"], window_split="none")
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--window-split" not in cmd
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# store_id 自动注入
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestStoreId:
|
||||
def test_store_id_injected(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"], store_id=42)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx = cmd.index("--store-id")
|
||||
assert cmd[idx + 1] == "42"
|
||||
|
||||
def test_store_id_none_not_generated(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"], store_id=None)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--store-id" not in cmd
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# dry_run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDryRun:
|
||||
def test_dry_run_flag(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"], dry_run=True)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--dry-run" in cmd
|
||||
|
||||
def test_no_dry_run_flag(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"], dry_run=False)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--dry-run" not in cmd
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extra_args
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtraArgs:
|
||||
def test_supported_value_arg(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
extra_args={"pg_dsn": "postgresql://localhost/test"},
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx = cmd.index("--pg-dsn")
|
||||
assert cmd[idx + 1] == "postgresql://localhost/test"
|
||||
|
||||
def test_supported_bool_arg(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
extra_args={"force_window_override": True},
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--force-window-override" in cmd
|
||||
|
||||
def test_unsupported_arg_ignored(self, builder: CLIBuilder):
|
||||
"""不在 CLI_SUPPORTED_ARGS 中的键应被忽略"""
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
extra_args={"unknown_param": "value"},
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--unknown-param" not in cmd
|
||||
|
||||
def test_none_value_ignored(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
extra_args={"pg_dsn": None},
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--pg-dsn" not in cmd
|
||||
|
||||
def test_false_bool_arg_not_generated(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
extra_args={"force_window_override": False},
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--force-window-override" not in cmd
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_command_string
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBuildCommandString:
|
||||
def test_returns_string(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"])
|
||||
result = builder.build_command_string(config, ETL_PATH)
|
||||
assert isinstance(result, str)
|
||||
assert "python -m cli.main" in result
|
||||
|
||||
def test_quotes_args_with_spaces(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
extra_args={"pg_dsn": "host=localhost dbname=test"},
|
||||
)
|
||||
result = builder.build_command_string(config, ETL_PATH)
|
||||
# 包含空格的值应被引号包裹
|
||||
assert '"host=localhost dbname=test"' in result
|
||||
94
apps/backend/tests/test_database.py
Normal file
94
apps/backend/tests/test_database.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
数据库连接模块单元测试。
|
||||
|
||||
覆盖:ETL 只读连接的创建、RLS site_id 设置、只读模式、异常处理。
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
from app.database import get_etl_readonly_connection
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_etl_readonly_connection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetEtlReadonlyConnection:
|
||||
"""ETL 只读连接:验证连接参数、只读设置、RLS 隔离。"""
|
||||
|
||||
@patch("app.database.psycopg2.connect")
|
||||
def test_sets_readonly_and_site_id(self, mock_connect):
|
||||
"""连接后应依次执行 SET read_only 和 SET LOCAL site_id。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
conn = get_etl_readonly_connection(site_id=42)
|
||||
|
||||
# 验证 autocommit 被关闭
|
||||
assert mock_conn.autocommit is False
|
||||
|
||||
# 验证执行了两条 SET 语句
|
||||
executed = [c.args[0] for c in mock_cursor.execute.call_args_list]
|
||||
assert "SET default_transaction_read_only = on" in executed[0]
|
||||
assert "SET LOCAL app.current_site_id" in executed[1]
|
||||
|
||||
# 验证 site_id 参数化传递(防 SQL 注入)
|
||||
site_id_call = mock_cursor.execute.call_args_list[1]
|
||||
assert site_id_call.args[1] == ("42",)
|
||||
|
||||
# 验证提交
|
||||
mock_conn.commit.assert_called_once()
|
||||
assert conn is mock_conn
|
||||
|
||||
@patch("app.database.psycopg2.connect")
|
||||
def test_accepts_string_site_id(self, mock_connect):
|
||||
"""site_id 为字符串时也应正常工作。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
get_etl_readonly_connection(site_id="99")
|
||||
|
||||
site_id_call = mock_cursor.execute.call_args_list[1]
|
||||
assert site_id_call.args[1] == ("99",)
|
||||
|
||||
@patch("app.database.psycopg2.connect")
|
||||
def test_closes_connection_on_setup_error(self, mock_connect):
|
||||
"""SET 语句执行失败时应关闭连接并抛出异常。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.execute.side_effect = Exception("SET failed")
|
||||
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
with pytest.raises(Exception, match="SET failed"):
|
||||
get_etl_readonly_connection(site_id=1)
|
||||
|
||||
mock_conn.close.assert_called_once()
|
||||
|
||||
@patch("app.database.psycopg2.connect")
|
||||
def test_uses_etl_config_params(self, mock_connect):
|
||||
"""应使用 ETL_DB_* 配置项连接。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
get_etl_readonly_connection(site_id=1)
|
||||
|
||||
connect_kwargs = mock_connect.call_args.kwargs
|
||||
# 验证使用了 ETL 数据库名(默认 etl_feiqiu)
|
||||
assert connect_kwargs["dbname"] == "etl_feiqiu"
|
||||
139
apps/backend/tests/test_db_viewer_properties.py
Normal file
139
apps/backend/tests/test_db_viewer_properties.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""数据库查看器属性测试(Property-Based Testing)。
|
||||
|
||||
使用 hypothesis 验证数据库查看器的通用正确性属性:
|
||||
- Property 17: SQL 写操作拦截
|
||||
- Property 18: SQL 查询结果行数限制
|
||||
|
||||
测试策略:
|
||||
- Property 17: 生成包含写操作关键词(随机大小写混合)的 SQL 字符串,
|
||||
验证 _WRITE_KEYWORDS 正则表达式能匹配到
|
||||
- Property 18: 生成随机长度的行列表(可能超过 1000 行),
|
||||
验证截取前 _MAX_ROWS 个元素后长度 <= 1000
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-db-viewer-properties")
|
||||
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from app.routers.db_viewer import _WRITE_KEYWORDS, _MAX_ROWS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 通用策略(Strategies)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 写操作关键词列表
|
||||
_WRITE_OPS = ["INSERT", "UPDATE", "DELETE", "DROP", "TRUNCATE"]
|
||||
|
||||
# SQL 前缀/后缀:不含写操作关键词的简单文本
|
||||
_sql_filler_st = st.text(
|
||||
alphabet=st.characters(
|
||||
whitelist_categories=("L", "N", "S"),
|
||||
blacklist_characters="\x00",
|
||||
),
|
||||
min_size=0,
|
||||
max_size=50,
|
||||
)
|
||||
|
||||
# 随机大小写混合的写操作关键词
|
||||
_random_case_keyword_st = st.sampled_from(_WRITE_OPS).flatmap(
|
||||
lambda kw: st.tuples(
|
||||
st.just(kw),
|
||||
st.lists(
|
||||
st.booleans(),
|
||||
min_size=len(kw),
|
||||
max_size=len(kw),
|
||||
),
|
||||
).map(
|
||||
lambda pair: "".join(
|
||||
c.upper() if flag else c.lower()
|
||||
for c, flag in zip(pair[0], pair[1])
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 17: SQL 写操作拦截
|
||||
# **Validates: Requirements 7.5**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=200)
|
||||
@given(
|
||||
prefix=_sql_filler_st,
|
||||
keyword=_random_case_keyword_st,
|
||||
suffix=_sql_filler_st,
|
||||
)
|
||||
def test_write_keywords_always_detected(prefix, keyword, suffix):
|
||||
"""Property 17: SQL 写操作拦截。
|
||||
|
||||
包含 INSERT、UPDATE、DELETE、DROP、TRUNCATE 关键词(不区分大小写)的
|
||||
SQL 语句,_WRITE_KEYWORDS 正则表达式应能匹配到。
|
||||
|
||||
策略:在随机前缀和后缀之间插入一个随机大小写混合的写操作关键词,
|
||||
用空格分隔以确保 \\b 词边界能匹配。
|
||||
"""
|
||||
# 用空格分隔确保词边界匹配
|
||||
sql = f"{prefix} {keyword} {suffix}"
|
||||
|
||||
match = _WRITE_KEYWORDS.search(sql)
|
||||
assert match is not None, (
|
||||
f"正则表达式未能匹配到写操作关键词:sql={sql!r}, keyword={keyword!r}"
|
||||
)
|
||||
# 匹配到的关键词(转大写后)应在写操作列表中
|
||||
assert match.group(1).upper() in _WRITE_OPS, (
|
||||
f"匹配到的关键词 '{match.group(1)}' 不在写操作列表中"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 18: SQL 查询结果行数限制
|
||||
# **Validates: Requirements 7.4**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 模拟数据库返回的行:每行是一个简单列表
|
||||
_row_st = st.lists(
|
||||
st.one_of(st.integers(), st.text(max_size=20), st.none()),
|
||||
min_size=1,
|
||||
max_size=5,
|
||||
)
|
||||
|
||||
# 行列表策略:0 到 3000 行(覆盖超过 _MAX_ROWS 的情况)
|
||||
_rows_st = st.lists(_row_st, min_size=0, max_size=3000)
|
||||
|
||||
|
||||
@settings(max_examples=200)
|
||||
@given(rows=_rows_st)
|
||||
def test_row_count_never_exceeds_max(rows):
|
||||
"""Property 18: SQL 查询结果行数限制。
|
||||
|
||||
对任意长度的行列表,取前 _MAX_ROWS 个元素后,
|
||||
结果长度应 <= 1000。
|
||||
|
||||
这等价于 cur.fetchmany(_MAX_ROWS) 的行为:
|
||||
数据库游标最多返回 _MAX_ROWS 行。
|
||||
"""
|
||||
# 模拟 fetchmany(_MAX_ROWS) 的行为
|
||||
truncated = rows[:_MAX_ROWS]
|
||||
|
||||
assert len(truncated) <= _MAX_ROWS, (
|
||||
f"截取后行数 {len(truncated)} 超过上限 {_MAX_ROWS}"
|
||||
)
|
||||
|
||||
# 额外验证:如果原始行数 <= _MAX_ROWS,截取后应保留全部
|
||||
if len(rows) <= _MAX_ROWS:
|
||||
assert len(truncated) == len(rows), (
|
||||
f"原始行数 {len(rows)} <= {_MAX_ROWS},截取后应保留全部,"
|
||||
f"实际 {len(truncated)}"
|
||||
)
|
||||
|
||||
# 额外验证:如果原始行数 > _MAX_ROWS,截取后应恰好为 _MAX_ROWS
|
||||
if len(rows) > _MAX_ROWS:
|
||||
assert len(truncated) == _MAX_ROWS, (
|
||||
f"原始行数 {len(rows)} > {_MAX_ROWS},截取后应恰好为 {_MAX_ROWS},"
|
||||
f"实际 {len(truncated)}"
|
||||
)
|
||||
321
apps/backend/tests/test_db_viewer_router.py
Normal file
321
apps/backend/tests/test_db_viewer_router.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""数据库查看器路由单元测试
|
||||
|
||||
覆盖 4 个端点:
|
||||
- GET /api/db/schemas
|
||||
- GET /api/db/schemas/{name}/tables
|
||||
- GET /api/db/tables/{schema}/{table}/columns
|
||||
- POST /api/db/query
|
||||
|
||||
通过 mock 绕过数据库连接,专注路由逻辑验证。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from psycopg2 import errors as pg_errors
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
|
||||
_TEST_USER = CurrentUser(user_id=1, site_id=100)
|
||||
|
||||
|
||||
def _override_auth():
|
||||
return _TEST_USER
|
||||
|
||||
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
client = TestClient(app)
|
||||
|
||||
_MOCK_CONN = "app.routers.db_viewer.get_etl_readonly_connection"
|
||||
|
||||
|
||||
def _make_mock_conn(rows, description=None):
|
||||
"""构造 mock 数据库连接,cursor 返回指定行和列描述。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
mock_cur.fetchall.return_value = rows
|
||||
mock_cur.fetchmany.return_value = rows
|
||||
mock_cur.description = description
|
||||
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cur
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_conn, mock_cur
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/db/schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListSchemas:
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_returns_schema_list(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([("dwd",), ("dws",), ("ods",)])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/schemas")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 3
|
||||
assert data[0]["name"] == "dwd"
|
||||
assert data[2]["name"] == "ods"
|
||||
|
||||
# 验证 site_id 传递
|
||||
mock_get_conn.assert_called_once_with(_TEST_USER.site_id)
|
||||
conn.close.assert_called_once()
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_empty_schemas(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/schemas")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/db/schemas/{name}/tables
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListTables:
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_returns_tables_with_row_count(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([
|
||||
("dim_member", 1500),
|
||||
("fact_order", 32000),
|
||||
])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/schemas/dwd/tables")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 2
|
||||
assert data[0]["name"] == "dim_member"
|
||||
assert data[0]["row_count"] == 1500
|
||||
assert data[1]["name"] == "fact_order"
|
||||
assert data[1]["row_count"] == 32000
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_null_row_count(self, mock_get_conn):
|
||||
"""pg_stat_user_tables 可能没有统计信息,row_count 为 None。"""
|
||||
conn, cur = _make_mock_conn([("new_table", None)])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/schemas/ods/tables")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data[0]["row_count"] is None
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_empty_schema(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/schemas/empty_schema/tables")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/db/tables/{schema}/{table}/columns
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListColumns:
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_returns_column_definitions(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([
|
||||
("id", "bigint", "NO", None),
|
||||
("name", "character varying", "YES", None),
|
||||
("created_at", "timestamp with time zone", "NO", "now()"),
|
||||
])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/tables/dwd/dim_member/columns")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 3
|
||||
|
||||
assert data[0]["name"] == "id"
|
||||
assert data[0]["data_type"] == "bigint"
|
||||
assert data[0]["is_nullable"] is False
|
||||
assert data[0]["column_default"] is None
|
||||
|
||||
assert data[1]["is_nullable"] is True
|
||||
assert data[2]["column_default"] == "now()"
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_empty_table(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/tables/dwd/nonexistent/columns")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/db/query
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExecuteQuery:
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_successful_select(self, mock_get_conn):
|
||||
description = [("id",), ("name",)]
|
||||
conn, cur = _make_mock_conn(
|
||||
[(1, "Alice"), (2, "Bob")],
|
||||
description=description,
|
||||
)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.post("/api/db/query", json={"sql": "SELECT id, name FROM users"})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["columns"] == ["id", "name"]
|
||||
assert data["rows"] == [[1, "Alice"], [2, "Bob"]]
|
||||
assert data["row_count"] == 2
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_empty_result(self, mock_get_conn):
|
||||
description = [("id",)]
|
||||
conn, cur = _make_mock_conn([], description=description)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.post("/api/db/query", json={"sql": "SELECT id FROM empty_table"})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["columns"] == ["id"]
|
||||
assert data["rows"] == []
|
||||
assert data["row_count"] == 0
|
||||
|
||||
# ── 写操作拦截 ──
|
||||
|
||||
@pytest.mark.parametrize("keyword", [
|
||||
"INSERT", "UPDATE", "DELETE", "DROP", "TRUNCATE",
|
||||
"insert", "update", "delete", "drop", "truncate",
|
||||
"Insert", "Update", "Delete", "Drop", "Truncate",
|
||||
])
|
||||
def test_blocks_write_operations(self, keyword):
|
||||
resp = client.post("/api/db/query", json={"sql": f"{keyword} INTO some_table VALUES (1)"})
|
||||
assert resp.status_code == 400
|
||||
assert "只读" in resp.json()["detail"] or "禁止" in resp.json()["detail"]
|
||||
|
||||
def test_blocks_mixed_case_write(self):
|
||||
resp = client.post("/api/db/query", json={"sql": "DeLeTe FROM users WHERE id = 1"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_blocks_write_in_subquery(self):
|
||||
"""写操作关键词出现在 SQL 任意位置都应拦截。"""
|
||||
resp = client.post("/api/db/query", json={"sql": "SELECT * FROM (DELETE FROM users) sub"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
# ── 空 SQL ──
|
||||
|
||||
def test_empty_sql(self):
|
||||
resp = client.post("/api/db/query", json={"sql": ""})
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_whitespace_only_sql(self):
|
||||
resp = client.post("/api/db/query", json={"sql": " "})
|
||||
assert resp.status_code == 400
|
||||
|
||||
# ── SQL 语法错误 ──
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_sql_syntax_error(self, mock_get_conn):
|
||||
conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
# 第一次 execute 设置 timeout 成功,第二次抛异常
|
||||
mock_cur.execute.side_effect = [None, Exception("syntax error at or near \"SELEC\"")]
|
||||
mock_cur.description = None
|
||||
conn.cursor.return_value.__enter__ = lambda s: mock_cur
|
||||
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.post("/api/db/query", json={"sql": "SELEC * FROM users"})
|
||||
assert resp.status_code == 400
|
||||
assert "SQL 执行错误" in resp.json()["detail"]
|
||||
|
||||
# ── 查询超时 ──
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_query_timeout(self, mock_get_conn):
|
||||
conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
mock_cur.execute.side_effect = [None, pg_errors.QueryCanceled()]
|
||||
mock_cur.description = None
|
||||
conn.cursor.return_value.__enter__ = lambda s: mock_cur
|
||||
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.post("/api/db/query", json={"sql": "SELECT pg_sleep(60)"})
|
||||
assert resp.status_code == 408
|
||||
assert "超时" in resp.json()["detail"]
|
||||
|
||||
# ── 行数限制验证 ──
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_row_limit(self, mock_get_conn):
|
||||
"""验证 fetchmany 被调用时传入 1000 行限制。"""
|
||||
description = [("id",)]
|
||||
conn, cur = _make_mock_conn(
|
||||
[(i,) for i in range(1000)],
|
||||
description=description,
|
||||
)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.post("/api/db/query", json={"sql": "SELECT id FROM big_table"})
|
||||
assert resp.status_code == 200
|
||||
# 验证 fetchmany 被调用时传入了 1000
|
||||
cur.fetchmany.assert_called_once_with(1000)
|
||||
|
||||
# ── 超时设置验证 ──
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_sets_statement_timeout(self, mock_get_conn):
|
||||
"""验证查询前设置了 statement_timeout。"""
|
||||
description = [("id",)]
|
||||
conn, cur = _make_mock_conn([(1,)], description=description)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
client.post("/api/db/query", json={"sql": "SELECT 1"})
|
||||
|
||||
# 第一次 execute 应该是设置超时
|
||||
first_call = cur.execute.call_args_list[0]
|
||||
assert "statement_timeout" in first_call[0][0]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 认证测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDbViewerAuth:
|
||||
|
||||
def test_requires_auth(self):
|
||||
"""移除 auth override 后,所有端点应返回 401/403。"""
|
||||
original = app.dependency_overrides.pop(get_current_user, None)
|
||||
try:
|
||||
endpoints = [
|
||||
("GET", "/api/db/schemas"),
|
||||
("GET", "/api/db/schemas/dwd/tables"),
|
||||
("GET", "/api/db/tables/dwd/dim_member/columns"),
|
||||
("POST", "/api/db/query"),
|
||||
]
|
||||
for method, url in endpoints:
|
||||
if method == "POST":
|
||||
resp = client.request(method, url, json={"sql": "SELECT 1"})
|
||||
else:
|
||||
resp = client.request(method, url)
|
||||
assert resp.status_code in (401, 403), f"{method} {url} 应需要认证"
|
||||
finally:
|
||||
if original:
|
||||
app.dependency_overrides[get_current_user] = original
|
||||
191
apps/backend/tests/test_env_config_properties.py
Normal file
191
apps/backend/tests/test_env_config_properties.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""环境配置属性测试(Property-Based Testing)。
|
||||
|
||||
使用 hypothesis 验证环境配置管理的通用正确性属性:
|
||||
- Property 15: .env 解析与敏感值掩码
|
||||
- Property 16: .env 写入往返一致性
|
||||
|
||||
测试策略:
|
||||
- Property 15: 生成随机 .env 内容(含敏感和非敏感键),验证 _parse_env + _is_sensitive 对敏感值掩码
|
||||
- Property 16: 生成随机键值对,序列化为 .env 格式后再解析,验证往返一致性
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-env-config-properties")
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from app.routers.env_config import _parse_env, _is_sensitive, _MASK, _SENSITIVE_KEYWORDS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 通用策略(Strategies)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 合法的环境变量键名:字母或下划线开头,后跟字母、数字、下划线
|
||||
_key_start_char = st.sampled_from(
|
||||
list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_")
|
||||
)
|
||||
_key_rest_char = st.sampled_from(
|
||||
list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_")
|
||||
)
|
||||
|
||||
_env_key_st = st.builds(
|
||||
lambda first, rest: first + rest,
|
||||
first=_key_start_char,
|
||||
rest=st.text(alphabet=_key_rest_char, min_size=0, max_size=30),
|
||||
)
|
||||
|
||||
# 值:不含换行符的可打印字符串(排除引号以避免解析歧义)
|
||||
_env_value_st = st.text(
|
||||
alphabet=st.characters(
|
||||
whitelist_categories=("L", "N", "P", "S"),
|
||||
blacklist_characters='\n\r"\'#',
|
||||
),
|
||||
min_size=0,
|
||||
max_size=50,
|
||||
)
|
||||
|
||||
# 敏感键:在随机键名中嵌入敏感关键词
|
||||
_sensitive_keyword_st = st.sampled_from(list(_SENSITIVE_KEYWORDS))
|
||||
|
||||
_sensitive_key_st = st.builds(
|
||||
lambda prefix, kw, suffix: prefix + kw + suffix,
|
||||
prefix=st.text(alphabet=_key_rest_char, min_size=0, max_size=10),
|
||||
kw=_sensitive_keyword_st,
|
||||
suffix=st.text(alphabet=_key_rest_char, min_size=0, max_size=10),
|
||||
).filter(lambda k: len(k) > 0 and k[0].isalpha() or k[0] == "_")
|
||||
|
||||
# 确保敏感键以字母或下划线开头
|
||||
_safe_sensitive_key_st = st.builds(
|
||||
lambda prefix, kw: prefix + "_" + kw,
|
||||
prefix=st.sampled_from(["DB", "API", "ETL", "APP", "MY"]),
|
||||
kw=_sensitive_keyword_st,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 15: .env 解析与敏感值掩码
|
||||
# **Validates: Requirements 6.1, 6.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
sensitive_keys=st.lists(_safe_sensitive_key_st, min_size=1, max_size=5, unique=True),
|
||||
sensitive_values=st.lists(
|
||||
st.text(min_size=1, max_size=30, alphabet=st.characters(
|
||||
whitelist_categories=("L", "N"),
|
||||
)),
|
||||
min_size=1, max_size=5,
|
||||
),
|
||||
normal_keys=st.lists(_env_key_st, min_size=1, max_size=5, unique=True),
|
||||
normal_values=st.lists(_env_value_st, min_size=1, max_size=5),
|
||||
)
|
||||
def test_sensitive_values_masked(sensitive_keys, sensitive_values, normal_keys, normal_values):
|
||||
"""Property 15: .env 解析与敏感值掩码。
|
||||
|
||||
包含敏感键(PASSWORD、TOKEN、SECRET、DSN)的 .env 文件内容,
|
||||
API 返回的键值对列表中这些键的值应被掩码替换,不包含原始敏感值。
|
||||
"""
|
||||
# 确保敏感键和普通键不重叠
|
||||
normal_keys_filtered = [k for k in normal_keys if k not in sensitive_keys]
|
||||
assume(len(normal_keys_filtered) >= 1)
|
||||
|
||||
# 对齐列表长度
|
||||
s_vals = (sensitive_values * ((len(sensitive_keys) // len(sensitive_values)) + 1))[:len(sensitive_keys)]
|
||||
n_vals = (normal_values * ((len(normal_keys_filtered) // len(normal_values)) + 1))[:len(normal_keys_filtered)]
|
||||
|
||||
# 构造 .env 内容
|
||||
lines = []
|
||||
for k, v in zip(sensitive_keys, s_vals):
|
||||
lines.append(f"{k}={v}")
|
||||
for k, v in zip(normal_keys_filtered, n_vals):
|
||||
lines.append(f"{k}={v}")
|
||||
env_content = "\n".join(lines) + "\n"
|
||||
|
||||
# 解析
|
||||
parsed = _parse_env(env_content)
|
||||
entries = [line for line in parsed if line["type"] == "entry"]
|
||||
|
||||
# 模拟 GET 端点的掩码逻辑
|
||||
masked_entries = {}
|
||||
for entry in entries:
|
||||
if _is_sensitive(entry["key"]):
|
||||
masked_entries[entry["key"]] = _MASK
|
||||
else:
|
||||
masked_entries[entry["key"]] = entry["value"]
|
||||
|
||||
# 验证:敏感键的值应被掩码
|
||||
for k, v in zip(sensitive_keys, s_vals):
|
||||
assert k in masked_entries, f"敏感键 {k} 应出现在解析结果中"
|
||||
assert masked_entries[k] == _MASK, (
|
||||
f"敏感键 {k} 的值应为掩码 '{_MASK}',实际为 '{masked_entries[k]}'"
|
||||
)
|
||||
# 原始敏感值不应出现在掩码后的结果中
|
||||
assert masked_entries[k] != v, (
|
||||
f"敏感键 {k} 的原始值 '{v}' 不应出现在掩码结果中"
|
||||
)
|
||||
|
||||
# 验证:非敏感键的值应保持原样
|
||||
for k, v in zip(normal_keys_filtered, n_vals):
|
||||
if not _is_sensitive(k):
|
||||
assert k in masked_entries, f"普通键 {k} 应出现在解析结果中"
|
||||
assert masked_entries[k] == v, (
|
||||
f"普通键 {k} 的值应为 '{v}',实际为 '{masked_entries[k]}'"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 16: .env 写入往返一致性
|
||||
# **Validates: Requirements 6.2**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
entries=st.lists(
|
||||
st.tuples(_env_key_st, _env_value_st),
|
||||
min_size=1,
|
||||
max_size=10,
|
||||
unique_by=lambda t: t[0], # 键唯一
|
||||
),
|
||||
)
|
||||
def test_env_write_read_round_trip(entries):
|
||||
"""Property 16: .env 写入往返一致性。
|
||||
|
||||
有效的键值对集合(不含注释和空行),写入 .env 文件后再读取解析,
|
||||
应得到与原始集合等价的键值对。
|
||||
"""
|
||||
# 过滤掉值中可能导致解析歧义的情况(值前后空白会被 strip)
|
||||
clean_entries = [(k, v.strip()) for k, v in entries]
|
||||
# 排除空键(策略已保证非空,但防御性检查)
|
||||
clean_entries = [(k, v) for k, v in clean_entries if k]
|
||||
|
||||
assume(len(clean_entries) >= 1)
|
||||
|
||||
# 模拟写入:构造 .env 文件内容
|
||||
lines = [f"{k}={v}" for k, v in clean_entries]
|
||||
env_content = "\n".join(lines) + "\n"
|
||||
|
||||
# 解析
|
||||
parsed = _parse_env(env_content)
|
||||
parsed_entries = {
|
||||
line["key"]: line["value"]
|
||||
for line in parsed
|
||||
if line["type"] == "entry"
|
||||
}
|
||||
|
||||
# 验证往返一致性:每个写入的键值对都应在解析结果中
|
||||
for k, v in clean_entries:
|
||||
assert k in parsed_entries, (
|
||||
f"键 '{k}' 应出现在解析结果中,实际键集合: {list(parsed_entries.keys())}"
|
||||
)
|
||||
assert parsed_entries[k] == v, (
|
||||
f"键 '{k}' 的值不一致:写入='{v}',解析='{parsed_entries[k]}'"
|
||||
)
|
||||
|
||||
# 验证:解析结果的键数量应与写入的一致
|
||||
assert len(parsed_entries) == len(clean_entries), (
|
||||
f"解析结果键数量 {len(parsed_entries)} 应等于写入数量 {len(clean_entries)}"
|
||||
)
|
||||
291
apps/backend/tests/test_env_config_router.py
Normal file
291
apps/backend/tests/test_env_config_router.py
Normal file
@@ -0,0 +1,291 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""环境配置路由单元测试
|
||||
|
||||
覆盖 3 个端点:GET / PUT / GET /export
|
||||
通过 mock 绕过文件 I/O,专注路由逻辑验证。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
|
||||
_TEST_USER = CurrentUser(user_id=1, site_id=100)
|
||||
|
||||
|
||||
def _override_auth():
|
||||
return _TEST_USER
|
||||
|
||||
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
client = TestClient(app)
|
||||
|
||||
# 模拟 .env 文件内容
|
||||
_SAMPLE_ENV = """\
|
||||
# 数据库配置
|
||||
DB_HOST=localhost
|
||||
DB_PORT=5432
|
||||
DB_PASSWORD=super_secret_123
|
||||
JWT_SECRET_KEY=my-jwt-secret
|
||||
|
||||
# ETL 配置
|
||||
ETL_DB_DSN=postgresql://user:pass@host/db
|
||||
TIMEZONE=Asia/Shanghai
|
||||
"""
|
||||
|
||||
_MOCK_ENV_PATH = "app.routers.env_config._ENV_PATH"
|
||||
|
||||
|
||||
def _mock_path(content: str | None = _SAMPLE_ENV, exists: bool = True):
|
||||
"""构造 mock Path 对象。"""
|
||||
mock = MagicMock()
|
||||
mock.exists.return_value = exists
|
||||
if content is not None:
|
||||
mock.read_text.return_value = content
|
||||
return mock
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/env-config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetEnvConfig:
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_returns_entries_with_masked_sensitive(self, mock_path_obj):
|
||||
mock_path_obj.__class__ = type(MagicMock())
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = _SAMPLE_ENV
|
||||
|
||||
resp = client.get("/api/env-config")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
entries = {e["key"]: e["value"] for e in data["entries"]}
|
||||
|
||||
# 非敏感值原样返回
|
||||
assert entries["DB_HOST"] == "localhost"
|
||||
assert entries["DB_PORT"] == "5432"
|
||||
assert entries["TIMEZONE"] == "Asia/Shanghai"
|
||||
|
||||
# 敏感值掩码
|
||||
assert entries["DB_PASSWORD"] == "****"
|
||||
assert entries["JWT_SECRET_KEY"] == "****"
|
||||
assert entries["ETL_DB_DSN"] == "****"
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_file_not_found(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = False
|
||||
|
||||
resp = client.get("/api/env-config")
|
||||
assert resp.status_code == 404
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_empty_file(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = ""
|
||||
|
||||
resp = client.get("/api/env-config")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["entries"] == []
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_comments_and_blank_lines_excluded(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = "# comment\n\nKEY=val\n"
|
||||
|
||||
resp = client.get("/api/env-config")
|
||||
assert resp.status_code == 200
|
||||
entries = resp.json()["entries"]
|
||||
assert len(entries) == 1
|
||||
assert entries[0]["key"] == "KEY"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PUT /api/env-config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUpdateEnvConfig:
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_update_existing_key(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = "DB_HOST=localhost\nDB_PORT=5432\n"
|
||||
|
||||
resp = client.put("/api/env-config", json={
|
||||
"entries": [
|
||||
{"key": "DB_HOST", "value": "192.168.1.1"},
|
||||
{"key": "DB_PORT", "value": "5433"},
|
||||
]
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
# 验证写入内容
|
||||
written = mock_path_obj.write_text.call_args[0][0]
|
||||
assert "DB_HOST=192.168.1.1" in written
|
||||
assert "DB_PORT=5433" in written
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_add_new_key(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = "DB_HOST=localhost\n"
|
||||
|
||||
resp = client.put("/api/env-config", json={
|
||||
"entries": [
|
||||
{"key": "DB_HOST", "value": "localhost"},
|
||||
{"key": "NEW_KEY", "value": "new_value"},
|
||||
]
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
written = mock_path_obj.write_text.call_args[0][0]
|
||||
assert "NEW_KEY=new_value" in written
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_masked_value_preserves_original(self, mock_path_obj):
|
||||
"""掩码值(****)不应覆盖原始敏感值。"""
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = "DB_PASSWORD=real_secret\nDB_HOST=localhost\n"
|
||||
|
||||
resp = client.put("/api/env-config", json={
|
||||
"entries": [
|
||||
{"key": "DB_PASSWORD", "value": "****"},
|
||||
{"key": "DB_HOST", "value": "newhost"},
|
||||
]
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
written = mock_path_obj.write_text.call_args[0][0]
|
||||
# 原始密码应保留
|
||||
assert "DB_PASSWORD=real_secret" in written
|
||||
assert "DB_HOST=newhost" in written
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_preserves_comments(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = "# 注释行\nDB_HOST=localhost\n\n# 另一个注释\n"
|
||||
|
||||
resp = client.put("/api/env-config", json={
|
||||
"entries": [{"key": "DB_HOST", "value": "newhost"}]
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
written = mock_path_obj.write_text.call_args[0][0]
|
||||
assert "# 注释行" in written
|
||||
assert "# 另一个注释" in written
|
||||
|
||||
def test_invalid_key_format(self):
|
||||
resp = client.put("/api/env-config", json={
|
||||
"entries": [{"key": "123BAD", "value": "val"}]
|
||||
})
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_empty_key(self):
|
||||
resp = client.put("/api/env-config", json={
|
||||
"entries": [{"key": "", "value": "val"}]
|
||||
})
|
||||
assert resp.status_code == 422
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_file_not_exists_creates_new(self, mock_path_obj):
|
||||
"""文件不存在时,应创建新文件。"""
|
||||
mock_path_obj.exists.return_value = False
|
||||
|
||||
resp = client.put("/api/env-config", json={
|
||||
"entries": [{"key": "NEW_KEY", "value": "value"}]
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
written = mock_path_obj.write_text.call_args[0][0]
|
||||
assert "NEW_KEY=value" in written
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_update_sensitive_with_new_value(self, mock_path_obj):
|
||||
"""显式提供新密码时应更新。"""
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = "DB_PASSWORD=old_secret\n"
|
||||
|
||||
resp = client.put("/api/env-config", json={
|
||||
"entries": [{"key": "DB_PASSWORD", "value": "new_secret"}]
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
written = mock_path_obj.write_text.call_args[0][0]
|
||||
assert "DB_PASSWORD=new_secret" in written
|
||||
|
||||
# 返回值中敏感键仍然掩码
|
||||
entries = {e["key"]: e["value"] for e in resp.json()["entries"]}
|
||||
assert entries["DB_PASSWORD"] == "****"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/env-config/export
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExportEnvConfig:
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_export_masks_sensitive(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = _SAMPLE_ENV
|
||||
|
||||
resp = client.get("/api/env-config/export")
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["content-type"].startswith("text/plain")
|
||||
assert "attachment" in resp.headers.get("content-disposition", "")
|
||||
|
||||
content = resp.text
|
||||
# 非敏感值保留
|
||||
assert "DB_HOST=localhost" in content
|
||||
assert "TIMEZONE=Asia/Shanghai" in content
|
||||
|
||||
# 敏感值掩码
|
||||
assert "super_secret_123" not in content
|
||||
assert "my-jwt-secret" not in content
|
||||
assert "DB_PASSWORD=****" in content
|
||||
assert "JWT_SECRET_KEY=****" in content
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_export_preserves_comments(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = _SAMPLE_ENV
|
||||
|
||||
content = client.get("/api/env-config/export").text
|
||||
assert "# 数据库配置" in content
|
||||
assert "# ETL 配置" in content
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_export_file_not_found(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = False
|
||||
|
||||
resp = client.get("/api/env-config/export")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 认证测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEnvConfigAuth:
|
||||
|
||||
def test_requires_auth(self):
|
||||
"""移除 auth override 后,所有端点应返回 401/403。"""
|
||||
# 临时移除 override
|
||||
original = app.dependency_overrides.pop(get_current_user, None)
|
||||
try:
|
||||
for method, url in [
|
||||
("GET", "/api/env-config"),
|
||||
("PUT", "/api/env-config"),
|
||||
("GET", "/api/env-config/export"),
|
||||
]:
|
||||
resp = client.request(method, url)
|
||||
assert resp.status_code in (401, 403), f"{method} {url} 应需要认证"
|
||||
finally:
|
||||
if original:
|
||||
app.dependency_overrides[get_current_user] = original
|
||||
246
apps/backend/tests/test_etl_status_router.py
Normal file
246
apps/backend/tests/test_etl_status_router.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""ETL 状态路由单元测试
|
||||
|
||||
覆盖 2 个端点:
|
||||
- GET /api/etl-status/cursors
|
||||
- GET /api/etl-status/recent-runs
|
||||
|
||||
通过 mock 绕过数据库连接,专注路由逻辑验证。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
|
||||
_TEST_USER = CurrentUser(user_id=1, site_id=100)
|
||||
|
||||
|
||||
def _override_auth():
|
||||
return _TEST_USER
|
||||
|
||||
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
client = TestClient(app)
|
||||
|
||||
_MOCK_ETL_CONN = "app.routers.etl_status.get_etl_readonly_connection"
|
||||
_MOCK_APP_CONN = "app.routers.etl_status.get_connection"
|
||||
|
||||
|
||||
def _make_mock_conn(rows):
|
||||
"""构造 mock 数据库连接,cursor 返回指定行。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
mock_cur.fetchall.return_value = rows
|
||||
mock_cur.fetchone.return_value = None
|
||||
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cur
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_conn, mock_cur
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/etl-status/cursors
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListCursors:
|
||||
|
||||
@patch(_MOCK_ETL_CONN)
|
||||
def test_returns_cursor_list(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([
|
||||
("ODS_FETCH_ORDERS", "2024-06-15 10:30:00+08", 1500),
|
||||
("ODS_FETCH_MEMBERS", "2024-06-15 09:00:00+08", 800),
|
||||
])
|
||||
# fetchone 用于 EXISTS 检查
|
||||
cur.fetchone.return_value = (True,)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/etl-status/cursors")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 2
|
||||
assert data[0]["task_code"] == "ODS_FETCH_ORDERS"
|
||||
assert data[0]["last_fetch_time"] == "2024-06-15 10:30:00+08"
|
||||
assert data[0]["record_count"] == 1500
|
||||
assert data[1]["task_code"] == "ODS_FETCH_MEMBERS"
|
||||
|
||||
# 验证 site_id 传递
|
||||
mock_get_conn.assert_called_once_with(_TEST_USER.site_id)
|
||||
conn.close.assert_called_once()
|
||||
|
||||
@patch(_MOCK_ETL_CONN)
|
||||
def test_table_not_exists_returns_empty(self, mock_get_conn):
|
||||
"""etl_admin.etl_cursor 表不存在时返回空列表。"""
|
||||
conn, cur = _make_mock_conn([])
|
||||
cur.fetchone.return_value = (False,)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/etl-status/cursors")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
@patch(_MOCK_ETL_CONN)
|
||||
def test_null_fields(self, mock_get_conn):
|
||||
"""游标字段可能为 None(任务从未执行过)。"""
|
||||
conn, cur = _make_mock_conn([
|
||||
("ODS_FETCH_INVENTORY", None, None),
|
||||
])
|
||||
cur.fetchone.return_value = (True,)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/etl-status/cursors")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data[0]["task_code"] == "ODS_FETCH_INVENTORY"
|
||||
assert data[0]["last_fetch_time"] is None
|
||||
assert data[0]["record_count"] is None
|
||||
|
||||
@patch(_MOCK_ETL_CONN)
|
||||
def test_empty_cursors(self, mock_get_conn):
|
||||
"""表存在但无数据。"""
|
||||
conn, cur = _make_mock_conn([])
|
||||
cur.fetchone.return_value = (True,)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/etl-status/cursors")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/etl-status/recent-runs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListRecentRuns:
|
||||
|
||||
@patch(_MOCK_APP_CONN)
|
||||
def test_returns_recent_runs(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([
|
||||
(
|
||||
"a1b2c3d4-0000-0000-0000-000000000001",
|
||||
["ODS_FETCH_ORDERS", "DWD_LOAD_FROM_ODS"],
|
||||
"success",
|
||||
"2024-06-15 10:30:00+08",
|
||||
"2024-06-15 10:35:00+08",
|
||||
300000,
|
||||
0,
|
||||
),
|
||||
(
|
||||
"a1b2c3d4-0000-0000-0000-000000000002",
|
||||
["DWS_AGGREGATE"],
|
||||
"failed",
|
||||
"2024-06-15 09:00:00+08",
|
||||
"2024-06-15 09:01:00+08",
|
||||
60000,
|
||||
1,
|
||||
),
|
||||
])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/etl-status/recent-runs")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 2
|
||||
|
||||
run0 = data[0]
|
||||
assert run0["id"] == "a1b2c3d4-0000-0000-0000-000000000001"
|
||||
assert run0["task_codes"] == ["ODS_FETCH_ORDERS", "DWD_LOAD_FROM_ODS"]
|
||||
assert run0["status"] == "success"
|
||||
assert run0["duration_ms"] == 300000
|
||||
assert run0["exit_code"] == 0
|
||||
|
||||
run1 = data[1]
|
||||
assert run1["status"] == "failed"
|
||||
assert run1["exit_code"] == 1
|
||||
|
||||
conn.close.assert_called_once()
|
||||
|
||||
@patch(_MOCK_APP_CONN)
|
||||
def test_empty_runs(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/etl-status/recent-runs")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
@patch(_MOCK_APP_CONN)
|
||||
def test_null_optional_fields(self, mock_get_conn):
|
||||
"""正在执行的任务 finished_at / duration_ms / exit_code 为 None。"""
|
||||
conn, cur = _make_mock_conn([
|
||||
(
|
||||
"a1b2c3d4-0000-0000-0000-000000000003",
|
||||
["ODS_FETCH_MEMBERS"],
|
||||
"running",
|
||||
"2024-06-15 11:00:00+08",
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/etl-status/recent-runs")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data[0]["status"] == "running"
|
||||
assert data[0]["finished_at"] is None
|
||||
assert data[0]["duration_ms"] is None
|
||||
assert data[0]["exit_code"] is None
|
||||
|
||||
@patch(_MOCK_APP_CONN)
|
||||
def test_site_id_filter(self, mock_get_conn):
|
||||
"""验证查询时传入了正确的 site_id 参数。"""
|
||||
conn, cur = _make_mock_conn([])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
client.get("/api/etl-status/recent-runs")
|
||||
|
||||
# 验证 SQL 中传入了 site_id 和 limit
|
||||
call_args = cur.execute.call_args
|
||||
params = call_args[0][1]
|
||||
assert params[0] == _TEST_USER.site_id
|
||||
assert params[1] == 50
|
||||
|
||||
@patch(_MOCK_APP_CONN)
|
||||
def test_empty_task_codes(self, mock_get_conn):
|
||||
"""task_codes 为 None 时应返回空列表。"""
|
||||
conn, cur = _make_mock_conn([
|
||||
(
|
||||
"a1b2c3d4-0000-0000-0000-000000000004",
|
||||
None,
|
||||
"pending",
|
||||
"2024-06-15 12:00:00+08",
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/etl-status/recent-runs")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()[0]["task_codes"] == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 认证测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEtlStatusAuth:
|
||||
|
||||
def test_requires_auth(self):
|
||||
"""移除 auth override 后,所有端点应返回 401/403。"""
|
||||
original = app.dependency_overrides.pop(get_current_user, None)
|
||||
try:
|
||||
for url in ["/api/etl-status/cursors", "/api/etl-status/recent-runs"]:
|
||||
resp = client.get(url)
|
||||
assert resp.status_code in (401, 403), f"GET {url} 应需要认证"
|
||||
finally:
|
||||
if original:
|
||||
app.dependency_overrides[get_current_user] = original
|
||||
339
apps/backend/tests/test_execution_router.py
Normal file
339
apps/backend/tests/test_execution_router.py
Normal file
@@ -0,0 +1,339 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""执行与队列路由单元测试
|
||||
|
||||
覆盖 8 个端点:run / queue CRUD / cancel / history / logs
|
||||
通过 mock 绕过数据库和服务层,专注路由逻辑验证。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
from app.services.task_queue import QueuedTask
|
||||
|
||||
# 固定测试用户
|
||||
_TEST_USER = CurrentUser(user_id=1, site_id=100)
|
||||
|
||||
|
||||
def _override_auth():
|
||||
return _TEST_USER
|
||||
|
||||
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
client = TestClient(app)
|
||||
|
||||
_NOW = datetime(2024, 6, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
# 构造测试用的 TaskConfig payload
|
||||
_VALID_CONFIG = {
|
||||
"tasks": ["ODS_MEMBER"],
|
||||
"pipeline": "api_ods",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/execution/run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRunTask:
|
||||
@patch("app.routers.execution.task_executor")
|
||||
def test_run_returns_execution_id(self, mock_executor):
|
||||
mock_executor.execute = AsyncMock()
|
||||
resp = client.post("/api/execution/run", json=_VALID_CONFIG)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "execution_id" in data
|
||||
assert data["message"] == "任务已提交执行"
|
||||
|
||||
@patch("app.routers.execution.task_executor")
|
||||
def test_run_injects_store_id(self, mock_executor):
|
||||
"""store_id 应从 JWT 注入"""
|
||||
mock_executor.execute = AsyncMock()
|
||||
resp = client.post("/api/execution/run", json={
|
||||
**_VALID_CONFIG,
|
||||
"store_id": 999, # 前端传的值应被覆盖
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_run_requires_auth(self):
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
try:
|
||||
resp = client.post("/api/execution/run", json=_VALID_CONFIG)
|
||||
assert resp.status_code in (401, 403)
|
||||
finally:
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
|
||||
def test_run_invalid_config_returns_422(self):
|
||||
"""缺少必填字段 tasks 时返回 422"""
|
||||
resp = client.post("/api/execution/run", json={"pipeline": "api_ods"})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/execution/queue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetQueue:
|
||||
@patch("app.routers.execution.task_queue")
|
||||
def test_get_queue_returns_list(self, mock_queue):
|
||||
mock_queue.list_pending.return_value = [
|
||||
QueuedTask(
|
||||
id="task-1", site_id=100, config={"tasks": ["ODS_MEMBER"]},
|
||||
status="pending", position=1, created_at=_NOW,
|
||||
),
|
||||
]
|
||||
resp = client.get("/api/execution/queue")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["id"] == "task-1"
|
||||
assert data[0]["status"] == "pending"
|
||||
|
||||
@patch("app.routers.execution.task_queue")
|
||||
def test_get_queue_empty(self, mock_queue):
|
||||
mock_queue.list_pending.return_value = []
|
||||
resp = client.get("/api/execution/queue")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
@patch("app.routers.execution.task_queue")
|
||||
def test_get_queue_filters_by_site_id(self, mock_queue):
|
||||
"""确认调用 list_pending 时传入了正确的 site_id"""
|
||||
mock_queue.list_pending.return_value = []
|
||||
client.get("/api/execution/queue")
|
||||
mock_queue.list_pending.assert_called_once_with(100)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/execution/queue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEnqueueTask:
|
||||
@patch("app.routers.execution.get_connection")
|
||||
@patch("app.routers.execution.task_queue")
|
||||
def test_enqueue_returns_201(self, mock_queue, mock_get_conn):
|
||||
mock_queue.enqueue.return_value = "new-task-id"
|
||||
|
||||
# mock 数据库查询返回
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = (
|
||||
"new-task-id", 100, '{"tasks": ["ODS_MEMBER"]}',
|
||||
"pending", 1, _NOW, None, None, None, None,
|
||||
)
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.post("/api/execution/queue", json=_VALID_CONFIG)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["id"] == "new-task-id"
|
||||
assert data["status"] == "pending"
|
||||
|
||||
@patch("app.routers.execution.task_queue")
|
||||
def test_enqueue_calls_with_site_id(self, mock_queue):
|
||||
"""确认 enqueue 时传入了 JWT 的 site_id"""
|
||||
mock_queue.enqueue.return_value = "id-1"
|
||||
|
||||
# 让后续的 DB 查询抛异常来快速结束(enqueue 本身已验证)
|
||||
with patch("app.routers.execution.get_connection") as mock_conn:
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = (
|
||||
"id-1", 100, '{"tasks": []}', "pending", 1,
|
||||
_NOW, None, None, None, None,
|
||||
)
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_conn.return_value = conn
|
||||
|
||||
client.post("/api/execution/queue", json=_VALID_CONFIG)
|
||||
|
||||
# 验证 enqueue 的第二个参数是 site_id=100
|
||||
call_args = mock_queue.enqueue.call_args
|
||||
assert call_args[0][1] == 100 # site_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PUT /api/execution/queue/reorder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReorderQueue:
|
||||
@patch("app.routers.execution.task_queue")
|
||||
def test_reorder_success(self, mock_queue):
|
||||
mock_queue.reorder.return_value = None
|
||||
resp = client.put("/api/execution/queue/reorder", json={
|
||||
"task_id": "task-1",
|
||||
"new_position": 3,
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
mock_queue.reorder.assert_called_once_with("task-1", 3, 100)
|
||||
|
||||
def test_reorder_missing_fields_returns_422(self):
|
||||
resp = client.put("/api/execution/queue/reorder", json={})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DELETE /api/execution/queue/{id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDeleteQueueTask:
|
||||
@patch("app.routers.execution.task_queue")
|
||||
def test_delete_success(self, mock_queue):
|
||||
mock_queue.delete.return_value = True
|
||||
resp = client.delete("/api/execution/queue/task-1")
|
||||
assert resp.status_code == 200
|
||||
mock_queue.delete.assert_called_once_with("task-1", 100)
|
||||
|
||||
@patch("app.routers.execution.task_queue")
|
||||
def test_delete_nonexistent_returns_409(self, mock_queue):
|
||||
mock_queue.delete.return_value = False
|
||||
resp = client.delete("/api/execution/queue/nonexistent")
|
||||
assert resp.status_code == 409
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/execution/{id}/cancel
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCancelExecution:
|
||||
@patch("app.routers.execution.task_executor")
|
||||
def test_cancel_success(self, mock_executor):
|
||||
mock_executor.cancel = AsyncMock(return_value=True)
|
||||
resp = client.post("/api/execution/some-id/cancel")
|
||||
assert resp.status_code == 200
|
||||
assert "取消" in resp.json()["message"]
|
||||
|
||||
@patch("app.routers.execution.task_executor")
|
||||
def test_cancel_not_found(self, mock_executor):
|
||||
mock_executor.cancel = AsyncMock(return_value=False)
|
||||
resp = client.post("/api/execution/nonexistent/cancel")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/execution/history
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExecutionHistory:
|
||||
@patch("app.routers.execution.get_connection")
|
||||
def test_history_returns_list(self, mock_get_conn):
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchall.return_value = [
|
||||
(
|
||||
"exec-1", 100, ["ODS_MEMBER"], "success",
|
||||
_NOW, _NOW, 0, 1234, "python -m cli.main", None,
|
||||
),
|
||||
]
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.get("/api/execution/history")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["id"] == "exec-1"
|
||||
assert data[0]["status"] == "success"
|
||||
|
||||
@patch("app.routers.execution.get_connection")
|
||||
def test_history_respects_limit(self, mock_get_conn):
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchall.return_value = []
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.get("/api/execution/history?limit=10")
|
||||
assert resp.status_code == 200
|
||||
|
||||
# 验证 SQL 中传入了 limit=10
|
||||
call_args = mock_cursor.execute.call_args
|
||||
assert call_args[0][1] == (100, 10) # (site_id, limit)
|
||||
|
||||
def test_history_limit_validation(self):
|
||||
"""limit 超出范围时返回 422"""
|
||||
resp = client.get("/api/execution/history?limit=0")
|
||||
assert resp.status_code == 422
|
||||
resp = client.get("/api/execution/history?limit=999")
|
||||
assert resp.status_code == 422
|
||||
|
||||
@patch("app.routers.execution.get_connection")
|
||||
def test_history_empty(self, mock_get_conn):
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchall.return_value = []
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.get("/api/execution/history")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/execution/{id}/logs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExecutionLogs:
|
||||
@patch("app.routers.execution.get_connection")
|
||||
@patch("app.routers.execution.task_executor")
|
||||
def test_logs_from_db(self, mock_executor, mock_get_conn):
|
||||
"""已完成任务从数据库读取日志"""
|
||||
mock_executor.is_running.return_value = False
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = ("stdout output", "stderr output")
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.get("/api/execution/exec-1/logs")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["execution_id"] == "exec-1"
|
||||
assert data["output_log"] == "stdout output"
|
||||
assert data["error_log"] == "stderr output"
|
||||
|
||||
@patch("app.routers.execution.task_executor")
|
||||
def test_logs_from_memory(self, mock_executor):
|
||||
"""执行中的任务从内存缓冲区读取"""
|
||||
mock_executor.is_running.return_value = True
|
||||
mock_executor.get_logs.return_value = ["line1", "line2"]
|
||||
|
||||
resp = client.get("/api/execution/running-id/logs")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["execution_id"] == "running-id"
|
||||
assert "line1" in data["output_log"]
|
||||
assert "line2" in data["output_log"]
|
||||
|
||||
@patch("app.routers.execution.get_connection")
|
||||
@patch("app.routers.execution.task_executor")
|
||||
def test_logs_not_found(self, mock_executor, mock_get_conn):
|
||||
mock_executor.is_running.return_value = False
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = None
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.get("/api/execution/nonexistent/logs")
|
||||
assert resp.status_code == 404
|
||||
510
apps/backend/tests/test_queue_properties.py
Normal file
510
apps/backend/tests/test_queue_properties.py
Normal file
@@ -0,0 +1,510 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""队列属性测试(Property-Based Testing)。
|
||||
|
||||
使用 hypothesis 验证队列管理的通用正确性属性:
|
||||
- Property 8: 队列 CRUD 不变量
|
||||
- Property 9: 队列出队顺序
|
||||
- Property 10: 队列重排一致性
|
||||
- Property 11: 执行历史排序与限制
|
||||
|
||||
测试策略:
|
||||
- Property 8-10 通过内存模拟队列状态,mock 数据库操作,验证 TaskQueue 的核心逻辑
|
||||
- Property 11 通过 mock 数据库返回,验证执行历史端点的排序与限制逻辑
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-queue-properties")
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from app.schemas.tasks import TaskConfigSchema
|
||||
from app.services.task_queue import TaskQueue, QueuedTask
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 通用策略(Strategies)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_site_id_st = st.integers(min_value=1, max_value=2**31 - 1)
|
||||
|
||||
# 简单的任务代码列表
|
||||
_task_codes = ["ODS_MEMBER", "ODS_PAYMENT", "ODS_ORDER", "DWD_LOAD_FROM_ODS", "DWS_SUMMARY"]
|
||||
|
||||
_simple_config_st = st.builds(
|
||||
TaskConfigSchema,
|
||||
tasks=st.lists(st.sampled_from(_task_codes), min_size=1, max_size=3, unique=True),
|
||||
pipeline=st.sampled_from(["api_ods", "api_ods_dwd", "ods_dwd"]),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 内存队列模拟器 — 用于 mock 数据库交互
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class InMemoryQueueDB:
|
||||
"""模拟 task_queue 表的内存存储,为 TaskQueue 方法提供 mock 数据库行为。"""
|
||||
|
||||
def __init__(self, site_id: int):
|
||||
self.site_id = site_id
|
||||
# 存储格式:{task_id: {config, status, position, ...}}
|
||||
self.rows: dict[str, dict] = {}
|
||||
|
||||
@property
|
||||
def pending_tasks(self) -> list[dict]:
|
||||
"""按 position 排序的 pending 任务列表。"""
|
||||
return sorted(
|
||||
[r for r in self.rows.values() if r["status"] == "pending"],
|
||||
key=lambda r: r["position"],
|
||||
)
|
||||
|
||||
def mock_enqueue_connection(self):
|
||||
"""为 enqueue 方法构造 mock connection。
|
||||
|
||||
enqueue 执行两条 SQL:
|
||||
1. SELECT COALESCE(MAX(position), 0) → 返回当前最大 position
|
||||
2. INSERT INTO task_queue → 插入新行
|
||||
"""
|
||||
pending = self.pending_tasks
|
||||
max_pos = max((r["position"] for r in pending), default=0)
|
||||
|
||||
call_count = [0]
|
||||
db = self
|
||||
|
||||
def make_cursor():
|
||||
cur = MagicMock()
|
||||
executed_sqls = []
|
||||
|
||||
def execute_side_effect(sql, params=None):
|
||||
executed_sqls.append((sql, params))
|
||||
call_count[0] += 1
|
||||
if "MAX(position)" in sql:
|
||||
cur.fetchone.return_value = (max_pos,)
|
||||
elif "INSERT INTO task_queue" in sql:
|
||||
# 记录插入的行
|
||||
task_id, site_id, config_json, new_pos = params
|
||||
db.rows[task_id] = {
|
||||
"id": task_id,
|
||||
"site_id": site_id,
|
||||
"config": json.loads(config_json),
|
||||
"status": "pending",
|
||||
"position": new_pos,
|
||||
}
|
||||
|
||||
cur.execute = MagicMock(side_effect=execute_side_effect)
|
||||
cur.__enter__ = MagicMock(return_value=cur)
|
||||
cur.__exit__ = MagicMock(return_value=False)
|
||||
return cur
|
||||
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value = make_cursor()
|
||||
return conn
|
||||
|
||||
def mock_dequeue_connection(self):
|
||||
"""为 dequeue 方法构造 mock connection。
|
||||
|
||||
dequeue 执行两条 SQL:
|
||||
1. SELECT ... ORDER BY position ASC LIMIT 1 FOR UPDATE → 返回队首任务
|
||||
2. UPDATE ... SET status = 'running' → 更新状态
|
||||
"""
|
||||
pending = self.pending_tasks
|
||||
first = pending[0] if pending else None
|
||||
db = self
|
||||
|
||||
def make_cursor():
|
||||
cur = MagicMock()
|
||||
|
||||
def execute_side_effect(sql, params=None):
|
||||
if "ORDER BY position ASC" in sql:
|
||||
if first:
|
||||
cur.fetchone.return_value = (
|
||||
first["id"], first["site_id"],
|
||||
json.dumps(first["config"]),
|
||||
first["status"], first["position"],
|
||||
None, None, None, None, None,
|
||||
)
|
||||
else:
|
||||
cur.fetchone.return_value = None
|
||||
elif "SET status = 'running'" in sql:
|
||||
if first:
|
||||
db.rows[first["id"]]["status"] = "running"
|
||||
|
||||
cur.execute = MagicMock(side_effect=execute_side_effect)
|
||||
cur.__enter__ = MagicMock(return_value=cur)
|
||||
cur.__exit__ = MagicMock(return_value=False)
|
||||
return cur
|
||||
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value = make_cursor()
|
||||
return conn
|
||||
|
||||
def mock_delete_connection(self, task_id: str):
|
||||
"""为 delete 方法构造 mock connection。"""
|
||||
db = self
|
||||
|
||||
def make_cursor():
|
||||
cur = MagicMock()
|
||||
|
||||
def execute_side_effect(sql, params=None):
|
||||
tid = params[0]
|
||||
if tid in db.rows and db.rows[tid]["status"] == "pending":
|
||||
del db.rows[tid]
|
||||
cur.rowcount = 1
|
||||
else:
|
||||
cur.rowcount = 0
|
||||
|
||||
cur.execute = MagicMock(side_effect=execute_side_effect)
|
||||
cur.rowcount = 0
|
||||
cur.__enter__ = MagicMock(return_value=cur)
|
||||
cur.__exit__ = MagicMock(return_value=False)
|
||||
return cur
|
||||
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value = make_cursor()
|
||||
return conn
|
||||
|
||||
def mock_reorder_connection(self):
|
||||
"""为 reorder 方法构造 mock connection。
|
||||
|
||||
reorder 执行:
|
||||
1. SELECT id FROM task_queue WHERE ... ORDER BY position ASC
|
||||
2. 多次 UPDATE task_queue SET position = %s WHERE id = %s
|
||||
"""
|
||||
pending = self.pending_tasks
|
||||
db = self
|
||||
|
||||
def make_cursor():
|
||||
cur = MagicMock()
|
||||
call_idx = [0]
|
||||
|
||||
def execute_side_effect(sql, params=None):
|
||||
if "SELECT id FROM task_queue" in sql:
|
||||
cur.fetchall.return_value = [(r["id"],) for r in pending]
|
||||
elif "UPDATE task_queue SET position" in sql:
|
||||
pos, tid = params
|
||||
if tid in db.rows:
|
||||
db.rows[tid]["position"] = pos
|
||||
|
||||
cur.execute = MagicMock(side_effect=execute_side_effect)
|
||||
cur.__enter__ = MagicMock(return_value=cur)
|
||||
cur.__exit__ = MagicMock(return_value=False)
|
||||
return cur
|
||||
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value = make_cursor()
|
||||
return conn
|
||||
|
||||
def mock_list_pending_connection(self):
|
||||
"""为 list_pending 方法构造 mock connection。"""
|
||||
pending = self.pending_tasks
|
||||
|
||||
def make_cursor():
|
||||
cur = MagicMock()
|
||||
|
||||
def execute_side_effect(sql, params=None):
|
||||
cur.fetchall.return_value = [
|
||||
(
|
||||
r["id"], r["site_id"], json.dumps(r["config"]),
|
||||
r["status"], r["position"],
|
||||
None, None, None, None, None,
|
||||
)
|
||||
for r in pending
|
||||
]
|
||||
|
||||
cur.execute = MagicMock(side_effect=execute_side_effect)
|
||||
cur.__enter__ = MagicMock(return_value=cur)
|
||||
cur.__exit__ = MagicMock(return_value=False)
|
||||
return cur
|
||||
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value = make_cursor()
|
||||
return conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 8: 队列 CRUD 不变量
|
||||
# **Validates: Requirements 4.1, 4.4**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
config=_simple_config_st,
|
||||
site_id=_site_id_st,
|
||||
initial_count=st.integers(min_value=0, max_value=5),
|
||||
)
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_queue_crud_invariant(mock_get_conn, config, site_id, initial_count):
|
||||
"""Property 8: 队列 CRUD 不变量。
|
||||
|
||||
入队一个任务后队列长度增加 1 且新任务状态为 pending;
|
||||
删除一个 pending 任务后队列长度减少 1 且该任务不再出现在队列中。
|
||||
"""
|
||||
queue = TaskQueue()
|
||||
db = InMemoryQueueDB(site_id)
|
||||
|
||||
# 预填充若干任务
|
||||
for i in range(initial_count):
|
||||
tid = str(uuid.uuid4())
|
||||
db.rows[tid] = {
|
||||
"id": tid,
|
||||
"site_id": site_id,
|
||||
"config": {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"},
|
||||
"status": "pending",
|
||||
"position": i + 1,
|
||||
}
|
||||
|
||||
before_count = len(db.pending_tasks)
|
||||
|
||||
# --- 入队 ---
|
||||
mock_get_conn.return_value = db.mock_enqueue_connection()
|
||||
new_id = queue.enqueue(config, site_id)
|
||||
|
||||
after_enqueue_count = len(db.pending_tasks)
|
||||
assert after_enqueue_count == before_count + 1, (
|
||||
f"入队后长度应 +1:期望 {before_count + 1},实际 {after_enqueue_count}"
|
||||
)
|
||||
assert new_id in db.rows, "新任务应存在于队列中"
|
||||
assert db.rows[new_id]["status"] == "pending", "新任务状态应为 pending"
|
||||
|
||||
# --- 删除刚入队的任务 ---
|
||||
mock_get_conn.return_value = db.mock_delete_connection(new_id)
|
||||
deleted = queue.delete(new_id, site_id)
|
||||
|
||||
after_delete_count = len(db.pending_tasks)
|
||||
assert deleted is True, "删除 pending 任务应返回 True"
|
||||
assert after_delete_count == before_count, (
|
||||
f"删除后长度应恢复:期望 {before_count},实际 {after_delete_count}"
|
||||
)
|
||||
assert new_id not in db.rows, "已删除任务不应出现在队列中"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 9: 队列出队顺序
|
||||
# **Validates: Requirements 4.2**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
site_id=_site_id_st,
|
||||
num_tasks=st.integers(min_value=1, max_value=8),
|
||||
positions=st.data(),
|
||||
)
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_queue_dequeue_order(mock_get_conn, site_id, num_tasks, positions):
|
||||
"""Property 9: 队列出队顺序。
|
||||
|
||||
包含多个 pending 任务的队列,dequeue 操作应返回 position 值最小的任务。
|
||||
"""
|
||||
queue = TaskQueue()
|
||||
db = InMemoryQueueDB(site_id)
|
||||
|
||||
# 生成不重复的 position 值
|
||||
pos_list = positions.draw(
|
||||
st.lists(
|
||||
st.integers(min_value=1, max_value=1000),
|
||||
min_size=num_tasks,
|
||||
max_size=num_tasks,
|
||||
unique=True,
|
||||
)
|
||||
)
|
||||
|
||||
# 填充队列
|
||||
task_ids = []
|
||||
for i, pos in enumerate(pos_list):
|
||||
tid = str(uuid.uuid4())
|
||||
task_ids.append(tid)
|
||||
db.rows[tid] = {
|
||||
"id": tid,
|
||||
"site_id": site_id,
|
||||
"config": {"tasks": [_task_codes[i % len(_task_codes)]], "pipeline": "api_ods"},
|
||||
"status": "pending",
|
||||
"position": pos,
|
||||
}
|
||||
|
||||
# 找出 position 最小的任务
|
||||
expected_first = min(db.pending_tasks, key=lambda r: r["position"])
|
||||
|
||||
# dequeue
|
||||
mock_get_conn.return_value = db.mock_dequeue_connection()
|
||||
result = queue.dequeue(site_id)
|
||||
|
||||
assert result is not None, "队列非空时 dequeue 不应返回 None"
|
||||
assert result.id == expected_first["id"], (
|
||||
f"应返回 position 最小的任务:期望 id={expected_first['id']} "
|
||||
f"(pos={expected_first['position']}),实际 id={result.id}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 10: 队列重排一致性
|
||||
# **Validates: Requirements 4.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
site_id=_site_id_st,
|
||||
num_tasks=st.integers(min_value=2, max_value=6),
|
||||
data=st.data(),
|
||||
)
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_queue_reorder_consistency(mock_get_conn, site_id, num_tasks, data):
|
||||
"""Property 10: 队列重排一致性。
|
||||
|
||||
重排操作(将任务移动到新位置)后,队列中任务的相对顺序应与请求一致:
|
||||
- 被移动的任务应出现在目标位置(clamp 到有效范围)
|
||||
- 其余任务保持原有相对顺序
|
||||
- 所有任务仍在队列中(不丢失)
|
||||
"""
|
||||
queue = TaskQueue()
|
||||
db = InMemoryQueueDB(site_id)
|
||||
|
||||
# 填充队列,position 从 1 开始连续编号
|
||||
task_ids = []
|
||||
for i in range(num_tasks):
|
||||
tid = str(uuid.uuid4())
|
||||
task_ids.append(tid)
|
||||
db.rows[tid] = {
|
||||
"id": tid,
|
||||
"site_id": site_id,
|
||||
"config": {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"},
|
||||
"status": "pending",
|
||||
"position": i + 1,
|
||||
}
|
||||
|
||||
# 随机选择要移动的任务和目标位置
|
||||
move_idx = data.draw(st.integers(min_value=0, max_value=num_tasks - 1))
|
||||
move_task_id = task_ids[move_idx]
|
||||
new_position = data.draw(st.integers(min_value=1, max_value=num_tasks + 2))
|
||||
|
||||
# 执行 reorder
|
||||
mock_get_conn.return_value = db.mock_reorder_connection()
|
||||
queue.reorder(move_task_id, new_position, site_id)
|
||||
|
||||
# 验证:所有任务仍在队列中
|
||||
remaining_ids = {r["id"] for r in db.rows.values() if r["status"] == "pending"}
|
||||
assert remaining_ids == set(task_ids), "重排后不应丢失任何任务"
|
||||
|
||||
# 验证:position 值连续且唯一(1-based)
|
||||
positions = sorted(r["position"] for r in db.pending_tasks)
|
||||
assert positions == list(range(1, num_tasks + 1)), (
|
||||
f"重排后 position 应为连续编号 1..{num_tasks},实际 {positions}"
|
||||
)
|
||||
|
||||
# 验证:被移动的任务在正确位置
|
||||
# reorder 内部逻辑:clamp new_position 到 [1, len(others)+1]
|
||||
clamped_pos = max(1, min(new_position, num_tasks))
|
||||
actual_pos = db.rows[move_task_id]["position"]
|
||||
assert actual_pos == clamped_pos, (
|
||||
f"被移动任务的 position 应为 {clamped_pos}(clamp 后),实际 {actual_pos}"
|
||||
)
|
||||
|
||||
# 验证:其余任务保持原有相对顺序
|
||||
others_before = [tid for tid in task_ids if tid != move_task_id]
|
||||
others_after = sorted(
|
||||
[r for r in db.pending_tasks if r["id"] != move_task_id],
|
||||
key=lambda r: r["position"],
|
||||
)
|
||||
others_after_ids = [r["id"] for r in others_after]
|
||||
assert others_after_ids == others_before, (
|
||||
"其余任务的相对顺序应保持不变"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 11: 执行历史排序与限制
|
||||
# **Validates: Requirements 4.5, 8.2**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 导入 FastAPI 测试客户端
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
def _make_history_rows(count: int, site_id: int) -> list[tuple]:
|
||||
"""生成 count 条执行历史记录,started_at 随机但可排序。"""
|
||||
base_time = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
rows = []
|
||||
for i in range(count):
|
||||
rows.append((
|
||||
str(uuid.uuid4()), # id
|
||||
site_id, # site_id
|
||||
["ODS_MEMBER"], # task_codes
|
||||
"success", # status
|
||||
base_time + timedelta(hours=i), # started_at
|
||||
base_time + timedelta(hours=i, minutes=30), # finished_at
|
||||
0, # exit_code
|
||||
1800000, # duration_ms
|
||||
"python -m cli.main", # command
|
||||
None, # summary
|
||||
))
|
||||
return rows
|
||||
|
||||
|
||||
@settings(max_examples=100, deadline=None)
|
||||
@given(
|
||||
site_id=_site_id_st,
|
||||
total_records=st.integers(min_value=0, max_value=30),
|
||||
limit=st.integers(min_value=1, max_value=200),
|
||||
)
|
||||
@patch("app.routers.execution.get_connection")
|
||||
def test_execution_history_sort_and_limit(mock_get_conn, site_id, total_records, limit):
|
||||
"""Property 11: 执行历史排序与限制。
|
||||
|
||||
执行历史记录集合,API 返回的结果应按 started_at 降序排列,
|
||||
且结果数量不超过请求的 limit 值。
|
||||
"""
|
||||
# 生成测试数据
|
||||
all_rows = _make_history_rows(total_records, site_id)
|
||||
|
||||
# 模拟数据库:按 started_at DESC 排序后取 limit 条
|
||||
sorted_rows = sorted(all_rows, key=lambda r: r[4], reverse=True)
|
||||
returned_rows = sorted_rows[:limit]
|
||||
|
||||
# mock 数据库连接
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchall.return_value = returned_rows
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
# 覆盖认证依赖
|
||||
test_user = CurrentUser(user_id=1, site_id=site_id)
|
||||
app.dependency_overrides[get_current_user] = lambda: test_user
|
||||
|
||||
try:
|
||||
client = TestClient(app)
|
||||
# limit 必须在 [1, 200] 范围内(API 约束)
|
||||
clamped_limit = max(1, min(limit, 200))
|
||||
resp = client.get(f"/api/execution/history?limit={clamped_limit}")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
|
||||
# 验证 1:结果数量不超过 limit
|
||||
assert len(data) <= clamped_limit, (
|
||||
f"结果数量 {len(data)} 超过 limit {clamped_limit}"
|
||||
)
|
||||
|
||||
# 验证 2:结果数量不超过总记录数
|
||||
assert len(data) <= total_records, (
|
||||
f"结果数量 {len(data)} 超过总记录数 {total_records}"
|
||||
)
|
||||
|
||||
# 验证 3:按 started_at 降序排列
|
||||
if len(data) >= 2:
|
||||
for i in range(len(data) - 1):
|
||||
t1 = data[i]["started_at"]
|
||||
t2 = data[i + 1]["started_at"]
|
||||
assert t1 >= t2, (
|
||||
f"结果未按 started_at 降序排列:data[{i}]={t1} < data[{i+1}]={t2}"
|
||||
)
|
||||
finally:
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(user_id=1, site_id=100)
|
||||
439
apps/backend/tests/test_schedule_properties.py
Normal file
439
apps/backend/tests/test_schedule_properties.py
Normal file
@@ -0,0 +1,439 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""调度属性测试(Property-Based Testing)。
|
||||
|
||||
使用 hypothesis 验证调度管理的通用正确性属性:
|
||||
- Property 12: 调度任务 CRUD 往返
|
||||
- Property 13: 到期调度任务自动入队
|
||||
- Property 14: 调度任务启用/禁用状态
|
||||
|
||||
测试策略:
|
||||
- Property 12: 通过 mock 数据库,验证 POST 创建后 GET 返回的 schedule_config 与提交的一致
|
||||
- Property 13: 通过 mock 数据库返回到期任务,验证 check_and_enqueue 调用了 task_queue.enqueue
|
||||
- Property 14: 通过 mock 数据库,验证 toggle 端点的 next_run_at 行为
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-schedule-properties")
|
||||
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
from app.schemas.schedules import ScheduleConfigSchema
|
||||
from app.schemas.tasks import TaskConfigSchema
|
||||
from app.services.scheduler import Scheduler, calculate_next_run
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 通用策略(Strategies)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_site_id_st = st.integers(min_value=1, max_value=2**31 - 1)
|
||||
|
||||
_task_codes = ["ODS_MEMBER", "ODS_PAYMENT", "ODS_ORDER", "DWD_LOAD_FROM_ODS", "DWS_SUMMARY"]
|
||||
|
||||
_simple_task_config_st = st.fixed_dictionaries({
|
||||
"tasks": st.lists(st.sampled_from(_task_codes), min_size=1, max_size=3, unique=True),
|
||||
"pipeline": st.sampled_from(["api_ods", "api_ods_dwd", "ods_dwd", "api_full"]),
|
||||
})
|
||||
|
||||
# 调度配置策略:覆盖 5 种调度类型
|
||||
_schedule_type_st = st.sampled_from(["once", "interval", "daily", "weekly", "cron"])
|
||||
|
||||
_interval_unit_st = st.sampled_from(["minutes", "hours", "days"])
|
||||
|
||||
# HH:MM 格式的时间字符串
|
||||
_time_str_st = st.builds(
|
||||
lambda h, m: f"{h:02d}:{m:02d}",
|
||||
h=st.integers(min_value=0, max_value=23),
|
||||
m=st.integers(min_value=0, max_value=59),
|
||||
)
|
||||
|
||||
# ISO weekday 列表(1=Monday ... 7=Sunday)
|
||||
_weekly_days_st = st.lists(
|
||||
st.integers(min_value=1, max_value=7),
|
||||
min_size=1, max_size=7, unique=True,
|
||||
)
|
||||
|
||||
# 简单 cron 表达式(minute hour * * *)
|
||||
_cron_st = st.builds(
|
||||
lambda m, h: f"{m} {h} * * *",
|
||||
m=st.integers(min_value=0, max_value=59),
|
||||
h=st.integers(min_value=0, max_value=23),
|
||||
)
|
||||
|
||||
|
||||
def _build_schedule_config(schedule_type, interval_value, interval_unit,
|
||||
daily_time, weekly_days, weekly_time, cron_expression):
|
||||
"""根据 schedule_type 构建 ScheduleConfigSchema。"""
|
||||
return ScheduleConfigSchema(
|
||||
schedule_type=schedule_type,
|
||||
interval_value=interval_value,
|
||||
interval_unit=interval_unit,
|
||||
daily_time=daily_time,
|
||||
weekly_days=weekly_days,
|
||||
weekly_time=weekly_time,
|
||||
cron_expression=cron_expression,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
|
||||
_schedule_config_st = st.builds(
|
||||
_build_schedule_config,
|
||||
schedule_type=_schedule_type_st,
|
||||
interval_value=st.integers(min_value=1, max_value=168),
|
||||
interval_unit=_interval_unit_st,
|
||||
daily_time=_time_str_st,
|
||||
weekly_days=_weekly_days_st,
|
||||
weekly_time=_time_str_st,
|
||||
cron_expression=_cron_st,
|
||||
)
|
||||
|
||||
# 用于 Property 14 的非 once 调度配置(启用后 next_run_at 应非 NULL)
|
||||
_non_once_schedule_type_st = st.sampled_from(["interval", "daily", "weekly", "cron"])
|
||||
|
||||
_non_once_schedule_config_st = st.builds(
|
||||
_build_schedule_config,
|
||||
schedule_type=_non_once_schedule_type_st,
|
||||
interval_value=st.integers(min_value=1, max_value=168),
|
||||
interval_unit=_interval_unit_st,
|
||||
daily_time=_time_str_st,
|
||||
weekly_days=_weekly_days_st,
|
||||
weekly_time=_time_str_st,
|
||||
cron_expression=_cron_st,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_NOW = datetime(2025, 6, 10, 10, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
# 模拟数据库行的列顺序(与 _SELECT_COLS 对应,共 13 列)
|
||||
# id, site_id, name, task_codes, task_config, schedule_config,
|
||||
# enabled, last_run_at, next_run_at, run_count, last_status,
|
||||
# created_at, updated_at
|
||||
|
||||
|
||||
def _make_db_row(
|
||||
schedule_id: str,
|
||||
site_id: int,
|
||||
name: str,
|
||||
task_codes: list[str],
|
||||
task_config: dict,
|
||||
schedule_config: dict,
|
||||
enabled: bool = True,
|
||||
next_run_at: datetime | None = None,
|
||||
) -> tuple:
|
||||
"""构造模拟数据库行。"""
|
||||
return (
|
||||
schedule_id, site_id, name, task_codes,
|
||||
json.dumps(task_config) if isinstance(task_config, dict) else task_config,
|
||||
json.dumps(schedule_config) if isinstance(schedule_config, dict) else schedule_config,
|
||||
enabled, None, next_run_at, 0, None, _NOW, _NOW,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 12: 调度任务 CRUD 往返
|
||||
# **Validates: Requirements 5.1, 5.4**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
site_id=_site_id_st,
|
||||
schedule_config=_schedule_config_st,
|
||||
task_config=_simple_task_config_st,
|
||||
name=st.text(min_size=1, max_size=50, alphabet=st.characters(
|
||||
whitelist_categories=("L", "N"), whitelist_characters="_- "
|
||||
)),
|
||||
task_codes=st.lists(st.sampled_from(_task_codes), min_size=1, max_size=3, unique=True),
|
||||
)
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_schedule_crud_round_trip(
|
||||
mock_get_conn, site_id, schedule_config, task_config, name, task_codes,
|
||||
):
|
||||
"""Property 12: 调度任务 CRUD 往返。
|
||||
|
||||
有效的 ScheduleConfigSchema,创建调度任务后再查询该任务,
|
||||
返回的调度配置应与创建时提交的配置等价。
|
||||
"""
|
||||
schedule_config_dict = schedule_config.model_dump()
|
||||
next_run = calculate_next_run(schedule_config, _NOW)
|
||||
|
||||
# 构造创建后数据库返回的行
|
||||
created_row = _make_db_row(
|
||||
schedule_id="test-sched-id",
|
||||
site_id=site_id,
|
||||
name=name,
|
||||
task_codes=task_codes,
|
||||
task_config=task_config,
|
||||
schedule_config=schedule_config_dict,
|
||||
enabled=schedule_config.enabled,
|
||||
next_run_at=next_run,
|
||||
)
|
||||
|
||||
# --- 创建阶段 ---
|
||||
# mock POST 的数据库连接(INSERT ... RETURNING)
|
||||
create_cursor = MagicMock()
|
||||
create_cursor.fetchone.return_value = created_row
|
||||
create_conn = MagicMock()
|
||||
create_conn.cursor.return_value.__enter__ = MagicMock(return_value=create_cursor)
|
||||
create_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# --- 查询阶段 ---
|
||||
# mock GET 的数据库连接(SELECT ... fetchall)
|
||||
list_cursor = MagicMock()
|
||||
list_cursor.fetchall.return_value = [created_row]
|
||||
list_conn = MagicMock()
|
||||
list_conn.cursor.return_value.__enter__ = MagicMock(return_value=list_cursor)
|
||||
list_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# 依次返回 create_conn 和 list_conn
|
||||
mock_get_conn.side_effect = [create_conn, list_conn]
|
||||
|
||||
# 覆盖认证
|
||||
test_user = CurrentUser(user_id=1, site_id=site_id)
|
||||
app.dependency_overrides[get_current_user] = lambda: test_user
|
||||
|
||||
try:
|
||||
client = TestClient(app)
|
||||
|
||||
# 创建调度任务
|
||||
create_body = {
|
||||
"name": name,
|
||||
"task_codes": task_codes,
|
||||
"task_config": task_config,
|
||||
"schedule_config": schedule_config_dict,
|
||||
}
|
||||
create_resp = client.post("/api/schedules", json=create_body)
|
||||
assert create_resp.status_code == 201, (
|
||||
f"创建应返回 201,实际 {create_resp.status_code}: {create_resp.text}"
|
||||
)
|
||||
created_data = create_resp.json()
|
||||
|
||||
# 查询调度任务列表
|
||||
list_resp = client.get("/api/schedules")
|
||||
assert list_resp.status_code == 200
|
||||
list_data = list_resp.json()
|
||||
|
||||
assert len(list_data) >= 1, "查询结果应至少包含刚创建的任务"
|
||||
|
||||
# 找到刚创建的任务
|
||||
found = next((s for s in list_data if s["id"] == created_data["id"]), None)
|
||||
assert found is not None, "查询结果应包含刚创建的任务"
|
||||
|
||||
# 核心验证:schedule_config 往返一致
|
||||
returned_config = found["schedule_config"]
|
||||
for key in schedule_config_dict:
|
||||
assert returned_config[key] == schedule_config_dict[key], (
|
||||
f"schedule_config.{key} 不一致:"
|
||||
f"提交={schedule_config_dict[key]},返回={returned_config[key]}"
|
||||
)
|
||||
|
||||
# 验证 task_config 往返一致
|
||||
returned_task_config = found["task_config"]
|
||||
for key in task_config:
|
||||
assert returned_task_config[key] == task_config[key], (
|
||||
f"task_config.{key} 不一致:提交={task_config[key]},返回={returned_task_config[key]}"
|
||||
)
|
||||
|
||||
# 验证基本字段
|
||||
assert found["name"] == name
|
||||
assert found["task_codes"] == task_codes
|
||||
assert found["site_id"] == site_id
|
||||
finally:
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(user_id=1, site_id=100)
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 13: 到期调度任务自动入队
|
||||
# **Validates: Requirements 5.2**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
site_id=_site_id_st,
|
||||
schedule_config=_schedule_config_st,
|
||||
task_config=_simple_task_config_st,
|
||||
)
|
||||
@patch("app.services.scheduler.task_queue")
|
||||
@patch("app.services.scheduler.get_connection")
|
||||
def test_due_schedule_auto_enqueue(
|
||||
mock_get_conn, mock_tq, site_id, schedule_config, task_config,
|
||||
):
|
||||
"""Property 13: 到期调度任务自动入队。
|
||||
|
||||
enabled 为 true 且 next_run_at 早于当前时间的调度任务,
|
||||
check_and_enqueue 执行后该任务的 TaskConfig 应出现在执行队列中。
|
||||
"""
|
||||
sched = Scheduler()
|
||||
schedule_config_dict = schedule_config.model_dump()
|
||||
|
||||
# 构造到期任务:next_run_at 在过去(比 now 早 5 分钟)
|
||||
task_id = "due-task-001"
|
||||
|
||||
# --- mock SELECT 到期任务 ---
|
||||
select_cursor = MagicMock()
|
||||
select_cursor.fetchall.return_value = [
|
||||
(task_id, site_id, json.dumps(task_config), json.dumps(schedule_config_dict)),
|
||||
]
|
||||
select_cursor.__enter__ = MagicMock(return_value=select_cursor)
|
||||
select_cursor.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# --- mock UPDATE 调度状态 ---
|
||||
update_cursor = MagicMock()
|
||||
update_cursor.__enter__ = MagicMock(return_value=update_cursor)
|
||||
update_cursor.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
conn = MagicMock()
|
||||
conn.cursor.side_effect = [select_cursor, update_cursor]
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
mock_tq.enqueue.return_value = "queue-id-123"
|
||||
|
||||
# 执行
|
||||
count = sched.check_and_enqueue()
|
||||
|
||||
# 验证:到期任务被入队
|
||||
assert count == 1, f"应有 1 个任务入队,实际 {count}"
|
||||
mock_tq.enqueue.assert_called_once()
|
||||
|
||||
# 验证入队参数
|
||||
call_args = mock_tq.enqueue.call_args
|
||||
enqueued_config = call_args[0][0]
|
||||
enqueued_site_id = call_args[0][1]
|
||||
|
||||
# site_id 应匹配
|
||||
assert enqueued_site_id == site_id, (
|
||||
f"入队的 site_id 应为 {site_id},实际 {enqueued_site_id}"
|
||||
)
|
||||
|
||||
# TaskConfig 应与原始配置一致
|
||||
assert isinstance(enqueued_config, TaskConfigSchema)
|
||||
assert enqueued_config.tasks == task_config["tasks"], (
|
||||
f"入队的 tasks 应为 {task_config['tasks']},实际 {enqueued_config.tasks}"
|
||||
)
|
||||
assert enqueued_config.pipeline == task_config["pipeline"], (
|
||||
f"入队的 pipeline 应为 {task_config['pipeline']},实际 {enqueued_config.pipeline}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 14: 调度任务启用/禁用状态
|
||||
# **Validates: Requirements 5.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100, deadline=None)
|
||||
@given(
|
||||
site_id=_site_id_st,
|
||||
schedule_config=_non_once_schedule_config_st,
|
||||
task_config=_simple_task_config_st,
|
||||
name=st.text(min_size=1, max_size=30, alphabet=st.characters(
|
||||
whitelist_categories=("L", "N"), whitelist_characters="_- "
|
||||
)),
|
||||
task_codes=st.lists(st.sampled_from(_task_codes), min_size=1, max_size=3, unique=True),
|
||||
)
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_schedule_toggle_next_run(
|
||||
mock_get_conn, site_id, schedule_config, task_config, name, task_codes,
|
||||
):
|
||||
"""Property 14: 调度任务启用/禁用状态。
|
||||
|
||||
禁用后 next_run_at 应为 NULL;
|
||||
重新启用后 next_run_at 应被重新计算为非 NULL 值(对于非一次性调度)。
|
||||
"""
|
||||
schedule_config_dict = schedule_config.model_dump()
|
||||
next_run_enabled = calculate_next_run(schedule_config, _NOW)
|
||||
|
||||
# --- 第一步:禁用(enabled=True → False)---
|
||||
# toggle 端点先 SELECT 当前状态,再 UPDATE RETURNING
|
||||
|
||||
# 禁用后的数据库行
|
||||
disabled_row = _make_db_row(
|
||||
schedule_id="sched-toggle-1",
|
||||
site_id=site_id,
|
||||
name=name,
|
||||
task_codes=task_codes,
|
||||
task_config=task_config,
|
||||
schedule_config=schedule_config_dict,
|
||||
enabled=False,
|
||||
next_run_at=None, # 禁用后 next_run_at 为 NULL
|
||||
)
|
||||
|
||||
# mock 禁用操作的数据库连接
|
||||
disable_cursor = MagicMock()
|
||||
disable_cursor.fetchone.side_effect = [
|
||||
(True, json.dumps(schedule_config_dict)), # SELECT 当前状态(enabled=True)
|
||||
disabled_row, # UPDATE RETURNING
|
||||
]
|
||||
disable_conn = MagicMock()
|
||||
disable_conn.cursor.return_value.__enter__ = MagicMock(return_value=disable_cursor)
|
||||
disable_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# --- 第二步:启用(enabled=False → True)---
|
||||
enabled_row = _make_db_row(
|
||||
schedule_id="sched-toggle-1",
|
||||
site_id=site_id,
|
||||
name=name,
|
||||
task_codes=task_codes,
|
||||
task_config=task_config,
|
||||
schedule_config=schedule_config_dict,
|
||||
enabled=True,
|
||||
next_run_at=next_run_enabled, # 启用后 next_run_at 被重新计算
|
||||
)
|
||||
|
||||
enable_cursor = MagicMock()
|
||||
enable_cursor.fetchone.side_effect = [
|
||||
(False, json.dumps(schedule_config_dict)), # SELECT 当前状态(enabled=False)
|
||||
enabled_row, # UPDATE RETURNING
|
||||
]
|
||||
enable_conn = MagicMock()
|
||||
enable_conn.cursor.return_value.__enter__ = MagicMock(return_value=enable_cursor)
|
||||
enable_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# 依次返回两个连接
|
||||
mock_get_conn.side_effect = [disable_conn, enable_conn]
|
||||
|
||||
# 覆盖认证
|
||||
test_user = CurrentUser(user_id=1, site_id=site_id)
|
||||
app.dependency_overrides[get_current_user] = lambda: test_user
|
||||
|
||||
try:
|
||||
client = TestClient(app)
|
||||
|
||||
# 禁用
|
||||
disable_resp = client.patch("/api/schedules/sched-toggle-1/toggle")
|
||||
assert disable_resp.status_code == 200, (
|
||||
f"禁用应返回 200,实际 {disable_resp.status_code}: {disable_resp.text}"
|
||||
)
|
||||
disable_data = disable_resp.json()
|
||||
|
||||
# 验证:禁用后 enabled=False,next_run_at=NULL
|
||||
assert disable_data["enabled"] is False, "禁用后 enabled 应为 False"
|
||||
assert disable_data["next_run_at"] is None, "禁用后 next_run_at 应为 NULL"
|
||||
|
||||
# 启用
|
||||
enable_resp = client.patch("/api/schedules/sched-toggle-1/toggle")
|
||||
assert enable_resp.status_code == 200, (
|
||||
f"启用应返回 200,实际 {enable_resp.status_code}: {enable_resp.text}"
|
||||
)
|
||||
enable_data = enable_resp.json()
|
||||
|
||||
# 验证:启用后 enabled=True,next_run_at 非 NULL(非一次性调度)
|
||||
assert enable_data["enabled"] is True, "启用后 enabled 应为 True"
|
||||
assert enable_data["next_run_at"] is not None, (
|
||||
"非一次性调度启用后 next_run_at 应被重新计算为非 NULL 值"
|
||||
)
|
||||
finally:
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(user_id=1, site_id=100)
|
||||
384
apps/backend/tests/test_scheduler.py
Normal file
384
apps/backend/tests/test_scheduler.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Scheduler 单元测试
|
||||
|
||||
覆盖:
|
||||
- calculate_next_run:各种调度类型的下次执行时间计算
|
||||
- _parse_simple_cron:简单 cron 表达式解析
|
||||
- check_and_enqueue:到期检查与入队逻辑
|
||||
- start / stop:后台循环生命周期
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.schemas.schedules import ScheduleConfigSchema
|
||||
from app.schemas.tasks import TaskConfigSchema
|
||||
from app.services.scheduler import (
|
||||
Scheduler,
|
||||
calculate_next_run,
|
||||
_parse_simple_cron,
|
||||
_parse_time,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def sched() -> Scheduler:
|
||||
return Scheduler()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def now() -> datetime:
|
||||
"""固定时间点:2025-06-10 10:00:00 UTC(周二)"""
|
||||
return datetime(2025, 6, 10, 10, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def _mock_cursor(fetchone_val=None, fetchall_val=None, rowcount=1):
|
||||
cur = MagicMock()
|
||||
cur.fetchone.return_value = fetchone_val
|
||||
cur.fetchall.return_value = fetchall_val or []
|
||||
cur.rowcount = rowcount
|
||||
cur.__enter__ = MagicMock(return_value=cur)
|
||||
cur.__exit__ = MagicMock(return_value=False)
|
||||
return cur
|
||||
|
||||
|
||||
def _mock_conn(cursor):
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value = cursor
|
||||
return conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_time
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseTime:
|
||||
def test_standard_format(self):
|
||||
assert _parse_time("04:00") == (4, 0)
|
||||
|
||||
def test_with_minutes(self):
|
||||
assert _parse_time("23:45") == (23, 45)
|
||||
|
||||
def test_midnight(self):
|
||||
assert _parse_time("00:00") == (0, 0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# calculate_next_run — once
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNextRunOnce:
|
||||
def test_once_returns_none(self, now):
|
||||
cfg = ScheduleConfigSchema(schedule_type="once")
|
||||
assert calculate_next_run(cfg, now) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# calculate_next_run — interval
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNextRunInterval:
|
||||
def test_interval_minutes(self, now):
|
||||
cfg = ScheduleConfigSchema(
|
||||
schedule_type="interval", interval_value=15, interval_unit="minutes",
|
||||
)
|
||||
result = calculate_next_run(cfg, now)
|
||||
assert result == now + timedelta(minutes=15)
|
||||
|
||||
def test_interval_hours(self, now):
|
||||
cfg = ScheduleConfigSchema(
|
||||
schedule_type="interval", interval_value=2, interval_unit="hours",
|
||||
)
|
||||
result = calculate_next_run(cfg, now)
|
||||
assert result == now + timedelta(hours=2)
|
||||
|
||||
def test_interval_days(self, now):
|
||||
cfg = ScheduleConfigSchema(
|
||||
schedule_type="interval", interval_value=3, interval_unit="days",
|
||||
)
|
||||
result = calculate_next_run(cfg, now)
|
||||
assert result == now + timedelta(days=3)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# calculate_next_run — daily
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNextRunDaily:
|
||||
def test_daily_next_day(self, now):
|
||||
cfg = ScheduleConfigSchema(schedule_type="daily", daily_time="04:00")
|
||||
result = calculate_next_run(cfg, now)
|
||||
expected = datetime(2025, 6, 11, 4, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_daily_custom_time(self, now):
|
||||
cfg = ScheduleConfigSchema(schedule_type="daily", daily_time="18:30")
|
||||
result = calculate_next_run(cfg, now)
|
||||
expected = datetime(2025, 6, 11, 18, 30, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# calculate_next_run — weekly
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNextRunWeekly:
|
||||
def test_weekly_later_this_week(self, now):
|
||||
# now 是周二(2),weekly_days=[5] 周五 → 3 天后
|
||||
cfg = ScheduleConfigSchema(
|
||||
schedule_type="weekly", weekly_days=[5], weekly_time="08:00",
|
||||
)
|
||||
result = calculate_next_run(cfg, now)
|
||||
expected = datetime(2025, 6, 13, 8, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_weekly_next_week(self, now):
|
||||
# now 是周二(2),weekly_days=[1] 周一 → 下周一(6天后)
|
||||
cfg = ScheduleConfigSchema(
|
||||
schedule_type="weekly", weekly_days=[1], weekly_time="04:00",
|
||||
)
|
||||
result = calculate_next_run(cfg, now)
|
||||
expected = datetime(2025, 6, 16, 4, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_weekly_multiple_days_picks_next(self, now):
|
||||
# now 是周二(2),weekly_days=[1, 4, 6] → 周四(4),2 天后
|
||||
cfg = ScheduleConfigSchema(
|
||||
schedule_type="weekly", weekly_days=[1, 4, 6], weekly_time="09:00",
|
||||
)
|
||||
result = calculate_next_run(cfg, now)
|
||||
expected = datetime(2025, 6, 12, 9, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# calculate_next_run — cron
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNextRunCron:
|
||||
def test_cron_daily(self, now):
|
||||
cfg = ScheduleConfigSchema(schedule_type="cron", cron_expression="30 4 * * *")
|
||||
result = calculate_next_run(cfg, now)
|
||||
expected = datetime(2025, 6, 11, 4, 30, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_cron_with_dow(self, now):
|
||||
# "0 8 * * 5" → 每周五 08:00,now 是周二 → 周五(3天后)
|
||||
cfg = ScheduleConfigSchema(schedule_type="cron", cron_expression="0 8 * * 5")
|
||||
result = calculate_next_run(cfg, now)
|
||||
expected = datetime(2025, 6, 13, 8, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_simple_cron
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseSimpleCron:
|
||||
def test_daily_cron(self, now):
|
||||
result = _parse_simple_cron("0 4 * * *", now)
|
||||
expected = datetime(2025, 6, 11, 4, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_invalid_field_count_fallback(self, now):
|
||||
# 字段数不对,回退到明天 04:00
|
||||
result = _parse_simple_cron("0 4 *", now)
|
||||
expected = datetime(2025, 6, 11, 4, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_wildcard_hour_minute(self, now):
|
||||
# "* * * * *" → hour=0, minute=0,明天 00:00
|
||||
result = _parse_simple_cron("* * * * *", now)
|
||||
expected = datetime(2025, 6, 11, 0, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_dow_sunday(self, now):
|
||||
# "0 6 * * 0" → 每周日 06:00,now 是周二 → 周日(5天后)
|
||||
result = _parse_simple_cron("0 6 * * 0", now)
|
||||
expected = datetime(2025, 6, 15, 6, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_dow_same_day_future_time(self):
|
||||
# 周二 08:00,cron 指定周二 12:00 → 当天
|
||||
now = datetime(2025, 6, 10, 8, 0, 0, tzinfo=timezone.utc)
|
||||
result = _parse_simple_cron("0 12 * * 2", now)
|
||||
expected = datetime(2025, 6, 10, 12, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_dow_same_day_past_time(self):
|
||||
# 周二 14:00,cron 指定周二 12:00 → 下周二
|
||||
now = datetime(2025, 6, 10, 14, 0, 0, tzinfo=timezone.utc)
|
||||
result = _parse_simple_cron("0 12 * * 2", now)
|
||||
expected = datetime(2025, 6, 17, 12, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_and_enqueue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckAndEnqueue:
|
||||
@patch("app.services.scheduler.get_connection")
|
||||
@patch("app.services.scheduler.task_queue")
|
||||
def test_enqueues_due_tasks(self, mock_tq, mock_get_conn, sched):
|
||||
"""到期任务应被入队,且更新 last_run_at / run_count / next_run_at"""
|
||||
task_config = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods_dwd"}
|
||||
schedule_config = {
|
||||
"schedule_type": "interval",
|
||||
"interval_value": 1,
|
||||
"interval_unit": "hours",
|
||||
}
|
||||
|
||||
# 第一次 cursor:SELECT 到期任务
|
||||
select_cur = _mock_cursor(
|
||||
fetchall_val=[
|
||||
("task-uuid-1", 42, json.dumps(task_config), json.dumps(schedule_config)),
|
||||
]
|
||||
)
|
||||
# 第二次 cursor:UPDATE
|
||||
update_cur = _mock_cursor()
|
||||
|
||||
conn = MagicMock()
|
||||
# cursor() 依次返回 select_cur 和 update_cur
|
||||
conn.cursor.side_effect = [select_cur, update_cur]
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
mock_tq.enqueue.return_value = "queue-id-1"
|
||||
|
||||
count = sched.check_and_enqueue()
|
||||
|
||||
assert count == 1
|
||||
mock_tq.enqueue.assert_called_once()
|
||||
# 验证 enqueue 的参数
|
||||
call_args = mock_tq.enqueue.call_args
|
||||
assert call_args[0][1] == 42 # site_id
|
||||
assert isinstance(call_args[0][0], TaskConfigSchema)
|
||||
|
||||
@patch("app.services.scheduler.get_connection")
|
||||
@patch("app.services.scheduler.task_queue")
|
||||
def test_no_due_tasks(self, mock_tq, mock_get_conn, sched):
|
||||
"""没有到期任务时,不入队"""
|
||||
cur = _mock_cursor(fetchall_val=[])
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
count = sched.check_and_enqueue()
|
||||
|
||||
assert count == 0
|
||||
mock_tq.enqueue.assert_not_called()
|
||||
|
||||
@patch("app.services.scheduler.get_connection")
|
||||
@patch("app.services.scheduler.task_queue")
|
||||
def test_skips_invalid_config(self, mock_tq, mock_get_conn, sched):
|
||||
"""配置反序列化失败的任务应被跳过"""
|
||||
# task_config 缺少必填字段 tasks
|
||||
bad_config = {"pipeline": "api_ods_dwd"}
|
||||
schedule_config = {"schedule_type": "once"}
|
||||
|
||||
cur = _mock_cursor(
|
||||
fetchall_val=[
|
||||
("task-uuid-bad", 42, json.dumps(bad_config), json.dumps(schedule_config)),
|
||||
]
|
||||
)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
count = sched.check_and_enqueue()
|
||||
|
||||
assert count == 0
|
||||
mock_tq.enqueue.assert_not_called()
|
||||
|
||||
@patch("app.services.scheduler.get_connection")
|
||||
@patch("app.services.scheduler.task_queue")
|
||||
def test_enqueue_failure_continues(self, mock_tq, mock_get_conn, sched):
|
||||
"""入队失败时应跳过该任务,继续处理后续任务"""
|
||||
task_config = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods_dwd"}
|
||||
schedule_config = {"schedule_type": "once"}
|
||||
|
||||
cur = _mock_cursor(
|
||||
fetchall_val=[
|
||||
("task-1", 42, json.dumps(task_config), json.dumps(schedule_config)),
|
||||
("task-2", 42, json.dumps(task_config), json.dumps(schedule_config)),
|
||||
]
|
||||
)
|
||||
# 需要额外的 cursor 给 UPDATE 用
|
||||
update_cur = _mock_cursor()
|
||||
conn = MagicMock()
|
||||
conn.cursor.side_effect = [cur, update_cur]
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
# 第一次入队失败,第二次成功
|
||||
mock_tq.enqueue.side_effect = [Exception("DB error"), "queue-id-2"]
|
||||
|
||||
count = sched.check_and_enqueue()
|
||||
|
||||
assert count == 1
|
||||
assert mock_tq.enqueue.call_count == 2
|
||||
|
||||
@patch("app.services.scheduler.get_connection")
|
||||
@patch("app.services.scheduler.task_queue")
|
||||
def test_once_type_sets_next_run_none(self, mock_tq, mock_get_conn, sched):
|
||||
"""once 类型任务入队后,next_run_at 应被设为 NULL"""
|
||||
task_config = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods_dwd"}
|
||||
schedule_config = {"schedule_type": "once"}
|
||||
|
||||
select_cur = _mock_cursor(
|
||||
fetchall_val=[
|
||||
("task-uuid-1", 42, json.dumps(task_config), json.dumps(schedule_config)),
|
||||
]
|
||||
)
|
||||
update_cur = _mock_cursor()
|
||||
conn = MagicMock()
|
||||
conn.cursor.side_effect = [select_cur, update_cur]
|
||||
mock_get_conn.return_value = conn
|
||||
mock_tq.enqueue.return_value = "queue-id-1"
|
||||
|
||||
sched.check_and_enqueue()
|
||||
|
||||
# 验证 UPDATE 语句中 next_run_at 参数为 None
|
||||
update_call = update_cur.__enter__().execute.call_args
|
||||
# 参数元组的第一个元素是 next_run_at
|
||||
assert update_call[0][1][0] is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# start / stop 生命周期
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLifecycle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_sets_running_false(self, sched):
|
||||
sched._running = True
|
||||
await sched.stop()
|
||||
assert sched._running is False
|
||||
assert sched._loop_task is None
|
||||
|
||||
def test_start_creates_task(self, sched):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
# 在事件循环中启动
|
||||
async def _run():
|
||||
sched.start()
|
||||
assert sched._loop_task is not None
|
||||
assert not sched._loop_task.done()
|
||||
await sched.stop()
|
||||
|
||||
loop.run_until_complete(_run())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_stop_idempotent(self, sched):
|
||||
"""多次 stop 不应报错"""
|
||||
await sched.stop()
|
||||
await sched.stop()
|
||||
assert sched._loop_task is None
|
||||
310
apps/backend/tests/test_schedules_router.py
Normal file
310
apps/backend/tests/test_schedules_router.py
Normal file
@@ -0,0 +1,310 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""调度路由单元测试
|
||||
|
||||
覆盖 5 个端点:list / create / update / delete / toggle
|
||||
通过 mock 绕过数据库,专注路由逻辑验证。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
|
||||
_TEST_USER = CurrentUser(user_id=1, site_id=100)
|
||||
|
||||
|
||||
def _override_auth():
|
||||
return _TEST_USER
|
||||
|
||||
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
client = TestClient(app)
|
||||
|
||||
_NOW = datetime(2024, 6, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
_NEXT = datetime(2024, 6, 2, 4, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
_SCHEDULE_CONFIG = {
|
||||
"schedule_type": "daily",
|
||||
"daily_time": "04:00",
|
||||
}
|
||||
|
||||
_VALID_CREATE = {
|
||||
"name": "每日全量同步",
|
||||
"task_codes": ["ODS_MEMBER", "ODS_ORDER"],
|
||||
"task_config": {"tasks": ["ODS_MEMBER", "ODS_ORDER"], "pipeline": "api_ods"},
|
||||
"schedule_config": _SCHEDULE_CONFIG,
|
||||
}
|
||||
|
||||
# 模拟数据库返回的完整行(13 列,与 _SELECT_COLS 对应)
|
||||
_DB_ROW = (
|
||||
"sched-1", 100, "每日全量同步", ["ODS_MEMBER", "ODS_ORDER"],
|
||||
json.dumps({"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}),
|
||||
json.dumps(_SCHEDULE_CONFIG),
|
||||
True, None, _NEXT, 0, None, _NOW, _NOW,
|
||||
)
|
||||
|
||||
|
||||
def _mock_conn_with_fetchall(rows):
|
||||
"""构造返回 fetchall 的 mock 连接。"""
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchall.return_value = rows
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_conn, mock_cursor
|
||||
|
||||
|
||||
def _mock_conn_with_fetchone(row):
|
||||
"""构造返回 fetchone 的 mock 连接。"""
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = row
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_conn, mock_cursor
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/schedules
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListSchedules:
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_list_returns_schedules(self, mock_get_conn):
|
||||
mock_conn, _ = _mock_conn_with_fetchall([_DB_ROW])
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.get("/api/schedules")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["id"] == "sched-1"
|
||||
assert data[0]["name"] == "每日全量同步"
|
||||
assert data[0]["site_id"] == 100
|
||||
assert data[0]["enabled"] is True
|
||||
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_list_empty(self, mock_get_conn):
|
||||
mock_conn, _ = _mock_conn_with_fetchall([])
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.get("/api/schedules")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_list_filters_by_site_id(self, mock_get_conn):
|
||||
mock_conn, mock_cursor = _mock_conn_with_fetchall([])
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
client.get("/api/schedules")
|
||||
call_args = mock_cursor.execute.call_args
|
||||
assert call_args[0][1] == (100,)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/schedules
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCreateSchedule:
|
||||
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_create_returns_201(self, mock_get_conn, mock_calc):
|
||||
mock_conn, mock_cursor = _mock_conn_with_fetchone(_DB_ROW)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.post("/api/schedules", json=_VALID_CREATE)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["id"] == "sched-1"
|
||||
assert data["name"] == "每日全量同步"
|
||||
|
||||
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_create_injects_site_id(self, mock_get_conn, mock_calc):
|
||||
mock_conn, mock_cursor = _mock_conn_with_fetchone(_DB_ROW)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
client.post("/api/schedules", json=_VALID_CREATE)
|
||||
# INSERT 的第一个参数应为 site_id=100
|
||||
insert_params = mock_cursor.execute.call_args[0][1]
|
||||
assert insert_params[0] == 100
|
||||
|
||||
def test_create_missing_name_returns_422(self):
|
||||
body = {**_VALID_CREATE}
|
||||
del body["name"]
|
||||
resp = client.post("/api/schedules", json=body)
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_create_invalid_schedule_type_returns_422(self):
|
||||
body = {**_VALID_CREATE, "schedule_config": {"schedule_type": "invalid"}}
|
||||
resp = client.post("/api/schedules", json=body)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PUT /api/schedules/{id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUpdateSchedule:
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_update_name(self, mock_get_conn):
|
||||
updated_row = list(_DB_ROW)
|
||||
updated_row[2] = "新名称"
|
||||
mock_conn, _ = _mock_conn_with_fetchone(tuple(updated_row))
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.put("/api/schedules/sched-1", json={"name": "新名称"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "新名称"
|
||||
|
||||
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_update_schedule_config_recalculates_next_run(self, mock_get_conn, mock_calc):
|
||||
mock_conn, _ = _mock_conn_with_fetchone(_DB_ROW)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.put("/api/schedules/sched-1", json={
|
||||
"schedule_config": {"schedule_type": "interval", "interval_value": 2, "interval_unit": "hours"},
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
mock_calc.assert_called_once()
|
||||
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_update_not_found(self, mock_get_conn):
|
||||
mock_conn, _ = _mock_conn_with_fetchone(None)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.put("/api/schedules/nonexistent", json={"name": "x"})
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_update_empty_body_returns_422(self):
|
||||
resp = client.put("/api/schedules/sched-1", json={})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DELETE /api/schedules/{id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDeleteSchedule:
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_delete_success(self, mock_get_conn):
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.rowcount = 1
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.delete("/api/schedules/sched-1")
|
||||
assert resp.status_code == 200
|
||||
assert "已删除" in resp.json()["message"]
|
||||
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_delete_not_found(self, mock_get_conn):
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.rowcount = 0
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.delete("/api/schedules/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PATCH /api/schedules/{id}/toggle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestToggleSchedule:
|
||||
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_toggle_disable(self, mock_get_conn, mock_calc):
|
||||
"""启用 → 禁用:next_run_at 应置 NULL"""
|
||||
# 第一次 fetchone 返回当前状态(enabled=True)
|
||||
# 第二次 fetchone 返回更新后的行
|
||||
disabled_row = list(_DB_ROW)
|
||||
disabled_row[6] = False # enabled
|
||||
disabled_row[8] = None # next_run_at
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.side_effect = [
|
||||
(True, json.dumps(_SCHEDULE_CONFIG)), # SELECT 当前状态
|
||||
tuple(disabled_row), # UPDATE RETURNING
|
||||
]
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.patch("/api/schedules/sched-1/toggle")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["enabled"] is False
|
||||
assert data["next_run_at"] is None
|
||||
|
||||
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_toggle_enable(self, mock_get_conn, mock_calc):
|
||||
"""禁用 → 启用:next_run_at 应被重新计算"""
|
||||
enabled_row = list(_DB_ROW)
|
||||
enabled_row[6] = True
|
||||
enabled_row[8] = _NEXT
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.side_effect = [
|
||||
(False, json.dumps(_SCHEDULE_CONFIG)), # SELECT 当前状态(disabled)
|
||||
tuple(enabled_row), # UPDATE RETURNING
|
||||
]
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.patch("/api/schedules/sched-1/toggle")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["enabled"] is True
|
||||
assert data["next_run_at"] is not None
|
||||
mock_calc.assert_called_once()
|
||||
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_toggle_not_found(self, mock_get_conn):
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = None
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.patch("/api/schedules/nonexistent/toggle")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 认证测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSchedulesAuth:
|
||||
def test_requires_auth(self):
|
||||
"""移除认证覆盖后,所有端点应返回 401/403"""
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
try:
|
||||
assert client.get("/api/schedules").status_code in (401, 403)
|
||||
assert client.post("/api/schedules", json=_VALID_CREATE).status_code in (401, 403)
|
||||
assert client.put("/api/schedules/x", json={"name": "x"}).status_code in (401, 403)
|
||||
assert client.delete("/api/schedules/x").status_code in (401, 403)
|
||||
assert client.patch("/api/schedules/x/toggle").status_code in (401, 403)
|
||||
finally:
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
336
apps/backend/tests/test_site_isolation_properties.py
Normal file
336
apps/backend/tests/test_site_isolation_properties.py
Normal file
@@ -0,0 +1,336 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""门店隔离属性测试(Property-Based Testing)。
|
||||
|
||||
Property 20: 对于任意两个不同 site_id 的 Operator,一个 Operator 查询
|
||||
队列/调度/执行历史时,结果中不应包含另一个 site_id 的数据。
|
||||
|
||||
Validates: Requirements 1.3
|
||||
|
||||
测试策略:
|
||||
- 通过 mock 数据库交互,验证 API 路由在不同 site_id 下的数据隔离
|
||||
- 队列隔离:为 site_id_a 入队任务,用 site_id_b 的 JWT 查询队列,结果应为空
|
||||
- 调度隔离:为 site_id_a 创建调度任务,用 site_id_b 的 JWT 查询调度列表,结果应为空
|
||||
- 执行历史隔离:site_id_a 的执行历史,用 site_id_b 的 JWT 查询不到
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-isolation")
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 通用策略(Strategies)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_site_id_st = st.integers(min_value=1, max_value=2**31 - 1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_mock_user(site_id: int) -> CurrentUser:
|
||||
"""构造指定 site_id 的 mock 用户。"""
|
||||
return CurrentUser(user_id=1, site_id=site_id)
|
||||
|
||||
|
||||
def _make_queue_rows(site_id: int, count: int) -> list[tuple]:
|
||||
"""生成 count 条属于 site_id 的队列行。"""
|
||||
rows = []
|
||||
for i in range(count):
|
||||
rows.append((
|
||||
str(uuid.uuid4()), # id
|
||||
site_id, # site_id
|
||||
json.dumps({"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}), # config
|
||||
"pending", # status
|
||||
i + 1, # position
|
||||
datetime(2024, 1, 1, tzinfo=timezone.utc), # created_at
|
||||
None, # started_at
|
||||
None, # finished_at
|
||||
None, # exit_code
|
||||
None, # error_message
|
||||
))
|
||||
return rows
|
||||
|
||||
|
||||
def _make_schedule_rows(site_id: int, count: int) -> list[tuple]:
|
||||
"""生成 count 条属于 site_id 的调度行。"""
|
||||
now = datetime.now(timezone.utc)
|
||||
rows = []
|
||||
for i in range(count):
|
||||
rows.append((
|
||||
str(uuid.uuid4()), # id
|
||||
site_id, # site_id
|
||||
f"调度任务_{i}", # name
|
||||
["ODS_MEMBER"], # task_codes
|
||||
{"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}, # task_config
|
||||
{"schedule_type": "daily", "daily_time": "04:00", # schedule_config
|
||||
"interval_value": 1, "interval_unit": "hours",
|
||||
"weekly_days": [1], "weekly_time": "04:00",
|
||||
"cron_expression": "0 4 * * *", "enabled": True,
|
||||
"start_date": None, "end_date": None},
|
||||
True, # enabled
|
||||
None, # last_run_at
|
||||
now + timedelta(hours=1), # next_run_at
|
||||
0, # run_count
|
||||
None, # last_status
|
||||
now, # created_at
|
||||
now, # updated_at
|
||||
))
|
||||
return rows
|
||||
|
||||
|
||||
def _make_history_rows(site_id: int, count: int) -> list[tuple]:
|
||||
"""生成 count 条属于 site_id 的执行历史行。"""
|
||||
base_time = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
rows = []
|
||||
for i in range(count):
|
||||
rows.append((
|
||||
str(uuid.uuid4()), # id
|
||||
site_id, # site_id
|
||||
["ODS_MEMBER"], # task_codes
|
||||
"success", # status
|
||||
base_time + timedelta(hours=i), # started_at
|
||||
base_time + timedelta(hours=i, minutes=30), # finished_at
|
||||
0, # exit_code
|
||||
1800000, # duration_ms
|
||||
"python -m cli.main", # command
|
||||
None, # summary
|
||||
))
|
||||
return rows
|
||||
|
||||
|
||||
def _mock_conn_returning(rows: list[tuple]) -> MagicMock:
|
||||
"""构造一个 mock connection,其 cursor.fetchall 返回指定行。"""
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchall.return_value = rows
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 20.1: 队列隔离
|
||||
# **Validates: Requirements 1.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100, deadline=None)
|
||||
@given(
|
||||
site_id_a=_site_id_st,
|
||||
site_id_b=_site_id_st,
|
||||
queue_count=st.integers(min_value=1, max_value=5),
|
||||
)
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_queue_isolation(mock_get_conn, site_id_a, site_id_b, queue_count):
|
||||
"""Property 20.1: 队列隔离。
|
||||
|
||||
为 site_id_a 入队若干任务后,用 site_id_b 的身份查询队列,
|
||||
结果应为空——不同门店的队列数据互不可见。
|
||||
"""
|
||||
assume(site_id_a != site_id_b)
|
||||
|
||||
# site_id_a 的队列数据
|
||||
rows_a = _make_queue_rows(site_id_a, queue_count)
|
||||
|
||||
# 核心隔离逻辑:根据查询时传入的 site_id 过滤
|
||||
# list_pending 内部 SQL: WHERE site_id = %s AND status = 'pending'
|
||||
def conn_for_site(querying_site_id):
|
||||
"""模拟数据库行为:只返回匹配 site_id 的行。"""
|
||||
if querying_site_id == site_id_a:
|
||||
return rows_a
|
||||
return [] # site_id_b 查不到 site_id_a 的数据
|
||||
|
||||
captured_params = {}
|
||||
|
||||
def make_mock_conn():
|
||||
mock_cursor = MagicMock()
|
||||
|
||||
def execute_side_effect(sql, params=None):
|
||||
if params:
|
||||
captured_params["site_id"] = params[0]
|
||||
# 根据 SQL 中的 site_id 参数返回对应数据
|
||||
mock_cursor.fetchall.return_value = conn_for_site(params[0])
|
||||
|
||||
mock_cursor.execute = MagicMock(side_effect=execute_side_effect)
|
||||
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_cursor.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
return mock_conn
|
||||
|
||||
mock_get_conn.return_value = make_mock_conn()
|
||||
|
||||
# 用 site_id_b 的身份查询队列
|
||||
app.dependency_overrides[get_current_user] = lambda: _make_mock_user(site_id_b)
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/execution/queue")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
|
||||
# 验证:site_id_b 查不到 site_id_a 的任何数据
|
||||
assert len(data) == 0, (
|
||||
f"site_id_b={site_id_b} 不应看到 site_id_a={site_id_a} 的队列数据,"
|
||||
f"但返回了 {len(data)} 条记录"
|
||||
)
|
||||
|
||||
# 额外验证:即使有数据返回,也不应包含 site_id_a 的记录
|
||||
for item in data:
|
||||
assert item.get("site_id") != site_id_a, (
|
||||
f"结果中不应包含 site_id_a={site_id_a} 的数据"
|
||||
)
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 20.2: 调度隔离
|
||||
# **Validates: Requirements 1.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
site_id_a=_site_id_st,
|
||||
site_id_b=_site_id_st,
|
||||
schedule_count=st.integers(min_value=1, max_value=5),
|
||||
)
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_schedule_isolation(mock_get_conn, site_id_a, site_id_b, schedule_count):
|
||||
"""Property 20.2: 调度隔离。
|
||||
|
||||
为 site_id_a 创建若干调度任务后,用 site_id_b 的身份查询调度列表,
|
||||
结果应为空——不同门店的调度数据互不可见。
|
||||
"""
|
||||
assume(site_id_a != site_id_b)
|
||||
|
||||
# site_id_a 的调度数据
|
||||
rows_a = _make_schedule_rows(site_id_a, schedule_count)
|
||||
|
||||
def make_mock_conn():
|
||||
mock_cursor = MagicMock()
|
||||
|
||||
def execute_side_effect(sql, params=None):
|
||||
if params:
|
||||
querying_site_id = params[0]
|
||||
# 只返回匹配 site_id 的行
|
||||
if querying_site_id == site_id_a:
|
||||
mock_cursor.fetchall.return_value = rows_a
|
||||
else:
|
||||
mock_cursor.fetchall.return_value = []
|
||||
|
||||
mock_cursor.execute = MagicMock(side_effect=execute_side_effect)
|
||||
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_cursor.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
return mock_conn
|
||||
|
||||
mock_get_conn.return_value = make_mock_conn()
|
||||
|
||||
# 用 site_id_b 的身份查询调度列表
|
||||
app.dependency_overrides[get_current_user] = lambda: _make_mock_user(site_id_b)
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/schedules")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
|
||||
# 验证:site_id_b 查不到 site_id_a 的任何调度数据
|
||||
assert len(data) == 0, (
|
||||
f"site_id_b={site_id_b} 不应看到 site_id_a={site_id_a} 的调度数据,"
|
||||
f"但返回了 {len(data)} 条记录"
|
||||
)
|
||||
|
||||
# 额外验证:即使有数据返回,也不应包含 site_id_a 的记录
|
||||
for item in data:
|
||||
assert item.get("site_id") != site_id_a, (
|
||||
f"结果中不应包含 site_id_a={site_id_a} 的调度数据"
|
||||
)
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 20.3: 执行历史隔离
|
||||
# **Validates: Requirements 1.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100, deadline=None)
|
||||
@given(
|
||||
site_id_a=_site_id_st,
|
||||
site_id_b=_site_id_st,
|
||||
history_count=st.integers(min_value=1, max_value=10),
|
||||
)
|
||||
@patch("app.routers.execution.get_connection")
|
||||
def test_execution_history_isolation(mock_get_conn, site_id_a, site_id_b, history_count):
|
||||
"""Property 20.3: 执行历史隔离。
|
||||
|
||||
site_id_a 有若干执行历史记录,用 site_id_b 的身份查询执行历史,
|
||||
结果应为空——不同门店的执行历史互不可见。
|
||||
"""
|
||||
assume(site_id_a != site_id_b)
|
||||
|
||||
# site_id_a 的执行历史数据
|
||||
rows_a = _make_history_rows(site_id_a, history_count)
|
||||
|
||||
def make_mock_conn():
|
||||
mock_cursor = MagicMock()
|
||||
|
||||
def execute_side_effect(sql, params=None):
|
||||
if params:
|
||||
querying_site_id = params[0]
|
||||
# 只返回匹配 site_id 的行
|
||||
if querying_site_id == site_id_a:
|
||||
mock_cursor.fetchall.return_value = rows_a
|
||||
else:
|
||||
mock_cursor.fetchall.return_value = []
|
||||
|
||||
mock_cursor.execute = MagicMock(side_effect=execute_side_effect)
|
||||
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_cursor.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
return mock_conn
|
||||
|
||||
mock_get_conn.return_value = make_mock_conn()
|
||||
|
||||
# 用 site_id_b 的身份查询执行历史
|
||||
app.dependency_overrides[get_current_user] = lambda: _make_mock_user(site_id_b)
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/execution/history")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
|
||||
# 验证:site_id_b 查不到 site_id_a 的任何执行历史
|
||||
assert len(data) == 0, (
|
||||
f"site_id_b={site_id_b} 不应看到 site_id_a={site_id_a} 的执行历史,"
|
||||
f"但返回了 {len(data)} 条记录"
|
||||
)
|
||||
|
||||
# 额外验证:即使有数据返回,也不应包含 site_id_a 的记录
|
||||
for item in data:
|
||||
assert item.get("site_id") != site_id_a, (
|
||||
f"结果中不应包含 site_id_a={site_id_a} 的执行历史"
|
||||
)
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
275
apps/backend/tests/test_task_config_properties.py
Normal file
275
apps/backend/tests/test_task_config_properties.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""TaskConfig 属性测试(Property-Based Testing)。
|
||||
|
||||
使用 hypothesis 验证 TaskConfig 相关的通用正确性属性:
|
||||
- Property 1: TaskConfig 序列化往返一致性
|
||||
- Property 6: 时间窗口验证
|
||||
- Property 7: TaskConfig 到 CLI 命令转换完整性
|
||||
"""
|
||||
|
||||
import datetime
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.schemas.tasks import TaskConfigSchema
|
||||
from app.services.cli_builder import CLIBuilder, VALID_FLOWS, VALID_PROCESSING_MODES
|
||||
from app.services.task_registry import ALL_TASKS
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 策略(Strategies)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 从真实任务注册表中采样任务代码
|
||||
_task_codes = [t.code for t in ALL_TASKS]
|
||||
|
||||
_tasks_st = st.lists(
|
||||
st.sampled_from(_task_codes),
|
||||
min_size=1,
|
||||
max_size=5,
|
||||
unique=True,
|
||||
)
|
||||
|
||||
_pipeline_st = st.sampled_from(sorted(VALID_FLOWS))
|
||||
_processing_mode_st = st.sampled_from(sorted(VALID_PROCESSING_MODES))
|
||||
_window_mode_st = st.sampled_from(["lookback", "custom"])
|
||||
|
||||
# 日期策略:生成 YYYY-MM-DD 格式字符串
|
||||
_date_st = st.dates(
|
||||
min_value=datetime.date(2020, 1, 1),
|
||||
max_value=datetime.date(2030, 12, 31),
|
||||
).map(lambda d: d.isoformat())
|
||||
|
||||
_window_split_st = st.sampled_from([None, "none", "day"])
|
||||
_window_split_days_st = st.one_of(st.none(), st.sampled_from([1, 10, 30]))
|
||||
_lookback_hours_st = st.integers(min_value=1, max_value=720)
|
||||
_overlap_seconds_st = st.integers(min_value=0, max_value=7200)
|
||||
_store_id_st = st.one_of(st.none(), st.integers(min_value=1, max_value=2**31 - 1))
|
||||
|
||||
# DWD 表名采样
|
||||
_dwd_table_names = [
|
||||
"dwd.dim_site",
|
||||
"dwd.dim_member",
|
||||
"dwd.dwd_settlement_head",
|
||||
]
|
||||
_dwd_only_tables_st = st.one_of(
|
||||
st.none(),
|
||||
st.lists(st.sampled_from(_dwd_table_names), min_size=1, max_size=3, unique=True),
|
||||
)
|
||||
|
||||
|
||||
def _valid_task_config_st():
|
||||
"""生成有效的 TaskConfigSchema 的复合策略。
|
||||
|
||||
确保 window_mode=custom 时 window_end >= window_start,
|
||||
避免触发 Pydantic 验证错误。
|
||||
"""
|
||||
|
||||
@st.composite
|
||||
def _build(draw):
|
||||
tasks = draw(_tasks_st)
|
||||
pipeline = draw(_pipeline_st)
|
||||
processing_mode = draw(_processing_mode_st)
|
||||
dry_run = draw(st.booleans())
|
||||
window_mode = draw(_window_mode_st)
|
||||
store_id = draw(_store_id_st)
|
||||
dwd_only_tables = draw(_dwd_only_tables_st)
|
||||
window_split = draw(_window_split_st)
|
||||
window_split_days = draw(_window_split_days_st)
|
||||
fetch_before_verify = draw(st.booleans())
|
||||
skip_ods = draw(st.booleans())
|
||||
ods_local = draw(st.booleans())
|
||||
|
||||
if window_mode == "custom":
|
||||
d1 = draw(st.dates(
|
||||
min_value=datetime.date(2020, 1, 1),
|
||||
max_value=datetime.date(2030, 12, 31),
|
||||
))
|
||||
d2 = draw(st.dates(
|
||||
min_value=datetime.date(2020, 1, 1),
|
||||
max_value=datetime.date(2030, 12, 31),
|
||||
))
|
||||
# 保证 end >= start
|
||||
window_start = min(d1, d2).isoformat()
|
||||
window_end = max(d1, d2).isoformat()
|
||||
lookback_hours = 24
|
||||
overlap_seconds = 600
|
||||
else:
|
||||
window_start = None
|
||||
window_end = None
|
||||
lookback_hours = draw(_lookback_hours_st)
|
||||
overlap_seconds = draw(_overlap_seconds_st)
|
||||
|
||||
return TaskConfigSchema(
|
||||
tasks=tasks,
|
||||
pipeline=pipeline,
|
||||
processing_mode=processing_mode,
|
||||
dry_run=dry_run,
|
||||
window_mode=window_mode,
|
||||
window_start=window_start,
|
||||
window_end=window_end,
|
||||
window_split=window_split,
|
||||
window_split_days=window_split_days,
|
||||
lookback_hours=lookback_hours,
|
||||
overlap_seconds=overlap_seconds,
|
||||
fetch_before_verify=fetch_before_verify,
|
||||
skip_ods_when_fetch_before_verify=skip_ods,
|
||||
ods_use_local_json=ods_local,
|
||||
store_id=store_id,
|
||||
dwd_only_tables=dwd_only_tables,
|
||||
extra_args={},
|
||||
)
|
||||
|
||||
return _build()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 1: TaskConfig 序列化往返一致性
|
||||
# **Validates: Requirements 11.1, 11.2, 11.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@settings(max_examples=200)
|
||||
@given(config=_valid_task_config_st())
|
||||
def test_task_config_round_trip(config: TaskConfigSchema):
|
||||
"""Property 1: 序列化为 JSON 后再反序列化,应产生与原始对象等价的结果。"""
|
||||
json_str = config.model_dump_json()
|
||||
restored = TaskConfigSchema.model_validate_json(json_str)
|
||||
assert restored == config, (
|
||||
f"往返不一致:\n原始={config}\n还原={restored}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 6: 时间窗口验证
|
||||
# **Validates: Requirements 2.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@settings(max_examples=200)
|
||||
@given(
|
||||
d1=st.dates(
|
||||
min_value=datetime.date(2020, 1, 1),
|
||||
max_value=datetime.date(2030, 12, 31),
|
||||
),
|
||||
d2=st.dates(
|
||||
min_value=datetime.date(2020, 1, 1),
|
||||
max_value=datetime.date(2030, 12, 31),
|
||||
),
|
||||
)
|
||||
def test_time_window_validation(d1: datetime.date, d2: datetime.date):
|
||||
"""Property 6: window_end < window_start 时验证应失败,否则应通过。"""
|
||||
start_str = d1.isoformat()
|
||||
end_str = d2.isoformat()
|
||||
|
||||
if end_str < start_str:
|
||||
# window_end 早于 window_start → 验证应失败
|
||||
try:
|
||||
TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
window_mode="custom",
|
||||
window_start=start_str,
|
||||
window_end=end_str,
|
||||
)
|
||||
raise AssertionError(
|
||||
f"期望 ValidationError,但验证通过了:start={start_str}, end={end_str}"
|
||||
)
|
||||
except ValidationError:
|
||||
pass # 预期行为
|
||||
else:
|
||||
# window_end >= window_start → 验证应通过
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
window_mode="custom",
|
||||
window_start=start_str,
|
||||
window_end=end_str,
|
||||
)
|
||||
assert config.window_start == start_str
|
||||
assert config.window_end == end_str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 7: TaskConfig 到 CLI 命令转换完整性
|
||||
# **Validates: Requirements 2.5, 2.6**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_builder = CLIBuilder()
|
||||
_ETL_PATH = "/fake/etl/project"
|
||||
|
||||
|
||||
@settings(max_examples=200)
|
||||
@given(config=_valid_task_config_st())
|
||||
def test_task_config_to_cli_completeness(config: TaskConfigSchema):
|
||||
"""Property 7: CLIBuilder 生成的命令应包含 TaskConfig 中所有非空字段对应的 CLI 参数。"""
|
||||
cmd = _builder.build_command(config, _ETL_PATH)
|
||||
|
||||
# 1) --pipeline 始终存在且值正确
|
||||
assert "--pipeline" in cmd
|
||||
idx = cmd.index("--pipeline")
|
||||
assert cmd[idx + 1] == config.pipeline
|
||||
|
||||
# 2) --processing-mode 始终存在且值正确
|
||||
assert "--processing-mode" in cmd
|
||||
idx = cmd.index("--processing-mode")
|
||||
assert cmd[idx + 1] == config.processing_mode
|
||||
|
||||
# 3) 非空任务列表 → --tasks 存在
|
||||
if config.tasks:
|
||||
assert "--tasks" in cmd
|
||||
idx = cmd.index("--tasks")
|
||||
assert set(cmd[idx + 1].split(",")) == set(config.tasks)
|
||||
|
||||
# 4) 时间窗口参数
|
||||
if config.window_mode == "lookback":
|
||||
# lookback 模式 → --lookback-hours 和 --overlap-seconds
|
||||
if config.lookback_hours is not None:
|
||||
assert "--lookback-hours" in cmd
|
||||
idx = cmd.index("--lookback-hours")
|
||||
assert cmd[idx + 1] == str(config.lookback_hours)
|
||||
if config.overlap_seconds is not None:
|
||||
assert "--overlap-seconds" in cmd
|
||||
idx = cmd.index("--overlap-seconds")
|
||||
assert cmd[idx + 1] == str(config.overlap_seconds)
|
||||
# lookback 模式不应出现 custom 参数
|
||||
assert "--window-start" not in cmd
|
||||
assert "--window-end" not in cmd
|
||||
else:
|
||||
# custom 模式 → --window-start / --window-end
|
||||
if config.window_start:
|
||||
assert "--window-start" in cmd
|
||||
if config.window_end:
|
||||
assert "--window-end" in cmd
|
||||
# custom 模式不应出现 lookback 参数
|
||||
assert "--lookback-hours" not in cmd
|
||||
assert "--overlap-seconds" not in cmd
|
||||
|
||||
# 5) dry_run → --dry-run
|
||||
if config.dry_run:
|
||||
assert "--dry-run" in cmd
|
||||
else:
|
||||
assert "--dry-run" not in cmd
|
||||
|
||||
# 6) store_id → --store-id
|
||||
if config.store_id is not None:
|
||||
assert "--store-id" in cmd
|
||||
idx = cmd.index("--store-id")
|
||||
assert cmd[idx + 1] == str(config.store_id)
|
||||
else:
|
||||
assert "--store-id" not in cmd
|
||||
|
||||
# 7) fetch_before_verify → 仅 verify_only 模式下生成
|
||||
if config.fetch_before_verify and config.processing_mode == "verify_only":
|
||||
assert "--fetch-before-verify" in cmd
|
||||
else:
|
||||
assert "--fetch-before-verify" not in cmd
|
||||
|
||||
# 8) window_split(非 None 且非 "none")→ --window-split
|
||||
if config.window_split and config.window_split != "none":
|
||||
assert "--window-split" in cmd
|
||||
idx = cmd.index("--window-split")
|
||||
assert cmd[idx + 1] == config.window_split
|
||||
if config.window_split_days is not None:
|
||||
assert "--window-split-days" in cmd
|
||||
else:
|
||||
assert "--window-split" not in cmd
|
||||
373
apps/backend/tests/test_task_executor.py
Normal file
373
apps/backend/tests/test_task_executor.py
Normal file
@@ -0,0 +1,373 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""TaskExecutor 单元测试
|
||||
|
||||
覆盖:子进程启动、stdout/stderr 读取、日志广播、取消、数据库记录。
|
||||
使用 asyncio 测试,mock 子进程和数据库连接避免外部依赖。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.schemas.tasks import TaskConfigSchema
|
||||
from app.services.task_executor import TaskExecutor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def executor() -> TaskExecutor:
|
||||
return TaskExecutor()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_config() -> TaskConfigSchema:
|
||||
return TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER", "ODS_PAYMENT"],
|
||||
pipeline="api_ods_dwd",
|
||||
store_id=42,
|
||||
)
|
||||
|
||||
|
||||
def _make_stream(lines: list[bytes]) -> AsyncMock:
|
||||
"""构造一个模拟的 asyncio.StreamReader,按行返回数据。"""
|
||||
stream = AsyncMock()
|
||||
# readline 依次返回每行,最后返回 b"" 表示 EOF
|
||||
stream.readline = AsyncMock(side_effect=[*lines, b""])
|
||||
return stream
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 订阅 / 取消订阅
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSubscription:
|
||||
def test_subscribe_returns_queue(self, executor: TaskExecutor):
|
||||
q = executor.subscribe("exec-1")
|
||||
assert isinstance(q, asyncio.Queue)
|
||||
|
||||
def test_subscribe_multiple(self, executor: TaskExecutor):
|
||||
q1 = executor.subscribe("exec-1")
|
||||
q2 = executor.subscribe("exec-1")
|
||||
assert q1 is not q2
|
||||
assert len(executor._subscribers["exec-1"]) == 2
|
||||
|
||||
def test_unsubscribe_removes_queue(self, executor: TaskExecutor):
|
||||
q = executor.subscribe("exec-1")
|
||||
executor.unsubscribe("exec-1", q)
|
||||
# 最后一个订阅者移除后,键也被清理
|
||||
assert "exec-1" not in executor._subscribers
|
||||
|
||||
def test_unsubscribe_nonexistent_is_safe(self, executor: TaskExecutor):
|
||||
"""对不存在的 execution_id 取消订阅不应报错"""
|
||||
q: asyncio.Queue = asyncio.Queue()
|
||||
executor.unsubscribe("nonexistent", q)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 广播
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBroadcast:
|
||||
def test_broadcast_to_subscribers(self, executor: TaskExecutor):
|
||||
q1 = executor.subscribe("exec-1")
|
||||
q2 = executor.subscribe("exec-1")
|
||||
executor._broadcast("exec-1", "hello")
|
||||
assert q1.get_nowait() == "hello"
|
||||
assert q2.get_nowait() == "hello"
|
||||
|
||||
def test_broadcast_no_subscribers_is_safe(self, executor: TaskExecutor):
|
||||
"""无订阅者时广播不应报错"""
|
||||
executor._broadcast("nonexistent", "hello")
|
||||
|
||||
def test_broadcast_end_sends_none(self, executor: TaskExecutor):
|
||||
q = executor.subscribe("exec-1")
|
||||
executor._broadcast_end("exec-1")
|
||||
assert q.get_nowait() is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 日志缓冲区
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLogBuffer:
|
||||
def test_get_logs_empty(self, executor: TaskExecutor):
|
||||
assert executor.get_logs("nonexistent") == []
|
||||
|
||||
def test_get_logs_returns_copy(self, executor: TaskExecutor):
|
||||
executor._log_buffers["exec-1"] = ["line1", "line2"]
|
||||
logs = executor.get_logs("exec-1")
|
||||
assert logs == ["line1", "line2"]
|
||||
# 修改副本不影响原始
|
||||
logs.append("line3")
|
||||
assert len(executor._log_buffers["exec-1"]) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 执行状态查询
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRunningState:
|
||||
def test_is_running_false_when_no_process(self, executor: TaskExecutor):
|
||||
assert executor.is_running("nonexistent") is False
|
||||
|
||||
def test_is_running_true_when_process_active(self, executor: TaskExecutor):
|
||||
proc = MagicMock()
|
||||
proc.returncode = None
|
||||
executor._processes["exec-1"] = proc
|
||||
assert executor.is_running("exec-1") is True
|
||||
|
||||
def test_is_running_false_when_process_exited(self, executor: TaskExecutor):
|
||||
proc = MagicMock()
|
||||
proc.returncode = 0
|
||||
executor._processes["exec-1"] = proc
|
||||
assert executor.is_running("exec-1") is False
|
||||
|
||||
def test_get_running_ids(self, executor: TaskExecutor):
|
||||
running = MagicMock()
|
||||
running.returncode = None
|
||||
exited = MagicMock()
|
||||
exited.returncode = 0
|
||||
executor._processes["a"] = running
|
||||
executor._processes["b"] = exited
|
||||
assert executor.get_running_ids() == ["a"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_stream
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadStream:
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_stdout_lines(self, executor: TaskExecutor):
|
||||
executor._log_buffers["exec-1"] = []
|
||||
stream = _make_stream([b"line1\n", b"line2\n"])
|
||||
collector: list[str] = []
|
||||
await executor._read_stream("exec-1", stream, "stdout", collector)
|
||||
assert collector == ["line1", "line2"]
|
||||
assert executor._log_buffers["exec-1"] == [
|
||||
"[stdout] line1",
|
||||
"[stdout] line2",
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_stderr_lines(self, executor: TaskExecutor):
|
||||
executor._log_buffers["exec-1"] = []
|
||||
stream = _make_stream([b"err1\n"])
|
||||
collector: list[str] = []
|
||||
await executor._read_stream("exec-1", stream, "stderr", collector)
|
||||
assert collector == ["err1"]
|
||||
assert executor._log_buffers["exec-1"] == ["[stderr] err1"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_stream_none_is_safe(self, executor: TaskExecutor):
|
||||
"""stream 为 None 时不应报错"""
|
||||
collector: list[str] = []
|
||||
await executor._read_stream("exec-1", None, "stdout", collector)
|
||||
assert collector == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_during_read(self, executor: TaskExecutor):
|
||||
executor._log_buffers["exec-1"] = []
|
||||
q = executor.subscribe("exec-1")
|
||||
stream = _make_stream([b"hello\n"])
|
||||
collector: list[str] = []
|
||||
await executor._read_stream("exec-1", stream, "stdout", collector)
|
||||
assert q.get_nowait() == "[stdout] hello"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# execute(集成级,mock 子进程和数据库)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExecute:
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
|
||||
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
|
||||
@patch("asyncio.create_subprocess_exec")
|
||||
async def test_successful_execution(
|
||||
self, mock_create, mock_write_log, mock_update_log,
|
||||
executor: TaskExecutor, sample_config: TaskConfigSchema,
|
||||
):
|
||||
# 模拟子进程
|
||||
proc = AsyncMock()
|
||||
proc.returncode = None
|
||||
proc.stdout = _make_stream([b"processing...\n", b"done\n"])
|
||||
proc.stderr = _make_stream([])
|
||||
proc.wait = AsyncMock(return_value=0)
|
||||
# wait 调用后设置 returncode
|
||||
async def _wait():
|
||||
proc.returncode = 0
|
||||
return 0
|
||||
proc.wait = _wait
|
||||
mock_create.return_value = proc
|
||||
|
||||
await executor.execute(sample_config, "exec-1", site_id=42)
|
||||
|
||||
# 验证写入了 running 状态
|
||||
mock_write_log.assert_called_once()
|
||||
call_kwargs = mock_write_log.call_args[1]
|
||||
assert call_kwargs["status"] == "running"
|
||||
assert call_kwargs["execution_id"] == "exec-1"
|
||||
|
||||
# 验证更新了 success 状态
|
||||
mock_update_log.assert_called_once()
|
||||
update_kwargs = mock_update_log.call_args[1]
|
||||
assert update_kwargs["status"] == "success"
|
||||
assert update_kwargs["exit_code"] == 0
|
||||
assert "processing..." in update_kwargs["output_log"]
|
||||
assert "done" in update_kwargs["output_log"]
|
||||
|
||||
# 进程已从跟踪表移除
|
||||
assert "exec-1" not in executor._processes
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
|
||||
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
|
||||
@patch("asyncio.create_subprocess_exec")
|
||||
async def test_failed_execution(
|
||||
self, mock_create, mock_write_log, mock_update_log,
|
||||
executor: TaskExecutor, sample_config: TaskConfigSchema,
|
||||
):
|
||||
proc = AsyncMock()
|
||||
proc.returncode = None
|
||||
proc.stdout = _make_stream([])
|
||||
proc.stderr = _make_stream([b"error occurred\n"])
|
||||
async def _wait():
|
||||
proc.returncode = 1
|
||||
return 1
|
||||
proc.wait = _wait
|
||||
mock_create.return_value = proc
|
||||
|
||||
await executor.execute(sample_config, "exec-2", site_id=42)
|
||||
|
||||
update_kwargs = mock_update_log.call_args[1]
|
||||
assert update_kwargs["status"] == "failed"
|
||||
assert update_kwargs["exit_code"] == 1
|
||||
assert "error occurred" in update_kwargs["error_log"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
|
||||
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
|
||||
@patch("asyncio.create_subprocess_exec")
|
||||
async def test_exception_during_execution(
|
||||
self, mock_create, mock_write_log, mock_update_log,
|
||||
executor: TaskExecutor, sample_config: TaskConfigSchema,
|
||||
):
|
||||
"""子进程创建失败时应记录 failed 状态"""
|
||||
mock_create.side_effect = OSError("command not found")
|
||||
|
||||
await executor.execute(sample_config, "exec-3", site_id=42)
|
||||
|
||||
update_kwargs = mock_update_log.call_args[1]
|
||||
assert update_kwargs["status"] == "failed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
|
||||
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
|
||||
@patch("asyncio.create_subprocess_exec")
|
||||
async def test_subscribers_notified_on_completion(
|
||||
self, mock_create, mock_write_log, mock_update_log,
|
||||
executor: TaskExecutor, sample_config: TaskConfigSchema,
|
||||
):
|
||||
proc = AsyncMock()
|
||||
proc.returncode = None
|
||||
proc.stdout = _make_stream([b"line\n"])
|
||||
proc.stderr = _make_stream([])
|
||||
async def _wait():
|
||||
proc.returncode = 0
|
||||
return 0
|
||||
proc.wait = _wait
|
||||
mock_create.return_value = proc
|
||||
|
||||
q = executor.subscribe("exec-4")
|
||||
await executor.execute(sample_config, "exec-4", site_id=42)
|
||||
|
||||
# 应收到日志行 + None 哨兵
|
||||
messages = []
|
||||
while not q.empty():
|
||||
messages.append(q.get_nowait())
|
||||
assert "[stdout] line" in messages
|
||||
assert None in messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
|
||||
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
|
||||
@patch("asyncio.create_subprocess_exec")
|
||||
async def test_duration_ms_recorded(
|
||||
self, mock_create, mock_write_log, mock_update_log,
|
||||
executor: TaskExecutor, sample_config: TaskConfigSchema,
|
||||
):
|
||||
proc = AsyncMock()
|
||||
proc.returncode = None
|
||||
proc.stdout = _make_stream([])
|
||||
proc.stderr = _make_stream([])
|
||||
async def _wait():
|
||||
proc.returncode = 0
|
||||
return 0
|
||||
proc.wait = _wait
|
||||
mock_create.return_value = proc
|
||||
|
||||
await executor.execute(sample_config, "exec-5", site_id=42)
|
||||
|
||||
update_kwargs = mock_update_log.call_args[1]
|
||||
assert isinstance(update_kwargs["duration_ms"], int)
|
||||
assert update_kwargs["duration_ms"] >= 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cancel
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCancel:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_running_process(self, executor: TaskExecutor):
|
||||
proc = MagicMock()
|
||||
proc.returncode = None
|
||||
proc.terminate = MagicMock()
|
||||
executor._processes["exec-1"] = proc
|
||||
|
||||
result = await executor.cancel("exec-1")
|
||||
assert result is True
|
||||
proc.terminate.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_nonexistent_returns_false(self, executor: TaskExecutor):
|
||||
result = await executor.cancel("nonexistent")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_already_exited_returns_false(self, executor: TaskExecutor):
|
||||
proc = MagicMock()
|
||||
proc.returncode = 0
|
||||
executor._processes["exec-1"] = proc
|
||||
|
||||
result = await executor.cancel("exec-1")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_process_lookup_error(self, executor: TaskExecutor):
|
||||
"""进程已消失时 terminate 抛出 ProcessLookupError"""
|
||||
proc = MagicMock()
|
||||
proc.returncode = None
|
||||
proc.terminate = MagicMock(side_effect=ProcessLookupError)
|
||||
executor._processes["exec-1"] = proc
|
||||
|
||||
result = await executor.cancel("exec-1")
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cleanup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCleanup:
|
||||
def test_cleanup_removes_buffers_and_subscribers(self, executor: TaskExecutor):
|
||||
executor._log_buffers["exec-1"] = ["line"]
|
||||
executor.subscribe("exec-1")
|
||||
executor.cleanup("exec-1")
|
||||
assert "exec-1" not in executor._log_buffers
|
||||
assert "exec-1" not in executor._subscribers
|
||||
|
||||
def test_cleanup_nonexistent_is_safe(self, executor: TaskExecutor):
|
||||
executor.cleanup("nonexistent")
|
||||
482
apps/backend/tests/test_task_queue.py
Normal file
482
apps/backend/tests/test_task_queue.py
Normal file
@@ -0,0 +1,482 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""TaskQueue 单元测试
|
||||
|
||||
覆盖:enqueue、dequeue、reorder、delete、process_loop 的核心逻辑。
|
||||
使用 mock 数据库操作,专注于业务逻辑验证。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, AsyncMock, patch, call
|
||||
|
||||
import pytest
|
||||
|
||||
from app.schemas.tasks import TaskConfigSchema
|
||||
from app.services.task_queue import TaskQueue, QueuedTask
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def queue() -> TaskQueue:
|
||||
return TaskQueue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_config() -> TaskConfigSchema:
|
||||
return TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER", "ODS_PAYMENT"],
|
||||
pipeline="api_ods_dwd",
|
||||
store_id=42,
|
||||
)
|
||||
|
||||
|
||||
def _mock_cursor(fetchone_val=None, fetchall_val=None, rowcount=1):
|
||||
"""构造 mock cursor,支持 context manager 协议。"""
|
||||
cur = MagicMock()
|
||||
cur.fetchone.return_value = fetchone_val
|
||||
cur.fetchall.return_value = fetchall_val or []
|
||||
cur.rowcount = rowcount
|
||||
cur.__enter__ = MagicMock(return_value=cur)
|
||||
cur.__exit__ = MagicMock(return_value=False)
|
||||
return cur
|
||||
|
||||
|
||||
def _mock_conn(cursor):
|
||||
"""构造 mock connection,支持 cursor() context manager。"""
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value = cursor
|
||||
return conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# enqueue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEnqueue:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_enqueue_returns_uuid(self, mock_get_conn, queue, sample_config):
|
||||
cur = _mock_cursor(fetchone_val=(0,))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
task_id = queue.enqueue(sample_config, site_id=42)
|
||||
|
||||
# 返回有效 UUID
|
||||
uuid.UUID(task_id)
|
||||
conn.commit.assert_called_once()
|
||||
conn.close.assert_called_once()
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_enqueue_position_increments(self, mock_get_conn, queue, sample_config):
|
||||
"""新任务 position = 当前最大 position + 1"""
|
||||
cur = _mock_cursor(fetchone_val=(5,))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.enqueue(sample_config, site_id=42)
|
||||
|
||||
# 检查 INSERT 调用中的 position 参数
|
||||
insert_call = cur.execute.call_args_list[1]
|
||||
args = insert_call[0][1]
|
||||
# args = (task_id, site_id, config_json, new_pos)
|
||||
assert args[3] == 6 # 5 + 1
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_enqueue_empty_queue_position_is_one(self, mock_get_conn, queue, sample_config):
|
||||
"""空队列时 position = 1"""
|
||||
cur = _mock_cursor(fetchone_val=(0,))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.enqueue(sample_config, site_id=42)
|
||||
|
||||
insert_call = cur.execute.call_args_list[1]
|
||||
args = insert_call[0][1]
|
||||
assert args[3] == 1
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_enqueue_serializes_config(self, mock_get_conn, queue, sample_config):
|
||||
"""config 被序列化为 JSON 字符串"""
|
||||
cur = _mock_cursor(fetchone_val=(0,))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.enqueue(sample_config, site_id=42)
|
||||
|
||||
insert_call = cur.execute.call_args_list[1]
|
||||
config_json_str = insert_call[0][1][2]
|
||||
parsed = json.loads(config_json_str)
|
||||
assert parsed["tasks"] == ["ODS_MEMBER", "ODS_PAYMENT"]
|
||||
assert parsed["pipeline"] == "api_ods_dwd"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# dequeue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDequeue:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_dequeue_returns_none_when_empty(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor(fetchone_val=None)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
result = queue.dequeue(site_id=42)
|
||||
|
||||
assert result is None
|
||||
conn.commit.assert_called()
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_dequeue_returns_task(self, mock_get_conn, queue):
|
||||
task_id = str(uuid.uuid4())
|
||||
config_dict = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}
|
||||
row = (
|
||||
task_id, 42, json.dumps(config_dict), "pending", 1,
|
||||
None, None, None, None, None,
|
||||
)
|
||||
cur = _mock_cursor(fetchone_val=row)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
result = queue.dequeue(site_id=42)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == task_id
|
||||
assert result.site_id == 42
|
||||
assert result.status == "running" # dequeue 后状态变为 running
|
||||
assert result.config["tasks"] == ["ODS_MEMBER"]
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_dequeue_updates_status_to_running(self, mock_get_conn, queue):
|
||||
task_id = str(uuid.uuid4())
|
||||
config_dict = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}
|
||||
row = (
|
||||
task_id, 42, json.dumps(config_dict), "pending", 1,
|
||||
None, None, None, None, None,
|
||||
)
|
||||
cur = _mock_cursor(fetchone_val=row)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.dequeue(site_id=42)
|
||||
|
||||
# 第二次 execute 调用应该是 UPDATE status = 'running'
|
||||
update_call = cur.execute.call_args_list[1]
|
||||
sql = update_call[0][0]
|
||||
assert "running" in sql
|
||||
assert task_id in update_call[0][1]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reorder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReorder:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_reorder_moves_task(self, mock_get_conn, queue):
|
||||
"""将第 3 个任务移到第 1 位"""
|
||||
ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||
rows = [(i,) for i in ids]
|
||||
cur = _mock_cursor(fetchall_val=rows)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.reorder(ids[2], new_position=1, site_id=42)
|
||||
|
||||
# 重排后顺序应为 [ids[2], ids[0], ids[1]]
|
||||
update_calls = cur.execute.call_args_list[1:] # 跳过 SELECT
|
||||
positions = {}
|
||||
for c in update_calls:
|
||||
pos, tid = c[0][1]
|
||||
positions[tid] = pos
|
||||
assert positions[ids[2]] == 1
|
||||
assert positions[ids[0]] == 2
|
||||
assert positions[ids[1]] == 3
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_reorder_nonexistent_task_is_noop(self, mock_get_conn, queue):
|
||||
"""重排不存在的任务不报错"""
|
||||
rows = [(str(uuid.uuid4()),)]
|
||||
cur = _mock_cursor(fetchall_val=rows)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.reorder("nonexistent-id", new_position=1, site_id=42)
|
||||
|
||||
# 只有 SELECT,没有 UPDATE
|
||||
assert cur.execute.call_count == 1
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_reorder_clamps_position(self, mock_get_conn, queue):
|
||||
"""position 超出范围时 clamp 到有效范围"""
|
||||
ids = [str(uuid.uuid4()) for _ in range(2)]
|
||||
rows = [(i,) for i in ids]
|
||||
cur = _mock_cursor(fetchall_val=rows)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
# new_position=100 超出范围,应 clamp 到末尾
|
||||
queue.reorder(ids[0], new_position=100, site_id=42)
|
||||
|
||||
update_calls = cur.execute.call_args_list[1:]
|
||||
positions = {}
|
||||
for c in update_calls:
|
||||
pos, tid = c[0][1]
|
||||
positions[tid] = pos
|
||||
# ids[0] 移到末尾
|
||||
assert positions[ids[1]] == 1
|
||||
assert positions[ids[0]] == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# delete
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDelete:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_delete_pending_task(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor(rowcount=1)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
result = queue.delete("task-1", site_id=42)
|
||||
|
||||
assert result is True
|
||||
conn.commit.assert_called_once()
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_delete_nonexistent_returns_false(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor(rowcount=0)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
result = queue.delete("nonexistent", site_id=42)
|
||||
|
||||
assert result is False
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_delete_only_affects_pending(self, mock_get_conn, queue):
|
||||
"""DELETE SQL 包含 status = 'pending' 条件"""
|
||||
cur = _mock_cursor(rowcount=0)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.delete("task-1", site_id=42)
|
||||
|
||||
sql = cur.execute.call_args[0][0]
|
||||
assert "pending" in sql
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_pending / has_running
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestQuery:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_list_pending_empty(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor(fetchall_val=[])
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
result = queue.list_pending(site_id=42)
|
||||
|
||||
assert result == []
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_list_pending_returns_tasks(self, mock_get_conn, queue):
|
||||
tid = str(uuid.uuid4())
|
||||
config = json.dumps({"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"})
|
||||
rows = [(tid, 42, config, "pending", 1, None, None, None, None, None)]
|
||||
cur = _mock_cursor(fetchall_val=rows)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
result = queue.list_pending(site_id=42)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == tid
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_has_running_true(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor(fetchone_val=(True,))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
assert queue.has_running(site_id=42) is True
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_has_running_false(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor(fetchone_val=(False,))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
assert queue.has_running(site_id=42) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# process_loop / _process_once
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestProcessLoop:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_once_skips_when_running(self, mock_get_conn, queue):
|
||||
"""有 running 任务时不 dequeue"""
|
||||
# _get_pending_site_ids 返回 [42]
|
||||
# has_running(42) 返回 True
|
||||
call_count = 0
|
||||
|
||||
def side_effect_conn():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# _get_pending_site_ids
|
||||
cur = _mock_cursor(fetchall_val=[(42,)])
|
||||
return _mock_conn(cur)
|
||||
else:
|
||||
# has_running
|
||||
cur = _mock_cursor(fetchone_val=(True,))
|
||||
return _mock_conn(cur)
|
||||
|
||||
mock_get_conn.side_effect = side_effect_conn
|
||||
|
||||
mock_executor = MagicMock()
|
||||
await queue._process_once(mock_executor)
|
||||
|
||||
# 不应调用 execute
|
||||
mock_executor.execute.assert_not_called()
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_once_dequeues_and_executes(self, mock_get_conn, queue):
|
||||
"""无 running 任务时 dequeue 并执行"""
|
||||
task_id = str(uuid.uuid4())
|
||||
config_dict = {
|
||||
"tasks": ["ODS_MEMBER"],
|
||||
"pipeline": "api_ods_dwd",
|
||||
"processing_mode": "increment_only",
|
||||
"dry_run": False,
|
||||
"window_mode": "lookback",
|
||||
"lookback_hours": 24,
|
||||
"overlap_seconds": 600,
|
||||
"fetch_before_verify": False,
|
||||
"skip_ods_when_fetch_before_verify": False,
|
||||
"ods_use_local_json": False,
|
||||
"extra_args": {},
|
||||
}
|
||||
config_json = json.dumps(config_dict)
|
||||
|
||||
call_count = 0
|
||||
|
||||
def side_effect_conn():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# _get_pending_site_ids
|
||||
cur = _mock_cursor(fetchall_val=[(42,)])
|
||||
return _mock_conn(cur)
|
||||
elif call_count == 2:
|
||||
# has_running → False
|
||||
cur = _mock_cursor(fetchone_val=(False,))
|
||||
return _mock_conn(cur)
|
||||
else:
|
||||
# dequeue → 返回任务
|
||||
row = (
|
||||
task_id, 42, config_json, "pending", 1,
|
||||
None, None, None, None, None,
|
||||
)
|
||||
cur = _mock_cursor(fetchone_val=row)
|
||||
return _mock_conn(cur)
|
||||
|
||||
mock_get_conn.side_effect = side_effect_conn
|
||||
|
||||
mock_executor = MagicMock()
|
||||
mock_executor.execute = AsyncMock()
|
||||
|
||||
await queue._process_once(mock_executor)
|
||||
|
||||
# 给 create_task 一点时间启动
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_once_no_pending(self, mock_get_conn, queue):
|
||||
"""无 pending 任务时什么都不做"""
|
||||
cur = _mock_cursor(fetchall_val=[])
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
mock_executor = MagicMock()
|
||||
await queue._process_once(mock_executor)
|
||||
|
||||
mock_executor.execute.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 生命周期
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLifecycle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_sets_running_false(self, queue):
|
||||
queue._running = True
|
||||
queue._loop_task = None
|
||||
|
||||
await queue.stop()
|
||||
|
||||
assert queue._running is False
|
||||
|
||||
def test_start_creates_task(self, queue):
|
||||
"""start() 应创建 asyncio.Task(需要事件循环)"""
|
||||
# 仅验证 _running 初始状态
|
||||
assert queue._running is False
|
||||
assert queue._loop_task is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _mark_failed / _update_queue_status_from_log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestInternalHelpers:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_mark_failed(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor()
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue._mark_failed("queue-1", "测试错误")
|
||||
|
||||
sql = cur.execute.call_args[0][0]
|
||||
assert "failed" in sql
|
||||
args = cur.execute.call_args[0][1]
|
||||
assert args[0] == "测试错误"
|
||||
assert args[1] == "queue-1"
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_update_queue_status_from_log(self, mock_get_conn, queue):
|
||||
"""从 execution_log 同步状态到 task_queue"""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
finished = datetime.now(timezone.utc)
|
||||
# 第一次 fetchone 返回 execution_log 行
|
||||
cur = _mock_cursor(fetchone_val=("success", finished, 0, None))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue._update_queue_status_from_log("queue-1")
|
||||
|
||||
# 应有 SELECT + UPDATE 两次 execute
|
||||
assert cur.execute.call_count == 2
|
||||
conn.commit.assert_called_once()
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_update_queue_status_no_log(self, mock_get_conn, queue):
|
||||
"""execution_log 无记录时不更新"""
|
||||
cur = _mock_cursor(fetchone_val=None)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue._update_queue_status_from_log("queue-1")
|
||||
|
||||
# 只有 SELECT,没有 UPDATE
|
||||
assert cur.execute.call_count == 1
|
||||
299
apps/backend/tests/test_task_registry_properties.py
Normal file
299
apps/backend/tests/test_task_registry_properties.py
Normal file
@@ -0,0 +1,299 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""任务注册表分组属性测试(Property-Based Testing)。
|
||||
|
||||
Property 4: 对于 Task_Registry 中的任务集合,分组结果中每个任务应出现在
|
||||
且仅出现在其所属业务域的分组中。
|
||||
|
||||
Validates: Requirements 2.1
|
||||
|
||||
测试策略:
|
||||
1. 直接测试 get_tasks_grouped_by_domain 函数:
|
||||
- 每个任务出现在且仅出现在其 domain 对应的分组中
|
||||
- 分组中的任务总数等于全部任务数(不多不少)
|
||||
- 每个分组的 key 等于该分组内所有任务的 domain
|
||||
2. 通过 API 端点测试(TestClient + mock auth):
|
||||
- 返回的 groups 中每个任务的 domain 与其所在分组 key 一致
|
||||
- 所有任务都出现在结果中
|
||||
3. 随机子集验证:
|
||||
- 随机选取任务子集,验证分组逻辑的一致性
|
||||
- 随机选取 domain,验证该 domain 下的任务都正确
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-registry")
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from app.services.task_registry import (
|
||||
get_all_tasks,
|
||||
get_tasks_grouped_by_domain,
|
||||
TaskDefinition,
|
||||
)
|
||||
from fastapi.testclient import TestClient
|
||||
from app.main import app
|
||||
from app.auth.dependencies import get_current_user, CurrentUser
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
ALL_TASKS = get_all_tasks()
|
||||
ALL_CODES = [t.code for t in ALL_TASKS]
|
||||
ALL_DOMAINS = list({t.domain for t in ALL_TASKS})
|
||||
|
||||
|
||||
def _mock_user() -> CurrentUser:
|
||||
return CurrentUser(user_id=1, site_id=1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4.1: 分组完整性 — 每个任务出现在且仅出现在其 domain 分组中
|
||||
# Validates: Requirements 2.1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(data=st.data())
|
||||
def test_every_task_in_exactly_its_domain_group(data):
|
||||
"""Property 4.1: 每个任务出现在且仅出现在其所属业务域的分组中。
|
||||
|
||||
从全量任务中随机选取一个任务,验证它只出现在对应 domain 的分组里,
|
||||
且不出现在其他任何分组中。
|
||||
"""
|
||||
grouped = get_tasks_grouped_by_domain()
|
||||
# 随机选取一个任务
|
||||
task = data.draw(st.sampled_from(ALL_TASKS))
|
||||
|
||||
# 该任务必须出现在其 domain 分组中
|
||||
assert task.domain in grouped, (
|
||||
f"任务 {task.code} 的 domain '{task.domain}' 不在分组 keys 中"
|
||||
)
|
||||
domain_codes = [t.code for t in grouped[task.domain]]
|
||||
assert task.code in domain_codes, (
|
||||
f"任务 {task.code} 未出现在其 domain '{task.domain}' 的分组中"
|
||||
)
|
||||
|
||||
# 该任务不应出现在其他任何分组中
|
||||
for other_domain, other_tasks in grouped.items():
|
||||
if other_domain == task.domain:
|
||||
continue
|
||||
other_codes = [t.code for t in other_tasks]
|
||||
assert task.code not in other_codes, (
|
||||
f"任务 {task.code}(domain={task.domain})错误地出现在 "
|
||||
f"domain '{other_domain}' 的分组中"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4.2: 分组总数守恒 — 分组中的任务总数等于全部任务数
|
||||
# Validates: Requirements 2.1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(data=st.data())
|
||||
def test_grouped_total_equals_all_tasks(data):
|
||||
"""Property 4.2: 分组中的任务总数等于全部任务数(不多不少)。
|
||||
|
||||
随机选取若干 domain 进行局部验证,同时验证全局总数守恒。
|
||||
"""
|
||||
all_tasks = get_all_tasks()
|
||||
grouped = get_tasks_grouped_by_domain()
|
||||
|
||||
# 全局守恒:分组内任务总数 == 全量任务数
|
||||
grouped_total = sum(len(tasks) for tasks in grouped.values())
|
||||
assert grouped_total == len(all_tasks), (
|
||||
f"分组总数 {grouped_total} != 全量任务数 {len(all_tasks)}"
|
||||
)
|
||||
|
||||
# 随机选取一个 domain,验证该 domain 下的任务数量正确
|
||||
domain = data.draw(st.sampled_from(ALL_DOMAINS))
|
||||
expected_count = sum(1 for t in all_tasks if t.domain == domain)
|
||||
actual_count = len(grouped[domain])
|
||||
assert actual_count == expected_count, (
|
||||
f"domain '{domain}' 分组内任务数 {actual_count} != 预期 {expected_count}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4.3: 分组 key 一致性 — 每个分组的 key 等于组内所有任务的 domain
|
||||
# Validates: Requirements 2.1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(data=st.data())
|
||||
def test_group_key_matches_task_domains(data):
|
||||
"""Property 4.3: 每个分组的 key 等于该分组内所有任务的 domain。
|
||||
|
||||
随机选取一个 domain 分组,验证组内每个任务的 domain 字段都等于分组 key。
|
||||
"""
|
||||
grouped = get_tasks_grouped_by_domain()
|
||||
domain = data.draw(st.sampled_from(list(grouped.keys())))
|
||||
|
||||
for task in grouped[domain]:
|
||||
assert task.domain == domain, (
|
||||
f"分组 '{domain}' 中的任务 {task.code} 的 domain 为 "
|
||||
f"'{task.domain}',与分组 key 不一致"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4.4: 任务 code 全局唯一 — 分组后不应出现重复 code
|
||||
# Validates: Requirements 2.1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(data=st.data())
|
||||
def test_no_duplicate_codes_across_groups(data):
|
||||
"""Property 4.4: 分组后所有任务的 code 全局唯一,无重复。
|
||||
|
||||
随机选取若干 domain 的任务合并,验证 code 不重复。
|
||||
"""
|
||||
grouped = get_tasks_grouped_by_domain()
|
||||
|
||||
# 收集所有分组中的 code
|
||||
all_codes_in_groups = []
|
||||
for tasks in grouped.values():
|
||||
all_codes_in_groups.extend(t.code for t in tasks)
|
||||
|
||||
assert len(all_codes_in_groups) == len(set(all_codes_in_groups)), (
|
||||
"分组中存在重复的任务 code"
|
||||
)
|
||||
|
||||
# 随机选取两个不同 domain,验证它们的任务 code 无交集
|
||||
if len(ALL_DOMAINS) >= 2:
|
||||
domains = data.draw(
|
||||
st.lists(st.sampled_from(ALL_DOMAINS), min_size=2, max_size=2, unique=True)
|
||||
)
|
||||
codes_a = {t.code for t in grouped[domains[0]]}
|
||||
codes_b = {t.code for t in grouped[domains[1]]}
|
||||
overlap = codes_a & codes_b
|
||||
assert not overlap, (
|
||||
f"domain '{domains[0]}' 和 '{domains[1]}' 存在重叠任务 code: {overlap}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4.5: 随机子集分组一致性 — 子集中的任务分组结果与全量一致
|
||||
# Validates: Requirements 2.1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
indices=st.lists(
|
||||
st.integers(min_value=0, max_value=len(ALL_TASKS) - 1),
|
||||
min_size=1,
|
||||
max_size=min(20, len(ALL_TASKS)),
|
||||
unique=True,
|
||||
)
|
||||
)
|
||||
def test_subset_grouping_consistency(indices):
|
||||
"""Property 4.5: 随机选取任务子集,验证每个任务在全量分组中的归属正确。
|
||||
|
||||
对于随机选取的任务子集,每个任务在 get_tasks_grouped_by_domain() 的结果中
|
||||
都应出现在其 domain 对应的分组里。
|
||||
"""
|
||||
grouped = get_tasks_grouped_by_domain()
|
||||
subset = [ALL_TASKS[i] for i in indices]
|
||||
|
||||
for task in subset:
|
||||
# 任务的 domain 必须是分组的 key 之一
|
||||
assert task.domain in grouped
|
||||
# 任务必须在对应分组中
|
||||
group_codes = {t.code for t in grouped[task.domain]}
|
||||
assert task.code in group_codes, (
|
||||
f"任务 {task.code} 未出现在 domain '{task.domain}' 的分组中"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4.6: API 端点分组正确性 — GET /api/tasks/registry 返回一致的分组
|
||||
# Validates: Requirements 2.1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(data=st.data())
|
||||
def test_api_registry_grouping_correctness(data):
|
||||
"""Property 4.6: API 端点返回的分组中,每个任务的 domain 与分组 key 一致,
|
||||
且所有任务都出现在结果中。
|
||||
"""
|
||||
app.dependency_overrides[get_current_user] = _mock_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/tasks/registry")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
groups = body["groups"]
|
||||
|
||||
# 收集 API 返回的所有任务 code
|
||||
api_codes: set[str] = set()
|
||||
for domain_key, task_list in groups.items():
|
||||
for task_item in task_list:
|
||||
# 每个任务的 domain 必须等于分组 key
|
||||
assert task_item["domain"] == domain_key, (
|
||||
f"API 返回的任务 {task_item['code']}(domain={task_item['domain']})"
|
||||
f"出现在分组 '{domain_key}' 中,不一致"
|
||||
)
|
||||
api_codes.add(task_item["code"])
|
||||
|
||||
# 所有任务都应出现在 API 结果中
|
||||
all_codes_set = {t.code for t in get_all_tasks()}
|
||||
assert api_codes == all_codes_set, (
|
||||
f"API 返回的任务集合与全量任务不一致。"
|
||||
f"缺失: {all_codes_set - api_codes},"
|
||||
f"多余: {api_codes - all_codes_set}"
|
||||
)
|
||||
|
||||
# 随机选取一个 domain,验证该 domain 下的任务数量与服务层一致
|
||||
if groups:
|
||||
domain = data.draw(st.sampled_from(list(groups.keys())))
|
||||
expected = get_tasks_grouped_by_domain()
|
||||
assert len(groups[domain]) == len(expected[domain]), (
|
||||
f"API 返回的 domain '{domain}' 任务数 {len(groups[domain])} "
|
||||
f"!= 服务层 {len(expected[domain])}"
|
||||
)
|
||||
finally:
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4.7: 随机 domain 过滤验证
|
||||
# Validates: Requirements 2.1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(domain=st.sampled_from(ALL_DOMAINS))
|
||||
def test_random_domain_tasks_all_correct(domain):
|
||||
"""Property 4.7: 随机选取一个 domain,验证该 domain 下的所有任务都正确归属。
|
||||
|
||||
对于选定的 domain:
|
||||
- 分组中的每个任务的 domain 字段都等于选定的 domain
|
||||
- 全量任务中所有属于该 domain 的任务都出现在分组中
|
||||
"""
|
||||
grouped = get_tasks_grouped_by_domain()
|
||||
all_tasks = get_all_tasks()
|
||||
|
||||
# 分组中该 domain 的任务
|
||||
group_tasks = grouped.get(domain, [])
|
||||
|
||||
# 全量任务中属于该 domain 的任务
|
||||
expected_tasks = [t for t in all_tasks if t.domain == domain]
|
||||
|
||||
# 数量一致
|
||||
assert len(group_tasks) == len(expected_tasks), (
|
||||
f"domain '{domain}': 分组中 {len(group_tasks)} 个任务,"
|
||||
f"预期 {len(expected_tasks)} 个"
|
||||
)
|
||||
|
||||
# code 集合一致
|
||||
group_codes = {t.code for t in group_tasks}
|
||||
expected_codes = {t.code for t in expected_tasks}
|
||||
assert group_codes == expected_codes, (
|
||||
f"domain '{domain}': 分组 codes {group_codes} != 预期 {expected_codes}"
|
||||
)
|
||||
|
||||
# 每个任务的 domain 字段都正确
|
||||
for task in group_tasks:
|
||||
assert task.domain == domain
|
||||
274
apps/backend/tests/test_tasks_router.py
Normal file
274
apps/backend/tests/test_tasks_router.py
Normal file
@@ -0,0 +1,274 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""任务注册表 API 单元测试
|
||||
|
||||
覆盖 4 个端点:registry / dwd-tables / flows / validate
|
||||
通过 JWT mock 绕过认证依赖。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
from app.services.task_registry import (
|
||||
ALL_TASKS,
|
||||
DWD_TABLES,
|
||||
FLOW_LAYER_MAP,
|
||||
get_tasks_grouped_by_domain,
|
||||
)
|
||||
|
||||
# 固定测试用户
|
||||
_TEST_USER = CurrentUser(user_id=1, site_id=100)
|
||||
|
||||
|
||||
def _override_auth():
|
||||
return _TEST_USER
|
||||
|
||||
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/tasks/registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTaskRegistry:
|
||||
def setup_method(self):
|
||||
"""每个测试方法前重新设置 auth 覆盖,防止其他测试文件的 clear/pop 导致状态泄漏"""
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
|
||||
def test_registry_returns_grouped_tasks(self):
|
||||
resp = client.get("/api/tasks/registry")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "groups" in data
|
||||
|
||||
# 所有任务都应出现在某个分组中
|
||||
all_codes_in_response = set()
|
||||
for domain, tasks in data["groups"].items():
|
||||
for t in tasks:
|
||||
all_codes_in_response.add(t["code"])
|
||||
assert t["domain"] == domain
|
||||
|
||||
expected_codes = {t.code for t in ALL_TASKS}
|
||||
assert all_codes_in_response == expected_codes
|
||||
|
||||
def test_registry_task_fields_complete(self):
|
||||
"""每个任务项包含所有必要字段"""
|
||||
resp = client.get("/api/tasks/registry")
|
||||
data = resp.json()
|
||||
required_fields = {"code", "name", "description", "domain", "layer",
|
||||
"requires_window", "is_ods", "is_dimension", "default_enabled"}
|
||||
for tasks in data["groups"].values():
|
||||
for t in tasks:
|
||||
assert required_fields.issubset(t.keys())
|
||||
|
||||
def test_registry_requires_auth(self):
|
||||
"""未认证时返回 401"""
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
try:
|
||||
resp = client.get("/api/tasks/registry")
|
||||
assert resp.status_code == 401
|
||||
finally:
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/tasks/dwd-tables
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDwdTables:
|
||||
def test_dwd_tables_returns_grouped(self):
|
||||
resp = client.get("/api/tasks/dwd-tables")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "groups" in data
|
||||
|
||||
all_tables_in_response = set()
|
||||
for domain, tables in data["groups"].items():
|
||||
for t in tables:
|
||||
all_tables_in_response.add(t["table_name"])
|
||||
assert t["domain"] == domain
|
||||
|
||||
expected_tables = {t.table_name for t in DWD_TABLES}
|
||||
assert all_tables_in_response == expected_tables
|
||||
|
||||
def test_dwd_tables_fields_complete(self):
|
||||
resp = client.get("/api/tasks/dwd-tables")
|
||||
data = resp.json()
|
||||
required_fields = {"table_name", "display_name", "domain", "ods_source", "is_dimension"}
|
||||
for tables in data["groups"].values():
|
||||
for t in tables:
|
||||
assert required_fields.issubset(t.keys())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/tasks/flows
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFlows:
|
||||
def test_flows_returns_seven_flows(self):
|
||||
resp = client.get("/api/tasks/flows")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["flows"]) == 7
|
||||
|
||||
def test_flows_returns_three_processing_modes(self):
|
||||
resp = client.get("/api/tasks/flows")
|
||||
data = resp.json()
|
||||
assert len(data["processing_modes"]) == 3
|
||||
|
||||
def test_flow_ids_match_registry(self):
|
||||
"""Flow ID 与 FLOW_LAYER_MAP 一致"""
|
||||
resp = client.get("/api/tasks/flows")
|
||||
data = resp.json()
|
||||
flow_ids = {f["id"] for f in data["flows"]}
|
||||
assert flow_ids == set(FLOW_LAYER_MAP.keys())
|
||||
|
||||
def test_flow_layers_non_empty(self):
|
||||
resp = client.get("/api/tasks/flows")
|
||||
data = resp.json()
|
||||
for f in data["flows"]:
|
||||
assert len(f["layers"]) > 0
|
||||
|
||||
def test_processing_mode_ids(self):
|
||||
resp = client.get("/api/tasks/flows")
|
||||
data = resp.json()
|
||||
mode_ids = {m["id"] for m in data["processing_modes"]}
|
||||
assert mode_ids == {"increment_only", "verify_only", "increment_verify"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/tasks/validate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestValidate:
|
||||
def test_validate_success(self):
|
||||
resp = client.post("/api/tasks/validate", json={
|
||||
"config": {
|
||||
"tasks": ["ODS_MEMBER", "ODS_PAYMENT"],
|
||||
"pipeline": "api_ods",
|
||||
}
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["valid"] is True
|
||||
assert data["errors"] == []
|
||||
assert len(data["command_args"]) > 0
|
||||
assert "--store-id" in data["command"]
|
||||
# store_id 应从 JWT 注入(测试用户 site_id=100)
|
||||
assert "100" in data["command"]
|
||||
|
||||
def test_validate_injects_store_id(self):
|
||||
"""即使前端传了 store_id,后端也用 JWT 中的值覆盖"""
|
||||
resp = client.post("/api/tasks/validate", json={
|
||||
"config": {
|
||||
"tasks": ["DWD_LOAD_FROM_ODS"],
|
||||
"pipeline": "ods_dwd",
|
||||
"store_id": 999,
|
||||
}
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
# 命令中应包含 JWT 的 site_id=100,而非前端传的 999
|
||||
assert "--store-id" in data["command"]
|
||||
idx = data["command_args"].index("--store-id")
|
||||
assert data["command_args"][idx + 1] == "100"
|
||||
|
||||
def test_validate_invalid_flow(self):
|
||||
resp = client.post("/api/tasks/validate", json={
|
||||
"config": {
|
||||
"tasks": ["ODS_MEMBER"],
|
||||
"pipeline": "nonexistent_flow",
|
||||
}
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["valid"] is False
|
||||
assert any("无效的执行流程" in e for e in data["errors"])
|
||||
|
||||
def test_validate_empty_tasks(self):
|
||||
resp = client.post("/api/tasks/validate", json={
|
||||
"config": {
|
||||
"tasks": [],
|
||||
"pipeline": "api_ods",
|
||||
}
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["valid"] is False
|
||||
assert any("任务列表不能为空" in e for e in data["errors"])
|
||||
|
||||
def test_validate_custom_window(self):
|
||||
resp = client.post("/api/tasks/validate", json={
|
||||
"config": {
|
||||
"tasks": ["ODS_MEMBER"],
|
||||
"pipeline": "api_ods",
|
||||
"window_mode": "custom",
|
||||
"window_start": "2024-01-01",
|
||||
"window_end": "2024-01-31",
|
||||
}
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["valid"] is True
|
||||
assert "--window-start" in data["command"]
|
||||
assert "--window-end" in data["command"]
|
||||
|
||||
def test_validate_window_end_before_start_rejected(self):
|
||||
"""window_end 早于 window_start 时 Pydantic 验证失败 → 422"""
|
||||
resp = client.post("/api/tasks/validate", json={
|
||||
"config": {
|
||||
"tasks": ["ODS_MEMBER"],
|
||||
"pipeline": "api_ods",
|
||||
"window_mode": "custom",
|
||||
"window_start": "2024-12-31",
|
||||
"window_end": "2024-01-01",
|
||||
}
|
||||
})
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_validate_dry_run_flag(self):
|
||||
resp = client.post("/api/tasks/validate", json={
|
||||
"config": {
|
||||
"tasks": ["ODS_MEMBER"],
|
||||
"pipeline": "api_ods",
|
||||
"dry_run": True,
|
||||
}
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "--dry-run" in data["command"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# task_registry 服务层测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTaskRegistryService:
|
||||
def test_all_tasks_have_unique_codes(self):
|
||||
codes = [t.code for t in ALL_TASKS]
|
||||
assert len(codes) == len(set(codes))
|
||||
|
||||
def test_grouped_tasks_cover_all(self):
|
||||
grouped = get_tasks_grouped_by_domain()
|
||||
all_codes = set()
|
||||
for tasks in grouped.values():
|
||||
for t in tasks:
|
||||
all_codes.add(t.code)
|
||||
assert all_codes == {t.code for t in ALL_TASKS}
|
||||
|
||||
def test_ods_tasks_marked_is_ods(self):
|
||||
for t in ALL_TASKS:
|
||||
if t.layer == "ODS":
|
||||
assert t.is_ods is True
|
||||
|
||||
def test_flow_layer_map_covers_all_flows(self):
|
||||
expected_flows = {"api_ods", "api_ods_dwd", "api_full", "ods_dwd",
|
||||
"dwd_dws", "dwd_dws_index", "dwd_index"}
|
||||
assert set(FLOW_LAYER_MAP.keys()) == expected_flows
|
||||
186
apps/backend/tests/test_ws_logs.py
Normal file
186
apps/backend/tests/test_ws_logs.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""WebSocket 日志推送端点测试
|
||||
|
||||
测试 /ws/logs/{execution_id} 端点的连接、日志回放、实时推送和断开行为。
|
||||
利用 TaskExecutor 已有的 subscribe/broadcast 机制进行验证。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
from app.main import app
|
||||
from app.services.task_executor import task_executor
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _cleanup_executor():
|
||||
"""每个测试前后清理 TaskExecutor 内部状态。"""
|
||||
yield
|
||||
# 清理所有残留的缓冲区和订阅者
|
||||
for eid in list(task_executor._log_buffers.keys()):
|
||||
task_executor.cleanup(eid)
|
||||
task_executor._subscribers.clear()
|
||||
task_executor._log_buffers.clear()
|
||||
|
||||
|
||||
class TestWebSocketConnection:
|
||||
"""WebSocket 连接/断开基本行为"""
|
||||
|
||||
def test_connect_and_disconnect(self):
|
||||
"""客户端能成功建立和关闭 WebSocket 连接。"""
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/ws/logs/test-exec-001") as ws:
|
||||
# 连接成功,直接关闭
|
||||
pass # __exit__ 会关闭连接
|
||||
|
||||
def test_connect_registers_subscriber(self):
|
||||
"""连接后 TaskExecutor 应注册订阅者。"""
|
||||
client = TestClient(app)
|
||||
# 预先初始化缓冲区(模拟有任务在运行)
|
||||
task_executor._log_buffers["test-exec-002"] = []
|
||||
|
||||
with client.websocket_connect("/ws/logs/test-exec-002"):
|
||||
# 连接期间应有订阅者
|
||||
assert "test-exec-002" in task_executor._subscribers
|
||||
assert len(task_executor._subscribers["test-exec-002"]) >= 1
|
||||
|
||||
|
||||
class TestLogReplay:
|
||||
"""历史日志回放"""
|
||||
|
||||
def test_replay_existing_logs(self):
|
||||
"""连接时应先收到内存缓冲区中已有的日志行。"""
|
||||
eid = "test-exec-replay"
|
||||
# 预填充日志缓冲区
|
||||
task_executor._log_buffers[eid] = [
|
||||
"[stdout] 第一行",
|
||||
"[stdout] 第二行",
|
||||
"[stderr] 警告信息",
|
||||
]
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
|
||||
# 应按顺序收到 3 条历史日志
|
||||
msg1 = ws.receive_text()
|
||||
msg2 = ws.receive_text()
|
||||
msg3 = ws.receive_text()
|
||||
|
||||
assert msg1 == "[stdout] 第一行"
|
||||
assert msg2 == "[stdout] 第二行"
|
||||
assert msg3 == "[stderr] 警告信息"
|
||||
|
||||
def test_no_logs_no_replay(self):
|
||||
"""没有历史日志时不应收到回放消息。"""
|
||||
eid = "test-exec-empty"
|
||||
task_executor._log_buffers[eid] = []
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
|
||||
# 发送结束信号让连接正常关闭
|
||||
task_executor._broadcast_end(eid)
|
||||
# 不应有回放消息,直接收到的是后续的结束信号处理
|
||||
|
||||
|
||||
class TestBroadcastReceive:
|
||||
"""实时日志广播接收"""
|
||||
|
||||
def test_receive_broadcast_messages(self):
|
||||
"""连接后应能收到 TaskExecutor 广播的实时日志。"""
|
||||
eid = "test-exec-broadcast"
|
||||
task_executor._log_buffers[eid] = []
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
|
||||
# 模拟 TaskExecutor 广播日志
|
||||
task_executor._broadcast(eid, "[stdout] 实时日志行1")
|
||||
task_executor._broadcast(eid, "[stderr] 实时错误行")
|
||||
|
||||
msg1 = ws.receive_text()
|
||||
msg2 = ws.receive_text()
|
||||
|
||||
assert msg1 == "[stdout] 实时日志行1"
|
||||
assert msg2 == "[stderr] 实时错误行"
|
||||
|
||||
def test_end_signal_closes_connection(self):
|
||||
"""收到 None 结束信号后 WebSocket 应正常关闭。"""
|
||||
eid = "test-exec-end"
|
||||
task_executor._log_buffers[eid] = []
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
|
||||
# 广播一条日志后发送结束信号
|
||||
task_executor._broadcast(eid, "[stdout] 最后一行")
|
||||
task_executor._broadcast_end(eid)
|
||||
|
||||
msg = ws.receive_text()
|
||||
assert msg == "[stdout] 最后一行"
|
||||
|
||||
def test_replay_then_broadcast(self):
|
||||
"""先回放历史日志,再接收实时广播。"""
|
||||
eid = "test-exec-mixed"
|
||||
task_executor._log_buffers[eid] = ["[stdout] 历史行"]
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
|
||||
# 先收到历史回放
|
||||
replay = ws.receive_text()
|
||||
assert replay == "[stdout] 历史行"
|
||||
|
||||
# 再收到实时广播
|
||||
task_executor._broadcast(eid, "[stdout] 新行")
|
||||
task_executor._broadcast_end(eid)
|
||||
|
||||
live = ws.receive_text()
|
||||
assert live == "[stdout] 新行"
|
||||
|
||||
|
||||
class TestLogBroadcasterUnit:
|
||||
"""直接测试 TaskExecutor 的 subscribe/unsubscribe/broadcast 方法
|
||||
(作为 LogBroadcaster 功能的单元测试)。
|
||||
"""
|
||||
|
||||
def test_subscribe_creates_queue(self):
|
||||
eid = "unit-sub"
|
||||
q = task_executor.subscribe(eid)
|
||||
assert isinstance(q, asyncio.Queue)
|
||||
assert eid in task_executor._subscribers
|
||||
task_executor.unsubscribe(eid, q)
|
||||
|
||||
def test_unsubscribe_removes_queue(self):
|
||||
eid = "unit-unsub"
|
||||
q = task_executor.subscribe(eid)
|
||||
task_executor.unsubscribe(eid, q)
|
||||
# 最后一个订阅者移除后,key 也应被清理
|
||||
assert eid not in task_executor._subscribers
|
||||
|
||||
def test_broadcast_delivers_to_all_subscribers(self):
|
||||
eid = "unit-multi"
|
||||
q1 = task_executor.subscribe(eid)
|
||||
q2 = task_executor.subscribe(eid)
|
||||
|
||||
task_executor._broadcast(eid, "测试消息")
|
||||
|
||||
assert q1.get_nowait() == "测试消息"
|
||||
assert q2.get_nowait() == "测试消息"
|
||||
|
||||
task_executor.unsubscribe(eid, q1)
|
||||
task_executor.unsubscribe(eid, q2)
|
||||
|
||||
def test_broadcast_end_sends_none(self):
|
||||
eid = "unit-end"
|
||||
q = task_executor.subscribe(eid)
|
||||
|
||||
task_executor._broadcast_end(eid)
|
||||
|
||||
assert q.get_nowait() is None
|
||||
task_executor.unsubscribe(eid, q)
|
||||
|
||||
def test_broadcast_no_subscribers_is_safe(self):
|
||||
"""没有订阅者时广播不应报错。"""
|
||||
task_executor._broadcast("nonexistent", "无人接收")
|
||||
task_executor._broadcast_end("nonexistent")
|
||||
Reference in New Issue
Block a user