在准备环境前提交次全部更改。
This commit is contained in:
@@ -2,16 +2,16 @@
|
||||
|
||||
## 作用说明
|
||||
|
||||
应用项目顶层目录,存放所有可独立部署/运行的子项目。当前包含 ETL 数据管线、FastAPI 后端、微信小程序前端,以及预留的管理后台。
|
||||
应用项目顶层目录,存放所有可独立部署/运行的子项目。当前包含 ETL Connector、FastAPI 后端、微信小程序前端,以及预留的管理后台。
|
||||
|
||||
## 内部结构
|
||||
|
||||
- `etl/pipelines/feiqiu/` — 飞球平台 ETL 管线(抽取→清洗→汇总全流程)
|
||||
- `etl/pipelines/feiqiu/` — 飞球 Connector(数据源连接器,抽取→清洗→汇总全流程)
|
||||
- `backend/` — FastAPI 后端(小程序 API、权限、审批)
|
||||
- `miniprogram/` — 微信小程序前端(Donut + TDesign)
|
||||
- `admin-web/` — 管理后台(预留,暂未实施)
|
||||
|
||||
## Roadmap
|
||||
|
||||
- 新增更多数据源管线时,在 `etl/pipelines/` 下按平台名创建子目录
|
||||
- 新增更多 Connector 时,在 `etl/pipelines/` 下按平台名创建子目录
|
||||
- `admin-web/` 待产品需求确认后启动
|
||||
|
||||
13
apps/admin-web/index.html
Normal file
13
apps/admin-web/index.html
Normal file
@@ -0,0 +1,13 @@
|
||||
<!doctype html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>NeoZQYY 管理后台</title>
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/src/main.tsx"></script>
|
||||
</body>
|
||||
</html>
|
||||
35
apps/admin-web/package.json
Normal file
35
apps/admin-web/package.json
Normal file
@@ -0,0 +1,35 @@
|
||||
{
|
||||
"name": "admin-web",
|
||||
"private": true,
|
||||
"version": "0.1.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "tsc -b && vite build",
|
||||
"preview": "vite preview",
|
||||
"test": "vitest run",
|
||||
"test:watch": "vitest",
|
||||
"lint": "tsc --noEmit"
|
||||
},
|
||||
"dependencies": {
|
||||
"@ant-design/icons": "^5.6.1",
|
||||
"antd": "^5.24.7",
|
||||
"axios": "^1.9.0",
|
||||
"dayjs": "^1.11.19",
|
||||
"react": "^19.1.0",
|
||||
"react-dom": "^19.1.0",
|
||||
"react-router-dom": "^7.6.1",
|
||||
"zustand": "^5.0.5"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@testing-library/jest-dom": "^6.6.3",
|
||||
"@testing-library/react": "^16.3.0",
|
||||
"@types/react": "^19.1.4",
|
||||
"@types/react-dom": "^19.1.5",
|
||||
"@vitejs/plugin-react": "^4.5.2",
|
||||
"jsdom": "^26.1.0",
|
||||
"typescript": "~5.8.3",
|
||||
"vite": "^6.3.5",
|
||||
"vitest": "^3.1.4"
|
||||
}
|
||||
}
|
||||
2851
apps/admin-web/pnpm-lock.yaml
generated
Normal file
2851
apps/admin-web/pnpm-lock.yaml
generated
Normal file
File diff suppressed because it is too large
Load Diff
196
apps/admin-web/src/App.tsx
Normal file
196
apps/admin-web/src/App.tsx
Normal file
@@ -0,0 +1,196 @@
|
||||
/**
|
||||
* 主布局与路由配置。
|
||||
*
|
||||
* - Ant Design Layout:Sider + Content + Footer(状态栏)
|
||||
* - react-router-dom:6 个功能页面路由 + 登录页路由
|
||||
* - 路由守卫:未登录重定向到登录页
|
||||
*/
|
||||
|
||||
import React, { useEffect, useState, useCallback } from "react";
|
||||
import { Routes, Route, Navigate, useNavigate, useLocation } from "react-router-dom";
|
||||
import { Layout, Menu, Spin, Space, Typography, Tag, Button, Tooltip } from "antd";
|
||||
import {
|
||||
SettingOutlined,
|
||||
UnorderedListOutlined,
|
||||
ToolOutlined,
|
||||
DatabaseOutlined,
|
||||
DashboardOutlined,
|
||||
FileTextOutlined,
|
||||
LogoutOutlined,
|
||||
} from "@ant-design/icons";
|
||||
import type { MenuProps } from "antd";
|
||||
import { useAuthStore } from "./store/authStore";
|
||||
import { fetchQueue } from "./api/execution";
|
||||
import type { QueuedTask } from "./types";
|
||||
import Login from "./pages/Login";
|
||||
import TaskConfig from "./pages/TaskConfig";
|
||||
import TaskManager from "./pages/TaskManager";
|
||||
import EnvConfig from "./pages/EnvConfig";
|
||||
import DBViewer from "./pages/DBViewer";
|
||||
import ETLStatus from "./pages/ETLStatus";
|
||||
import LogViewer from "./pages/LogViewer";
|
||||
|
||||
const { Sider, Content, Footer } = Layout;
|
||||
const { Text } = Typography;
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 侧边栏导航配置 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
const NAV_ITEMS: MenuProps["items"] = [
|
||||
{ key: "/", icon: <SettingOutlined />, label: "任务配置" },
|
||||
{ key: "/task-manager", icon: <UnorderedListOutlined />, label: "任务管理" },
|
||||
{ key: "/etl-status", icon: <DashboardOutlined />, label: "ETL 状态" },
|
||||
{ key: "/db-viewer", icon: <DatabaseOutlined />, label: "数据库" },
|
||||
{ key: "/log-viewer", icon: <FileTextOutlined />, label: "日志" },
|
||||
{ key: "/env-config", icon: <ToolOutlined />, label: "环境配置" },
|
||||
];
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 路由守卫 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
const PrivateRoute: React.FC<{ children: React.ReactNode }> = ({ children }) => {
|
||||
const isAuthenticated = useAuthStore((s) => s.isAuthenticated);
|
||||
return isAuthenticated ? <>{children}</> : <Navigate to="/login" replace />;
|
||||
};
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 主布局 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
const AppLayout: React.FC = () => {
|
||||
const navigate = useNavigate();
|
||||
const location = useLocation();
|
||||
const logout = useAuthStore((s) => s.logout);
|
||||
|
||||
const [runningTask, setRunningTask] = useState<QueuedTask | null>(null);
|
||||
|
||||
const pollQueue = useCallback(async () => {
|
||||
try {
|
||||
const queue = await fetchQueue();
|
||||
const running = queue.find((t) => t.status === "running") ?? null;
|
||||
setRunningTask(running);
|
||||
} catch {
|
||||
// 网络异常时不更新状态
|
||||
}
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
pollQueue();
|
||||
const timer = setInterval(pollQueue, 5_000);
|
||||
return () => clearInterval(timer);
|
||||
}, [pollQueue]);
|
||||
|
||||
const onMenuClick: MenuProps["onClick"] = ({ key }) => { navigate(key); };
|
||||
|
||||
const handleLogout = () => {
|
||||
logout();
|
||||
navigate("/login", { replace: true });
|
||||
};
|
||||
|
||||
return (
|
||||
<Layout style={{ minHeight: "100vh" }}>
|
||||
<Sider
|
||||
collapsible
|
||||
style={{ display: "flex", flexDirection: "column" }}
|
||||
>
|
||||
<div
|
||||
style={{
|
||||
height: 48,
|
||||
margin: "12px 16px",
|
||||
color: "#fff",
|
||||
fontWeight: 700,
|
||||
fontSize: 18,
|
||||
textAlign: "center",
|
||||
lineHeight: "48px",
|
||||
whiteSpace: "nowrap",
|
||||
overflow: "hidden",
|
||||
letterSpacing: 1,
|
||||
}}
|
||||
>
|
||||
NeoZQYY
|
||||
</div>
|
||||
<Menu
|
||||
theme="dark"
|
||||
mode="inline"
|
||||
selectedKeys={[location.pathname]}
|
||||
items={NAV_ITEMS}
|
||||
onClick={onMenuClick}
|
||||
/>
|
||||
<div style={{ flex: 1 }} />
|
||||
<div style={{ padding: "12px 16px" }}>
|
||||
<Tooltip title="退出登录">
|
||||
<Button
|
||||
type="text" icon={<LogoutOutlined />}
|
||||
onClick={handleLogout}
|
||||
style={{ color: "rgba(255,255,255,0.65)", width: "100%" }}
|
||||
>
|
||||
退出
|
||||
</Button>
|
||||
</Tooltip>
|
||||
</div>
|
||||
</Sider>
|
||||
<Layout>
|
||||
<Content style={{ margin: 16, minHeight: 280 }}>
|
||||
<Routes>
|
||||
<Route path="/" element={<TaskConfig />} />
|
||||
<Route path="/task-manager" element={<TaskManager />} />
|
||||
<Route path="/env-config" element={<EnvConfig />} />
|
||||
<Route path="/db-viewer" element={<DBViewer />} />
|
||||
<Route path="/etl-status" element={<ETLStatus />} />
|
||||
<Route path="/log-viewer" element={<LogViewer />} />
|
||||
</Routes>
|
||||
</Content>
|
||||
<Footer
|
||||
style={{
|
||||
textAlign: "center",
|
||||
padding: "6px 16px",
|
||||
background: "#fafafa",
|
||||
borderTop: "1px solid #f0f0f0",
|
||||
}}
|
||||
>
|
||||
{runningTask ? (
|
||||
<Space size={8}>
|
||||
<Spin size="small" />
|
||||
<Text>执行中</Text>
|
||||
<Tag color="processing">{runningTask.config.pipeline}</Tag>
|
||||
<Text type="secondary" style={{ fontSize: 12 }}>
|
||||
{runningTask.config.tasks.slice(0, 3).join(", ")}
|
||||
{runningTask.config.tasks.length > 3 && ` +${runningTask.config.tasks.length - 3}`}
|
||||
</Text>
|
||||
</Space>
|
||||
) : (
|
||||
<Text type="secondary" style={{ fontSize: 12 }}>无任务执行中</Text>
|
||||
)}
|
||||
</Footer>
|
||||
</Layout>
|
||||
</Layout>
|
||||
);
|
||||
};
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 根组件 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
const App: React.FC = () => {
|
||||
const hydrate = useAuthStore((s) => s.hydrate);
|
||||
|
||||
useEffect(() => { hydrate(); }, [hydrate]);
|
||||
|
||||
return (
|
||||
<Routes>
|
||||
<Route path="/login" element={<Login />} />
|
||||
<Route
|
||||
path="/*"
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<AppLayout />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
</Routes>
|
||||
);
|
||||
};
|
||||
|
||||
export default App;
|
||||
125
apps/admin-web/src/__tests__/flowLayers.test.ts
Normal file
125
apps/admin-web/src/__tests__/flowLayers.test.ts
Normal file
@@ -0,0 +1,125 @@
|
||||
/**
|
||||
* Flow 层级与任务兼容性测试
|
||||
*
|
||||
* **Validates: Requirements 2.2**
|
||||
*
|
||||
* Property 21: 对任意 Flow 类型和任务定义,当 Flow 包含的层不包含该任务所属层时,
|
||||
* 该任务不应出现在可选列表中;当 Flow 包含该任务所属层时,该任务应出现在可选列表中。
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from "vitest";
|
||||
import { getFlowLayers } from "../pages/TaskConfig";
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 预期的 Flow 定义(来自设计文档) */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
const EXPECTED_FLOWS: Record<string, string[]> = {
|
||||
api_ods: ["ODS"],
|
||||
api_ods_dwd: ["ODS", "DWD"],
|
||||
api_full: ["ODS", "DWD", "DWS", "INDEX"],
|
||||
ods_dwd: ["DWD"],
|
||||
dwd_dws: ["DWS"],
|
||||
dwd_dws_index: ["DWS", "INDEX"],
|
||||
dwd_index: ["INDEX"],
|
||||
};
|
||||
|
||||
describe("getFlowLayers — Flow 层级与任务兼容性", () => {
|
||||
/* ---- 1. 每个已知 Flow 返回正确的层列表 ---- */
|
||||
it.each(Object.entries(EXPECTED_FLOWS))(
|
||||
"Flow '%s' 应返回 %j",
|
||||
(flowId, expectedLayers) => {
|
||||
expect(getFlowLayers(flowId)).toEqual(expectedLayers);
|
||||
},
|
||||
);
|
||||
|
||||
/* ---- 2. 未知 Flow ID 返回空数组 ---- */
|
||||
it("未知 Flow ID 应返回空数组", () => {
|
||||
expect(getFlowLayers("unknown_flow")).toEqual([]);
|
||||
expect(getFlowLayers("")).toEqual([]);
|
||||
expect(getFlowLayers("API_FULL")).toEqual([]); // 大小写敏感
|
||||
});
|
||||
|
||||
/* ---- 3. 所有 7 种 Flow 都有定义 ---- */
|
||||
it("应定义全部 7 种 Flow", () => {
|
||||
const allFlowIds = Object.keys(EXPECTED_FLOWS);
|
||||
expect(allFlowIds).toHaveLength(7);
|
||||
for (const flowId of allFlowIds) {
|
||||
expect(getFlowLayers(flowId).length).toBeGreaterThan(0);
|
||||
}
|
||||
});
|
||||
|
||||
/* ---- 4. 层级互斥性验证 ---- */
|
||||
describe("层级互斥性", () => {
|
||||
it("api_ods 不包含 DWD / DWS / INDEX", () => {
|
||||
const layers = getFlowLayers("api_ods");
|
||||
expect(layers).not.toContain("DWD");
|
||||
expect(layers).not.toContain("DWS");
|
||||
expect(layers).not.toContain("INDEX");
|
||||
});
|
||||
|
||||
it("ods_dwd 只包含 DWD,不包含 ODS / DWS / INDEX", () => {
|
||||
const layers = getFlowLayers("ods_dwd");
|
||||
expect(layers).not.toContain("ODS");
|
||||
expect(layers).not.toContain("DWS");
|
||||
expect(layers).not.toContain("INDEX");
|
||||
});
|
||||
|
||||
it("dwd_dws 只包含 DWS,不包含 ODS / DWD / INDEX", () => {
|
||||
const layers = getFlowLayers("dwd_dws");
|
||||
expect(layers).not.toContain("ODS");
|
||||
expect(layers).not.toContain("DWD");
|
||||
expect(layers).not.toContain("INDEX");
|
||||
});
|
||||
|
||||
it("dwd_index 只包含 INDEX,不包含 ODS / DWD / DWS", () => {
|
||||
const layers = getFlowLayers("dwd_index");
|
||||
expect(layers).not.toContain("ODS");
|
||||
expect(layers).not.toContain("DWD");
|
||||
expect(layers).not.toContain("DWS");
|
||||
});
|
||||
});
|
||||
|
||||
/* ---- 5. 任务兼容性:模拟任务按层过滤 ---- */
|
||||
describe("任务兼容性过滤", () => {
|
||||
// 模拟任务定义
|
||||
const mockTasks = [
|
||||
{ code: "FETCH_ORDERS", layer: "ODS" },
|
||||
{ code: "LOAD_DWD_ORDERS", layer: "DWD" },
|
||||
{ code: "AGG_DAILY_REVENUE", layer: "DWS" },
|
||||
{ code: "CALC_WBI_INDEX", layer: "INDEX" },
|
||||
];
|
||||
|
||||
/**
|
||||
* 根据 Flow 包含的层过滤任务(与 TaskSelector 组件逻辑一致)
|
||||
*/
|
||||
function filterTasksByFlow(flowId: string) {
|
||||
const layers = getFlowLayers(flowId);
|
||||
return mockTasks.filter((t) => layers.includes(t.layer));
|
||||
}
|
||||
|
||||
it("api_ods 只显示 ODS 任务", () => {
|
||||
const visible = filterTasksByFlow("api_ods");
|
||||
expect(visible.map((t) => t.code)).toEqual(["FETCH_ORDERS"]);
|
||||
});
|
||||
|
||||
it("api_full 显示所有层的任务", () => {
|
||||
const visible = filterTasksByFlow("api_full");
|
||||
expect(visible).toHaveLength(4);
|
||||
});
|
||||
|
||||
it("dwd_dws_index 显示 DWS 和 INDEX 任务", () => {
|
||||
const visible = filterTasksByFlow("dwd_dws_index");
|
||||
const codes = visible.map((t) => t.code);
|
||||
expect(codes).toContain("AGG_DAILY_REVENUE");
|
||||
expect(codes).toContain("CALC_WBI_INDEX");
|
||||
expect(codes).not.toContain("FETCH_ORDERS");
|
||||
expect(codes).not.toContain("LOAD_DWD_ORDERS");
|
||||
});
|
||||
|
||||
it("未知 Flow 不显示任何任务", () => {
|
||||
const visible = filterTasksByFlow("nonexistent");
|
||||
expect(visible).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
});
|
||||
169
apps/admin-web/src/__tests__/logFilter.test.ts
Normal file
169
apps/admin-web/src/__tests__/logFilter.test.ts
Normal file
@@ -0,0 +1,169 @@
|
||||
/**
|
||||
* 日志过滤正确性测试
|
||||
*
|
||||
* **Validates: Requirements 9.2**
|
||||
*
|
||||
* Property 19: 对任意日志行集合和过滤关键词,过滤后的结果应只包含
|
||||
* 含有该关键词的日志行,且不遗漏任何匹配行。
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from "vitest";
|
||||
import { filterLogLines } from "../pages/LogViewer";
|
||||
|
||||
describe("filterLogLines — 日志过滤正确性", () => {
|
||||
/* ---- 1. 空关键词返回所有行 ---- */
|
||||
it("空关键词返回所有行", () => {
|
||||
const lines = ["INFO 启动", "ERROR 失败", "DEBUG 调试"];
|
||||
expect(filterLogLines(lines, "")).toEqual(lines);
|
||||
});
|
||||
|
||||
/* ---- 2. 空格关键词返回所有行 ---- */
|
||||
it("空格关键词返回所有行", () => {
|
||||
const lines = ["行1", "行2", "行3"];
|
||||
expect(filterLogLines(lines, " ")).toEqual(lines);
|
||||
expect(filterLogLines(lines, "\t")).toEqual(lines);
|
||||
});
|
||||
|
||||
/* ---- 3. 匹配的行被保留 ---- */
|
||||
it("匹配的行被保留", () => {
|
||||
const lines = ["INFO 启动成功", "ERROR 连接失败", "INFO 处理完成"];
|
||||
expect(filterLogLines(lines, "INFO")).toEqual([
|
||||
"INFO 启动成功",
|
||||
"INFO 处理完成",
|
||||
]);
|
||||
});
|
||||
|
||||
/* ---- 4. 不匹配的行被过滤掉 ---- */
|
||||
it("不匹配的行被过滤掉", () => {
|
||||
const lines = ["INFO ok", "ERROR fail", "WARN slow"];
|
||||
const result = filterLogLines(lines, "ERROR");
|
||||
expect(result).not.toContain("INFO ok");
|
||||
expect(result).not.toContain("WARN slow");
|
||||
});
|
||||
|
||||
/* ---- 5. 大小写不敏感匹配 ---- */
|
||||
it("大小写不敏感匹配", () => {
|
||||
const lines = ["Error occurred", "error found", "ERROR critical"];
|
||||
const result = filterLogLines(lines, "error");
|
||||
expect(result).toHaveLength(3);
|
||||
expect(result).toEqual(lines);
|
||||
});
|
||||
|
||||
/* ---- 6. 空行数组返回空数组 ---- */
|
||||
it("空行数组返回空数组", () => {
|
||||
expect(filterLogLines([], "anything")).toEqual([]);
|
||||
});
|
||||
|
||||
/* ---- 7. 所有行都匹配时返回全部 ---- */
|
||||
it("所有行都匹配时返回全部", () => {
|
||||
const lines = ["log: a", "log: b", "log: c"];
|
||||
expect(filterLogLines(lines, "log")).toEqual(lines);
|
||||
});
|
||||
|
||||
/* ---- 8. 没有行匹配时返回空数组 ---- */
|
||||
it("没有行匹配时返回空数组", () => {
|
||||
const lines = ["hello", "world", "foo"];
|
||||
expect(filterLogLines(lines, "zzz")).toEqual([]);
|
||||
});
|
||||
|
||||
/* ---- 9. 关键词在行首/行中/行尾都能匹配 ---- */
|
||||
describe("关键词位置匹配", () => {
|
||||
const keyword = "target";
|
||||
|
||||
it("行首匹配", () => {
|
||||
expect(filterLogLines(["target is here"], keyword)).toHaveLength(1);
|
||||
});
|
||||
|
||||
it("行中匹配", () => {
|
||||
expect(filterLogLines(["the target found"], keyword)).toHaveLength(1);
|
||||
});
|
||||
|
||||
it("行尾匹配", () => {
|
||||
expect(filterLogLines(["found the target"], keyword)).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
|
||||
/* ---- 10. 特殊字符关键词正常工作 ---- */
|
||||
it("特殊字符关键词正常工作", () => {
|
||||
const lines = [
|
||||
"path: /api/v1/users",
|
||||
"regex: [a-z]+",
|
||||
"price: $100.00",
|
||||
"normal line",
|
||||
];
|
||||
// 包含 '/' 的关键词
|
||||
expect(filterLogLines(lines, "/api")).toEqual(["path: /api/v1/users"]);
|
||||
// 包含 '[' 的关键词
|
||||
expect(filterLogLines(lines, "[a-z]")).toEqual(["regex: [a-z]+"]);
|
||||
// 包含 '$' 的关键词
|
||||
expect(filterLogLines(lines, "$100")).toEqual(["price: $100.00"]);
|
||||
});
|
||||
|
||||
/* ---- 11. Property: 过滤结果是原始数组的子集 ---- */
|
||||
it("过滤结果是原始数组的子集", () => {
|
||||
const lines = ["alpha", "beta", "gamma", "delta", "epsilon"];
|
||||
const keywords = ["a", "eta", "xyz", ""];
|
||||
|
||||
for (const kw of keywords) {
|
||||
const result = filterLogLines(lines, kw);
|
||||
// 结果中的每一行都必须存在于原始数组中
|
||||
for (const line of result) {
|
||||
expect(lines).toContain(line);
|
||||
}
|
||||
// 结果长度不超过原始数组
|
||||
expect(result.length).toBeLessThanOrEqual(lines.length);
|
||||
}
|
||||
});
|
||||
|
||||
/* ---- 12. Property: 过滤结果中每一行都包含关键词 ---- */
|
||||
it("过滤结果中每一行都包含关键词", () => {
|
||||
const lines = [
|
||||
"2024-01-01 INFO 启动",
|
||||
"2024-01-01 ERROR 数据库连接失败",
|
||||
"2024-01-01 WARN 内存不足",
|
||||
"2024-01-01 INFO 处理完成",
|
||||
"2024-01-01 DEBUG SQL: SELECT *",
|
||||
];
|
||||
const keywords = ["INFO", "error", "SQL", "2024", "不存在的关键词"];
|
||||
|
||||
for (const kw of keywords) {
|
||||
const result = filterLogLines(lines, kw);
|
||||
const lower = kw.toLowerCase();
|
||||
for (const line of result) {
|
||||
expect(line.toLowerCase()).toContain(lower);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
/* ---- 13. Property: 原始数组中包含关键词的行都在结果中(不遗漏) ---- */
|
||||
it("原始数组中包含关键词的行都在结果中(不遗漏)", () => {
|
||||
const lines = [
|
||||
"INFO 启动",
|
||||
"ERROR 失败",
|
||||
"INFO 完成",
|
||||
"WARN 超时",
|
||||
"INFO 关闭",
|
||||
];
|
||||
const keyword = "INFO";
|
||||
const result = filterLogLines(lines, keyword);
|
||||
const lower = keyword.toLowerCase();
|
||||
|
||||
// 手动找出所有应匹配的行
|
||||
const expected = lines.filter((l) => l.toLowerCase().includes(lower));
|
||||
expect(result).toEqual(expected);
|
||||
|
||||
// 确认没有遗漏:原始数组中每一行如果包含关键词,就必须在结果中
|
||||
for (const line of lines) {
|
||||
if (line.toLowerCase().includes(lower)) {
|
||||
expect(result).toContain(line);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
/* ---- 补充:保持原始顺序 ---- */
|
||||
it("过滤结果保持原始顺序", () => {
|
||||
const lines = ["c-match", "a-match", "b-no", "d-match"];
|
||||
const result = filterLogLines(lines, "match");
|
||||
expect(result).toEqual(["c-match", "a-match", "d-match"]);
|
||||
});
|
||||
});
|
||||
159
apps/admin-web/src/api/client.ts
Normal file
159
apps/admin-web/src/api/client.ts
Normal file
@@ -0,0 +1,159 @@
|
||||
/**
|
||||
* axios 实例 & JWT 拦截器。
|
||||
*
|
||||
* - 请求拦截器:自动从 localStorage 读取 access_token 并附加 Authorization header
|
||||
* - 响应拦截器:遇到 401 时尝试用 refresh_token 刷新,刷新失败则清除令牌并跳转 /login
|
||||
* - 并发刷新保护:多个请求同时 401 时只触发一次 refresh,其余排队等待
|
||||
*/
|
||||
|
||||
import axios, {
|
||||
type AxiosError,
|
||||
type AxiosRequestConfig,
|
||||
type InternalAxiosRequestConfig,
|
||||
} from "axios";
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 常量 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
const ACCESS_TOKEN_KEY = "access_token";
|
||||
const REFRESH_TOKEN_KEY = "refresh_token";
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* axios 实例 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
export const apiClient = axios.create({
|
||||
baseURL: "/api",
|
||||
});
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 请求拦截器 — 附加 JWT */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
apiClient.interceptors.request.use((config: InternalAxiosRequestConfig) => {
|
||||
const token = localStorage.getItem(ACCESS_TOKEN_KEY);
|
||||
if (token) {
|
||||
config.headers.Authorization = `Bearer ${token}`;
|
||||
}
|
||||
return config;
|
||||
});
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 响应拦截器 — 401 自动刷新 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
/** 是否正在刷新中 */
|
||||
let isRefreshing = false;
|
||||
|
||||
/** 等待刷新完成的请求队列 */
|
||||
let pendingQueue: {
|
||||
resolve: (token: string) => void;
|
||||
reject: (err: unknown) => void;
|
||||
}[] = [];
|
||||
|
||||
/** 刷新完成后,依次重放排队的请求 */
|
||||
function processPendingQueue(token: string | null, error: unknown) {
|
||||
pendingQueue.forEach(({ resolve, reject }) => {
|
||||
if (token) {
|
||||
resolve(token);
|
||||
} else {
|
||||
reject(error);
|
||||
}
|
||||
});
|
||||
pendingQueue = [];
|
||||
}
|
||||
|
||||
apiClient.interceptors.response.use(
|
||||
(response) => response,
|
||||
async (error: AxiosError) => {
|
||||
const originalRequest = error.config as AxiosRequestConfig & {
|
||||
_retried?: boolean;
|
||||
};
|
||||
|
||||
// 非 401、无原始请求、或已重试过 → 直接抛出
|
||||
if (
|
||||
error.response?.status !== 401 ||
|
||||
!originalRequest ||
|
||||
originalRequest._retried
|
||||
) {
|
||||
return Promise.reject(error);
|
||||
}
|
||||
|
||||
// 刷新端点本身 401 → 不再递归刷新
|
||||
if (originalRequest.url === "/auth/refresh") {
|
||||
clearTokensAndRedirect();
|
||||
return Promise.reject(error);
|
||||
}
|
||||
|
||||
// 已有刷新请求在飞 → 排队等待
|
||||
if (isRefreshing) {
|
||||
return new Promise<string>((resolve, reject) => {
|
||||
pendingQueue.push({ resolve, reject });
|
||||
}).then((newToken) => {
|
||||
originalRequest.headers = {
|
||||
...originalRequest.headers,
|
||||
Authorization: `Bearer ${newToken}`,
|
||||
};
|
||||
originalRequest._retried = true;
|
||||
return apiClient(originalRequest);
|
||||
});
|
||||
}
|
||||
|
||||
// 发起刷新
|
||||
isRefreshing = true;
|
||||
originalRequest._retried = true;
|
||||
|
||||
const refreshToken = localStorage.getItem(REFRESH_TOKEN_KEY);
|
||||
if (!refreshToken) {
|
||||
isRefreshing = false;
|
||||
processPendingQueue(null, error);
|
||||
clearTokensAndRedirect();
|
||||
return Promise.reject(error);
|
||||
}
|
||||
|
||||
try {
|
||||
// 用独立 axios 调用避免被自身拦截器干扰
|
||||
const { data } = await axios.post<{
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
}>("/api/auth/refresh", { refresh_token: refreshToken });
|
||||
|
||||
localStorage.setItem(ACCESS_TOKEN_KEY, data.access_token);
|
||||
localStorage.setItem(REFRESH_TOKEN_KEY, data.refresh_token);
|
||||
|
||||
processPendingQueue(data.access_token, null);
|
||||
|
||||
// 重放原始请求
|
||||
originalRequest.headers = {
|
||||
...originalRequest.headers,
|
||||
Authorization: `Bearer ${data.access_token}`,
|
||||
};
|
||||
return apiClient(originalRequest);
|
||||
} catch (refreshError) {
|
||||
processPendingQueue(null, refreshError);
|
||||
clearTokensAndRedirect();
|
||||
return Promise.reject(refreshError);
|
||||
} finally {
|
||||
isRefreshing = false;
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 辅助 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function clearTokensAndRedirect() {
|
||||
localStorage.removeItem(ACCESS_TOKEN_KEY);
|
||||
localStorage.removeItem(REFRESH_TOKEN_KEY);
|
||||
|
||||
// 派发自定义事件,让 authStore 监听并重置状态
|
||||
// 避免直接 import authStore 导致循环依赖
|
||||
window.dispatchEvent(new Event("auth:force-logout"));
|
||||
|
||||
// 避免在登录页反复跳转
|
||||
if (window.location.pathname !== "/login") {
|
||||
window.location.href = "/login";
|
||||
}
|
||||
}
|
||||
59
apps/admin-web/src/api/dbViewer.ts
Normal file
59
apps/admin-web/src/api/dbViewer.ts
Normal file
@@ -0,0 +1,59 @@
|
||||
/**
|
||||
* 数据库查看器相关 API 调用。
|
||||
*
|
||||
* - fetchSchemas:获取 Schema 列表
|
||||
* - fetchTables:获取指定 Schema 下的表列表(含行数)
|
||||
* - fetchColumns:获取指定表的列定义
|
||||
* - executeQuery:执行只读 SQL 查询
|
||||
*/
|
||||
|
||||
import { apiClient } from './client';
|
||||
|
||||
/** 表信息 */
|
||||
export interface TableInfo {
|
||||
name: string;
|
||||
row_count: number;
|
||||
}
|
||||
|
||||
/** 列定义 */
|
||||
export interface ColumnInfo {
|
||||
name: string;
|
||||
data_type: string;
|
||||
is_nullable: boolean;
|
||||
default_value: string | null;
|
||||
}
|
||||
|
||||
/** 查询结果 */
|
||||
export interface QueryResult {
|
||||
columns: string[];
|
||||
rows: unknown[][];
|
||||
row_count: number;
|
||||
}
|
||||
|
||||
/** 获取所有 Schema */
|
||||
export async function fetchSchemas(): Promise<string[]> {
|
||||
const { data } = await apiClient.get<{ schemas: string[] }>('/db/schemas');
|
||||
return data.schemas;
|
||||
}
|
||||
|
||||
/** 获取指定 Schema 下的表列表 */
|
||||
export async function fetchTables(schema: string): Promise<TableInfo[]> {
|
||||
const { data } = await apiClient.get<{ tables: TableInfo[] }>(
|
||||
`/db/schemas/${encodeURIComponent(schema)}/tables`,
|
||||
);
|
||||
return data.tables;
|
||||
}
|
||||
|
||||
/** 获取指定表的列定义 */
|
||||
export async function fetchColumns(schema: string, table: string): Promise<ColumnInfo[]> {
|
||||
const { data } = await apiClient.get<{ columns: ColumnInfo[] }>(
|
||||
`/db/tables/${encodeURIComponent(schema)}/${encodeURIComponent(table)}/columns`,
|
||||
);
|
||||
return data.columns;
|
||||
}
|
||||
|
||||
/** 执行只读 SQL 查询 */
|
||||
export async function executeQuery(sql: string): Promise<QueryResult> {
|
||||
const { data } = await apiClient.post<QueryResult>('/db/query', { sql });
|
||||
return data;
|
||||
}
|
||||
44
apps/admin-web/src/api/envConfig.ts
Normal file
44
apps/admin-web/src/api/envConfig.ts
Normal file
@@ -0,0 +1,44 @@
|
||||
/**
|
||||
* 环境配置相关 API 调用。
|
||||
*
|
||||
* - fetchEnvConfig:获取键值对列表(敏感值已掩码)
|
||||
* - updateEnvConfig:批量更新键值对
|
||||
* - exportEnvConfig:导出去敏感值的配置文件(浏览器下载)
|
||||
*/
|
||||
|
||||
import { apiClient } from './client';
|
||||
import type { EnvConfigItem } from '../types';
|
||||
|
||||
/** 获取环境配置列表 */
|
||||
export async function fetchEnvConfig(): Promise<EnvConfigItem[]> {
|
||||
const { data } = await apiClient.get<{ items: EnvConfigItem[] }>('/env-config');
|
||||
return data.items;
|
||||
}
|
||||
|
||||
/** 批量更新环境配置 */
|
||||
export async function updateEnvConfig(items: EnvConfigItem[]): Promise<void> {
|
||||
await apiClient.put('/env-config', { items });
|
||||
}
|
||||
|
||||
/** 导出配置文件(去敏感值),触发浏览器下载 */
|
||||
export async function exportEnvConfig(): Promise<void> {
|
||||
const response = await apiClient.get('/env-config/export', {
|
||||
responseType: 'blob',
|
||||
});
|
||||
// 从响应头提取文件名,回退默认值
|
||||
const disposition = response.headers['content-disposition'] as string | undefined;
|
||||
let filename = 'env-config.txt';
|
||||
if (disposition) {
|
||||
const match = disposition.match(/filename="?([^";\s]+)"?/);
|
||||
if (match) filename = match[1];
|
||||
}
|
||||
// 创建临时链接触发下载
|
||||
const url = URL.createObjectURL(response.data as Blob);
|
||||
const a = document.createElement('a');
|
||||
a.href = url;
|
||||
a.download = filename;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(url);
|
||||
}
|
||||
38
apps/admin-web/src/api/etlStatus.ts
Normal file
38
apps/admin-web/src/api/etlStatus.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
/**
|
||||
* ETL 状态监控 API 调用。
|
||||
*
|
||||
* - fetchCursors:获取各任务的数据游标(最后抓取时间、记录数)
|
||||
* - fetchRecentRuns:获取最近执行记录
|
||||
*/
|
||||
|
||||
import { apiClient } from './client';
|
||||
|
||||
/** ETL 游标信息 */
|
||||
export interface CursorInfo {
|
||||
task_code: string;
|
||||
last_fetch_time: string | null;
|
||||
record_count: number | null;
|
||||
}
|
||||
|
||||
/** 最近执行记录 */
|
||||
export interface RecentRun {
|
||||
id: string;
|
||||
task_codes: string[];
|
||||
status: string;
|
||||
started_at: string;
|
||||
finished_at: string | null;
|
||||
duration_ms: number | null;
|
||||
exit_code: number | null;
|
||||
}
|
||||
|
||||
/** 获取各任务的数据游标 */
|
||||
export async function fetchCursors(): Promise<CursorInfo[]> {
|
||||
const { data } = await apiClient.get<CursorInfo[]>('/etl-status/cursors');
|
||||
return data;
|
||||
}
|
||||
|
||||
/** 获取最近执行记录 */
|
||||
export async function fetchRecentRuns(): Promise<RecentRun[]> {
|
||||
const { data } = await apiClient.get<RecentRun[]>('/etl-status/recent-runs');
|
||||
return data;
|
||||
}
|
||||
47
apps/admin-web/src/api/execution.ts
Normal file
47
apps/admin-web/src/api/execution.ts
Normal file
@@ -0,0 +1,47 @@
|
||||
/**
|
||||
* 任务执行相关 API 调用。
|
||||
*
|
||||
* - submitToQueue:提交任务配置到执行队列
|
||||
* - executeDirectly:直接执行任务
|
||||
* - fetchQueue:获取当前队列
|
||||
* - fetchHistory:获取执行历史
|
||||
* - deleteFromQueue:从队列删除任务
|
||||
* - cancelExecution:取消执行中的任务
|
||||
*/
|
||||
|
||||
import { apiClient } from './client';
|
||||
import type { TaskConfig, QueuedTask, ExecutionLog } from '../types';
|
||||
|
||||
/** 提交任务配置到执行队列 */
|
||||
export async function submitToQueue(config: TaskConfig): Promise<{ id: string }> {
|
||||
const { data } = await apiClient.post<{ id: string }>('/execution/queue', config);
|
||||
return data;
|
||||
}
|
||||
|
||||
/** 直接执行任务(不经过队列) */
|
||||
export async function executeDirectly(config: TaskConfig): Promise<{ execution_id: string }> {
|
||||
const { data } = await apiClient.post<{ execution_id: string }>('/execution/run', config);
|
||||
return data;
|
||||
}
|
||||
|
||||
/** 获取当前任务队列 */
|
||||
export async function fetchQueue(): Promise<QueuedTask[]> {
|
||||
const { data } = await apiClient.get<QueuedTask[]>('/execution/queue');
|
||||
return data;
|
||||
}
|
||||
|
||||
/** 获取执行历史记录 */
|
||||
export async function fetchHistory(limit = 50): Promise<ExecutionLog[]> {
|
||||
const { data } = await apiClient.get<ExecutionLog[]>('/execution/history', { params: { limit } });
|
||||
return data;
|
||||
}
|
||||
|
||||
/** 从队列中删除待执行任务 */
|
||||
export async function deleteFromQueue(id: string): Promise<void> {
|
||||
await apiClient.delete(`/execution/queue/${id}`);
|
||||
}
|
||||
|
||||
/** 取消执行中的任务 */
|
||||
export async function cancelExecution(id: string): Promise<void> {
|
||||
await apiClient.post(`/execution/${id}/cancel`);
|
||||
}
|
||||
48
apps/admin-web/src/api/schedules.ts
Normal file
48
apps/admin-web/src/api/schedules.ts
Normal file
@@ -0,0 +1,48 @@
|
||||
/**
|
||||
* 调度任务相关 API 调用。
|
||||
*/
|
||||
|
||||
import { apiClient } from './client';
|
||||
import type { ScheduledTask, ScheduleConfig, TaskConfig } from '../types';
|
||||
|
||||
/** 获取调度任务列表 */
|
||||
export async function fetchSchedules(): Promise<ScheduledTask[]> {
|
||||
const { data } = await apiClient.get<ScheduledTask[]>('/schedules');
|
||||
return data;
|
||||
}
|
||||
|
||||
/** 创建调度任务 */
|
||||
export async function createSchedule(payload: {
|
||||
name: string;
|
||||
task_codes: string[];
|
||||
task_config: TaskConfig;
|
||||
schedule_config: ScheduleConfig;
|
||||
}): Promise<ScheduledTask> {
|
||||
const { data } = await apiClient.post<ScheduledTask>('/schedules', payload);
|
||||
return data;
|
||||
}
|
||||
|
||||
/** 更新调度任务 */
|
||||
export async function updateSchedule(
|
||||
id: string,
|
||||
payload: Partial<{
|
||||
name: string;
|
||||
task_codes: string[];
|
||||
task_config: TaskConfig;
|
||||
schedule_config: ScheduleConfig;
|
||||
}>,
|
||||
): Promise<ScheduledTask> {
|
||||
const { data } = await apiClient.put<ScheduledTask>(`/schedules/${id}`, payload);
|
||||
return data;
|
||||
}
|
||||
|
||||
/** 删除调度任务 */
|
||||
export async function deleteSchedule(id: string): Promise<void> {
|
||||
await apiClient.delete(`/schedules/${id}`);
|
||||
}
|
||||
|
||||
/** 启用/禁用调度任务 */
|
||||
export async function toggleSchedule(id: string): Promise<ScheduledTask> {
|
||||
const { data } = await apiClient.patch<ScheduledTask>(`/schedules/${id}/toggle`);
|
||||
return data;
|
||||
}
|
||||
32
apps/admin-web/src/api/tasks.ts
Normal file
32
apps/admin-web/src/api/tasks.ts
Normal file
@@ -0,0 +1,32 @@
|
||||
/**
|
||||
* 任务相关 API 调用。
|
||||
*
|
||||
* - fetchTaskRegistry:获取按业务域分组的任务注册表
|
||||
*/
|
||||
|
||||
import { apiClient } from './client';
|
||||
import type { TaskConfig, TaskDefinition } from '../types';
|
||||
|
||||
/** 获取按业务域分组的任务注册表 */
|
||||
export async function fetchTaskRegistry(): Promise<Record<string, TaskDefinition[]>> {
|
||||
// 后端返回 { groups: { 域名: [TaskItem] } },需要解包
|
||||
const { data } = await apiClient.get<{ groups: Record<string, TaskDefinition[]> }>('/tasks/registry');
|
||||
return data.groups;
|
||||
}
|
||||
|
||||
/** 获取按业务域分组的 DWD 表定义 */
|
||||
export async function fetchDwdTables(): Promise<Record<string, string[]>> {
|
||||
// 后端返回 { groups: { 域名: [DwdTableItem] } },需要解包并提取 table_name
|
||||
const { data } = await apiClient.get<{ groups: Record<string, { table_name: string }[]> }>('/tasks/dwd-tables');
|
||||
const result: Record<string, string[]> = {};
|
||||
for (const [domain, items] of Object.entries(data.groups)) {
|
||||
result[domain] = items.map((item) => item.table_name);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/** 验证任务配置并返回生成的 CLI 命令预览 */
|
||||
export async function validateTaskConfig(config: TaskConfig): Promise<{ command: string }> {
|
||||
const { data } = await apiClient.post<{ command: string }>('/tasks/validate', { config });
|
||||
return data;
|
||||
}
|
||||
187
apps/admin-web/src/components/DwdTableSelector.tsx
Normal file
187
apps/admin-web/src/components/DwdTableSelector.tsx
Normal file
@@ -0,0 +1,187 @@
|
||||
/**
|
||||
* 按业务域分组的 DWD 表选择器。
|
||||
*
|
||||
* 从 /api/tasks/dwd-tables 获取 DWD 表定义,按业务域折叠展示,
|
||||
* 支持全选/反选。仅在 Flow 包含 DWD 层时由父组件渲染。
|
||||
*/
|
||||
|
||||
import React, { useEffect, useState, useMemo, useCallback } from "react";
|
||||
import {
|
||||
Collapse,
|
||||
Checkbox,
|
||||
Spin,
|
||||
Alert,
|
||||
Button,
|
||||
Space,
|
||||
Typography,
|
||||
} from "antd";
|
||||
import type { CheckboxChangeEvent } from "antd/es/checkbox";
|
||||
import { fetchDwdTables } from "../api/tasks";
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Props */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
export interface DwdTableSelectorProps {
|
||||
/** 已选中的 DWD 表名列表 */
|
||||
selectedTables: string[];
|
||||
/** 选中表变化回调 */
|
||||
onTablesChange: (tables: string[]) => void;
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 组件 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
const DwdTableSelector: React.FC<DwdTableSelectorProps> = ({
|
||||
selectedTables,
|
||||
onTablesChange,
|
||||
}) => {
|
||||
/** 按业务域分组的 DWD 表 */
|
||||
const [tableGroups, setTableGroups] = useState<Record<string, string[]>>({});
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
/* ---------- 加载 DWD 表定义 ---------- */
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
|
||||
fetchDwdTables()
|
||||
.then((data) => {
|
||||
if (!cancelled) setTableGroups(data);
|
||||
})
|
||||
.catch((err) => {
|
||||
if (!cancelled) setError(err?.message ?? "获取 DWD 表列表失败");
|
||||
})
|
||||
.finally(() => {
|
||||
if (!cancelled) setLoading(false);
|
||||
});
|
||||
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
}, []);
|
||||
|
||||
/** 所有表名的扁平列表 */
|
||||
const allTableNames = useMemo(
|
||||
() => Object.values(tableGroups).flat(),
|
||||
[tableGroups],
|
||||
);
|
||||
|
||||
/* ---------- 事件处理 ---------- */
|
||||
|
||||
/** 单个业务域的勾选变化 */
|
||||
const handleDomainChange = useCallback(
|
||||
(domain: string, checkedTables: string[]) => {
|
||||
const domainTables = new Set(tableGroups[domain] ?? []);
|
||||
const otherSelected = selectedTables.filter((t) => !domainTables.has(t));
|
||||
onTablesChange([...otherSelected, ...checkedTables]);
|
||||
},
|
||||
[selectedTables, tableGroups, onTablesChange],
|
||||
);
|
||||
|
||||
/** 全选 */
|
||||
const handleSelectAll = useCallback(() => {
|
||||
onTablesChange(allTableNames);
|
||||
}, [allTableNames, onTablesChange]);
|
||||
|
||||
/** 反选 */
|
||||
const handleInvertSelection = useCallback(() => {
|
||||
const currentSet = new Set(selectedTables);
|
||||
const inverted = allTableNames.filter((t) => !currentSet.has(t));
|
||||
onTablesChange(inverted);
|
||||
}, [allTableNames, selectedTables, onTablesChange]);
|
||||
|
||||
/* ---------- 渲染 ---------- */
|
||||
|
||||
if (loading) {
|
||||
return <Spin tip="加载 DWD 表列表…" />;
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return <Alert type="error" message="加载失败" description={error} />;
|
||||
}
|
||||
|
||||
const domainEntries = Object.entries(tableGroups);
|
||||
|
||||
if (domainEntries.length === 0) {
|
||||
return <Text type="secondary">无可选 DWD 表</Text>;
|
||||
}
|
||||
|
||||
const selectedCount = selectedTables.filter((t) =>
|
||||
allTableNames.includes(t),
|
||||
).length;
|
||||
|
||||
return (
|
||||
<div>
|
||||
{/* 全选 / 反选 */}
|
||||
<Space style={{ marginBottom: 8 }}>
|
||||
<Button size="small" onClick={handleSelectAll}>
|
||||
全选
|
||||
</Button>
|
||||
<Button size="small" onClick={handleInvertSelection}>
|
||||
反选
|
||||
</Button>
|
||||
<Text type="secondary">
|
||||
已选 {selectedCount} / {allTableNames.length}
|
||||
</Text>
|
||||
</Space>
|
||||
|
||||
<Collapse
|
||||
defaultActiveKey={domainEntries.map(([d]) => d)}
|
||||
items={domainEntries.map(([domain, tables]) => {
|
||||
const domainSelected = selectedTables.filter((t) =>
|
||||
tables.includes(t),
|
||||
);
|
||||
|
||||
const allChecked = domainSelected.length === tables.length;
|
||||
const indeterminate = domainSelected.length > 0 && !allChecked;
|
||||
|
||||
const handleDomainCheckAll = (e: CheckboxChangeEvent) => {
|
||||
handleDomainChange(domain, e.target.checked ? tables : []);
|
||||
};
|
||||
|
||||
return {
|
||||
key: domain,
|
||||
label: (
|
||||
<span onClick={(e) => e.stopPropagation()}>
|
||||
<Checkbox
|
||||
indeterminate={indeterminate}
|
||||
checked={allChecked}
|
||||
onChange={handleDomainCheckAll}
|
||||
style={{ marginRight: 8 }}
|
||||
/>
|
||||
{domain}
|
||||
<Text type="secondary" style={{ marginLeft: 4 }}>
|
||||
({domainSelected.length}/{tables.length})
|
||||
</Text>
|
||||
</span>
|
||||
),
|
||||
children: (
|
||||
<Checkbox.Group
|
||||
value={domainSelected}
|
||||
onChange={(checked) =>
|
||||
handleDomainChange(domain, checked as string[])
|
||||
}
|
||||
>
|
||||
<Space direction="vertical">
|
||||
{tables.map((table) => (
|
||||
<Checkbox key={table} value={table}>
|
||||
{table}
|
||||
</Checkbox>
|
||||
))}
|
||||
</Space>
|
||||
</Checkbox.Group>
|
||||
),
|
||||
};
|
||||
})}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default DwdTableSelector;
|
||||
68
apps/admin-web/src/components/ErrorBoundary.tsx
Normal file
68
apps/admin-web/src/components/ErrorBoundary.tsx
Normal file
@@ -0,0 +1,68 @@
|
||||
/**
|
||||
* 全局错误边界 — 捕获 React 渲染异常,显示错误信息而非白屏。
|
||||
*/
|
||||
|
||||
import React from "react";
|
||||
import { Result, Button, Typography } from "antd";
|
||||
|
||||
const { Paragraph, Text } = Typography;
|
||||
|
||||
interface Props {
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
interface State {
|
||||
hasError: boolean;
|
||||
error: Error | null;
|
||||
}
|
||||
|
||||
class ErrorBoundary extends React.Component<Props, State> {
|
||||
constructor(props: Props) {
|
||||
super(props);
|
||||
this.state = { hasError: false, error: null };
|
||||
}
|
||||
|
||||
static getDerivedStateFromError(error: Error): State {
|
||||
return { hasError: true, error };
|
||||
}
|
||||
|
||||
componentDidCatch(error: Error, info: React.ErrorInfo) {
|
||||
console.error("[ErrorBoundary]", error, info.componentStack);
|
||||
}
|
||||
|
||||
handleReload = () => {
|
||||
this.setState({ hasError: false, error: null });
|
||||
window.location.reload();
|
||||
};
|
||||
|
||||
render() {
|
||||
if (this.state.hasError) {
|
||||
return (
|
||||
<div style={{ padding: 48 }}>
|
||||
<Result
|
||||
status="error"
|
||||
title="页面渲染出错"
|
||||
subTitle="请尝试刷新页面,如果问题持续请联系管理员。"
|
||||
extra={
|
||||
<Button type="primary" onClick={this.handleReload}>
|
||||
刷新页面
|
||||
</Button>
|
||||
}
|
||||
>
|
||||
{this.state.error && (
|
||||
<Paragraph>
|
||||
<Text type="danger" code style={{ whiteSpace: "pre-wrap", wordBreak: "break-all" }}>
|
||||
{this.state.error.message}
|
||||
</Text>
|
||||
</Paragraph>
|
||||
)}
|
||||
</Result>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return this.props.children;
|
||||
}
|
||||
}
|
||||
|
||||
export default ErrorBoundary;
|
||||
79
apps/admin-web/src/components/LogStream.tsx
Normal file
79
apps/admin-web/src/components/LogStream.tsx
Normal file
@@ -0,0 +1,79 @@
|
||||
/**
|
||||
* 日志流展示组件。
|
||||
*
|
||||
* - 等宽字体展示日志行
|
||||
* - 自动滚动到底部(useRef + scrollIntoView)
|
||||
* - 提供"暂停自动滚动"按钮(toggle)
|
||||
*/
|
||||
|
||||
import React, { useEffect, useRef, useState } from "react";
|
||||
import { Button } from "antd";
|
||||
import { PauseCircleOutlined, PlayCircleOutlined } from "@ant-design/icons";
|
||||
|
||||
export interface LogStreamProps {
|
||||
/** 可选的执行 ID,用于标题展示 */
|
||||
executionId?: string;
|
||||
/** 日志行数组 */
|
||||
lines: string[];
|
||||
}
|
||||
|
||||
const LogStream: React.FC<LogStreamProps> = ({ lines }) => {
|
||||
const [autoscroll, setAutoscroll] = useState(true);
|
||||
const bottomRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (autoscroll && bottomRef.current) {
|
||||
bottomRef.current.scrollIntoView({ behavior: "smooth" });
|
||||
}
|
||||
}, [lines, autoscroll]);
|
||||
|
||||
const handleToggle = () => {
|
||||
const next = !autoscroll;
|
||||
setAutoscroll(next);
|
||||
// 恢复时立即滚动到底部
|
||||
if (next && bottomRef.current) {
|
||||
bottomRef.current.scrollIntoView({ behavior: "smooth" });
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div style={{ display: "flex", flexDirection: "column", height: "100%" }}>
|
||||
<div style={{ marginBottom: 8, textAlign: "right" }}>
|
||||
<Button
|
||||
size="small"
|
||||
icon={autoscroll ? <PauseCircleOutlined /> : <PlayCircleOutlined />}
|
||||
onClick={handleToggle}
|
||||
>
|
||||
{autoscroll ? "暂停滚动" : "恢复滚动"}
|
||||
</Button>
|
||||
</div>
|
||||
<div
|
||||
style={{
|
||||
flex: 1,
|
||||
overflow: "auto",
|
||||
background: "#1e1e1e",
|
||||
color: "#d4d4d4",
|
||||
fontFamily: "'Cascadia Code', 'Fira Code', 'Consolas', monospace",
|
||||
fontSize: 13,
|
||||
lineHeight: 1.6,
|
||||
padding: 12,
|
||||
borderRadius: 4,
|
||||
minHeight: 300,
|
||||
}}
|
||||
>
|
||||
{lines.length === 0 ? (
|
||||
<div style={{ color: "#888" }}>暂无日志</div>
|
||||
) : (
|
||||
lines.map((line, i) => (
|
||||
<div key={i} style={{ whiteSpace: "pre-wrap", wordBreak: "break-all" }}>
|
||||
{line}
|
||||
</div>
|
||||
))
|
||||
)}
|
||||
<div ref={bottomRef} />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default LogStream;
|
||||
407
apps/admin-web/src/components/ScheduleTab.tsx
Normal file
407
apps/admin-web/src/components/ScheduleTab.tsx
Normal file
@@ -0,0 +1,407 @@
|
||||
/**
|
||||
* 调度管理 Tab 组件。
|
||||
*
|
||||
* 功能:
|
||||
* - 调度任务列表(名称、调度类型、启用 Switch、下次执行、执行次数、最近状态、操作)
|
||||
* - 创建/编辑调度任务 Modal(名称 + 调度配置)
|
||||
* - 删除确认
|
||||
*/
|
||||
|
||||
import React, { useEffect, useState, useCallback } from 'react';
|
||||
import {
|
||||
Table, Tag, Button, Switch, Popconfirm, Space, Modal, Form,
|
||||
Input, Select, InputNumber, TimePicker, Checkbox, message,
|
||||
} from 'antd';
|
||||
import { PlusOutlined, ReloadOutlined, EditOutlined, DeleteOutlined } from '@ant-design/icons';
|
||||
import type { ColumnsType } from 'antd/es/table';
|
||||
import dayjs from 'dayjs';
|
||||
import type { ScheduledTask, ScheduleConfig } from '../types';
|
||||
import {
|
||||
fetchSchedules,
|
||||
createSchedule,
|
||||
updateSchedule,
|
||||
deleteSchedule,
|
||||
toggleSchedule,
|
||||
} from '../api/schedules';
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 常量 & 工具 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
const STATUS_COLOR: Record<string, string> = {
|
||||
success: 'success',
|
||||
failed: 'error',
|
||||
running: 'processing',
|
||||
cancelled: 'warning',
|
||||
};
|
||||
|
||||
const SCHEDULE_TYPE_LABEL: Record<string, string> = {
|
||||
once: '一次性',
|
||||
interval: '固定间隔',
|
||||
daily: '每日',
|
||||
weekly: '每周',
|
||||
cron: 'Cron',
|
||||
};
|
||||
|
||||
const INTERVAL_UNIT_LABEL: Record<string, string> = {
|
||||
minutes: '分钟',
|
||||
hours: '小时',
|
||||
days: '天',
|
||||
};
|
||||
|
||||
const WEEKDAY_OPTIONS = [
|
||||
{ label: '周一', value: 1 },
|
||||
{ label: '周二', value: 2 },
|
||||
{ label: '周三', value: 3 },
|
||||
{ label: '周四', value: 4 },
|
||||
{ label: '周五', value: 5 },
|
||||
{ label: '周六', value: 6 },
|
||||
{ label: '周日', value: 0 },
|
||||
];
|
||||
|
||||
/** 格式化时间 */
|
||||
function fmtTime(iso: string | null | undefined): string {
|
||||
if (!iso) return '—';
|
||||
return new Date(iso).toLocaleString('zh-CN');
|
||||
}
|
||||
|
||||
/** 根据调度配置生成可读描述 */
|
||||
function describeSchedule(cfg: ScheduleConfig): string {
|
||||
switch (cfg.schedule_type) {
|
||||
case 'once':
|
||||
return '一次性';
|
||||
case 'interval':
|
||||
return `每 ${cfg.interval_value} ${INTERVAL_UNIT_LABEL[cfg.interval_unit] ?? cfg.interval_unit}`;
|
||||
case 'daily':
|
||||
return `每日 ${cfg.daily_time}`;
|
||||
case 'weekly': {
|
||||
const days = (cfg.weekly_days ?? [])
|
||||
.map((d) => WEEKDAY_OPTIONS.find((o) => o.value === d)?.label ?? `${d}`)
|
||||
.join('、');
|
||||
return `每周 ${days} ${cfg.weekly_time}`;
|
||||
}
|
||||
case 'cron':
|
||||
return `Cron: ${cfg.cron_expression}`;
|
||||
default:
|
||||
return cfg.schedule_type;
|
||||
}
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 调度配置表单子组件 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
/** 根据调度类型动态渲染配置项 */
|
||||
const ScheduleConfigFields: React.FC<{ scheduleType: string }> = ({ scheduleType }) => {
|
||||
switch (scheduleType) {
|
||||
case 'interval':
|
||||
return (
|
||||
<Space>
|
||||
<Form.Item name={['schedule_config', 'interval_value']} noStyle rules={[{ required: true }]}>
|
||||
<InputNumber min={1} placeholder="间隔值" />
|
||||
</Form.Item>
|
||||
<Form.Item name={['schedule_config', 'interval_unit']} noStyle rules={[{ required: true }]}>
|
||||
<Select style={{ width: 100 }} options={[
|
||||
{ label: '分钟', value: 'minutes' },
|
||||
{ label: '小时', value: 'hours' },
|
||||
{ label: '天', value: 'days' },
|
||||
]} />
|
||||
</Form.Item>
|
||||
</Space>
|
||||
);
|
||||
case 'daily':
|
||||
return (
|
||||
<Form.Item name={['schedule_config', 'daily_time']} label="执行时间" rules={[{ required: true }]}>
|
||||
<TimePicker format="HH:mm" />
|
||||
</Form.Item>
|
||||
);
|
||||
case 'weekly':
|
||||
return (
|
||||
<>
|
||||
<Form.Item name={['schedule_config', 'weekly_days']} label="星期" rules={[{ required: true }]}>
|
||||
<Checkbox.Group options={WEEKDAY_OPTIONS} />
|
||||
</Form.Item>
|
||||
<Form.Item name={['schedule_config', 'weekly_time']} label="执行时间" rules={[{ required: true }]}>
|
||||
<TimePicker format="HH:mm" />
|
||||
</Form.Item>
|
||||
</>
|
||||
);
|
||||
case 'cron':
|
||||
return (
|
||||
<Form.Item name={['schedule_config', 'cron_expression']} label="Cron 表达式" rules={[{ required: true }]}>
|
||||
<Input placeholder="0 4 * * *" />
|
||||
</Form.Item>
|
||||
);
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 主组件 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
const ScheduleTab: React.FC = () => {
|
||||
const [data, setData] = useState<ScheduledTask[]>([]);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [modalOpen, setModalOpen] = useState(false);
|
||||
const [editing, setEditing] = useState<ScheduledTask | null>(null);
|
||||
const [submitting, setSubmitting] = useState(false);
|
||||
const [scheduleType, setScheduleType] = useState<string>('daily');
|
||||
const [form] = Form.useForm();
|
||||
|
||||
/* 加载列表 */
|
||||
const load = useCallback(async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
setData(await fetchSchedules());
|
||||
} catch {
|
||||
message.error('加载调度任务失败');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
useEffect(() => { load(); }, [load]);
|
||||
|
||||
/* 打开创建 Modal */
|
||||
const openCreate = () => {
|
||||
setEditing(null);
|
||||
form.resetFields();
|
||||
form.setFieldsValue({
|
||||
schedule_config: {
|
||||
schedule_type: 'daily',
|
||||
interval_value: 1,
|
||||
interval_unit: 'hours',
|
||||
daily_time: dayjs('04:00', 'HH:mm'),
|
||||
weekly_days: [1],
|
||||
weekly_time: dayjs('04:00', 'HH:mm'),
|
||||
cron_expression: '0 4 * * *',
|
||||
},
|
||||
});
|
||||
setScheduleType('daily');
|
||||
setModalOpen(true);
|
||||
};
|
||||
|
||||
/* 打开编辑 Modal */
|
||||
const openEdit = (record: ScheduledTask) => {
|
||||
setEditing(record);
|
||||
const cfg = record.schedule_config;
|
||||
form.setFieldsValue({
|
||||
name: record.name,
|
||||
schedule_config: {
|
||||
...cfg,
|
||||
daily_time: cfg.daily_time ? dayjs(cfg.daily_time, 'HH:mm') : undefined,
|
||||
weekly_time: cfg.weekly_time ? dayjs(cfg.weekly_time, 'HH:mm') : undefined,
|
||||
},
|
||||
});
|
||||
setScheduleType(cfg.schedule_type);
|
||||
setModalOpen(true);
|
||||
};
|
||||
|
||||
/* 提交创建/编辑 */
|
||||
const handleSubmit = async () => {
|
||||
try {
|
||||
const values = await form.validateFields();
|
||||
setSubmitting(true);
|
||||
|
||||
// 将 dayjs 对象转为字符串
|
||||
const cfg = { ...values.schedule_config };
|
||||
if (cfg.daily_time && typeof cfg.daily_time !== 'string') {
|
||||
cfg.daily_time = cfg.daily_time.format('HH:mm');
|
||||
}
|
||||
if (cfg.weekly_time && typeof cfg.weekly_time !== 'string') {
|
||||
cfg.weekly_time = cfg.weekly_time.format('HH:mm');
|
||||
}
|
||||
|
||||
const scheduleConfig: ScheduleConfig = {
|
||||
schedule_type: cfg.schedule_type ?? 'daily',
|
||||
interval_value: cfg.interval_value ?? 1,
|
||||
interval_unit: cfg.interval_unit ?? 'hours',
|
||||
daily_time: cfg.daily_time ?? '04:00',
|
||||
weekly_days: cfg.weekly_days ?? [1],
|
||||
weekly_time: cfg.weekly_time ?? '04:00',
|
||||
cron_expression: cfg.cron_expression ?? '0 4 * * *',
|
||||
enabled: true,
|
||||
start_date: null,
|
||||
end_date: null,
|
||||
};
|
||||
|
||||
if (editing) {
|
||||
await updateSchedule(editing.id, {
|
||||
name: values.name,
|
||||
schedule_config: scheduleConfig,
|
||||
});
|
||||
message.success('调度任务已更新');
|
||||
} else {
|
||||
// 创建时使用默认 task_config(简化实现)
|
||||
await createSchedule({
|
||||
name: values.name,
|
||||
task_codes: [],
|
||||
task_config: {
|
||||
tasks: [],
|
||||
pipeline: 'api_full',
|
||||
processing_mode: 'increment_only',
|
||||
pipeline_flow: 'FULL',
|
||||
dry_run: false,
|
||||
window_mode: 'lookback',
|
||||
window_start: null,
|
||||
window_end: null,
|
||||
window_split: null,
|
||||
window_split_days: null,
|
||||
lookback_hours: 24,
|
||||
overlap_seconds: 600,
|
||||
fetch_before_verify: false,
|
||||
skip_ods_when_fetch_before_verify: false,
|
||||
ods_use_local_json: false,
|
||||
store_id: null,
|
||||
dwd_only_tables: null,
|
||||
force_full: false,
|
||||
extra_args: {},
|
||||
},
|
||||
schedule_config: scheduleConfig,
|
||||
});
|
||||
message.success('调度任务已创建');
|
||||
}
|
||||
|
||||
setModalOpen(false);
|
||||
load();
|
||||
} catch {
|
||||
// 表单验证失败,不做额外处理
|
||||
} finally {
|
||||
setSubmitting(false);
|
||||
}
|
||||
};
|
||||
|
||||
/* 删除 */
|
||||
const handleDelete = async (id: string) => {
|
||||
try {
|
||||
await deleteSchedule(id);
|
||||
message.success('已删除');
|
||||
load();
|
||||
} catch {
|
||||
message.error('删除失败');
|
||||
}
|
||||
};
|
||||
|
||||
/* 启用/禁用 */
|
||||
const handleToggle = async (id: string) => {
|
||||
try {
|
||||
await toggleSchedule(id);
|
||||
load();
|
||||
} catch {
|
||||
message.error('切换状态失败');
|
||||
}
|
||||
};
|
||||
|
||||
/* 表格列定义 */
|
||||
const columns: ColumnsType<ScheduledTask> = [
|
||||
{
|
||||
title: '名称',
|
||||
dataIndex: 'name',
|
||||
key: 'name',
|
||||
},
|
||||
{
|
||||
title: '调度类型',
|
||||
key: 'schedule_type',
|
||||
render: (_: unknown, record: ScheduledTask) => describeSchedule(record.schedule_config),
|
||||
},
|
||||
{
|
||||
title: '启用',
|
||||
dataIndex: 'enabled',
|
||||
key: 'enabled',
|
||||
width: 80,
|
||||
render: (enabled: boolean, record: ScheduledTask) => (
|
||||
<Switch checked={enabled} onChange={() => handleToggle(record.id)} size="small" />
|
||||
),
|
||||
},
|
||||
{
|
||||
title: '下次执行',
|
||||
dataIndex: 'next_run_at',
|
||||
key: 'next_run_at',
|
||||
render: fmtTime,
|
||||
},
|
||||
{
|
||||
title: '执行次数',
|
||||
dataIndex: 'run_count',
|
||||
key: 'run_count',
|
||||
width: 90,
|
||||
},
|
||||
{
|
||||
title: '最近状态',
|
||||
dataIndex: 'last_status',
|
||||
key: 'last_status',
|
||||
width: 100,
|
||||
render: (s: string | null) =>
|
||||
s ? <Tag color={STATUS_COLOR[s] ?? 'default'}>{s}</Tag> : '—',
|
||||
},
|
||||
{
|
||||
title: '操作',
|
||||
key: 'action',
|
||||
width: 140,
|
||||
render: (_: unknown, record: ScheduledTask) => (
|
||||
<Space size="small">
|
||||
<Button type="link" icon={<EditOutlined />} size="small" onClick={() => openEdit(record)}>
|
||||
编辑
|
||||
</Button>
|
||||
<Popconfirm title="确认删除该调度任务?" onConfirm={() => handleDelete(record.id)}>
|
||||
<Button type="link" danger icon={<DeleteOutlined />} size="small">
|
||||
删除
|
||||
</Button>
|
||||
</Popconfirm>
|
||||
</Space>
|
||||
),
|
||||
},
|
||||
];
|
||||
|
||||
return (
|
||||
<>
|
||||
<div style={{ marginBottom: 12 }}>
|
||||
<Space>
|
||||
<Button type="primary" icon={<PlusOutlined />} onClick={openCreate}>
|
||||
新建调度
|
||||
</Button>
|
||||
<Button icon={<ReloadOutlined />} onClick={load} loading={loading}>
|
||||
刷新
|
||||
</Button>
|
||||
</Space>
|
||||
</div>
|
||||
|
||||
<Table<ScheduledTask>
|
||||
rowKey="id"
|
||||
columns={columns}
|
||||
dataSource={data}
|
||||
loading={loading}
|
||||
pagination={false}
|
||||
size="middle"
|
||||
/>
|
||||
|
||||
{/* 创建/编辑 Modal */}
|
||||
<Modal
|
||||
title={editing ? '编辑调度任务' : '新建调度任务'}
|
||||
open={modalOpen}
|
||||
onOk={handleSubmit}
|
||||
onCancel={() => setModalOpen(false)}
|
||||
confirmLoading={submitting}
|
||||
destroyOnClose
|
||||
>
|
||||
<Form form={form} layout="vertical" preserve={false}>
|
||||
<Form.Item name="name" label="名称" rules={[{ required: true, message: '请输入调度任务名称' }]}>
|
||||
<Input placeholder="例如:每日全量同步" />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item name={['schedule_config', 'schedule_type']} label="调度类型" rules={[{ required: true }]}>
|
||||
<Select
|
||||
options={Object.entries(SCHEDULE_TYPE_LABEL).map(([value, label]) => ({ value, label }))}
|
||||
onChange={(v: string) => setScheduleType(v)}
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
<ScheduleConfigFields scheduleType={scheduleType} />
|
||||
</Form>
|
||||
</Modal>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default ScheduleTab;
|
||||
309
apps/admin-web/src/components/TaskSelector.tsx
Normal file
309
apps/admin-web/src/components/TaskSelector.tsx
Normal file
@@ -0,0 +1,309 @@
|
||||
/**
|
||||
* 按业务域分组的任务选择器。
|
||||
*
|
||||
* 从 /api/tasks/registry 获取任务注册表,按业务域折叠展示,
|
||||
* 支持全选/反选和按 Flow 层级过滤。
|
||||
* 当 Flow 包含 DWD 层时,在 DWD 任务下方内嵌表过滤子选项。
|
||||
*/
|
||||
|
||||
import React, { useEffect, useState, useMemo, useCallback } from "react";
|
||||
import {
|
||||
Collapse,
|
||||
Checkbox,
|
||||
Spin,
|
||||
Alert,
|
||||
Button,
|
||||
Space,
|
||||
Typography,
|
||||
Tag,
|
||||
Divider,
|
||||
} from "antd";
|
||||
import type { CheckboxChangeEvent } from "antd/es/checkbox";
|
||||
import { fetchTaskRegistry, fetchDwdTables } from "../api/tasks";
|
||||
import type { TaskDefinition } from "../types";
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Props */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
export interface TaskSelectorProps {
|
||||
/** 当前 Flow 包含的层(如 ["ODS", "DWD"]) */
|
||||
layers: string[];
|
||||
/** 已选中的任务编码列表 */
|
||||
selectedTasks: string[];
|
||||
/** 选中任务变化回调 */
|
||||
onTasksChange: (tasks: string[]) => void;
|
||||
/** DWD 表过滤:已选中的表名列表 */
|
||||
selectedDwdTables?: string[];
|
||||
/** DWD 表过滤变化回调 */
|
||||
onDwdTablesChange?: (tables: string[]) => void;
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 过滤逻辑 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
export function filterTasksByLayers(
|
||||
tasks: TaskDefinition[],
|
||||
layers: string[],
|
||||
): TaskDefinition[] {
|
||||
if (layers.length === 0) return [];
|
||||
return tasks;
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 组件 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
const TaskSelector: React.FC<TaskSelectorProps> = ({
|
||||
layers,
|
||||
selectedTasks,
|
||||
onTasksChange,
|
||||
selectedDwdTables = [],
|
||||
onDwdTablesChange,
|
||||
}) => {
|
||||
const [registry, setRegistry] = useState<Record<string, TaskDefinition[]>>({});
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
// DWD 表定义(按域分组)
|
||||
const [dwdTableGroups, setDwdTableGroups] = useState<Record<string, string[]>>({});
|
||||
const showDwdFilter = layers.includes("DWD") && !!onDwdTablesChange;
|
||||
|
||||
/* ---------- 加载任务注册表 ---------- */
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
|
||||
const promises: Promise<void>[] = [
|
||||
fetchTaskRegistry()
|
||||
.then((data) => { if (!cancelled) setRegistry(data); })
|
||||
.catch((err) => { if (!cancelled) setError(err?.message ?? "获取任务列表失败"); }),
|
||||
];
|
||||
// 如果包含 DWD 层,同时加载 DWD 表定义
|
||||
if (layers.includes("DWD")) {
|
||||
promises.push(
|
||||
fetchDwdTables()
|
||||
.then((data) => { if (!cancelled) setDwdTableGroups(data); })
|
||||
.catch(() => { /* DWD 表加载失败不阻塞任务列表 */ }),
|
||||
);
|
||||
}
|
||||
|
||||
Promise.all(promises).finally(() => { if (!cancelled) setLoading(false); });
|
||||
return () => { cancelled = true; };
|
||||
}, [layers]);
|
||||
|
||||
/* ---------- 按 layers 过滤后的分组 ---------- */
|
||||
const filteredGroups = useMemo(() => {
|
||||
const result: Record<string, TaskDefinition[]> = {};
|
||||
for (const [domain, tasks] of Object.entries(registry)) {
|
||||
const visible = filterTasksByLayers(tasks, layers);
|
||||
if (visible.length > 0) {
|
||||
result[domain] = [...visible].sort((a, b) => {
|
||||
if (a.is_common === b.is_common) return 0;
|
||||
return a.is_common ? -1 : 1;
|
||||
});
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}, [registry, layers]);
|
||||
|
||||
const allVisibleCodes = useMemo(
|
||||
() => Object.values(filteredGroups).flatMap((t) => t.map((d) => d.code)),
|
||||
[filteredGroups],
|
||||
);
|
||||
|
||||
// DWD 表扁平列表
|
||||
const allDwdTableNames = useMemo(
|
||||
() => Object.values(dwdTableGroups).flat(),
|
||||
[dwdTableGroups],
|
||||
);
|
||||
|
||||
/* ---------- 事件处理 ---------- */
|
||||
|
||||
const handleDomainChange = useCallback(
|
||||
(domain: string, checkedCodes: string[]) => {
|
||||
const otherDomainCodes = selectedTasks.filter(
|
||||
(code) => !filteredGroups[domain]?.some((t) => t.code === code),
|
||||
);
|
||||
onTasksChange([...otherDomainCodes, ...checkedCodes]);
|
||||
},
|
||||
[selectedTasks, filteredGroups, onTasksChange],
|
||||
);
|
||||
|
||||
const handleSelectAll = useCallback(() => {
|
||||
onTasksChange(allVisibleCodes);
|
||||
}, [allVisibleCodes, onTasksChange]);
|
||||
|
||||
const handleInvertSelection = useCallback(() => {
|
||||
const currentSet = new Set(selectedTasks);
|
||||
const inverted = allVisibleCodes.filter((code) => !currentSet.has(code));
|
||||
onTasksChange(inverted);
|
||||
}, [allVisibleCodes, selectedTasks, onTasksChange]);
|
||||
|
||||
/* ---------- DWD 表过滤事件 ---------- */
|
||||
|
||||
const handleDwdDomainTableChange = useCallback(
|
||||
(domain: string, checked: string[]) => {
|
||||
if (!onDwdTablesChange) return;
|
||||
const domainTables = new Set(dwdTableGroups[domain] ?? []);
|
||||
const otherSelected = selectedDwdTables.filter((t) => !domainTables.has(t));
|
||||
onDwdTablesChange([...otherSelected, ...checked]);
|
||||
},
|
||||
[selectedDwdTables, dwdTableGroups, onDwdTablesChange],
|
||||
);
|
||||
|
||||
const handleDwdSelectAll = useCallback(() => {
|
||||
onDwdTablesChange?.(allDwdTableNames);
|
||||
}, [allDwdTableNames, onDwdTablesChange]);
|
||||
|
||||
const handleDwdClearAll = useCallback(() => {
|
||||
onDwdTablesChange?.([]);
|
||||
}, [onDwdTablesChange]);
|
||||
|
||||
/* ---------- 渲染 ---------- */
|
||||
|
||||
if (loading) return <Spin tip="加载任务列表…" />;
|
||||
if (error) return <Alert type="error" message="加载失败" description={error} />;
|
||||
|
||||
const domainEntries = Object.entries(filteredGroups);
|
||||
if (domainEntries.length === 0) return <Text type="secondary">当前 Flow 无可选任务</Text>;
|
||||
|
||||
const selectedCount = selectedTasks.filter((c) => allVisibleCodes.includes(c)).length;
|
||||
// DWD 装载任务是否被选中
|
||||
const dwdLoadSelected = selectedTasks.includes("DWD_LOAD_FROM_ODS");
|
||||
|
||||
return (
|
||||
<div>
|
||||
<Space style={{ marginBottom: 8 }}>
|
||||
<Button size="small" onClick={handleSelectAll}>全选</Button>
|
||||
<Button size="small" onClick={handleInvertSelection}>反选</Button>
|
||||
<Text type="secondary">已选 {selectedCount} / {allVisibleCodes.length}</Text>
|
||||
</Space>
|
||||
|
||||
<Collapse
|
||||
defaultActiveKey={domainEntries.map(([d]) => d)}
|
||||
items={domainEntries.map(([domain, tasks]) => {
|
||||
const domainCodes = tasks.map((t) => t.code);
|
||||
const domainSelected = selectedTasks.filter((c) => domainCodes.includes(c));
|
||||
const allChecked = domainSelected.length === domainCodes.length;
|
||||
const indeterminate = domainSelected.length > 0 && !allChecked;
|
||||
|
||||
const handleDomainCheckAll = (e: CheckboxChangeEvent) => {
|
||||
handleDomainChange(domain, e.target.checked ? domainCodes : []);
|
||||
};
|
||||
|
||||
return {
|
||||
key: domain,
|
||||
label: (
|
||||
<span onClick={(e) => e.stopPropagation()}>
|
||||
<Checkbox
|
||||
indeterminate={indeterminate}
|
||||
checked={allChecked}
|
||||
onChange={handleDomainCheckAll}
|
||||
style={{ marginRight: 8 }}
|
||||
/>
|
||||
{domain}
|
||||
<Text type="secondary" style={{ marginLeft: 4 }}>
|
||||
({domainSelected.length}/{domainCodes.length})
|
||||
</Text>
|
||||
</span>
|
||||
),
|
||||
children: (
|
||||
<Checkbox.Group
|
||||
value={domainSelected}
|
||||
onChange={(checked) => handleDomainChange(domain, checked as string[])}
|
||||
>
|
||||
<Space direction="vertical" style={{ width: "100%" }}>
|
||||
{tasks.map((t) => (
|
||||
<Checkbox key={t.code} value={t.code}>
|
||||
<Text strong style={t.is_common === false ? { color: "#999" } : undefined}>{t.code}</Text>
|
||||
<Text type="secondary" style={{ marginLeft: 8 }}>{t.name}</Text>
|
||||
{t.is_common === false && (
|
||||
<Tag color="default" style={{ marginLeft: 6, fontSize: 11 }}>不常用</Tag>
|
||||
)}
|
||||
</Checkbox>
|
||||
))}
|
||||
</Space>
|
||||
</Checkbox.Group>
|
||||
),
|
||||
};
|
||||
})}
|
||||
/>
|
||||
|
||||
{/* DWD 表过滤:仅在 DWD 层且 DWD_LOAD_FROM_ODS 被选中时显示 */}
|
||||
{showDwdFilter && dwdLoadSelected && allDwdTableNames.length > 0 && (
|
||||
<>
|
||||
<Divider style={{ margin: "12px 0 8px" }} />
|
||||
<div style={{ padding: "0 4px" }}>
|
||||
<Space style={{ marginBottom: 6 }}>
|
||||
<Text strong style={{ fontSize: 13 }}>DWD 表过滤</Text>
|
||||
<Text type="secondary" style={{ fontSize: 12 }}>
|
||||
{selectedDwdTables.length === 0
|
||||
? "(未选择 = 全部装载)"
|
||||
: `已选 ${selectedDwdTables.length} / ${allDwdTableNames.length}`}
|
||||
</Text>
|
||||
</Space>
|
||||
<div style={{ marginBottom: 6 }}>
|
||||
<Space size={4}>
|
||||
<Button size="small" type="link" style={{ padding: 0, fontSize: 12 }} onClick={handleDwdSelectAll}>
|
||||
全选
|
||||
</Button>
|
||||
<Button size="small" type="link" style={{ padding: 0, fontSize: 12 }} onClick={handleDwdClearAll}>
|
||||
清空(全部装载)
|
||||
</Button>
|
||||
</Space>
|
||||
</div>
|
||||
<Collapse
|
||||
size="small"
|
||||
items={Object.entries(dwdTableGroups).map(([domain, tables]) => {
|
||||
const domainSelected = selectedDwdTables.filter((t) => tables.includes(t));
|
||||
const allDomainChecked = domainSelected.length === tables.length;
|
||||
const domainIndeterminate = domainSelected.length > 0 && !allDomainChecked;
|
||||
|
||||
return {
|
||||
key: domain,
|
||||
label: (
|
||||
<span onClick={(e) => e.stopPropagation()}>
|
||||
<Checkbox
|
||||
indeterminate={domainIndeterminate}
|
||||
checked={allDomainChecked}
|
||||
onChange={(e: CheckboxChangeEvent) =>
|
||||
handleDwdDomainTableChange(domain, e.target.checked ? tables : [])
|
||||
}
|
||||
style={{ marginRight: 8 }}
|
||||
/>
|
||||
{domain}
|
||||
<Text type="secondary" style={{ marginLeft: 4, fontSize: 12 }}>
|
||||
({domainSelected.length}/{tables.length})
|
||||
</Text>
|
||||
</span>
|
||||
),
|
||||
children: (
|
||||
<Checkbox.Group
|
||||
value={domainSelected}
|
||||
onChange={(checked) => handleDwdDomainTableChange(domain, checked as string[])}
|
||||
>
|
||||
<Space direction="vertical">
|
||||
{tables.map((table) => (
|
||||
<Checkbox key={table} value={table}>
|
||||
<Text style={{ fontSize: 12 }}>{table}</Text>
|
||||
</Checkbox>
|
||||
))}
|
||||
</Space>
|
||||
</Checkbox.Group>
|
||||
),
|
||||
};
|
||||
})}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default TaskSelector;
|
||||
22
apps/admin-web/src/main.tsx
Normal file
22
apps/admin-web/src/main.tsx
Normal file
@@ -0,0 +1,22 @@
|
||||
import React from "react";
|
||||
import ReactDOM from "react-dom/client";
|
||||
import { BrowserRouter } from "react-router-dom";
|
||||
import { ConfigProvider } from "antd";
|
||||
import zhCN from "antd/locale/zh_CN";
|
||||
import App from "./App";
|
||||
import ErrorBoundary from "./components/ErrorBoundary";
|
||||
|
||||
/**
|
||||
* 入口:ErrorBoundary + BrowserRouter + antd 中文 locale + App 根组件。
|
||||
*/
|
||||
ReactDOM.createRoot(document.getElementById("root")!).render(
|
||||
<React.StrictMode>
|
||||
<ErrorBoundary>
|
||||
<BrowserRouter>
|
||||
<ConfigProvider locale={zhCN}>
|
||||
<App />
|
||||
</ConfigProvider>
|
||||
</BrowserRouter>
|
||||
</ErrorBoundary>
|
||||
</React.StrictMode>,
|
||||
);
|
||||
235
apps/admin-web/src/pages/DBViewer.tsx
Normal file
235
apps/admin-web/src/pages/DBViewer.tsx
Normal file
@@ -0,0 +1,235 @@
|
||||
/**
|
||||
* 数据库查看器页面。
|
||||
*
|
||||
* - 左侧:Schema → Table 层级树,异步加载
|
||||
* - 右侧上方:SQL 编辑器 + 执行按钮
|
||||
* - 右侧下方:列定义 / 查询结果 Table
|
||||
*/
|
||||
|
||||
import React, { useEffect, useState, useCallback } from 'react';
|
||||
import { Tree, Input, Button, Table, Space, message, Spin, Tag, Card, Typography, Tooltip } from 'antd';
|
||||
import {
|
||||
PlayCircleOutlined, ReloadOutlined, TableOutlined,
|
||||
DatabaseOutlined, CopyOutlined,
|
||||
} from '@ant-design/icons';
|
||||
import type { DataNode, EventDataNode } from 'antd/es/tree';
|
||||
import type { ColumnsType } from 'antd/es/table';
|
||||
import {
|
||||
fetchSchemas, fetchTables, fetchColumns, executeQuery,
|
||||
type ColumnInfo, type QueryResult,
|
||||
} from '../api/dbViewer';
|
||||
|
||||
const { TextArea } = Input;
|
||||
const { Title, Text } = Typography;
|
||||
|
||||
const schemaKey = (schema: string) => `s::${schema}`;
|
||||
const tableKey = (schema: string, table: string) => `t::${schema}::${table}`;
|
||||
|
||||
function parseTableKey(key: string): { schema: string; table: string } | null {
|
||||
if (!key.startsWith('t::')) return null;
|
||||
const parts = key.slice(3).split('::');
|
||||
if (parts.length !== 2) return null;
|
||||
return { schema: parts[0], table: parts[1] };
|
||||
}
|
||||
|
||||
const DBViewer: React.FC = () => {
|
||||
const [treeData, setTreeData] = useState<DataNode[]>([]);
|
||||
const [loadingTree, setLoadingTree] = useState(false);
|
||||
const [expandedKeys, setExpandedKeys] = useState<React.Key[]>([]);
|
||||
const [selectedTable, setSelectedTable] = useState<{ schema: string; table: string } | null>(null);
|
||||
const [columnData, setColumnData] = useState<ColumnInfo[]>([]);
|
||||
const [loadingColumns, setLoadingColumns] = useState(false);
|
||||
const [sql, setSql] = useState('');
|
||||
const [queryResult, setQueryResult] = useState<QueryResult | null>(null);
|
||||
const [loadingQuery, setLoadingQuery] = useState(false);
|
||||
|
||||
const loadSchemas = useCallback(async () => {
|
||||
setLoadingTree(true);
|
||||
try {
|
||||
const schemas = await fetchSchemas();
|
||||
setTreeData(
|
||||
schemas.map((s) => ({
|
||||
title: s, key: schemaKey(s), icon: <DatabaseOutlined />, isLeaf: false,
|
||||
})),
|
||||
);
|
||||
} catch { message.error('加载 Schema 列表失败'); }
|
||||
finally { setLoadingTree(false); }
|
||||
}, []);
|
||||
|
||||
useEffect(() => { loadSchemas(); }, [loadSchemas]);
|
||||
|
||||
const onLoadData = async (node: EventDataNode<DataNode>) => {
|
||||
const key = node.key as string;
|
||||
if (!key.startsWith('s::')) return;
|
||||
if (node.children && node.children.length > 0) return;
|
||||
const schema = key.slice(3);
|
||||
try {
|
||||
const tables = await fetchTables(schema);
|
||||
const children: DataNode[] = tables.map((t) => ({
|
||||
title: (
|
||||
<Space size={4}>
|
||||
<span>{t.name}</span>
|
||||
<Text type="secondary" style={{ fontSize: 11 }}>({t.row_count.toLocaleString()})</Text>
|
||||
</Space>
|
||||
),
|
||||
key: tableKey(schema, t.name), icon: <TableOutlined />, isLeaf: true,
|
||||
}));
|
||||
setTreeData((prev) => prev.map((n) => n.key === key ? { ...n, children } : n));
|
||||
} catch { message.error(`加载 ${schema} 的表列表失败`); }
|
||||
};
|
||||
|
||||
const onSelectNode = async (_: React.Key[], info: { node: DataNode }) => {
|
||||
const key = info.node.key as string;
|
||||
const parsed = parseTableKey(key);
|
||||
if (!parsed) return;
|
||||
setSelectedTable(parsed);
|
||||
setLoadingColumns(true);
|
||||
setQueryResult(null);
|
||||
try {
|
||||
const cols = await fetchColumns(parsed.schema, parsed.table);
|
||||
setColumnData(cols);
|
||||
setSql(`SELECT * FROM ${parsed.schema}.${parsed.table} LIMIT 100;`);
|
||||
} catch { message.error('加载列定义失败'); setColumnData([]); }
|
||||
finally { setLoadingColumns(false); }
|
||||
};
|
||||
|
||||
const handleExecute = async () => {
|
||||
const trimmed = sql.trim();
|
||||
if (!trimmed) { message.warning('请输入 SQL 语句'); return; }
|
||||
setLoadingQuery(true);
|
||||
try {
|
||||
const result = await executeQuery(trimmed);
|
||||
setQueryResult(result);
|
||||
} catch (err: unknown) {
|
||||
const axiosErr = err as { response?: { data?: { detail?: string } } };
|
||||
const msg = axiosErr.response?.data?.detail ?? (err instanceof Error ? err.message : '查询执行失败');
|
||||
message.error(msg);
|
||||
setQueryResult(null);
|
||||
} finally { setLoadingQuery(false); }
|
||||
};
|
||||
|
||||
const handleCopySql = () => {
|
||||
navigator.clipboard.writeText(sql).then(() => message.success('已复制'));
|
||||
};
|
||||
|
||||
const columnDefColumns: ColumnsType<ColumnInfo> = [
|
||||
{ title: '列名', dataIndex: 'name', key: 'name', render: (v: string) => <code>{v}</code> },
|
||||
{ title: '数据类型', dataIndex: 'data_type', key: 'data_type' },
|
||||
{
|
||||
title: '可空', dataIndex: 'is_nullable', key: 'is_nullable', width: 70, align: 'center',
|
||||
render: (v: boolean) => v ? <Tag color="orange">YES</Tag> : <Tag color="blue">NO</Tag>,
|
||||
},
|
||||
{
|
||||
title: '默认值', dataIndex: 'default_value', key: 'default_value',
|
||||
render: (v: string | null) => v != null ? <code style={{ fontSize: 12 }}>{v}</code> : <Text type="secondary">—</Text>,
|
||||
},
|
||||
];
|
||||
|
||||
const resultColumns: ColumnsType<Record<string, unknown>> = queryResult
|
||||
? queryResult.columns.map((col, idx) => ({
|
||||
title: col, dataIndex: String(idx), key: col, ellipsis: true,
|
||||
render: (v: unknown) => {
|
||||
if (v === null || v === undefined) return <Text type="secondary">NULL</Text>;
|
||||
return String(v);
|
||||
},
|
||||
}))
|
||||
: [];
|
||||
|
||||
const resultDataSource: Record<string, unknown>[] = queryResult
|
||||
? queryResult.rows.map((row, rowIdx) => {
|
||||
const obj: Record<string, unknown> = { _key: rowIdx };
|
||||
row.forEach((cell, colIdx) => { obj[String(colIdx)] = cell; });
|
||||
return obj;
|
||||
})
|
||||
: [];
|
||||
|
||||
return (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', height: 'calc(100vh - 120px)' }}>
|
||||
<div style={{ marginBottom: 12, display: 'flex', justifyContent: 'space-between', alignItems: 'center' }}>
|
||||
<Title level={4} style={{ margin: 0 }}>
|
||||
<DatabaseOutlined style={{ marginRight: 8 }} />
|
||||
数据库查看器
|
||||
</Title>
|
||||
{selectedTable && (
|
||||
<Tag color="blue">{selectedTable.schema}.{selectedTable.table}</Tag>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div style={{ display: 'flex', flex: 1, gap: 12, minHeight: 0 }}>
|
||||
{/* 左侧树 */}
|
||||
<Card
|
||||
size="small"
|
||||
title="Schema / 表"
|
||||
extra={<Button size="small" icon={<ReloadOutlined />} onClick={loadSchemas} loading={loadingTree} />}
|
||||
style={{ width: 260, minWidth: 260, display: 'flex', flexDirection: 'column' }}
|
||||
styles={{ body: { flex: 1, overflow: 'auto', padding: '8px 12px' } }}
|
||||
>
|
||||
<Spin spinning={loadingTree}>
|
||||
<Tree
|
||||
showIcon treeData={treeData} loadData={onLoadData}
|
||||
expandedKeys={expandedKeys}
|
||||
onExpand={(keys) => setExpandedKeys(keys)}
|
||||
onSelect={onSelectNode}
|
||||
/>
|
||||
</Spin>
|
||||
</Card>
|
||||
|
||||
{/* 右侧 */}
|
||||
<div style={{ flex: 1, display: 'flex', flexDirection: 'column', minWidth: 0 }}>
|
||||
{/* SQL 编辑器 */}
|
||||
<Card size="small" style={{ marginBottom: 12 }} styles={{ body: { padding: '8px 12px' } }}>
|
||||
<TextArea
|
||||
rows={5} value={sql} onChange={(e) => setSql(e.target.value)}
|
||||
placeholder="输入 SQL 查询语句…"
|
||||
style={{ fontFamily: "'Cascadia Code', 'Fira Code', Consolas, monospace", fontSize: 13, marginBottom: 8 }}
|
||||
onKeyDown={(e) => { if ((e.ctrlKey || e.metaKey) && e.key === 'Enter') { e.preventDefault(); handleExecute(); } }}
|
||||
/>
|
||||
<Space>
|
||||
<Button type="primary" icon={<PlayCircleOutlined />} onClick={handleExecute} loading={loadingQuery}>
|
||||
执行 <Text type="secondary" style={{ fontSize: 11, marginLeft: 4 }}>(Ctrl+Enter)</Text>
|
||||
</Button>
|
||||
<Tooltip title="复制 SQL">
|
||||
<Button icon={<CopyOutlined />} onClick={handleCopySql} />
|
||||
</Tooltip>
|
||||
</Space>
|
||||
</Card>
|
||||
|
||||
{/* 结果区域 */}
|
||||
<Card size="small" style={{ flex: 1, display: 'flex', flexDirection: 'column' }}
|
||||
styles={{ body: { flex: 1, overflow: 'auto', padding: '8px 12px' } }}
|
||||
>
|
||||
{queryResult ? (
|
||||
<>
|
||||
<div style={{ marginBottom: 8 }}>
|
||||
<Text type="secondary">查询返回 {queryResult.row_count} 行</Text>
|
||||
</div>
|
||||
<Table<Record<string, unknown>>
|
||||
rowKey="_key" columns={resultColumns} dataSource={resultDataSource}
|
||||
pagination={{ pageSize: 50, showSizeChanger: false, showTotal: (t) => `共 ${t} 行` }}
|
||||
size="small" scroll={{ x: 'max-content' }} bordered
|
||||
/>
|
||||
</>
|
||||
) : selectedTable ? (
|
||||
<>
|
||||
<div style={{ marginBottom: 8 }}>
|
||||
<Text strong>{selectedTable.schema}.{selectedTable.table}</Text>
|
||||
<Text type="secondary" style={{ marginLeft: 8 }}>列定义</Text>
|
||||
</div>
|
||||
<Table<ColumnInfo>
|
||||
rowKey="name" columns={columnDefColumns} dataSource={columnData}
|
||||
loading={loadingColumns} pagination={false} size="small" bordered
|
||||
/>
|
||||
</>
|
||||
) : (
|
||||
<div style={{ color: '#bbb', textAlign: 'center', marginTop: 60 }}>
|
||||
在左侧选择一张表查看列定义,或输入 SQL 执行查询
|
||||
</div>
|
||||
)}
|
||||
</Card>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default DBViewer;
|
||||
137
apps/admin-web/src/pages/ETLStatus.tsx
Normal file
137
apps/admin-web/src/pages/ETLStatus.tsx
Normal file
@@ -0,0 +1,137 @@
|
||||
/**
|
||||
* ETL 状态监控页面。
|
||||
*
|
||||
* - 游标状态 Table
|
||||
* - 最近执行记录 Table
|
||||
*/
|
||||
|
||||
import React, { useEffect, useState, useCallback } from 'react';
|
||||
import { Table, Tag, Button, message, Typography, Card, Row, Col, Statistic } from 'antd';
|
||||
import { ReloadOutlined, DashboardOutlined, DatabaseOutlined, PlayCircleOutlined } from '@ant-design/icons';
|
||||
import type { ColumnsType } from 'antd/es/table';
|
||||
import {
|
||||
fetchCursors, fetchRecentRuns,
|
||||
type CursorInfo, type RecentRun,
|
||||
} from '../api/etlStatus';
|
||||
|
||||
const { Title, Text } = Typography;
|
||||
|
||||
const STATUS_COLOR: Record<string, string> = {
|
||||
success: 'green', failed: 'red', running: 'blue', cancelled: 'orange',
|
||||
};
|
||||
|
||||
function formatTime(raw: string | null): string {
|
||||
if (!raw) return '—';
|
||||
const d = new Date(raw);
|
||||
return Number.isNaN(d.getTime()) ? raw : d.toLocaleString('zh-CN');
|
||||
}
|
||||
|
||||
function formatDuration(ms: number | null): string {
|
||||
if (ms == null) return '—';
|
||||
if (ms < 1000) return `${ms}ms`;
|
||||
const seconds = Math.floor(ms / 1000);
|
||||
if (seconds < 60) return `${seconds}s`;
|
||||
const minutes = Math.floor(seconds / 60);
|
||||
const remainSec = seconds % 60;
|
||||
return `${minutes}m ${remainSec}s`;
|
||||
}
|
||||
|
||||
const cursorColumns: ColumnsType<CursorInfo> = [
|
||||
{ title: '任务编码', dataIndex: 'task_code', key: 'task_code', render: (v: string) => <code>{v}</code> },
|
||||
{ title: '最后抓取时间', dataIndex: 'last_fetch_time', key: 'last_fetch_time', render: (v: string | null) => formatTime(v) },
|
||||
{
|
||||
title: '记录数', dataIndex: 'record_count', key: 'record_count', align: 'right',
|
||||
render: (v: number | null) => (v != null ? <Text strong>{v.toLocaleString()}</Text> : '—'),
|
||||
},
|
||||
];
|
||||
|
||||
const runColumns: ColumnsType<RecentRun> = [
|
||||
{ title: '任务名称', dataIndex: 'task_codes', key: 'task_codes', render: (codes: string[]) => codes.join(', ') || '—' },
|
||||
{
|
||||
title: '状态', dataIndex: 'status', key: 'status', width: 90,
|
||||
render: (status: string) => <Tag color={STATUS_COLOR[status] ?? 'default'}>{status}</Tag>,
|
||||
},
|
||||
{ title: '开始时间', dataIndex: 'started_at', key: 'started_at', width: 170, render: (v: string) => formatTime(v) },
|
||||
{ title: '执行时长', dataIndex: 'duration_ms', key: 'duration_ms', width: 100, render: (v: number | null) => formatDuration(v) },
|
||||
];
|
||||
|
||||
const ETLStatus: React.FC = () => {
|
||||
const [cursors, setCursors] = useState<CursorInfo[]>([]);
|
||||
const [runs, setRuns] = useState<RecentRun[]>([]);
|
||||
const [loading, setLoading] = useState(false);
|
||||
|
||||
const load = useCallback(async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const [c, r] = await Promise.all([fetchCursors(), fetchRecentRuns()]);
|
||||
setCursors(c);
|
||||
setRuns(r);
|
||||
} catch { message.error('加载 ETL 状态失败'); }
|
||||
finally { setLoading(false); }
|
||||
}, []);
|
||||
|
||||
useEffect(() => { load(); }, [load]);
|
||||
|
||||
// 统计
|
||||
const successCount = runs.filter((r) => r.status === 'success').length;
|
||||
const failedCount = runs.filter((r) => r.status === 'failed').length;
|
||||
const runningCount = runs.filter((r) => r.status === 'running').length;
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div style={{ marginBottom: 16, display: 'flex', justifyContent: 'space-between', alignItems: 'center' }}>
|
||||
<Title level={4} style={{ margin: 0 }}>
|
||||
<DashboardOutlined style={{ marginRight: 8 }} />
|
||||
ETL 状态监控
|
||||
</Title>
|
||||
<Button icon={<ReloadOutlined />} onClick={load} loading={loading}>刷新</Button>
|
||||
</div>
|
||||
|
||||
{/* 统计卡片 */}
|
||||
<Row gutter={12} style={{ marginBottom: 16 }}>
|
||||
<Col span={6}>
|
||||
<Card size="small">
|
||||
<Statistic title="游标数" value={cursors.length} prefix={<DatabaseOutlined />} />
|
||||
</Card>
|
||||
</Col>
|
||||
<Col span={6}>
|
||||
<Card size="small">
|
||||
<Statistic title="最近执行" value={runs.length} prefix={<PlayCircleOutlined />} />
|
||||
</Card>
|
||||
</Col>
|
||||
<Col span={6}>
|
||||
<Card size="small">
|
||||
<Statistic title="成功" value={successCount} valueStyle={{ color: '#52c41a' }} />
|
||||
</Card>
|
||||
</Col>
|
||||
<Col span={6}>
|
||||
<Card size="small">
|
||||
<Statistic
|
||||
title="失败 / 运行中"
|
||||
value={failedCount}
|
||||
suffix={runningCount > 0 ? ` / ${runningCount}` : ''}
|
||||
valueStyle={{ color: failedCount > 0 ? '#ff4d4f' : undefined }}
|
||||
/>
|
||||
</Card>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
<Card size="small" title="游标状态" style={{ marginBottom: 12 }}>
|
||||
<Table<CursorInfo>
|
||||
rowKey="task_code" columns={cursorColumns} dataSource={cursors}
|
||||
loading={loading} pagination={false} size="small"
|
||||
/>
|
||||
</Card>
|
||||
|
||||
<Card size="small" title="最近执行记录">
|
||||
<Table<RecentRun>
|
||||
rowKey="id" columns={runColumns} dataSource={runs}
|
||||
loading={loading} pagination={{ pageSize: 20, showTotal: (t) => `共 ${t} 条` }}
|
||||
size="small"
|
||||
/>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default ETLStatus;
|
||||
164
apps/admin-web/src/pages/EnvConfig.tsx
Normal file
164
apps/admin-web/src/pages/EnvConfig.tsx
Normal file
@@ -0,0 +1,164 @@
|
||||
/**
|
||||
* 环境配置页面。
|
||||
*
|
||||
* - Ant Design Table 展示键值对,支持 inline 编辑
|
||||
* - 敏感值显示为 ****,编辑时可输入新值
|
||||
* - 顶部按钮栏:刷新、保存、导出
|
||||
*/
|
||||
|
||||
import React, { useEffect, useState, useCallback, useRef } from 'react';
|
||||
import { Table, Button, Input, Tag, Space, message, Card, Typography, Badge } from 'antd';
|
||||
import type { InputRef } from 'antd';
|
||||
import {
|
||||
ReloadOutlined, SaveOutlined, DownloadOutlined, ToolOutlined,
|
||||
} from '@ant-design/icons';
|
||||
import type { ColumnsType } from 'antd/es/table';
|
||||
import type { EnvConfigItem } from '../types';
|
||||
import { fetchEnvConfig, updateEnvConfig, exportEnvConfig } from '../api/envConfig';
|
||||
|
||||
const { Title, Text } = Typography;
|
||||
|
||||
const MASK = '****';
|
||||
|
||||
const EnvConfig: React.FC = () => {
|
||||
const [items, setItems] = useState<EnvConfigItem[]>([]);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [saving, setSaving] = useState(false);
|
||||
const [exporting, setExporting] = useState(false);
|
||||
const [editingKey, setEditingKey] = useState<string | null>(null);
|
||||
const [editValue, setEditValue] = useState('');
|
||||
const [dirtyMap, setDirtyMap] = useState<Record<string, string>>({});
|
||||
const inputRef = useRef<InputRef>(null);
|
||||
|
||||
const load = useCallback(async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const data = await fetchEnvConfig();
|
||||
setItems(data);
|
||||
setDirtyMap({});
|
||||
setEditingKey(null);
|
||||
} catch { message.error('加载环境配置失败'); }
|
||||
finally { setLoading(false); }
|
||||
}, []);
|
||||
|
||||
useEffect(() => { load(); }, [load]);
|
||||
|
||||
const startEdit = (key: string, currentValue: string, isSensitive: boolean) => {
|
||||
setEditingKey(key);
|
||||
setEditValue(isSensitive ? '' : (dirtyMap[key] ?? currentValue));
|
||||
setTimeout(() => { inputRef.current?.focus(); }, 0);
|
||||
};
|
||||
|
||||
const confirmEdit = (key: string, originalValue: string, isSensitive: boolean) => {
|
||||
const trimmed = editValue.trim();
|
||||
if (isSensitive && trimmed === '') {
|
||||
setDirtyMap((prev) => { const next = { ...prev }; delete next[key]; return next; });
|
||||
} else if (!isSensitive && trimmed === originalValue) {
|
||||
setDirtyMap((prev) => { const next = { ...prev }; delete next[key]; return next; });
|
||||
} else {
|
||||
setDirtyMap((prev) => ({ ...prev, [key]: trimmed }));
|
||||
}
|
||||
setEditingKey(null);
|
||||
};
|
||||
|
||||
const cancelEdit = () => { setEditingKey(null); };
|
||||
|
||||
const handleSave = async () => {
|
||||
if (Object.keys(dirtyMap).length === 0) { message.info('没有需要保存的修改'); return; }
|
||||
setSaving(true);
|
||||
try {
|
||||
const payload = items.map((item) => ({
|
||||
key: item.key,
|
||||
value: dirtyMap[item.key] ?? item.value,
|
||||
is_sensitive: item.is_sensitive,
|
||||
}));
|
||||
await updateEnvConfig(payload);
|
||||
message.success('保存成功');
|
||||
await load();
|
||||
} catch { message.error('保存失败'); }
|
||||
finally { setSaving(false); }
|
||||
};
|
||||
|
||||
const handleExport = async () => {
|
||||
setExporting(true);
|
||||
try { await exportEnvConfig(); message.success('导出成功'); }
|
||||
catch { message.error('导出失败'); }
|
||||
finally { setExporting(false); }
|
||||
};
|
||||
|
||||
const columns: ColumnsType<EnvConfigItem> = [
|
||||
{
|
||||
title: '键名', dataIndex: 'key', key: 'key', width: '35%',
|
||||
render: (text: string) => <code style={{ fontSize: 12 }}>{text}</code>,
|
||||
},
|
||||
{
|
||||
title: '值', dataIndex: 'value', key: 'value', width: '50%',
|
||||
render: (_: string, record: EnvConfigItem) => {
|
||||
if (editingKey === record.key) {
|
||||
return (
|
||||
<Input
|
||||
ref={inputRef} value={editValue} size="small"
|
||||
placeholder={record.is_sensitive ? '输入新值(留空则不修改)' : undefined}
|
||||
onChange={(e) => setEditValue(e.target.value)}
|
||||
onPressEnter={() => confirmEdit(record.key, record.value, record.is_sensitive)}
|
||||
onBlur={() => confirmEdit(record.key, record.value, record.is_sensitive)}
|
||||
onKeyDown={(e) => { if (e.key === 'Escape') cancelEdit(); }}
|
||||
style={{ fontFamily: 'monospace' }}
|
||||
/>
|
||||
);
|
||||
}
|
||||
const isDirty = record.key in dirtyMap;
|
||||
const displayValue = record.is_sensitive
|
||||
? (isDirty ? MASK + ' (已修改)' : MASK)
|
||||
: (isDirty ? dirtyMap[record.key] : record.value);
|
||||
return (
|
||||
<span
|
||||
style={{ cursor: 'pointer', color: isDirty ? '#1677ff' : undefined, fontFamily: 'monospace', fontSize: 12 }}
|
||||
onClick={() => startEdit(record.key, record.value, record.is_sensitive)}
|
||||
title="点击编辑"
|
||||
>
|
||||
{displayValue || <Text type="secondary">(空)</Text>}
|
||||
</span>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '类型', dataIndex: 'is_sensitive', key: 'is_sensitive', width: '15%', align: 'center',
|
||||
render: (v: boolean) => v ? <Tag color="red">敏感</Tag> : <Tag color="green">普通</Tag>,
|
||||
},
|
||||
];
|
||||
|
||||
const hasDirty = Object.keys(dirtyMap).length > 0;
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div style={{ marginBottom: 16, display: 'flex', justifyContent: 'space-between', alignItems: 'center' }}>
|
||||
<Title level={4} style={{ margin: 0 }}>
|
||||
<ToolOutlined style={{ marginRight: 8 }} />
|
||||
环境配置
|
||||
</Title>
|
||||
<Space>
|
||||
<Button icon={<ReloadOutlined />} onClick={load} loading={loading}>刷新</Button>
|
||||
<Badge count={hasDirty ? Object.keys(dirtyMap).length : 0} size="small">
|
||||
<Button
|
||||
type="primary" icon={<SaveOutlined />}
|
||||
onClick={handleSave} loading={saving} disabled={!hasDirty}
|
||||
>
|
||||
保存
|
||||
</Button>
|
||||
</Badge>
|
||||
<Button icon={<DownloadOutlined />} onClick={handleExport} loading={exporting}>导出</Button>
|
||||
</Space>
|
||||
</div>
|
||||
|
||||
<Card size="small">
|
||||
<Table<EnvConfigItem>
|
||||
rowKey="key" columns={columns} dataSource={items}
|
||||
loading={loading} pagination={false} size="small"
|
||||
/>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default EnvConfig;
|
||||
138
apps/admin-web/src/pages/LogViewer.tsx
Normal file
138
apps/admin-web/src/pages/LogViewer.tsx
Normal file
@@ -0,0 +1,138 @@
|
||||
/**
|
||||
* 日志查看器页面。
|
||||
*
|
||||
* - 输入执行 ID,通过 WebSocket 实时接收日志
|
||||
* - 支持加载历史日志
|
||||
* - 关键词过滤(大小写不敏感)
|
||||
*/
|
||||
|
||||
import React, { useState, useRef, useCallback, useEffect } from "react";
|
||||
import { Input, Button, Space, message, Card, Typography, Tag, Badge } from "antd";
|
||||
import {
|
||||
LinkOutlined, DisconnectOutlined, HistoryOutlined,
|
||||
FileTextOutlined, SearchOutlined, ClearOutlined,
|
||||
} from "@ant-design/icons";
|
||||
import { apiClient } from "../api/client";
|
||||
import LogStream from "../components/LogStream";
|
||||
|
||||
const { Title, Text } = Typography;
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 纯函数:日志过滤 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
export function filterLogLines(lines: string[], keyword: string): string[] {
|
||||
if (!keyword.trim()) return lines;
|
||||
const lower = keyword.toLowerCase();
|
||||
return lines.filter((line) => line.toLowerCase().includes(lower));
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 页面组件 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
const LogViewer: React.FC = () => {
|
||||
const [executionId, setExecutionId] = useState("");
|
||||
const [lines, setLines] = useState<string[]>([]);
|
||||
const [filterKeyword, setFilterKeyword] = useState("");
|
||||
const [connected, setConnected] = useState(false);
|
||||
const wsRef = useRef<WebSocket | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
return () => { wsRef.current?.close(); };
|
||||
}, []);
|
||||
|
||||
const handleConnect = useCallback(() => {
|
||||
const id = executionId.trim();
|
||||
if (!id) { message.warning("请输入执行 ID"); return; }
|
||||
wsRef.current?.close();
|
||||
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
|
||||
const host = window.location.host;
|
||||
const ws = new WebSocket(`${protocol}//${host}/ws/logs/${id}`);
|
||||
wsRef.current = ws;
|
||||
ws.onopen = () => { setConnected(true); message.success("WebSocket 已连接"); };
|
||||
ws.onmessage = (event) => { setLines((prev) => [...prev, event.data]); };
|
||||
ws.onclose = () => { setConnected(false); };
|
||||
ws.onerror = () => { message.error("WebSocket 连接失败"); setConnected(false); };
|
||||
}, [executionId]);
|
||||
|
||||
const handleDisconnect = useCallback(() => {
|
||||
wsRef.current?.close();
|
||||
wsRef.current = null;
|
||||
setConnected(false);
|
||||
}, []);
|
||||
|
||||
const handleLoadHistory = useCallback(async () => {
|
||||
const id = executionId.trim();
|
||||
if (!id) { message.warning("请输入执行 ID"); return; }
|
||||
try {
|
||||
const { data } = await apiClient.get<{ execution_id: string; output_log: string | null; error_log: string | null }>(
|
||||
`/execution/${id}/logs`
|
||||
);
|
||||
const parts: string[] = [];
|
||||
if (data.output_log) parts.push(data.output_log);
|
||||
if (data.error_log) parts.push(data.error_log);
|
||||
const historyLines = parts.join("\n").split("\n");
|
||||
setLines(historyLines);
|
||||
message.success("历史日志加载完成");
|
||||
} catch { message.error("加载历史日志失败"); }
|
||||
}, [executionId]);
|
||||
|
||||
const handleClear = useCallback(() => { setLines([]); }, []);
|
||||
|
||||
const filteredLines = filterLogLines(lines, filterKeyword);
|
||||
|
||||
return (
|
||||
<div style={{ display: "flex", flexDirection: "column", height: "100%" }}>
|
||||
<div style={{ marginBottom: 12, display: "flex", justifyContent: "space-between", alignItems: "center" }}>
|
||||
<Title level={4} style={{ margin: 0 }}>
|
||||
<FileTextOutlined style={{ marginRight: 8 }} />
|
||||
日志查看器
|
||||
</Title>
|
||||
<Space>
|
||||
{connected && <Badge status="processing" text={<Text type="success">已连接</Text>} />}
|
||||
<Tag>{lines.length} 行</Tag>
|
||||
{filterKeyword && <Tag color="blue">{filteredLines.length} 条匹配</Tag>}
|
||||
</Space>
|
||||
</div>
|
||||
|
||||
{/* 操作栏 */}
|
||||
<Card size="small" style={{ marginBottom: 12 }}>
|
||||
<Space wrap style={{ width: "100%", justifyContent: "space-between" }}>
|
||||
<Space>
|
||||
<Input
|
||||
placeholder="执行 ID"
|
||||
value={executionId}
|
||||
onChange={(e) => setExecutionId(e.target.value)}
|
||||
style={{ width: 280, fontFamily: "monospace" }}
|
||||
onPressEnter={handleConnect}
|
||||
allowClear
|
||||
/>
|
||||
{connected ? (
|
||||
<Button icon={<DisconnectOutlined />} danger onClick={handleDisconnect}>断开</Button>
|
||||
) : (
|
||||
<Button type="primary" icon={<LinkOutlined />} onClick={handleConnect}>连接</Button>
|
||||
)}
|
||||
<Button icon={<HistoryOutlined />} onClick={handleLoadHistory}>加载历史</Button>
|
||||
<Button icon={<ClearOutlined />} onClick={handleClear} disabled={lines.length === 0}>清空</Button>
|
||||
</Space>
|
||||
<Input
|
||||
prefix={<SearchOutlined />}
|
||||
placeholder="过滤关键词..."
|
||||
value={filterKeyword}
|
||||
onChange={(e) => setFilterKeyword(e.target.value)}
|
||||
allowClear
|
||||
style={{ width: 220 }}
|
||||
/>
|
||||
</Space>
|
||||
</Card>
|
||||
|
||||
{/* 日志流 */}
|
||||
<div style={{ flex: 1, minHeight: 0 }}>
|
||||
<LogStream executionId={executionId} lines={filteredLines} />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default LogViewer;
|
||||
92
apps/admin-web/src/pages/Login.tsx
Normal file
92
apps/admin-web/src/pages/Login.tsx
Normal file
@@ -0,0 +1,92 @@
|
||||
/**
|
||||
* 登录页面 — Ant Design Form + Zustand authStore。
|
||||
*/
|
||||
|
||||
import React, { useState } from "react";
|
||||
import { Button, Card, Form, Input, message, Typography, Space } from "antd";
|
||||
import { LockOutlined, UserOutlined } from "@ant-design/icons";
|
||||
import { useNavigate } from "react-router-dom";
|
||||
import { useAuthStore } from "../store/authStore";
|
||||
|
||||
const { Title, Text } = Typography;
|
||||
|
||||
interface LoginFormValues {
|
||||
username: string;
|
||||
password: string;
|
||||
}
|
||||
|
||||
const Login: React.FC = () => {
|
||||
const navigate = useNavigate();
|
||||
const login = useAuthStore((s) => s.login);
|
||||
const [loading, setLoading] = useState(false);
|
||||
|
||||
const onFinish = async (values: LoginFormValues) => {
|
||||
setLoading(true);
|
||||
try {
|
||||
await login(values.username, values.password);
|
||||
message.success("登录成功");
|
||||
navigate("/", { replace: true });
|
||||
} catch (err: unknown) {
|
||||
const detail =
|
||||
(err as { response?: { data?: { detail?: string } } })?.response?.data
|
||||
?.detail ?? "登录失败,请检查用户名和密码";
|
||||
message.error(detail);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
style={{
|
||||
display: "flex",
|
||||
justifyContent: "center",
|
||||
alignItems: "center",
|
||||
minHeight: "100vh",
|
||||
background: "linear-gradient(135deg, #667eea 0%, #764ba2 100%)",
|
||||
}}
|
||||
>
|
||||
<Card
|
||||
style={{
|
||||
width: 400,
|
||||
borderRadius: 12,
|
||||
boxShadow: "0 8px 32px rgba(0,0,0,0.15)",
|
||||
}}
|
||||
>
|
||||
<Space direction="vertical" style={{ width: "100%", textAlign: "center", marginBottom: 24 }}>
|
||||
<Title level={3} style={{ margin: 0 }}>NeoZQYY</Title>
|
||||
<Text type="secondary">管理后台</Text>
|
||||
</Space>
|
||||
|
||||
<Form<LoginFormValues>
|
||||
name="login"
|
||||
onFinish={onFinish}
|
||||
autoComplete="off"
|
||||
size="large"
|
||||
>
|
||||
<Form.Item
|
||||
name="username"
|
||||
rules={[{ required: true, message: "请输入用户名" }]}
|
||||
>
|
||||
<Input prefix={<UserOutlined />} placeholder="用户名" />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="password"
|
||||
rules={[{ required: true, message: "请输入密码" }]}
|
||||
>
|
||||
<Input.Password prefix={<LockOutlined />} placeholder="密码" />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item>
|
||||
<Button type="primary" htmlType="submit" loading={loading} block>
|
||||
登录
|
||||
</Button>
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default Login;
|
||||
573
apps/admin-web/src/pages/TaskConfig.tsx
Normal file
573
apps/admin-web/src/pages/TaskConfig.tsx
Normal file
@@ -0,0 +1,573 @@
|
||||
/**
|
||||
* ETL 任务配置页面。
|
||||
*
|
||||
* 提供 Flow 选择、处理模式、时间窗口、高级选项等配置区域,
|
||||
* 以及连接器/Store 选择、任务选择、DWD 表选择、CLI 命令预览和任务提交功能。
|
||||
*/
|
||||
|
||||
import React, { useState, useEffect, useMemo } from "react";
|
||||
import {
|
||||
Card,
|
||||
Radio,
|
||||
Checkbox,
|
||||
InputNumber,
|
||||
DatePicker,
|
||||
Button,
|
||||
Space,
|
||||
Typography,
|
||||
Input,
|
||||
message,
|
||||
Row,
|
||||
Col,
|
||||
Badge,
|
||||
Alert,
|
||||
TreeSelect,
|
||||
Tooltip,
|
||||
Segmented,
|
||||
} from "antd";
|
||||
import {
|
||||
SendOutlined,
|
||||
ThunderboltOutlined,
|
||||
CodeOutlined,
|
||||
SettingOutlined,
|
||||
ClockCircleOutlined,
|
||||
SyncOutlined,
|
||||
ShopOutlined,
|
||||
ApiOutlined,
|
||||
} from "@ant-design/icons";
|
||||
import { useNavigate } from "react-router-dom";
|
||||
import TaskSelector from "../components/TaskSelector";
|
||||
import { validateTaskConfig } from "../api/tasks";
|
||||
import { submitToQueue, executeDirectly } from "../api/execution";
|
||||
import { useAuthStore } from "../store/authStore";
|
||||
import type { RadioChangeEvent } from "antd";
|
||||
import type { Dayjs } from "dayjs";
|
||||
import type { TaskConfig as TaskConfigType } from "../types";
|
||||
|
||||
const { Title, Text } = Typography;
|
||||
const { TextArea } = Input;
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Flow 定义 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
const FLOW_DEFINITIONS: Record<string, { name: string; layers: string[]; desc: string }> = {
|
||||
api_ods: { name: "API → ODS", layers: ["ODS"], desc: "仅抓取原始数据" },
|
||||
api_ods_dwd: { name: "API → ODS → DWD", layers: ["ODS", "DWD"], desc: "抓取并清洗装载" },
|
||||
api_full: { name: "API → ODS → DWD → DWS → INDEX", layers: ["ODS", "DWD", "DWS", "INDEX"], desc: "全链路执行" },
|
||||
ods_dwd: { name: "ODS → DWD", layers: ["DWD"], desc: "仅清洗装载" },
|
||||
dwd_dws: { name: "DWD → DWS汇总", layers: ["DWS"], desc: "仅汇总计算" },
|
||||
dwd_dws_index: { name: "DWD → DWS → INDEX", layers: ["DWS", "INDEX"], desc: "汇总+指数" },
|
||||
dwd_index: { name: "DWD → INDEX", layers: ["INDEX"], desc: "仅指数计算" },
|
||||
};
|
||||
|
||||
export function getFlowLayers(flowId: string): string[] {
|
||||
return FLOW_DEFINITIONS[flowId]?.layers ?? [];
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 处理模式 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
const PROCESSING_MODES = [
|
||||
{ value: "increment_only", label: "仅增量", desc: "按游标增量抓取和装载" },
|
||||
{ value: "verify_only", label: "校验并修复", desc: "对比源和目标,修复差异" },
|
||||
{ value: "increment_verify", label: "增量+校验", desc: "先增量再校验" },
|
||||
] as const;
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 时间窗口 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
type WindowMode = "lookback" | "custom";
|
||||
|
||||
const WINDOW_SPLIT_OPTIONS = [
|
||||
{ value: 0, label: "不切分" },
|
||||
{ value: 1, label: "1天" },
|
||||
{ value: 10, label: "10天" },
|
||||
{ value: 30, label: "30天" },
|
||||
] as const;
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 连接器 → 门店 树形数据结构 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
/** 连接器定义:每个连接器下挂载门店列表 */
|
||||
interface ConnectorDef {
|
||||
id: string;
|
||||
label: string;
|
||||
icon: React.ReactNode;
|
||||
}
|
||||
|
||||
const CONNECTOR_DEFS: ConnectorDef[] = [
|
||||
{ id: "feiqiu", label: "飞球", icon: <ApiOutlined /> },
|
||||
];
|
||||
|
||||
/** 构建 TreeSelect 的 treeData,连接器为父节点,门店为子节点 */
|
||||
function buildConnectorStoreTree(
|
||||
connectors: ConnectorDef[],
|
||||
siteId: number | null,
|
||||
): { treeData: { title: React.ReactNode; value: string; key: string; children?: { title: React.ReactNode; value: string; key: string }[] }[]; allValues: string[] } {
|
||||
const allValues: string[] = [];
|
||||
const treeData = connectors.map((c) => {
|
||||
// 每个连接器下挂载当前用户的门店(未来可扩展为多门店)
|
||||
const stores = siteId
|
||||
? [{ title: (<Space size={4}><ShopOutlined /><span>门店 {siteId}</span></Space>), value: `${c.id}::${siteId}`, key: `${c.id}::${siteId}` }]
|
||||
: [];
|
||||
stores.forEach((s) => allValues.push(s.value));
|
||||
return {
|
||||
title: (<Space size={4}>{c.icon}<span>{c.label}</span></Space>),
|
||||
value: c.id,
|
||||
key: c.id,
|
||||
children: stores,
|
||||
};
|
||||
});
|
||||
return { treeData, allValues };
|
||||
}
|
||||
|
||||
/** 从选中值中解析出 store_id 列表 */
|
||||
function parseSelectedStoreIds(selected: string[]): number[] {
|
||||
const ids: number[] = [];
|
||||
for (const v of selected) {
|
||||
// 格式: "connector::storeId"
|
||||
const parts = v.split("::");
|
||||
if (parts.length === 2) {
|
||||
const num = Number(parts[1]);
|
||||
if (!isNaN(num)) ids.push(num);
|
||||
}
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 页面组件 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
const TaskConfig: React.FC = () => {
|
||||
const navigate = useNavigate();
|
||||
const user = useAuthStore((s) => s.user);
|
||||
|
||||
/* ---------- 连接器 & Store 树形选择 ---------- */
|
||||
const { treeData: connectorTreeData, allValues: allConnectorStoreValues } = useMemo(
|
||||
() => buildConnectorStoreTree(CONNECTOR_DEFS, user?.site_id ?? null),
|
||||
[user?.site_id],
|
||||
);
|
||||
// 默认全选
|
||||
const [selectedConnectorStores, setSelectedConnectorStores] = useState<string[]>([]);
|
||||
|
||||
// 初始化时默认全选
|
||||
useEffect(() => {
|
||||
if (selectedConnectorStores.length === 0 && allConnectorStoreValues.length > 0) {
|
||||
setSelectedConnectorStores(allConnectorStoreValues);
|
||||
}
|
||||
}, [allConnectorStoreValues]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
// 从选中值解析 store_id(取第一个,当前单门店场景)
|
||||
const selectedStoreIds = useMemo(() => parseSelectedStoreIds(selectedConnectorStores), [selectedConnectorStores]);
|
||||
const effectiveStoreId = selectedStoreIds.length === 1 ? selectedStoreIds[0] : null;
|
||||
|
||||
/* ---------- Flow ---------- */
|
||||
const [flow, setFlow] = useState<string>("api_ods_dwd");
|
||||
|
||||
/* ---------- 处理模式 ---------- */
|
||||
const [processingMode, setProcessingMode] = useState<string>("increment_only");
|
||||
const [fetchBeforeVerify, setFetchBeforeVerify] = useState(false);
|
||||
|
||||
/* ---------- 时间窗口 ---------- */
|
||||
const [windowMode, setWindowMode] = useState<WindowMode>("lookback");
|
||||
const [lookbackHours, setLookbackHours] = useState<number>(24);
|
||||
const [overlapSeconds, setOverlapSeconds] = useState<number>(600);
|
||||
const [windowStart, setWindowStart] = useState<Dayjs | null>(null);
|
||||
const [windowEnd, setWindowEnd] = useState<Dayjs | null>(null);
|
||||
const [windowSplitDays, setWindowSplitDays] = useState<number>(0);
|
||||
|
||||
/* ---------- 任务选择 ---------- */
|
||||
const [selectedTasks, setSelectedTasks] = useState<string[]>([]);
|
||||
const [selectedDwdTables, setSelectedDwdTables] = useState<string[]>([]);
|
||||
|
||||
/* ---------- 高级选项 ---------- */
|
||||
const [dryRun, setDryRun] = useState(false);
|
||||
const [forceFull, setForceFull] = useState(false);
|
||||
const [useLocalJson, setUseLocalJson] = useState(false);
|
||||
|
||||
/* ---------- CLI 预览 ---------- */
|
||||
const [cliCommand, setCliCommand] = useState<string>("");
|
||||
const [cliEdited, setCliEdited] = useState(false);
|
||||
const [cliLoading, setCliLoading] = useState(false);
|
||||
|
||||
/* ---------- 提交状态 ---------- */
|
||||
const [submitting, setSubmitting] = useState(false);
|
||||
|
||||
/* ---------- 派生状态 ---------- */
|
||||
const layers = getFlowLayers(flow);
|
||||
const showVerifyOption = processingMode === "verify_only";
|
||||
|
||||
/* ---------- 构建 TaskConfig 对象 ---------- */
|
||||
const buildTaskConfig = (): TaskConfigType => ({
|
||||
tasks: selectedTasks,
|
||||
pipeline: flow,
|
||||
processing_mode: processingMode,
|
||||
pipeline_flow: "FULL",
|
||||
dry_run: dryRun,
|
||||
window_mode: windowMode,
|
||||
window_start: windowMode === "custom" && windowStart ? windowStart.format("YYYY-MM-DD") : null,
|
||||
window_end: windowMode === "custom" && windowEnd ? windowEnd.format("YYYY-MM-DD") : null,
|
||||
window_split: windowSplitDays > 0 ? "day" : null,
|
||||
window_split_days: windowSplitDays > 0 ? windowSplitDays : null,
|
||||
lookback_hours: lookbackHours,
|
||||
overlap_seconds: overlapSeconds,
|
||||
fetch_before_verify: fetchBeforeVerify,
|
||||
skip_ods_when_fetch_before_verify: false,
|
||||
ods_use_local_json: useLocalJson,
|
||||
store_id: effectiveStoreId,
|
||||
dwd_only_tables: selectedDwdTables.length > 0 ? selectedDwdTables : null,
|
||||
force_full: forceFull,
|
||||
extra_args: {},
|
||||
});
|
||||
|
||||
/* ---------- 自动刷新 CLI 预览 ---------- */
|
||||
const refreshCli = async () => {
|
||||
setCliLoading(true);
|
||||
try {
|
||||
const { command } = await validateTaskConfig(buildTaskConfig());
|
||||
setCliCommand(command);
|
||||
setCliEdited(false);
|
||||
} catch {
|
||||
// 静默失败,保留上次命令
|
||||
} finally {
|
||||
setCliLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
// 配置变化时自动刷新 CLI(防抖)
|
||||
useEffect(() => {
|
||||
if (cliEdited) return; // 用户手动编辑过则不自动刷新
|
||||
const timer = setTimeout(refreshCli, 500);
|
||||
return () => clearTimeout(timer);
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [flow, processingMode, fetchBeforeVerify, windowMode, lookbackHours, overlapSeconds,
|
||||
windowStart, windowEnd, windowSplitDays, selectedTasks, selectedDwdTables,
|
||||
dryRun, forceFull, useLocalJson, selectedConnectorStores]);
|
||||
|
||||
/* ---------- 事件处理 ---------- */
|
||||
const handleFlowChange = (e: RadioChangeEvent) => setFlow(e.target.value);
|
||||
|
||||
const handleSubmitToQueue = async () => {
|
||||
setSubmitting(true);
|
||||
try {
|
||||
await submitToQueue(buildTaskConfig());
|
||||
message.success("已提交到执行队列");
|
||||
navigate("/task-manager");
|
||||
} catch (err: unknown) {
|
||||
const msg = err instanceof Error ? err.message : "提交失败";
|
||||
message.error(`提交到队列失败:${msg}`);
|
||||
} finally {
|
||||
setSubmitting(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleExecuteDirectly = async () => {
|
||||
setSubmitting(true);
|
||||
try {
|
||||
await executeDirectly(buildTaskConfig());
|
||||
message.success("任务已开始执行");
|
||||
navigate("/task-manager");
|
||||
} catch (err: unknown) {
|
||||
const msg = err instanceof Error ? err.message : "执行失败";
|
||||
message.error(`直接执行失败:${msg}`);
|
||||
} finally {
|
||||
setSubmitting(false);
|
||||
}
|
||||
};
|
||||
|
||||
/* ---------- 样式常量 ---------- */
|
||||
const cardStyle = { marginBottom: 12 };
|
||||
const sectionTitleStyle: React.CSSProperties = {
|
||||
fontSize: 13, fontWeight: 500, color: "#666", marginBottom: 8, display: "block",
|
||||
};
|
||||
|
||||
return (
|
||||
<div style={{ maxWidth: 960, margin: "0 auto" }}>
|
||||
{/* ---- 页面标题 ---- */}
|
||||
<div style={{ marginBottom: 16, display: "flex", alignItems: "center", justifyContent: "space-between" }}>
|
||||
<Title level={4} style={{ margin: 0 }}>
|
||||
<SettingOutlined style={{ marginRight: 8 }} />
|
||||
任务配置
|
||||
</Title>
|
||||
<Space>
|
||||
<Badge count={selectedTasks.length} size="small" offset={[-4, 0]}>
|
||||
<Text type="secondary">已选任务</Text>
|
||||
</Badge>
|
||||
</Space>
|
||||
</div>
|
||||
|
||||
{/* ---- 第一行:连接器/门店 + Flow ---- */}
|
||||
<Row gutter={12}>
|
||||
<Col span={8}>
|
||||
<Card size="small" title={<Space size={4}><ApiOutlined />连接器 / 门店</Space>} style={cardStyle}>
|
||||
<TreeSelect
|
||||
treeData={connectorTreeData}
|
||||
value={selectedConnectorStores}
|
||||
onChange={setSelectedConnectorStores}
|
||||
treeCheckable
|
||||
treeDefaultExpandAll
|
||||
showCheckedStrategy={TreeSelect.SHOW_CHILD}
|
||||
placeholder="选择连接器和门店"
|
||||
style={{ width: "100%" }}
|
||||
maxTagCount={3}
|
||||
maxTagPlaceholder={(omitted) => `+${omitted.length} 项`}
|
||||
treeCheckStrictly={false}
|
||||
/>
|
||||
<Text type="secondary" style={{ fontSize: 11, marginTop: 6, display: "block" }}>
|
||||
{selectedStoreIds.length === 0
|
||||
? "未选择门店,将使用 JWT 默认值"
|
||||
: `已选 ${selectedStoreIds.length} 个门店`}
|
||||
</Text>
|
||||
</Card>
|
||||
</Col>
|
||||
<Col span={16}>
|
||||
<Card size="small" title="执行流程 (Flow)" style={cardStyle}>
|
||||
<Radio.Group value={flow} onChange={handleFlowChange} style={{ width: "100%" }}>
|
||||
<Row gutter={[0, 4]}>
|
||||
{Object.entries(FLOW_DEFINITIONS).map(([id, def]) => (
|
||||
<Col span={12} key={id}>
|
||||
<Tooltip title={def.desc}>
|
||||
<Radio value={id}>
|
||||
<Text strong style={{ fontSize: 12 }}>{id}</Text>
|
||||
</Radio>
|
||||
</Tooltip>
|
||||
</Col>
|
||||
))}
|
||||
</Row>
|
||||
</Radio.Group>
|
||||
<div style={{ marginTop: 6, padding: "4px 8px", background: "#f6f8fa", borderRadius: 4 }}>
|
||||
<Text type="secondary" style={{ fontSize: 12 }}>
|
||||
{layers.join(" → ") || "—"}
|
||||
</Text>
|
||||
</div>
|
||||
</Card>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
{/* ---- 第二行:处理模式 + 时间窗口 ---- */}
|
||||
<Row gutter={12}>
|
||||
<Col span={8}>
|
||||
<Card size="small" title="处理模式" style={cardStyle}>
|
||||
<Radio.Group
|
||||
value={processingMode}
|
||||
onChange={(e) => {
|
||||
setProcessingMode(e.target.value);
|
||||
if (e.target.value === "increment_only") setFetchBeforeVerify(false);
|
||||
}}
|
||||
>
|
||||
<Space direction="vertical" style={{ width: "100%" }}>
|
||||
{PROCESSING_MODES.map((m) => (
|
||||
<Radio key={m.value} value={m.value}>
|
||||
<Text strong>{m.label}</Text>
|
||||
<br />
|
||||
<Text type="secondary" style={{ fontSize: 12 }}>{m.desc}</Text>
|
||||
</Radio>
|
||||
))}
|
||||
</Space>
|
||||
</Radio.Group>
|
||||
{showVerifyOption && (
|
||||
<Checkbox
|
||||
checked={fetchBeforeVerify}
|
||||
onChange={(e) => setFetchBeforeVerify(e.target.checked)}
|
||||
style={{ marginTop: 8 }}
|
||||
>
|
||||
校验前从 API 获取
|
||||
</Checkbox>
|
||||
)}
|
||||
</Card>
|
||||
</Col>
|
||||
<Col span={16}>
|
||||
<Card
|
||||
size="small"
|
||||
title={<><ClockCircleOutlined style={{ marginRight: 6 }} />时间窗口</>}
|
||||
style={cardStyle}
|
||||
>
|
||||
<Row gutter={16}>
|
||||
<Col span={24}>
|
||||
<Segmented
|
||||
value={windowMode}
|
||||
onChange={(v) => setWindowMode(v as WindowMode)}
|
||||
options={[
|
||||
{ value: "lookback", label: "回溯模式" },
|
||||
{ value: "custom", label: "自定义范围" },
|
||||
]}
|
||||
style={{ marginBottom: 12 }}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
{windowMode === "lookback" ? (
|
||||
<Row gutter={16}>
|
||||
<Col span={12}>
|
||||
<Text style={sectionTitleStyle}>回溯小时数</Text>
|
||||
<InputNumber
|
||||
min={1} max={720} value={lookbackHours}
|
||||
onChange={(v) => setLookbackHours(v ?? 24)}
|
||||
style={{ width: "100%" }}
|
||||
addonAfter="小时"
|
||||
/>
|
||||
</Col>
|
||||
<Col span={12}>
|
||||
<Text style={sectionTitleStyle}>冗余秒数</Text>
|
||||
<InputNumber
|
||||
min={0} max={7200} value={overlapSeconds}
|
||||
onChange={(v) => setOverlapSeconds(v ?? 600)}
|
||||
style={{ width: "100%" }}
|
||||
addonAfter="秒"
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
) : (
|
||||
<Row gutter={16}>
|
||||
<Col span={12}>
|
||||
<Text style={sectionTitleStyle}>开始日期</Text>
|
||||
<DatePicker
|
||||
value={windowStart} onChange={setWindowStart}
|
||||
placeholder="选择开始日期" style={{ width: "100%" }}
|
||||
/>
|
||||
</Col>
|
||||
<Col span={12}>
|
||||
<Text style={sectionTitleStyle}>结束日期</Text>
|
||||
<DatePicker
|
||||
value={windowEnd} onChange={setWindowEnd}
|
||||
placeholder="选择结束日期" style={{ width: "100%" }}
|
||||
status={windowStart && windowEnd && windowEnd.isBefore(windowStart) ? "error" : undefined}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
)}
|
||||
|
||||
<div style={{ marginTop: 12 }}>
|
||||
<Text style={sectionTitleStyle}>窗口切分</Text>
|
||||
<Radio.Group
|
||||
value={windowSplitDays}
|
||||
onChange={(e) => setWindowSplitDays(e.target.value)}
|
||||
>
|
||||
{WINDOW_SPLIT_OPTIONS.map((opt) => (
|
||||
<Radio.Button key={opt.value} value={opt.value}>{opt.label}</Radio.Button>
|
||||
))}
|
||||
</Radio.Group>
|
||||
</div>
|
||||
</Card>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
{/* ---- 高级选项(带描述) ---- */}
|
||||
<Card size="small" title="高级选项" style={cardStyle}>
|
||||
<Row gutter={[24, 8]}>
|
||||
<Col span={12}>
|
||||
<Checkbox checked={dryRun} onChange={(e) => setDryRun(e.target.checked)}>
|
||||
<Text strong>dry-run</Text>
|
||||
</Checkbox>
|
||||
<div style={{ marginLeft: 24 }}>
|
||||
<Text type="secondary" style={{ fontSize: 12 }}>模拟执行,走完整流程但不写入数据库</Text>
|
||||
</div>
|
||||
</Col>
|
||||
<Col span={12}>
|
||||
<Checkbox checked={forceFull} onChange={(e) => setForceFull(e.target.checked)}>
|
||||
<Text strong>force-full</Text>
|
||||
</Checkbox>
|
||||
<div style={{ marginLeft: 24 }}>
|
||||
<Text type="secondary" style={{ fontSize: 12 }}>强制全量,跳过 hash 去重和变更对比</Text>
|
||||
</div>
|
||||
</Col>
|
||||
<Col span={12}>
|
||||
<Checkbox checked={useLocalJson} onChange={(e) => setUseLocalJson(e.target.checked)}>
|
||||
<Text strong>本地 JSON</Text>
|
||||
</Checkbox>
|
||||
<div style={{ marginLeft: 24 }}>
|
||||
<Text type="secondary" style={{ fontSize: 12 }}>离线模式,从本地 JSON 回放(等同 --data-source offline)</Text>
|
||||
</div>
|
||||
</Col>
|
||||
</Row>
|
||||
</Card>
|
||||
|
||||
{/* ---- 任务选择(含 DWD 表过滤) ---- */}
|
||||
<Card size="small" title="任务选择" style={cardStyle}>
|
||||
<TaskSelector
|
||||
layers={layers}
|
||||
selectedTasks={selectedTasks}
|
||||
onTasksChange={setSelectedTasks}
|
||||
selectedDwdTables={selectedDwdTables}
|
||||
onDwdTablesChange={setSelectedDwdTables}
|
||||
/>
|
||||
</Card>
|
||||
|
||||
{/* ---- CLI 命令预览(内嵌可编辑) ---- */}
|
||||
<Card
|
||||
size="small"
|
||||
title={
|
||||
<Space>
|
||||
<CodeOutlined />
|
||||
<span>CLI 命令预览</span>
|
||||
{cliEdited && <Text type="warning" style={{ fontSize: 12 }}>(已手动编辑)</Text>}
|
||||
</Space>
|
||||
}
|
||||
extra={
|
||||
<Button
|
||||
size="small"
|
||||
icon={<SyncOutlined spin={cliLoading} />}
|
||||
onClick={() => { setCliEdited(false); refreshCli(); }}
|
||||
>
|
||||
重新生成
|
||||
</Button>
|
||||
}
|
||||
style={cardStyle}
|
||||
>
|
||||
<TextArea
|
||||
value={cliCommand}
|
||||
onChange={(e) => { setCliCommand(e.target.value); setCliEdited(true); }}
|
||||
autoSize={{ minRows: 2, maxRows: 6 }}
|
||||
style={{
|
||||
fontFamily: "'Cascadia Code', 'Fira Code', Consolas, monospace",
|
||||
fontSize: 13,
|
||||
background: "#1e1e1e",
|
||||
color: "#d4d4d4",
|
||||
border: "none",
|
||||
borderRadius: 4,
|
||||
}}
|
||||
placeholder="配置变更后自动生成 CLI 命令..."
|
||||
/>
|
||||
{cliEdited && (
|
||||
<Alert
|
||||
type="info"
|
||||
showIcon
|
||||
message="已手动编辑命令,配置变更不会自动覆盖。点击「重新生成」恢复自动模式。"
|
||||
style={{ marginTop: 8 }}
|
||||
banner
|
||||
/>
|
||||
)}
|
||||
</Card>
|
||||
|
||||
{/* ---- 操作按钮 ---- */}
|
||||
<Card size="small" style={{ marginBottom: 24 }}>
|
||||
<Space size="middle">
|
||||
<Button
|
||||
type="primary"
|
||||
size="large"
|
||||
icon={<SendOutlined />}
|
||||
loading={submitting}
|
||||
onClick={handleSubmitToQueue}
|
||||
>
|
||||
提交到队列
|
||||
</Button>
|
||||
<Button
|
||||
size="large"
|
||||
icon={<ThunderboltOutlined />}
|
||||
loading={submitting}
|
||||
onClick={handleExecuteDirectly}
|
||||
>
|
||||
直接执行
|
||||
</Button>
|
||||
</Space>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default TaskConfig;
|
||||
255
apps/admin-web/src/pages/TaskManager.tsx
Normal file
255
apps/admin-web/src/pages/TaskManager.tsx
Normal file
@@ -0,0 +1,255 @@
|
||||
/**
|
||||
* 任务管理页面。
|
||||
*
|
||||
* 三个 Tab:队列、调度、历史
|
||||
*/
|
||||
|
||||
import React, { useEffect, useState, useCallback } from 'react';
|
||||
import {
|
||||
Tabs, Table, Tag, Button, Popconfirm, Space, message, Drawer,
|
||||
Typography, Descriptions, Empty,
|
||||
} from 'antd';
|
||||
import {
|
||||
ReloadOutlined, DeleteOutlined, StopOutlined,
|
||||
UnorderedListOutlined, ClockCircleOutlined, HistoryOutlined,
|
||||
} from '@ant-design/icons';
|
||||
import type { ColumnsType } from 'antd/es/table';
|
||||
import type { QueuedTask, ExecutionLog } from '../types';
|
||||
import {
|
||||
fetchQueue, fetchHistory, deleteFromQueue, cancelExecution,
|
||||
} from '../api/execution';
|
||||
import ScheduleTab from '../components/ScheduleTab';
|
||||
|
||||
const { Title, Text } = Typography;
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 状态颜色映射 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
const STATUS_COLOR: Record<string, string> = {
|
||||
pending: 'default',
|
||||
running: 'processing',
|
||||
success: 'success',
|
||||
failed: 'error',
|
||||
cancelled: 'warning',
|
||||
};
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 工具函数 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function fmtTime(iso: string | null | undefined): string {
|
||||
if (!iso) return '—';
|
||||
return new Date(iso).toLocaleString('zh-CN');
|
||||
}
|
||||
|
||||
function fmtDuration(ms: number | null | undefined): string {
|
||||
if (ms == null) return '—';
|
||||
if (ms < 1000) return `${ms}ms`;
|
||||
const sec = ms / 1000;
|
||||
if (sec < 60) return `${sec.toFixed(1)}s`;
|
||||
const min = Math.floor(sec / 60);
|
||||
const remainSec = Math.round(sec % 60);
|
||||
return `${min}m${remainSec}s`;
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 队列 Tab */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
const QueueTab: React.FC = () => {
|
||||
const [data, setData] = useState<QueuedTask[]>([]);
|
||||
const [loading, setLoading] = useState(false);
|
||||
|
||||
const load = useCallback(async () => {
|
||||
setLoading(true);
|
||||
try { setData(await fetchQueue()); }
|
||||
catch { message.error('加载队列失败'); }
|
||||
finally { setLoading(false); }
|
||||
}, []);
|
||||
|
||||
useEffect(() => { load(); }, [load]);
|
||||
|
||||
const handleDelete = async (id: string) => {
|
||||
try { await deleteFromQueue(id); message.success('已删除'); load(); }
|
||||
catch { message.error('删除失败'); }
|
||||
};
|
||||
|
||||
const handleCancel = async (id: string) => {
|
||||
try { await cancelExecution(id); message.success('已取消'); load(); }
|
||||
catch { message.error('取消失败'); }
|
||||
};
|
||||
|
||||
const columns: ColumnsType<QueuedTask> = [
|
||||
{
|
||||
title: '任务', dataIndex: ['config', 'tasks'], key: 'tasks',
|
||||
render: (tasks: string[]) => (
|
||||
<Text style={{ maxWidth: 300 }} ellipsis={{ tooltip: tasks?.join(', ') }}>
|
||||
{tasks?.join(', ') ?? '—'}
|
||||
</Text>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: 'Flow', dataIndex: ['config', 'pipeline'], key: 'pipeline', width: 120,
|
||||
render: (v: string) => <Tag>{v}</Tag>,
|
||||
},
|
||||
{
|
||||
title: '状态', dataIndex: 'status', key: 'status', width: 90,
|
||||
render: (s: string) => <Tag color={STATUS_COLOR[s] ?? 'default'}>{s}</Tag>,
|
||||
},
|
||||
{ title: '位置', dataIndex: 'position', key: 'position', width: 60, align: 'center' },
|
||||
{ title: '创建时间', dataIndex: 'created_at', key: 'created_at', width: 170, render: fmtTime },
|
||||
{
|
||||
title: '操作', key: 'action', width: 100, align: 'center',
|
||||
render: (_: unknown, record: QueuedTask) => {
|
||||
if (record.status === 'pending') {
|
||||
return (
|
||||
<Popconfirm title="确认删除?" onConfirm={() => handleDelete(record.id)}>
|
||||
<Button type="link" danger icon={<DeleteOutlined />} size="small">删除</Button>
|
||||
</Popconfirm>
|
||||
);
|
||||
}
|
||||
if (record.status === 'running') {
|
||||
return (
|
||||
<Popconfirm title="确认取消执行?" onConfirm={() => handleCancel(record.id)}>
|
||||
<Button type="link" danger icon={<StopOutlined />} size="small">取消</Button>
|
||||
</Popconfirm>
|
||||
);
|
||||
}
|
||||
return null;
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
return (
|
||||
<>
|
||||
<div style={{ marginBottom: 12, display: 'flex', justifyContent: 'space-between', alignItems: 'center' }}>
|
||||
<Text type="secondary">共 {data.length} 个任务</Text>
|
||||
<Button icon={<ReloadOutlined />} onClick={load} loading={loading}>刷新</Button>
|
||||
</div>
|
||||
<Table<QueuedTask>
|
||||
rowKey="id" columns={columns} dataSource={data}
|
||||
loading={loading} pagination={false} size="small"
|
||||
locale={{ emptyText: <Empty description="队列为空" /> }}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 历史 Tab */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
const HistoryTab: React.FC = () => {
|
||||
const [data, setData] = useState<ExecutionLog[]>([]);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [detail, setDetail] = useState<ExecutionLog | null>(null);
|
||||
|
||||
const load = useCallback(async () => {
|
||||
setLoading(true);
|
||||
try { setData(await fetchHistory()); }
|
||||
catch { message.error('加载历史记录失败'); }
|
||||
finally { setLoading(false); }
|
||||
}, []);
|
||||
|
||||
useEffect(() => { load(); }, [load]);
|
||||
|
||||
const columns: ColumnsType<ExecutionLog> = [
|
||||
{
|
||||
title: '任务', dataIndex: 'task_codes', key: 'task_codes',
|
||||
render: (codes: string[]) => (
|
||||
<Text style={{ maxWidth: 300 }} ellipsis={{ tooltip: codes?.join(', ') }}>
|
||||
{codes?.join(', ') ?? '—'}
|
||||
</Text>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: '状态', dataIndex: 'status', key: 'status', width: 90,
|
||||
render: (s: string) => <Tag color={STATUS_COLOR[s] ?? 'default'}>{s}</Tag>,
|
||||
},
|
||||
{ title: '开始时间', dataIndex: 'started_at', key: 'started_at', width: 170, render: fmtTime },
|
||||
{ title: '时长', dataIndex: 'duration_ms', key: 'duration_ms', width: 90, render: fmtDuration },
|
||||
{
|
||||
title: '退出码', dataIndex: 'exit_code', key: 'exit_code', width: 70, align: 'center',
|
||||
render: (v: number | null) => v != null ? (
|
||||
<Tag color={v === 0 ? 'success' : 'error'}>{v}</Tag>
|
||||
) : '—',
|
||||
},
|
||||
];
|
||||
|
||||
return (
|
||||
<>
|
||||
<div style={{ marginBottom: 12, display: 'flex', justifyContent: 'space-between', alignItems: 'center' }}>
|
||||
<Text type="secondary">最近 {data.length} 条记录</Text>
|
||||
<Button icon={<ReloadOutlined />} onClick={load} loading={loading}>刷新</Button>
|
||||
</div>
|
||||
<Table<ExecutionLog>
|
||||
rowKey="id" columns={columns} dataSource={data}
|
||||
loading={loading} pagination={{ pageSize: 20, showTotal: (t) => `共 ${t} 条` }}
|
||||
size="small"
|
||||
onRow={(record) => ({ onClick: () => setDetail(record), style: { cursor: 'pointer' } })}
|
||||
/>
|
||||
|
||||
<Drawer
|
||||
title="执行详情" open={!!detail} onClose={() => setDetail(null)}
|
||||
width={520}
|
||||
>
|
||||
{detail && (
|
||||
<Descriptions column={1} bordered size="small">
|
||||
<Descriptions.Item label="任务">{detail.task_codes?.join(', ')}</Descriptions.Item>
|
||||
<Descriptions.Item label="状态">
|
||||
<Tag color={STATUS_COLOR[detail.status] ?? 'default'}>{detail.status}</Tag>
|
||||
</Descriptions.Item>
|
||||
<Descriptions.Item label="开始时间">{fmtTime(detail.started_at)}</Descriptions.Item>
|
||||
<Descriptions.Item label="结束时间">{fmtTime(detail.finished_at)}</Descriptions.Item>
|
||||
<Descriptions.Item label="时长">{fmtDuration(detail.duration_ms)}</Descriptions.Item>
|
||||
<Descriptions.Item label="退出码">
|
||||
{detail.exit_code != null ? (
|
||||
<Tag color={detail.exit_code === 0 ? 'success' : 'error'}>{detail.exit_code}</Tag>
|
||||
) : '—'}
|
||||
</Descriptions.Item>
|
||||
<Descriptions.Item label="命令">
|
||||
<code style={{ wordBreak: 'break-all', fontSize: 12 }}>{detail.command || '—'}</code>
|
||||
</Descriptions.Item>
|
||||
</Descriptions>
|
||||
)}
|
||||
</Drawer>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 主组件 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
const TaskManager: React.FC = () => {
|
||||
const items = [
|
||||
{
|
||||
key: 'queue',
|
||||
label: <Space><UnorderedListOutlined />队列</Space>,
|
||||
children: <QueueTab />,
|
||||
},
|
||||
{
|
||||
key: 'schedule',
|
||||
label: <Space><ClockCircleOutlined />调度</Space>,
|
||||
children: <ScheduleTab />,
|
||||
},
|
||||
{
|
||||
key: 'history',
|
||||
label: <Space><HistoryOutlined />历史</Space>,
|
||||
children: <HistoryTab />,
|
||||
},
|
||||
];
|
||||
|
||||
return (
|
||||
<div>
|
||||
<Title level={4} style={{ marginBottom: 16 }}>
|
||||
<UnorderedListOutlined style={{ marginRight: 8 }} />
|
||||
任务管理
|
||||
</Title>
|
||||
<Tabs defaultActiveKey="queue" items={items} />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default TaskManager;
|
||||
137
apps/admin-web/src/store/authStore.ts
Normal file
137
apps/admin-web/src/store/authStore.ts
Normal file
@@ -0,0 +1,137 @@
|
||||
/**
|
||||
* 认证状态管理 — Zustand store。
|
||||
*
|
||||
* - 存储 JWT 令牌和用户信息
|
||||
* - login / logout / hydrate 三个核心方法
|
||||
* - 令牌同步到 localStorage,与 client.ts 拦截器共用同一 key
|
||||
*/
|
||||
|
||||
import { create } from "zustand";
|
||||
import { apiClient } from "../api/client";
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 类型 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
/** 当前登录用户信息(从 JWT payload 解码或登录响应中获取) */
|
||||
export interface AuthUser {
|
||||
user_id: number;
|
||||
username: string;
|
||||
display_name: string;
|
||||
site_id: number;
|
||||
}
|
||||
|
||||
/** 后端 /api/auth/login 响应体 */
|
||||
interface LoginResponse {
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
token_type: string;
|
||||
}
|
||||
|
||||
export interface AuthState {
|
||||
accessToken: string | null;
|
||||
refreshToken: string | null;
|
||||
user: AuthUser | null;
|
||||
isAuthenticated: boolean;
|
||||
|
||||
/** 用户名密码登录,成功后存储令牌到 state 和 localStorage */
|
||||
login: (username: string, password: string) => Promise<void>;
|
||||
/** 登出,清除 state 和 localStorage */
|
||||
logout: () => void;
|
||||
/** 从 localStorage 恢复状态(应用启动时调用) */
|
||||
hydrate: () => void;
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 常量 — 与 client.ts 保持一致 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
const ACCESS_TOKEN_KEY = "access_token";
|
||||
const REFRESH_TOKEN_KEY = "refresh_token";
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* 辅助:从 JWT payload 解析用户信息 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function parseJwtPayload(token: string): AuthUser | null {
|
||||
try {
|
||||
const base64 = token.split(".")[1];
|
||||
if (!base64) return null;
|
||||
const json = atob(base64);
|
||||
const payload = JSON.parse(json) as Record<string, unknown>;
|
||||
return {
|
||||
user_id: payload.user_id as number,
|
||||
username: payload.username as string,
|
||||
display_name: (payload.display_name as string) ?? "",
|
||||
site_id: payload.site_id as number,
|
||||
};
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Store */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
export const useAuthStore = create<AuthState>((set, _get) => ({
|
||||
accessToken: null,
|
||||
refreshToken: null,
|
||||
user: null,
|
||||
isAuthenticated: false,
|
||||
|
||||
async login(username: string, password: string) {
|
||||
const { data } = await apiClient.post<LoginResponse>("/auth/login", {
|
||||
username,
|
||||
password,
|
||||
});
|
||||
|
||||
const { access_token, refresh_token } = data;
|
||||
|
||||
// 持久化到 localStorage
|
||||
localStorage.setItem(ACCESS_TOKEN_KEY, access_token);
|
||||
localStorage.setItem(REFRESH_TOKEN_KEY, refresh_token);
|
||||
|
||||
const user = parseJwtPayload(access_token);
|
||||
|
||||
set({
|
||||
accessToken: access_token,
|
||||
refreshToken: refresh_token,
|
||||
user,
|
||||
isAuthenticated: true,
|
||||
});
|
||||
},
|
||||
|
||||
logout() {
|
||||
localStorage.removeItem(ACCESS_TOKEN_KEY);
|
||||
localStorage.removeItem(REFRESH_TOKEN_KEY);
|
||||
|
||||
set({
|
||||
accessToken: null,
|
||||
refreshToken: null,
|
||||
user: null,
|
||||
isAuthenticated: false,
|
||||
});
|
||||
},
|
||||
|
||||
hydrate() {
|
||||
const accessToken = localStorage.getItem(ACCESS_TOKEN_KEY);
|
||||
const refreshToken = localStorage.getItem(REFRESH_TOKEN_KEY);
|
||||
|
||||
if (accessToken) {
|
||||
const user = parseJwtPayload(accessToken);
|
||||
set({
|
||||
accessToken,
|
||||
refreshToken,
|
||||
user,
|
||||
isAuthenticated: true,
|
||||
});
|
||||
}
|
||||
},
|
||||
}));
|
||||
|
||||
// 监听 axios 拦截器的强制登出事件,同步清除 Zustand 状态
|
||||
// 避免 localStorage 已清空但 isAuthenticated 仍为 true 导致白屏
|
||||
window.addEventListener("auth:force-logout", () => {
|
||||
useAuthStore.getState().logout();
|
||||
});
|
||||
133
apps/admin-web/src/types/index.ts
Normal file
133
apps/admin-web/src/types/index.ts
Normal file
@@ -0,0 +1,133 @@
|
||||
/**
|
||||
* 前后端共享的 TypeScript 类型定义。
|
||||
* 与设计文档中的 Pydantic 模型和数据库表结构对应。
|
||||
*/
|
||||
|
||||
/** ETL 任务执行配置 */
|
||||
export interface TaskConfig {
|
||||
tasks: string[];
|
||||
/** 执行流程 Flow ID(对应 CLI --pipeline) */
|
||||
pipeline: string;
|
||||
/** 处理模式 */
|
||||
processing_mode: string;
|
||||
/** 传统模式兼容(已弃用) */
|
||||
pipeline_flow: string;
|
||||
dry_run: boolean;
|
||||
/** lookback / custom */
|
||||
window_mode: string;
|
||||
window_start: string | null;
|
||||
window_end: string | null;
|
||||
/** none / day */
|
||||
window_split: string | null;
|
||||
/** 1 / 10 / 30 */
|
||||
window_split_days: number | null;
|
||||
lookback_hours: number;
|
||||
overlap_seconds: number;
|
||||
fetch_before_verify: boolean;
|
||||
skip_ods_when_fetch_before_verify: boolean;
|
||||
ods_use_local_json: boolean;
|
||||
/** 门店 ID(由后端从 JWT 注入) */
|
||||
store_id: number | null;
|
||||
/** DWD 表级选择 */
|
||||
dwd_only_tables: string[] | null;
|
||||
/** 强制全量处理(跳过 hash 去重和变更对比) */
|
||||
force_full: boolean;
|
||||
extra_args: Record<string, unknown>;
|
||||
}
|
||||
|
||||
/** 执行流程(Flow)定义 */
|
||||
export interface PipelineDefinition {
|
||||
id: string;
|
||||
name: string;
|
||||
/** 包含的层:ODS / DWD / DWS / INDEX */
|
||||
layers: string[];
|
||||
}
|
||||
|
||||
/** 处理模式定义 */
|
||||
export interface ProcessingModeDefinition {
|
||||
id: string;
|
||||
name: string;
|
||||
description: string;
|
||||
}
|
||||
|
||||
/** 任务注册表中的任务定义 */
|
||||
export interface TaskDefinition {
|
||||
code: string;
|
||||
name: string;
|
||||
description: string;
|
||||
/** 业务域(会员、结算、助教等) */
|
||||
domain: string;
|
||||
requires_window: boolean;
|
||||
is_ods: boolean;
|
||||
is_dimension: boolean;
|
||||
default_enabled: boolean;
|
||||
/** 常用任务标记,false 表示工具类/手动类任务 */
|
||||
is_common: boolean;
|
||||
}
|
||||
|
||||
/** 调度配置 */
|
||||
export interface ScheduleConfig {
|
||||
schedule_type: "once" | "interval" | "daily" | "weekly" | "cron";
|
||||
interval_value: number;
|
||||
interval_unit: "minutes" | "hours" | "days";
|
||||
daily_time: string;
|
||||
weekly_days: number[];
|
||||
weekly_time: string;
|
||||
cron_expression: string;
|
||||
enabled: boolean;
|
||||
start_date: string | null;
|
||||
end_date: string | null;
|
||||
}
|
||||
|
||||
/** 队列中的任务 */
|
||||
export interface QueuedTask {
|
||||
id: string;
|
||||
site_id: number;
|
||||
config: TaskConfig;
|
||||
status: "pending" | "running" | "success" | "failed" | "cancelled";
|
||||
position: number;
|
||||
created_at: string;
|
||||
started_at: string | null;
|
||||
finished_at: string | null;
|
||||
exit_code: number | null;
|
||||
error_message: string | null;
|
||||
}
|
||||
|
||||
/** 执行历史记录 */
|
||||
export interface ExecutionLog {
|
||||
id: string;
|
||||
site_id: number;
|
||||
task_codes: string[];
|
||||
status: string;
|
||||
started_at: string;
|
||||
finished_at: string | null;
|
||||
exit_code: number | null;
|
||||
duration_ms: number | null;
|
||||
command: string;
|
||||
summary: Record<string, unknown> | null;
|
||||
}
|
||||
|
||||
/** 调度任务 */
|
||||
export interface ScheduledTask {
|
||||
id: string;
|
||||
site_id: number;
|
||||
name: string;
|
||||
task_codes: string[];
|
||||
task_config: TaskConfig;
|
||||
schedule_config: ScheduleConfig;
|
||||
enabled: boolean;
|
||||
last_run_at: string | null;
|
||||
next_run_at: string | null;
|
||||
run_count: number;
|
||||
last_status: string | null;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
}
|
||||
|
||||
/** 环境配置项 */
|
||||
export interface EnvConfigItem {
|
||||
key: string;
|
||||
value: string;
|
||||
is_sensitive: boolean;
|
||||
}
|
||||
|
||||
1
apps/admin-web/src/vite-env.d.ts
vendored
Normal file
1
apps/admin-web/src/vite-env.d.ts
vendored
Normal file
@@ -0,0 +1 @@
|
||||
/// <reference types="vite/client" />
|
||||
26
apps/admin-web/tsconfig.json
Normal file
26
apps/admin-web/tsconfig.json
Normal file
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2020",
|
||||
"useDefineForClassFields": true,
|
||||
"lib": ["ES2020", "DOM", "DOM.Iterable"],
|
||||
"module": "ESNext",
|
||||
"skipLibCheck": true,
|
||||
|
||||
/* 打包 */
|
||||
"moduleResolution": "bundler",
|
||||
"allowImportingTsExtensions": true,
|
||||
"isolatedModules": true,
|
||||
"moduleDetection": "force",
|
||||
"noEmit": true,
|
||||
"jsx": "react-jsx",
|
||||
|
||||
/* 类型检查 */
|
||||
"strict": true,
|
||||
"noUnusedLocals": true,
|
||||
"noUnusedParameters": true,
|
||||
"noFallthroughCasesInSwitch": true,
|
||||
"forceConsistentCasingInFileNames": true
|
||||
},
|
||||
"include": ["src"],
|
||||
"references": [{ "path": "./tsconfig.node.json" }]
|
||||
}
|
||||
25
apps/admin-web/tsconfig.node.json
Normal file
25
apps/admin-web/tsconfig.node.json
Normal file
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2022",
|
||||
"lib": ["ES2023"],
|
||||
"module": "ESNext",
|
||||
"skipLibCheck": true,
|
||||
|
||||
"moduleResolution": "bundler",
|
||||
"allowImportingTsExtensions": true,
|
||||
"isolatedModules": true,
|
||||
"moduleDetection": "force",
|
||||
|
||||
"composite": true,
|
||||
"declaration": true,
|
||||
"declarationMap": true,
|
||||
"emitDeclarationOnly": true,
|
||||
|
||||
"strict": true,
|
||||
"noUnusedLocals": true,
|
||||
"noUnusedParameters": true,
|
||||
"noFallthroughCasesInSwitch": true,
|
||||
"forceConsistentCasingInFileNames": true
|
||||
},
|
||||
"include": ["vite.config.ts"]
|
||||
}
|
||||
1
apps/admin-web/tsconfig.node.tsbuildinfo
Normal file
1
apps/admin-web/tsconfig.node.tsbuildinfo
Normal file
File diff suppressed because one or more lines are too long
1
apps/admin-web/tsconfig.tsbuildinfo
Normal file
1
apps/admin-web/tsconfig.tsbuildinfo
Normal file
@@ -0,0 +1 @@
|
||||
{"root":["./src/app.tsx","./src/main.tsx","./src/vite-env.d.ts","./src/__tests__/flowlayers.test.ts","./src/__tests__/logfilter.test.ts","./src/api/client.ts","./src/api/dbviewer.ts","./src/api/envconfig.ts","./src/api/etlstatus.ts","./src/api/execution.ts","./src/api/schedules.ts","./src/api/tasks.ts","./src/components/dwdtableselector.tsx","./src/components/errorboundary.tsx","./src/components/logstream.tsx","./src/components/scheduletab.tsx","./src/components/taskselector.tsx","./src/pages/dbviewer.tsx","./src/pages/etlstatus.tsx","./src/pages/envconfig.tsx","./src/pages/logviewer.tsx","./src/pages/login.tsx","./src/pages/taskconfig.tsx","./src/pages/taskmanager.tsx","./src/store/authstore.ts","./src/types/index.ts"],"version":"5.8.3"}
|
||||
3
apps/admin-web/vite.config.d.ts
vendored
Normal file
3
apps/admin-web/vite.config.d.ts
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
declare const _default: import("vite").UserConfig;
|
||||
export default _default;
|
||||
//# sourceMappingURL=vite.config.d.ts.map
|
||||
1
apps/admin-web/vite.config.d.ts.map
Normal file
1
apps/admin-web/vite.config.d.ts.map
Normal file
@@ -0,0 +1 @@
|
||||
{"version":3,"file":"vite.config.d.ts","sourceRoot":"","sources":["vite.config.ts"],"names":[],"mappings":";AAIA,wBAqBG"}
|
||||
26
apps/admin-web/vite.config.ts
Normal file
26
apps/admin-web/vite.config.ts
Normal file
@@ -0,0 +1,26 @@
|
||||
import { defineConfig } from "vitest/config";
|
||||
import react from "@vitejs/plugin-react";
|
||||
|
||||
// https://vite.dev/config/
|
||||
export default defineConfig({
|
||||
plugins: [react()],
|
||||
server: {
|
||||
proxy: {
|
||||
// API 代理到后端
|
||||
"/api": {
|
||||
target: "http://localhost:8000",
|
||||
changeOrigin: true,
|
||||
},
|
||||
// WebSocket 代理
|
||||
"/ws": {
|
||||
target: "ws://localhost:8000",
|
||||
ws: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
test: {
|
||||
globals: true,
|
||||
environment: "jsdom",
|
||||
setupFiles: [],
|
||||
},
|
||||
});
|
||||
48
apps/backend/.env.local
Normal file
48
apps/backend/.env.local
Normal file
@@ -0,0 +1,48 @@
|
||||
# ==============================================================================
|
||||
# NeoZQYY 后端 .env.local — 私有覆盖层
|
||||
# ==============================================================================
|
||||
# 后端 config.py 以 override=True 加载此文件,优先级高于根 .env
|
||||
# 敏感值禁止提交;本文件已在 .gitignore 中排除
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 业务数据库(zqyy_app)
|
||||
# ------------------------------------------------------------------------------
|
||||
# DB_HOST / DB_PORT / DB_USER / DB_PASSWORD 继承自根 .env,无需重复
|
||||
# CHANGE 2026-02-15 | 默认指向测试库,生产环境切换为 zqyy_app
|
||||
APP_DB_NAME=test_zqyy_app
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# ETL 数据库(后端只读访问,用于数据库查看器)
|
||||
# ------------------------------------------------------------------------------
|
||||
# 与 zqyy_app 同实例时可省略 ETL_DB_HOST/PORT/USER/PASSWORD,自动复用
|
||||
# CHANGE 2026-02-15 | 默认指向测试库,生产环境切换为 etl_feiqiu
|
||||
ETL_DB_NAME=test_etl_feiqiu
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# JWT 认证
|
||||
# ------------------------------------------------------------------------------
|
||||
JWT_SECRET_KEY=change-me-in-production
|
||||
JWT_ALGORITHM=HS256
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# CORS(逗号分隔)
|
||||
# ------------------------------------------------------------------------------
|
||||
CORS_ORIGINS=http://localhost:5173
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 微信消息推送(与微信后台填写的 Token 一致)
|
||||
# CHANGE 2026-02-19 | 新增微信消息推送回调 Token
|
||||
# ------------------------------------------------------------------------------
|
||||
WX_CALLBACK_TOKEN=LLZQwx2026push
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 通用
|
||||
# ------------------------------------------------------------------------------
|
||||
LOG_LEVEL=INFO
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# ETL 项目路径(子进程 cwd,缺省按 monorepo 相对路径推算)
|
||||
# ------------------------------------------------------------------------------
|
||||
# ETL_PROJECT_PATH=C:/NeoZQYY/apps/etl/connectors/feiqiu
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
## 内部结构
|
||||
|
||||
`
|
||||
```
|
||||
apps/backend/
|
||||
├── app/
|
||||
│ ├── main.py # FastAPI 入口,启用 OpenAPI 文档
|
||||
@@ -16,21 +16,22 @@ apps/backend/
|
||||
├── tests/ # 后端测试
|
||||
├── pyproject.toml # 依赖声明
|
||||
└── README.md
|
||||
`
|
||||
```
|
||||
|
||||
## 启动
|
||||
|
||||
`ash
|
||||
```bash
|
||||
# 确保已在根目录执行 uv sync --all-packages
|
||||
cd apps/backend
|
||||
uvicorn app.main:app --reload
|
||||
`
|
||||
uv run uvicorn app.main:app --host 127.0.0.1 --port 8000 --reload
|
||||
```
|
||||
|
||||
API 文档自动生成于 http://localhost:8000/docs
|
||||
|
||||
## 依赖
|
||||
|
||||
- fastapi>=0.100, uvicorn>=0.23
|
||||
- psycopg2-binary>=2.9.0
|
||||
- fastapi>=0.115, uvicorn[standard]>=0.34
|
||||
- psycopg2-binary>=2.9, python-dotenv>=1.0
|
||||
- neozqyy-shared(workspace 引用)
|
||||
|
||||
## Roadmap
|
||||
|
||||
1
apps/backend/app/auth/__init__.py
Normal file
1
apps/backend/app/auth/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""认证模块:JWT 令牌管理与 FastAPI 依赖注入。"""
|
||||
67
apps/backend/app/auth/dependencies.py
Normal file
67
apps/backend/app/auth/dependencies.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
FastAPI 依赖注入:从 JWT 提取当前用户信息。
|
||||
|
||||
用法:
|
||||
@router.get("/protected")
|
||||
async def protected_endpoint(user: CurrentUser = Depends(get_current_user)):
|
||||
print(user.user_id, user.site_id)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import JWTError
|
||||
|
||||
from app.auth.jwt import decode_access_token
|
||||
|
||||
# Bearer token 提取器
|
||||
_bearer_scheme = HTTPBearer(auto_error=True)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CurrentUser:
|
||||
"""从 JWT 解析出的当前用户上下文。"""
|
||||
|
||||
user_id: int
|
||||
site_id: int
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(_bearer_scheme),
|
||||
) -> CurrentUser:
|
||||
"""
|
||||
FastAPI 依赖:从 Authorization header 提取 JWT,验证后返回用户信息。
|
||||
|
||||
失败时抛出 401。
|
||||
"""
|
||||
token = credentials.credentials
|
||||
try:
|
||||
payload = decode_access_token(token)
|
||||
except JWTError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无效的令牌",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
user_id_raw = payload.get("sub")
|
||||
site_id = payload.get("site_id")
|
||||
|
||||
if user_id_raw is None or site_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="令牌缺少必要字段",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
user_id = int(user_id_raw)
|
||||
except (TypeError, ValueError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="令牌中 user_id 格式无效",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
return CurrentUser(user_id=user_id, site_id=site_id)
|
||||
112
apps/backend/app/auth/jwt.py
Normal file
112
apps/backend/app/auth/jwt.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
JWT 令牌生成、验证与解码。
|
||||
|
||||
- access_token:短期有效(默认 30 分钟),用于 API 请求认证
|
||||
- refresh_token:长期有效(默认 7 天),用于刷新 access_token
|
||||
- payload 包含 user_id、site_id、令牌类型(access / refresh)
|
||||
- 密码哈希直接使用 bcrypt 库(passlib 与 bcrypt>=4.1 存在兼容性问题)
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import bcrypt
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from app import config
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""校验明文密码与哈希是否匹配。"""
|
||||
return bcrypt.checkpw(
|
||||
plain_password.encode("utf-8"), hashed_password.encode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""生成密码的 bcrypt 哈希。"""
|
||||
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||
|
||||
|
||||
def create_access_token(user_id: int, site_id: int) -> str:
|
||||
"""
|
||||
生成 access_token。
|
||||
|
||||
payload: sub=user_id, site_id, type=access, exp
|
||||
"""
|
||||
expire = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=config.JWT_ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"site_id": site_id,
|
||||
"type": "access",
|
||||
"exp": expire,
|
||||
}
|
||||
return jwt.encode(payload, config.JWT_SECRET_KEY, algorithm=config.JWT_ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(user_id: int, site_id: int) -> str:
|
||||
"""
|
||||
生成 refresh_token。
|
||||
|
||||
payload: sub=user_id, site_id, type=refresh, exp
|
||||
"""
|
||||
expire = datetime.now(timezone.utc) + timedelta(
|
||||
days=config.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"site_id": site_id,
|
||||
"type": "refresh",
|
||||
"exp": expire,
|
||||
}
|
||||
return jwt.encode(payload, config.JWT_SECRET_KEY, algorithm=config.JWT_ALGORITHM)
|
||||
|
||||
|
||||
def create_token_pair(user_id: int, site_id: int) -> dict[str, str]:
|
||||
"""生成 access_token + refresh_token 令牌对。"""
|
||||
return {
|
||||
"access_token": create_access_token(user_id, site_id),
|
||||
"refresh_token": create_refresh_token(user_id, site_id),
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
|
||||
def decode_token(token: str) -> dict:
|
||||
"""
|
||||
解码并验证 JWT 令牌。
|
||||
|
||||
返回 payload dict,包含 sub、site_id、type、exp。
|
||||
令牌无效或过期时抛出 JWTError。
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, config.JWT_SECRET_KEY, algorithms=[config.JWT_ALGORITHM]
|
||||
)
|
||||
return payload
|
||||
except JWTError:
|
||||
raise
|
||||
|
||||
|
||||
def decode_access_token(token: str) -> dict:
|
||||
"""
|
||||
解码 access_token 并验证类型。
|
||||
|
||||
令牌类型不是 access 时抛出 JWTError。
|
||||
"""
|
||||
payload = decode_token(token)
|
||||
if payload.get("type") != "access":
|
||||
raise JWTError("令牌类型不是 access")
|
||||
return payload
|
||||
|
||||
|
||||
def decode_refresh_token(token: str) -> dict:
|
||||
"""
|
||||
解码 refresh_token 并验证类型。
|
||||
|
||||
令牌类型不是 refresh 时抛出 JWTError。
|
||||
"""
|
||||
payload = decode_token(token)
|
||||
if payload.get("type") != "refresh":
|
||||
raise JWTError("令牌类型不是 refresh")
|
||||
return payload
|
||||
@@ -29,7 +29,37 @@ DB_HOST: str = get("DB_HOST", "localhost")
|
||||
DB_PORT: str = get("DB_PORT", "5432")
|
||||
DB_USER: str = get("DB_USER", "")
|
||||
DB_PASSWORD: str = get("DB_PASSWORD", "")
|
||||
APP_DB_NAME: str = get("APP_DB_NAME", "zqyy_app")
|
||||
# CHANGE 2026-02-15 | 默认指向测试库,生产环境通过 .env 覆盖
|
||||
APP_DB_NAME: str = get("APP_DB_NAME", "test_zqyy_app")
|
||||
|
||||
# ---- JWT 认证 ----
|
||||
JWT_SECRET_KEY: str = get("JWT_SECRET_KEY", "") # 生产环境必须设置
|
||||
JWT_ALGORITHM: str = get("JWT_ALGORITHM", "HS256")
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = int(get("JWT_ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
|
||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = int(get("JWT_REFRESH_TOKEN_EXPIRE_DAYS", "7"))
|
||||
|
||||
# ---- ETL 数据库连接参数(可独立配置,缺省时复用 zqyy_app 的连接参数) ----
|
||||
ETL_DB_HOST: str = get("ETL_DB_HOST") or DB_HOST
|
||||
ETL_DB_PORT: str = get("ETL_DB_PORT") or DB_PORT
|
||||
ETL_DB_USER: str = get("ETL_DB_USER") or DB_USER
|
||||
ETL_DB_PASSWORD: str = get("ETL_DB_PASSWORD") or DB_PASSWORD
|
||||
# CHANGE 2026-02-15 | 默认指向测试库,生产环境通过 .env 覆盖
|
||||
ETL_DB_NAME: str = get("ETL_DB_NAME", "test_etl_feiqiu")
|
||||
|
||||
# ---- CORS ----
|
||||
# 逗号分隔的允许来源列表;缺省允许 Vite 开发服务器
|
||||
CORS_ORIGINS: list[str] = [
|
||||
o.strip()
|
||||
for o in get("CORS_ORIGINS", "http://localhost:5173").split(",")
|
||||
if o.strip()
|
||||
]
|
||||
|
||||
# ---- ETL 项目路径 ----
|
||||
# ETL CLI 的工作目录(子进程 cwd),缺省时按 monorepo 相对路径推算
|
||||
ETL_PROJECT_PATH: str = get(
|
||||
"ETL_PROJECT_PATH",
|
||||
str(Path(__file__).resolve().parents[3] / "apps" / "etl" / "connectors" / "feiqiu"),
|
||||
)
|
||||
|
||||
# ---- 通用 ----
|
||||
TIMEZONE: str = get("TIMEZONE", "Asia/Shanghai")
|
||||
|
||||
@@ -1,14 +1,30 @@
|
||||
"""
|
||||
zqyy_app 数据库连接
|
||||
数据库连接
|
||||
|
||||
使用 psycopg2 直连 PostgreSQL,不引入 ORM。
|
||||
连接参数从环境变量读取(经 config 模块加载)。
|
||||
|
||||
提供两类连接:
|
||||
- get_connection():zqyy_app 读写连接(用户/队列/调度等业务数据)
|
||||
- get_etl_readonly_connection(site_id):etl_feiqiu 只读连接(数据库查看器),
|
||||
自动设置 RLS site_id 隔离
|
||||
"""
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extensions import connection as PgConnection
|
||||
|
||||
from app.config import APP_DB_NAME, DB_HOST, DB_PASSWORD, DB_PORT, DB_USER
|
||||
from app.config import (
|
||||
APP_DB_NAME,
|
||||
DB_HOST,
|
||||
DB_PASSWORD,
|
||||
DB_PORT,
|
||||
DB_USER,
|
||||
ETL_DB_HOST,
|
||||
ETL_DB_NAME,
|
||||
ETL_DB_PASSWORD,
|
||||
ETL_DB_PORT,
|
||||
ETL_DB_USER,
|
||||
)
|
||||
|
||||
|
||||
def get_connection() -> PgConnection:
|
||||
@@ -24,3 +40,43 @@ def get_connection() -> PgConnection:
|
||||
password=DB_PASSWORD,
|
||||
dbname=APP_DB_NAME,
|
||||
)
|
||||
|
||||
|
||||
def get_etl_readonly_connection(site_id: int | str) -> PgConnection:
|
||||
"""
|
||||
获取 ETL 数据库(etl_feiqiu)的只读连接。
|
||||
|
||||
连接建立后自动执行:
|
||||
1. SET default_transaction_read_only = on — 禁止写操作
|
||||
2. SET LOCAL app.current_site_id = '{site_id}' — 启用 RLS 门店隔离
|
||||
|
||||
调用方负责关闭连接。典型用法::
|
||||
|
||||
conn = get_etl_readonly_connection(site_id)
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SELECT ...")
|
||||
finally:
|
||||
conn.close()
|
||||
"""
|
||||
conn = psycopg2.connect(
|
||||
host=ETL_DB_HOST,
|
||||
port=ETL_DB_PORT,
|
||||
user=ETL_DB_USER,
|
||||
password=ETL_DB_PASSWORD,
|
||||
dbname=ETL_DB_NAME,
|
||||
)
|
||||
try:
|
||||
conn.autocommit = False
|
||||
with conn.cursor() as cur:
|
||||
# 会话级只读:防止任何写操作
|
||||
cur.execute("SET default_transaction_read_only = on")
|
||||
# 事务级 RLS 隔离:设置当前门店 ID
|
||||
cur.execute(
|
||||
"SET LOCAL app.current_site_id = %s", (str(site_id),)
|
||||
)
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.close()
|
||||
raise
|
||||
return conn
|
||||
|
||||
@@ -1,20 +1,66 @@
|
||||
"""
|
||||
NeoZQYY 后端 API 入口
|
||||
|
||||
基于 FastAPI 构建,为微信小程序提供 RESTful API。
|
||||
基于 FastAPI 构建,为管理后台和微信小程序提供 RESTful API。
|
||||
OpenAPI 文档自动生成于 /docs(Swagger UI)和 /redoc(ReDoc)。
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app import config
|
||||
# CHANGE 2026-02-19 | 新增 xcx_test 路由(MVP 验证)+ wx_callback 路由(微信消息推送)
|
||||
from app.routers import auth, execution, schedules, tasks, env_config, db_viewer, etl_status, xcx_test, wx_callback
|
||||
from app.services.scheduler import scheduler
|
||||
from app.services.task_queue import task_queue
|
||||
from app.ws.logs import ws_router
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期:启动时拉起后台服务,关闭时优雅停止。"""
|
||||
# 启动
|
||||
task_queue.start()
|
||||
scheduler.start()
|
||||
yield
|
||||
# 关闭
|
||||
await scheduler.stop()
|
||||
await task_queue.stop()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="NeoZQYY API",
|
||||
description="台球门店运营助手 — 微信小程序后端 API",
|
||||
description="台球门店运营助手 — 后端 API(管理后台 + 微信小程序)",
|
||||
version="0.1.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# ---- CORS 中间件 ----
|
||||
# 允许来源从环境变量 CORS_ORIGINS 读取,缺省允许 Vite 开发服务器 (localhost:5173)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=config.CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# ---- 路由注册 ----
|
||||
app.include_router(auth.router)
|
||||
app.include_router(tasks.router)
|
||||
app.include_router(execution.router)
|
||||
app.include_router(schedules.router)
|
||||
app.include_router(env_config.router)
|
||||
app.include_router(db_viewer.router)
|
||||
app.include_router(etl_status.router)
|
||||
app.include_router(ws_router)
|
||||
app.include_router(xcx_test.router)
|
||||
app.include_router(wx_callback.router)
|
||||
|
||||
|
||||
@app.get("/health", tags=["系统"])
|
||||
async def health_check():
|
||||
|
||||
97
apps/backend/app/routers/auth.py
Normal file
97
apps/backend/app/routers/auth.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
认证路由:登录与令牌刷新。
|
||||
|
||||
- POST /api/auth/login — 验证用户名密码,返回 JWT 令牌对
|
||||
- POST /api/auth/refresh — 用刷新令牌换取新的访问令牌
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from jose import JWTError
|
||||
|
||||
from app.auth.jwt import (
|
||||
create_access_token,
|
||||
create_token_pair,
|
||||
decode_refresh_token,
|
||||
verify_password,
|
||||
)
|
||||
from app.database import get_connection
|
||||
from app.schemas.auth import LoginRequest, RefreshRequest, TokenResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["认证"])
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(body: LoginRequest):
|
||||
"""
|
||||
用户登录。
|
||||
|
||||
查询 admin_users 表验证用户名密码,成功后返回 JWT 令牌对。
|
||||
- 用户不存在或密码错误:401
|
||||
- 账号已禁用(is_active=false):401
|
||||
"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT id, password_hash, site_id, is_active "
|
||||
"FROM admin_users WHERE username = %s",
|
||||
(body.username,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if row is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户名或密码错误",
|
||||
)
|
||||
|
||||
user_id, password_hash, site_id, is_active = row
|
||||
|
||||
if not is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="账号已被禁用",
|
||||
)
|
||||
|
||||
if not verify_password(body.password, password_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户名或密码错误",
|
||||
)
|
||||
|
||||
tokens = create_token_pair(user_id, site_id)
|
||||
return TokenResponse(**tokens)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def refresh(body: RefreshRequest):
|
||||
"""
|
||||
刷新访问令牌。
|
||||
|
||||
验证 refresh_token 有效性,成功后仅返回新的 access_token
|
||||
(refresh_token 保持不变,由客户端继续持有)。
|
||||
"""
|
||||
try:
|
||||
payload = decode_refresh_token(body.refresh_token)
|
||||
except JWTError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无效的刷新令牌",
|
||||
)
|
||||
|
||||
user_id = int(payload["sub"])
|
||||
site_id = payload["site_id"]
|
||||
|
||||
# 生成新的 access_token,refresh_token 原样返回
|
||||
new_access = create_access_token(user_id, site_id)
|
||||
return TokenResponse(
|
||||
access_token=new_access,
|
||||
refresh_token=body.refresh_token,
|
||||
token_type="bearer",
|
||||
)
|
||||
228
apps/backend/app/routers/db_viewer.py
Normal file
228
apps/backend/app/routers/db_viewer.py
Normal file
@@ -0,0 +1,228 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""数据库查看器 API
|
||||
|
||||
提供 4 个端点:
|
||||
- GET /api/db/schemas — 返回 Schema 列表
|
||||
- GET /api/db/schemas/{name}/tables — 返回表列表和行数
|
||||
- GET /api/db/tables/{schema}/{table}/columns — 返回列定义
|
||||
- POST /api/db/query — 只读 SQL 执行
|
||||
|
||||
所有端点需要 JWT 认证。
|
||||
使用 get_etl_readonly_connection(site_id) 确保 RLS 隔离。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from psycopg2 import errors as pg_errors, OperationalError
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.database import get_etl_readonly_connection
|
||||
from app.schemas.db_viewer import (
|
||||
ColumnInfo,
|
||||
QueryRequest,
|
||||
QueryResponse,
|
||||
SchemaInfo,
|
||||
TableInfo,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/db", tags=["数据库查看器"])
|
||||
|
||||
# 写操作关键词(不区分大小写)
|
||||
_WRITE_KEYWORDS = re.compile(
|
||||
r"\b(INSERT|UPDATE|DELETE|DROP|TRUNCATE)\b",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# 查询结果行数上限
|
||||
_MAX_ROWS = 1000
|
||||
|
||||
# 查询超时(秒)
|
||||
_QUERY_TIMEOUT_SEC = 30
|
||||
|
||||
|
||||
# ── GET /api/db/schemas ──────────────────────────────────────
|
||||
|
||||
@router.get("/schemas", response_model=list[SchemaInfo])
|
||||
async def list_schemas(
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> list[SchemaInfo]:
|
||||
"""返回 ETL 数据库中的 Schema 列表。"""
|
||||
conn = get_etl_readonly_connection(user.site_id)
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
ORDER BY schema_name
|
||||
"""
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [SchemaInfo(name=row[0]) for row in rows]
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
# ── GET /api/db/schemas/{name}/tables ────────────────────────
|
||||
|
||||
@router.get("/schemas/{name}/tables", response_model=list[TableInfo])
|
||||
async def list_tables(
|
||||
name: str,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> list[TableInfo]:
|
||||
"""返回指定 Schema 下所有表的名称和行数统计。"""
|
||||
conn = get_etl_readonly_connection(user.site_id)
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT
|
||||
t.table_name,
|
||||
s.n_live_tup
|
||||
FROM information_schema.tables t
|
||||
LEFT JOIN pg_stat_user_tables s
|
||||
ON s.schemaname = t.table_schema
|
||||
AND s.relname = t.table_name
|
||||
WHERE t.table_schema = %s
|
||||
AND t.table_type = 'BASE TABLE'
|
||||
ORDER BY t.table_name
|
||||
""",
|
||||
(name,),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [
|
||||
TableInfo(name=row[0], row_count=row[1])
|
||||
for row in rows
|
||||
]
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
# ── GET /api/db/tables/{schema}/{table}/columns ──────────────
|
||||
|
||||
@router.get(
|
||||
"/tables/{schema}/{table}/columns",
|
||||
response_model=list[ColumnInfo],
|
||||
)
|
||||
async def list_columns(
|
||||
schema: str,
|
||||
table: str,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> list[ColumnInfo]:
|
||||
"""返回指定表的列定义。"""
|
||||
conn = get_etl_readonly_connection(user.site_id)
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT
|
||||
column_name,
|
||||
data_type,
|
||||
is_nullable,
|
||||
column_default
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = %s AND table_name = %s
|
||||
ORDER BY ordinal_position
|
||||
""",
|
||||
(schema, table),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [
|
||||
ColumnInfo(
|
||||
name=row[0],
|
||||
data_type=row[1],
|
||||
is_nullable=row[2] == "YES",
|
||||
column_default=row[3],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
# ── POST /api/db/query ───────────────────────────────────────
|
||||
|
||||
@router.post("/query", response_model=QueryResponse)
|
||||
async def execute_query(
|
||||
body: QueryRequest,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> QueryResponse:
|
||||
"""只读 SQL 执行。
|
||||
|
||||
安全措施:
|
||||
1. 拦截写操作关键词(INSERT / UPDATE / DELETE / DROP / TRUNCATE)
|
||||
2. 限制返回行数上限 1000 行
|
||||
3. 设置查询超时 30 秒
|
||||
"""
|
||||
sql = body.sql.strip()
|
||||
if not sql:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="SQL 语句不能为空",
|
||||
)
|
||||
|
||||
# 拦截写操作
|
||||
if _WRITE_KEYWORDS.search(sql):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="只允许只读查询,禁止 INSERT / UPDATE / DELETE / DROP / TRUNCATE 操作",
|
||||
)
|
||||
|
||||
conn = get_etl_readonly_connection(user.site_id)
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
# 设置查询超时
|
||||
cur.execute(
|
||||
"SET LOCAL statement_timeout = %s",
|
||||
(f"{_QUERY_TIMEOUT_SEC}s",),
|
||||
)
|
||||
|
||||
try:
|
||||
cur.execute(sql)
|
||||
except pg_errors.QueryCanceled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_408_REQUEST_TIMEOUT,
|
||||
detail=f"查询超时(超过 {_QUERY_TIMEOUT_SEC} 秒)",
|
||||
)
|
||||
except Exception as exc:
|
||||
# SQL 语法错误或其他执行错误
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"SQL 执行错误: {exc}",
|
||||
)
|
||||
|
||||
# 提取列名
|
||||
columns = (
|
||||
[desc[0] for desc in cur.description]
|
||||
if cur.description
|
||||
else []
|
||||
)
|
||||
|
||||
# 限制返回行数
|
||||
rows = cur.fetchmany(_MAX_ROWS)
|
||||
# 将元组转为列表,便于 JSON 序列化
|
||||
rows_list = [list(row) for row in rows]
|
||||
|
||||
return QueryResponse(
|
||||
columns=columns,
|
||||
rows=rows_list,
|
||||
row_count=len(rows_list),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except OperationalError as exc:
|
||||
# 连接级错误
|
||||
logger.error("数据库查看器连接错误: %s", exc)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="数据库连接错误",
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
240
apps/backend/app/routers/env_config.py
Normal file
240
apps/backend/app/routers/env_config.py
Normal file
@@ -0,0 +1,240 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""环境配置 API
|
||||
|
||||
提供 3 个端点:
|
||||
- GET /api/env-config — 读取 .env,敏感值掩码
|
||||
- PUT /api/env-config — 验证并写入 .env
|
||||
- GET /api/env-config/export — 导出去敏感值的配置文件
|
||||
|
||||
所有端点需要 JWT 认证。
|
||||
敏感键判定:键名中包含 PASSWORD、TOKEN、SECRET、DSN(不区分大小写)。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/env-config", tags=["环境配置"])
|
||||
|
||||
# .env 文件路径:项目根目录
|
||||
_ENV_PATH = Path(__file__).resolve().parents[3] / ".env"
|
||||
|
||||
# 敏感键关键词(不区分大小写)
|
||||
_SENSITIVE_KEYWORDS = ("PASSWORD", "TOKEN", "SECRET", "DSN")
|
||||
|
||||
_MASK = "****"
|
||||
|
||||
|
||||
# ── Pydantic 模型 ────────────────────────────────────────────
|
||||
|
||||
class EnvEntry(BaseModel):
|
||||
"""单条环境变量键值对。"""
|
||||
key: str
|
||||
value: str
|
||||
|
||||
|
||||
class EnvConfigResponse(BaseModel):
|
||||
"""GET 响应:键值对列表。"""
|
||||
entries: list[EnvEntry]
|
||||
|
||||
|
||||
class EnvConfigUpdateRequest(BaseModel):
|
||||
"""PUT 请求体:键值对列表。"""
|
||||
entries: list[EnvEntry]
|
||||
|
||||
|
||||
# ── 工具函数 ─────────────────────────────────────────────────
|
||||
|
||||
def _is_sensitive(key: str) -> bool:
|
||||
"""判断键名是否为敏感键。"""
|
||||
upper = key.upper()
|
||||
return any(kw in upper for kw in _SENSITIVE_KEYWORDS)
|
||||
|
||||
|
||||
def _parse_env(content: str) -> list[dict]:
|
||||
"""解析 .env 文件内容,返回行级结构。
|
||||
|
||||
每行分为三种类型:
|
||||
- comment: 注释行或空行(原样保留)
|
||||
- entry: 键值对
|
||||
"""
|
||||
lines: list[dict] = []
|
||||
for raw_line in content.splitlines():
|
||||
stripped = raw_line.strip()
|
||||
if not stripped or stripped.startswith("#"):
|
||||
lines.append({"type": "comment", "raw": raw_line})
|
||||
else:
|
||||
# 支持 KEY=VALUE 和 KEY="VALUE" 格式
|
||||
match = re.match(r'^([A-Za-z_][A-Za-z0-9_]*)=(.*)', raw_line)
|
||||
if match:
|
||||
key = match.group(1)
|
||||
value = match.group(2).strip()
|
||||
# 去除引号包裹
|
||||
if len(value) >= 2 and value[0] == value[-1] and value[0] in ('"', "'"):
|
||||
value = value[1:-1]
|
||||
lines.append({"type": "entry", "key": key, "value": value, "raw": raw_line})
|
||||
else:
|
||||
# 无法解析的行当作注释保留
|
||||
lines.append({"type": "comment", "raw": raw_line})
|
||||
return lines
|
||||
|
||||
|
||||
def _read_env_file(path: Path) -> str:
|
||||
"""读取 .env 文件内容。"""
|
||||
if not path.exists():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=".env 文件不存在",
|
||||
)
|
||||
return path.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def _write_env_file(path: Path, content: str) -> None:
|
||||
"""写入 .env 文件。"""
|
||||
try:
|
||||
path.write_text(content, encoding="utf-8")
|
||||
except OSError as exc:
|
||||
logger.error("写入 .env 文件失败: %s", exc)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="写入 .env 文件失败",
|
||||
)
|
||||
|
||||
|
||||
def _validate_entries(entries: list[EnvEntry]) -> None:
|
||||
"""验证键值对格式。"""
|
||||
for idx, entry in enumerate(entries):
|
||||
if not entry.key:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"第 {idx + 1} 行:键名不能为空",
|
||||
)
|
||||
if not re.match(r'^[A-Za-z_][A-Za-z0-9_]*$', entry.key):
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"第 {idx + 1} 行:键名 '{entry.key}' 格式无效(仅允许字母、数字、下划线,且不能以数字开头)",
|
||||
)
|
||||
|
||||
|
||||
# ── GET /api/env-config — 读取 ───────────────────────────────
|
||||
|
||||
@router.get("", response_model=EnvConfigResponse)
|
||||
async def get_env_config(
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> EnvConfigResponse:
|
||||
"""读取 .env 文件,敏感值以掩码展示。"""
|
||||
content = _read_env_file(_ENV_PATH)
|
||||
parsed = _parse_env(content)
|
||||
|
||||
entries = []
|
||||
for line in parsed:
|
||||
if line["type"] == "entry":
|
||||
value = _MASK if _is_sensitive(line["key"]) else line["value"]
|
||||
entries.append(EnvEntry(key=line["key"], value=value))
|
||||
|
||||
return EnvConfigResponse(entries=entries)
|
||||
|
||||
|
||||
# ── PUT /api/env-config — 写入 ───────────────────────────────
|
||||
|
||||
@router.put("", response_model=EnvConfigResponse)
|
||||
async def update_env_config(
|
||||
body: EnvConfigUpdateRequest,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> EnvConfigResponse:
|
||||
"""验证并写入 .env 文件。
|
||||
|
||||
保留原文件中的注释行和空行。对于已有键,更新值;
|
||||
对于新键,追加到文件末尾。掩码值(****)的键跳过更新,保留原值。
|
||||
"""
|
||||
_validate_entries(body.entries)
|
||||
|
||||
# 读取原文件(如果存在)
|
||||
if _ENV_PATH.exists():
|
||||
original_content = _ENV_PATH.read_text(encoding="utf-8")
|
||||
parsed = _parse_env(original_content)
|
||||
else:
|
||||
parsed = []
|
||||
|
||||
# 构建新值映射(跳过掩码值)
|
||||
new_values: dict[str, str] = {}
|
||||
for entry in body.entries:
|
||||
if entry.value != _MASK:
|
||||
new_values[entry.key] = entry.value
|
||||
|
||||
# 更新已有行
|
||||
seen_keys: set[str] = set()
|
||||
output_lines: list[str] = []
|
||||
for line in parsed:
|
||||
if line["type"] == "comment":
|
||||
output_lines.append(line["raw"])
|
||||
elif line["type"] == "entry":
|
||||
key = line["key"]
|
||||
seen_keys.add(key)
|
||||
if key in new_values:
|
||||
output_lines.append(f"{key}={new_values[key]}")
|
||||
else:
|
||||
# 保留原值(包括掩码跳过的敏感键)
|
||||
output_lines.append(line["raw"])
|
||||
|
||||
# 追加新键
|
||||
for entry in body.entries:
|
||||
if entry.key not in seen_keys and entry.value != _MASK:
|
||||
output_lines.append(f"{entry.key}={entry.value}")
|
||||
|
||||
new_content = "\n".join(output_lines)
|
||||
if output_lines:
|
||||
new_content += "\n"
|
||||
|
||||
_write_env_file(_ENV_PATH, new_content)
|
||||
|
||||
# 返回更新后的配置(敏感值掩码)
|
||||
result_parsed = _parse_env(new_content)
|
||||
entries = []
|
||||
for line in result_parsed:
|
||||
if line["type"] == "entry":
|
||||
value = _MASK if _is_sensitive(line["key"]) else line["value"]
|
||||
entries.append(EnvEntry(key=line["key"], value=value))
|
||||
|
||||
return EnvConfigResponse(entries=entries)
|
||||
|
||||
|
||||
# ── GET /api/env-config/export — 导出 ────────────────────────
|
||||
|
||||
@router.get("/export")
|
||||
async def export_env_config(
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> PlainTextResponse:
|
||||
"""导出去除敏感值的配置文件(作为文件下载)。"""
|
||||
content = _read_env_file(_ENV_PATH)
|
||||
parsed = _parse_env(content)
|
||||
|
||||
output_lines: list[str] = []
|
||||
for line in parsed:
|
||||
if line["type"] == "comment":
|
||||
output_lines.append(line["raw"])
|
||||
elif line["type"] == "entry":
|
||||
if _is_sensitive(line["key"]):
|
||||
output_lines.append(f"{line['key']}={_MASK}")
|
||||
else:
|
||||
output_lines.append(line["raw"])
|
||||
|
||||
export_content = "\n".join(output_lines)
|
||||
if output_lines:
|
||||
export_content += "\n"
|
||||
|
||||
return PlainTextResponse(
|
||||
content=export_content,
|
||||
media_type="text/plain",
|
||||
headers={"Content-Disposition": "attachment; filename=env-config.txt"},
|
||||
)
|
||||
134
apps/backend/app/routers/etl_status.py
Normal file
134
apps/backend/app/routers/etl_status.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""ETL 状态监控 API
|
||||
|
||||
提供 2 个端点:
|
||||
- GET /api/etl-status/cursors — 返回各任务的数据游标(最后抓取时间、记录数)
|
||||
- GET /api/etl-status/recent-runs — 返回最近 50 条任务执行记录
|
||||
|
||||
所有端点需要 JWT 认证。
|
||||
游标端点查询 ETL 数据库(meta.etl_cursor),
|
||||
执行记录端点查询 zqyy_app 数据库(task_execution_log)。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from psycopg2 import OperationalError
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.database import get_connection, get_etl_readonly_connection
|
||||
from app.schemas.etl_status import CursorInfo, RecentRun
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/etl-status", tags=["ETL 状态"])
|
||||
|
||||
# 最近执行记录条数上限
|
||||
_RECENT_RUNS_LIMIT = 50
|
||||
|
||||
|
||||
# ── GET /api/etl-status/cursors ──────────────────────────────
|
||||
|
||||
@router.get("/cursors", response_model=list[CursorInfo])
|
||||
async def list_cursors(
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> list[CursorInfo]:
|
||||
"""返回各 ODS 表的最新数据游标。
|
||||
|
||||
查询 ETL 数据库中的 meta.etl_cursor 表。
|
||||
如果该表不存在,返回空列表而非报错。
|
||||
"""
|
||||
conn = get_etl_readonly_connection(user.site_id)
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
# CHANGE 2026-02-15 | 对齐新库 etl_feiqiu 六层架构:etl_admin → meta
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'meta'
|
||||
AND table_name = 'etl_cursor'
|
||||
)
|
||||
"""
|
||||
)
|
||||
exists = cur.fetchone()[0]
|
||||
if not exists:
|
||||
return []
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT task_code, last_fetch_time, record_count
|
||||
FROM meta.etl_cursor
|
||||
ORDER BY task_code
|
||||
"""
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
|
||||
return [
|
||||
CursorInfo(
|
||||
task_code=row[0],
|
||||
last_fetch_time=str(row[1]) if row[1] is not None else None,
|
||||
record_count=row[2],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
except OperationalError as exc:
|
||||
logger.error("ETL 游标查询连接错误: %s", exc)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="ETL 数据库连接错误",
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
# ── GET /api/etl-status/recent-runs ──────────────────────────
|
||||
|
||||
@router.get("/recent-runs", response_model=list[RecentRun])
|
||||
async def list_recent_runs(
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> list[RecentRun]:
|
||||
"""返回最近 50 条任务执行记录。
|
||||
|
||||
查询 zqyy_app 数据库中的 task_execution_log 表,
|
||||
按 site_id 过滤,按 started_at DESC 排序。
|
||||
"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, task_codes, status, started_at,
|
||||
finished_at, duration_ms, exit_code
|
||||
FROM task_execution_log
|
||||
WHERE site_id = %s
|
||||
ORDER BY started_at DESC
|
||||
LIMIT %s
|
||||
""",
|
||||
(user.site_id, _RECENT_RUNS_LIMIT),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
|
||||
return [
|
||||
RecentRun(
|
||||
id=str(row[0]),
|
||||
task_codes=list(row[1]) if row[1] else [],
|
||||
status=row[2],
|
||||
started_at=str(row[3]),
|
||||
finished_at=str(row[4]) if row[4] is not None else None,
|
||||
duration_ms=row[5],
|
||||
exit_code=row[6],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
except OperationalError as exc:
|
||||
logger.error("执行记录查询连接错误: %s", exc)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="数据库连接错误",
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
281
apps/backend/app/routers/execution.py
Normal file
281
apps/backend/app/routers/execution.py
Normal file
@@ -0,0 +1,281 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""执行与队列 API
|
||||
|
||||
提供 8 个端点:
|
||||
- POST /api/execution/run — 直接执行任务
|
||||
- GET /api/execution/queue — 获取当前队列(按 site_id 过滤)
|
||||
- POST /api/execution/queue — 添加到队列
|
||||
- PUT /api/execution/queue/reorder — 重排队列
|
||||
- DELETE /api/execution/queue/{id} — 删除队列任务
|
||||
- POST /api/execution/{id}/cancel — 取消执行中的任务
|
||||
- GET /api/execution/history — 执行历史(按 site_id 过滤)
|
||||
- GET /api/execution/{id}/logs — 获取历史日志
|
||||
|
||||
所有端点需要 JWT 认证,site_id 从 JWT 提取。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.database import get_connection
|
||||
from app.schemas.execution import (
|
||||
ExecutionHistoryItem,
|
||||
ExecutionLogsResponse,
|
||||
ExecutionRunResponse,
|
||||
QueueTaskResponse,
|
||||
ReorderRequest,
|
||||
)
|
||||
from app.schemas.tasks import TaskConfigSchema
|
||||
from app.services.task_executor import task_executor
|
||||
from app.services.task_queue import task_queue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/execution", tags=["任务执行"])
|
||||
|
||||
|
||||
# ── POST /api/execution/run — 直接执行任务 ────────────────────
|
||||
|
||||
@router.post("/run", response_model=ExecutionRunResponse)
|
||||
async def run_task(
|
||||
config: TaskConfigSchema,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> ExecutionRunResponse:
|
||||
"""直接执行任务(不经过队列)。
|
||||
|
||||
从 JWT 注入 store_id,创建 execution_id 后异步启动子进程。
|
||||
"""
|
||||
config = config.model_copy(update={"store_id": user.site_id})
|
||||
execution_id = str(uuid.uuid4())
|
||||
|
||||
# 异步启动执行,不阻塞响应
|
||||
asyncio.create_task(
|
||||
task_executor.execute(
|
||||
config=config,
|
||||
execution_id=execution_id,
|
||||
site_id=user.site_id,
|
||||
)
|
||||
)
|
||||
|
||||
return ExecutionRunResponse(
|
||||
execution_id=execution_id,
|
||||
message="任务已提交执行",
|
||||
)
|
||||
|
||||
|
||||
# ── GET /api/execution/queue — 获取当前队列 ───────────────────
|
||||
|
||||
@router.get("/queue", response_model=list[QueueTaskResponse])
|
||||
async def get_queue(
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> list[QueueTaskResponse]:
|
||||
"""获取当前门店的待执行队列。"""
|
||||
tasks = task_queue.list_pending(user.site_id)
|
||||
return [
|
||||
QueueTaskResponse(
|
||||
id=t.id,
|
||||
site_id=t.site_id,
|
||||
config=t.config,
|
||||
status=t.status,
|
||||
position=t.position,
|
||||
created_at=t.created_at,
|
||||
started_at=t.started_at,
|
||||
finished_at=t.finished_at,
|
||||
exit_code=t.exit_code,
|
||||
error_message=t.error_message,
|
||||
)
|
||||
for t in tasks
|
||||
]
|
||||
|
||||
|
||||
# ── POST /api/execution/queue — 添加到队列 ───────────────────
|
||||
|
||||
@router.post("/queue", response_model=QueueTaskResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def enqueue_task(
|
||||
config: TaskConfigSchema,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> QueueTaskResponse:
|
||||
"""将任务配置添加到执行队列。"""
|
||||
config = config.model_copy(update={"store_id": user.site_id})
|
||||
task_id = task_queue.enqueue(config, user.site_id)
|
||||
|
||||
# 查询刚创建的任务返回完整信息
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, site_id, config, status, position,
|
||||
created_at, started_at, finished_at,
|
||||
exit_code, error_message
|
||||
FROM task_queue WHERE id = %s
|
||||
""",
|
||||
(task_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if row is None:
|
||||
raise HTTPException(status_code=500, detail="入队后查询失败")
|
||||
|
||||
config_data = row[2] if isinstance(row[2], dict) else json.loads(row[2])
|
||||
return QueueTaskResponse(
|
||||
id=str(row[0]),
|
||||
site_id=row[1],
|
||||
config=config_data,
|
||||
status=row[3],
|
||||
position=row[4],
|
||||
created_at=row[5],
|
||||
started_at=row[6],
|
||||
finished_at=row[7],
|
||||
exit_code=row[8],
|
||||
error_message=row[9],
|
||||
)
|
||||
|
||||
|
||||
# ── PUT /api/execution/queue/reorder — 重排队列 ──────────────
|
||||
|
||||
@router.put("/queue/reorder")
|
||||
async def reorder_queue(
|
||||
body: ReorderRequest,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""调整队列中任务的执行顺序。"""
|
||||
task_queue.reorder(body.task_id, body.new_position, user.site_id)
|
||||
return {"message": "队列已重排"}
|
||||
|
||||
|
||||
# ── DELETE /api/execution/queue/{id} — 删除队列任务 ──────────
|
||||
|
||||
@router.delete("/queue/{task_id}")
|
||||
async def delete_queue_task(
|
||||
task_id: str,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""从队列中删除待执行任务。仅允许删除 pending 状态的任务。"""
|
||||
deleted = task_queue.delete(task_id, user.site_id)
|
||||
if not deleted:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="任务不存在或非待执行状态,无法删除",
|
||||
)
|
||||
return {"message": "任务已从队列中删除"}
|
||||
|
||||
|
||||
# ── POST /api/execution/{id}/cancel — 取消执行 ──────────────
|
||||
|
||||
@router.post("/{execution_id}/cancel")
|
||||
async def cancel_execution(
|
||||
execution_id: str,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""取消正在执行的任务。"""
|
||||
cancelled = await task_executor.cancel(execution_id)
|
||||
if not cancelled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="执行任务不存在或已完成",
|
||||
)
|
||||
return {"message": "已发送取消信号"}
|
||||
|
||||
|
||||
# ── GET /api/execution/history — 执行历史 ────────────────────
|
||||
|
||||
@router.get("/history", response_model=list[ExecutionHistoryItem])
|
||||
async def get_execution_history(
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> list[ExecutionHistoryItem]:
|
||||
"""获取执行历史记录,按 started_at 降序排列。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, site_id, task_codes, status, started_at,
|
||||
finished_at, exit_code, duration_ms, command, summary
|
||||
FROM task_execution_log
|
||||
WHERE site_id = %s
|
||||
ORDER BY started_at DESC
|
||||
LIMIT %s
|
||||
""",
|
||||
(user.site_id, limit),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return [
|
||||
ExecutionHistoryItem(
|
||||
id=str(row[0]),
|
||||
site_id=row[1],
|
||||
task_codes=row[2] or [],
|
||||
status=row[3],
|
||||
started_at=row[4],
|
||||
finished_at=row[5],
|
||||
exit_code=row[6],
|
||||
duration_ms=row[7],
|
||||
command=row[8],
|
||||
summary=row[9],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
|
||||
# ── GET /api/execution/{id}/logs — 获取历史日志 ──────────────
|
||||
|
||||
@router.get("/{execution_id}/logs", response_model=ExecutionLogsResponse)
|
||||
async def get_execution_logs(
|
||||
execution_id: str,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> ExecutionLogsResponse:
|
||||
"""获取指定执行的完整日志。
|
||||
|
||||
优先从内存缓冲区读取(执行中),否则从数据库读取(已完成)。
|
||||
"""
|
||||
# 先尝试内存缓冲区(执行中的任务)
|
||||
if task_executor.is_running(execution_id):
|
||||
lines = task_executor.get_logs(execution_id)
|
||||
return ExecutionLogsResponse(
|
||||
execution_id=execution_id,
|
||||
output_log="\n".join(lines) if lines else None,
|
||||
)
|
||||
|
||||
# 从数据库读取已完成任务的日志
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT output_log, error_log
|
||||
FROM task_execution_log
|
||||
WHERE id = %s AND site_id = %s
|
||||
""",
|
||||
(execution_id, user.site_id),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if row is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="执行记录不存在",
|
||||
)
|
||||
|
||||
return ExecutionLogsResponse(
|
||||
execution_id=execution_id,
|
||||
output_log=row[0],
|
||||
error_log=row[1],
|
||||
)
|
||||
293
apps/backend/app/routers/schedules.py
Normal file
293
apps/backend/app/routers/schedules.py
Normal file
@@ -0,0 +1,293 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""调度任务 CRUD API
|
||||
|
||||
提供 5 个端点:
|
||||
- GET /api/schedules — 列表(按 site_id 过滤)
|
||||
- POST /api/schedules — 创建
|
||||
- PUT /api/schedules/{id} — 更新
|
||||
- DELETE /api/schedules/{id} — 删除
|
||||
- PATCH /api/schedules/{id}/toggle — 启用/禁用
|
||||
|
||||
所有端点需要 JWT 认证,site_id 从 JWT 提取。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.database import get_connection
|
||||
from app.schemas.schedules import (
|
||||
CreateScheduleRequest,
|
||||
ScheduleResponse,
|
||||
UpdateScheduleRequest,
|
||||
)
|
||||
from app.services.scheduler import calculate_next_run
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/schedules", tags=["调度管理"])
|
||||
|
||||
|
||||
|
||||
def _row_to_response(row) -> ScheduleResponse:
|
||||
"""将数据库行转换为 ScheduleResponse。"""
|
||||
task_config = row[4] if isinstance(row[4], dict) else json.loads(row[4])
|
||||
schedule_config = row[5] if isinstance(row[5], dict) else json.loads(row[5])
|
||||
return ScheduleResponse(
|
||||
id=str(row[0]),
|
||||
site_id=row[1],
|
||||
name=row[2],
|
||||
task_codes=row[3] or [],
|
||||
task_config=task_config,
|
||||
schedule_config=schedule_config,
|
||||
enabled=row[6],
|
||||
last_run_at=row[7],
|
||||
next_run_at=row[8],
|
||||
run_count=row[9],
|
||||
last_status=row[10],
|
||||
created_at=row[11],
|
||||
updated_at=row[12],
|
||||
)
|
||||
|
||||
|
||||
# 查询列列表,复用于多个端点
|
||||
_SELECT_COLS = """
|
||||
id, site_id, name, task_codes, task_config, schedule_config,
|
||||
enabled, last_run_at, next_run_at, run_count, last_status,
|
||||
created_at, updated_at
|
||||
"""
|
||||
|
||||
|
||||
# ── GET /api/schedules — 列表 ────────────────────────────────
|
||||
|
||||
@router.get("", response_model=list[ScheduleResponse])
|
||||
async def list_schedules(
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> list[ScheduleResponse]:
|
||||
"""获取当前门店的所有调度任务。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT {_SELECT_COLS} FROM scheduled_tasks WHERE site_id = %s ORDER BY created_at DESC",
|
||||
(user.site_id,),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return [_row_to_response(row) for row in rows]
|
||||
|
||||
|
||||
# ── POST /api/schedules — 创建 ──────────────────────────────
|
||||
|
||||
@router.post("", response_model=ScheduleResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_schedule(
|
||||
body: CreateScheduleRequest,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> ScheduleResponse:
|
||||
"""创建调度任务,自动计算 next_run_at。"""
|
||||
now = datetime.now(timezone.utc)
|
||||
next_run = calculate_next_run(body.schedule_config, now)
|
||||
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
f"""
|
||||
INSERT INTO scheduled_tasks
|
||||
(site_id, name, task_codes, task_config, schedule_config, enabled, next_run_at)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
||||
RETURNING {_SELECT_COLS}
|
||||
""",
|
||||
(
|
||||
user.site_id,
|
||||
body.name,
|
||||
body.task_codes,
|
||||
json.dumps(body.task_config),
|
||||
body.schedule_config.model_dump_json(),
|
||||
body.schedule_config.enabled,
|
||||
next_run,
|
||||
),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return _row_to_response(row)
|
||||
|
||||
|
||||
# ── PUT /api/schedules/{id} — 更新 ──────────────────────────
|
||||
|
||||
@router.put("/{schedule_id}", response_model=ScheduleResponse)
|
||||
async def update_schedule(
|
||||
schedule_id: str,
|
||||
body: UpdateScheduleRequest,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> ScheduleResponse:
|
||||
"""更新调度任务,仅更新请求中提供的字段。"""
|
||||
# 构建动态 SET 子句
|
||||
set_parts: list[str] = []
|
||||
params: list = []
|
||||
|
||||
if body.name is not None:
|
||||
set_parts.append("name = %s")
|
||||
params.append(body.name)
|
||||
if body.task_codes is not None:
|
||||
set_parts.append("task_codes = %s")
|
||||
params.append(body.task_codes)
|
||||
if body.task_config is not None:
|
||||
set_parts.append("task_config = %s")
|
||||
params.append(json.dumps(body.task_config))
|
||||
if body.schedule_config is not None:
|
||||
set_parts.append("schedule_config = %s")
|
||||
params.append(body.schedule_config.model_dump_json())
|
||||
# 更新调度配置时重新计算 next_run_at
|
||||
now = datetime.now(timezone.utc)
|
||||
next_run = calculate_next_run(body.schedule_config, now)
|
||||
set_parts.append("next_run_at = %s")
|
||||
params.append(next_run)
|
||||
|
||||
if not set_parts:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="至少需要提供一个更新字段",
|
||||
)
|
||||
|
||||
set_parts.append("updated_at = NOW()")
|
||||
set_clause = ", ".join(set_parts)
|
||||
params.extend([schedule_id, user.site_id])
|
||||
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
f"""
|
||||
UPDATE scheduled_tasks
|
||||
SET {set_clause}
|
||||
WHERE id = %s AND site_id = %s
|
||||
RETURNING {_SELECT_COLS}
|
||||
""",
|
||||
params,
|
||||
)
|
||||
row = cur.fetchone()
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if row is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="调度任务不存在",
|
||||
)
|
||||
|
||||
return _row_to_response(row)
|
||||
|
||||
|
||||
# ── DELETE /api/schedules/{id} — 删除 ────────────────────────
|
||||
|
||||
@router.delete("/{schedule_id}")
|
||||
async def delete_schedule(
|
||||
schedule_id: str,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""删除调度任务。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"DELETE FROM scheduled_tasks WHERE id = %s AND site_id = %s",
|
||||
(schedule_id, user.site_id),
|
||||
)
|
||||
deleted = cur.rowcount
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if deleted == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="调度任务不存在",
|
||||
)
|
||||
|
||||
return {"message": "调度任务已删除"}
|
||||
|
||||
|
||||
# ── PATCH /api/schedules/{id}/toggle — 启用/禁用 ─────────────
|
||||
|
||||
@router.patch("/{schedule_id}/toggle", response_model=ScheduleResponse)
|
||||
async def toggle_schedule(
|
||||
schedule_id: str,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> ScheduleResponse:
|
||||
"""切换调度任务的启用/禁用状态。
|
||||
|
||||
禁用时 next_run_at 置 NULL;启用时重新计算 next_run_at。
|
||||
"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
# 先查询当前状态和调度配置
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT enabled, schedule_config FROM scheduled_tasks WHERE id = %s AND site_id = %s",
|
||||
(schedule_id, user.site_id),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
|
||||
if row is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="调度任务不存在",
|
||||
)
|
||||
|
||||
current_enabled = row[0]
|
||||
new_enabled = not current_enabled
|
||||
|
||||
if new_enabled:
|
||||
# 启用:重新计算 next_run_at
|
||||
schedule_config_raw = row[1] if isinstance(row[1], dict) else json.loads(row[1])
|
||||
from app.schemas.schedules import ScheduleConfigSchema
|
||||
schedule_cfg = ScheduleConfigSchema(**schedule_config_raw)
|
||||
now = datetime.now(timezone.utc)
|
||||
next_run = calculate_next_run(schedule_cfg, now)
|
||||
else:
|
||||
# 禁用:next_run_at 置 NULL
|
||||
next_run = None
|
||||
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
f"""
|
||||
UPDATE scheduled_tasks
|
||||
SET enabled = %s, next_run_at = %s, updated_at = NOW()
|
||||
WHERE id = %s AND site_id = %s
|
||||
RETURNING {_SELECT_COLS}
|
||||
""",
|
||||
(new_enabled, next_run, schedule_id, user.site_id),
|
||||
)
|
||||
updated_row = cur.fetchone()
|
||||
conn.commit()
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return _row_to_response(updated_row)
|
||||
209
apps/backend/app/routers/tasks.py
Normal file
209
apps/backend/app/routers/tasks.py
Normal file
@@ -0,0 +1,209 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""任务注册表 & 配置 API
|
||||
|
||||
提供 4 个端点:
|
||||
- GET /api/tasks/registry — 按业务域分组的任务列表
|
||||
- GET /api/tasks/dwd-tables — 按业务域分组的 DWD 表定义
|
||||
- GET /api/tasks/flows — 7 种 Flow + 3 种处理模式
|
||||
- POST /api/tasks/validate — 验证 TaskConfig 并返回 CLI 命令预览
|
||||
|
||||
所有端点需要 JWT 认证。validate 端点从 JWT 注入 store_id。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.config import ETL_PROJECT_PATH
|
||||
from app.schemas.tasks import (
|
||||
FlowDefinition,
|
||||
ProcessingModeDefinition,
|
||||
TaskConfigSchema,
|
||||
)
|
||||
from app.services.cli_builder import cli_builder
|
||||
from app.services.task_registry import (
|
||||
DWD_TABLES,
|
||||
FLOW_LAYER_MAP,
|
||||
get_dwd_tables_grouped_by_domain,
|
||||
get_tasks_grouped_by_domain,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/tasks", tags=["任务配置"])
|
||||
|
||||
|
||||
# ── 响应模型 ──────────────────────────────────────────────────
|
||||
|
||||
class TaskItem(BaseModel):
|
||||
code: str
|
||||
name: str
|
||||
description: str
|
||||
domain: str
|
||||
layer: str
|
||||
requires_window: bool
|
||||
is_ods: bool
|
||||
is_dimension: bool
|
||||
default_enabled: bool
|
||||
is_common: bool
|
||||
|
||||
|
||||
class DwdTableItem(BaseModel):
|
||||
table_name: str
|
||||
display_name: str
|
||||
domain: str
|
||||
ods_source: str
|
||||
is_dimension: bool
|
||||
|
||||
|
||||
class TaskRegistryResponse(BaseModel):
|
||||
"""按业务域分组的任务列表"""
|
||||
groups: dict[str, list[TaskItem]]
|
||||
|
||||
|
||||
class DwdTablesResponse(BaseModel):
|
||||
"""按业务域分组的 DWD 表定义"""
|
||||
groups: dict[str, list[DwdTableItem]]
|
||||
|
||||
|
||||
class FlowsResponse(BaseModel):
|
||||
"""Flow 定义 + 处理模式定义"""
|
||||
flows: list[FlowDefinition]
|
||||
processing_modes: list[ProcessingModeDefinition]
|
||||
|
||||
|
||||
class ValidateRequest(BaseModel):
|
||||
"""验证请求体 — 复用 TaskConfigSchema,但 store_id 由后端注入"""
|
||||
config: TaskConfigSchema
|
||||
|
||||
|
||||
class ValidateResponse(BaseModel):
|
||||
"""验证结果 + CLI 命令预览"""
|
||||
valid: bool
|
||||
command: str
|
||||
command_args: list[str]
|
||||
errors: list[str]
|
||||
|
||||
|
||||
# ── Flow 定义(静态) ────────────────────────────────────────
|
||||
|
||||
FLOW_DEFINITIONS: list[FlowDefinition] = [
|
||||
FlowDefinition(id="api_ods", name="API → ODS", layers=["ODS"]),
|
||||
FlowDefinition(id="api_ods_dwd", name="API → ODS → DWD", layers=["ODS", "DWD"]),
|
||||
FlowDefinition(id="api_full", name="API → ODS → DWD → DWS汇总 → DWS指数", layers=["ODS", "DWD", "DWS", "INDEX"]),
|
||||
FlowDefinition(id="ods_dwd", name="ODS → DWD", layers=["DWD"]),
|
||||
FlowDefinition(id="dwd_dws", name="DWD → DWS汇总", layers=["DWS"]),
|
||||
FlowDefinition(id="dwd_dws_index", name="DWD → DWS汇总 → DWS指数", layers=["DWS", "INDEX"]),
|
||||
FlowDefinition(id="dwd_index", name="DWD → DWS指数", layers=["INDEX"]),
|
||||
]
|
||||
|
||||
PROCESSING_MODE_DEFINITIONS: list[ProcessingModeDefinition] = [
|
||||
ProcessingModeDefinition(id="increment_only", name="仅增量处理", description="只处理新增和变更的数据"),
|
||||
ProcessingModeDefinition(id="verify_only", name="仅校验修复", description="校验现有数据并修复不一致(可选'校验前从 API 获取')"),
|
||||
ProcessingModeDefinition(id="increment_verify", name="增量 + 校验修复", description="先增量处理,再校验并修复"),
|
||||
]
|
||||
|
||||
|
||||
# ── 端点 ──────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/registry", response_model=TaskRegistryResponse)
|
||||
async def get_task_registry(
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> TaskRegistryResponse:
|
||||
"""返回按业务域分组的任务列表"""
|
||||
grouped = get_tasks_grouped_by_domain()
|
||||
return TaskRegistryResponse(
|
||||
groups={
|
||||
domain: [
|
||||
TaskItem(
|
||||
code=t.code,
|
||||
name=t.name,
|
||||
description=t.description,
|
||||
domain=t.domain,
|
||||
layer=t.layer,
|
||||
requires_window=t.requires_window,
|
||||
is_ods=t.is_ods,
|
||||
is_dimension=t.is_dimension,
|
||||
default_enabled=t.default_enabled,
|
||||
is_common=t.is_common,
|
||||
)
|
||||
for t in tasks
|
||||
]
|
||||
for domain, tasks in grouped.items()
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/dwd-tables", response_model=DwdTablesResponse)
|
||||
async def get_dwd_tables(
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> DwdTablesResponse:
|
||||
"""返回按业务域分组的 DWD 表定义"""
|
||||
grouped = get_dwd_tables_grouped_by_domain()
|
||||
return DwdTablesResponse(
|
||||
groups={
|
||||
domain: [
|
||||
DwdTableItem(
|
||||
table_name=t.table_name,
|
||||
display_name=t.display_name,
|
||||
domain=t.domain,
|
||||
ods_source=t.ods_source,
|
||||
is_dimension=t.is_dimension,
|
||||
)
|
||||
for t in tables
|
||||
]
|
||||
for domain, tables in grouped.items()
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/flows", response_model=FlowsResponse)
|
||||
async def get_flows(
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> FlowsResponse:
|
||||
"""返回 7 种 Flow 定义和 3 种处理模式定义"""
|
||||
return FlowsResponse(
|
||||
flows=FLOW_DEFINITIONS,
|
||||
processing_modes=PROCESSING_MODE_DEFINITIONS,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/validate", response_model=ValidateResponse)
|
||||
async def validate_task_config(
|
||||
body: ValidateRequest,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> ValidateResponse:
|
||||
"""验证 TaskConfig 并返回生成的 CLI 命令预览
|
||||
|
||||
从 JWT 注入 store_id,前端无需传递。
|
||||
"""
|
||||
config = body.config.model_copy(update={"store_id": user.site_id})
|
||||
errors: list[str] = []
|
||||
|
||||
# 验证 Flow ID
|
||||
if config.pipeline not in FLOW_LAYER_MAP:
|
||||
errors.append(f"无效的执行流程: {config.pipeline}")
|
||||
|
||||
# 验证任务列表非空
|
||||
if not config.tasks:
|
||||
errors.append("任务列表不能为空")
|
||||
|
||||
if errors:
|
||||
return ValidateResponse(
|
||||
valid=False,
|
||||
command="",
|
||||
command_args=[],
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
cmd_args = cli_builder.build_command(config, ETL_PROJECT_PATH)
|
||||
cmd_str = cli_builder.build_command_string(config, ETL_PROJECT_PATH)
|
||||
|
||||
return ValidateResponse(
|
||||
valid=True,
|
||||
command=cmd_str,
|
||||
command_args=cmd_args,
|
||||
errors=[],
|
||||
)
|
||||
104
apps/backend/app/routers/wx_callback.py
Normal file
104
apps/backend/app/routers/wx_callback.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# AI_CHANGELOG
|
||||
# - 2026-02-19 | Prompt: 配置微信消息推送 | 新增微信消息推送回调接口,支持 GET 验签 + POST 消息接收
|
||||
|
||||
"""
|
||||
微信消息推送回调接口
|
||||
|
||||
处理两类请求:
|
||||
1. GET — 微信服务器验证(配置时触发一次)
|
||||
2. POST — 接收微信推送的消息/事件
|
||||
|
||||
安全模式下需要解密消息体,当前先用明文模式跑通,后续切安全模式。
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Query, Request, Response
|
||||
|
||||
from app.config import get
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/wx", tags=["微信回调"])
|
||||
|
||||
# Token 从环境变量读取,与微信后台填写的一致
|
||||
# 放在 apps/backend/.env.local 中:WX_CALLBACK_TOKEN=你自定义的token
|
||||
WX_CALLBACK_TOKEN: str = get("WX_CALLBACK_TOKEN", "")
|
||||
|
||||
|
||||
def _check_signature(signature: str, timestamp: str, nonce: str) -> bool:
|
||||
"""
|
||||
验证请求是否来自微信服务器。
|
||||
|
||||
将 Token、timestamp、nonce 字典序排序后拼接,做 SHA1,
|
||||
与 signature 比对。
|
||||
"""
|
||||
if not WX_CALLBACK_TOKEN:
|
||||
logger.error("WX_CALLBACK_TOKEN 未配置")
|
||||
return False
|
||||
|
||||
items = sorted([WX_CALLBACK_TOKEN, timestamp, nonce])
|
||||
hash_str = hashlib.sha1("".join(items).encode("utf-8")).hexdigest()
|
||||
return hash_str == signature
|
||||
|
||||
|
||||
@router.get("/callback")
|
||||
async def verify(
|
||||
signature: str = Query(...),
|
||||
timestamp: str = Query(...),
|
||||
nonce: str = Query(...),
|
||||
echostr: str = Query(...),
|
||||
):
|
||||
"""
|
||||
微信服务器验证接口。
|
||||
|
||||
配置消息推送时微信会发 GET 请求,验签通过后原样返回 echostr。
|
||||
"""
|
||||
if _check_signature(signature, timestamp, nonce):
|
||||
logger.info("微信回调验证通过")
|
||||
# 必须原样返回 echostr(纯文本,不能包裹 JSON)
|
||||
return Response(content=echostr, media_type="text/plain")
|
||||
else:
|
||||
logger.warning("微信回调验签失败: signature=%s", signature)
|
||||
return Response(content="signature mismatch", status_code=403)
|
||||
|
||||
|
||||
@router.post("/callback")
|
||||
async def receive_message(
|
||||
request: Request,
|
||||
signature: str = Query(""),
|
||||
timestamp: str = Query(""),
|
||||
nonce: str = Query(""),
|
||||
):
|
||||
"""
|
||||
接收微信推送的消息/事件。
|
||||
|
||||
当前为明文模式,直接解析 JSON 包体。
|
||||
后续切安全模式时需增加 AES 解密逻辑。
|
||||
"""
|
||||
# 验签(POST 也带 signature 参数)
|
||||
if not _check_signature(signature, timestamp, nonce):
|
||||
logger.warning("消息推送验签失败")
|
||||
return Response(content="signature mismatch", status_code=403)
|
||||
|
||||
# 解析消息体
|
||||
body = await request.body()
|
||||
content_type = request.headers.get("content-type", "")
|
||||
|
||||
if "json" in content_type:
|
||||
import json
|
||||
try:
|
||||
data = json.loads(body)
|
||||
except json.JSONDecodeError:
|
||||
data = {"raw": body.decode("utf-8", errors="replace")}
|
||||
else:
|
||||
# XML 格式暂不解析,记录原文
|
||||
data = {"raw_xml": body.decode("utf-8", errors="replace")}
|
||||
|
||||
logger.info("收到微信推送: MsgType=%s, Event=%s",
|
||||
data.get("MsgType", "?"), data.get("Event", "?"))
|
||||
|
||||
# TODO: 根据 MsgType/Event 分发处理(客服消息、订阅事件等)
|
||||
# 当前统一返回 success
|
||||
return Response(content="success", media_type="text/plain")
|
||||
37
apps/backend/app/routers/xcx_test.py
Normal file
37
apps/backend/app/routers/xcx_test.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# AI_CHANGELOG
|
||||
# - 2026-02-19 | Prompt: 小程序 MVP 全链路验证 | 新增 /api/xcx-test 接口,从 test."xcx-test" 表读取 ti 列第一行
|
||||
|
||||
"""
|
||||
小程序 MVP 验证接口
|
||||
|
||||
从 test_zqyy_app 库的 test."xcx-test" 表读取数据,
|
||||
用于验证小程序 → 后端 → 数据库全链路连通性。
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from app.database import get_connection
|
||||
|
||||
router = APIRouter(prefix="/api/xcx-test", tags=["小程序MVP"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_xcx_test():
|
||||
"""
|
||||
读取 test."xcx-test" 表 ti 列第一行。
|
||||
|
||||
用于小程序 MVP 全链路验证:小程序 → API → DB → 返回数据。
|
||||
"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
# CHANGE 2026-02-19 | 读取 test schema 下的 xcx-test 表
|
||||
# 表名含连字符,必须用双引号包裹
|
||||
cur.execute('SELECT ti FROM test."xcx-test" LIMIT 1')
|
||||
row = cur.fetchone()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail="无数据")
|
||||
|
||||
return {"ti": row[0]}
|
||||
30
apps/backend/app/schemas/auth.py
Normal file
30
apps/backend/app/schemas/auth.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""
|
||||
认证相关 Pydantic 模型。
|
||||
|
||||
- LoginRequest:登录请求体
|
||||
- TokenResponse:令牌响应体
|
||||
- RefreshRequest:刷新令牌请求体
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
"""登录请求。"""
|
||||
|
||||
username: str = Field(..., min_length=1, max_length=64, description="用户名")
|
||||
password: str = Field(..., min_length=1, description="密码")
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
"""刷新令牌请求。"""
|
||||
|
||||
refresh_token: str = Field(..., min_length=1, description="刷新令牌")
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""令牌响应。"""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
42
apps/backend/app/schemas/db_viewer.py
Normal file
42
apps/backend/app/schemas/db_viewer.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""数据库查看器 Pydantic 模型
|
||||
|
||||
定义 Schema 浏览、表结构查看、SQL 查询的请求/响应模型。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SchemaInfo(BaseModel):
|
||||
"""Schema 信息。"""
|
||||
name: str
|
||||
|
||||
|
||||
class TableInfo(BaseModel):
|
||||
"""表信息(含行数统计)。"""
|
||||
name: str
|
||||
row_count: int | None = None
|
||||
|
||||
|
||||
class ColumnInfo(BaseModel):
|
||||
"""列定义。"""
|
||||
name: str
|
||||
data_type: str
|
||||
is_nullable: bool
|
||||
column_default: str | None = None
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
"""SQL 查询请求。"""
|
||||
sql: str
|
||||
|
||||
|
||||
class QueryResponse(BaseModel):
|
||||
"""SQL 查询响应。"""
|
||||
columns: list[str]
|
||||
rows: list[list[Any]]
|
||||
row_count: int
|
||||
27
apps/backend/app/schemas/etl_status.py
Normal file
27
apps/backend/app/schemas/etl_status.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""ETL 状态监控 Pydantic 模型
|
||||
|
||||
定义游标信息和最近执行记录的响应模型。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CursorInfo(BaseModel):
|
||||
"""ETL 游标信息(单条任务的最后抓取状态)。"""
|
||||
task_code: str
|
||||
last_fetch_time: str | None = None
|
||||
record_count: int | None = None
|
||||
|
||||
|
||||
class RecentRun(BaseModel):
|
||||
"""最近执行记录。"""
|
||||
id: str
|
||||
task_codes: list[str]
|
||||
status: str
|
||||
started_at: str
|
||||
finished_at: str | None = None
|
||||
duration_ms: int | None = None
|
||||
exit_code: int | None = None
|
||||
59
apps/backend/app/schemas/execution.py
Normal file
59
apps/backend/app/schemas/execution.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""执行与队列相关的 Pydantic 模型
|
||||
|
||||
用于 execution 路由的请求/响应序列化。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ReorderRequest(BaseModel):
|
||||
"""队列重排请求"""
|
||||
task_id: str
|
||||
new_position: int
|
||||
|
||||
|
||||
class QueueTaskResponse(BaseModel):
|
||||
"""队列任务响应"""
|
||||
id: str
|
||||
site_id: int
|
||||
config: dict[str, Any]
|
||||
status: str
|
||||
position: int
|
||||
created_at: datetime | None = None
|
||||
started_at: datetime | None = None
|
||||
finished_at: datetime | None = None
|
||||
exit_code: int | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class ExecutionRunResponse(BaseModel):
|
||||
"""直接执行任务的响应"""
|
||||
execution_id: str
|
||||
message: str
|
||||
|
||||
|
||||
class ExecutionHistoryItem(BaseModel):
|
||||
"""执行历史记录"""
|
||||
id: str
|
||||
site_id: int
|
||||
task_codes: list[str]
|
||||
status: str
|
||||
started_at: datetime
|
||||
finished_at: datetime | None = None
|
||||
exit_code: int | None = None
|
||||
duration_ms: int | None = None
|
||||
command: str | None = None
|
||||
summary: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ExecutionLogsResponse(BaseModel):
|
||||
"""执行日志响应"""
|
||||
execution_id: str
|
||||
output_log: str | None = None
|
||||
error_log: str | None = None
|
||||
61
apps/backend/app/schemas/schedules.py
Normal file
61
apps/backend/app/schemas/schedules.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""调度配置 Pydantic 模型
|
||||
|
||||
定义 ScheduleConfigSchema 及相关模型,供调度服务和路由使用。
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ScheduleConfigSchema(BaseModel):
|
||||
"""调度配置 — 支持 5 种调度类型"""
|
||||
|
||||
schedule_type: Literal["once", "interval", "daily", "weekly", "cron"]
|
||||
interval_value: int = 1
|
||||
interval_unit: Literal["minutes", "hours", "days"] = "hours"
|
||||
daily_time: str = "04:00"
|
||||
weekly_days: list[int] = [1]
|
||||
weekly_time: str = "04:00"
|
||||
cron_expression: str = "0 4 * * *"
|
||||
enabled: bool = True
|
||||
start_date: str | None = None
|
||||
end_date: str | None = None
|
||||
|
||||
|
||||
class CreateScheduleRequest(BaseModel):
|
||||
"""创建调度任务请求"""
|
||||
|
||||
name: str
|
||||
task_codes: list[str]
|
||||
task_config: dict[str, Any]
|
||||
schedule_config: ScheduleConfigSchema
|
||||
|
||||
|
||||
class UpdateScheduleRequest(BaseModel):
|
||||
"""更新调度任务请求(所有字段可选)"""
|
||||
|
||||
name: str | None = None
|
||||
task_codes: list[str] | None = None
|
||||
task_config: dict[str, Any] | None = None
|
||||
schedule_config: ScheduleConfigSchema | None = None
|
||||
|
||||
|
||||
class ScheduleResponse(BaseModel):
|
||||
"""调度任务响应"""
|
||||
|
||||
id: str
|
||||
site_id: int
|
||||
name: str
|
||||
task_codes: list[str]
|
||||
task_config: dict[str, Any]
|
||||
schedule_config: dict[str, Any]
|
||||
enabled: bool
|
||||
last_run_at: datetime | None = None
|
||||
next_run_at: datetime | None = None
|
||||
run_count: int
|
||||
last_status: str | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
73
apps/backend/app/schemas/tasks.py
Normal file
73
apps/backend/app/schemas/tasks.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""任务配置 Pydantic 模型
|
||||
|
||||
定义 TaskConfigSchema 及相关模型,用于前后端传输和 CLIBuilder 消费。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
||||
class TaskConfigSchema(BaseModel):
|
||||
"""任务配置 — 前后端传输格式
|
||||
|
||||
字段与 CLI 参数的映射关系:
|
||||
- pipeline → --pipeline(Flow ID,7 种之一)
|
||||
- processing_mode → --processing-mode(3 种处理模式)
|
||||
- tasks → --tasks(逗号分隔)
|
||||
- dry_run → --dry-run(布尔标志)
|
||||
- window_mode → 决定使用 lookback 还是 custom 时间窗口(仅前端逻辑,不直接映射 CLI 参数)
|
||||
- window_start → --window-start
|
||||
- window_end → --window-end
|
||||
- window_split → --window-split
|
||||
- window_split_days → --window-split-days
|
||||
- lookback_hours → --lookback-hours
|
||||
- overlap_seconds → --overlap-seconds
|
||||
- fetch_before_verify → --fetch-before-verify(布尔标志)
|
||||
- store_id → --store-id(由后端从 JWT 注入,前端不传)
|
||||
- dwd_only_tables → 传入 extra_args 或未来扩展
|
||||
"""
|
||||
|
||||
tasks: list[str]
|
||||
pipeline: str = "api_ods_dwd"
|
||||
processing_mode: str = "increment_only"
|
||||
dry_run: bool = False
|
||||
window_mode: str = "lookback"
|
||||
window_start: str | None = None
|
||||
window_end: str | None = None
|
||||
window_split: str | None = None
|
||||
window_split_days: int | None = None
|
||||
lookback_hours: int = 24
|
||||
overlap_seconds: int = 600
|
||||
fetch_before_verify: bool = False
|
||||
skip_ods_when_fetch_before_verify: bool = False
|
||||
ods_use_local_json: bool = False
|
||||
store_id: int | None = None
|
||||
dwd_only_tables: list[str] | None = None
|
||||
force_full: bool = False
|
||||
extra_args: dict[str, Any] = {}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_window(self) -> "TaskConfigSchema":
|
||||
"""验证时间窗口:结束日期不早于开始日期"""
|
||||
if self.window_start and self.window_end:
|
||||
if self.window_end < self.window_start:
|
||||
raise ValueError("window_end 不能早于 window_start")
|
||||
return self
|
||||
|
||||
|
||||
class FlowDefinition(BaseModel):
|
||||
"""执行流程(Flow)定义"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
layers: list[str]
|
||||
|
||||
|
||||
class ProcessingModeDefinition(BaseModel):
|
||||
"""处理模式定义"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
1
apps/backend/app/services/__init__.py
Normal file
1
apps/backend/app/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
158
apps/backend/app/services/cli_builder.py
Normal file
158
apps/backend/app/services/cli_builder.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""CLI 命令构建器
|
||||
|
||||
从 gui/utils/cli_builder.py 迁移,适配后端 TaskConfigSchema。
|
||||
将 TaskConfigSchema 转换为 ETL CLI 命令行参数列表。
|
||||
|
||||
支持:
|
||||
- 7 种 Flow(api_ods / api_ods_dwd / api_full / ods_dwd / dwd_dws / dwd_dws_index / dwd_index)
|
||||
- 3 种处理模式(increment_only / verify_only / increment_verify)
|
||||
- 自动注入 --store-id 参数
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ..schemas.tasks import TaskConfigSchema
|
||||
|
||||
# 有效的 Flow ID 集合
|
||||
VALID_FLOWS: set[str] = {
|
||||
"api_ods",
|
||||
"api_ods_dwd",
|
||||
"api_full",
|
||||
"ods_dwd",
|
||||
"dwd_dws",
|
||||
"dwd_dws_index",
|
||||
"dwd_index",
|
||||
}
|
||||
|
||||
# 有效的处理模式集合
|
||||
VALID_PROCESSING_MODES: set[str] = {
|
||||
"increment_only",
|
||||
"verify_only",
|
||||
"increment_verify",
|
||||
}
|
||||
|
||||
# CLI 支持的 extra_args 键(值类型 + 布尔类型)
|
||||
CLI_SUPPORTED_ARGS: set[str] = {
|
||||
# 值类型参数
|
||||
"pg_dsn", "pg_host", "pg_port", "pg_name",
|
||||
"pg_user", "pg_password", "api_base", "api_token", "api_timeout",
|
||||
"api_page_size", "api_retry_max",
|
||||
"export_root", "log_root", "fetch_root",
|
||||
"ingest_source", "idle_start", "idle_end",
|
||||
"data_source", "pipeline_flow",
|
||||
"window_split_unit",
|
||||
# 布尔类型参数
|
||||
"force_window_override", "write_pretty_json", "allow_empty_advance",
|
||||
}
|
||||
|
||||
|
||||
class CLIBuilder:
|
||||
"""将 TaskConfigSchema 转换为 ETL CLI 命令行参数列表"""
|
||||
|
||||
def build_command(
|
||||
self,
|
||||
config: TaskConfigSchema,
|
||||
etl_project_path: str,
|
||||
python_executable: str = "python",
|
||||
) -> list[str]:
|
||||
"""构建完整的 CLI 命令参数列表。
|
||||
|
||||
生成格式:
|
||||
[python, -m, cli.main, --flow, {flow_id}, --tasks, ..., --store-id, {site_id}, ...]
|
||||
|
||||
Args:
|
||||
config: 任务配置对象(Pydantic 模型)
|
||||
etl_project_path: ETL 项目根目录路径(用于 cwd,不拼入命令)
|
||||
python_executable: Python 可执行文件路径,默认 "python"
|
||||
|
||||
Returns:
|
||||
命令行参数列表
|
||||
"""
|
||||
cmd: list[str] = [python_executable, "-m", "cli.main"]
|
||||
|
||||
# -- Flow(执行流程) --
|
||||
cmd.extend(["--flow", config.pipeline])
|
||||
|
||||
# -- 处理模式 --
|
||||
if config.processing_mode:
|
||||
cmd.extend(["--processing-mode", config.processing_mode])
|
||||
|
||||
# -- 任务列表 --
|
||||
if config.tasks:
|
||||
cmd.extend(["--tasks", ",".join(config.tasks)])
|
||||
|
||||
# -- 校验前从 API 获取数据(仅 verify_only 模式有效) --
|
||||
if config.fetch_before_verify and config.processing_mode == "verify_only":
|
||||
cmd.append("--fetch-before-verify")
|
||||
|
||||
# -- 时间窗口 --
|
||||
if config.window_mode == "lookback":
|
||||
# 回溯模式
|
||||
if config.lookback_hours is not None:
|
||||
cmd.extend(["--lookback-hours", str(config.lookback_hours)])
|
||||
if config.overlap_seconds is not None:
|
||||
cmd.extend(["--overlap-seconds", str(config.overlap_seconds)])
|
||||
else:
|
||||
# 自定义时间窗口
|
||||
if config.window_start:
|
||||
cmd.extend(["--window-start", config.window_start])
|
||||
if config.window_end:
|
||||
cmd.extend(["--window-end", config.window_end])
|
||||
|
||||
# -- 时间窗口切分 --
|
||||
if config.window_split and config.window_split != "none":
|
||||
cmd.extend(["--window-split", config.window_split])
|
||||
if config.window_split_days is not None:
|
||||
cmd.extend(["--window-split-days", str(config.window_split_days)])
|
||||
|
||||
# -- Dry-run --
|
||||
if config.dry_run:
|
||||
cmd.append("--dry-run")
|
||||
|
||||
# -- 强制全量处理 --
|
||||
if config.force_full:
|
||||
cmd.append("--force-full")
|
||||
|
||||
# -- 本地 JSON 模式 → --data-source offline --
|
||||
if config.ods_use_local_json:
|
||||
cmd.extend(["--data-source", "offline"])
|
||||
|
||||
# -- 门店 ID(自动注入) --
|
||||
if config.store_id is not None:
|
||||
cmd.extend(["--store-id", str(config.store_id)])
|
||||
|
||||
# -- 额外参数(只传递 CLI 支持的参数) --
|
||||
for key, value in config.extra_args.items():
|
||||
if value is not None and key in CLI_SUPPORTED_ARGS:
|
||||
arg_name = f"--{key.replace('_', '-')}"
|
||||
if isinstance(value, bool):
|
||||
if value:
|
||||
cmd.append(arg_name)
|
||||
else:
|
||||
cmd.extend([arg_name, str(value)])
|
||||
|
||||
return cmd
|
||||
|
||||
def build_command_string(
|
||||
self,
|
||||
config: TaskConfigSchema,
|
||||
etl_project_path: str,
|
||||
python_executable: str = "python",
|
||||
) -> str:
|
||||
"""构建命令行字符串(用于显示/日志记录)。
|
||||
|
||||
对包含空格的参数自动添加引号。
|
||||
"""
|
||||
cmd = self.build_command(config, etl_project_path, python_executable)
|
||||
quoted: list[str] = []
|
||||
for arg in cmd:
|
||||
if " " in arg or '"' in arg:
|
||||
quoted.append(f'"{arg}"')
|
||||
else:
|
||||
quoted.append(arg)
|
||||
return " ".join(quoted)
|
||||
|
||||
|
||||
# 全局单例
|
||||
cli_builder = CLIBuilder()
|
||||
303
apps/backend/app/services/scheduler.py
Normal file
303
apps/backend/app/services/scheduler.py
Normal file
@@ -0,0 +1,303 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""调度器服务
|
||||
|
||||
后台 asyncio 循环,每 30 秒检查一次到期的调度任务,
|
||||
将其 TaskConfig 入队到 TaskQueue。
|
||||
|
||||
核心逻辑:
|
||||
- check_and_enqueue():查询 enabled=true 且 next_run_at <= now 的调度任务
|
||||
- start() / stop():管理后台循环生命周期
|
||||
- _calculate_next_run():根据 ScheduleConfig 计算下次执行时间
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from ..database import get_connection
|
||||
from ..schemas.schedules import ScheduleConfigSchema
|
||||
from ..schemas.tasks import TaskConfigSchema
|
||||
from .task_queue import task_queue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 调度器轮询间隔(秒)
|
||||
SCHEDULER_POLL_INTERVAL = 30
|
||||
|
||||
|
||||
def _parse_time(time_str: str) -> tuple[int, int]:
|
||||
"""解析 HH:MM 格式的时间字符串,返回 (hour, minute)。"""
|
||||
parts = time_str.split(":")
|
||||
return int(parts[0]), int(parts[1])
|
||||
|
||||
|
||||
def calculate_next_run(
|
||||
schedule_config: ScheduleConfigSchema,
|
||||
now: datetime | None = None,
|
||||
) -> datetime | None:
|
||||
"""根据调度配置计算下次执行时间。
|
||||
|
||||
Args:
|
||||
schedule_config: 调度配置
|
||||
now: 当前时间(默认 UTC now),方便测试注入
|
||||
|
||||
Returns:
|
||||
下次执行时间(UTC),once 类型返回 None 表示不再执行
|
||||
"""
|
||||
if now is None:
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
stype = schedule_config.schedule_type
|
||||
|
||||
if stype == "once":
|
||||
# 一次性任务执行后不再调度
|
||||
return None
|
||||
|
||||
if stype == "interval":
|
||||
unit_map = {
|
||||
"minutes": timedelta(minutes=schedule_config.interval_value),
|
||||
"hours": timedelta(hours=schedule_config.interval_value),
|
||||
"days": timedelta(days=schedule_config.interval_value),
|
||||
}
|
||||
delta = unit_map.get(schedule_config.interval_unit)
|
||||
if delta is None:
|
||||
logger.warning("未知的 interval_unit: %s", schedule_config.interval_unit)
|
||||
return None
|
||||
return now + delta
|
||||
|
||||
if stype == "daily":
|
||||
hour, minute = _parse_time(schedule_config.daily_time)
|
||||
# 计算明天的 daily_time
|
||||
tomorrow = now + timedelta(days=1)
|
||||
return tomorrow.replace(hour=hour, minute=minute, second=0, microsecond=0)
|
||||
|
||||
if stype == "weekly":
|
||||
hour, minute = _parse_time(schedule_config.weekly_time)
|
||||
days = sorted(schedule_config.weekly_days) if schedule_config.weekly_days else [1]
|
||||
# ISO weekday: 1=Monday ... 7=Sunday
|
||||
current_weekday = now.isoweekday()
|
||||
|
||||
# 找到下一个匹配的 weekday
|
||||
for day in days:
|
||||
if day > current_weekday:
|
||||
delta_days = day - current_weekday
|
||||
next_dt = now + timedelta(days=delta_days)
|
||||
return next_dt.replace(hour=hour, minute=minute, second=0, microsecond=0)
|
||||
|
||||
# 本周没有更晚的 weekday,跳到下周第一个
|
||||
first_day = days[0]
|
||||
delta_days = 7 - current_weekday + first_day
|
||||
next_dt = now + timedelta(days=delta_days)
|
||||
return next_dt.replace(hour=hour, minute=minute, second=0, microsecond=0)
|
||||
|
||||
if stype == "cron":
|
||||
# 简单 cron 解析:仅支持 "minute hour * * *" 格式(每日定时)
|
||||
# 复杂 cron 表达式可后续引入 croniter 库
|
||||
return _parse_simple_cron(schedule_config.cron_expression, now)
|
||||
|
||||
logger.warning("未知的 schedule_type: %s", stype)
|
||||
return None
|
||||
|
||||
|
||||
def _parse_simple_cron(expression: str, now: datetime) -> datetime | None:
|
||||
"""简单 cron 解析器,支持基本的 5 字段格式。
|
||||
|
||||
支持的格式:
|
||||
- "M H * * *" → 每天 H:M
|
||||
- "M H * * D" → 每周 D 的 H:M(D 为 0-6,0=Sunday)
|
||||
- 其他格式回退到每天 04:00
|
||||
|
||||
不支持范围、列表、步进等高级语法。如需完整 cron 支持,
|
||||
可在 pyproject.toml 中添加 croniter 依赖。
|
||||
"""
|
||||
parts = expression.strip().split()
|
||||
if len(parts) != 5:
|
||||
logger.warning("无法解析 cron 表达式: %s,回退到明天 04:00", expression)
|
||||
tomorrow = now + timedelta(days=1)
|
||||
return tomorrow.replace(hour=4, minute=0, second=0, microsecond=0)
|
||||
|
||||
minute_str, hour_str, dom, month, dow = parts
|
||||
|
||||
try:
|
||||
minute = int(minute_str) if minute_str != "*" else 0
|
||||
hour = int(hour_str) if hour_str != "*" else 0
|
||||
except ValueError:
|
||||
logger.warning("cron 表达式时间字段无法解析: %s,回退到明天 04:00", expression)
|
||||
tomorrow = now + timedelta(days=1)
|
||||
return tomorrow.replace(hour=4, minute=0, second=0, microsecond=0)
|
||||
|
||||
# 如果指定了 day-of-week(非 *)
|
||||
if dow != "*":
|
||||
try:
|
||||
cron_dow = int(dow) # 0=Sunday, 1=Monday, ..., 6=Saturday
|
||||
except ValueError:
|
||||
tomorrow = now + timedelta(days=1)
|
||||
return tomorrow.replace(hour=hour, minute=minute, second=0, microsecond=0)
|
||||
|
||||
# 转换为 ISO weekday(1=Monday, 7=Sunday)
|
||||
iso_dow = 7 if cron_dow == 0 else cron_dow
|
||||
current_iso = now.isoweekday()
|
||||
|
||||
if iso_dow > current_iso:
|
||||
delta_days = iso_dow - current_iso
|
||||
elif iso_dow < current_iso:
|
||||
delta_days = 7 - current_iso + iso_dow
|
||||
else:
|
||||
# 同一天,看时间是否已过
|
||||
target_today = now.replace(hour=hour, minute=minute, second=0, microsecond=0)
|
||||
if now < target_today:
|
||||
delta_days = 0
|
||||
else:
|
||||
delta_days = 7
|
||||
|
||||
next_dt = now + timedelta(days=delta_days)
|
||||
return next_dt.replace(hour=hour, minute=minute, second=0, microsecond=0)
|
||||
|
||||
# 每天定时(dom=* month=* dow=*)
|
||||
tomorrow = now + timedelta(days=1)
|
||||
return tomorrow.replace(hour=hour, minute=minute, second=0, microsecond=0)
|
||||
|
||||
|
||||
class Scheduler:
|
||||
"""基于 PostgreSQL 的定时调度器
|
||||
|
||||
后台 asyncio 循环每 SCHEDULER_POLL_INTERVAL 秒检查一次到期任务,
|
||||
将其 TaskConfig 入队到 TaskQueue。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._running = False
|
||||
self._loop_task: asyncio.Task | None = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 核心:检查到期任务并入队
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def check_and_enqueue(self) -> int:
|
||||
"""查询 enabled=true 且 next_run_at <= now 的调度任务,将其入队。
|
||||
|
||||
Returns:
|
||||
本次入队的任务数量
|
||||
"""
|
||||
conn = get_connection()
|
||||
enqueued = 0
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, site_id, task_config, schedule_config
|
||||
FROM scheduled_tasks
|
||||
WHERE enabled = TRUE
|
||||
AND next_run_at IS NOT NULL
|
||||
AND next_run_at <= NOW()
|
||||
ORDER BY next_run_at ASC
|
||||
"""
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
|
||||
for row in rows:
|
||||
task_id = str(row[0])
|
||||
site_id = row[1]
|
||||
task_config_raw = row[2] if isinstance(row[2], dict) else json.loads(row[2])
|
||||
schedule_config_raw = row[3] if isinstance(row[3], dict) else json.loads(row[3])
|
||||
|
||||
try:
|
||||
config = TaskConfigSchema(**task_config_raw)
|
||||
schedule_cfg = ScheduleConfigSchema(**schedule_config_raw)
|
||||
except Exception:
|
||||
logger.exception("调度任务 [%s] 配置反序列化失败,跳过", task_id)
|
||||
continue
|
||||
|
||||
# 入队
|
||||
try:
|
||||
queue_id = task_queue.enqueue(config, site_id)
|
||||
logger.info(
|
||||
"调度任务 [%s] 入队成功 → queue_id=%s site_id=%s",
|
||||
task_id, queue_id, site_id,
|
||||
)
|
||||
enqueued += 1
|
||||
except Exception:
|
||||
logger.exception("调度任务 [%s] 入队失败", task_id)
|
||||
continue
|
||||
|
||||
# 更新调度任务状态
|
||||
now = datetime.now(timezone.utc)
|
||||
next_run = calculate_next_run(schedule_cfg, now)
|
||||
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE scheduled_tasks
|
||||
SET last_run_at = NOW(),
|
||||
run_count = run_count + 1,
|
||||
next_run_at = %s,
|
||||
last_status = 'enqueued',
|
||||
updated_at = NOW()
|
||||
WHERE id = %s
|
||||
""",
|
||||
(next_run, task_id),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
except Exception:
|
||||
logger.exception("check_and_enqueue 执行异常")
|
||||
try:
|
||||
conn.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if enqueued > 0:
|
||||
logger.info("本轮调度检查:%d 个任务入队", enqueued)
|
||||
return enqueued
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 后台循环
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _loop(self) -> None:
|
||||
"""后台 asyncio 循环,每 SCHEDULER_POLL_INTERVAL 秒检查一次。"""
|
||||
self._running = True
|
||||
logger.info("Scheduler 后台循环启动(间隔 %ds)", SCHEDULER_POLL_INTERVAL)
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# 在线程池中执行同步数据库操作,避免阻塞事件循环
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, self.check_and_enqueue)
|
||||
except Exception:
|
||||
logger.exception("Scheduler 循环迭代异常")
|
||||
|
||||
await asyncio.sleep(SCHEDULER_POLL_INTERVAL)
|
||||
|
||||
logger.info("Scheduler 后台循环停止")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 生命周期
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def start(self) -> None:
|
||||
"""启动后台调度循环(在 FastAPI lifespan 中调用)。"""
|
||||
if self._loop_task is None or self._loop_task.done():
|
||||
self._loop_task = asyncio.create_task(self._loop())
|
||||
logger.info("Scheduler 已启动")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止后台调度循环。"""
|
||||
self._running = False
|
||||
if self._loop_task and not self._loop_task.done():
|
||||
self._loop_task.cancel()
|
||||
try:
|
||||
await self._loop_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._loop_task = None
|
||||
logger.info("Scheduler 已停止")
|
||||
|
||||
|
||||
# 全局单例
|
||||
scheduler = Scheduler()
|
||||
391
apps/backend/app/services/task_executor.py
Normal file
391
apps/backend/app/services/task_executor.py
Normal file
@@ -0,0 +1,391 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""ETL 任务执行器
|
||||
|
||||
通过 asyncio.create_subprocess_exec 启动 ETL CLI 子进程,
|
||||
逐行读取 stdout/stderr 并广播到 WebSocket 订阅者,
|
||||
执行完成后将结果写入 task_execution_log 表。
|
||||
|
||||
设计要点:
|
||||
- 每个 execution_id 对应一个子进程,存储在 _processes 字典中
|
||||
- 日志行存储在内存缓冲区 _log_buffers 中
|
||||
- WebSocket 订阅者通过 asyncio.Queue 接收实时日志
|
||||
- Windows 兼容:取消时使用 process.terminate() 而非 SIGTERM
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from ..config import ETL_PROJECT_PATH
|
||||
from ..database import get_connection
|
||||
from ..schemas.tasks import TaskConfigSchema
|
||||
from ..services.cli_builder import cli_builder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskExecutor:
|
||||
"""管理 ETL CLI 子进程的生命周期"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# execution_id → subprocess.Popen
|
||||
self._processes: dict[str, subprocess.Popen] = {}
|
||||
# execution_id → list[str](stdout + stderr 混合日志)
|
||||
self._log_buffers: dict[str, list[str]] = {}
|
||||
# execution_id → set[asyncio.Queue](WebSocket 订阅者)
|
||||
self._subscribers: dict[str, set[asyncio.Queue[str | None]]] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# WebSocket 订阅管理
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def subscribe(self, execution_id: str) -> asyncio.Queue[str | None]:
|
||||
"""注册一个 WebSocket 订阅者,返回用于读取日志行的 Queue。
|
||||
|
||||
Queue 中推送 str 表示日志行,None 表示执行结束。
|
||||
"""
|
||||
if execution_id not in self._subscribers:
|
||||
self._subscribers[execution_id] = set()
|
||||
queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||||
self._subscribers[execution_id].add(queue)
|
||||
return queue
|
||||
|
||||
def unsubscribe(self, execution_id: str, queue: asyncio.Queue[str | None]) -> None:
|
||||
"""移除一个 WebSocket 订阅者。"""
|
||||
subs = self._subscribers.get(execution_id)
|
||||
if subs:
|
||||
subs.discard(queue)
|
||||
if not subs:
|
||||
del self._subscribers[execution_id]
|
||||
|
||||
def _broadcast(self, execution_id: str, line: str) -> None:
|
||||
"""向所有订阅者广播一行日志。"""
|
||||
subs = self._subscribers.get(execution_id)
|
||||
if subs:
|
||||
for q in subs:
|
||||
q.put_nowait(line)
|
||||
|
||||
def _broadcast_end(self, execution_id: str) -> None:
|
||||
"""通知所有订阅者执行已结束(发送 None 哨兵)。"""
|
||||
subs = self._subscribers.get(execution_id)
|
||||
if subs:
|
||||
for q in subs:
|
||||
q.put_nowait(None)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 日志缓冲区
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_logs(self, execution_id: str) -> list[str]:
|
||||
"""获取指定执行的内存日志缓冲区(副本)。"""
|
||||
return list(self._log_buffers.get(execution_id, []))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 执行状态查询
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def is_running(self, execution_id: str) -> bool:
|
||||
"""判断指定执行是否仍在运行。"""
|
||||
proc = self._processes.get(execution_id)
|
||||
if proc is None:
|
||||
return False
|
||||
return proc.poll() is None
|
||||
|
||||
def get_running_ids(self) -> list[str]:
|
||||
"""返回当前所有运行中的 execution_id 列表。"""
|
||||
return [eid for eid, p in self._processes.items() if p.returncode is None]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 核心执行
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
config: TaskConfigSchema,
|
||||
execution_id: str,
|
||||
queue_id: str | None = None,
|
||||
site_id: int | None = None,
|
||||
) -> None:
|
||||
"""以子进程方式调用 ETL CLI。
|
||||
|
||||
使用 subprocess.Popen + 线程读取,兼容 Windows(避免
|
||||
asyncio.create_subprocess_exec 在 Windows 上的 NotImplementedError)。
|
||||
"""
|
||||
cmd = cli_builder.build_command(
|
||||
config, ETL_PROJECT_PATH, python_executable=sys.executable
|
||||
)
|
||||
command_str = " ".join(cmd)
|
||||
effective_site_id = site_id or config.store_id
|
||||
|
||||
logger.info(
|
||||
"启动 ETL 子进程 [%s]: %s (cwd=%s)",
|
||||
execution_id, command_str, ETL_PROJECT_PATH,
|
||||
)
|
||||
|
||||
self._log_buffers[execution_id] = []
|
||||
started_at = datetime.now(timezone.utc)
|
||||
t0 = time.monotonic()
|
||||
|
||||
self._write_execution_log(
|
||||
execution_id=execution_id,
|
||||
queue_id=queue_id,
|
||||
site_id=effective_site_id,
|
||||
task_codes=config.tasks,
|
||||
status="running",
|
||||
started_at=started_at,
|
||||
command=command_str,
|
||||
)
|
||||
|
||||
exit_code: int | None = None
|
||||
status = "running"
|
||||
stdout_lines: list[str] = []
|
||||
stderr_lines: list[str] = []
|
||||
|
||||
try:
|
||||
# 构建额外环境变量(DWD 表过滤通过环境变量注入)
|
||||
extra_env: dict[str, str] = {}
|
||||
if config.dwd_only_tables:
|
||||
extra_env["DWD_ONLY_TABLES"] = ",".join(config.dwd_only_tables)
|
||||
|
||||
# 在线程池中运行子进程,兼容 Windows
|
||||
exit_code = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
self._run_subprocess,
|
||||
cmd,
|
||||
execution_id,
|
||||
stdout_lines,
|
||||
stderr_lines,
|
||||
extra_env or None,
|
||||
)
|
||||
|
||||
if exit_code == 0:
|
||||
status = "success"
|
||||
else:
|
||||
status = "failed"
|
||||
|
||||
logger.info(
|
||||
"ETL 子进程 [%s] 退出,exit_code=%s, status=%s",
|
||||
execution_id, exit_code, status,
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
status = "cancelled"
|
||||
logger.info("ETL 子进程 [%s] 已取消", execution_id)
|
||||
# 尝试终止子进程
|
||||
proc = self._processes.get(execution_id)
|
||||
if proc and proc.poll() is None:
|
||||
proc.terminate()
|
||||
except Exception as exc:
|
||||
status = "failed"
|
||||
import traceback
|
||||
tb = traceback.format_exc()
|
||||
stderr_lines.append(f"[task_executor] 子进程启动/执行异常: {exc}")
|
||||
stderr_lines.append(tb)
|
||||
logger.exception("ETL 子进程 [%s] 执行异常", execution_id)
|
||||
finally:
|
||||
elapsed_ms = int((time.monotonic() - t0) * 1000)
|
||||
finished_at = datetime.now(timezone.utc)
|
||||
|
||||
self._broadcast_end(execution_id)
|
||||
self._processes.pop(execution_id, None)
|
||||
|
||||
self._update_execution_log(
|
||||
execution_id=execution_id,
|
||||
status=status,
|
||||
finished_at=finished_at,
|
||||
exit_code=exit_code,
|
||||
duration_ms=elapsed_ms,
|
||||
output_log="\n".join(stdout_lines),
|
||||
error_log="\n".join(stderr_lines),
|
||||
)
|
||||
|
||||
def _run_subprocess(
|
||||
self,
|
||||
cmd: list[str],
|
||||
execution_id: str,
|
||||
stdout_lines: list[str],
|
||||
stderr_lines: list[str],
|
||||
extra_env: dict[str, str] | None = None,
|
||||
) -> int:
|
||||
"""在线程中运行子进程并逐行读取输出。"""
|
||||
import os
|
||||
env = os.environ.copy()
|
||||
# 强制子进程使用 UTF-8 输出,避免 Windows GBK 乱码
|
||||
env["PYTHONIOENCODING"] = "utf-8"
|
||||
if extra_env:
|
||||
env.update(extra_env)
|
||||
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=ETL_PROJECT_PATH,
|
||||
env=env,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
)
|
||||
self._processes[execution_id] = proc
|
||||
|
||||
def read_stream(
|
||||
stream, stream_name: str, collector: list[str],
|
||||
) -> None:
|
||||
"""逐行读取流并广播。"""
|
||||
for raw_line in stream:
|
||||
line = raw_line.rstrip("\n").rstrip("\r")
|
||||
tagged = f"[{stream_name}] {line}"
|
||||
buf = self._log_buffers.get(execution_id)
|
||||
if buf is not None:
|
||||
buf.append(tagged)
|
||||
collector.append(line)
|
||||
self._broadcast(execution_id, tagged)
|
||||
|
||||
t_out = threading.Thread(
|
||||
target=read_stream, args=(proc.stdout, "stdout", stdout_lines),
|
||||
daemon=True,
|
||||
)
|
||||
t_err = threading.Thread(
|
||||
target=read_stream, args=(proc.stderr, "stderr", stderr_lines),
|
||||
daemon=True,
|
||||
)
|
||||
t_out.start()
|
||||
t_err.start()
|
||||
|
||||
proc.wait()
|
||||
t_out.join(timeout=5)
|
||||
t_err.join(timeout=5)
|
||||
|
||||
return proc.returncode
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 取消
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def cancel(self, execution_id: str) -> bool:
|
||||
"""向子进程发送终止信号。
|
||||
|
||||
Returns:
|
||||
True 表示成功发送终止信号,False 表示进程不存在或已退出。
|
||||
"""
|
||||
proc = self._processes.get(execution_id)
|
||||
if proc is None:
|
||||
return False
|
||||
# subprocess.Popen: poll() 返回 None 表示仍在运行
|
||||
if proc.poll() is not None:
|
||||
return False
|
||||
|
||||
logger.info("取消 ETL 子进程 [%s], pid=%s", execution_id, proc.pid)
|
||||
try:
|
||||
proc.terminate()
|
||||
except ProcessLookupError:
|
||||
return False
|
||||
return True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 数据库操作(同步,在线程池中执行也可,此处简单直连)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _write_execution_log(
|
||||
*,
|
||||
execution_id: str,
|
||||
queue_id: str | None,
|
||||
site_id: int | None,
|
||||
task_codes: list[str],
|
||||
status: str,
|
||||
started_at: datetime,
|
||||
command: str,
|
||||
) -> None:
|
||||
"""插入一条执行日志记录(running 状态)。"""
|
||||
try:
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO task_execution_log
|
||||
(id, queue_id, site_id, task_codes, status,
|
||||
started_at, command)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
||||
""",
|
||||
(
|
||||
execution_id,
|
||||
queue_id,
|
||||
site_id or 0,
|
||||
task_codes,
|
||||
status,
|
||||
started_at,
|
||||
command,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
except Exception:
|
||||
logger.exception("写入 execution_log 失败 [%s]", execution_id)
|
||||
|
||||
@staticmethod
|
||||
def _update_execution_log(
|
||||
*,
|
||||
execution_id: str,
|
||||
status: str,
|
||||
finished_at: datetime,
|
||||
exit_code: int | None,
|
||||
duration_ms: int,
|
||||
output_log: str,
|
||||
error_log: str,
|
||||
) -> None:
|
||||
"""更新执行日志记录(完成状态)。"""
|
||||
try:
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE task_execution_log
|
||||
SET status = %s,
|
||||
finished_at = %s,
|
||||
exit_code = %s,
|
||||
duration_ms = %s,
|
||||
output_log = %s,
|
||||
error_log = %s
|
||||
WHERE id = %s
|
||||
""",
|
||||
(
|
||||
status,
|
||||
finished_at,
|
||||
exit_code,
|
||||
duration_ms,
|
||||
output_log,
|
||||
error_log,
|
||||
execution_id,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
except Exception:
|
||||
logger.exception("更新 execution_log 失败 [%s]", execution_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 清理
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def cleanup(self, execution_id: str) -> None:
|
||||
"""清理指定执行的内存资源(日志缓冲区和订阅者)。
|
||||
|
||||
通常在确认日志已持久化后调用。
|
||||
"""
|
||||
self._log_buffers.pop(execution_id, None)
|
||||
self._subscribers.pop(execution_id, None)
|
||||
|
||||
|
||||
# 全局单例
|
||||
task_executor = TaskExecutor()
|
||||
486
apps/backend/app/services/task_queue.py
Normal file
486
apps/backend/app/services/task_queue.py
Normal file
@@ -0,0 +1,486 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""任务队列服务
|
||||
|
||||
基于 PostgreSQL task_queue 表实现 FIFO 队列,支持:
|
||||
- enqueue:入队,自动分配 position(当前最大 + 1)
|
||||
- dequeue:取出 position 最小的 pending 任务
|
||||
- reorder:调整任务在队列中的位置
|
||||
- delete:删除 pending 任务
|
||||
- process_loop:后台协程,队列非空且无运行中任务时自动取出执行
|
||||
|
||||
所有操作按 site_id 过滤,实现门店隔离。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..database import get_connection
|
||||
from ..schemas.tasks import TaskConfigSchema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 后台循环轮询间隔(秒)
|
||||
POLL_INTERVAL_SECONDS = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueuedTask:
|
||||
"""队列任务数据对象"""
|
||||
|
||||
id: str
|
||||
site_id: int
|
||||
config: dict[str, Any]
|
||||
status: str
|
||||
position: int
|
||||
created_at: Any = None
|
||||
started_at: Any = None
|
||||
finished_at: Any = None
|
||||
exit_code: int | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class TaskQueue:
|
||||
"""基于 PostgreSQL 的任务队列"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._running = False
|
||||
self._loop_task: asyncio.Task | None = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 入队
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def enqueue(self, config: TaskConfigSchema, site_id: int) -> str:
|
||||
"""将任务配置入队,自动分配 position。
|
||||
|
||||
Args:
|
||||
config: 任务配置
|
||||
site_id: 门店 ID(门店隔离)
|
||||
|
||||
Returns:
|
||||
新创建的队列任务 ID(UUID 字符串)
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
config_json = config.model_dump(mode="json")
|
||||
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
# 取当前该门店 pending 任务的最大 position,新任务排在末尾
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT COALESCE(MAX(position), 0)
|
||||
FROM task_queue
|
||||
WHERE site_id = %s AND status = 'pending'
|
||||
""",
|
||||
(site_id,),
|
||||
)
|
||||
max_pos = cur.fetchone()[0]
|
||||
new_pos = max_pos + 1
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO task_queue (id, site_id, config, status, position)
|
||||
VALUES (%s, %s, %s, 'pending', %s)
|
||||
""",
|
||||
(task_id, site_id, json.dumps(config_json), new_pos),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
logger.info("任务入队 [%s] site_id=%s position=%s", task_id, site_id, new_pos)
|
||||
return task_id
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 出队
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def dequeue(self, site_id: int) -> QueuedTask | None:
|
||||
"""取出 position 最小的 pending 任务,将其状态改为 running。
|
||||
|
||||
Args:
|
||||
site_id: 门店 ID
|
||||
|
||||
Returns:
|
||||
QueuedTask 或 None(队列为空时)
|
||||
"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
# 选取 position 最小的 pending 任务并锁定
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, site_id, config, status, position,
|
||||
created_at, started_at, finished_at,
|
||||
exit_code, error_message
|
||||
FROM task_queue
|
||||
WHERE site_id = %s AND status = 'pending'
|
||||
ORDER BY position ASC
|
||||
LIMIT 1
|
||||
FOR UPDATE SKIP LOCKED
|
||||
""",
|
||||
(site_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
conn.commit()
|
||||
return None
|
||||
|
||||
task = QueuedTask(
|
||||
id=str(row[0]),
|
||||
site_id=row[1],
|
||||
config=row[2] if isinstance(row[2], dict) else json.loads(row[2]),
|
||||
status=row[3],
|
||||
position=row[4],
|
||||
created_at=row[5],
|
||||
started_at=row[6],
|
||||
finished_at=row[7],
|
||||
exit_code=row[8],
|
||||
error_message=row[9],
|
||||
)
|
||||
|
||||
# 更新状态为 running
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE task_queue
|
||||
SET status = 'running', started_at = NOW()
|
||||
WHERE id = %s
|
||||
""",
|
||||
(task.id,),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
task.status = "running"
|
||||
logger.info("任务出队 [%s] site_id=%s", task.id, site_id)
|
||||
return task
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 重排
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def reorder(self, task_id: str, new_position: int, site_id: int) -> None:
|
||||
"""调整任务在队列中的位置。
|
||||
|
||||
仅允许对 pending 状态的任务重排。将目标任务移到 new_position,
|
||||
其余 pending 任务按原有相对顺序重新编号。
|
||||
|
||||
Args:
|
||||
task_id: 要移动的任务 ID
|
||||
new_position: 目标位置(1-based)
|
||||
site_id: 门店 ID
|
||||
"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
# 获取该门店所有 pending 任务,按 position 排序
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id FROM task_queue
|
||||
WHERE site_id = %s AND status = 'pending'
|
||||
ORDER BY position ASC
|
||||
""",
|
||||
(site_id,),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
task_ids = [str(r[0]) for r in rows]
|
||||
|
||||
if task_id not in task_ids:
|
||||
conn.commit()
|
||||
return
|
||||
|
||||
# 从列表中移除目标任务,再插入到新位置
|
||||
task_ids.remove(task_id)
|
||||
# new_position 是 1-based,转为 0-based 索引并 clamp
|
||||
insert_idx = max(0, min(new_position - 1, len(task_ids)))
|
||||
task_ids.insert(insert_idx, task_id)
|
||||
|
||||
# 按新顺序重新分配 position(1-based 连续编号)
|
||||
for idx, tid in enumerate(task_ids, start=1):
|
||||
cur.execute(
|
||||
"UPDATE task_queue SET position = %s WHERE id = %s",
|
||||
(idx, tid),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
logger.info(
|
||||
"任务重排 [%s] → position=%s site_id=%s",
|
||||
task_id, new_position, site_id,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 删除
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def delete(self, task_id: str, site_id: int) -> bool:
|
||||
"""删除 pending 状态的任务。
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
site_id: 门店 ID
|
||||
|
||||
Returns:
|
||||
True 表示成功删除,False 表示任务不存在或非 pending 状态。
|
||||
"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
DELETE FROM task_queue
|
||||
WHERE id = %s AND site_id = %s AND status = 'pending'
|
||||
""",
|
||||
(task_id, site_id),
|
||||
)
|
||||
deleted = cur.rowcount > 0
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if deleted:
|
||||
logger.info("任务删除 [%s] site_id=%s", task_id, site_id)
|
||||
else:
|
||||
logger.warning(
|
||||
"任务删除失败 [%s] site_id=%s(不存在或非 pending)",
|
||||
task_id, site_id,
|
||||
)
|
||||
return deleted
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 查询
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def list_pending(self, site_id: int) -> list[QueuedTask]:
|
||||
"""列出指定门店的所有 pending 任务,按 position 升序。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, site_id, config, status, position,
|
||||
created_at, started_at, finished_at,
|
||||
exit_code, error_message
|
||||
FROM task_queue
|
||||
WHERE site_id = %s AND status = 'pending'
|
||||
ORDER BY position ASC
|
||||
""",
|
||||
(site_id,),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return [
|
||||
QueuedTask(
|
||||
id=str(r[0]),
|
||||
site_id=r[1],
|
||||
config=r[2] if isinstance(r[2], dict) else json.loads(r[2]),
|
||||
status=r[3],
|
||||
position=r[4],
|
||||
created_at=r[5],
|
||||
started_at=r[6],
|
||||
finished_at=r[7],
|
||||
exit_code=r[8],
|
||||
error_message=r[9],
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
|
||||
def has_running(self, site_id: int) -> bool:
|
||||
"""检查指定门店是否有 running 状态的任务。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM task_queue
|
||||
WHERE site_id = %s AND status = 'running'
|
||||
)
|
||||
""",
|
||||
(site_id,),
|
||||
)
|
||||
result = cur.fetchone()[0]
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 后台处理循环
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def process_loop(self) -> None:
|
||||
"""后台协程:队列非空且无运行中任务时,自动取出并执行。
|
||||
|
||||
循环逻辑:
|
||||
1. 查询所有有 pending 任务的 site_id
|
||||
2. 对每个 site_id,若无 running 任务则 dequeue 并执行
|
||||
3. 等待 POLL_INTERVAL_SECONDS 后重复
|
||||
"""
|
||||
# 延迟导入避免循环依赖
|
||||
from .task_executor import task_executor
|
||||
|
||||
self._running = True
|
||||
logger.info("TaskQueue process_loop 启动")
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
await self._process_once(task_executor)
|
||||
except Exception:
|
||||
logger.exception("process_loop 迭代异常")
|
||||
|
||||
await asyncio.sleep(POLL_INTERVAL_SECONDS)
|
||||
|
||||
logger.info("TaskQueue process_loop 停止")
|
||||
|
||||
async def _process_once(self, executor: Any) -> None:
|
||||
"""单次处理:扫描所有门店的 pending 队列并执行。"""
|
||||
site_ids = self._get_pending_site_ids()
|
||||
|
||||
for site_id in site_ids:
|
||||
if self.has_running(site_id):
|
||||
continue
|
||||
|
||||
task = self.dequeue(site_id)
|
||||
if task is None:
|
||||
continue
|
||||
|
||||
config = TaskConfigSchema(**task.config)
|
||||
execution_id = str(uuid.uuid4())
|
||||
|
||||
logger.info(
|
||||
"process_loop 自动执行 [%s] queue_id=%s site_id=%s",
|
||||
execution_id, task.id, site_id,
|
||||
)
|
||||
|
||||
# 异步启动执行(不阻塞循环)
|
||||
asyncio.create_task(
|
||||
self._execute_and_update(
|
||||
executor, config, execution_id, task.id, site_id,
|
||||
)
|
||||
)
|
||||
|
||||
async def _execute_and_update(
|
||||
self,
|
||||
executor: Any,
|
||||
config: TaskConfigSchema,
|
||||
execution_id: str,
|
||||
queue_id: str,
|
||||
site_id: int,
|
||||
) -> None:
|
||||
"""执行任务并更新队列状态。"""
|
||||
try:
|
||||
await executor.execute(
|
||||
config=config,
|
||||
execution_id=execution_id,
|
||||
queue_id=queue_id,
|
||||
site_id=site_id,
|
||||
)
|
||||
# 执行完成后根据 executor 的结果更新 task_queue 状态
|
||||
self._update_queue_status_from_log(queue_id)
|
||||
except Exception:
|
||||
logger.exception("队列任务执行异常 [%s]", queue_id)
|
||||
self._mark_failed(queue_id, "执行过程中发生未捕获异常")
|
||||
|
||||
def _get_pending_site_ids(self) -> list[int]:
|
||||
"""获取所有有 pending 任务的 site_id 列表。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT DISTINCT site_id FROM task_queue
|
||||
WHERE status = 'pending'
|
||||
"""
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
return [r[0] for r in rows]
|
||||
|
||||
def _update_queue_status_from_log(self, queue_id: str) -> None:
|
||||
"""从 task_execution_log 读取执行结果,同步到 task_queue 记录。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT status, finished_at, exit_code, error_log
|
||||
FROM task_execution_log
|
||||
WHERE queue_id = %s
|
||||
ORDER BY started_at DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE task_queue
|
||||
SET status = %s, finished_at = %s,
|
||||
exit_code = %s, error_message = %s
|
||||
WHERE id = %s
|
||||
""",
|
||||
(row[0], row[1], row[2], row[3], queue_id),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _mark_failed(self, queue_id: str, error_message: str) -> None:
|
||||
"""将队列任务标记为 failed。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE task_queue
|
||||
SET status = 'failed', finished_at = NOW(),
|
||||
error_message = %s
|
||||
WHERE id = %s
|
||||
""",
|
||||
(error_message, queue_id),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 生命周期
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def start(self) -> None:
|
||||
"""启动后台处理循环(在 FastAPI lifespan 中调用)。"""
|
||||
if self._loop_task is None or self._loop_task.done():
|
||||
self._loop_task = asyncio.create_task(self.process_loop())
|
||||
logger.info("TaskQueue 后台循环已启动")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止后台处理循环。"""
|
||||
self._running = False
|
||||
if self._loop_task and not self._loop_task.done():
|
||||
self._loop_task.cancel()
|
||||
try:
|
||||
await self._loop_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._loop_task = None
|
||||
logger.info("TaskQueue 后台循环已停止")
|
||||
|
||||
|
||||
# 全局单例
|
||||
task_queue = TaskQueue()
|
||||
221
apps/backend/app/services/task_registry.py
Normal file
221
apps/backend/app/services/task_registry.py
Normal file
@@ -0,0 +1,221 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""静态任务注册表
|
||||
|
||||
从 ETL orchestration/task_registry.py 提取的任务元数据硬编码副本。
|
||||
后端不直接导入 ETL 代码,避免引入重量级依赖链。
|
||||
|
||||
业务域分组逻辑:按任务代码前缀 / 目标表语义归类,与 GUI 保持一致。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TaskDefinition:
|
||||
"""单个 ETL 任务的元数据"""
|
||||
|
||||
code: str
|
||||
name: str
|
||||
description: str
|
||||
domain: str # 业务域:会员 / 结算 / 助教 / 商品 / 台桌 / 团购 / 库存 / 财务 / 指数 / 工具
|
||||
layer: str # ODS / DWD / DWS / INDEX / UTILITY
|
||||
requires_window: bool = True
|
||||
is_ods: bool = False
|
||||
is_dimension: bool = False
|
||||
default_enabled: bool = True
|
||||
is_common: bool = True # 常用任务标记,False 表示工具类/手动类任务
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DwdTableDefinition:
|
||||
"""DWD 表元数据"""
|
||||
|
||||
table_name: str # 完整表名(含 schema)
|
||||
display_name: str
|
||||
domain: str
|
||||
ods_source: str # 对应的 ODS 源表
|
||||
is_dimension: bool = False
|
||||
|
||||
|
||||
# ── ODS 任务定义 ──────────────────────────────────────────────
|
||||
|
||||
ODS_TASKS: list[TaskDefinition] = [
|
||||
TaskDefinition("ODS_ASSISTANT_ACCOUNT", "助教账号", "抽取助教账号主数据", "助教", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_ASSISTANT_LEDGER", "助教服务记录", "抽取助教服务流水", "助教", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_ASSISTANT_ABOLISH", "助教取消记录", "抽取助教取消/作废记录", "助教", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_SETTLEMENT_RECORDS", "结算记录", "抽取订单结算记录", "结算", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_SETTLEMENT_TICKET", "结账小票", "抽取结账小票明细", "结算", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_TABLE_USE", "台费流水", "抽取台费使用流水", "台桌", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_TABLE_FEE_DISCOUNT", "台费折扣", "抽取台费折扣记录", "台桌", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_TABLES", "台桌主数据", "抽取门店台桌信息", "台桌", "ODS", is_ods=True, requires_window=False),
|
||||
TaskDefinition("ODS_PAYMENT", "支付流水", "抽取支付交易记录", "结算", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_REFUND", "退款流水", "抽取退款交易记录", "结算", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_PLATFORM_COUPON", "平台券核销", "抽取平台优惠券核销记录", "团购", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_MEMBER", "会员主数据", "抽取会员档案", "会员", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_MEMBER_CARD", "会员储值卡", "抽取会员储值卡信息", "会员", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_MEMBER_BALANCE", "会员余额变动", "抽取会员余额变动记录", "会员", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_RECHARGE_SETTLE", "充值结算", "抽取充值结算记录", "会员", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_GROUP_PACKAGE", "团购套餐", "抽取团购套餐定义", "团购", "ODS", is_ods=True, requires_window=False),
|
||||
TaskDefinition("ODS_GROUP_BUY_REDEMPTION", "团购核销", "抽取团购核销记录", "团购", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_INVENTORY_STOCK", "库存快照", "抽取商品库存汇总", "库存", "ODS", is_ods=True, requires_window=False),
|
||||
TaskDefinition("ODS_INVENTORY_CHANGE", "库存变动", "抽取库存出入库记录", "库存", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_GOODS_CATEGORY", "商品分类", "抽取商品分类树", "商品", "ODS", is_ods=True, requires_window=False),
|
||||
TaskDefinition("ODS_STORE_GOODS", "门店商品", "抽取门店商品主数据", "商品", "ODS", is_ods=True, requires_window=False),
|
||||
TaskDefinition("ODS_STORE_GOODS_SALES", "商品销售", "抽取门店商品销售记录", "商品", "ODS", is_ods=True),
|
||||
TaskDefinition("ODS_TENANT_GOODS", "租户商品", "抽取租户级商品主数据", "商品", "ODS", is_ods=True, requires_window=False),
|
||||
]
|
||||
|
||||
# ── DWD 任务定义 ──────────────────────────────────────────────
|
||||
|
||||
DWD_TASKS: list[TaskDefinition] = [
|
||||
TaskDefinition("DWD_LOAD_FROM_ODS", "DWD 装载", "从 ODS 装载至 DWD(维度 SCD2 + 事实增量)", "通用", "DWD", requires_window=False),
|
||||
TaskDefinition("DWD_QUALITY_CHECK", "DWD 质量检查", "对 DWD 层数据执行质量校验", "通用", "DWD", requires_window=False, is_common=False),
|
||||
]
|
||||
|
||||
# ── DWS 任务定义 ──────────────────────────────────────────────
|
||||
|
||||
DWS_TASKS: list[TaskDefinition] = [
|
||||
TaskDefinition("DWS_BUILD_ORDER_SUMMARY", "订单汇总构建", "构建订单汇总宽表", "结算", "DWS"),
|
||||
TaskDefinition("DWS_ASSISTANT_DAILY", "助教日报", "汇总助教每日业绩", "助教", "DWS"),
|
||||
TaskDefinition("DWS_ASSISTANT_MONTHLY", "助教月报", "汇总助教月度业绩", "助教", "DWS"),
|
||||
TaskDefinition("DWS_ASSISTANT_CUSTOMER", "助教客户分析", "汇总助教-客户关系", "助教", "DWS"),
|
||||
TaskDefinition("DWS_ASSISTANT_SALARY", "助教工资计算", "计算助教工资", "助教", "DWS"),
|
||||
TaskDefinition("DWS_ASSISTANT_FINANCE", "助教财务汇总", "汇总助教财务数据", "助教", "DWS"),
|
||||
TaskDefinition("DWS_MEMBER_CONSUMPTION", "会员消费分析", "汇总会员消费数据", "会员", "DWS"),
|
||||
TaskDefinition("DWS_MEMBER_VISIT", "会员到店分析", "汇总会员到店频次", "会员", "DWS"),
|
||||
TaskDefinition("DWS_FINANCE_DAILY", "财务日报", "汇总每日财务数据", "财务", "DWS"),
|
||||
TaskDefinition("DWS_FINANCE_RECHARGE", "充值汇总", "汇总充值数据", "财务", "DWS"),
|
||||
TaskDefinition("DWS_FINANCE_INCOME_STRUCTURE", "收入结构", "分析收入结构", "财务", "DWS"),
|
||||
TaskDefinition("DWS_FINANCE_DISCOUNT_DETAIL", "折扣明细", "汇总折扣明细", "财务", "DWS"),
|
||||
# CHANGE [2026-02-19] intent: 同步 ETL 侧合并——原 DWS_RETENTION_CLEANUP / DWS_MV_REFRESH_* 已合并为 DWS_MAINTENANCE
|
||||
TaskDefinition("DWS_MAINTENANCE", "DWS 维护", "刷新物化视图 + 清理过期留存数据", "通用", "DWS", requires_window=False, is_common=False),
|
||||
]
|
||||
|
||||
# ── INDEX 任务定义 ────────────────────────────────────────────
|
||||
|
||||
INDEX_TASKS: list[TaskDefinition] = [
|
||||
TaskDefinition("DWS_WINBACK_INDEX", "回流指数 (WBI)", "计算会员回流指数", "指数", "INDEX"),
|
||||
TaskDefinition("DWS_NEWCONV_INDEX", "新客转化指数 (NCI)", "计算新客转化指数", "指数", "INDEX"),
|
||||
TaskDefinition("DWS_ML_MANUAL_IMPORT", "手动导入 (ML)", "手动导入机器学习数据", "指数", "INDEX", requires_window=False, is_common=False),
|
||||
TaskDefinition("DWS_RELATION_INDEX", "关系指数 (RS)", "计算助教-客户关系指数", "指数", "INDEX"),
|
||||
]
|
||||
|
||||
# ── 工具类任务定义 ────────────────────────────────────────────
|
||||
|
||||
UTILITY_TASKS: list[TaskDefinition] = [
|
||||
TaskDefinition("MANUAL_INGEST", "手动导入", "从本地 JSON 文件手动导入数据", "工具", "UTILITY", requires_window=False, is_common=False),
|
||||
TaskDefinition("INIT_ODS_SCHEMA", "初始化 ODS Schema", "创建 ODS 层表结构", "工具", "UTILITY", requires_window=False, is_common=False),
|
||||
TaskDefinition("INIT_DWD_SCHEMA", "初始化 DWD Schema", "创建 DWD 层表结构", "工具", "UTILITY", requires_window=False, is_common=False),
|
||||
TaskDefinition("INIT_DWS_SCHEMA", "初始化 DWS Schema", "创建 DWS 层表结构", "工具", "UTILITY", requires_window=False, is_common=False),
|
||||
TaskDefinition("ODS_JSON_ARCHIVE", "ODS JSON 归档", "归档 ODS 原始 JSON 文件", "工具", "UTILITY", requires_window=False, is_common=False),
|
||||
TaskDefinition("CHECK_CUTOFF", "游标检查", "检查各任务数据游标截止点", "工具", "UTILITY", requires_window=False, is_common=False),
|
||||
TaskDefinition("SEED_DWS_CONFIG", "DWS 配置种子", "初始化 DWS 配置数据", "工具", "UTILITY", requires_window=False, is_common=False),
|
||||
TaskDefinition("DATA_INTEGRITY_CHECK", "数据完整性校验", "校验跨层数据完整性", "工具", "UTILITY", requires_window=False, is_common=False),
|
||||
]
|
||||
|
||||
# ── 全量任务列表 ──────────────────────────────────────────────
|
||||
|
||||
ALL_TASKS: list[TaskDefinition] = ODS_TASKS + DWD_TASKS + DWS_TASKS + INDEX_TASKS + UTILITY_TASKS
|
||||
|
||||
# 按 code 索引,便于快速查找
|
||||
_TASK_BY_CODE: dict[str, TaskDefinition] = {t.code: t for t in ALL_TASKS}
|
||||
|
||||
|
||||
def get_all_tasks() -> list[TaskDefinition]:
|
||||
return ALL_TASKS
|
||||
|
||||
|
||||
def get_task_by_code(code: str) -> TaskDefinition | None:
|
||||
return _TASK_BY_CODE.get(code.upper())
|
||||
|
||||
|
||||
def get_tasks_grouped_by_domain() -> dict[str, list[TaskDefinition]]:
|
||||
"""按业务域分组返回任务列表"""
|
||||
groups: dict[str, list[TaskDefinition]] = {}
|
||||
for t in ALL_TASKS:
|
||||
groups.setdefault(t.domain, []).append(t)
|
||||
return groups
|
||||
|
||||
|
||||
def get_tasks_by_layer(layer: str) -> list[TaskDefinition]:
|
||||
"""获取指定层的所有任务"""
|
||||
layer_upper = layer.upper()
|
||||
return [t for t in ALL_TASKS if t.layer == layer_upper]
|
||||
|
||||
|
||||
# ── Flow → 层映射 ────────────────────────────────────────────
|
||||
# 每种 Flow 包含的层,用于前端按 Flow 过滤可选任务
|
||||
|
||||
FLOW_LAYER_MAP: dict[str, list[str]] = {
|
||||
"api_ods": ["ODS"],
|
||||
"api_ods_dwd": ["ODS", "DWD"],
|
||||
"api_full": ["ODS", "DWD", "DWS", "INDEX"],
|
||||
"ods_dwd": ["DWD"],
|
||||
"dwd_dws": ["DWS"],
|
||||
"dwd_dws_index": ["DWS", "INDEX"],
|
||||
"dwd_index": ["INDEX"],
|
||||
}
|
||||
|
||||
|
||||
def get_compatible_tasks(flow_id: str) -> list[TaskDefinition]:
|
||||
"""根据 Flow 包含的层,返回兼容的任务列表"""
|
||||
layers = FLOW_LAYER_MAP.get(flow_id, [])
|
||||
return [t for t in ALL_TASKS if t.layer in layers]
|
||||
|
||||
|
||||
# ── DWD 表定义 ────────────────────────────────────────────────
|
||||
|
||||
DWD_TABLES: list[DwdTableDefinition] = [
|
||||
# 维度表
|
||||
DwdTableDefinition("dwd.dim_site", "门店维度", "台桌", "ods.table_fee_transactions", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_site_ex", "门店维度(扩展)", "台桌", "ods.table_fee_transactions", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_table", "台桌维度", "台桌", "ods.site_tables_master", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_table_ex", "台桌维度(扩展)", "台桌", "ods.site_tables_master", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_assistant", "助教维度", "助教", "ods.assistant_accounts_master", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_assistant_ex", "助教维度(扩展)", "助教", "ods.assistant_accounts_master", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_member", "会员维度", "会员", "ods.member_profiles", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_member_ex", "会员维度(扩展)", "会员", "ods.member_profiles", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_member_card_account", "会员储值卡维度", "会员", "ods.member_stored_value_cards", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_member_card_account_ex", "会员储值卡维度(扩展)", "会员", "ods.member_stored_value_cards", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_tenant_goods", "租户商品维度", "商品", "ods.tenant_goods_master", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_tenant_goods_ex", "租户商品维度(扩展)", "商品", "ods.tenant_goods_master", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_store_goods", "门店商品维度", "商品", "ods.store_goods_master", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_store_goods_ex", "门店商品维度(扩展)", "商品", "ods.store_goods_master", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_goods_category", "商品分类维度", "商品", "ods.stock_goods_category_tree", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_groupbuy_package", "团购套餐维度", "团购", "ods.group_buy_packages", is_dimension=True),
|
||||
DwdTableDefinition("dwd.dim_groupbuy_package_ex", "团购套餐维度(扩展)", "团购", "ods.group_buy_packages", is_dimension=True),
|
||||
# 事实表
|
||||
DwdTableDefinition("dwd.dwd_settlement_head", "结算主表", "结算", "ods.settlement_records"),
|
||||
DwdTableDefinition("dwd.dwd_settlement_head_ex", "结算主表(扩展)", "结算", "ods.settlement_records"),
|
||||
DwdTableDefinition("dwd.dwd_table_fee_log", "台费流水", "台桌", "ods.table_fee_transactions"),
|
||||
DwdTableDefinition("dwd.dwd_table_fee_log_ex", "台费流水(扩展)", "台桌", "ods.table_fee_transactions"),
|
||||
DwdTableDefinition("dwd.dwd_table_fee_adjust", "台费折扣", "台桌", "ods.table_fee_discount_records"),
|
||||
DwdTableDefinition("dwd.dwd_table_fee_adjust_ex", "台费折扣(扩展)", "台桌", "ods.table_fee_discount_records"),
|
||||
DwdTableDefinition("dwd.dwd_store_goods_sale", "商品销售", "商品", "ods.store_goods_sales_records"),
|
||||
DwdTableDefinition("dwd.dwd_store_goods_sale_ex", "商品销售(扩展)", "商品", "ods.store_goods_sales_records"),
|
||||
DwdTableDefinition("dwd.dwd_assistant_service_log", "助教服务流水", "助教", "ods.assistant_service_records"),
|
||||
DwdTableDefinition("dwd.dwd_assistant_service_log_ex", "助教服务流水(扩展)", "助教", "ods.assistant_service_records"),
|
||||
DwdTableDefinition("dwd.dwd_assistant_trash_event", "助教取消事件", "助教", "ods.assistant_cancellation_records"),
|
||||
DwdTableDefinition("dwd.dwd_assistant_trash_event_ex", "助教取消事件(扩展)", "助教", "ods.assistant_cancellation_records"),
|
||||
DwdTableDefinition("dwd.dwd_member_balance_change", "会员余额变动", "会员", "ods.member_balance_changes"),
|
||||
DwdTableDefinition("dwd.dwd_member_balance_change_ex", "会员余额变动(扩展)", "会员", "ods.member_balance_changes"),
|
||||
DwdTableDefinition("dwd.dwd_groupbuy_redemption", "团购核销", "团购", "ods.group_buy_redemption_records"),
|
||||
DwdTableDefinition("dwd.dwd_groupbuy_redemption_ex", "团购核销(扩展)", "团购", "ods.group_buy_redemption_records"),
|
||||
DwdTableDefinition("dwd.dwd_platform_coupon_redemption", "平台券核销", "团购", "ods.platform_coupon_redemption_records"),
|
||||
DwdTableDefinition("dwd.dwd_platform_coupon_redemption_ex", "平台券核销(扩展)", "团购", "ods.platform_coupon_redemption_records"),
|
||||
DwdTableDefinition("dwd.dwd_recharge_order", "充值订单", "会员", "ods.recharge_settlements"),
|
||||
DwdTableDefinition("dwd.dwd_recharge_order_ex", "充值订单(扩展)", "会员", "ods.recharge_settlements"),
|
||||
DwdTableDefinition("dwd.dwd_payment", "支付流水", "结算", "ods.payment_transactions"),
|
||||
DwdTableDefinition("dwd.dwd_refund", "退款流水", "结算", "ods.refund_transactions"),
|
||||
DwdTableDefinition("dwd.dwd_refund_ex", "退款流水(扩展)", "结算", "ods.refund_transactions"),
|
||||
]
|
||||
|
||||
|
||||
def get_dwd_tables_grouped_by_domain() -> dict[str, list[DwdTableDefinition]]:
|
||||
"""按业务域分组返回 DWD 表定义"""
|
||||
groups: dict[str, list[DwdTableDefinition]] = {}
|
||||
for t in DWD_TABLES:
|
||||
groups.setdefault(t.domain, []).append(t)
|
||||
return groups
|
||||
68
apps/backend/app/ws/logs.py
Normal file
68
apps/backend/app/ws/logs.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""WebSocket 日志推送端点
|
||||
|
||||
提供 WS /ws/logs/{execution_id} 端点,实时推送 ETL 任务执行日志。
|
||||
客户端连接后,先发送已有的历史日志行,再实时推送新日志,
|
||||
直到执行结束(收到 None 哨兵)或客户端断开。
|
||||
|
||||
设计要点:
|
||||
- 利用 TaskExecutor 已有的 subscribe/unsubscribe 机制
|
||||
- 连接时先回放内存缓冲区中的历史日志,避免丢失已产生的行
|
||||
- 通过 asyncio.Queue 接收实时日志,None 表示执行结束
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
|
||||
from ..services.task_executor import task_executor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ws_router = APIRouter()
|
||||
|
||||
|
||||
@ws_router.websocket("/ws/logs/{execution_id}")
|
||||
async def ws_logs(websocket: WebSocket, execution_id: str) -> None:
|
||||
"""实时推送指定 execution_id 的任务执行日志。
|
||||
|
||||
流程:
|
||||
1. 接受 WebSocket 连接
|
||||
2. 回放内存缓冲区中已有的日志行
|
||||
3. 订阅 TaskExecutor,持续推送新日志
|
||||
4. 收到 None(执行结束)或客户端断开时关闭
|
||||
"""
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket 连接已建立: execution_id=%s", execution_id)
|
||||
|
||||
# 订阅日志流
|
||||
queue = task_executor.subscribe(execution_id)
|
||||
|
||||
try:
|
||||
# 回放已有的历史日志行
|
||||
for line in task_executor.get_logs(execution_id):
|
||||
await websocket.send_text(line)
|
||||
|
||||
# 如果任务已经不在运行且没有订阅者队列中的数据,
|
||||
# 仍然保持连接等待——可能是任务刚结束但 queue 里还有未消费的消息
|
||||
while True:
|
||||
msg = await queue.get()
|
||||
if msg is None:
|
||||
# 执行结束哨兵
|
||||
break
|
||||
await websocket.send_text(msg)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info("WebSocket 客户端断开: execution_id=%s", execution_id)
|
||||
except Exception:
|
||||
logger.exception("WebSocket 异常: execution_id=%s", execution_id)
|
||||
finally:
|
||||
task_executor.unsubscribe(execution_id, queue)
|
||||
# 安全关闭连接(客户端可能已断开,忽略错误)
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("WebSocket 连接已清理: execution_id=%s", execution_id)
|
||||
24
apps/backend/doc/开放平台证书.cer
Normal file
24
apps/backend/doc/开放平台证书.cer
Normal file
@@ -0,0 +1,24 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIID9DCCAtygAwIBAgIUaB2siLoT1Nb+u9K0aL18avodFRYwDQYJKoZIhvcNAQEL
|
||||
BQAwbTELMAkGA1UEBhMCQ04xEjAQBgNVBAgMCUd1YW5nRG9uZzERMA8GA1UEBwwI
|
||||
U2hlblpoZW4xEDAOBgNVBAoMB1RlbmNlbnQxEDAOBgNVBAsMB1RlbmNlbnQxEzAR
|
||||
BgNVBAMMCnRzbTAwMDAwMDYwHhcNMjUxMTA0MTA1NzQ0WhcNMzUxMTAyMTA1NzQ0
|
||||
WjCBmDEeMBwGCSqGSIb3DQEJARYPd2VpeGlubXBAcXEuY29tMRswGQYDVQQDDBJ3
|
||||
eDdjMDc3OTNkODI3MzI5MjExFTATBgNVBAoMDFRlbmNlbnQgSW5jLjEOMAwGA1UE
|
||||
CwwFV3hnTXAxCzAJBgNVBAYTAkNOMRIwEAYDVQQIDAlHdWFuZ0RvbmcxETAPBgNV
|
||||
BAcMCFNoZW5aaGVuMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA8Ig7
|
||||
m6qvcw+0ncNZBUg9Xfk+6WTKSj7cCwcuD66JgYPQ4pVreqCTzusc//E+EGapXyR3
|
||||
fJZl+AL2sVWRIzLa1+dt9+45YBGm3lSbvarlbsdVMrYwFRW+d/vZgxcXfcS1VmVP
|
||||
NPEw7DaAkWlvUkFmUGdNzH+QLsuXTZKWEtHtmSo7us9HTJYV/aEH2uJpsHE4A3fP
|
||||
Vbc/wy1EwLt48o0ZDzpiPZiqn+nSrSXEqBPOEwzICxnHCJRpEH01RxBJGdTSouzF
|
||||
pfncMeEpGfFGH8GW8IQYzzvrvYDbprsnVMfHNAo0MMGK+iyWyAOFMqrkvSb962x7
|
||||
KoXLDH9OfmFRNse9WQIDAQABo2AwXjAdBgNVHQ4EFgQUccFm3WWmKA/m+uXeW8Xe
|
||||
jfNsX+owHwYDVR0jBBgwFoAURM1183H4z2eJGmP5z/zWI7tjRZAwDAYDVR0TAQH/
|
||||
BAIwADAOBgNVHQ8BAf8EBAMCBsAwDQYJKoZIhvcNAQELBQADggEBAIXfGfQARyxC
|
||||
Ptut+rOccdq8TawasZT7o7TnAGCCTPsAWCd5RTAXse65mSGM6oxjQsppZxtYz4Kx
|
||||
TLySl91Vok2nMH1jBoWPx9WoFyU6zCkmOkq7zWvEU23FR1Quq0QB0fmHrVMNQqxA
|
||||
LKkUuUFTa1wmVuYaKtcz5LAaj+GmgrY3kTIWg81ybPF/Hibkz0zWh54SLBc64Ha6
|
||||
zfNXDffq3vVVo04DKZW8Erd9nZL0F/w2u6+MpTl5CrAYzSZyDcNiIGbSrYpYYRt9
|
||||
JagGAn/ZZD93SnOiMcRCsNfNq4LisSf6AUMSA3F9Rw8iuxas5lDBf073pEy2vWjG
|
||||
VSp+Vio/oEY=
|
||||
-----END CERTIFICATE-----
|
||||
38
apps/backend/doc/微信开放平台 小程序 配置.md
Normal file
38
apps/backend/doc/微信开放平台 小程序 配置.md
Normal file
@@ -0,0 +1,38 @@
|
||||
## 开放平台证书
|
||||
开放平台证书编号:06e9682660ce742bb45ef278ae941af0
|
||||
证书已下载。
|
||||
|
||||
编号 密钥类型 密钥明文
|
||||
901b24f9af7b1421b80ebd5df9094141 对称密钥 D1fK6Zib6UOG10bM4WWhjsbMImCNXz7Mxq/0oRREGmw=
|
||||
59347d0e3cd661af9e90f7def5b6ca00 非对称密钥 -----BEGIN PUBLIC KEY----- MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAzMrQ8iGGop0JEvcS/dZu sVUgjHeHqtds506Ftq/0WylTGKG9yByrY6eSKHttkBs/IyqG6JU9TtmskvVBam9B BmwOVyHkXATlXwEIkhyShu459p0dQzKwFaiygtj/fvxzWvurXVf1UcbIYVP1u7T0 E6yqatUkhmeaIyBsuw7vp7yYxpJoxsR2t70cDOTnfmHzl47FqSmq8xQRl/Cyw4FJ +1NS3i3cQBYkxjxGJ8Q3lhd5xd8IjwzDf+v/UgEhmoduZDAr/HkOwcrh3ihCCkLC DsfqV6qJrtqHhq5PHog3TqaaF6I9vk04FQTHn07XzIMciBDauYdD/Mq3B+ZBOKX4 gQIDAQAB -----END PUBLIC KEY-----
|
||||
|
||||
|
||||
|
||||
## 安全管理配置
|
||||
https://developers.weixin.qq.com/miniprogram/dev/framework/share_signature.html
|
||||
|
||||
## 消息推送配置
|
||||
(暂未处理)
|
||||
|
||||
填写的URL需要正确响应微信发送的Token验证,填写说明请阅读消息推送服务器配置指南。https://developers.weixin.qq.com/miniprogram/dev/framework/server-ability/message-push.html
|
||||
|
||||
|
||||
|
||||
URL(服务器地址):
|
||||
https://push.langlangzhuoqiu.cn
|
||||
|
||||
Token(令牌):
|
||||
nCmGmzINYfaqf5jDGKdkDO2AJ9C0VWNe
|
||||
|
||||
EncodingAESKey(消息加密密钥):
|
||||
2SZwWe90vG121o1l7RRMbtGt8GNvA1Juf727a3m7nZX
|
||||
|
||||
消息加密方式:
|
||||
安全模式 (消息包为纯密文,需要加密和解密。)
|
||||
|
||||
数据格式:
|
||||
JSON
|
||||
|
||||
|
||||
## 业务域名
|
||||
https://api.langlangzhuoqiu.cn
|
||||
@@ -1,10 +1,34 @@
|
||||
# AI_CHANGELOG
|
||||
# - 2026-02-15 | Prompt: 让 FastAPI 成功启动 | 补全运行依赖(fastapi/uvicorn/psycopg2-binary/python-dotenv),使后端可通过 uv run uvicorn 启动
|
||||
# - 风险:依赖版本变更可能影响其他 workspace 成员;验证:uv sync --all-packages && uv run uvicorn app.main:app
|
||||
|
||||
[project]
|
||||
name = "zqyy-backend"
|
||||
version = "0.1.0"
|
||||
requires-python = ">=3.10"
|
||||
# CHANGE 2026-02-15 | intent: 补全后端运行依赖,原先仅声明 neozqyy-shared 导致 uvicorn/fastapi 缺失无法启动
|
||||
# assumptions: 版本下限与 tech.md 记录的核心依赖一致;uvicorn[standard] 包含 uvloop/httptools 等性能依赖
|
||||
dependencies = [
|
||||
"neozqyy-shared",
|
||||
"fastapi>=0.115",
|
||||
"uvicorn[standard]>=0.34",
|
||||
"psycopg2-binary>=2.9",
|
||||
"python-dotenv>=1.0",
|
||||
"python-jose[cryptography]>=3.3",
|
||||
"bcrypt>=4.0",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
neozqyy-shared = { workspace = true }
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest>=8.0",
|
||||
"pytest-asyncio>=0.23",
|
||||
"hypothesis>=6.100",
|
||||
"httpx>=0.27",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
pythonpath = ["."]
|
||||
|
||||
62
apps/backend/tests/test_auth_dependencies.py
Normal file
62
apps/backend/tests/test_auth_dependencies.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
FastAPI 依赖注入 get_current_user 单元测试。
|
||||
|
||||
通过 FastAPI TestClient 验证 Authorization header 处理。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
import pytest
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.auth.jwt import create_access_token, create_refresh_token
|
||||
|
||||
# 构造一个最小 FastAPI 应用用于测试依赖注入
|
||||
_test_app = FastAPI()
|
||||
|
||||
|
||||
@_test_app.get("/protected")
|
||||
async def protected_route(user: CurrentUser = Depends(get_current_user)):
|
||||
return {"user_id": user.user_id, "site_id": user.site_id}
|
||||
|
||||
|
||||
client = TestClient(_test_app)
|
||||
|
||||
|
||||
class TestGetCurrentUser:
|
||||
def test_valid_access_token(self):
|
||||
token = create_access_token(user_id=10, site_id=100)
|
||||
resp = client.get("/protected", headers={"Authorization": f"Bearer {token}"})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["user_id"] == 10
|
||||
assert data["site_id"] == 100
|
||||
|
||||
def test_missing_auth_header_returns_401(self):
|
||||
"""缺少 Authorization header 时返回 401。"""
|
||||
resp = client.get("/protected")
|
||||
assert resp.status_code in (401, 403)
|
||||
|
||||
def test_invalid_token_returns_401(self):
|
||||
resp = client.get(
|
||||
"/protected", headers={"Authorization": "Bearer invalid.token.here"}
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_refresh_token_rejected(self):
|
||||
"""refresh 令牌不能用于访问受保护端点。"""
|
||||
token = create_refresh_token(user_id=1, site_id=1)
|
||||
resp = client.get("/protected", headers={"Authorization": f"Bearer {token}"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_current_user_is_frozen_dataclass(self):
|
||||
"""CurrentUser 是不可变的。"""
|
||||
user = CurrentUser(user_id=1, site_id=2)
|
||||
assert user.user_id == 1
|
||||
assert user.site_id == 2
|
||||
with pytest.raises(AttributeError):
|
||||
user.user_id = 99 # type: ignore[misc]
|
||||
147
apps/backend/tests/test_auth_jwt.py
Normal file
147
apps/backend/tests/test_auth_jwt.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""
|
||||
JWT 认证模块单元测试。
|
||||
|
||||
覆盖:令牌生成、验证、过期、类型校验、密码哈希、依赖注入。
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from jose import jwt as jose_jwt
|
||||
|
||||
# 测试前设置 JWT_SECRET_KEY,避免空密钥
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
from app.auth.jwt import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
create_token_pair,
|
||||
decode_access_token,
|
||||
decode_refresh_token,
|
||||
decode_token,
|
||||
hash_password,
|
||||
verify_password,
|
||||
)
|
||||
from app import config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 密码哈希
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPasswordHashing:
|
||||
def test_hash_and_verify(self):
|
||||
raw = "my_secure_password"
|
||||
hashed = hash_password(raw)
|
||||
assert verify_password(raw, hashed)
|
||||
|
||||
def test_wrong_password_rejected(self):
|
||||
hashed = hash_password("correct")
|
||||
assert not verify_password("wrong", hashed)
|
||||
|
||||
def test_hash_is_not_plaintext(self):
|
||||
raw = "plaintext123"
|
||||
hashed = hash_password(raw)
|
||||
assert hashed != raw
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 令牌生成与解码
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTokenCreation:
|
||||
def test_access_token_contains_expected_fields(self):
|
||||
token = create_access_token(user_id=1, site_id=100)
|
||||
payload = decode_token(token)
|
||||
assert payload["sub"] == "1"
|
||||
assert payload["site_id"] == 100
|
||||
assert payload["type"] == "access"
|
||||
assert "exp" in payload
|
||||
|
||||
def test_refresh_token_contains_expected_fields(self):
|
||||
token = create_refresh_token(user_id=2, site_id=200)
|
||||
payload = decode_token(token)
|
||||
assert payload["sub"] == "2"
|
||||
assert payload["site_id"] == 200
|
||||
assert payload["type"] == "refresh"
|
||||
assert "exp" in payload
|
||||
|
||||
def test_token_pair_returns_both_tokens(self):
|
||||
pair = create_token_pair(user_id=3, site_id=300)
|
||||
assert "access_token" in pair
|
||||
assert "refresh_token" in pair
|
||||
assert pair["token_type"] == "bearer"
|
||||
|
||||
# 验证两个令牌类型不同
|
||||
access_payload = decode_token(pair["access_token"])
|
||||
refresh_payload = decode_token(pair["refresh_token"])
|
||||
assert access_payload["type"] == "access"
|
||||
assert refresh_payload["type"] == "refresh"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 令牌类型校验
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTokenTypeValidation:
|
||||
def test_decode_access_token_rejects_refresh(self):
|
||||
"""access 解码器拒绝 refresh 令牌。"""
|
||||
token = create_refresh_token(user_id=1, site_id=1)
|
||||
with pytest.raises(Exception):
|
||||
decode_access_token(token)
|
||||
|
||||
def test_decode_refresh_token_rejects_access(self):
|
||||
"""refresh 解码器拒绝 access 令牌。"""
|
||||
token = create_access_token(user_id=1, site_id=1)
|
||||
with pytest.raises(Exception):
|
||||
decode_refresh_token(token)
|
||||
|
||||
def test_decode_access_token_accepts_access(self):
|
||||
token = create_access_token(user_id=5, site_id=50)
|
||||
payload = decode_access_token(token)
|
||||
assert payload["sub"] == "5"
|
||||
assert payload["site_id"] == 50
|
||||
|
||||
def test_decode_refresh_token_accepts_refresh(self):
|
||||
token = create_refresh_token(user_id=6, site_id=60)
|
||||
payload = decode_refresh_token(token)
|
||||
assert payload["sub"] == "6"
|
||||
assert payload["site_id"] == 60
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 令牌过期
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTokenExpiry:
|
||||
def test_expired_token_rejected(self):
|
||||
"""手动构造已过期令牌,验证解码失败。"""
|
||||
payload = {
|
||||
"sub": "1",
|
||||
"site_id": 1,
|
||||
"type": "access",
|
||||
"exp": int(time.time()) - 10, # 10 秒前过期
|
||||
}
|
||||
token = jose_jwt.encode(
|
||||
payload, config.JWT_SECRET_KEY, algorithm=config.JWT_ALGORITHM
|
||||
)
|
||||
with pytest.raises(Exception):
|
||||
decode_token(token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 无效令牌
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestInvalidToken:
|
||||
def test_garbage_token_rejected(self):
|
||||
with pytest.raises(Exception):
|
||||
decode_token("not.a.valid.jwt")
|
||||
|
||||
def test_wrong_secret_rejected(self):
|
||||
"""用不同密钥签发的令牌应被拒绝。"""
|
||||
payload = {"sub": "1", "site_id": 1, "type": "access", "exp": int(time.time()) + 3600}
|
||||
token = jose_jwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||||
with pytest.raises(Exception):
|
||||
decode_token(token)
|
||||
137
apps/backend/tests/test_auth_properties.py
Normal file
137
apps/backend/tests/test_auth_properties.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
认证模块属性测试(Property-Based Testing)。
|
||||
|
||||
使用 hypothesis 验证认证系统的通用正确性属性:
|
||||
- Property 2: 无效凭据始终被拒绝
|
||||
- Property 3: 有效 JWT 令牌授权访问
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-property-tests")
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.auth.jwt import create_access_token
|
||||
from app.main import app
|
||||
from app.routers.auth import router
|
||||
|
||||
# 确保路由已挂载
|
||||
if router not in [r for r in app.routes]:
|
||||
app.include_router(router)
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 策略(Strategies)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 用户名策略:1~64 字符的可打印字符串(排除控制字符)
|
||||
_username_st = st.text(
|
||||
alphabet=st.characters(whitelist_categories=("L", "N", "P", "S")),
|
||||
min_size=1,
|
||||
max_size=64,
|
||||
)
|
||||
|
||||
# 密码策略:1~128 字符的可打印字符串
|
||||
_password_st = st.text(
|
||||
alphabet=st.characters(whitelist_categories=("L", "N", "P", "S")),
|
||||
min_size=1,
|
||||
max_size=128,
|
||||
)
|
||||
|
||||
# user_id 策略:正整数
|
||||
_user_id_st = st.integers(min_value=1, max_value=2**31 - 1)
|
||||
|
||||
# site_id 策略:正整数
|
||||
_site_id_st = st.integers(min_value=1, max_value=2**63 - 1)
|
||||
|
||||
|
||||
def _mock_db_returning(row):
|
||||
"""构造 mock get_connection,cursor.fetchone() 返回指定行。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = row
|
||||
mock_conn.cursor.return_value.__enter__ = lambda _: mock_cursor
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 2: 无效凭据始终被拒绝
|
||||
# **Validates: Requirements 1.2**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(username=_username_st, password=_password_st)
|
||||
@patch("app.routers.auth.get_connection")
|
||||
def test_invalid_credentials_always_rejected(mock_get_conn, username, password):
|
||||
"""
|
||||
Property 2: 无效凭据始终被拒绝。
|
||||
|
||||
对于任意用户名/密码组合,当数据库中不存在该用户时(fetchone 返回 None),
|
||||
登录接口应始终返回 401 状态码。
|
||||
"""
|
||||
# mock 数据库返回 None — 用户不存在
|
||||
mock_get_conn.return_value = _mock_db_returning(None)
|
||||
|
||||
resp = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": username, "password": password},
|
||||
)
|
||||
assert resp.status_code == 401, (
|
||||
f"期望 401,实际 {resp.status_code},username={username!r}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 3: 有效 JWT 令牌授权访问
|
||||
# **Validates: Requirements 1.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
import asyncio
|
||||
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
|
||||
|
||||
def _run_async(coro):
|
||||
"""在同步上下文中执行异步协程,避免 DeprecationWarning。"""
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(user_id=_user_id_st, site_id=_site_id_st)
|
||||
def test_valid_jwt_grants_access(user_id, site_id):
|
||||
"""
|
||||
Property 3: 有效 JWT 令牌授权访问。
|
||||
|
||||
对于任意 user_id 和 site_id,由系统签发的未过期 access_token
|
||||
应能被 get_current_user 依赖成功解析为 CurrentUser 对象,
|
||||
且解析出的 user_id 和 site_id 与签发时一致。
|
||||
"""
|
||||
# 生成有效的 access_token
|
||||
token = create_access_token(user_id=user_id, site_id=site_id)
|
||||
|
||||
# 直接调用依赖函数验证令牌解析
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
|
||||
result = _run_async(get_current_user(credentials))
|
||||
|
||||
assert isinstance(result, CurrentUser)
|
||||
assert result.user_id == user_id, (
|
||||
f"user_id 不匹配:期望 {user_id},实际 {result.user_id}"
|
||||
)
|
||||
assert result.site_id == site_id, (
|
||||
f"site_id 不匹配:期望 {site_id},实际 {result.site_id}"
|
||||
)
|
||||
167
apps/backend/tests/test_auth_router.py
Normal file
167
apps/backend/tests/test_auth_router.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
认证路由单元测试。
|
||||
|
||||
覆盖:登录成功/失败、刷新令牌、账号禁用等场景。
|
||||
通过 mock 数据库连接避免依赖真实 PostgreSQL。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.jwt import (
|
||||
create_refresh_token,
|
||||
decode_access_token,
|
||||
decode_refresh_token,
|
||||
hash_password,
|
||||
)
|
||||
from app.main import app
|
||||
from app.routers.auth import router
|
||||
|
||||
# 注册路由到 app(测试时确保路由已挂载)
|
||||
if router not in [r for r in app.routes]:
|
||||
app.include_router(router)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# 测试用固定数据
|
||||
_TEST_PASSWORD = "correct_password"
|
||||
_TEST_HASH = hash_password(_TEST_PASSWORD)
|
||||
_TEST_USER_ROW = (1, _TEST_HASH, 100, True) # id, password_hash, site_id, is_active
|
||||
_DISABLED_USER_ROW = (2, _TEST_HASH, 200, False)
|
||||
|
||||
|
||||
def _mock_db_returning(row):
|
||||
"""构造一个 mock get_connection,cursor.fetchone() 返回指定行。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = row
|
||||
mock_conn.cursor.return_value.__enter__ = lambda _: mock_cursor
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/auth/login
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLogin:
|
||||
@patch("app.routers.auth.get_connection")
|
||||
def test_login_success(self, mock_get_conn):
|
||||
mock_get_conn.return_value = _mock_db_returning(_TEST_USER_ROW)
|
||||
|
||||
resp = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "admin", "password": _TEST_PASSWORD},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "access_token" in data
|
||||
assert "refresh_token" in data
|
||||
assert data["token_type"] == "bearer"
|
||||
|
||||
# 验证 access_token payload 包含正确的 user_id 和 site_id
|
||||
payload = decode_access_token(data["access_token"])
|
||||
assert payload["sub"] == "1"
|
||||
assert payload["site_id"] == 100
|
||||
|
||||
@patch("app.routers.auth.get_connection")
|
||||
def test_login_user_not_found(self, mock_get_conn):
|
||||
"""用户不存在时返回 401。"""
|
||||
mock_get_conn.return_value = _mock_db_returning(None)
|
||||
|
||||
resp = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "nonexistent", "password": "whatever"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert "用户名或密码错误" in resp.json()["detail"]
|
||||
|
||||
@patch("app.routers.auth.get_connection")
|
||||
def test_login_wrong_password(self, mock_get_conn):
|
||||
"""密码错误时返回 401。"""
|
||||
mock_get_conn.return_value = _mock_db_returning(_TEST_USER_ROW)
|
||||
|
||||
resp = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "admin", "password": "wrong_password"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert "用户名或密码错误" in resp.json()["detail"]
|
||||
|
||||
@patch("app.routers.auth.get_connection")
|
||||
def test_login_disabled_account(self, mock_get_conn):
|
||||
"""账号已禁用时返回 401。"""
|
||||
mock_get_conn.return_value = _mock_db_returning(_DISABLED_USER_ROW)
|
||||
|
||||
resp = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "disabled_user", "password": _TEST_PASSWORD},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert "禁用" in resp.json()["detail"]
|
||||
|
||||
def test_login_missing_username(self):
|
||||
"""缺少 username 字段时返回 422。"""
|
||||
resp = client.post("/api/auth/login", json={"password": "test"})
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_login_empty_password(self):
|
||||
"""空密码时返回 422。"""
|
||||
resp = client.post(
|
||||
"/api/auth/login", json={"username": "admin", "password": ""}
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/auth/refresh
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRefresh:
|
||||
def test_refresh_success(self):
|
||||
"""有效的 refresh_token 换取新的 access_token。"""
|
||||
refresh = create_refresh_token(user_id=5, site_id=50)
|
||||
|
||||
resp = client.post(
|
||||
"/api/auth/refresh", json={"refresh_token": refresh}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "access_token" in data
|
||||
# refresh_token 原样返回
|
||||
assert data["refresh_token"] == refresh
|
||||
assert data["token_type"] == "bearer"
|
||||
|
||||
# 新 access_token 包含正确信息
|
||||
payload = decode_access_token(data["access_token"])
|
||||
assert payload["sub"] == "5"
|
||||
assert payload["site_id"] == 50
|
||||
|
||||
def test_refresh_with_invalid_token(self):
|
||||
"""无效令牌返回 401。"""
|
||||
resp = client.post(
|
||||
"/api/auth/refresh", json={"refresh_token": "garbage.token.here"}
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert "无效的刷新令牌" in resp.json()["detail"]
|
||||
|
||||
def test_refresh_with_access_token_rejected(self):
|
||||
"""用 access_token 做刷新应被拒绝。"""
|
||||
from app.auth.jwt import create_access_token
|
||||
|
||||
access = create_access_token(user_id=1, site_id=1)
|
||||
resp = client.post(
|
||||
"/api/auth/refresh", json={"refresh_token": access}
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_refresh_missing_token(self):
|
||||
"""缺少 refresh_token 字段时返回 422。"""
|
||||
resp = client.post("/api/auth/refresh", json={})
|
||||
assert resp.status_code == 422
|
||||
259
apps/backend/tests/test_cli_builder.py
Normal file
259
apps/backend/tests/test_cli_builder.py
Normal file
@@ -0,0 +1,259 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""CLIBuilder 单元测试
|
||||
|
||||
覆盖:7 种 Flow、3 种处理模式、时间窗口、store_id 自动注入、extra_args 等。
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.schemas.tasks import TaskConfigSchema
|
||||
from app.services.cli_builder import CLIBuilder, VALID_FLOWS, VALID_PROCESSING_MODES
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def builder() -> CLIBuilder:
|
||||
return CLIBuilder()
|
||||
|
||||
|
||||
ETL_PATH = "/fake/etl/project"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 基本命令结构
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBasicCommand:
|
||||
def test_minimal_command(self, builder: CLIBuilder):
|
||||
"""最小配置应生成 python -m cli.main --pipeline ... --processing-mode ..."""
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"])
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert cmd[:3] == ["python", "-m", "cli.main"]
|
||||
assert "--pipeline" in cmd
|
||||
assert "--processing-mode" in cmd
|
||||
|
||||
def test_custom_python_executable(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"])
|
||||
cmd = builder.build_command(config, ETL_PATH, python_executable="python3")
|
||||
assert cmd[0] == "python3"
|
||||
|
||||
def test_tasks_joined_by_comma(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER", "ODS_PAYMENT", "ODS_REFUND"])
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx = cmd.index("--tasks")
|
||||
assert cmd[idx + 1] == "ODS_MEMBER,ODS_PAYMENT,ODS_REFUND"
|
||||
|
||||
def test_empty_tasks_no_tasks_arg(self, builder: CLIBuilder):
|
||||
"""空任务列表不应生成 --tasks 参数"""
|
||||
config = TaskConfigSchema(tasks=[])
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--tasks" not in cmd
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7 种 Flow
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFlows:
|
||||
@pytest.mark.parametrize("flow_id", sorted(VALID_FLOWS))
|
||||
def test_all_flows_accepted(self, builder: CLIBuilder, flow_id: str):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"], pipeline=flow_id)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx = cmd.index("--pipeline")
|
||||
assert cmd[idx + 1] == flow_id
|
||||
|
||||
def test_default_flow_is_api_ods_dwd(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"])
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx = cmd.index("--pipeline")
|
||||
assert cmd[idx + 1] == "api_ods_dwd"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3 种处理模式
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestProcessingModes:
|
||||
@pytest.mark.parametrize("mode", sorted(VALID_PROCESSING_MODES))
|
||||
def test_all_modes_accepted(self, builder: CLIBuilder, mode: str):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"], processing_mode=mode)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx = cmd.index("--processing-mode")
|
||||
assert cmd[idx + 1] == mode
|
||||
|
||||
def test_fetch_before_verify_only_in_verify_mode(self, builder: CLIBuilder):
|
||||
"""--fetch-before-verify 仅在 verify_only 模式下生效"""
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
processing_mode="verify_only",
|
||||
fetch_before_verify=True,
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--fetch-before-verify" in cmd
|
||||
|
||||
def test_fetch_before_verify_ignored_in_increment_mode(self, builder: CLIBuilder):
|
||||
"""increment_only 模式下 fetch_before_verify=True 不应生成参数"""
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
processing_mode="increment_only",
|
||||
fetch_before_verify=True,
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--fetch-before-verify" not in cmd
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 时间窗口
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTimeWindow:
|
||||
def test_lookback_mode_generates_lookback_args(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
window_mode="lookback",
|
||||
lookback_hours=48,
|
||||
overlap_seconds=1200,
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx_lb = cmd.index("--lookback-hours")
|
||||
assert cmd[idx_lb + 1] == "48"
|
||||
idx_ol = cmd.index("--overlap-seconds")
|
||||
assert cmd[idx_ol + 1] == "1200"
|
||||
# lookback 模式不应生成 --window-start / --window-end
|
||||
assert "--window-start" not in cmd
|
||||
assert "--window-end" not in cmd
|
||||
|
||||
def test_custom_mode_generates_window_args(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
window_mode="custom",
|
||||
window_start="2026-01-01",
|
||||
window_end="2026-01-31",
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx_s = cmd.index("--window-start")
|
||||
assert cmd[idx_s + 1] == "2026-01-01"
|
||||
idx_e = cmd.index("--window-end")
|
||||
assert cmd[idx_e + 1] == "2026-01-31"
|
||||
# custom 模式不应生成 --lookback-hours / --overlap-seconds
|
||||
assert "--lookback-hours" not in cmd
|
||||
assert "--overlap-seconds" not in cmd
|
||||
|
||||
def test_window_split_with_days(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
window_split="day",
|
||||
window_split_days=10,
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx = cmd.index("--window-split")
|
||||
assert cmd[idx + 1] == "day"
|
||||
idx_d = cmd.index("--window-split-days")
|
||||
assert cmd[idx_d + 1] == "10"
|
||||
|
||||
def test_window_split_none_not_generated(self, builder: CLIBuilder):
|
||||
"""window_split='none' 不应生成 --window-split 参数"""
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"], window_split="none")
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--window-split" not in cmd
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# store_id 自动注入
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestStoreId:
|
||||
def test_store_id_injected(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"], store_id=42)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx = cmd.index("--store-id")
|
||||
assert cmd[idx + 1] == "42"
|
||||
|
||||
def test_store_id_none_not_generated(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"], store_id=None)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--store-id" not in cmd
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# dry_run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDryRun:
|
||||
def test_dry_run_flag(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"], dry_run=True)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--dry-run" in cmd
|
||||
|
||||
def test_no_dry_run_flag(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"], dry_run=False)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--dry-run" not in cmd
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extra_args
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtraArgs:
|
||||
def test_supported_value_arg(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
extra_args={"pg_dsn": "postgresql://localhost/test"},
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
idx = cmd.index("--pg-dsn")
|
||||
assert cmd[idx + 1] == "postgresql://localhost/test"
|
||||
|
||||
def test_supported_bool_arg(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
extra_args={"force_window_override": True},
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--force-window-override" in cmd
|
||||
|
||||
def test_unsupported_arg_ignored(self, builder: CLIBuilder):
|
||||
"""不在 CLI_SUPPORTED_ARGS 中的键应被忽略"""
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
extra_args={"unknown_param": "value"},
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--unknown-param" not in cmd
|
||||
|
||||
def test_none_value_ignored(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
extra_args={"pg_dsn": None},
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--pg-dsn" not in cmd
|
||||
|
||||
def test_false_bool_arg_not_generated(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
extra_args={"force_window_override": False},
|
||||
)
|
||||
cmd = builder.build_command(config, ETL_PATH)
|
||||
assert "--force-window-override" not in cmd
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_command_string
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBuildCommandString:
|
||||
def test_returns_string(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(tasks=["ODS_MEMBER"])
|
||||
result = builder.build_command_string(config, ETL_PATH)
|
||||
assert isinstance(result, str)
|
||||
assert "python -m cli.main" in result
|
||||
|
||||
def test_quotes_args_with_spaces(self, builder: CLIBuilder):
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
extra_args={"pg_dsn": "host=localhost dbname=test"},
|
||||
)
|
||||
result = builder.build_command_string(config, ETL_PATH)
|
||||
# 包含空格的值应被引号包裹
|
||||
assert '"host=localhost dbname=test"' in result
|
||||
94
apps/backend/tests/test_database.py
Normal file
94
apps/backend/tests/test_database.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
数据库连接模块单元测试。
|
||||
|
||||
覆盖:ETL 只读连接的创建、RLS site_id 设置、只读模式、异常处理。
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
from app.database import get_etl_readonly_connection
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_etl_readonly_connection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetEtlReadonlyConnection:
|
||||
"""ETL 只读连接:验证连接参数、只读设置、RLS 隔离。"""
|
||||
|
||||
@patch("app.database.psycopg2.connect")
|
||||
def test_sets_readonly_and_site_id(self, mock_connect):
|
||||
"""连接后应依次执行 SET read_only 和 SET LOCAL site_id。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
conn = get_etl_readonly_connection(site_id=42)
|
||||
|
||||
# 验证 autocommit 被关闭
|
||||
assert mock_conn.autocommit is False
|
||||
|
||||
# 验证执行了两条 SET 语句
|
||||
executed = [c.args[0] for c in mock_cursor.execute.call_args_list]
|
||||
assert "SET default_transaction_read_only = on" in executed[0]
|
||||
assert "SET LOCAL app.current_site_id" in executed[1]
|
||||
|
||||
# 验证 site_id 参数化传递(防 SQL 注入)
|
||||
site_id_call = mock_cursor.execute.call_args_list[1]
|
||||
assert site_id_call.args[1] == ("42",)
|
||||
|
||||
# 验证提交
|
||||
mock_conn.commit.assert_called_once()
|
||||
assert conn is mock_conn
|
||||
|
||||
@patch("app.database.psycopg2.connect")
|
||||
def test_accepts_string_site_id(self, mock_connect):
|
||||
"""site_id 为字符串时也应正常工作。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
get_etl_readonly_connection(site_id="99")
|
||||
|
||||
site_id_call = mock_cursor.execute.call_args_list[1]
|
||||
assert site_id_call.args[1] == ("99",)
|
||||
|
||||
@patch("app.database.psycopg2.connect")
|
||||
def test_closes_connection_on_setup_error(self, mock_connect):
|
||||
"""SET 语句执行失败时应关闭连接并抛出异常。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.execute.side_effect = Exception("SET failed")
|
||||
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
with pytest.raises(Exception, match="SET failed"):
|
||||
get_etl_readonly_connection(site_id=1)
|
||||
|
||||
mock_conn.close.assert_called_once()
|
||||
|
||||
@patch("app.database.psycopg2.connect")
|
||||
def test_uses_etl_config_params(self, mock_connect):
|
||||
"""应使用 ETL_DB_* 配置项连接。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
get_etl_readonly_connection(site_id=1)
|
||||
|
||||
connect_kwargs = mock_connect.call_args.kwargs
|
||||
# 验证使用了 ETL 数据库名(默认 etl_feiqiu)
|
||||
assert connect_kwargs["dbname"] == "etl_feiqiu"
|
||||
139
apps/backend/tests/test_db_viewer_properties.py
Normal file
139
apps/backend/tests/test_db_viewer_properties.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""数据库查看器属性测试(Property-Based Testing)。
|
||||
|
||||
使用 hypothesis 验证数据库查看器的通用正确性属性:
|
||||
- Property 17: SQL 写操作拦截
|
||||
- Property 18: SQL 查询结果行数限制
|
||||
|
||||
测试策略:
|
||||
- Property 17: 生成包含写操作关键词(随机大小写混合)的 SQL 字符串,
|
||||
验证 _WRITE_KEYWORDS 正则表达式能匹配到
|
||||
- Property 18: 生成随机长度的行列表(可能超过 1000 行),
|
||||
验证截取前 _MAX_ROWS 个元素后长度 <= 1000
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-db-viewer-properties")
|
||||
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from app.routers.db_viewer import _WRITE_KEYWORDS, _MAX_ROWS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 通用策略(Strategies)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 写操作关键词列表
|
||||
_WRITE_OPS = ["INSERT", "UPDATE", "DELETE", "DROP", "TRUNCATE"]
|
||||
|
||||
# SQL 前缀/后缀:不含写操作关键词的简单文本
|
||||
_sql_filler_st = st.text(
|
||||
alphabet=st.characters(
|
||||
whitelist_categories=("L", "N", "S"),
|
||||
blacklist_characters="\x00",
|
||||
),
|
||||
min_size=0,
|
||||
max_size=50,
|
||||
)
|
||||
|
||||
# 随机大小写混合的写操作关键词
|
||||
_random_case_keyword_st = st.sampled_from(_WRITE_OPS).flatmap(
|
||||
lambda kw: st.tuples(
|
||||
st.just(kw),
|
||||
st.lists(
|
||||
st.booleans(),
|
||||
min_size=len(kw),
|
||||
max_size=len(kw),
|
||||
),
|
||||
).map(
|
||||
lambda pair: "".join(
|
||||
c.upper() if flag else c.lower()
|
||||
for c, flag in zip(pair[0], pair[1])
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 17: SQL 写操作拦截
|
||||
# **Validates: Requirements 7.5**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=200)
|
||||
@given(
|
||||
prefix=_sql_filler_st,
|
||||
keyword=_random_case_keyword_st,
|
||||
suffix=_sql_filler_st,
|
||||
)
|
||||
def test_write_keywords_always_detected(prefix, keyword, suffix):
|
||||
"""Property 17: SQL 写操作拦截。
|
||||
|
||||
包含 INSERT、UPDATE、DELETE、DROP、TRUNCATE 关键词(不区分大小写)的
|
||||
SQL 语句,_WRITE_KEYWORDS 正则表达式应能匹配到。
|
||||
|
||||
策略:在随机前缀和后缀之间插入一个随机大小写混合的写操作关键词,
|
||||
用空格分隔以确保 \\b 词边界能匹配。
|
||||
"""
|
||||
# 用空格分隔确保词边界匹配
|
||||
sql = f"{prefix} {keyword} {suffix}"
|
||||
|
||||
match = _WRITE_KEYWORDS.search(sql)
|
||||
assert match is not None, (
|
||||
f"正则表达式未能匹配到写操作关键词:sql={sql!r}, keyword={keyword!r}"
|
||||
)
|
||||
# 匹配到的关键词(转大写后)应在写操作列表中
|
||||
assert match.group(1).upper() in _WRITE_OPS, (
|
||||
f"匹配到的关键词 '{match.group(1)}' 不在写操作列表中"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 18: SQL 查询结果行数限制
|
||||
# **Validates: Requirements 7.4**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 模拟数据库返回的行:每行是一个简单列表
|
||||
_row_st = st.lists(
|
||||
st.one_of(st.integers(), st.text(max_size=20), st.none()),
|
||||
min_size=1,
|
||||
max_size=5,
|
||||
)
|
||||
|
||||
# 行列表策略:0 到 3000 行(覆盖超过 _MAX_ROWS 的情况)
|
||||
_rows_st = st.lists(_row_st, min_size=0, max_size=3000)
|
||||
|
||||
|
||||
@settings(max_examples=200)
|
||||
@given(rows=_rows_st)
|
||||
def test_row_count_never_exceeds_max(rows):
|
||||
"""Property 18: SQL 查询结果行数限制。
|
||||
|
||||
对任意长度的行列表,取前 _MAX_ROWS 个元素后,
|
||||
结果长度应 <= 1000。
|
||||
|
||||
这等价于 cur.fetchmany(_MAX_ROWS) 的行为:
|
||||
数据库游标最多返回 _MAX_ROWS 行。
|
||||
"""
|
||||
# 模拟 fetchmany(_MAX_ROWS) 的行为
|
||||
truncated = rows[:_MAX_ROWS]
|
||||
|
||||
assert len(truncated) <= _MAX_ROWS, (
|
||||
f"截取后行数 {len(truncated)} 超过上限 {_MAX_ROWS}"
|
||||
)
|
||||
|
||||
# 额外验证:如果原始行数 <= _MAX_ROWS,截取后应保留全部
|
||||
if len(rows) <= _MAX_ROWS:
|
||||
assert len(truncated) == len(rows), (
|
||||
f"原始行数 {len(rows)} <= {_MAX_ROWS},截取后应保留全部,"
|
||||
f"实际 {len(truncated)}"
|
||||
)
|
||||
|
||||
# 额外验证:如果原始行数 > _MAX_ROWS,截取后应恰好为 _MAX_ROWS
|
||||
if len(rows) > _MAX_ROWS:
|
||||
assert len(truncated) == _MAX_ROWS, (
|
||||
f"原始行数 {len(rows)} > {_MAX_ROWS},截取后应恰好为 {_MAX_ROWS},"
|
||||
f"实际 {len(truncated)}"
|
||||
)
|
||||
321
apps/backend/tests/test_db_viewer_router.py
Normal file
321
apps/backend/tests/test_db_viewer_router.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""数据库查看器路由单元测试
|
||||
|
||||
覆盖 4 个端点:
|
||||
- GET /api/db/schemas
|
||||
- GET /api/db/schemas/{name}/tables
|
||||
- GET /api/db/tables/{schema}/{table}/columns
|
||||
- POST /api/db/query
|
||||
|
||||
通过 mock 绕过数据库连接,专注路由逻辑验证。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from psycopg2 import errors as pg_errors
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
|
||||
_TEST_USER = CurrentUser(user_id=1, site_id=100)
|
||||
|
||||
|
||||
def _override_auth():
|
||||
return _TEST_USER
|
||||
|
||||
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
client = TestClient(app)
|
||||
|
||||
_MOCK_CONN = "app.routers.db_viewer.get_etl_readonly_connection"
|
||||
|
||||
|
||||
def _make_mock_conn(rows, description=None):
|
||||
"""构造 mock 数据库连接,cursor 返回指定行和列描述。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
mock_cur.fetchall.return_value = rows
|
||||
mock_cur.fetchmany.return_value = rows
|
||||
mock_cur.description = description
|
||||
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cur
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_conn, mock_cur
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/db/schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListSchemas:
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_returns_schema_list(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([("dwd",), ("dws",), ("ods",)])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/schemas")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 3
|
||||
assert data[0]["name"] == "dwd"
|
||||
assert data[2]["name"] == "ods"
|
||||
|
||||
# 验证 site_id 传递
|
||||
mock_get_conn.assert_called_once_with(_TEST_USER.site_id)
|
||||
conn.close.assert_called_once()
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_empty_schemas(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/schemas")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/db/schemas/{name}/tables
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListTables:
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_returns_tables_with_row_count(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([
|
||||
("dim_member", 1500),
|
||||
("fact_order", 32000),
|
||||
])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/schemas/dwd/tables")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 2
|
||||
assert data[0]["name"] == "dim_member"
|
||||
assert data[0]["row_count"] == 1500
|
||||
assert data[1]["name"] == "fact_order"
|
||||
assert data[1]["row_count"] == 32000
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_null_row_count(self, mock_get_conn):
|
||||
"""pg_stat_user_tables 可能没有统计信息,row_count 为 None。"""
|
||||
conn, cur = _make_mock_conn([("new_table", None)])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/schemas/ods/tables")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data[0]["row_count"] is None
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_empty_schema(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/schemas/empty_schema/tables")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/db/tables/{schema}/{table}/columns
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListColumns:
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_returns_column_definitions(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([
|
||||
("id", "bigint", "NO", None),
|
||||
("name", "character varying", "YES", None),
|
||||
("created_at", "timestamp with time zone", "NO", "now()"),
|
||||
])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/tables/dwd/dim_member/columns")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 3
|
||||
|
||||
assert data[0]["name"] == "id"
|
||||
assert data[0]["data_type"] == "bigint"
|
||||
assert data[0]["is_nullable"] is False
|
||||
assert data[0]["column_default"] is None
|
||||
|
||||
assert data[1]["is_nullable"] is True
|
||||
assert data[2]["column_default"] == "now()"
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_empty_table(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/tables/dwd/nonexistent/columns")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/db/query
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExecuteQuery:
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_successful_select(self, mock_get_conn):
|
||||
description = [("id",), ("name",)]
|
||||
conn, cur = _make_mock_conn(
|
||||
[(1, "Alice"), (2, "Bob")],
|
||||
description=description,
|
||||
)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.post("/api/db/query", json={"sql": "SELECT id, name FROM users"})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["columns"] == ["id", "name"]
|
||||
assert data["rows"] == [[1, "Alice"], [2, "Bob"]]
|
||||
assert data["row_count"] == 2
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_empty_result(self, mock_get_conn):
|
||||
description = [("id",)]
|
||||
conn, cur = _make_mock_conn([], description=description)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.post("/api/db/query", json={"sql": "SELECT id FROM empty_table"})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["columns"] == ["id"]
|
||||
assert data["rows"] == []
|
||||
assert data["row_count"] == 0
|
||||
|
||||
# ── 写操作拦截 ──
|
||||
|
||||
@pytest.mark.parametrize("keyword", [
|
||||
"INSERT", "UPDATE", "DELETE", "DROP", "TRUNCATE",
|
||||
"insert", "update", "delete", "drop", "truncate",
|
||||
"Insert", "Update", "Delete", "Drop", "Truncate",
|
||||
])
|
||||
def test_blocks_write_operations(self, keyword):
|
||||
resp = client.post("/api/db/query", json={"sql": f"{keyword} INTO some_table VALUES (1)"})
|
||||
assert resp.status_code == 400
|
||||
assert "只读" in resp.json()["detail"] or "禁止" in resp.json()["detail"]
|
||||
|
||||
def test_blocks_mixed_case_write(self):
|
||||
resp = client.post("/api/db/query", json={"sql": "DeLeTe FROM users WHERE id = 1"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_blocks_write_in_subquery(self):
|
||||
"""写操作关键词出现在 SQL 任意位置都应拦截。"""
|
||||
resp = client.post("/api/db/query", json={"sql": "SELECT * FROM (DELETE FROM users) sub"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
# ── 空 SQL ──
|
||||
|
||||
def test_empty_sql(self):
|
||||
resp = client.post("/api/db/query", json={"sql": ""})
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_whitespace_only_sql(self):
|
||||
resp = client.post("/api/db/query", json={"sql": " "})
|
||||
assert resp.status_code == 400
|
||||
|
||||
# ── SQL 语法错误 ──
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_sql_syntax_error(self, mock_get_conn):
|
||||
conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
# 第一次 execute 设置 timeout 成功,第二次抛异常
|
||||
mock_cur.execute.side_effect = [None, Exception("syntax error at or near \"SELEC\"")]
|
||||
mock_cur.description = None
|
||||
conn.cursor.return_value.__enter__ = lambda s: mock_cur
|
||||
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.post("/api/db/query", json={"sql": "SELEC * FROM users"})
|
||||
assert resp.status_code == 400
|
||||
assert "SQL 执行错误" in resp.json()["detail"]
|
||||
|
||||
# ── 查询超时 ──
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_query_timeout(self, mock_get_conn):
|
||||
conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
mock_cur.execute.side_effect = [None, pg_errors.QueryCanceled()]
|
||||
mock_cur.description = None
|
||||
conn.cursor.return_value.__enter__ = lambda s: mock_cur
|
||||
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.post("/api/db/query", json={"sql": "SELECT pg_sleep(60)"})
|
||||
assert resp.status_code == 408
|
||||
assert "超时" in resp.json()["detail"]
|
||||
|
||||
# ── 行数限制验证 ──
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_row_limit(self, mock_get_conn):
|
||||
"""验证 fetchmany 被调用时传入 1000 行限制。"""
|
||||
description = [("id",)]
|
||||
conn, cur = _make_mock_conn(
|
||||
[(i,) for i in range(1000)],
|
||||
description=description,
|
||||
)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.post("/api/db/query", json={"sql": "SELECT id FROM big_table"})
|
||||
assert resp.status_code == 200
|
||||
# 验证 fetchmany 被调用时传入了 1000
|
||||
cur.fetchmany.assert_called_once_with(1000)
|
||||
|
||||
# ── 超时设置验证 ──
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_sets_statement_timeout(self, mock_get_conn):
|
||||
"""验证查询前设置了 statement_timeout。"""
|
||||
description = [("id",)]
|
||||
conn, cur = _make_mock_conn([(1,)], description=description)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
client.post("/api/db/query", json={"sql": "SELECT 1"})
|
||||
|
||||
# 第一次 execute 应该是设置超时
|
||||
first_call = cur.execute.call_args_list[0]
|
||||
assert "statement_timeout" in first_call[0][0]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 认证测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDbViewerAuth:
|
||||
|
||||
def test_requires_auth(self):
|
||||
"""移除 auth override 后,所有端点应返回 401/403。"""
|
||||
original = app.dependency_overrides.pop(get_current_user, None)
|
||||
try:
|
||||
endpoints = [
|
||||
("GET", "/api/db/schemas"),
|
||||
("GET", "/api/db/schemas/dwd/tables"),
|
||||
("GET", "/api/db/tables/dwd/dim_member/columns"),
|
||||
("POST", "/api/db/query"),
|
||||
]
|
||||
for method, url in endpoints:
|
||||
if method == "POST":
|
||||
resp = client.request(method, url, json={"sql": "SELECT 1"})
|
||||
else:
|
||||
resp = client.request(method, url)
|
||||
assert resp.status_code in (401, 403), f"{method} {url} 应需要认证"
|
||||
finally:
|
||||
if original:
|
||||
app.dependency_overrides[get_current_user] = original
|
||||
191
apps/backend/tests/test_env_config_properties.py
Normal file
191
apps/backend/tests/test_env_config_properties.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""环境配置属性测试(Property-Based Testing)。
|
||||
|
||||
使用 hypothesis 验证环境配置管理的通用正确性属性:
|
||||
- Property 15: .env 解析与敏感值掩码
|
||||
- Property 16: .env 写入往返一致性
|
||||
|
||||
测试策略:
|
||||
- Property 15: 生成随机 .env 内容(含敏感和非敏感键),验证 _parse_env + _is_sensitive 对敏感值掩码
|
||||
- Property 16: 生成随机键值对,序列化为 .env 格式后再解析,验证往返一致性
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-env-config-properties")
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from app.routers.env_config import _parse_env, _is_sensitive, _MASK, _SENSITIVE_KEYWORDS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 通用策略(Strategies)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 合法的环境变量键名:字母或下划线开头,后跟字母、数字、下划线
|
||||
_key_start_char = st.sampled_from(
|
||||
list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_")
|
||||
)
|
||||
_key_rest_char = st.sampled_from(
|
||||
list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_")
|
||||
)
|
||||
|
||||
_env_key_st = st.builds(
|
||||
lambda first, rest: first + rest,
|
||||
first=_key_start_char,
|
||||
rest=st.text(alphabet=_key_rest_char, min_size=0, max_size=30),
|
||||
)
|
||||
|
||||
# 值:不含换行符的可打印字符串(排除引号以避免解析歧义)
|
||||
_env_value_st = st.text(
|
||||
alphabet=st.characters(
|
||||
whitelist_categories=("L", "N", "P", "S"),
|
||||
blacklist_characters='\n\r"\'#',
|
||||
),
|
||||
min_size=0,
|
||||
max_size=50,
|
||||
)
|
||||
|
||||
# 敏感键:在随机键名中嵌入敏感关键词
|
||||
_sensitive_keyword_st = st.sampled_from(list(_SENSITIVE_KEYWORDS))
|
||||
|
||||
_sensitive_key_st = st.builds(
|
||||
lambda prefix, kw, suffix: prefix + kw + suffix,
|
||||
prefix=st.text(alphabet=_key_rest_char, min_size=0, max_size=10),
|
||||
kw=_sensitive_keyword_st,
|
||||
suffix=st.text(alphabet=_key_rest_char, min_size=0, max_size=10),
|
||||
).filter(lambda k: len(k) > 0 and k[0].isalpha() or k[0] == "_")
|
||||
|
||||
# 确保敏感键以字母或下划线开头
|
||||
_safe_sensitive_key_st = st.builds(
|
||||
lambda prefix, kw: prefix + "_" + kw,
|
||||
prefix=st.sampled_from(["DB", "API", "ETL", "APP", "MY"]),
|
||||
kw=_sensitive_keyword_st,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 15: .env 解析与敏感值掩码
|
||||
# **Validates: Requirements 6.1, 6.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
sensitive_keys=st.lists(_safe_sensitive_key_st, min_size=1, max_size=5, unique=True),
|
||||
sensitive_values=st.lists(
|
||||
st.text(min_size=1, max_size=30, alphabet=st.characters(
|
||||
whitelist_categories=("L", "N"),
|
||||
)),
|
||||
min_size=1, max_size=5,
|
||||
),
|
||||
normal_keys=st.lists(_env_key_st, min_size=1, max_size=5, unique=True),
|
||||
normal_values=st.lists(_env_value_st, min_size=1, max_size=5),
|
||||
)
|
||||
def test_sensitive_values_masked(sensitive_keys, sensitive_values, normal_keys, normal_values):
|
||||
"""Property 15: .env 解析与敏感值掩码。
|
||||
|
||||
包含敏感键(PASSWORD、TOKEN、SECRET、DSN)的 .env 文件内容,
|
||||
API 返回的键值对列表中这些键的值应被掩码替换,不包含原始敏感值。
|
||||
"""
|
||||
# 确保敏感键和普通键不重叠
|
||||
normal_keys_filtered = [k for k in normal_keys if k not in sensitive_keys]
|
||||
assume(len(normal_keys_filtered) >= 1)
|
||||
|
||||
# 对齐列表长度
|
||||
s_vals = (sensitive_values * ((len(sensitive_keys) // len(sensitive_values)) + 1))[:len(sensitive_keys)]
|
||||
n_vals = (normal_values * ((len(normal_keys_filtered) // len(normal_values)) + 1))[:len(normal_keys_filtered)]
|
||||
|
||||
# 构造 .env 内容
|
||||
lines = []
|
||||
for k, v in zip(sensitive_keys, s_vals):
|
||||
lines.append(f"{k}={v}")
|
||||
for k, v in zip(normal_keys_filtered, n_vals):
|
||||
lines.append(f"{k}={v}")
|
||||
env_content = "\n".join(lines) + "\n"
|
||||
|
||||
# 解析
|
||||
parsed = _parse_env(env_content)
|
||||
entries = [line for line in parsed if line["type"] == "entry"]
|
||||
|
||||
# 模拟 GET 端点的掩码逻辑
|
||||
masked_entries = {}
|
||||
for entry in entries:
|
||||
if _is_sensitive(entry["key"]):
|
||||
masked_entries[entry["key"]] = _MASK
|
||||
else:
|
||||
masked_entries[entry["key"]] = entry["value"]
|
||||
|
||||
# 验证:敏感键的值应被掩码
|
||||
for k, v in zip(sensitive_keys, s_vals):
|
||||
assert k in masked_entries, f"敏感键 {k} 应出现在解析结果中"
|
||||
assert masked_entries[k] == _MASK, (
|
||||
f"敏感键 {k} 的值应为掩码 '{_MASK}',实际为 '{masked_entries[k]}'"
|
||||
)
|
||||
# 原始敏感值不应出现在掩码后的结果中
|
||||
assert masked_entries[k] != v, (
|
||||
f"敏感键 {k} 的原始值 '{v}' 不应出现在掩码结果中"
|
||||
)
|
||||
|
||||
# 验证:非敏感键的值应保持原样
|
||||
for k, v in zip(normal_keys_filtered, n_vals):
|
||||
if not _is_sensitive(k):
|
||||
assert k in masked_entries, f"普通键 {k} 应出现在解析结果中"
|
||||
assert masked_entries[k] == v, (
|
||||
f"普通键 {k} 的值应为 '{v}',实际为 '{masked_entries[k]}'"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 16: .env 写入往返一致性
|
||||
# **Validates: Requirements 6.2**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
entries=st.lists(
|
||||
st.tuples(_env_key_st, _env_value_st),
|
||||
min_size=1,
|
||||
max_size=10,
|
||||
unique_by=lambda t: t[0], # 键唯一
|
||||
),
|
||||
)
|
||||
def test_env_write_read_round_trip(entries):
|
||||
"""Property 16: .env 写入往返一致性。
|
||||
|
||||
有效的键值对集合(不含注释和空行),写入 .env 文件后再读取解析,
|
||||
应得到与原始集合等价的键值对。
|
||||
"""
|
||||
# 过滤掉值中可能导致解析歧义的情况(值前后空白会被 strip)
|
||||
clean_entries = [(k, v.strip()) for k, v in entries]
|
||||
# 排除空键(策略已保证非空,但防御性检查)
|
||||
clean_entries = [(k, v) for k, v in clean_entries if k]
|
||||
|
||||
assume(len(clean_entries) >= 1)
|
||||
|
||||
# 模拟写入:构造 .env 文件内容
|
||||
lines = [f"{k}={v}" for k, v in clean_entries]
|
||||
env_content = "\n".join(lines) + "\n"
|
||||
|
||||
# 解析
|
||||
parsed = _parse_env(env_content)
|
||||
parsed_entries = {
|
||||
line["key"]: line["value"]
|
||||
for line in parsed
|
||||
if line["type"] == "entry"
|
||||
}
|
||||
|
||||
# 验证往返一致性:每个写入的键值对都应在解析结果中
|
||||
for k, v in clean_entries:
|
||||
assert k in parsed_entries, (
|
||||
f"键 '{k}' 应出现在解析结果中,实际键集合: {list(parsed_entries.keys())}"
|
||||
)
|
||||
assert parsed_entries[k] == v, (
|
||||
f"键 '{k}' 的值不一致:写入='{v}',解析='{parsed_entries[k]}'"
|
||||
)
|
||||
|
||||
# 验证:解析结果的键数量应与写入的一致
|
||||
assert len(parsed_entries) == len(clean_entries), (
|
||||
f"解析结果键数量 {len(parsed_entries)} 应等于写入数量 {len(clean_entries)}"
|
||||
)
|
||||
291
apps/backend/tests/test_env_config_router.py
Normal file
291
apps/backend/tests/test_env_config_router.py
Normal file
@@ -0,0 +1,291 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""环境配置路由单元测试
|
||||
|
||||
覆盖 3 个端点:GET / PUT / GET /export
|
||||
通过 mock 绕过文件 I/O,专注路由逻辑验证。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
|
||||
_TEST_USER = CurrentUser(user_id=1, site_id=100)
|
||||
|
||||
|
||||
def _override_auth():
|
||||
return _TEST_USER
|
||||
|
||||
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
client = TestClient(app)
|
||||
|
||||
# 模拟 .env 文件内容
|
||||
_SAMPLE_ENV = """\
|
||||
# 数据库配置
|
||||
DB_HOST=localhost
|
||||
DB_PORT=5432
|
||||
DB_PASSWORD=super_secret_123
|
||||
JWT_SECRET_KEY=my-jwt-secret
|
||||
|
||||
# ETL 配置
|
||||
ETL_DB_DSN=postgresql://user:pass@host/db
|
||||
TIMEZONE=Asia/Shanghai
|
||||
"""
|
||||
|
||||
_MOCK_ENV_PATH = "app.routers.env_config._ENV_PATH"
|
||||
|
||||
|
||||
def _mock_path(content: str | None = _SAMPLE_ENV, exists: bool = True):
|
||||
"""构造 mock Path 对象。"""
|
||||
mock = MagicMock()
|
||||
mock.exists.return_value = exists
|
||||
if content is not None:
|
||||
mock.read_text.return_value = content
|
||||
return mock
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/env-config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetEnvConfig:
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_returns_entries_with_masked_sensitive(self, mock_path_obj):
|
||||
mock_path_obj.__class__ = type(MagicMock())
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = _SAMPLE_ENV
|
||||
|
||||
resp = client.get("/api/env-config")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
entries = {e["key"]: e["value"] for e in data["entries"]}
|
||||
|
||||
# 非敏感值原样返回
|
||||
assert entries["DB_HOST"] == "localhost"
|
||||
assert entries["DB_PORT"] == "5432"
|
||||
assert entries["TIMEZONE"] == "Asia/Shanghai"
|
||||
|
||||
# 敏感值掩码
|
||||
assert entries["DB_PASSWORD"] == "****"
|
||||
assert entries["JWT_SECRET_KEY"] == "****"
|
||||
assert entries["ETL_DB_DSN"] == "****"
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_file_not_found(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = False
|
||||
|
||||
resp = client.get("/api/env-config")
|
||||
assert resp.status_code == 404
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_empty_file(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = ""
|
||||
|
||||
resp = client.get("/api/env-config")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["entries"] == []
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_comments_and_blank_lines_excluded(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = "# comment\n\nKEY=val\n"
|
||||
|
||||
resp = client.get("/api/env-config")
|
||||
assert resp.status_code == 200
|
||||
entries = resp.json()["entries"]
|
||||
assert len(entries) == 1
|
||||
assert entries[0]["key"] == "KEY"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PUT /api/env-config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUpdateEnvConfig:
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_update_existing_key(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = "DB_HOST=localhost\nDB_PORT=5432\n"
|
||||
|
||||
resp = client.put("/api/env-config", json={
|
||||
"entries": [
|
||||
{"key": "DB_HOST", "value": "192.168.1.1"},
|
||||
{"key": "DB_PORT", "value": "5433"},
|
||||
]
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
# 验证写入内容
|
||||
written = mock_path_obj.write_text.call_args[0][0]
|
||||
assert "DB_HOST=192.168.1.1" in written
|
||||
assert "DB_PORT=5433" in written
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_add_new_key(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = "DB_HOST=localhost\n"
|
||||
|
||||
resp = client.put("/api/env-config", json={
|
||||
"entries": [
|
||||
{"key": "DB_HOST", "value": "localhost"},
|
||||
{"key": "NEW_KEY", "value": "new_value"},
|
||||
]
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
written = mock_path_obj.write_text.call_args[0][0]
|
||||
assert "NEW_KEY=new_value" in written
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_masked_value_preserves_original(self, mock_path_obj):
|
||||
"""掩码值(****)不应覆盖原始敏感值。"""
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = "DB_PASSWORD=real_secret\nDB_HOST=localhost\n"
|
||||
|
||||
resp = client.put("/api/env-config", json={
|
||||
"entries": [
|
||||
{"key": "DB_PASSWORD", "value": "****"},
|
||||
{"key": "DB_HOST", "value": "newhost"},
|
||||
]
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
written = mock_path_obj.write_text.call_args[0][0]
|
||||
# 原始密码应保留
|
||||
assert "DB_PASSWORD=real_secret" in written
|
||||
assert "DB_HOST=newhost" in written
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_preserves_comments(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = "# 注释行\nDB_HOST=localhost\n\n# 另一个注释\n"
|
||||
|
||||
resp = client.put("/api/env-config", json={
|
||||
"entries": [{"key": "DB_HOST", "value": "newhost"}]
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
written = mock_path_obj.write_text.call_args[0][0]
|
||||
assert "# 注释行" in written
|
||||
assert "# 另一个注释" in written
|
||||
|
||||
def test_invalid_key_format(self):
|
||||
resp = client.put("/api/env-config", json={
|
||||
"entries": [{"key": "123BAD", "value": "val"}]
|
||||
})
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_empty_key(self):
|
||||
resp = client.put("/api/env-config", json={
|
||||
"entries": [{"key": "", "value": "val"}]
|
||||
})
|
||||
assert resp.status_code == 422
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_file_not_exists_creates_new(self, mock_path_obj):
|
||||
"""文件不存在时,应创建新文件。"""
|
||||
mock_path_obj.exists.return_value = False
|
||||
|
||||
resp = client.put("/api/env-config", json={
|
||||
"entries": [{"key": "NEW_KEY", "value": "value"}]
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
written = mock_path_obj.write_text.call_args[0][0]
|
||||
assert "NEW_KEY=value" in written
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_update_sensitive_with_new_value(self, mock_path_obj):
|
||||
"""显式提供新密码时应更新。"""
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = "DB_PASSWORD=old_secret\n"
|
||||
|
||||
resp = client.put("/api/env-config", json={
|
||||
"entries": [{"key": "DB_PASSWORD", "value": "new_secret"}]
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
written = mock_path_obj.write_text.call_args[0][0]
|
||||
assert "DB_PASSWORD=new_secret" in written
|
||||
|
||||
# 返回值中敏感键仍然掩码
|
||||
entries = {e["key"]: e["value"] for e in resp.json()["entries"]}
|
||||
assert entries["DB_PASSWORD"] == "****"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/env-config/export
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExportEnvConfig:
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_export_masks_sensitive(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = _SAMPLE_ENV
|
||||
|
||||
resp = client.get("/api/env-config/export")
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["content-type"].startswith("text/plain")
|
||||
assert "attachment" in resp.headers.get("content-disposition", "")
|
||||
|
||||
content = resp.text
|
||||
# 非敏感值保留
|
||||
assert "DB_HOST=localhost" in content
|
||||
assert "TIMEZONE=Asia/Shanghai" in content
|
||||
|
||||
# 敏感值掩码
|
||||
assert "super_secret_123" not in content
|
||||
assert "my-jwt-secret" not in content
|
||||
assert "DB_PASSWORD=****" in content
|
||||
assert "JWT_SECRET_KEY=****" in content
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_export_preserves_comments(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = True
|
||||
mock_path_obj.read_text.return_value = _SAMPLE_ENV
|
||||
|
||||
content = client.get("/api/env-config/export").text
|
||||
assert "# 数据库配置" in content
|
||||
assert "# ETL 配置" in content
|
||||
|
||||
@patch(_MOCK_ENV_PATH)
|
||||
def test_export_file_not_found(self, mock_path_obj):
|
||||
mock_path_obj.exists.return_value = False
|
||||
|
||||
resp = client.get("/api/env-config/export")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 认证测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEnvConfigAuth:
|
||||
|
||||
def test_requires_auth(self):
|
||||
"""移除 auth override 后,所有端点应返回 401/403。"""
|
||||
# 临时移除 override
|
||||
original = app.dependency_overrides.pop(get_current_user, None)
|
||||
try:
|
||||
for method, url in [
|
||||
("GET", "/api/env-config"),
|
||||
("PUT", "/api/env-config"),
|
||||
("GET", "/api/env-config/export"),
|
||||
]:
|
||||
resp = client.request(method, url)
|
||||
assert resp.status_code in (401, 403), f"{method} {url} 应需要认证"
|
||||
finally:
|
||||
if original:
|
||||
app.dependency_overrides[get_current_user] = original
|
||||
246
apps/backend/tests/test_etl_status_router.py
Normal file
246
apps/backend/tests/test_etl_status_router.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""ETL 状态路由单元测试
|
||||
|
||||
覆盖 2 个端点:
|
||||
- GET /api/etl-status/cursors
|
||||
- GET /api/etl-status/recent-runs
|
||||
|
||||
通过 mock 绕过数据库连接,专注路由逻辑验证。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
|
||||
_TEST_USER = CurrentUser(user_id=1, site_id=100)
|
||||
|
||||
|
||||
def _override_auth():
|
||||
return _TEST_USER
|
||||
|
||||
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
client = TestClient(app)
|
||||
|
||||
_MOCK_ETL_CONN = "app.routers.etl_status.get_etl_readonly_connection"
|
||||
_MOCK_APP_CONN = "app.routers.etl_status.get_connection"
|
||||
|
||||
|
||||
def _make_mock_conn(rows):
|
||||
"""构造 mock 数据库连接,cursor 返回指定行。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
mock_cur.fetchall.return_value = rows
|
||||
mock_cur.fetchone.return_value = None
|
||||
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cur
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_conn, mock_cur
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/etl-status/cursors
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListCursors:
|
||||
|
||||
@patch(_MOCK_ETL_CONN)
|
||||
def test_returns_cursor_list(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([
|
||||
("ODS_FETCH_ORDERS", "2024-06-15 10:30:00+08", 1500),
|
||||
("ODS_FETCH_MEMBERS", "2024-06-15 09:00:00+08", 800),
|
||||
])
|
||||
# fetchone 用于 EXISTS 检查
|
||||
cur.fetchone.return_value = (True,)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/etl-status/cursors")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 2
|
||||
assert data[0]["task_code"] == "ODS_FETCH_ORDERS"
|
||||
assert data[0]["last_fetch_time"] == "2024-06-15 10:30:00+08"
|
||||
assert data[0]["record_count"] == 1500
|
||||
assert data[1]["task_code"] == "ODS_FETCH_MEMBERS"
|
||||
|
||||
# 验证 site_id 传递
|
||||
mock_get_conn.assert_called_once_with(_TEST_USER.site_id)
|
||||
conn.close.assert_called_once()
|
||||
|
||||
@patch(_MOCK_ETL_CONN)
|
||||
def test_table_not_exists_returns_empty(self, mock_get_conn):
|
||||
"""etl_admin.etl_cursor 表不存在时返回空列表。"""
|
||||
conn, cur = _make_mock_conn([])
|
||||
cur.fetchone.return_value = (False,)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/etl-status/cursors")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
@patch(_MOCK_ETL_CONN)
|
||||
def test_null_fields(self, mock_get_conn):
|
||||
"""游标字段可能为 None(任务从未执行过)。"""
|
||||
conn, cur = _make_mock_conn([
|
||||
("ODS_FETCH_INVENTORY", None, None),
|
||||
])
|
||||
cur.fetchone.return_value = (True,)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/etl-status/cursors")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data[0]["task_code"] == "ODS_FETCH_INVENTORY"
|
||||
assert data[0]["last_fetch_time"] is None
|
||||
assert data[0]["record_count"] is None
|
||||
|
||||
@patch(_MOCK_ETL_CONN)
|
||||
def test_empty_cursors(self, mock_get_conn):
|
||||
"""表存在但无数据。"""
|
||||
conn, cur = _make_mock_conn([])
|
||||
cur.fetchone.return_value = (True,)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/etl-status/cursors")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/etl-status/recent-runs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListRecentRuns:
|
||||
|
||||
@patch(_MOCK_APP_CONN)
|
||||
def test_returns_recent_runs(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([
|
||||
(
|
||||
"a1b2c3d4-0000-0000-0000-000000000001",
|
||||
["ODS_FETCH_ORDERS", "DWD_LOAD_FROM_ODS"],
|
||||
"success",
|
||||
"2024-06-15 10:30:00+08",
|
||||
"2024-06-15 10:35:00+08",
|
||||
300000,
|
||||
0,
|
||||
),
|
||||
(
|
||||
"a1b2c3d4-0000-0000-0000-000000000002",
|
||||
["DWS_AGGREGATE"],
|
||||
"failed",
|
||||
"2024-06-15 09:00:00+08",
|
||||
"2024-06-15 09:01:00+08",
|
||||
60000,
|
||||
1,
|
||||
),
|
||||
])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/etl-status/recent-runs")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 2
|
||||
|
||||
run0 = data[0]
|
||||
assert run0["id"] == "a1b2c3d4-0000-0000-0000-000000000001"
|
||||
assert run0["task_codes"] == ["ODS_FETCH_ORDERS", "DWD_LOAD_FROM_ODS"]
|
||||
assert run0["status"] == "success"
|
||||
assert run0["duration_ms"] == 300000
|
||||
assert run0["exit_code"] == 0
|
||||
|
||||
run1 = data[1]
|
||||
assert run1["status"] == "failed"
|
||||
assert run1["exit_code"] == 1
|
||||
|
||||
conn.close.assert_called_once()
|
||||
|
||||
@patch(_MOCK_APP_CONN)
|
||||
def test_empty_runs(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/etl-status/recent-runs")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
@patch(_MOCK_APP_CONN)
|
||||
def test_null_optional_fields(self, mock_get_conn):
|
||||
"""正在执行的任务 finished_at / duration_ms / exit_code 为 None。"""
|
||||
conn, cur = _make_mock_conn([
|
||||
(
|
||||
"a1b2c3d4-0000-0000-0000-000000000003",
|
||||
["ODS_FETCH_MEMBERS"],
|
||||
"running",
|
||||
"2024-06-15 11:00:00+08",
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/etl-status/recent-runs")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data[0]["status"] == "running"
|
||||
assert data[0]["finished_at"] is None
|
||||
assert data[0]["duration_ms"] is None
|
||||
assert data[0]["exit_code"] is None
|
||||
|
||||
@patch(_MOCK_APP_CONN)
|
||||
def test_site_id_filter(self, mock_get_conn):
|
||||
"""验证查询时传入了正确的 site_id 参数。"""
|
||||
conn, cur = _make_mock_conn([])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
client.get("/api/etl-status/recent-runs")
|
||||
|
||||
# 验证 SQL 中传入了 site_id 和 limit
|
||||
call_args = cur.execute.call_args
|
||||
params = call_args[0][1]
|
||||
assert params[0] == _TEST_USER.site_id
|
||||
assert params[1] == 50
|
||||
|
||||
@patch(_MOCK_APP_CONN)
|
||||
def test_empty_task_codes(self, mock_get_conn):
|
||||
"""task_codes 为 None 时应返回空列表。"""
|
||||
conn, cur = _make_mock_conn([
|
||||
(
|
||||
"a1b2c3d4-0000-0000-0000-000000000004",
|
||||
None,
|
||||
"pending",
|
||||
"2024-06-15 12:00:00+08",
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/etl-status/recent-runs")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()[0]["task_codes"] == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 认证测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEtlStatusAuth:
|
||||
|
||||
def test_requires_auth(self):
|
||||
"""移除 auth override 后,所有端点应返回 401/403。"""
|
||||
original = app.dependency_overrides.pop(get_current_user, None)
|
||||
try:
|
||||
for url in ["/api/etl-status/cursors", "/api/etl-status/recent-runs"]:
|
||||
resp = client.get(url)
|
||||
assert resp.status_code in (401, 403), f"GET {url} 应需要认证"
|
||||
finally:
|
||||
if original:
|
||||
app.dependency_overrides[get_current_user] = original
|
||||
339
apps/backend/tests/test_execution_router.py
Normal file
339
apps/backend/tests/test_execution_router.py
Normal file
@@ -0,0 +1,339 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""执行与队列路由单元测试
|
||||
|
||||
覆盖 8 个端点:run / queue CRUD / cancel / history / logs
|
||||
通过 mock 绕过数据库和服务层,专注路由逻辑验证。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
from app.services.task_queue import QueuedTask
|
||||
|
||||
# 固定测试用户
|
||||
_TEST_USER = CurrentUser(user_id=1, site_id=100)
|
||||
|
||||
|
||||
def _override_auth():
|
||||
return _TEST_USER
|
||||
|
||||
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
client = TestClient(app)
|
||||
|
||||
_NOW = datetime(2024, 6, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
# 构造测试用的 TaskConfig payload
|
||||
_VALID_CONFIG = {
|
||||
"tasks": ["ODS_MEMBER"],
|
||||
"pipeline": "api_ods",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/execution/run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRunTask:
|
||||
@patch("app.routers.execution.task_executor")
|
||||
def test_run_returns_execution_id(self, mock_executor):
|
||||
mock_executor.execute = AsyncMock()
|
||||
resp = client.post("/api/execution/run", json=_VALID_CONFIG)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "execution_id" in data
|
||||
assert data["message"] == "任务已提交执行"
|
||||
|
||||
@patch("app.routers.execution.task_executor")
|
||||
def test_run_injects_store_id(self, mock_executor):
|
||||
"""store_id 应从 JWT 注入"""
|
||||
mock_executor.execute = AsyncMock()
|
||||
resp = client.post("/api/execution/run", json={
|
||||
**_VALID_CONFIG,
|
||||
"store_id": 999, # 前端传的值应被覆盖
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_run_requires_auth(self):
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
try:
|
||||
resp = client.post("/api/execution/run", json=_VALID_CONFIG)
|
||||
assert resp.status_code in (401, 403)
|
||||
finally:
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
|
||||
def test_run_invalid_config_returns_422(self):
|
||||
"""缺少必填字段 tasks 时返回 422"""
|
||||
resp = client.post("/api/execution/run", json={"pipeline": "api_ods"})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/execution/queue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetQueue:
|
||||
@patch("app.routers.execution.task_queue")
|
||||
def test_get_queue_returns_list(self, mock_queue):
|
||||
mock_queue.list_pending.return_value = [
|
||||
QueuedTask(
|
||||
id="task-1", site_id=100, config={"tasks": ["ODS_MEMBER"]},
|
||||
status="pending", position=1, created_at=_NOW,
|
||||
),
|
||||
]
|
||||
resp = client.get("/api/execution/queue")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["id"] == "task-1"
|
||||
assert data[0]["status"] == "pending"
|
||||
|
||||
@patch("app.routers.execution.task_queue")
|
||||
def test_get_queue_empty(self, mock_queue):
|
||||
mock_queue.list_pending.return_value = []
|
||||
resp = client.get("/api/execution/queue")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
@patch("app.routers.execution.task_queue")
|
||||
def test_get_queue_filters_by_site_id(self, mock_queue):
|
||||
"""确认调用 list_pending 时传入了正确的 site_id"""
|
||||
mock_queue.list_pending.return_value = []
|
||||
client.get("/api/execution/queue")
|
||||
mock_queue.list_pending.assert_called_once_with(100)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/execution/queue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEnqueueTask:
|
||||
@patch("app.routers.execution.get_connection")
|
||||
@patch("app.routers.execution.task_queue")
|
||||
def test_enqueue_returns_201(self, mock_queue, mock_get_conn):
|
||||
mock_queue.enqueue.return_value = "new-task-id"
|
||||
|
||||
# mock 数据库查询返回
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = (
|
||||
"new-task-id", 100, '{"tasks": ["ODS_MEMBER"]}',
|
||||
"pending", 1, _NOW, None, None, None, None,
|
||||
)
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.post("/api/execution/queue", json=_VALID_CONFIG)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["id"] == "new-task-id"
|
||||
assert data["status"] == "pending"
|
||||
|
||||
@patch("app.routers.execution.task_queue")
|
||||
def test_enqueue_calls_with_site_id(self, mock_queue):
|
||||
"""确认 enqueue 时传入了 JWT 的 site_id"""
|
||||
mock_queue.enqueue.return_value = "id-1"
|
||||
|
||||
# 让后续的 DB 查询抛异常来快速结束(enqueue 本身已验证)
|
||||
with patch("app.routers.execution.get_connection") as mock_conn:
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = (
|
||||
"id-1", 100, '{"tasks": []}', "pending", 1,
|
||||
_NOW, None, None, None, None,
|
||||
)
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_conn.return_value = conn
|
||||
|
||||
client.post("/api/execution/queue", json=_VALID_CONFIG)
|
||||
|
||||
# 验证 enqueue 的第二个参数是 site_id=100
|
||||
call_args = mock_queue.enqueue.call_args
|
||||
assert call_args[0][1] == 100 # site_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PUT /api/execution/queue/reorder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReorderQueue:
|
||||
@patch("app.routers.execution.task_queue")
|
||||
def test_reorder_success(self, mock_queue):
|
||||
mock_queue.reorder.return_value = None
|
||||
resp = client.put("/api/execution/queue/reorder", json={
|
||||
"task_id": "task-1",
|
||||
"new_position": 3,
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
mock_queue.reorder.assert_called_once_with("task-1", 3, 100)
|
||||
|
||||
def test_reorder_missing_fields_returns_422(self):
|
||||
resp = client.put("/api/execution/queue/reorder", json={})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DELETE /api/execution/queue/{id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDeleteQueueTask:
|
||||
@patch("app.routers.execution.task_queue")
|
||||
def test_delete_success(self, mock_queue):
|
||||
mock_queue.delete.return_value = True
|
||||
resp = client.delete("/api/execution/queue/task-1")
|
||||
assert resp.status_code == 200
|
||||
mock_queue.delete.assert_called_once_with("task-1", 100)
|
||||
|
||||
@patch("app.routers.execution.task_queue")
|
||||
def test_delete_nonexistent_returns_409(self, mock_queue):
|
||||
mock_queue.delete.return_value = False
|
||||
resp = client.delete("/api/execution/queue/nonexistent")
|
||||
assert resp.status_code == 409
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/execution/{id}/cancel
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCancelExecution:
|
||||
@patch("app.routers.execution.task_executor")
|
||||
def test_cancel_success(self, mock_executor):
|
||||
mock_executor.cancel = AsyncMock(return_value=True)
|
||||
resp = client.post("/api/execution/some-id/cancel")
|
||||
assert resp.status_code == 200
|
||||
assert "取消" in resp.json()["message"]
|
||||
|
||||
@patch("app.routers.execution.task_executor")
|
||||
def test_cancel_not_found(self, mock_executor):
|
||||
mock_executor.cancel = AsyncMock(return_value=False)
|
||||
resp = client.post("/api/execution/nonexistent/cancel")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/execution/history
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExecutionHistory:
|
||||
@patch("app.routers.execution.get_connection")
|
||||
def test_history_returns_list(self, mock_get_conn):
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchall.return_value = [
|
||||
(
|
||||
"exec-1", 100, ["ODS_MEMBER"], "success",
|
||||
_NOW, _NOW, 0, 1234, "python -m cli.main", None,
|
||||
),
|
||||
]
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.get("/api/execution/history")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["id"] == "exec-1"
|
||||
assert data[0]["status"] == "success"
|
||||
|
||||
@patch("app.routers.execution.get_connection")
|
||||
def test_history_respects_limit(self, mock_get_conn):
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchall.return_value = []
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.get("/api/execution/history?limit=10")
|
||||
assert resp.status_code == 200
|
||||
|
||||
# 验证 SQL 中传入了 limit=10
|
||||
call_args = mock_cursor.execute.call_args
|
||||
assert call_args[0][1] == (100, 10) # (site_id, limit)
|
||||
|
||||
def test_history_limit_validation(self):
|
||||
"""limit 超出范围时返回 422"""
|
||||
resp = client.get("/api/execution/history?limit=0")
|
||||
assert resp.status_code == 422
|
||||
resp = client.get("/api/execution/history?limit=999")
|
||||
assert resp.status_code == 422
|
||||
|
||||
@patch("app.routers.execution.get_connection")
|
||||
def test_history_empty(self, mock_get_conn):
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchall.return_value = []
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.get("/api/execution/history")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/execution/{id}/logs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExecutionLogs:
|
||||
@patch("app.routers.execution.get_connection")
|
||||
@patch("app.routers.execution.task_executor")
|
||||
def test_logs_from_db(self, mock_executor, mock_get_conn):
|
||||
"""已完成任务从数据库读取日志"""
|
||||
mock_executor.is_running.return_value = False
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = ("stdout output", "stderr output")
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.get("/api/execution/exec-1/logs")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["execution_id"] == "exec-1"
|
||||
assert data["output_log"] == "stdout output"
|
||||
assert data["error_log"] == "stderr output"
|
||||
|
||||
@patch("app.routers.execution.task_executor")
|
||||
def test_logs_from_memory(self, mock_executor):
|
||||
"""执行中的任务从内存缓冲区读取"""
|
||||
mock_executor.is_running.return_value = True
|
||||
mock_executor.get_logs.return_value = ["line1", "line2"]
|
||||
|
||||
resp = client.get("/api/execution/running-id/logs")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["execution_id"] == "running-id"
|
||||
assert "line1" in data["output_log"]
|
||||
assert "line2" in data["output_log"]
|
||||
|
||||
@patch("app.routers.execution.get_connection")
|
||||
@patch("app.routers.execution.task_executor")
|
||||
def test_logs_not_found(self, mock_executor, mock_get_conn):
|
||||
mock_executor.is_running.return_value = False
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = None
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.get("/api/execution/nonexistent/logs")
|
||||
assert resp.status_code == 404
|
||||
510
apps/backend/tests/test_queue_properties.py
Normal file
510
apps/backend/tests/test_queue_properties.py
Normal file
@@ -0,0 +1,510 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""队列属性测试(Property-Based Testing)。
|
||||
|
||||
使用 hypothesis 验证队列管理的通用正确性属性:
|
||||
- Property 8: 队列 CRUD 不变量
|
||||
- Property 9: 队列出队顺序
|
||||
- Property 10: 队列重排一致性
|
||||
- Property 11: 执行历史排序与限制
|
||||
|
||||
测试策略:
|
||||
- Property 8-10 通过内存模拟队列状态,mock 数据库操作,验证 TaskQueue 的核心逻辑
|
||||
- Property 11 通过 mock 数据库返回,验证执行历史端点的排序与限制逻辑
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-queue-properties")
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from app.schemas.tasks import TaskConfigSchema
|
||||
from app.services.task_queue import TaskQueue, QueuedTask
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 通用策略(Strategies)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_site_id_st = st.integers(min_value=1, max_value=2**31 - 1)
|
||||
|
||||
# 简单的任务代码列表
|
||||
_task_codes = ["ODS_MEMBER", "ODS_PAYMENT", "ODS_ORDER", "DWD_LOAD_FROM_ODS", "DWS_SUMMARY"]
|
||||
|
||||
_simple_config_st = st.builds(
|
||||
TaskConfigSchema,
|
||||
tasks=st.lists(st.sampled_from(_task_codes), min_size=1, max_size=3, unique=True),
|
||||
pipeline=st.sampled_from(["api_ods", "api_ods_dwd", "ods_dwd"]),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 内存队列模拟器 — 用于 mock 数据库交互
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class InMemoryQueueDB:
|
||||
"""模拟 task_queue 表的内存存储,为 TaskQueue 方法提供 mock 数据库行为。"""
|
||||
|
||||
def __init__(self, site_id: int):
|
||||
self.site_id = site_id
|
||||
# 存储格式:{task_id: {config, status, position, ...}}
|
||||
self.rows: dict[str, dict] = {}
|
||||
|
||||
@property
|
||||
def pending_tasks(self) -> list[dict]:
|
||||
"""按 position 排序的 pending 任务列表。"""
|
||||
return sorted(
|
||||
[r for r in self.rows.values() if r["status"] == "pending"],
|
||||
key=lambda r: r["position"],
|
||||
)
|
||||
|
||||
def mock_enqueue_connection(self):
|
||||
"""为 enqueue 方法构造 mock connection。
|
||||
|
||||
enqueue 执行两条 SQL:
|
||||
1. SELECT COALESCE(MAX(position), 0) → 返回当前最大 position
|
||||
2. INSERT INTO task_queue → 插入新行
|
||||
"""
|
||||
pending = self.pending_tasks
|
||||
max_pos = max((r["position"] for r in pending), default=0)
|
||||
|
||||
call_count = [0]
|
||||
db = self
|
||||
|
||||
def make_cursor():
|
||||
cur = MagicMock()
|
||||
executed_sqls = []
|
||||
|
||||
def execute_side_effect(sql, params=None):
|
||||
executed_sqls.append((sql, params))
|
||||
call_count[0] += 1
|
||||
if "MAX(position)" in sql:
|
||||
cur.fetchone.return_value = (max_pos,)
|
||||
elif "INSERT INTO task_queue" in sql:
|
||||
# 记录插入的行
|
||||
task_id, site_id, config_json, new_pos = params
|
||||
db.rows[task_id] = {
|
||||
"id": task_id,
|
||||
"site_id": site_id,
|
||||
"config": json.loads(config_json),
|
||||
"status": "pending",
|
||||
"position": new_pos,
|
||||
}
|
||||
|
||||
cur.execute = MagicMock(side_effect=execute_side_effect)
|
||||
cur.__enter__ = MagicMock(return_value=cur)
|
||||
cur.__exit__ = MagicMock(return_value=False)
|
||||
return cur
|
||||
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value = make_cursor()
|
||||
return conn
|
||||
|
||||
def mock_dequeue_connection(self):
|
||||
"""为 dequeue 方法构造 mock connection。
|
||||
|
||||
dequeue 执行两条 SQL:
|
||||
1. SELECT ... ORDER BY position ASC LIMIT 1 FOR UPDATE → 返回队首任务
|
||||
2. UPDATE ... SET status = 'running' → 更新状态
|
||||
"""
|
||||
pending = self.pending_tasks
|
||||
first = pending[0] if pending else None
|
||||
db = self
|
||||
|
||||
def make_cursor():
|
||||
cur = MagicMock()
|
||||
|
||||
def execute_side_effect(sql, params=None):
|
||||
if "ORDER BY position ASC" in sql:
|
||||
if first:
|
||||
cur.fetchone.return_value = (
|
||||
first["id"], first["site_id"],
|
||||
json.dumps(first["config"]),
|
||||
first["status"], first["position"],
|
||||
None, None, None, None, None,
|
||||
)
|
||||
else:
|
||||
cur.fetchone.return_value = None
|
||||
elif "SET status = 'running'" in sql:
|
||||
if first:
|
||||
db.rows[first["id"]]["status"] = "running"
|
||||
|
||||
cur.execute = MagicMock(side_effect=execute_side_effect)
|
||||
cur.__enter__ = MagicMock(return_value=cur)
|
||||
cur.__exit__ = MagicMock(return_value=False)
|
||||
return cur
|
||||
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value = make_cursor()
|
||||
return conn
|
||||
|
||||
def mock_delete_connection(self, task_id: str):
|
||||
"""为 delete 方法构造 mock connection。"""
|
||||
db = self
|
||||
|
||||
def make_cursor():
|
||||
cur = MagicMock()
|
||||
|
||||
def execute_side_effect(sql, params=None):
|
||||
tid = params[0]
|
||||
if tid in db.rows and db.rows[tid]["status"] == "pending":
|
||||
del db.rows[tid]
|
||||
cur.rowcount = 1
|
||||
else:
|
||||
cur.rowcount = 0
|
||||
|
||||
cur.execute = MagicMock(side_effect=execute_side_effect)
|
||||
cur.rowcount = 0
|
||||
cur.__enter__ = MagicMock(return_value=cur)
|
||||
cur.__exit__ = MagicMock(return_value=False)
|
||||
return cur
|
||||
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value = make_cursor()
|
||||
return conn
|
||||
|
||||
def mock_reorder_connection(self):
|
||||
"""为 reorder 方法构造 mock connection。
|
||||
|
||||
reorder 执行:
|
||||
1. SELECT id FROM task_queue WHERE ... ORDER BY position ASC
|
||||
2. 多次 UPDATE task_queue SET position = %s WHERE id = %s
|
||||
"""
|
||||
pending = self.pending_tasks
|
||||
db = self
|
||||
|
||||
def make_cursor():
|
||||
cur = MagicMock()
|
||||
call_idx = [0]
|
||||
|
||||
def execute_side_effect(sql, params=None):
|
||||
if "SELECT id FROM task_queue" in sql:
|
||||
cur.fetchall.return_value = [(r["id"],) for r in pending]
|
||||
elif "UPDATE task_queue SET position" in sql:
|
||||
pos, tid = params
|
||||
if tid in db.rows:
|
||||
db.rows[tid]["position"] = pos
|
||||
|
||||
cur.execute = MagicMock(side_effect=execute_side_effect)
|
||||
cur.__enter__ = MagicMock(return_value=cur)
|
||||
cur.__exit__ = MagicMock(return_value=False)
|
||||
return cur
|
||||
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value = make_cursor()
|
||||
return conn
|
||||
|
||||
def mock_list_pending_connection(self):
|
||||
"""为 list_pending 方法构造 mock connection。"""
|
||||
pending = self.pending_tasks
|
||||
|
||||
def make_cursor():
|
||||
cur = MagicMock()
|
||||
|
||||
def execute_side_effect(sql, params=None):
|
||||
cur.fetchall.return_value = [
|
||||
(
|
||||
r["id"], r["site_id"], json.dumps(r["config"]),
|
||||
r["status"], r["position"],
|
||||
None, None, None, None, None,
|
||||
)
|
||||
for r in pending
|
||||
]
|
||||
|
||||
cur.execute = MagicMock(side_effect=execute_side_effect)
|
||||
cur.__enter__ = MagicMock(return_value=cur)
|
||||
cur.__exit__ = MagicMock(return_value=False)
|
||||
return cur
|
||||
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value = make_cursor()
|
||||
return conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 8: 队列 CRUD 不变量
|
||||
# **Validates: Requirements 4.1, 4.4**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
config=_simple_config_st,
|
||||
site_id=_site_id_st,
|
||||
initial_count=st.integers(min_value=0, max_value=5),
|
||||
)
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_queue_crud_invariant(mock_get_conn, config, site_id, initial_count):
|
||||
"""Property 8: 队列 CRUD 不变量。
|
||||
|
||||
入队一个任务后队列长度增加 1 且新任务状态为 pending;
|
||||
删除一个 pending 任务后队列长度减少 1 且该任务不再出现在队列中。
|
||||
"""
|
||||
queue = TaskQueue()
|
||||
db = InMemoryQueueDB(site_id)
|
||||
|
||||
# 预填充若干任务
|
||||
for i in range(initial_count):
|
||||
tid = str(uuid.uuid4())
|
||||
db.rows[tid] = {
|
||||
"id": tid,
|
||||
"site_id": site_id,
|
||||
"config": {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"},
|
||||
"status": "pending",
|
||||
"position": i + 1,
|
||||
}
|
||||
|
||||
before_count = len(db.pending_tasks)
|
||||
|
||||
# --- 入队 ---
|
||||
mock_get_conn.return_value = db.mock_enqueue_connection()
|
||||
new_id = queue.enqueue(config, site_id)
|
||||
|
||||
after_enqueue_count = len(db.pending_tasks)
|
||||
assert after_enqueue_count == before_count + 1, (
|
||||
f"入队后长度应 +1:期望 {before_count + 1},实际 {after_enqueue_count}"
|
||||
)
|
||||
assert new_id in db.rows, "新任务应存在于队列中"
|
||||
assert db.rows[new_id]["status"] == "pending", "新任务状态应为 pending"
|
||||
|
||||
# --- 删除刚入队的任务 ---
|
||||
mock_get_conn.return_value = db.mock_delete_connection(new_id)
|
||||
deleted = queue.delete(new_id, site_id)
|
||||
|
||||
after_delete_count = len(db.pending_tasks)
|
||||
assert deleted is True, "删除 pending 任务应返回 True"
|
||||
assert after_delete_count == before_count, (
|
||||
f"删除后长度应恢复:期望 {before_count},实际 {after_delete_count}"
|
||||
)
|
||||
assert new_id not in db.rows, "已删除任务不应出现在队列中"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 9: 队列出队顺序
|
||||
# **Validates: Requirements 4.2**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
site_id=_site_id_st,
|
||||
num_tasks=st.integers(min_value=1, max_value=8),
|
||||
positions=st.data(),
|
||||
)
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_queue_dequeue_order(mock_get_conn, site_id, num_tasks, positions):
|
||||
"""Property 9: 队列出队顺序。
|
||||
|
||||
包含多个 pending 任务的队列,dequeue 操作应返回 position 值最小的任务。
|
||||
"""
|
||||
queue = TaskQueue()
|
||||
db = InMemoryQueueDB(site_id)
|
||||
|
||||
# 生成不重复的 position 值
|
||||
pos_list = positions.draw(
|
||||
st.lists(
|
||||
st.integers(min_value=1, max_value=1000),
|
||||
min_size=num_tasks,
|
||||
max_size=num_tasks,
|
||||
unique=True,
|
||||
)
|
||||
)
|
||||
|
||||
# 填充队列
|
||||
task_ids = []
|
||||
for i, pos in enumerate(pos_list):
|
||||
tid = str(uuid.uuid4())
|
||||
task_ids.append(tid)
|
||||
db.rows[tid] = {
|
||||
"id": tid,
|
||||
"site_id": site_id,
|
||||
"config": {"tasks": [_task_codes[i % len(_task_codes)]], "pipeline": "api_ods"},
|
||||
"status": "pending",
|
||||
"position": pos,
|
||||
}
|
||||
|
||||
# 找出 position 最小的任务
|
||||
expected_first = min(db.pending_tasks, key=lambda r: r["position"])
|
||||
|
||||
# dequeue
|
||||
mock_get_conn.return_value = db.mock_dequeue_connection()
|
||||
result = queue.dequeue(site_id)
|
||||
|
||||
assert result is not None, "队列非空时 dequeue 不应返回 None"
|
||||
assert result.id == expected_first["id"], (
|
||||
f"应返回 position 最小的任务:期望 id={expected_first['id']} "
|
||||
f"(pos={expected_first['position']}),实际 id={result.id}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 10: 队列重排一致性
|
||||
# **Validates: Requirements 4.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
site_id=_site_id_st,
|
||||
num_tasks=st.integers(min_value=2, max_value=6),
|
||||
data=st.data(),
|
||||
)
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_queue_reorder_consistency(mock_get_conn, site_id, num_tasks, data):
|
||||
"""Property 10: 队列重排一致性。
|
||||
|
||||
重排操作(将任务移动到新位置)后,队列中任务的相对顺序应与请求一致:
|
||||
- 被移动的任务应出现在目标位置(clamp 到有效范围)
|
||||
- 其余任务保持原有相对顺序
|
||||
- 所有任务仍在队列中(不丢失)
|
||||
"""
|
||||
queue = TaskQueue()
|
||||
db = InMemoryQueueDB(site_id)
|
||||
|
||||
# 填充队列,position 从 1 开始连续编号
|
||||
task_ids = []
|
||||
for i in range(num_tasks):
|
||||
tid = str(uuid.uuid4())
|
||||
task_ids.append(tid)
|
||||
db.rows[tid] = {
|
||||
"id": tid,
|
||||
"site_id": site_id,
|
||||
"config": {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"},
|
||||
"status": "pending",
|
||||
"position": i + 1,
|
||||
}
|
||||
|
||||
# 随机选择要移动的任务和目标位置
|
||||
move_idx = data.draw(st.integers(min_value=0, max_value=num_tasks - 1))
|
||||
move_task_id = task_ids[move_idx]
|
||||
new_position = data.draw(st.integers(min_value=1, max_value=num_tasks + 2))
|
||||
|
||||
# 执行 reorder
|
||||
mock_get_conn.return_value = db.mock_reorder_connection()
|
||||
queue.reorder(move_task_id, new_position, site_id)
|
||||
|
||||
# 验证:所有任务仍在队列中
|
||||
remaining_ids = {r["id"] for r in db.rows.values() if r["status"] == "pending"}
|
||||
assert remaining_ids == set(task_ids), "重排后不应丢失任何任务"
|
||||
|
||||
# 验证:position 值连续且唯一(1-based)
|
||||
positions = sorted(r["position"] for r in db.pending_tasks)
|
||||
assert positions == list(range(1, num_tasks + 1)), (
|
||||
f"重排后 position 应为连续编号 1..{num_tasks},实际 {positions}"
|
||||
)
|
||||
|
||||
# 验证:被移动的任务在正确位置
|
||||
# reorder 内部逻辑:clamp new_position 到 [1, len(others)+1]
|
||||
clamped_pos = max(1, min(new_position, num_tasks))
|
||||
actual_pos = db.rows[move_task_id]["position"]
|
||||
assert actual_pos == clamped_pos, (
|
||||
f"被移动任务的 position 应为 {clamped_pos}(clamp 后),实际 {actual_pos}"
|
||||
)
|
||||
|
||||
# 验证:其余任务保持原有相对顺序
|
||||
others_before = [tid for tid in task_ids if tid != move_task_id]
|
||||
others_after = sorted(
|
||||
[r for r in db.pending_tasks if r["id"] != move_task_id],
|
||||
key=lambda r: r["position"],
|
||||
)
|
||||
others_after_ids = [r["id"] for r in others_after]
|
||||
assert others_after_ids == others_before, (
|
||||
"其余任务的相对顺序应保持不变"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 11: 执行历史排序与限制
|
||||
# **Validates: Requirements 4.5, 8.2**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 导入 FastAPI 测试客户端
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
def _make_history_rows(count: int, site_id: int) -> list[tuple]:
|
||||
"""生成 count 条执行历史记录,started_at 随机但可排序。"""
|
||||
base_time = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
rows = []
|
||||
for i in range(count):
|
||||
rows.append((
|
||||
str(uuid.uuid4()), # id
|
||||
site_id, # site_id
|
||||
["ODS_MEMBER"], # task_codes
|
||||
"success", # status
|
||||
base_time + timedelta(hours=i), # started_at
|
||||
base_time + timedelta(hours=i, minutes=30), # finished_at
|
||||
0, # exit_code
|
||||
1800000, # duration_ms
|
||||
"python -m cli.main", # command
|
||||
None, # summary
|
||||
))
|
||||
return rows
|
||||
|
||||
|
||||
@settings(max_examples=100, deadline=None)
|
||||
@given(
|
||||
site_id=_site_id_st,
|
||||
total_records=st.integers(min_value=0, max_value=30),
|
||||
limit=st.integers(min_value=1, max_value=200),
|
||||
)
|
||||
@patch("app.routers.execution.get_connection")
|
||||
def test_execution_history_sort_and_limit(mock_get_conn, site_id, total_records, limit):
|
||||
"""Property 11: 执行历史排序与限制。
|
||||
|
||||
执行历史记录集合,API 返回的结果应按 started_at 降序排列,
|
||||
且结果数量不超过请求的 limit 值。
|
||||
"""
|
||||
# 生成测试数据
|
||||
all_rows = _make_history_rows(total_records, site_id)
|
||||
|
||||
# 模拟数据库:按 started_at DESC 排序后取 limit 条
|
||||
sorted_rows = sorted(all_rows, key=lambda r: r[4], reverse=True)
|
||||
returned_rows = sorted_rows[:limit]
|
||||
|
||||
# mock 数据库连接
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchall.return_value = returned_rows
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
# 覆盖认证依赖
|
||||
test_user = CurrentUser(user_id=1, site_id=site_id)
|
||||
app.dependency_overrides[get_current_user] = lambda: test_user
|
||||
|
||||
try:
|
||||
client = TestClient(app)
|
||||
# limit 必须在 [1, 200] 范围内(API 约束)
|
||||
clamped_limit = max(1, min(limit, 200))
|
||||
resp = client.get(f"/api/execution/history?limit={clamped_limit}")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
|
||||
# 验证 1:结果数量不超过 limit
|
||||
assert len(data) <= clamped_limit, (
|
||||
f"结果数量 {len(data)} 超过 limit {clamped_limit}"
|
||||
)
|
||||
|
||||
# 验证 2:结果数量不超过总记录数
|
||||
assert len(data) <= total_records, (
|
||||
f"结果数量 {len(data)} 超过总记录数 {total_records}"
|
||||
)
|
||||
|
||||
# 验证 3:按 started_at 降序排列
|
||||
if len(data) >= 2:
|
||||
for i in range(len(data) - 1):
|
||||
t1 = data[i]["started_at"]
|
||||
t2 = data[i + 1]["started_at"]
|
||||
assert t1 >= t2, (
|
||||
f"结果未按 started_at 降序排列:data[{i}]={t1} < data[{i+1}]={t2}"
|
||||
)
|
||||
finally:
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(user_id=1, site_id=100)
|
||||
439
apps/backend/tests/test_schedule_properties.py
Normal file
439
apps/backend/tests/test_schedule_properties.py
Normal file
@@ -0,0 +1,439 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""调度属性测试(Property-Based Testing)。
|
||||
|
||||
使用 hypothesis 验证调度管理的通用正确性属性:
|
||||
- Property 12: 调度任务 CRUD 往返
|
||||
- Property 13: 到期调度任务自动入队
|
||||
- Property 14: 调度任务启用/禁用状态
|
||||
|
||||
测试策略:
|
||||
- Property 12: 通过 mock 数据库,验证 POST 创建后 GET 返回的 schedule_config 与提交的一致
|
||||
- Property 13: 通过 mock 数据库返回到期任务,验证 check_and_enqueue 调用了 task_queue.enqueue
|
||||
- Property 14: 通过 mock 数据库,验证 toggle 端点的 next_run_at 行为
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-schedule-properties")
|
||||
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
from app.schemas.schedules import ScheduleConfigSchema
|
||||
from app.schemas.tasks import TaskConfigSchema
|
||||
from app.services.scheduler import Scheduler, calculate_next_run
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 通用策略(Strategies)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_site_id_st = st.integers(min_value=1, max_value=2**31 - 1)
|
||||
|
||||
_task_codes = ["ODS_MEMBER", "ODS_PAYMENT", "ODS_ORDER", "DWD_LOAD_FROM_ODS", "DWS_SUMMARY"]
|
||||
|
||||
_simple_task_config_st = st.fixed_dictionaries({
|
||||
"tasks": st.lists(st.sampled_from(_task_codes), min_size=1, max_size=3, unique=True),
|
||||
"pipeline": st.sampled_from(["api_ods", "api_ods_dwd", "ods_dwd", "api_full"]),
|
||||
})
|
||||
|
||||
# 调度配置策略:覆盖 5 种调度类型
|
||||
_schedule_type_st = st.sampled_from(["once", "interval", "daily", "weekly", "cron"])
|
||||
|
||||
_interval_unit_st = st.sampled_from(["minutes", "hours", "days"])
|
||||
|
||||
# HH:MM 格式的时间字符串
|
||||
_time_str_st = st.builds(
|
||||
lambda h, m: f"{h:02d}:{m:02d}",
|
||||
h=st.integers(min_value=0, max_value=23),
|
||||
m=st.integers(min_value=0, max_value=59),
|
||||
)
|
||||
|
||||
# ISO weekday 列表(1=Monday ... 7=Sunday)
|
||||
_weekly_days_st = st.lists(
|
||||
st.integers(min_value=1, max_value=7),
|
||||
min_size=1, max_size=7, unique=True,
|
||||
)
|
||||
|
||||
# 简单 cron 表达式(minute hour * * *)
|
||||
_cron_st = st.builds(
|
||||
lambda m, h: f"{m} {h} * * *",
|
||||
m=st.integers(min_value=0, max_value=59),
|
||||
h=st.integers(min_value=0, max_value=23),
|
||||
)
|
||||
|
||||
|
||||
def _build_schedule_config(schedule_type, interval_value, interval_unit,
|
||||
daily_time, weekly_days, weekly_time, cron_expression):
|
||||
"""根据 schedule_type 构建 ScheduleConfigSchema。"""
|
||||
return ScheduleConfigSchema(
|
||||
schedule_type=schedule_type,
|
||||
interval_value=interval_value,
|
||||
interval_unit=interval_unit,
|
||||
daily_time=daily_time,
|
||||
weekly_days=weekly_days,
|
||||
weekly_time=weekly_time,
|
||||
cron_expression=cron_expression,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
|
||||
_schedule_config_st = st.builds(
|
||||
_build_schedule_config,
|
||||
schedule_type=_schedule_type_st,
|
||||
interval_value=st.integers(min_value=1, max_value=168),
|
||||
interval_unit=_interval_unit_st,
|
||||
daily_time=_time_str_st,
|
||||
weekly_days=_weekly_days_st,
|
||||
weekly_time=_time_str_st,
|
||||
cron_expression=_cron_st,
|
||||
)
|
||||
|
||||
# 用于 Property 14 的非 once 调度配置(启用后 next_run_at 应非 NULL)
|
||||
_non_once_schedule_type_st = st.sampled_from(["interval", "daily", "weekly", "cron"])
|
||||
|
||||
_non_once_schedule_config_st = st.builds(
|
||||
_build_schedule_config,
|
||||
schedule_type=_non_once_schedule_type_st,
|
||||
interval_value=st.integers(min_value=1, max_value=168),
|
||||
interval_unit=_interval_unit_st,
|
||||
daily_time=_time_str_st,
|
||||
weekly_days=_weekly_days_st,
|
||||
weekly_time=_time_str_st,
|
||||
cron_expression=_cron_st,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_NOW = datetime(2025, 6, 10, 10, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
# 模拟数据库行的列顺序(与 _SELECT_COLS 对应,共 13 列)
|
||||
# id, site_id, name, task_codes, task_config, schedule_config,
|
||||
# enabled, last_run_at, next_run_at, run_count, last_status,
|
||||
# created_at, updated_at
|
||||
|
||||
|
||||
def _make_db_row(
|
||||
schedule_id: str,
|
||||
site_id: int,
|
||||
name: str,
|
||||
task_codes: list[str],
|
||||
task_config: dict,
|
||||
schedule_config: dict,
|
||||
enabled: bool = True,
|
||||
next_run_at: datetime | None = None,
|
||||
) -> tuple:
|
||||
"""构造模拟数据库行。"""
|
||||
return (
|
||||
schedule_id, site_id, name, task_codes,
|
||||
json.dumps(task_config) if isinstance(task_config, dict) else task_config,
|
||||
json.dumps(schedule_config) if isinstance(schedule_config, dict) else schedule_config,
|
||||
enabled, None, next_run_at, 0, None, _NOW, _NOW,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 12: 调度任务 CRUD 往返
|
||||
# **Validates: Requirements 5.1, 5.4**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
site_id=_site_id_st,
|
||||
schedule_config=_schedule_config_st,
|
||||
task_config=_simple_task_config_st,
|
||||
name=st.text(min_size=1, max_size=50, alphabet=st.characters(
|
||||
whitelist_categories=("L", "N"), whitelist_characters="_- "
|
||||
)),
|
||||
task_codes=st.lists(st.sampled_from(_task_codes), min_size=1, max_size=3, unique=True),
|
||||
)
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_schedule_crud_round_trip(
|
||||
mock_get_conn, site_id, schedule_config, task_config, name, task_codes,
|
||||
):
|
||||
"""Property 12: 调度任务 CRUD 往返。
|
||||
|
||||
有效的 ScheduleConfigSchema,创建调度任务后再查询该任务,
|
||||
返回的调度配置应与创建时提交的配置等价。
|
||||
"""
|
||||
schedule_config_dict = schedule_config.model_dump()
|
||||
next_run = calculate_next_run(schedule_config, _NOW)
|
||||
|
||||
# 构造创建后数据库返回的行
|
||||
created_row = _make_db_row(
|
||||
schedule_id="test-sched-id",
|
||||
site_id=site_id,
|
||||
name=name,
|
||||
task_codes=task_codes,
|
||||
task_config=task_config,
|
||||
schedule_config=schedule_config_dict,
|
||||
enabled=schedule_config.enabled,
|
||||
next_run_at=next_run,
|
||||
)
|
||||
|
||||
# --- 创建阶段 ---
|
||||
# mock POST 的数据库连接(INSERT ... RETURNING)
|
||||
create_cursor = MagicMock()
|
||||
create_cursor.fetchone.return_value = created_row
|
||||
create_conn = MagicMock()
|
||||
create_conn.cursor.return_value.__enter__ = MagicMock(return_value=create_cursor)
|
||||
create_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# --- 查询阶段 ---
|
||||
# mock GET 的数据库连接(SELECT ... fetchall)
|
||||
list_cursor = MagicMock()
|
||||
list_cursor.fetchall.return_value = [created_row]
|
||||
list_conn = MagicMock()
|
||||
list_conn.cursor.return_value.__enter__ = MagicMock(return_value=list_cursor)
|
||||
list_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# 依次返回 create_conn 和 list_conn
|
||||
mock_get_conn.side_effect = [create_conn, list_conn]
|
||||
|
||||
# 覆盖认证
|
||||
test_user = CurrentUser(user_id=1, site_id=site_id)
|
||||
app.dependency_overrides[get_current_user] = lambda: test_user
|
||||
|
||||
try:
|
||||
client = TestClient(app)
|
||||
|
||||
# 创建调度任务
|
||||
create_body = {
|
||||
"name": name,
|
||||
"task_codes": task_codes,
|
||||
"task_config": task_config,
|
||||
"schedule_config": schedule_config_dict,
|
||||
}
|
||||
create_resp = client.post("/api/schedules", json=create_body)
|
||||
assert create_resp.status_code == 201, (
|
||||
f"创建应返回 201,实际 {create_resp.status_code}: {create_resp.text}"
|
||||
)
|
||||
created_data = create_resp.json()
|
||||
|
||||
# 查询调度任务列表
|
||||
list_resp = client.get("/api/schedules")
|
||||
assert list_resp.status_code == 200
|
||||
list_data = list_resp.json()
|
||||
|
||||
assert len(list_data) >= 1, "查询结果应至少包含刚创建的任务"
|
||||
|
||||
# 找到刚创建的任务
|
||||
found = next((s for s in list_data if s["id"] == created_data["id"]), None)
|
||||
assert found is not None, "查询结果应包含刚创建的任务"
|
||||
|
||||
# 核心验证:schedule_config 往返一致
|
||||
returned_config = found["schedule_config"]
|
||||
for key in schedule_config_dict:
|
||||
assert returned_config[key] == schedule_config_dict[key], (
|
||||
f"schedule_config.{key} 不一致:"
|
||||
f"提交={schedule_config_dict[key]},返回={returned_config[key]}"
|
||||
)
|
||||
|
||||
# 验证 task_config 往返一致
|
||||
returned_task_config = found["task_config"]
|
||||
for key in task_config:
|
||||
assert returned_task_config[key] == task_config[key], (
|
||||
f"task_config.{key} 不一致:提交={task_config[key]},返回={returned_task_config[key]}"
|
||||
)
|
||||
|
||||
# 验证基本字段
|
||||
assert found["name"] == name
|
||||
assert found["task_codes"] == task_codes
|
||||
assert found["site_id"] == site_id
|
||||
finally:
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(user_id=1, site_id=100)
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 13: 到期调度任务自动入队
|
||||
# **Validates: Requirements 5.2**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
site_id=_site_id_st,
|
||||
schedule_config=_schedule_config_st,
|
||||
task_config=_simple_task_config_st,
|
||||
)
|
||||
@patch("app.services.scheduler.task_queue")
|
||||
@patch("app.services.scheduler.get_connection")
|
||||
def test_due_schedule_auto_enqueue(
|
||||
mock_get_conn, mock_tq, site_id, schedule_config, task_config,
|
||||
):
|
||||
"""Property 13: 到期调度任务自动入队。
|
||||
|
||||
enabled 为 true 且 next_run_at 早于当前时间的调度任务,
|
||||
check_and_enqueue 执行后该任务的 TaskConfig 应出现在执行队列中。
|
||||
"""
|
||||
sched = Scheduler()
|
||||
schedule_config_dict = schedule_config.model_dump()
|
||||
|
||||
# 构造到期任务:next_run_at 在过去(比 now 早 5 分钟)
|
||||
task_id = "due-task-001"
|
||||
|
||||
# --- mock SELECT 到期任务 ---
|
||||
select_cursor = MagicMock()
|
||||
select_cursor.fetchall.return_value = [
|
||||
(task_id, site_id, json.dumps(task_config), json.dumps(schedule_config_dict)),
|
||||
]
|
||||
select_cursor.__enter__ = MagicMock(return_value=select_cursor)
|
||||
select_cursor.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# --- mock UPDATE 调度状态 ---
|
||||
update_cursor = MagicMock()
|
||||
update_cursor.__enter__ = MagicMock(return_value=update_cursor)
|
||||
update_cursor.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
conn = MagicMock()
|
||||
conn.cursor.side_effect = [select_cursor, update_cursor]
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
mock_tq.enqueue.return_value = "queue-id-123"
|
||||
|
||||
# 执行
|
||||
count = sched.check_and_enqueue()
|
||||
|
||||
# 验证:到期任务被入队
|
||||
assert count == 1, f"应有 1 个任务入队,实际 {count}"
|
||||
mock_tq.enqueue.assert_called_once()
|
||||
|
||||
# 验证入队参数
|
||||
call_args = mock_tq.enqueue.call_args
|
||||
enqueued_config = call_args[0][0]
|
||||
enqueued_site_id = call_args[0][1]
|
||||
|
||||
# site_id 应匹配
|
||||
assert enqueued_site_id == site_id, (
|
||||
f"入队的 site_id 应为 {site_id},实际 {enqueued_site_id}"
|
||||
)
|
||||
|
||||
# TaskConfig 应与原始配置一致
|
||||
assert isinstance(enqueued_config, TaskConfigSchema)
|
||||
assert enqueued_config.tasks == task_config["tasks"], (
|
||||
f"入队的 tasks 应为 {task_config['tasks']},实际 {enqueued_config.tasks}"
|
||||
)
|
||||
assert enqueued_config.pipeline == task_config["pipeline"], (
|
||||
f"入队的 pipeline 应为 {task_config['pipeline']},实际 {enqueued_config.pipeline}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 14: 调度任务启用/禁用状态
|
||||
# **Validates: Requirements 5.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100, deadline=None)
|
||||
@given(
|
||||
site_id=_site_id_st,
|
||||
schedule_config=_non_once_schedule_config_st,
|
||||
task_config=_simple_task_config_st,
|
||||
name=st.text(min_size=1, max_size=30, alphabet=st.characters(
|
||||
whitelist_categories=("L", "N"), whitelist_characters="_- "
|
||||
)),
|
||||
task_codes=st.lists(st.sampled_from(_task_codes), min_size=1, max_size=3, unique=True),
|
||||
)
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_schedule_toggle_next_run(
|
||||
mock_get_conn, site_id, schedule_config, task_config, name, task_codes,
|
||||
):
|
||||
"""Property 14: 调度任务启用/禁用状态。
|
||||
|
||||
禁用后 next_run_at 应为 NULL;
|
||||
重新启用后 next_run_at 应被重新计算为非 NULL 值(对于非一次性调度)。
|
||||
"""
|
||||
schedule_config_dict = schedule_config.model_dump()
|
||||
next_run_enabled = calculate_next_run(schedule_config, _NOW)
|
||||
|
||||
# --- 第一步:禁用(enabled=True → False)---
|
||||
# toggle 端点先 SELECT 当前状态,再 UPDATE RETURNING
|
||||
|
||||
# 禁用后的数据库行
|
||||
disabled_row = _make_db_row(
|
||||
schedule_id="sched-toggle-1",
|
||||
site_id=site_id,
|
||||
name=name,
|
||||
task_codes=task_codes,
|
||||
task_config=task_config,
|
||||
schedule_config=schedule_config_dict,
|
||||
enabled=False,
|
||||
next_run_at=None, # 禁用后 next_run_at 为 NULL
|
||||
)
|
||||
|
||||
# mock 禁用操作的数据库连接
|
||||
disable_cursor = MagicMock()
|
||||
disable_cursor.fetchone.side_effect = [
|
||||
(True, json.dumps(schedule_config_dict)), # SELECT 当前状态(enabled=True)
|
||||
disabled_row, # UPDATE RETURNING
|
||||
]
|
||||
disable_conn = MagicMock()
|
||||
disable_conn.cursor.return_value.__enter__ = MagicMock(return_value=disable_cursor)
|
||||
disable_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# --- 第二步:启用(enabled=False → True)---
|
||||
enabled_row = _make_db_row(
|
||||
schedule_id="sched-toggle-1",
|
||||
site_id=site_id,
|
||||
name=name,
|
||||
task_codes=task_codes,
|
||||
task_config=task_config,
|
||||
schedule_config=schedule_config_dict,
|
||||
enabled=True,
|
||||
next_run_at=next_run_enabled, # 启用后 next_run_at 被重新计算
|
||||
)
|
||||
|
||||
enable_cursor = MagicMock()
|
||||
enable_cursor.fetchone.side_effect = [
|
||||
(False, json.dumps(schedule_config_dict)), # SELECT 当前状态(enabled=False)
|
||||
enabled_row, # UPDATE RETURNING
|
||||
]
|
||||
enable_conn = MagicMock()
|
||||
enable_conn.cursor.return_value.__enter__ = MagicMock(return_value=enable_cursor)
|
||||
enable_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# 依次返回两个连接
|
||||
mock_get_conn.side_effect = [disable_conn, enable_conn]
|
||||
|
||||
# 覆盖认证
|
||||
test_user = CurrentUser(user_id=1, site_id=site_id)
|
||||
app.dependency_overrides[get_current_user] = lambda: test_user
|
||||
|
||||
try:
|
||||
client = TestClient(app)
|
||||
|
||||
# 禁用
|
||||
disable_resp = client.patch("/api/schedules/sched-toggle-1/toggle")
|
||||
assert disable_resp.status_code == 200, (
|
||||
f"禁用应返回 200,实际 {disable_resp.status_code}: {disable_resp.text}"
|
||||
)
|
||||
disable_data = disable_resp.json()
|
||||
|
||||
# 验证:禁用后 enabled=False,next_run_at=NULL
|
||||
assert disable_data["enabled"] is False, "禁用后 enabled 应为 False"
|
||||
assert disable_data["next_run_at"] is None, "禁用后 next_run_at 应为 NULL"
|
||||
|
||||
# 启用
|
||||
enable_resp = client.patch("/api/schedules/sched-toggle-1/toggle")
|
||||
assert enable_resp.status_code == 200, (
|
||||
f"启用应返回 200,实际 {enable_resp.status_code}: {enable_resp.text}"
|
||||
)
|
||||
enable_data = enable_resp.json()
|
||||
|
||||
# 验证:启用后 enabled=True,next_run_at 非 NULL(非一次性调度)
|
||||
assert enable_data["enabled"] is True, "启用后 enabled 应为 True"
|
||||
assert enable_data["next_run_at"] is not None, (
|
||||
"非一次性调度启用后 next_run_at 应被重新计算为非 NULL 值"
|
||||
)
|
||||
finally:
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(user_id=1, site_id=100)
|
||||
384
apps/backend/tests/test_scheduler.py
Normal file
384
apps/backend/tests/test_scheduler.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Scheduler 单元测试
|
||||
|
||||
覆盖:
|
||||
- calculate_next_run:各种调度类型的下次执行时间计算
|
||||
- _parse_simple_cron:简单 cron 表达式解析
|
||||
- check_and_enqueue:到期检查与入队逻辑
|
||||
- start / stop:后台循环生命周期
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.schemas.schedules import ScheduleConfigSchema
|
||||
from app.schemas.tasks import TaskConfigSchema
|
||||
from app.services.scheduler import (
|
||||
Scheduler,
|
||||
calculate_next_run,
|
||||
_parse_simple_cron,
|
||||
_parse_time,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def sched() -> Scheduler:
|
||||
return Scheduler()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def now() -> datetime:
|
||||
"""固定时间点:2025-06-10 10:00:00 UTC(周二)"""
|
||||
return datetime(2025, 6, 10, 10, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def _mock_cursor(fetchone_val=None, fetchall_val=None, rowcount=1):
|
||||
cur = MagicMock()
|
||||
cur.fetchone.return_value = fetchone_val
|
||||
cur.fetchall.return_value = fetchall_val or []
|
||||
cur.rowcount = rowcount
|
||||
cur.__enter__ = MagicMock(return_value=cur)
|
||||
cur.__exit__ = MagicMock(return_value=False)
|
||||
return cur
|
||||
|
||||
|
||||
def _mock_conn(cursor):
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value = cursor
|
||||
return conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_time
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseTime:
|
||||
def test_standard_format(self):
|
||||
assert _parse_time("04:00") == (4, 0)
|
||||
|
||||
def test_with_minutes(self):
|
||||
assert _parse_time("23:45") == (23, 45)
|
||||
|
||||
def test_midnight(self):
|
||||
assert _parse_time("00:00") == (0, 0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# calculate_next_run — once
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNextRunOnce:
|
||||
def test_once_returns_none(self, now):
|
||||
cfg = ScheduleConfigSchema(schedule_type="once")
|
||||
assert calculate_next_run(cfg, now) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# calculate_next_run — interval
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNextRunInterval:
|
||||
def test_interval_minutes(self, now):
|
||||
cfg = ScheduleConfigSchema(
|
||||
schedule_type="interval", interval_value=15, interval_unit="minutes",
|
||||
)
|
||||
result = calculate_next_run(cfg, now)
|
||||
assert result == now + timedelta(minutes=15)
|
||||
|
||||
def test_interval_hours(self, now):
|
||||
cfg = ScheduleConfigSchema(
|
||||
schedule_type="interval", interval_value=2, interval_unit="hours",
|
||||
)
|
||||
result = calculate_next_run(cfg, now)
|
||||
assert result == now + timedelta(hours=2)
|
||||
|
||||
def test_interval_days(self, now):
|
||||
cfg = ScheduleConfigSchema(
|
||||
schedule_type="interval", interval_value=3, interval_unit="days",
|
||||
)
|
||||
result = calculate_next_run(cfg, now)
|
||||
assert result == now + timedelta(days=3)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# calculate_next_run — daily
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNextRunDaily:
|
||||
def test_daily_next_day(self, now):
|
||||
cfg = ScheduleConfigSchema(schedule_type="daily", daily_time="04:00")
|
||||
result = calculate_next_run(cfg, now)
|
||||
expected = datetime(2025, 6, 11, 4, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_daily_custom_time(self, now):
|
||||
cfg = ScheduleConfigSchema(schedule_type="daily", daily_time="18:30")
|
||||
result = calculate_next_run(cfg, now)
|
||||
expected = datetime(2025, 6, 11, 18, 30, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# calculate_next_run — weekly
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNextRunWeekly:
|
||||
def test_weekly_later_this_week(self, now):
|
||||
# now 是周二(2),weekly_days=[5] 周五 → 3 天后
|
||||
cfg = ScheduleConfigSchema(
|
||||
schedule_type="weekly", weekly_days=[5], weekly_time="08:00",
|
||||
)
|
||||
result = calculate_next_run(cfg, now)
|
||||
expected = datetime(2025, 6, 13, 8, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_weekly_next_week(self, now):
|
||||
# now 是周二(2),weekly_days=[1] 周一 → 下周一(6天后)
|
||||
cfg = ScheduleConfigSchema(
|
||||
schedule_type="weekly", weekly_days=[1], weekly_time="04:00",
|
||||
)
|
||||
result = calculate_next_run(cfg, now)
|
||||
expected = datetime(2025, 6, 16, 4, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_weekly_multiple_days_picks_next(self, now):
|
||||
# now 是周二(2),weekly_days=[1, 4, 6] → 周四(4),2 天后
|
||||
cfg = ScheduleConfigSchema(
|
||||
schedule_type="weekly", weekly_days=[1, 4, 6], weekly_time="09:00",
|
||||
)
|
||||
result = calculate_next_run(cfg, now)
|
||||
expected = datetime(2025, 6, 12, 9, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# calculate_next_run — cron
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNextRunCron:
|
||||
def test_cron_daily(self, now):
|
||||
cfg = ScheduleConfigSchema(schedule_type="cron", cron_expression="30 4 * * *")
|
||||
result = calculate_next_run(cfg, now)
|
||||
expected = datetime(2025, 6, 11, 4, 30, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_cron_with_dow(self, now):
|
||||
# "0 8 * * 5" → 每周五 08:00,now 是周二 → 周五(3天后)
|
||||
cfg = ScheduleConfigSchema(schedule_type="cron", cron_expression="0 8 * * 5")
|
||||
result = calculate_next_run(cfg, now)
|
||||
expected = datetime(2025, 6, 13, 8, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_simple_cron
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseSimpleCron:
|
||||
def test_daily_cron(self, now):
|
||||
result = _parse_simple_cron("0 4 * * *", now)
|
||||
expected = datetime(2025, 6, 11, 4, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_invalid_field_count_fallback(self, now):
|
||||
# 字段数不对,回退到明天 04:00
|
||||
result = _parse_simple_cron("0 4 *", now)
|
||||
expected = datetime(2025, 6, 11, 4, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_wildcard_hour_minute(self, now):
|
||||
# "* * * * *" → hour=0, minute=0,明天 00:00
|
||||
result = _parse_simple_cron("* * * * *", now)
|
||||
expected = datetime(2025, 6, 11, 0, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_dow_sunday(self, now):
|
||||
# "0 6 * * 0" → 每周日 06:00,now 是周二 → 周日(5天后)
|
||||
result = _parse_simple_cron("0 6 * * 0", now)
|
||||
expected = datetime(2025, 6, 15, 6, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_dow_same_day_future_time(self):
|
||||
# 周二 08:00,cron 指定周二 12:00 → 当天
|
||||
now = datetime(2025, 6, 10, 8, 0, 0, tzinfo=timezone.utc)
|
||||
result = _parse_simple_cron("0 12 * * 2", now)
|
||||
expected = datetime(2025, 6, 10, 12, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_dow_same_day_past_time(self):
|
||||
# 周二 14:00,cron 指定周二 12:00 → 下周二
|
||||
now = datetime(2025, 6, 10, 14, 0, 0, tzinfo=timezone.utc)
|
||||
result = _parse_simple_cron("0 12 * * 2", now)
|
||||
expected = datetime(2025, 6, 17, 12, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_and_enqueue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckAndEnqueue:
|
||||
@patch("app.services.scheduler.get_connection")
|
||||
@patch("app.services.scheduler.task_queue")
|
||||
def test_enqueues_due_tasks(self, mock_tq, mock_get_conn, sched):
|
||||
"""到期任务应被入队,且更新 last_run_at / run_count / next_run_at"""
|
||||
task_config = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods_dwd"}
|
||||
schedule_config = {
|
||||
"schedule_type": "interval",
|
||||
"interval_value": 1,
|
||||
"interval_unit": "hours",
|
||||
}
|
||||
|
||||
# 第一次 cursor:SELECT 到期任务
|
||||
select_cur = _mock_cursor(
|
||||
fetchall_val=[
|
||||
("task-uuid-1", 42, json.dumps(task_config), json.dumps(schedule_config)),
|
||||
]
|
||||
)
|
||||
# 第二次 cursor:UPDATE
|
||||
update_cur = _mock_cursor()
|
||||
|
||||
conn = MagicMock()
|
||||
# cursor() 依次返回 select_cur 和 update_cur
|
||||
conn.cursor.side_effect = [select_cur, update_cur]
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
mock_tq.enqueue.return_value = "queue-id-1"
|
||||
|
||||
count = sched.check_and_enqueue()
|
||||
|
||||
assert count == 1
|
||||
mock_tq.enqueue.assert_called_once()
|
||||
# 验证 enqueue 的参数
|
||||
call_args = mock_tq.enqueue.call_args
|
||||
assert call_args[0][1] == 42 # site_id
|
||||
assert isinstance(call_args[0][0], TaskConfigSchema)
|
||||
|
||||
@patch("app.services.scheduler.get_connection")
|
||||
@patch("app.services.scheduler.task_queue")
|
||||
def test_no_due_tasks(self, mock_tq, mock_get_conn, sched):
|
||||
"""没有到期任务时,不入队"""
|
||||
cur = _mock_cursor(fetchall_val=[])
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
count = sched.check_and_enqueue()
|
||||
|
||||
assert count == 0
|
||||
mock_tq.enqueue.assert_not_called()
|
||||
|
||||
@patch("app.services.scheduler.get_connection")
|
||||
@patch("app.services.scheduler.task_queue")
|
||||
def test_skips_invalid_config(self, mock_tq, mock_get_conn, sched):
|
||||
"""配置反序列化失败的任务应被跳过"""
|
||||
# task_config 缺少必填字段 tasks
|
||||
bad_config = {"pipeline": "api_ods_dwd"}
|
||||
schedule_config = {"schedule_type": "once"}
|
||||
|
||||
cur = _mock_cursor(
|
||||
fetchall_val=[
|
||||
("task-uuid-bad", 42, json.dumps(bad_config), json.dumps(schedule_config)),
|
||||
]
|
||||
)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
count = sched.check_and_enqueue()
|
||||
|
||||
assert count == 0
|
||||
mock_tq.enqueue.assert_not_called()
|
||||
|
||||
@patch("app.services.scheduler.get_connection")
|
||||
@patch("app.services.scheduler.task_queue")
|
||||
def test_enqueue_failure_continues(self, mock_tq, mock_get_conn, sched):
|
||||
"""入队失败时应跳过该任务,继续处理后续任务"""
|
||||
task_config = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods_dwd"}
|
||||
schedule_config = {"schedule_type": "once"}
|
||||
|
||||
cur = _mock_cursor(
|
||||
fetchall_val=[
|
||||
("task-1", 42, json.dumps(task_config), json.dumps(schedule_config)),
|
||||
("task-2", 42, json.dumps(task_config), json.dumps(schedule_config)),
|
||||
]
|
||||
)
|
||||
# 需要额外的 cursor 给 UPDATE 用
|
||||
update_cur = _mock_cursor()
|
||||
conn = MagicMock()
|
||||
conn.cursor.side_effect = [cur, update_cur]
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
# 第一次入队失败,第二次成功
|
||||
mock_tq.enqueue.side_effect = [Exception("DB error"), "queue-id-2"]
|
||||
|
||||
count = sched.check_and_enqueue()
|
||||
|
||||
assert count == 1
|
||||
assert mock_tq.enqueue.call_count == 2
|
||||
|
||||
@patch("app.services.scheduler.get_connection")
|
||||
@patch("app.services.scheduler.task_queue")
|
||||
def test_once_type_sets_next_run_none(self, mock_tq, mock_get_conn, sched):
|
||||
"""once 类型任务入队后,next_run_at 应被设为 NULL"""
|
||||
task_config = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods_dwd"}
|
||||
schedule_config = {"schedule_type": "once"}
|
||||
|
||||
select_cur = _mock_cursor(
|
||||
fetchall_val=[
|
||||
("task-uuid-1", 42, json.dumps(task_config), json.dumps(schedule_config)),
|
||||
]
|
||||
)
|
||||
update_cur = _mock_cursor()
|
||||
conn = MagicMock()
|
||||
conn.cursor.side_effect = [select_cur, update_cur]
|
||||
mock_get_conn.return_value = conn
|
||||
mock_tq.enqueue.return_value = "queue-id-1"
|
||||
|
||||
sched.check_and_enqueue()
|
||||
|
||||
# 验证 UPDATE 语句中 next_run_at 参数为 None
|
||||
update_call = update_cur.__enter__().execute.call_args
|
||||
# 参数元组的第一个元素是 next_run_at
|
||||
assert update_call[0][1][0] is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# start / stop 生命周期
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLifecycle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_sets_running_false(self, sched):
|
||||
sched._running = True
|
||||
await sched.stop()
|
||||
assert sched._running is False
|
||||
assert sched._loop_task is None
|
||||
|
||||
def test_start_creates_task(self, sched):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
# 在事件循环中启动
|
||||
async def _run():
|
||||
sched.start()
|
||||
assert sched._loop_task is not None
|
||||
assert not sched._loop_task.done()
|
||||
await sched.stop()
|
||||
|
||||
loop.run_until_complete(_run())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_stop_idempotent(self, sched):
|
||||
"""多次 stop 不应报错"""
|
||||
await sched.stop()
|
||||
await sched.stop()
|
||||
assert sched._loop_task is None
|
||||
310
apps/backend/tests/test_schedules_router.py
Normal file
310
apps/backend/tests/test_schedules_router.py
Normal file
@@ -0,0 +1,310 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""调度路由单元测试
|
||||
|
||||
覆盖 5 个端点:list / create / update / delete / toggle
|
||||
通过 mock 绕过数据库,专注路由逻辑验证。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
|
||||
_TEST_USER = CurrentUser(user_id=1, site_id=100)
|
||||
|
||||
|
||||
def _override_auth():
|
||||
return _TEST_USER
|
||||
|
||||
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
client = TestClient(app)
|
||||
|
||||
_NOW = datetime(2024, 6, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
_NEXT = datetime(2024, 6, 2, 4, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
_SCHEDULE_CONFIG = {
|
||||
"schedule_type": "daily",
|
||||
"daily_time": "04:00",
|
||||
}
|
||||
|
||||
_VALID_CREATE = {
|
||||
"name": "每日全量同步",
|
||||
"task_codes": ["ODS_MEMBER", "ODS_ORDER"],
|
||||
"task_config": {"tasks": ["ODS_MEMBER", "ODS_ORDER"], "pipeline": "api_ods"},
|
||||
"schedule_config": _SCHEDULE_CONFIG,
|
||||
}
|
||||
|
||||
# 模拟数据库返回的完整行(13 列,与 _SELECT_COLS 对应)
|
||||
_DB_ROW = (
|
||||
"sched-1", 100, "每日全量同步", ["ODS_MEMBER", "ODS_ORDER"],
|
||||
json.dumps({"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}),
|
||||
json.dumps(_SCHEDULE_CONFIG),
|
||||
True, None, _NEXT, 0, None, _NOW, _NOW,
|
||||
)
|
||||
|
||||
|
||||
def _mock_conn_with_fetchall(rows):
|
||||
"""构造返回 fetchall 的 mock 连接。"""
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchall.return_value = rows
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_conn, mock_cursor
|
||||
|
||||
|
||||
def _mock_conn_with_fetchone(row):
|
||||
"""构造返回 fetchone 的 mock 连接。"""
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = row
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_conn, mock_cursor
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/schedules
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListSchedules:
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_list_returns_schedules(self, mock_get_conn):
|
||||
mock_conn, _ = _mock_conn_with_fetchall([_DB_ROW])
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.get("/api/schedules")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["id"] == "sched-1"
|
||||
assert data[0]["name"] == "每日全量同步"
|
||||
assert data[0]["site_id"] == 100
|
||||
assert data[0]["enabled"] is True
|
||||
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_list_empty(self, mock_get_conn):
|
||||
mock_conn, _ = _mock_conn_with_fetchall([])
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.get("/api/schedules")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_list_filters_by_site_id(self, mock_get_conn):
|
||||
mock_conn, mock_cursor = _mock_conn_with_fetchall([])
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
client.get("/api/schedules")
|
||||
call_args = mock_cursor.execute.call_args
|
||||
assert call_args[0][1] == (100,)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/schedules
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCreateSchedule:
|
||||
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_create_returns_201(self, mock_get_conn, mock_calc):
|
||||
mock_conn, mock_cursor = _mock_conn_with_fetchone(_DB_ROW)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.post("/api/schedules", json=_VALID_CREATE)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["id"] == "sched-1"
|
||||
assert data["name"] == "每日全量同步"
|
||||
|
||||
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_create_injects_site_id(self, mock_get_conn, mock_calc):
|
||||
mock_conn, mock_cursor = _mock_conn_with_fetchone(_DB_ROW)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
client.post("/api/schedules", json=_VALID_CREATE)
|
||||
# INSERT 的第一个参数应为 site_id=100
|
||||
insert_params = mock_cursor.execute.call_args[0][1]
|
||||
assert insert_params[0] == 100
|
||||
|
||||
def test_create_missing_name_returns_422(self):
|
||||
body = {**_VALID_CREATE}
|
||||
del body["name"]
|
||||
resp = client.post("/api/schedules", json=body)
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_create_invalid_schedule_type_returns_422(self):
|
||||
body = {**_VALID_CREATE, "schedule_config": {"schedule_type": "invalid"}}
|
||||
resp = client.post("/api/schedules", json=body)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PUT /api/schedules/{id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUpdateSchedule:
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_update_name(self, mock_get_conn):
|
||||
updated_row = list(_DB_ROW)
|
||||
updated_row[2] = "新名称"
|
||||
mock_conn, _ = _mock_conn_with_fetchone(tuple(updated_row))
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.put("/api/schedules/sched-1", json={"name": "新名称"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "新名称"
|
||||
|
||||
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_update_schedule_config_recalculates_next_run(self, mock_get_conn, mock_calc):
|
||||
mock_conn, _ = _mock_conn_with_fetchone(_DB_ROW)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.put("/api/schedules/sched-1", json={
|
||||
"schedule_config": {"schedule_type": "interval", "interval_value": 2, "interval_unit": "hours"},
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
mock_calc.assert_called_once()
|
||||
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_update_not_found(self, mock_get_conn):
|
||||
mock_conn, _ = _mock_conn_with_fetchone(None)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.put("/api/schedules/nonexistent", json={"name": "x"})
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_update_empty_body_returns_422(self):
|
||||
resp = client.put("/api/schedules/sched-1", json={})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DELETE /api/schedules/{id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDeleteSchedule:
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_delete_success(self, mock_get_conn):
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.rowcount = 1
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.delete("/api/schedules/sched-1")
|
||||
assert resp.status_code == 200
|
||||
assert "已删除" in resp.json()["message"]
|
||||
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_delete_not_found(self, mock_get_conn):
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.rowcount = 0
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.delete("/api/schedules/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PATCH /api/schedules/{id}/toggle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestToggleSchedule:
|
||||
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_toggle_disable(self, mock_get_conn, mock_calc):
|
||||
"""启用 → 禁用:next_run_at 应置 NULL"""
|
||||
# 第一次 fetchone 返回当前状态(enabled=True)
|
||||
# 第二次 fetchone 返回更新后的行
|
||||
disabled_row = list(_DB_ROW)
|
||||
disabled_row[6] = False # enabled
|
||||
disabled_row[8] = None # next_run_at
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.side_effect = [
|
||||
(True, json.dumps(_SCHEDULE_CONFIG)), # SELECT 当前状态
|
||||
tuple(disabled_row), # UPDATE RETURNING
|
||||
]
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.patch("/api/schedules/sched-1/toggle")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["enabled"] is False
|
||||
assert data["next_run_at"] is None
|
||||
|
||||
@patch("app.routers.schedules.calculate_next_run", return_value=_NEXT)
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_toggle_enable(self, mock_get_conn, mock_calc):
|
||||
"""禁用 → 启用:next_run_at 应被重新计算"""
|
||||
enabled_row = list(_DB_ROW)
|
||||
enabled_row[6] = True
|
||||
enabled_row[8] = _NEXT
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.side_effect = [
|
||||
(False, json.dumps(_SCHEDULE_CONFIG)), # SELECT 当前状态(disabled)
|
||||
tuple(enabled_row), # UPDATE RETURNING
|
||||
]
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.patch("/api/schedules/sched-1/toggle")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["enabled"] is True
|
||||
assert data["next_run_at"] is not None
|
||||
mock_calc.assert_called_once()
|
||||
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_toggle_not_found(self, mock_get_conn):
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = None
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
resp = client.patch("/api/schedules/nonexistent/toggle")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 认证测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSchedulesAuth:
|
||||
def test_requires_auth(self):
|
||||
"""移除认证覆盖后,所有端点应返回 401/403"""
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
try:
|
||||
assert client.get("/api/schedules").status_code in (401, 403)
|
||||
assert client.post("/api/schedules", json=_VALID_CREATE).status_code in (401, 403)
|
||||
assert client.put("/api/schedules/x", json={"name": "x"}).status_code in (401, 403)
|
||||
assert client.delete("/api/schedules/x").status_code in (401, 403)
|
||||
assert client.patch("/api/schedules/x/toggle").status_code in (401, 403)
|
||||
finally:
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
336
apps/backend/tests/test_site_isolation_properties.py
Normal file
336
apps/backend/tests/test_site_isolation_properties.py
Normal file
@@ -0,0 +1,336 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""门店隔离属性测试(Property-Based Testing)。
|
||||
|
||||
Property 20: 对于任意两个不同 site_id 的 Operator,一个 Operator 查询
|
||||
队列/调度/执行历史时,结果中不应包含另一个 site_id 的数据。
|
||||
|
||||
Validates: Requirements 1.3
|
||||
|
||||
测试策略:
|
||||
- 通过 mock 数据库交互,验证 API 路由在不同 site_id 下的数据隔离
|
||||
- 队列隔离:为 site_id_a 入队任务,用 site_id_b 的 JWT 查询队列,结果应为空
|
||||
- 调度隔离:为 site_id_a 创建调度任务,用 site_id_b 的 JWT 查询调度列表,结果应为空
|
||||
- 执行历史隔离:site_id_a 的执行历史,用 site_id_b 的 JWT 查询不到
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-isolation")
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 通用策略(Strategies)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_site_id_st = st.integers(min_value=1, max_value=2**31 - 1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_mock_user(site_id: int) -> CurrentUser:
|
||||
"""构造指定 site_id 的 mock 用户。"""
|
||||
return CurrentUser(user_id=1, site_id=site_id)
|
||||
|
||||
|
||||
def _make_queue_rows(site_id: int, count: int) -> list[tuple]:
|
||||
"""生成 count 条属于 site_id 的队列行。"""
|
||||
rows = []
|
||||
for i in range(count):
|
||||
rows.append((
|
||||
str(uuid.uuid4()), # id
|
||||
site_id, # site_id
|
||||
json.dumps({"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}), # config
|
||||
"pending", # status
|
||||
i + 1, # position
|
||||
datetime(2024, 1, 1, tzinfo=timezone.utc), # created_at
|
||||
None, # started_at
|
||||
None, # finished_at
|
||||
None, # exit_code
|
||||
None, # error_message
|
||||
))
|
||||
return rows
|
||||
|
||||
|
||||
def _make_schedule_rows(site_id: int, count: int) -> list[tuple]:
|
||||
"""生成 count 条属于 site_id 的调度行。"""
|
||||
now = datetime.now(timezone.utc)
|
||||
rows = []
|
||||
for i in range(count):
|
||||
rows.append((
|
||||
str(uuid.uuid4()), # id
|
||||
site_id, # site_id
|
||||
f"调度任务_{i}", # name
|
||||
["ODS_MEMBER"], # task_codes
|
||||
{"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}, # task_config
|
||||
{"schedule_type": "daily", "daily_time": "04:00", # schedule_config
|
||||
"interval_value": 1, "interval_unit": "hours",
|
||||
"weekly_days": [1], "weekly_time": "04:00",
|
||||
"cron_expression": "0 4 * * *", "enabled": True,
|
||||
"start_date": None, "end_date": None},
|
||||
True, # enabled
|
||||
None, # last_run_at
|
||||
now + timedelta(hours=1), # next_run_at
|
||||
0, # run_count
|
||||
None, # last_status
|
||||
now, # created_at
|
||||
now, # updated_at
|
||||
))
|
||||
return rows
|
||||
|
||||
|
||||
def _make_history_rows(site_id: int, count: int) -> list[tuple]:
|
||||
"""生成 count 条属于 site_id 的执行历史行。"""
|
||||
base_time = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
rows = []
|
||||
for i in range(count):
|
||||
rows.append((
|
||||
str(uuid.uuid4()), # id
|
||||
site_id, # site_id
|
||||
["ODS_MEMBER"], # task_codes
|
||||
"success", # status
|
||||
base_time + timedelta(hours=i), # started_at
|
||||
base_time + timedelta(hours=i, minutes=30), # finished_at
|
||||
0, # exit_code
|
||||
1800000, # duration_ms
|
||||
"python -m cli.main", # command
|
||||
None, # summary
|
||||
))
|
||||
return rows
|
||||
|
||||
|
||||
def _mock_conn_returning(rows: list[tuple]) -> MagicMock:
|
||||
"""构造一个 mock connection,其 cursor.fetchall 返回指定行。"""
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchall.return_value = rows
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 20.1: 队列隔离
|
||||
# **Validates: Requirements 1.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100, deadline=None)
|
||||
@given(
|
||||
site_id_a=_site_id_st,
|
||||
site_id_b=_site_id_st,
|
||||
queue_count=st.integers(min_value=1, max_value=5),
|
||||
)
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_queue_isolation(mock_get_conn, site_id_a, site_id_b, queue_count):
|
||||
"""Property 20.1: 队列隔离。
|
||||
|
||||
为 site_id_a 入队若干任务后,用 site_id_b 的身份查询队列,
|
||||
结果应为空——不同门店的队列数据互不可见。
|
||||
"""
|
||||
assume(site_id_a != site_id_b)
|
||||
|
||||
# site_id_a 的队列数据
|
||||
rows_a = _make_queue_rows(site_id_a, queue_count)
|
||||
|
||||
# 核心隔离逻辑:根据查询时传入的 site_id 过滤
|
||||
# list_pending 内部 SQL: WHERE site_id = %s AND status = 'pending'
|
||||
def conn_for_site(querying_site_id):
|
||||
"""模拟数据库行为:只返回匹配 site_id 的行。"""
|
||||
if querying_site_id == site_id_a:
|
||||
return rows_a
|
||||
return [] # site_id_b 查不到 site_id_a 的数据
|
||||
|
||||
captured_params = {}
|
||||
|
||||
def make_mock_conn():
|
||||
mock_cursor = MagicMock()
|
||||
|
||||
def execute_side_effect(sql, params=None):
|
||||
if params:
|
||||
captured_params["site_id"] = params[0]
|
||||
# 根据 SQL 中的 site_id 参数返回对应数据
|
||||
mock_cursor.fetchall.return_value = conn_for_site(params[0])
|
||||
|
||||
mock_cursor.execute = MagicMock(side_effect=execute_side_effect)
|
||||
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_cursor.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
return mock_conn
|
||||
|
||||
mock_get_conn.return_value = make_mock_conn()
|
||||
|
||||
# 用 site_id_b 的身份查询队列
|
||||
app.dependency_overrides[get_current_user] = lambda: _make_mock_user(site_id_b)
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/execution/queue")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
|
||||
# 验证:site_id_b 查不到 site_id_a 的任何数据
|
||||
assert len(data) == 0, (
|
||||
f"site_id_b={site_id_b} 不应看到 site_id_a={site_id_a} 的队列数据,"
|
||||
f"但返回了 {len(data)} 条记录"
|
||||
)
|
||||
|
||||
# 额外验证:即使有数据返回,也不应包含 site_id_a 的记录
|
||||
for item in data:
|
||||
assert item.get("site_id") != site_id_a, (
|
||||
f"结果中不应包含 site_id_a={site_id_a} 的数据"
|
||||
)
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 20.2: 调度隔离
|
||||
# **Validates: Requirements 1.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
site_id_a=_site_id_st,
|
||||
site_id_b=_site_id_st,
|
||||
schedule_count=st.integers(min_value=1, max_value=5),
|
||||
)
|
||||
@patch("app.routers.schedules.get_connection")
|
||||
def test_schedule_isolation(mock_get_conn, site_id_a, site_id_b, schedule_count):
|
||||
"""Property 20.2: 调度隔离。
|
||||
|
||||
为 site_id_a 创建若干调度任务后,用 site_id_b 的身份查询调度列表,
|
||||
结果应为空——不同门店的调度数据互不可见。
|
||||
"""
|
||||
assume(site_id_a != site_id_b)
|
||||
|
||||
# site_id_a 的调度数据
|
||||
rows_a = _make_schedule_rows(site_id_a, schedule_count)
|
||||
|
||||
def make_mock_conn():
|
||||
mock_cursor = MagicMock()
|
||||
|
||||
def execute_side_effect(sql, params=None):
|
||||
if params:
|
||||
querying_site_id = params[0]
|
||||
# 只返回匹配 site_id 的行
|
||||
if querying_site_id == site_id_a:
|
||||
mock_cursor.fetchall.return_value = rows_a
|
||||
else:
|
||||
mock_cursor.fetchall.return_value = []
|
||||
|
||||
mock_cursor.execute = MagicMock(side_effect=execute_side_effect)
|
||||
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_cursor.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
return mock_conn
|
||||
|
||||
mock_get_conn.return_value = make_mock_conn()
|
||||
|
||||
# 用 site_id_b 的身份查询调度列表
|
||||
app.dependency_overrides[get_current_user] = lambda: _make_mock_user(site_id_b)
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/schedules")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
|
||||
# 验证:site_id_b 查不到 site_id_a 的任何调度数据
|
||||
assert len(data) == 0, (
|
||||
f"site_id_b={site_id_b} 不应看到 site_id_a={site_id_a} 的调度数据,"
|
||||
f"但返回了 {len(data)} 条记录"
|
||||
)
|
||||
|
||||
# 额外验证:即使有数据返回,也不应包含 site_id_a 的记录
|
||||
for item in data:
|
||||
assert item.get("site_id") != site_id_a, (
|
||||
f"结果中不应包含 site_id_a={site_id_a} 的调度数据"
|
||||
)
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 20.3: 执行历史隔离
|
||||
# **Validates: Requirements 1.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100, deadline=None)
|
||||
@given(
|
||||
site_id_a=_site_id_st,
|
||||
site_id_b=_site_id_st,
|
||||
history_count=st.integers(min_value=1, max_value=10),
|
||||
)
|
||||
@patch("app.routers.execution.get_connection")
|
||||
def test_execution_history_isolation(mock_get_conn, site_id_a, site_id_b, history_count):
|
||||
"""Property 20.3: 执行历史隔离。
|
||||
|
||||
site_id_a 有若干执行历史记录,用 site_id_b 的身份查询执行历史,
|
||||
结果应为空——不同门店的执行历史互不可见。
|
||||
"""
|
||||
assume(site_id_a != site_id_b)
|
||||
|
||||
# site_id_a 的执行历史数据
|
||||
rows_a = _make_history_rows(site_id_a, history_count)
|
||||
|
||||
def make_mock_conn():
|
||||
mock_cursor = MagicMock()
|
||||
|
||||
def execute_side_effect(sql, params=None):
|
||||
if params:
|
||||
querying_site_id = params[0]
|
||||
# 只返回匹配 site_id 的行
|
||||
if querying_site_id == site_id_a:
|
||||
mock_cursor.fetchall.return_value = rows_a
|
||||
else:
|
||||
mock_cursor.fetchall.return_value = []
|
||||
|
||||
mock_cursor.execute = MagicMock(side_effect=execute_side_effect)
|
||||
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_cursor.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
return mock_conn
|
||||
|
||||
mock_get_conn.return_value = make_mock_conn()
|
||||
|
||||
# 用 site_id_b 的身份查询执行历史
|
||||
app.dependency_overrides[get_current_user] = lambda: _make_mock_user(site_id_b)
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/execution/history")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
|
||||
# 验证:site_id_b 查不到 site_id_a 的任何执行历史
|
||||
assert len(data) == 0, (
|
||||
f"site_id_b={site_id_b} 不应看到 site_id_a={site_id_a} 的执行历史,"
|
||||
f"但返回了 {len(data)} 条记录"
|
||||
)
|
||||
|
||||
# 额外验证:即使有数据返回,也不应包含 site_id_a 的记录
|
||||
for item in data:
|
||||
assert item.get("site_id") != site_id_a, (
|
||||
f"结果中不应包含 site_id_a={site_id_a} 的执行历史"
|
||||
)
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
275
apps/backend/tests/test_task_config_properties.py
Normal file
275
apps/backend/tests/test_task_config_properties.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""TaskConfig 属性测试(Property-Based Testing)。
|
||||
|
||||
使用 hypothesis 验证 TaskConfig 相关的通用正确性属性:
|
||||
- Property 1: TaskConfig 序列化往返一致性
|
||||
- Property 6: 时间窗口验证
|
||||
- Property 7: TaskConfig 到 CLI 命令转换完整性
|
||||
"""
|
||||
|
||||
import datetime
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.schemas.tasks import TaskConfigSchema
|
||||
from app.services.cli_builder import CLIBuilder, VALID_FLOWS, VALID_PROCESSING_MODES
|
||||
from app.services.task_registry import ALL_TASKS
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 策略(Strategies)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 从真实任务注册表中采样任务代码
|
||||
_task_codes = [t.code for t in ALL_TASKS]
|
||||
|
||||
_tasks_st = st.lists(
|
||||
st.sampled_from(_task_codes),
|
||||
min_size=1,
|
||||
max_size=5,
|
||||
unique=True,
|
||||
)
|
||||
|
||||
_pipeline_st = st.sampled_from(sorted(VALID_FLOWS))
|
||||
_processing_mode_st = st.sampled_from(sorted(VALID_PROCESSING_MODES))
|
||||
_window_mode_st = st.sampled_from(["lookback", "custom"])
|
||||
|
||||
# 日期策略:生成 YYYY-MM-DD 格式字符串
|
||||
_date_st = st.dates(
|
||||
min_value=datetime.date(2020, 1, 1),
|
||||
max_value=datetime.date(2030, 12, 31),
|
||||
).map(lambda d: d.isoformat())
|
||||
|
||||
_window_split_st = st.sampled_from([None, "none", "day"])
|
||||
_window_split_days_st = st.one_of(st.none(), st.sampled_from([1, 10, 30]))
|
||||
_lookback_hours_st = st.integers(min_value=1, max_value=720)
|
||||
_overlap_seconds_st = st.integers(min_value=0, max_value=7200)
|
||||
_store_id_st = st.one_of(st.none(), st.integers(min_value=1, max_value=2**31 - 1))
|
||||
|
||||
# DWD 表名采样
|
||||
_dwd_table_names = [
|
||||
"dwd.dim_site",
|
||||
"dwd.dim_member",
|
||||
"dwd.dwd_settlement_head",
|
||||
]
|
||||
_dwd_only_tables_st = st.one_of(
|
||||
st.none(),
|
||||
st.lists(st.sampled_from(_dwd_table_names), min_size=1, max_size=3, unique=True),
|
||||
)
|
||||
|
||||
|
||||
def _valid_task_config_st():
|
||||
"""生成有效的 TaskConfigSchema 的复合策略。
|
||||
|
||||
确保 window_mode=custom 时 window_end >= window_start,
|
||||
避免触发 Pydantic 验证错误。
|
||||
"""
|
||||
|
||||
@st.composite
|
||||
def _build(draw):
|
||||
tasks = draw(_tasks_st)
|
||||
pipeline = draw(_pipeline_st)
|
||||
processing_mode = draw(_processing_mode_st)
|
||||
dry_run = draw(st.booleans())
|
||||
window_mode = draw(_window_mode_st)
|
||||
store_id = draw(_store_id_st)
|
||||
dwd_only_tables = draw(_dwd_only_tables_st)
|
||||
window_split = draw(_window_split_st)
|
||||
window_split_days = draw(_window_split_days_st)
|
||||
fetch_before_verify = draw(st.booleans())
|
||||
skip_ods = draw(st.booleans())
|
||||
ods_local = draw(st.booleans())
|
||||
|
||||
if window_mode == "custom":
|
||||
d1 = draw(st.dates(
|
||||
min_value=datetime.date(2020, 1, 1),
|
||||
max_value=datetime.date(2030, 12, 31),
|
||||
))
|
||||
d2 = draw(st.dates(
|
||||
min_value=datetime.date(2020, 1, 1),
|
||||
max_value=datetime.date(2030, 12, 31),
|
||||
))
|
||||
# 保证 end >= start
|
||||
window_start = min(d1, d2).isoformat()
|
||||
window_end = max(d1, d2).isoformat()
|
||||
lookback_hours = 24
|
||||
overlap_seconds = 600
|
||||
else:
|
||||
window_start = None
|
||||
window_end = None
|
||||
lookback_hours = draw(_lookback_hours_st)
|
||||
overlap_seconds = draw(_overlap_seconds_st)
|
||||
|
||||
return TaskConfigSchema(
|
||||
tasks=tasks,
|
||||
pipeline=pipeline,
|
||||
processing_mode=processing_mode,
|
||||
dry_run=dry_run,
|
||||
window_mode=window_mode,
|
||||
window_start=window_start,
|
||||
window_end=window_end,
|
||||
window_split=window_split,
|
||||
window_split_days=window_split_days,
|
||||
lookback_hours=lookback_hours,
|
||||
overlap_seconds=overlap_seconds,
|
||||
fetch_before_verify=fetch_before_verify,
|
||||
skip_ods_when_fetch_before_verify=skip_ods,
|
||||
ods_use_local_json=ods_local,
|
||||
store_id=store_id,
|
||||
dwd_only_tables=dwd_only_tables,
|
||||
extra_args={},
|
||||
)
|
||||
|
||||
return _build()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 1: TaskConfig 序列化往返一致性
|
||||
# **Validates: Requirements 11.1, 11.2, 11.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@settings(max_examples=200)
|
||||
@given(config=_valid_task_config_st())
|
||||
def test_task_config_round_trip(config: TaskConfigSchema):
|
||||
"""Property 1: 序列化为 JSON 后再反序列化,应产生与原始对象等价的结果。"""
|
||||
json_str = config.model_dump_json()
|
||||
restored = TaskConfigSchema.model_validate_json(json_str)
|
||||
assert restored == config, (
|
||||
f"往返不一致:\n原始={config}\n还原={restored}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 6: 时间窗口验证
|
||||
# **Validates: Requirements 2.3**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@settings(max_examples=200)
|
||||
@given(
|
||||
d1=st.dates(
|
||||
min_value=datetime.date(2020, 1, 1),
|
||||
max_value=datetime.date(2030, 12, 31),
|
||||
),
|
||||
d2=st.dates(
|
||||
min_value=datetime.date(2020, 1, 1),
|
||||
max_value=datetime.date(2030, 12, 31),
|
||||
),
|
||||
)
|
||||
def test_time_window_validation(d1: datetime.date, d2: datetime.date):
|
||||
"""Property 6: window_end < window_start 时验证应失败,否则应通过。"""
|
||||
start_str = d1.isoformat()
|
||||
end_str = d2.isoformat()
|
||||
|
||||
if end_str < start_str:
|
||||
# window_end 早于 window_start → 验证应失败
|
||||
try:
|
||||
TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
window_mode="custom",
|
||||
window_start=start_str,
|
||||
window_end=end_str,
|
||||
)
|
||||
raise AssertionError(
|
||||
f"期望 ValidationError,但验证通过了:start={start_str}, end={end_str}"
|
||||
)
|
||||
except ValidationError:
|
||||
pass # 预期行为
|
||||
else:
|
||||
# window_end >= window_start → 验证应通过
|
||||
config = TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER"],
|
||||
window_mode="custom",
|
||||
window_start=start_str,
|
||||
window_end=end_str,
|
||||
)
|
||||
assert config.window_start == start_str
|
||||
assert config.window_end == end_str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature: admin-web-console, Property 7: TaskConfig 到 CLI 命令转换完整性
|
||||
# **Validates: Requirements 2.5, 2.6**
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_builder = CLIBuilder()
|
||||
_ETL_PATH = "/fake/etl/project"
|
||||
|
||||
|
||||
@settings(max_examples=200)
|
||||
@given(config=_valid_task_config_st())
|
||||
def test_task_config_to_cli_completeness(config: TaskConfigSchema):
|
||||
"""Property 7: CLIBuilder 生成的命令应包含 TaskConfig 中所有非空字段对应的 CLI 参数。"""
|
||||
cmd = _builder.build_command(config, _ETL_PATH)
|
||||
|
||||
# 1) --pipeline 始终存在且值正确
|
||||
assert "--pipeline" in cmd
|
||||
idx = cmd.index("--pipeline")
|
||||
assert cmd[idx + 1] == config.pipeline
|
||||
|
||||
# 2) --processing-mode 始终存在且值正确
|
||||
assert "--processing-mode" in cmd
|
||||
idx = cmd.index("--processing-mode")
|
||||
assert cmd[idx + 1] == config.processing_mode
|
||||
|
||||
# 3) 非空任务列表 → --tasks 存在
|
||||
if config.tasks:
|
||||
assert "--tasks" in cmd
|
||||
idx = cmd.index("--tasks")
|
||||
assert set(cmd[idx + 1].split(",")) == set(config.tasks)
|
||||
|
||||
# 4) 时间窗口参数
|
||||
if config.window_mode == "lookback":
|
||||
# lookback 模式 → --lookback-hours 和 --overlap-seconds
|
||||
if config.lookback_hours is not None:
|
||||
assert "--lookback-hours" in cmd
|
||||
idx = cmd.index("--lookback-hours")
|
||||
assert cmd[idx + 1] == str(config.lookback_hours)
|
||||
if config.overlap_seconds is not None:
|
||||
assert "--overlap-seconds" in cmd
|
||||
idx = cmd.index("--overlap-seconds")
|
||||
assert cmd[idx + 1] == str(config.overlap_seconds)
|
||||
# lookback 模式不应出现 custom 参数
|
||||
assert "--window-start" not in cmd
|
||||
assert "--window-end" not in cmd
|
||||
else:
|
||||
# custom 模式 → --window-start / --window-end
|
||||
if config.window_start:
|
||||
assert "--window-start" in cmd
|
||||
if config.window_end:
|
||||
assert "--window-end" in cmd
|
||||
# custom 模式不应出现 lookback 参数
|
||||
assert "--lookback-hours" not in cmd
|
||||
assert "--overlap-seconds" not in cmd
|
||||
|
||||
# 5) dry_run → --dry-run
|
||||
if config.dry_run:
|
||||
assert "--dry-run" in cmd
|
||||
else:
|
||||
assert "--dry-run" not in cmd
|
||||
|
||||
# 6) store_id → --store-id
|
||||
if config.store_id is not None:
|
||||
assert "--store-id" in cmd
|
||||
idx = cmd.index("--store-id")
|
||||
assert cmd[idx + 1] == str(config.store_id)
|
||||
else:
|
||||
assert "--store-id" not in cmd
|
||||
|
||||
# 7) fetch_before_verify → 仅 verify_only 模式下生成
|
||||
if config.fetch_before_verify and config.processing_mode == "verify_only":
|
||||
assert "--fetch-before-verify" in cmd
|
||||
else:
|
||||
assert "--fetch-before-verify" not in cmd
|
||||
|
||||
# 8) window_split(非 None 且非 "none")→ --window-split
|
||||
if config.window_split and config.window_split != "none":
|
||||
assert "--window-split" in cmd
|
||||
idx = cmd.index("--window-split")
|
||||
assert cmd[idx + 1] == config.window_split
|
||||
if config.window_split_days is not None:
|
||||
assert "--window-split-days" in cmd
|
||||
else:
|
||||
assert "--window-split" not in cmd
|
||||
373
apps/backend/tests/test_task_executor.py
Normal file
373
apps/backend/tests/test_task_executor.py
Normal file
@@ -0,0 +1,373 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""TaskExecutor 单元测试
|
||||
|
||||
覆盖:子进程启动、stdout/stderr 读取、日志广播、取消、数据库记录。
|
||||
使用 asyncio 测试,mock 子进程和数据库连接避免外部依赖。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.schemas.tasks import TaskConfigSchema
|
||||
from app.services.task_executor import TaskExecutor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def executor() -> TaskExecutor:
|
||||
return TaskExecutor()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_config() -> TaskConfigSchema:
|
||||
return TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER", "ODS_PAYMENT"],
|
||||
pipeline="api_ods_dwd",
|
||||
store_id=42,
|
||||
)
|
||||
|
||||
|
||||
def _make_stream(lines: list[bytes]) -> AsyncMock:
|
||||
"""构造一个模拟的 asyncio.StreamReader,按行返回数据。"""
|
||||
stream = AsyncMock()
|
||||
# readline 依次返回每行,最后返回 b"" 表示 EOF
|
||||
stream.readline = AsyncMock(side_effect=[*lines, b""])
|
||||
return stream
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 订阅 / 取消订阅
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSubscription:
|
||||
def test_subscribe_returns_queue(self, executor: TaskExecutor):
|
||||
q = executor.subscribe("exec-1")
|
||||
assert isinstance(q, asyncio.Queue)
|
||||
|
||||
def test_subscribe_multiple(self, executor: TaskExecutor):
|
||||
q1 = executor.subscribe("exec-1")
|
||||
q2 = executor.subscribe("exec-1")
|
||||
assert q1 is not q2
|
||||
assert len(executor._subscribers["exec-1"]) == 2
|
||||
|
||||
def test_unsubscribe_removes_queue(self, executor: TaskExecutor):
|
||||
q = executor.subscribe("exec-1")
|
||||
executor.unsubscribe("exec-1", q)
|
||||
# 最后一个订阅者移除后,键也被清理
|
||||
assert "exec-1" not in executor._subscribers
|
||||
|
||||
def test_unsubscribe_nonexistent_is_safe(self, executor: TaskExecutor):
|
||||
"""对不存在的 execution_id 取消订阅不应报错"""
|
||||
q: asyncio.Queue = asyncio.Queue()
|
||||
executor.unsubscribe("nonexistent", q)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 广播
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBroadcast:
|
||||
def test_broadcast_to_subscribers(self, executor: TaskExecutor):
|
||||
q1 = executor.subscribe("exec-1")
|
||||
q2 = executor.subscribe("exec-1")
|
||||
executor._broadcast("exec-1", "hello")
|
||||
assert q1.get_nowait() == "hello"
|
||||
assert q2.get_nowait() == "hello"
|
||||
|
||||
def test_broadcast_no_subscribers_is_safe(self, executor: TaskExecutor):
|
||||
"""无订阅者时广播不应报错"""
|
||||
executor._broadcast("nonexistent", "hello")
|
||||
|
||||
def test_broadcast_end_sends_none(self, executor: TaskExecutor):
|
||||
q = executor.subscribe("exec-1")
|
||||
executor._broadcast_end("exec-1")
|
||||
assert q.get_nowait() is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 日志缓冲区
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLogBuffer:
|
||||
def test_get_logs_empty(self, executor: TaskExecutor):
|
||||
assert executor.get_logs("nonexistent") == []
|
||||
|
||||
def test_get_logs_returns_copy(self, executor: TaskExecutor):
|
||||
executor._log_buffers["exec-1"] = ["line1", "line2"]
|
||||
logs = executor.get_logs("exec-1")
|
||||
assert logs == ["line1", "line2"]
|
||||
# 修改副本不影响原始
|
||||
logs.append("line3")
|
||||
assert len(executor._log_buffers["exec-1"]) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 执行状态查询
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRunningState:
|
||||
def test_is_running_false_when_no_process(self, executor: TaskExecutor):
|
||||
assert executor.is_running("nonexistent") is False
|
||||
|
||||
def test_is_running_true_when_process_active(self, executor: TaskExecutor):
|
||||
proc = MagicMock()
|
||||
proc.returncode = None
|
||||
executor._processes["exec-1"] = proc
|
||||
assert executor.is_running("exec-1") is True
|
||||
|
||||
def test_is_running_false_when_process_exited(self, executor: TaskExecutor):
|
||||
proc = MagicMock()
|
||||
proc.returncode = 0
|
||||
executor._processes["exec-1"] = proc
|
||||
assert executor.is_running("exec-1") is False
|
||||
|
||||
def test_get_running_ids(self, executor: TaskExecutor):
|
||||
running = MagicMock()
|
||||
running.returncode = None
|
||||
exited = MagicMock()
|
||||
exited.returncode = 0
|
||||
executor._processes["a"] = running
|
||||
executor._processes["b"] = exited
|
||||
assert executor.get_running_ids() == ["a"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_stream
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadStream:
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_stdout_lines(self, executor: TaskExecutor):
|
||||
executor._log_buffers["exec-1"] = []
|
||||
stream = _make_stream([b"line1\n", b"line2\n"])
|
||||
collector: list[str] = []
|
||||
await executor._read_stream("exec-1", stream, "stdout", collector)
|
||||
assert collector == ["line1", "line2"]
|
||||
assert executor._log_buffers["exec-1"] == [
|
||||
"[stdout] line1",
|
||||
"[stdout] line2",
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_stderr_lines(self, executor: TaskExecutor):
|
||||
executor._log_buffers["exec-1"] = []
|
||||
stream = _make_stream([b"err1\n"])
|
||||
collector: list[str] = []
|
||||
await executor._read_stream("exec-1", stream, "stderr", collector)
|
||||
assert collector == ["err1"]
|
||||
assert executor._log_buffers["exec-1"] == ["[stderr] err1"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_stream_none_is_safe(self, executor: TaskExecutor):
|
||||
"""stream 为 None 时不应报错"""
|
||||
collector: list[str] = []
|
||||
await executor._read_stream("exec-1", None, "stdout", collector)
|
||||
assert collector == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_during_read(self, executor: TaskExecutor):
|
||||
executor._log_buffers["exec-1"] = []
|
||||
q = executor.subscribe("exec-1")
|
||||
stream = _make_stream([b"hello\n"])
|
||||
collector: list[str] = []
|
||||
await executor._read_stream("exec-1", stream, "stdout", collector)
|
||||
assert q.get_nowait() == "[stdout] hello"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# execute(集成级,mock 子进程和数据库)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExecute:
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
|
||||
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
|
||||
@patch("asyncio.create_subprocess_exec")
|
||||
async def test_successful_execution(
|
||||
self, mock_create, mock_write_log, mock_update_log,
|
||||
executor: TaskExecutor, sample_config: TaskConfigSchema,
|
||||
):
|
||||
# 模拟子进程
|
||||
proc = AsyncMock()
|
||||
proc.returncode = None
|
||||
proc.stdout = _make_stream([b"processing...\n", b"done\n"])
|
||||
proc.stderr = _make_stream([])
|
||||
proc.wait = AsyncMock(return_value=0)
|
||||
# wait 调用后设置 returncode
|
||||
async def _wait():
|
||||
proc.returncode = 0
|
||||
return 0
|
||||
proc.wait = _wait
|
||||
mock_create.return_value = proc
|
||||
|
||||
await executor.execute(sample_config, "exec-1", site_id=42)
|
||||
|
||||
# 验证写入了 running 状态
|
||||
mock_write_log.assert_called_once()
|
||||
call_kwargs = mock_write_log.call_args[1]
|
||||
assert call_kwargs["status"] == "running"
|
||||
assert call_kwargs["execution_id"] == "exec-1"
|
||||
|
||||
# 验证更新了 success 状态
|
||||
mock_update_log.assert_called_once()
|
||||
update_kwargs = mock_update_log.call_args[1]
|
||||
assert update_kwargs["status"] == "success"
|
||||
assert update_kwargs["exit_code"] == 0
|
||||
assert "processing..." in update_kwargs["output_log"]
|
||||
assert "done" in update_kwargs["output_log"]
|
||||
|
||||
# 进程已从跟踪表移除
|
||||
assert "exec-1" not in executor._processes
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
|
||||
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
|
||||
@patch("asyncio.create_subprocess_exec")
|
||||
async def test_failed_execution(
|
||||
self, mock_create, mock_write_log, mock_update_log,
|
||||
executor: TaskExecutor, sample_config: TaskConfigSchema,
|
||||
):
|
||||
proc = AsyncMock()
|
||||
proc.returncode = None
|
||||
proc.stdout = _make_stream([])
|
||||
proc.stderr = _make_stream([b"error occurred\n"])
|
||||
async def _wait():
|
||||
proc.returncode = 1
|
||||
return 1
|
||||
proc.wait = _wait
|
||||
mock_create.return_value = proc
|
||||
|
||||
await executor.execute(sample_config, "exec-2", site_id=42)
|
||||
|
||||
update_kwargs = mock_update_log.call_args[1]
|
||||
assert update_kwargs["status"] == "failed"
|
||||
assert update_kwargs["exit_code"] == 1
|
||||
assert "error occurred" in update_kwargs["error_log"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
|
||||
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
|
||||
@patch("asyncio.create_subprocess_exec")
|
||||
async def test_exception_during_execution(
|
||||
self, mock_create, mock_write_log, mock_update_log,
|
||||
executor: TaskExecutor, sample_config: TaskConfigSchema,
|
||||
):
|
||||
"""子进程创建失败时应记录 failed 状态"""
|
||||
mock_create.side_effect = OSError("command not found")
|
||||
|
||||
await executor.execute(sample_config, "exec-3", site_id=42)
|
||||
|
||||
update_kwargs = mock_update_log.call_args[1]
|
||||
assert update_kwargs["status"] == "failed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
|
||||
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
|
||||
@patch("asyncio.create_subprocess_exec")
|
||||
async def test_subscribers_notified_on_completion(
|
||||
self, mock_create, mock_write_log, mock_update_log,
|
||||
executor: TaskExecutor, sample_config: TaskConfigSchema,
|
||||
):
|
||||
proc = AsyncMock()
|
||||
proc.returncode = None
|
||||
proc.stdout = _make_stream([b"line\n"])
|
||||
proc.stderr = _make_stream([])
|
||||
async def _wait():
|
||||
proc.returncode = 0
|
||||
return 0
|
||||
proc.wait = _wait
|
||||
mock_create.return_value = proc
|
||||
|
||||
q = executor.subscribe("exec-4")
|
||||
await executor.execute(sample_config, "exec-4", site_id=42)
|
||||
|
||||
# 应收到日志行 + None 哨兵
|
||||
messages = []
|
||||
while not q.empty():
|
||||
messages.append(q.get_nowait())
|
||||
assert "[stdout] line" in messages
|
||||
assert None in messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
|
||||
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
|
||||
@patch("asyncio.create_subprocess_exec")
|
||||
async def test_duration_ms_recorded(
|
||||
self, mock_create, mock_write_log, mock_update_log,
|
||||
executor: TaskExecutor, sample_config: TaskConfigSchema,
|
||||
):
|
||||
proc = AsyncMock()
|
||||
proc.returncode = None
|
||||
proc.stdout = _make_stream([])
|
||||
proc.stderr = _make_stream([])
|
||||
async def _wait():
|
||||
proc.returncode = 0
|
||||
return 0
|
||||
proc.wait = _wait
|
||||
mock_create.return_value = proc
|
||||
|
||||
await executor.execute(sample_config, "exec-5", site_id=42)
|
||||
|
||||
update_kwargs = mock_update_log.call_args[1]
|
||||
assert isinstance(update_kwargs["duration_ms"], int)
|
||||
assert update_kwargs["duration_ms"] >= 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cancel
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCancel:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_running_process(self, executor: TaskExecutor):
|
||||
proc = MagicMock()
|
||||
proc.returncode = None
|
||||
proc.terminate = MagicMock()
|
||||
executor._processes["exec-1"] = proc
|
||||
|
||||
result = await executor.cancel("exec-1")
|
||||
assert result is True
|
||||
proc.terminate.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_nonexistent_returns_false(self, executor: TaskExecutor):
|
||||
result = await executor.cancel("nonexistent")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_already_exited_returns_false(self, executor: TaskExecutor):
|
||||
proc = MagicMock()
|
||||
proc.returncode = 0
|
||||
executor._processes["exec-1"] = proc
|
||||
|
||||
result = await executor.cancel("exec-1")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_process_lookup_error(self, executor: TaskExecutor):
|
||||
"""进程已消失时 terminate 抛出 ProcessLookupError"""
|
||||
proc = MagicMock()
|
||||
proc.returncode = None
|
||||
proc.terminate = MagicMock(side_effect=ProcessLookupError)
|
||||
executor._processes["exec-1"] = proc
|
||||
|
||||
result = await executor.cancel("exec-1")
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cleanup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCleanup:
|
||||
def test_cleanup_removes_buffers_and_subscribers(self, executor: TaskExecutor):
|
||||
executor._log_buffers["exec-1"] = ["line"]
|
||||
executor.subscribe("exec-1")
|
||||
executor.cleanup("exec-1")
|
||||
assert "exec-1" not in executor._log_buffers
|
||||
assert "exec-1" not in executor._subscribers
|
||||
|
||||
def test_cleanup_nonexistent_is_safe(self, executor: TaskExecutor):
|
||||
executor.cleanup("nonexistent")
|
||||
482
apps/backend/tests/test_task_queue.py
Normal file
482
apps/backend/tests/test_task_queue.py
Normal file
@@ -0,0 +1,482 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""TaskQueue 单元测试
|
||||
|
||||
覆盖:enqueue、dequeue、reorder、delete、process_loop 的核心逻辑。
|
||||
使用 mock 数据库操作,专注于业务逻辑验证。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, AsyncMock, patch, call
|
||||
|
||||
import pytest
|
||||
|
||||
from app.schemas.tasks import TaskConfigSchema
|
||||
from app.services.task_queue import TaskQueue, QueuedTask
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def queue() -> TaskQueue:
|
||||
return TaskQueue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_config() -> TaskConfigSchema:
|
||||
return TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER", "ODS_PAYMENT"],
|
||||
pipeline="api_ods_dwd",
|
||||
store_id=42,
|
||||
)
|
||||
|
||||
|
||||
def _mock_cursor(fetchone_val=None, fetchall_val=None, rowcount=1):
|
||||
"""构造 mock cursor,支持 context manager 协议。"""
|
||||
cur = MagicMock()
|
||||
cur.fetchone.return_value = fetchone_val
|
||||
cur.fetchall.return_value = fetchall_val or []
|
||||
cur.rowcount = rowcount
|
||||
cur.__enter__ = MagicMock(return_value=cur)
|
||||
cur.__exit__ = MagicMock(return_value=False)
|
||||
return cur
|
||||
|
||||
|
||||
def _mock_conn(cursor):
|
||||
"""构造 mock connection,支持 cursor() context manager。"""
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value = cursor
|
||||
return conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# enqueue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEnqueue:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_enqueue_returns_uuid(self, mock_get_conn, queue, sample_config):
|
||||
cur = _mock_cursor(fetchone_val=(0,))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
task_id = queue.enqueue(sample_config, site_id=42)
|
||||
|
||||
# 返回有效 UUID
|
||||
uuid.UUID(task_id)
|
||||
conn.commit.assert_called_once()
|
||||
conn.close.assert_called_once()
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_enqueue_position_increments(self, mock_get_conn, queue, sample_config):
|
||||
"""新任务 position = 当前最大 position + 1"""
|
||||
cur = _mock_cursor(fetchone_val=(5,))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.enqueue(sample_config, site_id=42)
|
||||
|
||||
# 检查 INSERT 调用中的 position 参数
|
||||
insert_call = cur.execute.call_args_list[1]
|
||||
args = insert_call[0][1]
|
||||
# args = (task_id, site_id, config_json, new_pos)
|
||||
assert args[3] == 6 # 5 + 1
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_enqueue_empty_queue_position_is_one(self, mock_get_conn, queue, sample_config):
|
||||
"""空队列时 position = 1"""
|
||||
cur = _mock_cursor(fetchone_val=(0,))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.enqueue(sample_config, site_id=42)
|
||||
|
||||
insert_call = cur.execute.call_args_list[1]
|
||||
args = insert_call[0][1]
|
||||
assert args[3] == 1
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_enqueue_serializes_config(self, mock_get_conn, queue, sample_config):
|
||||
"""config 被序列化为 JSON 字符串"""
|
||||
cur = _mock_cursor(fetchone_val=(0,))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.enqueue(sample_config, site_id=42)
|
||||
|
||||
insert_call = cur.execute.call_args_list[1]
|
||||
config_json_str = insert_call[0][1][2]
|
||||
parsed = json.loads(config_json_str)
|
||||
assert parsed["tasks"] == ["ODS_MEMBER", "ODS_PAYMENT"]
|
||||
assert parsed["pipeline"] == "api_ods_dwd"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# dequeue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDequeue:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_dequeue_returns_none_when_empty(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor(fetchone_val=None)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
result = queue.dequeue(site_id=42)
|
||||
|
||||
assert result is None
|
||||
conn.commit.assert_called()
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_dequeue_returns_task(self, mock_get_conn, queue):
|
||||
task_id = str(uuid.uuid4())
|
||||
config_dict = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}
|
||||
row = (
|
||||
task_id, 42, json.dumps(config_dict), "pending", 1,
|
||||
None, None, None, None, None,
|
||||
)
|
||||
cur = _mock_cursor(fetchone_val=row)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
result = queue.dequeue(site_id=42)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == task_id
|
||||
assert result.site_id == 42
|
||||
assert result.status == "running" # dequeue 后状态变为 running
|
||||
assert result.config["tasks"] == ["ODS_MEMBER"]
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_dequeue_updates_status_to_running(self, mock_get_conn, queue):
|
||||
task_id = str(uuid.uuid4())
|
||||
config_dict = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}
|
||||
row = (
|
||||
task_id, 42, json.dumps(config_dict), "pending", 1,
|
||||
None, None, None, None, None,
|
||||
)
|
||||
cur = _mock_cursor(fetchone_val=row)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.dequeue(site_id=42)
|
||||
|
||||
# 第二次 execute 调用应该是 UPDATE status = 'running'
|
||||
update_call = cur.execute.call_args_list[1]
|
||||
sql = update_call[0][0]
|
||||
assert "running" in sql
|
||||
assert task_id in update_call[0][1]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reorder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReorder:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_reorder_moves_task(self, mock_get_conn, queue):
|
||||
"""将第 3 个任务移到第 1 位"""
|
||||
ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||
rows = [(i,) for i in ids]
|
||||
cur = _mock_cursor(fetchall_val=rows)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.reorder(ids[2], new_position=1, site_id=42)
|
||||
|
||||
# 重排后顺序应为 [ids[2], ids[0], ids[1]]
|
||||
update_calls = cur.execute.call_args_list[1:] # 跳过 SELECT
|
||||
positions = {}
|
||||
for c in update_calls:
|
||||
pos, tid = c[0][1]
|
||||
positions[tid] = pos
|
||||
assert positions[ids[2]] == 1
|
||||
assert positions[ids[0]] == 2
|
||||
assert positions[ids[1]] == 3
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_reorder_nonexistent_task_is_noop(self, mock_get_conn, queue):
|
||||
"""重排不存在的任务不报错"""
|
||||
rows = [(str(uuid.uuid4()),)]
|
||||
cur = _mock_cursor(fetchall_val=rows)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.reorder("nonexistent-id", new_position=1, site_id=42)
|
||||
|
||||
# 只有 SELECT,没有 UPDATE
|
||||
assert cur.execute.call_count == 1
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_reorder_clamps_position(self, mock_get_conn, queue):
|
||||
"""position 超出范围时 clamp 到有效范围"""
|
||||
ids = [str(uuid.uuid4()) for _ in range(2)]
|
||||
rows = [(i,) for i in ids]
|
||||
cur = _mock_cursor(fetchall_val=rows)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
# new_position=100 超出范围,应 clamp 到末尾
|
||||
queue.reorder(ids[0], new_position=100, site_id=42)
|
||||
|
||||
update_calls = cur.execute.call_args_list[1:]
|
||||
positions = {}
|
||||
for c in update_calls:
|
||||
pos, tid = c[0][1]
|
||||
positions[tid] = pos
|
||||
# ids[0] 移到末尾
|
||||
assert positions[ids[1]] == 1
|
||||
assert positions[ids[0]] == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# delete
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDelete:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_delete_pending_task(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor(rowcount=1)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
result = queue.delete("task-1", site_id=42)
|
||||
|
||||
assert result is True
|
||||
conn.commit.assert_called_once()
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_delete_nonexistent_returns_false(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor(rowcount=0)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
result = queue.delete("nonexistent", site_id=42)
|
||||
|
||||
assert result is False
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_delete_only_affects_pending(self, mock_get_conn, queue):
|
||||
"""DELETE SQL 包含 status = 'pending' 条件"""
|
||||
cur = _mock_cursor(rowcount=0)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.delete("task-1", site_id=42)
|
||||
|
||||
sql = cur.execute.call_args[0][0]
|
||||
assert "pending" in sql
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_pending / has_running
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestQuery:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_list_pending_empty(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor(fetchall_val=[])
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
result = queue.list_pending(site_id=42)
|
||||
|
||||
assert result == []
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_list_pending_returns_tasks(self, mock_get_conn, queue):
|
||||
tid = str(uuid.uuid4())
|
||||
config = json.dumps({"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"})
|
||||
rows = [(tid, 42, config, "pending", 1, None, None, None, None, None)]
|
||||
cur = _mock_cursor(fetchall_val=rows)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
result = queue.list_pending(site_id=42)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == tid
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_has_running_true(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor(fetchone_val=(True,))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
assert queue.has_running(site_id=42) is True
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_has_running_false(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor(fetchone_val=(False,))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
assert queue.has_running(site_id=42) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# process_loop / _process_once
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestProcessLoop:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_once_skips_when_running(self, mock_get_conn, queue):
|
||||
"""有 running 任务时不 dequeue"""
|
||||
# _get_pending_site_ids 返回 [42]
|
||||
# has_running(42) 返回 True
|
||||
call_count = 0
|
||||
|
||||
def side_effect_conn():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# _get_pending_site_ids
|
||||
cur = _mock_cursor(fetchall_val=[(42,)])
|
||||
return _mock_conn(cur)
|
||||
else:
|
||||
# has_running
|
||||
cur = _mock_cursor(fetchone_val=(True,))
|
||||
return _mock_conn(cur)
|
||||
|
||||
mock_get_conn.side_effect = side_effect_conn
|
||||
|
||||
mock_executor = MagicMock()
|
||||
await queue._process_once(mock_executor)
|
||||
|
||||
# 不应调用 execute
|
||||
mock_executor.execute.assert_not_called()
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_once_dequeues_and_executes(self, mock_get_conn, queue):
|
||||
"""无 running 任务时 dequeue 并执行"""
|
||||
task_id = str(uuid.uuid4())
|
||||
config_dict = {
|
||||
"tasks": ["ODS_MEMBER"],
|
||||
"pipeline": "api_ods_dwd",
|
||||
"processing_mode": "increment_only",
|
||||
"dry_run": False,
|
||||
"window_mode": "lookback",
|
||||
"lookback_hours": 24,
|
||||
"overlap_seconds": 600,
|
||||
"fetch_before_verify": False,
|
||||
"skip_ods_when_fetch_before_verify": False,
|
||||
"ods_use_local_json": False,
|
||||
"extra_args": {},
|
||||
}
|
||||
config_json = json.dumps(config_dict)
|
||||
|
||||
call_count = 0
|
||||
|
||||
def side_effect_conn():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# _get_pending_site_ids
|
||||
cur = _mock_cursor(fetchall_val=[(42,)])
|
||||
return _mock_conn(cur)
|
||||
elif call_count == 2:
|
||||
# has_running → False
|
||||
cur = _mock_cursor(fetchone_val=(False,))
|
||||
return _mock_conn(cur)
|
||||
else:
|
||||
# dequeue → 返回任务
|
||||
row = (
|
||||
task_id, 42, config_json, "pending", 1,
|
||||
None, None, None, None, None,
|
||||
)
|
||||
cur = _mock_cursor(fetchone_val=row)
|
||||
return _mock_conn(cur)
|
||||
|
||||
mock_get_conn.side_effect = side_effect_conn
|
||||
|
||||
mock_executor = MagicMock()
|
||||
mock_executor.execute = AsyncMock()
|
||||
|
||||
await queue._process_once(mock_executor)
|
||||
|
||||
# 给 create_task 一点时间启动
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_once_no_pending(self, mock_get_conn, queue):
|
||||
"""无 pending 任务时什么都不做"""
|
||||
cur = _mock_cursor(fetchall_val=[])
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
mock_executor = MagicMock()
|
||||
await queue._process_once(mock_executor)
|
||||
|
||||
mock_executor.execute.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 生命周期
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLifecycle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_sets_running_false(self, queue):
|
||||
queue._running = True
|
||||
queue._loop_task = None
|
||||
|
||||
await queue.stop()
|
||||
|
||||
assert queue._running is False
|
||||
|
||||
def test_start_creates_task(self, queue):
|
||||
"""start() 应创建 asyncio.Task(需要事件循环)"""
|
||||
# 仅验证 _running 初始状态
|
||||
assert queue._running is False
|
||||
assert queue._loop_task is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _mark_failed / _update_queue_status_from_log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestInternalHelpers:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_mark_failed(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor()
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue._mark_failed("queue-1", "测试错误")
|
||||
|
||||
sql = cur.execute.call_args[0][0]
|
||||
assert "failed" in sql
|
||||
args = cur.execute.call_args[0][1]
|
||||
assert args[0] == "测试错误"
|
||||
assert args[1] == "queue-1"
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_update_queue_status_from_log(self, mock_get_conn, queue):
|
||||
"""从 execution_log 同步状态到 task_queue"""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
finished = datetime.now(timezone.utc)
|
||||
# 第一次 fetchone 返回 execution_log 行
|
||||
cur = _mock_cursor(fetchone_val=("success", finished, 0, None))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue._update_queue_status_from_log("queue-1")
|
||||
|
||||
# 应有 SELECT + UPDATE 两次 execute
|
||||
assert cur.execute.call_count == 2
|
||||
conn.commit.assert_called_once()
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_update_queue_status_no_log(self, mock_get_conn, queue):
|
||||
"""execution_log 无记录时不更新"""
|
||||
cur = _mock_cursor(fetchone_val=None)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue._update_queue_status_from_log("queue-1")
|
||||
|
||||
# 只有 SELECT,没有 UPDATE
|
||||
assert cur.execute.call_count == 1
|
||||
299
apps/backend/tests/test_task_registry_properties.py
Normal file
299
apps/backend/tests/test_task_registry_properties.py
Normal file
@@ -0,0 +1,299 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""任务注册表分组属性测试(Property-Based Testing)。
|
||||
|
||||
Property 4: 对于 Task_Registry 中的任务集合,分组结果中每个任务应出现在
|
||||
且仅出现在其所属业务域的分组中。
|
||||
|
||||
Validates: Requirements 2.1
|
||||
|
||||
测试策略:
|
||||
1. 直接测试 get_tasks_grouped_by_domain 函数:
|
||||
- 每个任务出现在且仅出现在其 domain 对应的分组中
|
||||
- 分组中的任务总数等于全部任务数(不多不少)
|
||||
- 每个分组的 key 等于该分组内所有任务的 domain
|
||||
2. 通过 API 端点测试(TestClient + mock auth):
|
||||
- 返回的 groups 中每个任务的 domain 与其所在分组 key 一致
|
||||
- 所有任务都出现在结果中
|
||||
3. 随机子集验证:
|
||||
- 随机选取任务子集,验证分组逻辑的一致性
|
||||
- 随机选取 domain,验证该 domain 下的任务都正确
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-registry")
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from app.services.task_registry import (
|
||||
get_all_tasks,
|
||||
get_tasks_grouped_by_domain,
|
||||
TaskDefinition,
|
||||
)
|
||||
from fastapi.testclient import TestClient
|
||||
from app.main import app
|
||||
from app.auth.dependencies import get_current_user, CurrentUser
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
ALL_TASKS = get_all_tasks()
|
||||
ALL_CODES = [t.code for t in ALL_TASKS]
|
||||
ALL_DOMAINS = list({t.domain for t in ALL_TASKS})
|
||||
|
||||
|
||||
def _mock_user() -> CurrentUser:
|
||||
return CurrentUser(user_id=1, site_id=1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4.1: 分组完整性 — 每个任务出现在且仅出现在其 domain 分组中
|
||||
# Validates: Requirements 2.1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(data=st.data())
|
||||
def test_every_task_in_exactly_its_domain_group(data):
|
||||
"""Property 4.1: 每个任务出现在且仅出现在其所属业务域的分组中。
|
||||
|
||||
从全量任务中随机选取一个任务,验证它只出现在对应 domain 的分组里,
|
||||
且不出现在其他任何分组中。
|
||||
"""
|
||||
grouped = get_tasks_grouped_by_domain()
|
||||
# 随机选取一个任务
|
||||
task = data.draw(st.sampled_from(ALL_TASKS))
|
||||
|
||||
# 该任务必须出现在其 domain 分组中
|
||||
assert task.domain in grouped, (
|
||||
f"任务 {task.code} 的 domain '{task.domain}' 不在分组 keys 中"
|
||||
)
|
||||
domain_codes = [t.code for t in grouped[task.domain]]
|
||||
assert task.code in domain_codes, (
|
||||
f"任务 {task.code} 未出现在其 domain '{task.domain}' 的分组中"
|
||||
)
|
||||
|
||||
# 该任务不应出现在其他任何分组中
|
||||
for other_domain, other_tasks in grouped.items():
|
||||
if other_domain == task.domain:
|
||||
continue
|
||||
other_codes = [t.code for t in other_tasks]
|
||||
assert task.code not in other_codes, (
|
||||
f"任务 {task.code}(domain={task.domain})错误地出现在 "
|
||||
f"domain '{other_domain}' 的分组中"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4.2: 分组总数守恒 — 分组中的任务总数等于全部任务数
|
||||
# Validates: Requirements 2.1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(data=st.data())
|
||||
def test_grouped_total_equals_all_tasks(data):
|
||||
"""Property 4.2: 分组中的任务总数等于全部任务数(不多不少)。
|
||||
|
||||
随机选取若干 domain 进行局部验证,同时验证全局总数守恒。
|
||||
"""
|
||||
all_tasks = get_all_tasks()
|
||||
grouped = get_tasks_grouped_by_domain()
|
||||
|
||||
# 全局守恒:分组内任务总数 == 全量任务数
|
||||
grouped_total = sum(len(tasks) for tasks in grouped.values())
|
||||
assert grouped_total == len(all_tasks), (
|
||||
f"分组总数 {grouped_total} != 全量任务数 {len(all_tasks)}"
|
||||
)
|
||||
|
||||
# 随机选取一个 domain,验证该 domain 下的任务数量正确
|
||||
domain = data.draw(st.sampled_from(ALL_DOMAINS))
|
||||
expected_count = sum(1 for t in all_tasks if t.domain == domain)
|
||||
actual_count = len(grouped[domain])
|
||||
assert actual_count == expected_count, (
|
||||
f"domain '{domain}' 分组内任务数 {actual_count} != 预期 {expected_count}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4.3: 分组 key 一致性 — 每个分组的 key 等于组内所有任务的 domain
|
||||
# Validates: Requirements 2.1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(data=st.data())
|
||||
def test_group_key_matches_task_domains(data):
|
||||
"""Property 4.3: 每个分组的 key 等于该分组内所有任务的 domain。
|
||||
|
||||
随机选取一个 domain 分组,验证组内每个任务的 domain 字段都等于分组 key。
|
||||
"""
|
||||
grouped = get_tasks_grouped_by_domain()
|
||||
domain = data.draw(st.sampled_from(list(grouped.keys())))
|
||||
|
||||
for task in grouped[domain]:
|
||||
assert task.domain == domain, (
|
||||
f"分组 '{domain}' 中的任务 {task.code} 的 domain 为 "
|
||||
f"'{task.domain}',与分组 key 不一致"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4.4: 任务 code 全局唯一 — 分组后不应出现重复 code
|
||||
# Validates: Requirements 2.1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(data=st.data())
|
||||
def test_no_duplicate_codes_across_groups(data):
|
||||
"""Property 4.4: 分组后所有任务的 code 全局唯一,无重复。
|
||||
|
||||
随机选取若干 domain 的任务合并,验证 code 不重复。
|
||||
"""
|
||||
grouped = get_tasks_grouped_by_domain()
|
||||
|
||||
# 收集所有分组中的 code
|
||||
all_codes_in_groups = []
|
||||
for tasks in grouped.values():
|
||||
all_codes_in_groups.extend(t.code for t in tasks)
|
||||
|
||||
assert len(all_codes_in_groups) == len(set(all_codes_in_groups)), (
|
||||
"分组中存在重复的任务 code"
|
||||
)
|
||||
|
||||
# 随机选取两个不同 domain,验证它们的任务 code 无交集
|
||||
if len(ALL_DOMAINS) >= 2:
|
||||
domains = data.draw(
|
||||
st.lists(st.sampled_from(ALL_DOMAINS), min_size=2, max_size=2, unique=True)
|
||||
)
|
||||
codes_a = {t.code for t in grouped[domains[0]]}
|
||||
codes_b = {t.code for t in grouped[domains[1]]}
|
||||
overlap = codes_a & codes_b
|
||||
assert not overlap, (
|
||||
f"domain '{domains[0]}' 和 '{domains[1]}' 存在重叠任务 code: {overlap}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4.5: 随机子集分组一致性 — 子集中的任务分组结果与全量一致
|
||||
# Validates: Requirements 2.1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(
|
||||
indices=st.lists(
|
||||
st.integers(min_value=0, max_value=len(ALL_TASKS) - 1),
|
||||
min_size=1,
|
||||
max_size=min(20, len(ALL_TASKS)),
|
||||
unique=True,
|
||||
)
|
||||
)
|
||||
def test_subset_grouping_consistency(indices):
|
||||
"""Property 4.5: 随机选取任务子集,验证每个任务在全量分组中的归属正确。
|
||||
|
||||
对于随机选取的任务子集,每个任务在 get_tasks_grouped_by_domain() 的结果中
|
||||
都应出现在其 domain 对应的分组里。
|
||||
"""
|
||||
grouped = get_tasks_grouped_by_domain()
|
||||
subset = [ALL_TASKS[i] for i in indices]
|
||||
|
||||
for task in subset:
|
||||
# 任务的 domain 必须是分组的 key 之一
|
||||
assert task.domain in grouped
|
||||
# 任务必须在对应分组中
|
||||
group_codes = {t.code for t in grouped[task.domain]}
|
||||
assert task.code in group_codes, (
|
||||
f"任务 {task.code} 未出现在 domain '{task.domain}' 的分组中"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4.6: API 端点分组正确性 — GET /api/tasks/registry 返回一致的分组
|
||||
# Validates: Requirements 2.1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(data=st.data())
|
||||
def test_api_registry_grouping_correctness(data):
|
||||
"""Property 4.6: API 端点返回的分组中,每个任务的 domain 与分组 key 一致,
|
||||
且所有任务都出现在结果中。
|
||||
"""
|
||||
app.dependency_overrides[get_current_user] = _mock_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/tasks/registry")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
groups = body["groups"]
|
||||
|
||||
# 收集 API 返回的所有任务 code
|
||||
api_codes: set[str] = set()
|
||||
for domain_key, task_list in groups.items():
|
||||
for task_item in task_list:
|
||||
# 每个任务的 domain 必须等于分组 key
|
||||
assert task_item["domain"] == domain_key, (
|
||||
f"API 返回的任务 {task_item['code']}(domain={task_item['domain']})"
|
||||
f"出现在分组 '{domain_key}' 中,不一致"
|
||||
)
|
||||
api_codes.add(task_item["code"])
|
||||
|
||||
# 所有任务都应出现在 API 结果中
|
||||
all_codes_set = {t.code for t in get_all_tasks()}
|
||||
assert api_codes == all_codes_set, (
|
||||
f"API 返回的任务集合与全量任务不一致。"
|
||||
f"缺失: {all_codes_set - api_codes},"
|
||||
f"多余: {api_codes - all_codes_set}"
|
||||
)
|
||||
|
||||
# 随机选取一个 domain,验证该 domain 下的任务数量与服务层一致
|
||||
if groups:
|
||||
domain = data.draw(st.sampled_from(list(groups.keys())))
|
||||
expected = get_tasks_grouped_by_domain()
|
||||
assert len(groups[domain]) == len(expected[domain]), (
|
||||
f"API 返回的 domain '{domain}' 任务数 {len(groups[domain])} "
|
||||
f"!= 服务层 {len(expected[domain])}"
|
||||
)
|
||||
finally:
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4.7: 随机 domain 过滤验证
|
||||
# Validates: Requirements 2.1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(domain=st.sampled_from(ALL_DOMAINS))
|
||||
def test_random_domain_tasks_all_correct(domain):
|
||||
"""Property 4.7: 随机选取一个 domain,验证该 domain 下的所有任务都正确归属。
|
||||
|
||||
对于选定的 domain:
|
||||
- 分组中的每个任务的 domain 字段都等于选定的 domain
|
||||
- 全量任务中所有属于该 domain 的任务都出现在分组中
|
||||
"""
|
||||
grouped = get_tasks_grouped_by_domain()
|
||||
all_tasks = get_all_tasks()
|
||||
|
||||
# 分组中该 domain 的任务
|
||||
group_tasks = grouped.get(domain, [])
|
||||
|
||||
# 全量任务中属于该 domain 的任务
|
||||
expected_tasks = [t for t in all_tasks if t.domain == domain]
|
||||
|
||||
# 数量一致
|
||||
assert len(group_tasks) == len(expected_tasks), (
|
||||
f"domain '{domain}': 分组中 {len(group_tasks)} 个任务,"
|
||||
f"预期 {len(expected_tasks)} 个"
|
||||
)
|
||||
|
||||
# code 集合一致
|
||||
group_codes = {t.code for t in group_tasks}
|
||||
expected_codes = {t.code for t in expected_tasks}
|
||||
assert group_codes == expected_codes, (
|
||||
f"domain '{domain}': 分组 codes {group_codes} != 预期 {expected_codes}"
|
||||
)
|
||||
|
||||
# 每个任务的 domain 字段都正确
|
||||
for task in group_tasks:
|
||||
assert task.domain == domain
|
||||
274
apps/backend/tests/test_tasks_router.py
Normal file
274
apps/backend/tests/test_tasks_router.py
Normal file
@@ -0,0 +1,274 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""任务注册表 API 单元测试
|
||||
|
||||
覆盖 4 个端点:registry / dwd-tables / flows / validate
|
||||
通过 JWT mock 绕过认证依赖。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
from app.services.task_registry import (
|
||||
ALL_TASKS,
|
||||
DWD_TABLES,
|
||||
FLOW_LAYER_MAP,
|
||||
get_tasks_grouped_by_domain,
|
||||
)
|
||||
|
||||
# 固定测试用户
|
||||
_TEST_USER = CurrentUser(user_id=1, site_id=100)
|
||||
|
||||
|
||||
def _override_auth():
|
||||
return _TEST_USER
|
||||
|
||||
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/tasks/registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTaskRegistry:
|
||||
def setup_method(self):
|
||||
"""每个测试方法前重新设置 auth 覆盖,防止其他测试文件的 clear/pop 导致状态泄漏"""
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
|
||||
def test_registry_returns_grouped_tasks(self):
|
||||
resp = client.get("/api/tasks/registry")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "groups" in data
|
||||
|
||||
# 所有任务都应出现在某个分组中
|
||||
all_codes_in_response = set()
|
||||
for domain, tasks in data["groups"].items():
|
||||
for t in tasks:
|
||||
all_codes_in_response.add(t["code"])
|
||||
assert t["domain"] == domain
|
||||
|
||||
expected_codes = {t.code for t in ALL_TASKS}
|
||||
assert all_codes_in_response == expected_codes
|
||||
|
||||
def test_registry_task_fields_complete(self):
|
||||
"""每个任务项包含所有必要字段"""
|
||||
resp = client.get("/api/tasks/registry")
|
||||
data = resp.json()
|
||||
required_fields = {"code", "name", "description", "domain", "layer",
|
||||
"requires_window", "is_ods", "is_dimension", "default_enabled"}
|
||||
for tasks in data["groups"].values():
|
||||
for t in tasks:
|
||||
assert required_fields.issubset(t.keys())
|
||||
|
||||
def test_registry_requires_auth(self):
|
||||
"""未认证时返回 401"""
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
try:
|
||||
resp = client.get("/api/tasks/registry")
|
||||
assert resp.status_code == 401
|
||||
finally:
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/tasks/dwd-tables
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDwdTables:
|
||||
def test_dwd_tables_returns_grouped(self):
|
||||
resp = client.get("/api/tasks/dwd-tables")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "groups" in data
|
||||
|
||||
all_tables_in_response = set()
|
||||
for domain, tables in data["groups"].items():
|
||||
for t in tables:
|
||||
all_tables_in_response.add(t["table_name"])
|
||||
assert t["domain"] == domain
|
||||
|
||||
expected_tables = {t.table_name for t in DWD_TABLES}
|
||||
assert all_tables_in_response == expected_tables
|
||||
|
||||
def test_dwd_tables_fields_complete(self):
|
||||
resp = client.get("/api/tasks/dwd-tables")
|
||||
data = resp.json()
|
||||
required_fields = {"table_name", "display_name", "domain", "ods_source", "is_dimension"}
|
||||
for tables in data["groups"].values():
|
||||
for t in tables:
|
||||
assert required_fields.issubset(t.keys())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/tasks/flows
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFlows:
|
||||
def test_flows_returns_seven_flows(self):
|
||||
resp = client.get("/api/tasks/flows")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["flows"]) == 7
|
||||
|
||||
def test_flows_returns_three_processing_modes(self):
|
||||
resp = client.get("/api/tasks/flows")
|
||||
data = resp.json()
|
||||
assert len(data["processing_modes"]) == 3
|
||||
|
||||
def test_flow_ids_match_registry(self):
|
||||
"""Flow ID 与 FLOW_LAYER_MAP 一致"""
|
||||
resp = client.get("/api/tasks/flows")
|
||||
data = resp.json()
|
||||
flow_ids = {f["id"] for f in data["flows"]}
|
||||
assert flow_ids == set(FLOW_LAYER_MAP.keys())
|
||||
|
||||
def test_flow_layers_non_empty(self):
|
||||
resp = client.get("/api/tasks/flows")
|
||||
data = resp.json()
|
||||
for f in data["flows"]:
|
||||
assert len(f["layers"]) > 0
|
||||
|
||||
def test_processing_mode_ids(self):
|
||||
resp = client.get("/api/tasks/flows")
|
||||
data = resp.json()
|
||||
mode_ids = {m["id"] for m in data["processing_modes"]}
|
||||
assert mode_ids == {"increment_only", "verify_only", "increment_verify"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/tasks/validate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestValidate:
|
||||
def test_validate_success(self):
|
||||
resp = client.post("/api/tasks/validate", json={
|
||||
"config": {
|
||||
"tasks": ["ODS_MEMBER", "ODS_PAYMENT"],
|
||||
"pipeline": "api_ods",
|
||||
}
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["valid"] is True
|
||||
assert data["errors"] == []
|
||||
assert len(data["command_args"]) > 0
|
||||
assert "--store-id" in data["command"]
|
||||
# store_id 应从 JWT 注入(测试用户 site_id=100)
|
||||
assert "100" in data["command"]
|
||||
|
||||
def test_validate_injects_store_id(self):
|
||||
"""即使前端传了 store_id,后端也用 JWT 中的值覆盖"""
|
||||
resp = client.post("/api/tasks/validate", json={
|
||||
"config": {
|
||||
"tasks": ["DWD_LOAD_FROM_ODS"],
|
||||
"pipeline": "ods_dwd",
|
||||
"store_id": 999,
|
||||
}
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
# 命令中应包含 JWT 的 site_id=100,而非前端传的 999
|
||||
assert "--store-id" in data["command"]
|
||||
idx = data["command_args"].index("--store-id")
|
||||
assert data["command_args"][idx + 1] == "100"
|
||||
|
||||
def test_validate_invalid_flow(self):
|
||||
resp = client.post("/api/tasks/validate", json={
|
||||
"config": {
|
||||
"tasks": ["ODS_MEMBER"],
|
||||
"pipeline": "nonexistent_flow",
|
||||
}
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["valid"] is False
|
||||
assert any("无效的执行流程" in e for e in data["errors"])
|
||||
|
||||
def test_validate_empty_tasks(self):
|
||||
resp = client.post("/api/tasks/validate", json={
|
||||
"config": {
|
||||
"tasks": [],
|
||||
"pipeline": "api_ods",
|
||||
}
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["valid"] is False
|
||||
assert any("任务列表不能为空" in e for e in data["errors"])
|
||||
|
||||
def test_validate_custom_window(self):
|
||||
resp = client.post("/api/tasks/validate", json={
|
||||
"config": {
|
||||
"tasks": ["ODS_MEMBER"],
|
||||
"pipeline": "api_ods",
|
||||
"window_mode": "custom",
|
||||
"window_start": "2024-01-01",
|
||||
"window_end": "2024-01-31",
|
||||
}
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["valid"] is True
|
||||
assert "--window-start" in data["command"]
|
||||
assert "--window-end" in data["command"]
|
||||
|
||||
def test_validate_window_end_before_start_rejected(self):
|
||||
"""window_end 早于 window_start 时 Pydantic 验证失败 → 422"""
|
||||
resp = client.post("/api/tasks/validate", json={
|
||||
"config": {
|
||||
"tasks": ["ODS_MEMBER"],
|
||||
"pipeline": "api_ods",
|
||||
"window_mode": "custom",
|
||||
"window_start": "2024-12-31",
|
||||
"window_end": "2024-01-01",
|
||||
}
|
||||
})
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_validate_dry_run_flag(self):
|
||||
resp = client.post("/api/tasks/validate", json={
|
||||
"config": {
|
||||
"tasks": ["ODS_MEMBER"],
|
||||
"pipeline": "api_ods",
|
||||
"dry_run": True,
|
||||
}
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "--dry-run" in data["command"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# task_registry 服务层测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTaskRegistryService:
|
||||
def test_all_tasks_have_unique_codes(self):
|
||||
codes = [t.code for t in ALL_TASKS]
|
||||
assert len(codes) == len(set(codes))
|
||||
|
||||
def test_grouped_tasks_cover_all(self):
|
||||
grouped = get_tasks_grouped_by_domain()
|
||||
all_codes = set()
|
||||
for tasks in grouped.values():
|
||||
for t in tasks:
|
||||
all_codes.add(t.code)
|
||||
assert all_codes == {t.code for t in ALL_TASKS}
|
||||
|
||||
def test_ods_tasks_marked_is_ods(self):
|
||||
for t in ALL_TASKS:
|
||||
if t.layer == "ODS":
|
||||
assert t.is_ods is True
|
||||
|
||||
def test_flow_layer_map_covers_all_flows(self):
|
||||
expected_flows = {"api_ods", "api_ods_dwd", "api_full", "ods_dwd",
|
||||
"dwd_dws", "dwd_dws_index", "dwd_index"}
|
||||
assert set(FLOW_LAYER_MAP.keys()) == expected_flows
|
||||
186
apps/backend/tests/test_ws_logs.py
Normal file
186
apps/backend/tests/test_ws_logs.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""WebSocket 日志推送端点测试
|
||||
|
||||
测试 /ws/logs/{execution_id} 端点的连接、日志回放、实时推送和断开行为。
|
||||
利用 TaskExecutor 已有的 subscribe/broadcast 机制进行验证。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
from app.main import app
|
||||
from app.services.task_executor import task_executor
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _cleanup_executor():
|
||||
"""每个测试前后清理 TaskExecutor 内部状态。"""
|
||||
yield
|
||||
# 清理所有残留的缓冲区和订阅者
|
||||
for eid in list(task_executor._log_buffers.keys()):
|
||||
task_executor.cleanup(eid)
|
||||
task_executor._subscribers.clear()
|
||||
task_executor._log_buffers.clear()
|
||||
|
||||
|
||||
class TestWebSocketConnection:
|
||||
"""WebSocket 连接/断开基本行为"""
|
||||
|
||||
def test_connect_and_disconnect(self):
|
||||
"""客户端能成功建立和关闭 WebSocket 连接。"""
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/ws/logs/test-exec-001") as ws:
|
||||
# 连接成功,直接关闭
|
||||
pass # __exit__ 会关闭连接
|
||||
|
||||
def test_connect_registers_subscriber(self):
|
||||
"""连接后 TaskExecutor 应注册订阅者。"""
|
||||
client = TestClient(app)
|
||||
# 预先初始化缓冲区(模拟有任务在运行)
|
||||
task_executor._log_buffers["test-exec-002"] = []
|
||||
|
||||
with client.websocket_connect("/ws/logs/test-exec-002"):
|
||||
# 连接期间应有订阅者
|
||||
assert "test-exec-002" in task_executor._subscribers
|
||||
assert len(task_executor._subscribers["test-exec-002"]) >= 1
|
||||
|
||||
|
||||
class TestLogReplay:
|
||||
"""历史日志回放"""
|
||||
|
||||
def test_replay_existing_logs(self):
|
||||
"""连接时应先收到内存缓冲区中已有的日志行。"""
|
||||
eid = "test-exec-replay"
|
||||
# 预填充日志缓冲区
|
||||
task_executor._log_buffers[eid] = [
|
||||
"[stdout] 第一行",
|
||||
"[stdout] 第二行",
|
||||
"[stderr] 警告信息",
|
||||
]
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
|
||||
# 应按顺序收到 3 条历史日志
|
||||
msg1 = ws.receive_text()
|
||||
msg2 = ws.receive_text()
|
||||
msg3 = ws.receive_text()
|
||||
|
||||
assert msg1 == "[stdout] 第一行"
|
||||
assert msg2 == "[stdout] 第二行"
|
||||
assert msg3 == "[stderr] 警告信息"
|
||||
|
||||
def test_no_logs_no_replay(self):
|
||||
"""没有历史日志时不应收到回放消息。"""
|
||||
eid = "test-exec-empty"
|
||||
task_executor._log_buffers[eid] = []
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
|
||||
# 发送结束信号让连接正常关闭
|
||||
task_executor._broadcast_end(eid)
|
||||
# 不应有回放消息,直接收到的是后续的结束信号处理
|
||||
|
||||
|
||||
class TestBroadcastReceive:
|
||||
"""实时日志广播接收"""
|
||||
|
||||
def test_receive_broadcast_messages(self):
|
||||
"""连接后应能收到 TaskExecutor 广播的实时日志。"""
|
||||
eid = "test-exec-broadcast"
|
||||
task_executor._log_buffers[eid] = []
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
|
||||
# 模拟 TaskExecutor 广播日志
|
||||
task_executor._broadcast(eid, "[stdout] 实时日志行1")
|
||||
task_executor._broadcast(eid, "[stderr] 实时错误行")
|
||||
|
||||
msg1 = ws.receive_text()
|
||||
msg2 = ws.receive_text()
|
||||
|
||||
assert msg1 == "[stdout] 实时日志行1"
|
||||
assert msg2 == "[stderr] 实时错误行"
|
||||
|
||||
def test_end_signal_closes_connection(self):
|
||||
"""收到 None 结束信号后 WebSocket 应正常关闭。"""
|
||||
eid = "test-exec-end"
|
||||
task_executor._log_buffers[eid] = []
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
|
||||
# 广播一条日志后发送结束信号
|
||||
task_executor._broadcast(eid, "[stdout] 最后一行")
|
||||
task_executor._broadcast_end(eid)
|
||||
|
||||
msg = ws.receive_text()
|
||||
assert msg == "[stdout] 最后一行"
|
||||
|
||||
def test_replay_then_broadcast(self):
|
||||
"""先回放历史日志,再接收实时广播。"""
|
||||
eid = "test-exec-mixed"
|
||||
task_executor._log_buffers[eid] = ["[stdout] 历史行"]
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect(f"/ws/logs/{eid}") as ws:
|
||||
# 先收到历史回放
|
||||
replay = ws.receive_text()
|
||||
assert replay == "[stdout] 历史行"
|
||||
|
||||
# 再收到实时广播
|
||||
task_executor._broadcast(eid, "[stdout] 新行")
|
||||
task_executor._broadcast_end(eid)
|
||||
|
||||
live = ws.receive_text()
|
||||
assert live == "[stdout] 新行"
|
||||
|
||||
|
||||
class TestLogBroadcasterUnit:
|
||||
"""直接测试 TaskExecutor 的 subscribe/unsubscribe/broadcast 方法
|
||||
(作为 LogBroadcaster 功能的单元测试)。
|
||||
"""
|
||||
|
||||
def test_subscribe_creates_queue(self):
|
||||
eid = "unit-sub"
|
||||
q = task_executor.subscribe(eid)
|
||||
assert isinstance(q, asyncio.Queue)
|
||||
assert eid in task_executor._subscribers
|
||||
task_executor.unsubscribe(eid, q)
|
||||
|
||||
def test_unsubscribe_removes_queue(self):
|
||||
eid = "unit-unsub"
|
||||
q = task_executor.subscribe(eid)
|
||||
task_executor.unsubscribe(eid, q)
|
||||
# 最后一个订阅者移除后,key 也应被清理
|
||||
assert eid not in task_executor._subscribers
|
||||
|
||||
def test_broadcast_delivers_to_all_subscribers(self):
|
||||
eid = "unit-multi"
|
||||
q1 = task_executor.subscribe(eid)
|
||||
q2 = task_executor.subscribe(eid)
|
||||
|
||||
task_executor._broadcast(eid, "测试消息")
|
||||
|
||||
assert q1.get_nowait() == "测试消息"
|
||||
assert q2.get_nowait() == "测试消息"
|
||||
|
||||
task_executor.unsubscribe(eid, q1)
|
||||
task_executor.unsubscribe(eid, q2)
|
||||
|
||||
def test_broadcast_end_sends_none(self):
|
||||
eid = "unit-end"
|
||||
q = task_executor.subscribe(eid)
|
||||
|
||||
task_executor._broadcast_end(eid)
|
||||
|
||||
assert q.get_nowait() is None
|
||||
task_executor.unsubscribe(eid, q)
|
||||
|
||||
def test_broadcast_no_subscribers_is_safe(self):
|
||||
"""没有订阅者时广播不应报错。"""
|
||||
task_executor._broadcast("nonexistent", "无人接收")
|
||||
task_executor._broadcast_end("nonexistent")
|
||||
@@ -2,14 +2,14 @@
|
||||
|
||||
## 作用说明
|
||||
|
||||
ETL 数据管线集合。每个上游数据源对应 `pipelines/` 下的一个子目录,当前仅有飞球平台(`feiqiu`)。管线负责从 SaaS API 抽取数据,经 ODS→DWD→Core→DWS 逐层处理后落库。
|
||||
ETL Connector(数据源连接器)集合。每个上游数据源对应 `pipelines/` 下的一个子目录(即一个 Connector),当前仅有飞球平台(`feiqiu`)。Connector 负责从 SaaS API 抽取数据,经 ODS→DWD→Core→DWS 逐层处理后落库。
|
||||
|
||||
## 内部结构
|
||||
|
||||
- `pipelines/feiqiu/` — 飞球平台 ETL(api、cli、config、loaders、models、orchestration、scd、tasks、utils、quality、tests)
|
||||
- `pipelines/feiqiu/` — 飞球 Connector(api、cli、config、loaders、models、orchestration、scd、tasks、utils、quality、tests)
|
||||
|
||||
## Roadmap
|
||||
|
||||
- 将通用抽取/加载逻辑抽离为 `etl_sdk` 共享包,供多管线复用
|
||||
- 将通用抽取/加载逻辑抽离为 `etl_sdk` 共享包,供多 Connector 复用
|
||||
- 将各平台 API 客户端拆分为独立 `connectors` 包,实现可插拔数据源接入
|
||||
- 新增管线时在 `pipelines/` 下创建同构子目录
|
||||
- 新增 Connector 时在 `pipelines/` 下创建同构子目录
|
||||
|
||||
212
apps/etl/connectors/feiqiu/.env
Normal file
212
apps/etl/connectors/feiqiu/.env
Normal file
@@ -0,0 +1,212 @@
|
||||
# ==============================================================================
|
||||
# NeoZQYY ETL Connector(飞球)配置
|
||||
# ==============================================================================
|
||||
# ETL env_parser.py 从此文件加载
|
||||
# 优先级:DEFAULTS < 此 .env < 环境变量 < CLI 参数
|
||||
# 敏感值禁止提交;本文件已在 .gitignore 中排除
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 门店配置
|
||||
# ------------------------------------------------------------------------------
|
||||
STORE_ID=2790685415443269
|
||||
TIMEZONE=Asia/Shanghai
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 数据库配置
|
||||
# ------------------------------------------------------------------------------
|
||||
# CHANGE 2026-02-15 | 默认指向测试库,生产环境切换为 etl_feiqiu
|
||||
PG_DSN=postgresql://local-Python:Neo-local-1991125@100.64.0.4:5432/test_etl_feiqiu
|
||||
PG_CONNECT_TIMEOUT=10
|
||||
|
||||
# 数据库 Schema
|
||||
SCHEMA_OLTP=ods
|
||||
SCHEMA_ETL=meta
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# API 配置(上游 SaaS API)
|
||||
# ------------------------------------------------------------------------------
|
||||
API_BASE=https://pc.ficoo.vip/apiprod/admin/v1/
|
||||
API_TOKEN=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJjbGllbnQtdHlwZSI6IjQiLCJ1c2VyLXR5cGUiOiIxIiwiaHR0cDovL3NjaGVtYXMubWljcm9zb2Z0LmNvbS93cy8yMDA4LzA2L2lkZW50aXR5L2NsYWltcy9yb2xlIjoiMTIiLCJyb2xlLWlkIjoiMTIiLCJ0ZW5hbnQtaWQiOiIyNzkwNjgzMTYwNzA5OTU3Iiwibmlja25hbWUiOiLnp5_miLfnrqHnkIblkZjvvJrmganmgakxIiwic2l0ZS1pZCI6IjAiLCJtb2JpbGUiOiIxMzgxMDUwMjMwNCIsInNpZCI6IjI5NTA0ODk2NTgzOTU4NDUiLCJzdGFmZi1pZCI6IjMwMDk5MTg2OTE1NTkwNDUiLCJvcmctaWQiOiIwIiwicm9sZS10eXBlIjoiMyIsInJlZnJlc2hUb2tlbiI6InoxazVzWjlDeEFKYnFkNG1pT3NwUzBsQTRMYUNGcURkQjJBdFdsQk1DbDA9IiwicmVmcmVzaEV4cGlyeVRpbWUiOiIyMDI2LzIvMjIg5LiL5Y2IMTE6NTk6MzAiLCJuZWVkQ2hlY2tUb2tlbiI6ImZhbHNlIiwiZXhwIjoxNzcxNzc1OTcwLCJpc3MiOiJ0ZXN0IiwiYXVkIjoiVXNlciJ9.27D1QgKFYGgMKR9bS5NbCSl4kIf9oFVOQLsFl_ITxdI
|
||||
API_TIMEOUT=20
|
||||
API_PAGE_SIZE=200
|
||||
API_RETRY_MAX=3
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 路径配置(已更新为 NeoZQYY 路径)
|
||||
# ------------------------------------------------------------------------------
|
||||
EXPORT_ROOT=C:/NeoZQYY/export/ETL/JSON
|
||||
LOG_ROOT=C:/NeoZQYY/export/ETL/LOG
|
||||
FETCH_ROOT=C:/NeoZQYY/export/ETL/JSON
|
||||
WRITE_PRETTY_JSON=true
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 管线流程配置
|
||||
# ------------------------------------------------------------------------------
|
||||
PIPELINE_FLOW=FULL
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 时间窗口配置
|
||||
# ------------------------------------------------------------------------------
|
||||
OVERLAP_SECONDS=600
|
||||
WINDOW_BUSY_MIN=30
|
||||
WINDOW_IDLE_MIN=180
|
||||
IDLE_START=04:00
|
||||
IDLE_END=16:00
|
||||
WINDOW_SPLIT_UNIT=day
|
||||
WINDOW_SPLIT_DAYS=10
|
||||
WINDOW_COMPENSATION_HOURS=2
|
||||
ALLOW_EMPTY_RESULT_ADVANCE=true
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 快照配置
|
||||
# ------------------------------------------------------------------------------
|
||||
SNAPSHOT_MISSING_DELETE=true
|
||||
SNAPSHOT_ALLOW_EMPTY_DELETE=false
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 数据完整性检查配置
|
||||
# ------------------------------------------------------------------------------
|
||||
INTEGRITY_MODE=history
|
||||
INTEGRITY_HISTORY_START=2025-07-01
|
||||
INTEGRITY_INCLUDE_DIMENSIONS=true
|
||||
INTEGRITY_AUTO_CHECK=false
|
||||
INTEGRITY_AUTO_BACKFILL=false
|
||||
INTEGRITY_COMPARE_CONTENT=true
|
||||
INTEGRITY_CONTENT_SAMPLE_LIMIT=50
|
||||
INTEGRITY_BACKFILL_MISMATCH=true
|
||||
INTEGRITY_RECHECK_AFTER_BACKFILL=true
|
||||
|
||||
# 指定 ODS 任务代码(逗号分隔,为空则全部)
|
||||
# INTEGRITY_ODS_TASK_CODES=
|
||||
|
||||
# 是否强制按月切分(默认 true)
|
||||
# INTEGRITY_FORCE_MONTHLY_SPLIT=true
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 校验配置
|
||||
# ------------------------------------------------------------------------------
|
||||
VERIFY_SKIP_ODS_ON_FETCH=true
|
||||
VERIFY_ODS_LOCAL_JSON=true
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 数据库会话参数(defaults.py → db.session.*)
|
||||
# ------------------------------------------------------------------------------
|
||||
# 会话时区(默认跟随 TIMEZONE)
|
||||
# DB_SESSION_TIMEZONE=Asia/Shanghai
|
||||
|
||||
# SQL 语句超时(毫秒,默认 30000)
|
||||
# DB_STATEMENT_TIMEOUT_MS=30000
|
||||
|
||||
# 锁等待超时(毫秒,默认 5000)
|
||||
# DB_LOCK_TIMEOUT_MS=5000
|
||||
|
||||
# 事务空闲超时(毫秒,默认 600000 = 10 分钟)
|
||||
# DB_IDLE_IN_TX_TIMEOUT_MS=600000
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 清洗配置(defaults.py → clean.*)
|
||||
# ------------------------------------------------------------------------------
|
||||
# 是否记录未知字段(默认 true)
|
||||
# CLEAN_LOG_UNKNOWN_FIELDS=true
|
||||
|
||||
# 未知字段日志上限(默认 50)
|
||||
# CLEAN_UNKNOWN_FIELDS_LIMIT=50
|
||||
|
||||
# 哈希算法(默认 sha1)
|
||||
# CLEAN_HASH_ALGO=sha1
|
||||
|
||||
# 哈希盐值(默认空)
|
||||
# CLEAN_HASH_SALT=
|
||||
|
||||
# 严格数值校验(默认 true)
|
||||
# CLEAN_STRICT_NUMERIC=true
|
||||
|
||||
# 金额舍入精度(默认 2 位小数)
|
||||
# CLEAN_ROUND_MONEY_SCALE=2
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 安全配置(defaults.py → security.*)
|
||||
# ------------------------------------------------------------------------------
|
||||
# 日志中是否脱敏(默认 true)
|
||||
# SECURITY_REDACT_IN_LOGS=true
|
||||
|
||||
# 需脱敏的键名(JSON 数组,默认 ["token","password","Authorization"])
|
||||
# SECURITY_REDACT_KEYS=["token","password","Authorization"]
|
||||
|
||||
# 日志中是否回显 token(默认 false,调试用)
|
||||
# SECURITY_ECHO_TOKEN_IN_LOGS=false
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# IO 文件大小限制(defaults.py → io.max_file_bytes)
|
||||
# ------------------------------------------------------------------------------
|
||||
# 单文件最大字节数(默认 50MB = 52428800)
|
||||
# MAX_FILE_BYTES=52428800
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# DWD 层配置
|
||||
# ------------------------------------------------------------------------------
|
||||
DWD_FACT_UPSERT=true
|
||||
|
||||
# 事实表 UPSERT 批量大小(默认 1000)
|
||||
# DWD_FACT_UPSERT_BATCH_SIZE=1000
|
||||
|
||||
# 最小批量大小(锁冲突时自动缩小,默认 100)
|
||||
# DWD_FACT_UPSERT_MIN_BATCH_SIZE=100
|
||||
|
||||
# 最大重试次数(默认 2)
|
||||
# DWD_FACT_UPSERT_MAX_RETRIES=2
|
||||
|
||||
# 重试退避时间(JSON 数组,秒,默认 [1,2,4])
|
||||
# DWD_FACT_UPSERT_RETRY_BACKOFF=[1,2,4]
|
||||
|
||||
# 事实表 backfill 锁等待超时(毫秒,为空则沿用 DB_LOCK_TIMEOUT_MS)
|
||||
# DWD_FACT_UPSERT_LOCK_TIMEOUT_MS=
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 任务列表配置
|
||||
# ------------------------------------------------------------------------------
|
||||
RUN_TASKS=PRODUCTS,TABLES,MEMBERS,ASSISTANTS,PACKAGES_DEF,ORDERS,PAYMENTS,REFUNDS,COUPON_USAGE,INVENTORY_CHANGE,TOPUPS,TABLE_DISCOUNT,ASSISTANT_ABOLISH,LEDGER
|
||||
INDEX_LOOKBACK_DAYS=60
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# DWS 月度/薪资配置(defaults.py → dws.*)
|
||||
# ------------------------------------------------------------------------------
|
||||
# 是否允许历史月度重算(默认 false)
|
||||
# DWS_MONTHLY_ALLOW_HISTORY=false
|
||||
|
||||
# 上月宽限天数(默认 5,即次月 1-5 号仍可计算上月)
|
||||
# DWS_MONTHLY_PREV_GRACE_DAYS=5
|
||||
|
||||
# 历史月份数(默认 0,即不回溯)
|
||||
# DWS_MONTHLY_HISTORY_MONTHS=0
|
||||
|
||||
# 新人封顶生效日期(默认 2026-03-01)
|
||||
# DWS_MONTHLY_NEW_HIRE_CAP_EFFECTIVE_FROM=2026-03-01
|
||||
|
||||
# 新人封顶天数(默认 25)
|
||||
# DWS_MONTHLY_NEW_HIRE_CAP_DAY=25
|
||||
|
||||
# 新人最高等级(默认 2)
|
||||
# DWS_MONTHLY_NEW_HIRE_MAX_TIER_LEVEL=2
|
||||
|
||||
# 薪资计算运行天数(默认 5)
|
||||
# DWS_SALARY_RUN_DAYS=5
|
||||
|
||||
# 是否允许非周期内运行(默认 false)
|
||||
# DWS_SALARY_ALLOW_OUT_OF_CYCLE=false
|
||||
|
||||
# 包房课单价(默认 138)
|
||||
# DWS_SALARY_ROOM_COURSE_PRICE=138
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 运行模式(defaults.py → run.data_source)
|
||||
# ------------------------------------------------------------------------------
|
||||
# 数据源模式:hybrid(默认,API+本地)、online(仅 API)、offline(仅本地)
|
||||
# 也可通过 PIPELINE_FLOW 间接设置(FULL→hybrid, FETCH_ONLY→online, INGEST_ONLY→offline)
|
||||
# DATA_SOURCE=hybrid
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# API 额外请求头(defaults.py → api.headers_extra)
|
||||
# ------------------------------------------------------------------------------
|
||||
# JSON 对象格式,默认空
|
||||
# API_HEADERS_EXTRA={}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user