初始提交:飞球 ETL 系统全量代码
This commit is contained in:
38
scripts/README.md
Normal file
38
scripts/README.md
Normal file
@@ -0,0 +1,38 @@
|
||||
# scripts/ — 运维与工具脚本
|
||||
|
||||
## 子目录
|
||||
|
||||
| 目录 | 用途 | 典型场景 |
|
||||
|------|------|----------|
|
||||
| `audit/` | 仓库审计(文件清单、调用流、文档对齐分析) | `python -m scripts.audit.run_audit` |
|
||||
| `check/` | 数据检查(ODS 缺口、内容哈希、完整性校验) | `python -m scripts.check.check_data_integrity` |
|
||||
| `db_admin/` | 数据库管理(Excel 导入 DWS 支出/回款/提成) | `python scripts/db_admin/import_dws_excel.py --type expense` |
|
||||
| `export/` | 数据导出(指数、团购、亲密度、会员明细等) | `python scripts/export/export_index_tables.py` |
|
||||
| `rebuild/` | 数据重建(全量 ODS→DWD 重建) | `python scripts/rebuild/rebuild_db_and_run_ods_to_dwd.py` |
|
||||
| `repair/` | 数据修复(回填、去重、hash 修复、维度修复) | `python scripts/repair/dedupe_ods_snapshots.py` |
|
||||
|
||||
## 根目录脚本
|
||||
|
||||
- `run_update.py` — 一键增量更新(ODS → DWD → DWS),适合 cron/计划任务调用
|
||||
- `run_ods.bat` — Windows 批处理:ODS 建表 + 灌入示例 JSON
|
||||
|
||||
## 运行方式
|
||||
|
||||
所有脚本在项目根目录(`C:\ZQYY\FQ-ETL`)执行:
|
||||
|
||||
```bash
|
||||
# 审计报告生成
|
||||
python -m scripts.audit.run_audit
|
||||
|
||||
# 一键增量更新
|
||||
python scripts/run_update.py
|
||||
|
||||
# 数据完整性检查(需要数据库连接)
|
||||
python -m scripts.check.check_data_integrity --window-start "2025-01-01" --window-end "2025-02-01"
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 所有脚本依赖 `.env` 中的 `PG_DSN` 配置(或环境变量)
|
||||
- `rebuild/` 下的脚本会重建 Schema,生产环境慎用
|
||||
- `repair/` 下的脚本会修改数据,建议先 `--dry-run`(如支持)
|
||||
1
scripts/__init__.py
Normal file
1
scripts/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# 脚本辅助工具包标记。
|
||||
107
scripts/audit/__init__.py
Normal file
107
scripts/audit/__init__.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
仓库治理只读审计 — 共享数据模型
|
||||
|
||||
定义审计脚本各模块共用的 dataclass 和枚举类型。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 文件元信息
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class FileEntry:
|
||||
"""单个文件/目录的元信息。"""
|
||||
|
||||
rel_path: str # 相对于仓库根目录的路径
|
||||
is_dir: bool # 是否为目录
|
||||
size_bytes: int # 文件大小(目录为 0)
|
||||
extension: str # 文件扩展名(小写,含点号)
|
||||
is_empty_dir: bool # 是否为空目录
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 用途分类与处置标签
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Category(str, Enum):
|
||||
"""文件用途分类。"""
|
||||
|
||||
CORE_CODE = "核心代码"
|
||||
CONFIG = "配置"
|
||||
DATABASE_DEF = "数据库定义"
|
||||
TEST = "测试"
|
||||
DOCS = "文档"
|
||||
SCRIPTS = "脚本工具"
|
||||
GUI = "GUI"
|
||||
BUILD_DEPLOY = "构建与部署"
|
||||
LOG_OUTPUT = "日志与输出"
|
||||
TEMP_DEBUG = "临时与调试"
|
||||
OTHER = "其他"
|
||||
|
||||
|
||||
class Disposition(str, Enum):
|
||||
"""处置标签。"""
|
||||
|
||||
KEEP = "保留"
|
||||
CANDIDATE_DELETE = "候选删除"
|
||||
CANDIDATE_ARCHIVE = "候选归档"
|
||||
NEEDS_REVIEW = "待确认"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 文件清单条目
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class InventoryItem:
|
||||
"""清单条目:路径 + 分类 + 处置 + 说明。"""
|
||||
|
||||
rel_path: str
|
||||
category: Category
|
||||
disposition: Disposition
|
||||
description: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 流程树节点
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class FlowNode:
|
||||
"""流程树节点。"""
|
||||
|
||||
name: str # 节点名称(模块名/类名/函数名)
|
||||
source_file: str # 所在源文件路径
|
||||
node_type: str # 类型:entry / module / class / function
|
||||
children: list[FlowNode] = field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 文档对齐
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class DocMapping:
|
||||
"""文档与代码的映射关系。"""
|
||||
|
||||
doc_path: str # 文档文件路径
|
||||
doc_topic: str # 文档主题
|
||||
related_code: list[str] # 关联的代码文件/模块
|
||||
status: str # 状态:aligned / stale / conflict / orphan
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlignmentIssue:
|
||||
"""对齐问题。"""
|
||||
|
||||
doc_path: str # 文档路径
|
||||
issue_type: str # stale / conflict / missing
|
||||
description: str # 问题描述
|
||||
related_code: str # 关联代码路径
|
||||
617
scripts/audit/doc_alignment_analyzer.py
Normal file
617
scripts/audit/doc_alignment_analyzer.py
Normal file
@@ -0,0 +1,617 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
文档对齐分析器 — 检查文档与代码之间的映射关系、过期点、冲突点和缺失点。
|
||||
|
||||
文档来源:
|
||||
- docs/ 目录(.md, .txt, .csv, .json)
|
||||
- 根目录 README.md
|
||||
- 开发笔记/ 目录
|
||||
- 各模块内的 README.md
|
||||
- .kiro/steering/ 引导文件
|
||||
- docs/test-json-doc/ API 响应样本
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from scripts.audit import AlignmentIssue, DocMapping
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 常量
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 文档文件扩展名
|
||||
_DOC_EXTENSIONS = {".md", ".txt", ".csv"}
|
||||
|
||||
# 核心代码目录——缺少文档时应报告
|
||||
_CORE_CODE_DIRS = {
|
||||
"tasks",
|
||||
"loaders",
|
||||
"orchestration",
|
||||
"quality",
|
||||
"models",
|
||||
"utils",
|
||||
"api",
|
||||
"scd",
|
||||
"config",
|
||||
"database",
|
||||
}
|
||||
|
||||
# ODS 表中的通用元数据列,比对时忽略
|
||||
_ODS_META_COLUMNS = {"content_hash", "payload", "created_at", "updated_at", "id"}
|
||||
|
||||
# SQL 关键字,解析 DDL 列名时排除
|
||||
_SQL_KEYWORDS = {
|
||||
"primary", "key", "not", "null", "default", "unique", "check",
|
||||
"references", "foreign", "constraint", "index", "create", "table",
|
||||
"if", "exists", "serial", "bigserial", "true", "false",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 安全读取文件(编码回退)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _safe_read(path: Path) -> str:
|
||||
"""尝试以 utf-8 → gbk → latin-1 回退读取文件内容。"""
|
||||
for enc in ("utf-8", "gbk", "latin-1"):
|
||||
try:
|
||||
return path.read_text(encoding=enc)
|
||||
except (UnicodeDecodeError, UnicodeError):
|
||||
continue
|
||||
return ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# scan_docs — 扫描所有文档来源
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def scan_docs(repo_root: Path) -> list[str]:
|
||||
"""扫描所有文档文件路径,返回相对路径列表(已排序)。
|
||||
|
||||
文档来源:
|
||||
1. docs/ 目录下的 .md, .txt, .csv, .json 文件
|
||||
2. 根目录 README.md
|
||||
3. 开发笔记/ 目录
|
||||
4. 各模块内的 README.md(如 gui/README.md)
|
||||
5. .kiro/steering/ 引导文件
|
||||
"""
|
||||
results: list[str] = []
|
||||
|
||||
def _rel(p: Path) -> str:
|
||||
"""返回归一化的正斜杠相对路径。"""
|
||||
return str(p.relative_to(repo_root)).replace("\\", "/")
|
||||
|
||||
# 1. docs/ 目录(递归,含 test-json-doc 下的 .json)
|
||||
docs_dir = repo_root / "docs"
|
||||
if docs_dir.is_dir():
|
||||
for p in docs_dir.rglob("*"):
|
||||
if p.is_file():
|
||||
ext = p.suffix.lower()
|
||||
if ext in _DOC_EXTENSIONS or ext == ".json":
|
||||
results.append(_rel(p))
|
||||
|
||||
# 2. 根目录 README.md
|
||||
root_readme = repo_root / "README.md"
|
||||
if root_readme.is_file():
|
||||
results.append("README.md")
|
||||
|
||||
# 3. 开发笔记/
|
||||
dev_notes = repo_root / "开发笔记"
|
||||
if dev_notes.is_dir():
|
||||
for p in dev_notes.rglob("*"):
|
||||
if p.is_file():
|
||||
results.append(_rel(p))
|
||||
|
||||
# 4. 各模块内的 README.md
|
||||
for child in sorted(repo_root.iterdir()):
|
||||
if child.is_dir() and child.name not in ("docs", "开发笔记", ".kiro"):
|
||||
readme = child / "README.md"
|
||||
if readme.is_file():
|
||||
results.append(_rel(readme))
|
||||
|
||||
# 5. .kiro/steering/
|
||||
steering_dir = repo_root / ".kiro" / "steering"
|
||||
if steering_dir.is_dir():
|
||||
for p in sorted(steering_dir.iterdir()):
|
||||
if p.is_file():
|
||||
results.append(_rel(p))
|
||||
|
||||
return sorted(set(results))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extract_code_references — 从文档提取代码引用
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def extract_code_references(doc_path: Path) -> list[str]:
|
||||
"""从文档中提取代码引用(反引号内的文件路径、类名、函数名等)。
|
||||
|
||||
规则:
|
||||
- 提取反引号内的内容
|
||||
- 跳过单字符引用
|
||||
- 跳过纯数字/版本号
|
||||
- 反斜杠归一化为正斜杠
|
||||
- 去重
|
||||
"""
|
||||
if not doc_path.is_file():
|
||||
return []
|
||||
|
||||
text = _safe_read(doc_path)
|
||||
if not text:
|
||||
return []
|
||||
|
||||
# 提取反引号内容
|
||||
backtick_refs = re.findall(r"`([^`]+)`", text)
|
||||
|
||||
seen: set[str] = set()
|
||||
results: list[str] = []
|
||||
|
||||
for raw in backtick_refs:
|
||||
ref = raw.strip()
|
||||
# 归一化反斜杠
|
||||
ref = ref.replace("\\", "/")
|
||||
# 跳过单字符
|
||||
if len(ref) <= 1:
|
||||
continue
|
||||
# 跳过纯数字和版本号
|
||||
if re.fullmatch(r"[\d.]+", ref):
|
||||
continue
|
||||
# 去重
|
||||
if ref in seen:
|
||||
continue
|
||||
seen.add(ref)
|
||||
results.append(ref)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_reference_validity — 检查引用有效性
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def check_reference_validity(ref: str, repo_root: Path) -> bool:
|
||||
"""检查文档中的代码引用是否仍然有效。
|
||||
|
||||
检查策略:
|
||||
1. 直接作为文件/目录路径检查
|
||||
2. 去掉 FQ-ETL/ 前缀后检查(兼容旧文档引用)
|
||||
3. 将点号路径转为文件路径检查(如 config.settings → config/settings.py)
|
||||
"""
|
||||
# 1. 直接路径
|
||||
if (repo_root / ref).exists():
|
||||
return True
|
||||
|
||||
# 2. 去掉旧包名前缀(兼容历史文档)
|
||||
for prefix in ("FQ-ETL/", "etl_billiards/"):
|
||||
if ref.startswith(prefix):
|
||||
stripped = ref[len(prefix):]
|
||||
if (repo_root / stripped).exists():
|
||||
return True
|
||||
|
||||
# 3. 点号模块路径 → 文件路径
|
||||
if "." in ref and "/" not in ref:
|
||||
as_path = ref.replace(".", "/") + ".py"
|
||||
if (repo_root / as_path).exists():
|
||||
return True
|
||||
# 也可能是目录(包)
|
||||
as_dir = ref.replace(".", "/")
|
||||
if (repo_root / as_dir).is_dir():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_undocumented_modules — 找出缺少文档的核心代码模块
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def find_undocumented_modules(
|
||||
repo_root: Path,
|
||||
documented: set[str],
|
||||
) -> list[str]:
|
||||
"""找出缺少文档的核心代码模块。
|
||||
|
||||
只检查 _CORE_CODE_DIRS 中的 .py 文件(排除 __init__.py)。
|
||||
返回已排序的相对路径列表。
|
||||
"""
|
||||
undocumented: list[str] = []
|
||||
|
||||
for core_dir in sorted(_CORE_CODE_DIRS):
|
||||
dir_path = repo_root / core_dir
|
||||
if not dir_path.is_dir():
|
||||
continue
|
||||
for py_file in dir_path.rglob("*.py"):
|
||||
if py_file.name == "__init__.py":
|
||||
continue
|
||||
rel = str(py_file.relative_to(repo_root))
|
||||
# 归一化路径分隔符
|
||||
rel = rel.replace("\\", "/")
|
||||
if rel not in documented:
|
||||
undocumented.append(rel)
|
||||
|
||||
return sorted(undocumented)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DDL / 数据字典解析辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _parse_ddl_tables(sql: str) -> dict[str, set[str]]:
|
||||
"""从 DDL SQL 中提取表名和列名。
|
||||
|
||||
返回 {表名: {列名集合}} 字典。
|
||||
支持带 schema 前缀的表名(如 billiards_dwd.dim_member → dim_member)。
|
||||
"""
|
||||
tables: dict[str, set[str]] = {}
|
||||
|
||||
# 匹配 CREATE TABLE [IF NOT EXISTS] [schema.]table_name (
|
||||
create_re = re.compile(
|
||||
r"CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?"
|
||||
r"(?:\w+\.)?(\w+)\s*\(",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
for match in create_re.finditer(sql):
|
||||
table_name = match.group(1)
|
||||
# 找到对应的括号内容
|
||||
start = match.end()
|
||||
depth = 1
|
||||
pos = start
|
||||
while pos < len(sql) and depth > 0:
|
||||
if sql[pos] == "(":
|
||||
depth += 1
|
||||
elif sql[pos] == ")":
|
||||
depth -= 1
|
||||
pos += 1
|
||||
body = sql[start:pos - 1]
|
||||
|
||||
columns: set[str] = set()
|
||||
# 逐行提取列名——取每行第一个标识符
|
||||
for line in body.split("\n"):
|
||||
line = line.strip().rstrip(",")
|
||||
if not line:
|
||||
continue
|
||||
# 提取第一个单词
|
||||
col_match = re.match(r"(\w+)", line)
|
||||
if col_match:
|
||||
col_name = col_match.group(1).lower()
|
||||
# 排除 SQL 关键字
|
||||
if col_name not in _SQL_KEYWORDS:
|
||||
columns.add(col_name)
|
||||
|
||||
tables[table_name] = columns
|
||||
|
||||
return tables
|
||||
|
||||
|
||||
def _parse_dictionary_tables(md: str) -> dict[str, set[str]]:
|
||||
"""从数据字典 Markdown 中提取表名和字段名。
|
||||
|
||||
约定:
|
||||
- 表名出现在 ## 标题中(可能带反引号)
|
||||
- 字段名出现在 Markdown 表格的第一列
|
||||
- 跳过表头行(含"字段"字样)和分隔行(含 ---)
|
||||
"""
|
||||
tables: dict[str, set[str]] = {}
|
||||
current_table: str | None = None
|
||||
|
||||
for line in md.split("\n"):
|
||||
# 匹配 ## 标题中的表名
|
||||
heading = re.match(r"^##\s+`?(\w+)`?", line)
|
||||
if heading:
|
||||
current_table = heading.group(1)
|
||||
tables[current_table] = set()
|
||||
continue
|
||||
|
||||
if current_table is None:
|
||||
continue
|
||||
|
||||
# 跳过分隔行
|
||||
if re.match(r"^\s*\|[-\s|]+\|\s*$", line):
|
||||
continue
|
||||
|
||||
# 解析表格行
|
||||
row_match = re.match(r"^\s*\|\s*(\S+)", line)
|
||||
if row_match:
|
||||
field = row_match.group(1)
|
||||
# 跳过表头(含"字段"字样)
|
||||
if field in ("字段",):
|
||||
continue
|
||||
tables[current_table].add(field)
|
||||
|
||||
return tables
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_ddl_vs_dictionary — DDL 与数据字典比对
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def check_ddl_vs_dictionary(repo_root: Path) -> list[AlignmentIssue]:
|
||||
"""比对 DDL 文件与数据字典文档的覆盖度。
|
||||
|
||||
检查:
|
||||
1. DDL 中有但字典中没有的表 → missing
|
||||
2. 同名表中 DDL 有但字典没有的列 → conflict
|
||||
"""
|
||||
issues: list[AlignmentIssue] = []
|
||||
|
||||
# 收集所有 DDL 表定义
|
||||
ddl_tables: dict[str, set[str]] = {}
|
||||
db_dir = repo_root / "database"
|
||||
if db_dir.is_dir():
|
||||
for sql_file in sorted(db_dir.glob("schema_*.sql")):
|
||||
content = _safe_read(sql_file)
|
||||
for tbl, cols in _parse_ddl_tables(content).items():
|
||||
if tbl in ddl_tables:
|
||||
ddl_tables[tbl] |= cols
|
||||
else:
|
||||
ddl_tables[tbl] = set(cols)
|
||||
|
||||
# 收集所有数据字典表定义
|
||||
dict_tables: dict[str, set[str]] = {}
|
||||
docs_dir = repo_root / "docs"
|
||||
if docs_dir.is_dir():
|
||||
for dict_file in sorted(docs_dir.glob("*dictionary*.md")):
|
||||
content = _safe_read(dict_file)
|
||||
for tbl, fields in _parse_dictionary_tables(content).items():
|
||||
if tbl in dict_tables:
|
||||
dict_tables[tbl] |= fields
|
||||
else:
|
||||
dict_tables[tbl] = set(fields)
|
||||
|
||||
# 比对
|
||||
for tbl, ddl_cols in sorted(ddl_tables.items()):
|
||||
if tbl not in dict_tables:
|
||||
issues.append(AlignmentIssue(
|
||||
doc_path="docs/*dictionary*.md",
|
||||
issue_type="missing",
|
||||
description=f"DDL 定义了表 `{tbl}`,但数据字典中未收录",
|
||||
related_code=f"database/schema_*.sql ({tbl})",
|
||||
))
|
||||
else:
|
||||
# 检查列差异
|
||||
dict_cols = dict_tables[tbl]
|
||||
missing_cols = ddl_cols - dict_cols
|
||||
for col in sorted(missing_cols):
|
||||
issues.append(AlignmentIssue(
|
||||
doc_path="docs/*dictionary*.md",
|
||||
issue_type="conflict",
|
||||
description=f"表 `{tbl}` 的列 `{col}` 在 DDL 中存在但数据字典中缺失",
|
||||
related_code=f"database/schema_*.sql ({tbl}.{col})",
|
||||
))
|
||||
|
||||
return issues
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_api_samples_vs_parsers — API 样本与解析器比对
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def check_api_samples_vs_parsers(repo_root: Path) -> list[AlignmentIssue]:
|
||||
"""比对 API 响应样本与 ODS 表结构的一致性。
|
||||
|
||||
策略:
|
||||
1. 扫描 docs/test-json-doc/ 下的 .json 文件
|
||||
2. 提取 JSON 中的顶层字段名
|
||||
3. 从 ODS DDL 中查找同名表
|
||||
4. 比对字段差异(忽略 ODS 元数据列)
|
||||
"""
|
||||
issues: list[AlignmentIssue] = []
|
||||
|
||||
sample_dir = repo_root / "docs" / "test-json-doc"
|
||||
if not sample_dir.is_dir():
|
||||
return issues
|
||||
|
||||
# 收集 ODS 表定义(保留全部列,比对时忽略元数据列)
|
||||
ods_tables: dict[str, set[str]] = {}
|
||||
db_dir = repo_root / "database"
|
||||
if db_dir.is_dir():
|
||||
for sql_file in sorted(db_dir.glob("schema_*ODS*.sql")):
|
||||
content = _safe_read(sql_file)
|
||||
for tbl, cols in _parse_ddl_tables(content).items():
|
||||
ods_tables[tbl] = cols
|
||||
|
||||
# 逐个样本文件比对
|
||||
for json_file in sorted(sample_dir.glob("*.json")):
|
||||
entity_name = json_file.stem # 文件名(不含扩展名)作为实体名
|
||||
|
||||
# 解析 JSON 样本
|
||||
try:
|
||||
content = _safe_read(json_file)
|
||||
data = json.loads(content)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
continue
|
||||
|
||||
# 提取顶层字段名
|
||||
sample_fields: set[str] = set()
|
||||
if isinstance(data, list) and data:
|
||||
# 数组格式——取第一个元素的键
|
||||
first = data[0]
|
||||
if isinstance(first, dict):
|
||||
sample_fields = set(first.keys())
|
||||
elif isinstance(data, dict):
|
||||
sample_fields = set(data.keys())
|
||||
|
||||
if not sample_fields:
|
||||
continue
|
||||
|
||||
# 查找匹配的 ODS 表
|
||||
matched_table: str | None = None
|
||||
matched_cols: set[str] = set()
|
||||
for tbl, cols in ods_tables.items():
|
||||
# 表名包含实体名(如 test_entity 匹配 billiards_ods.test_entity)
|
||||
tbl_lower = tbl.lower()
|
||||
entity_lower = entity_name.lower()
|
||||
if entity_lower in tbl_lower or tbl_lower == entity_lower:
|
||||
matched_table = tbl
|
||||
matched_cols = cols
|
||||
break
|
||||
|
||||
if matched_table is None:
|
||||
continue
|
||||
|
||||
# 比对:样本中有但 ODS 表中没有的字段
|
||||
extra_fields = sample_fields - matched_cols
|
||||
for field in sorted(extra_fields):
|
||||
issues.append(AlignmentIssue(
|
||||
doc_path=f"docs/test-json-doc/{json_file.name}",
|
||||
issue_type="conflict",
|
||||
description=(
|
||||
f"API 样本字段 `{field}` 在 ODS 表 `{matched_table}` 中未定义"
|
||||
),
|
||||
related_code=f"database/schema_*ODS*.sql ({matched_table})",
|
||||
))
|
||||
|
||||
return issues
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_mappings — 构建文档与代码的映射关系
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_mappings(
|
||||
doc_paths: list[str],
|
||||
repo_root: Path,
|
||||
) -> list[DocMapping]:
|
||||
"""为每份文档建立与代码模块的映射关系。"""
|
||||
mappings: list[DocMapping] = []
|
||||
|
||||
for doc_rel in doc_paths:
|
||||
doc_path = repo_root / doc_rel
|
||||
refs = extract_code_references(doc_path)
|
||||
|
||||
# 确定关联代码和状态
|
||||
valid_refs: list[str] = []
|
||||
has_stale = False
|
||||
for ref in refs:
|
||||
if check_reference_validity(ref, repo_root):
|
||||
valid_refs.append(ref)
|
||||
else:
|
||||
has_stale = True
|
||||
|
||||
# 推断文档主题(取文件名或第一行标题)
|
||||
topic = _infer_topic(doc_path, doc_rel)
|
||||
|
||||
if not refs:
|
||||
status = "orphan"
|
||||
elif has_stale:
|
||||
status = "stale"
|
||||
else:
|
||||
status = "aligned"
|
||||
|
||||
mappings.append(DocMapping(
|
||||
doc_path=doc_rel,
|
||||
doc_topic=topic,
|
||||
related_code=valid_refs,
|
||||
status=status,
|
||||
))
|
||||
|
||||
return mappings
|
||||
|
||||
|
||||
def _infer_topic(doc_path: Path, doc_rel: str) -> str:
|
||||
"""从文档推断主题——优先取 Markdown 一级标题,否则用文件名。"""
|
||||
if doc_path.is_file() and doc_path.suffix.lower() in (".md", ".txt"):
|
||||
try:
|
||||
text = _safe_read(doc_path)
|
||||
for line in text.split("\n"):
|
||||
line = line.strip()
|
||||
if line.startswith("# "):
|
||||
return line[2:].strip()
|
||||
except Exception:
|
||||
pass
|
||||
return doc_rel
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# render_alignment_report — 生成 Markdown 格式的文档对齐报告
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def render_alignment_report(
|
||||
mappings: list[DocMapping],
|
||||
issues: list[AlignmentIssue],
|
||||
repo_root: str,
|
||||
) -> str:
|
||||
"""生成 Markdown 格式的文档对齐报告。
|
||||
|
||||
分区:映射关系表、过期点列表、冲突点列表、缺失点列表、统计摘要。
|
||||
"""
|
||||
lines: list[str] = []
|
||||
|
||||
# --- 头部 ---
|
||||
now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
lines.append("# 文档对齐报告")
|
||||
lines.append("")
|
||||
lines.append(f"- 生成时间:{now}")
|
||||
lines.append(f"- 仓库路径:`{repo_root}`")
|
||||
lines.append("")
|
||||
|
||||
# --- 映射关系 ---
|
||||
lines.append("## 映射关系")
|
||||
lines.append("")
|
||||
if mappings:
|
||||
lines.append("| 文档路径 | 主题 | 关联代码 | 状态 |")
|
||||
lines.append("|---|---|---|---|")
|
||||
for m in mappings:
|
||||
code_str = ", ".join(f"`{c}`" for c in m.related_code) if m.related_code else "—"
|
||||
lines.append(f"| `{m.doc_path}` | {m.doc_topic} | {code_str} | {m.status} |")
|
||||
else:
|
||||
lines.append("未发现文档映射关系。")
|
||||
lines.append("")
|
||||
|
||||
# --- 按 issue_type 分组 ---
|
||||
stale = [i for i in issues if i.issue_type == "stale"]
|
||||
conflict = [i for i in issues if i.issue_type == "conflict"]
|
||||
missing = [i for i in issues if i.issue_type == "missing"]
|
||||
|
||||
# --- 过期点 ---
|
||||
lines.append("## 过期点")
|
||||
lines.append("")
|
||||
if stale:
|
||||
lines.append("| 文档路径 | 描述 | 关联代码 |")
|
||||
lines.append("|---|---|---|")
|
||||
for i in stale:
|
||||
lines.append(f"| `{i.doc_path}` | {i.description} | `{i.related_code}` |")
|
||||
else:
|
||||
lines.append("未发现过期点。")
|
||||
lines.append("")
|
||||
|
||||
# --- 冲突点 ---
|
||||
lines.append("## 冲突点")
|
||||
lines.append("")
|
||||
if conflict:
|
||||
lines.append("| 文档路径 | 描述 | 关联代码 |")
|
||||
lines.append("|---|---|---|")
|
||||
for i in conflict:
|
||||
lines.append(f"| `{i.doc_path}` | {i.description} | `{i.related_code}` |")
|
||||
else:
|
||||
lines.append("未发现冲突点。")
|
||||
lines.append("")
|
||||
|
||||
# --- 缺失点 ---
|
||||
lines.append("## 缺失点")
|
||||
lines.append("")
|
||||
if missing:
|
||||
lines.append("| 文档路径 | 描述 | 关联代码 |")
|
||||
lines.append("|---|---|---|")
|
||||
for i in missing:
|
||||
lines.append(f"| `{i.doc_path}` | {i.description} | `{i.related_code}` |")
|
||||
else:
|
||||
lines.append("未发现缺失点。")
|
||||
lines.append("")
|
||||
|
||||
# --- 统计摘要 ---
|
||||
lines.append("## 统计摘要")
|
||||
lines.append("")
|
||||
lines.append(f"- 文档总数:{len(mappings)}")
|
||||
lines.append(f"- 过期点数量:{len(stale)}")
|
||||
lines.append(f"- 冲突点数量:{len(conflict)}")
|
||||
lines.append(f"- 缺失点数量:{len(missing)}")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
618
scripts/audit/flow_analyzer.py
Normal file
618
scripts/audit/flow_analyzer.py
Normal file
@@ -0,0 +1,618 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
流程树分析器 — 通过静态分析 Python 源码的 import 语句和类继承关系,
|
||||
构建从入口到末端模块的调用树。
|
||||
|
||||
仅执行只读操作:读取并解析 Python 源文件,不修改任何文件。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from scripts.audit import FileEntry, FlowNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 项目内部包名列表(顶层目录中属于项目代码的包)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PROJECT_PACKAGES: set[str] = {
|
||||
"cli", "config", "api", "database", "tasks", "loaders",
|
||||
"scd", "orchestration", "quality", "models", "utils",
|
||||
"gui", "scripts",
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 已知的第三方包和标准库顶层模块(用于排除非项目导入)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_KNOWN_THIRD_PARTY: set[str] = {
|
||||
"psycopg2", "requests", "dateutil", "python_dateutil",
|
||||
"dotenv", "openpyxl", "PySide6", "flask", "pyinstaller",
|
||||
"PyInstaller", "hypothesis", "pytest", "_pytest", "py",
|
||||
"pluggy", "pkg_resources", "setuptools", "pip", "wheel",
|
||||
"tzdata", "six", "certifi", "urllib3", "charset_normalizer",
|
||||
"idna", "shiboken6",
|
||||
}
|
||||
|
||||
|
||||
def _is_project_module(module_name: str) -> bool:
|
||||
"""判断模块名是否属于项目内部模块。"""
|
||||
top = module_name.split(".")[0]
|
||||
if top in _PROJECT_PACKAGES:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_stdlib_or_third_party(module_name: str) -> bool:
|
||||
"""判断模块名是否属于标准库或已知第三方包。"""
|
||||
top = module_name.split(".")[0]
|
||||
if top in _KNOWN_THIRD_PARTY:
|
||||
return True
|
||||
# 检查标准库
|
||||
if top in sys.stdlib_module_names:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 文件读取(多编码回退)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _read_source(filepath: Path) -> str | None:
|
||||
"""读取 Python 源文件内容,尝试 utf-8 → gbk → latin-1 回退。
|
||||
|
||||
返回文件内容字符串,读取失败时返回 None。
|
||||
"""
|
||||
for encoding in ("utf-8", "gbk", "latin-1"):
|
||||
try:
|
||||
return filepath.read_text(encoding=encoding)
|
||||
except (UnicodeDecodeError, UnicodeError):
|
||||
continue
|
||||
except (OSError, PermissionError) as exc:
|
||||
logger.warning("无法读取文件 %s: %s", filepath, exc)
|
||||
return None
|
||||
logger.warning("无法以任何编码读取文件 %s", filepath)
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 路径 ↔ 模块名转换
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _path_to_module_name(rel_path: str) -> str:
|
||||
"""将相对路径转换为 Python 模块名。
|
||||
|
||||
例如:
|
||||
- "cli/main.py" → "cli.main"
|
||||
- "cli/__init__.py" → "cli"
|
||||
- "tasks/dws/assistant.py" → "tasks.dws.assistant"
|
||||
"""
|
||||
p = rel_path.replace("\\", "/")
|
||||
if p.endswith("/__init__.py"):
|
||||
p = p[: -len("/__init__.py")]
|
||||
elif p.endswith(".py"):
|
||||
p = p[:-3]
|
||||
return p.replace("/", ".")
|
||||
|
||||
|
||||
def _module_to_path(module_name: str) -> str:
|
||||
"""将模块名转换为相对文件路径(优先 .py 文件)。
|
||||
|
||||
例如:
|
||||
- "cli.main" → "cli/main.py"
|
||||
- "cli" → "cli/__init__.py"
|
||||
"""
|
||||
return module_name.replace(".", "/") + ".py"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_imports — 解析 Python 文件的 import 语句
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def parse_imports(filepath: Path) -> list[str]:
|
||||
"""使用 ast 模块解析 Python 文件的 import 语句,返回被导入的本地模块列表。
|
||||
|
||||
- 仅返回项目内部模块(排除标准库和第三方包)
|
||||
- 结果去重
|
||||
- 语法错误或文件不存在时返回空列表
|
||||
"""
|
||||
if not filepath.exists():
|
||||
return []
|
||||
|
||||
source = _read_source(filepath)
|
||||
if source is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
tree = ast.parse(source, filename=str(filepath))
|
||||
except SyntaxError:
|
||||
logger.warning("语法错误,无法解析 %s", filepath)
|
||||
return []
|
||||
|
||||
modules: list[str] = []
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
name = alias.name
|
||||
if _is_project_module(name) and not _is_stdlib_or_third_party(name):
|
||||
modules.append(name)
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
if node.module and node.level == 0:
|
||||
name = node.module
|
||||
if _is_project_module(name) and not _is_stdlib_or_third_party(name):
|
||||
modules.append(name)
|
||||
|
||||
# 去重并保持顺序
|
||||
seen: set[str] = set()
|
||||
result: list[str] = []
|
||||
for m in modules:
|
||||
if m not in seen:
|
||||
seen.add(m)
|
||||
result.append(m)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_flow_tree — 从入口递归追踪 import 链,构建流程树
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_flow_tree(
|
||||
repo_root: Path,
|
||||
entry_file: str,
|
||||
_visited: set[str] | None = None,
|
||||
) -> FlowNode:
|
||||
"""从指定入口文件出发,递归追踪 import 链,构建流程树。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
repo_root : Path
|
||||
仓库根目录。
|
||||
entry_file : str
|
||||
入口文件的相对路径(如 "cli/main.py")。
|
||||
_visited : set[str] | None
|
||||
内部使用,防止循环导入导致无限递归。
|
||||
|
||||
Returns
|
||||
-------
|
||||
FlowNode
|
||||
以入口文件为根的流程树。
|
||||
"""
|
||||
is_root = _visited is None
|
||||
if _visited is None:
|
||||
_visited = set()
|
||||
|
||||
module_name = _path_to_module_name(entry_file)
|
||||
node_type = "entry" if is_root else "module"
|
||||
|
||||
_visited.add(entry_file)
|
||||
|
||||
filepath = repo_root / entry_file
|
||||
children: list[FlowNode] = []
|
||||
|
||||
if filepath.exists():
|
||||
imported_modules = parse_imports(filepath)
|
||||
for mod in imported_modules:
|
||||
child_path = _module_to_path(mod)
|
||||
# 如果 .py 文件不存在,尝试 __init__.py
|
||||
if not (repo_root / child_path).exists():
|
||||
alt_path = mod.replace(".", "/") + "/__init__.py"
|
||||
if (repo_root / alt_path).exists():
|
||||
child_path = alt_path
|
||||
|
||||
if child_path not in _visited:
|
||||
child_node = build_flow_tree(repo_root, child_path, _visited)
|
||||
children.append(child_node)
|
||||
|
||||
return FlowNode(
|
||||
name=module_name,
|
||||
source_file=entry_file,
|
||||
node_type=node_type,
|
||||
children=children,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 批处理文件解析
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _parse_bat_python_target(bat_path: Path) -> str | None:
|
||||
"""从批处理文件中解析 python -m 命令的目标模块名。
|
||||
|
||||
返回模块名(如 "cli.main"),未找到时返回 None。
|
||||
"""
|
||||
if not bat_path.exists():
|
||||
return None
|
||||
|
||||
content = _read_source(bat_path)
|
||||
if content is None:
|
||||
return None
|
||||
|
||||
# 匹配 python -m module.name 或 python3 -m module.name
|
||||
pattern = re.compile(r"python[3]?\s+-m\s+([\w.]+)", re.IGNORECASE)
|
||||
for line in content.splitlines():
|
||||
m = pattern.search(line)
|
||||
if m:
|
||||
return m.group(1)
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 入口点识别
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def discover_entry_points(repo_root: Path) -> list[dict[str, str]]:
|
||||
"""识别项目的所有入口点。
|
||||
|
||||
返回字典列表,每个字典包含:
|
||||
- type: 入口类型(CLI / GUI / 批处理 / 运维脚本)
|
||||
- file: 相对路径
|
||||
- description: 简要说明
|
||||
|
||||
识别规则:
|
||||
- cli/main.py → CLI 入口
|
||||
- gui/main.py → GUI 入口
|
||||
- *.bat 文件 → 解析其中的 python -m 命令
|
||||
- scripts/*.py(含 if __name__ == "__main__",排除 __init__.py 和 audit/ 子目录)
|
||||
"""
|
||||
entries: list[dict[str, str]] = []
|
||||
|
||||
# CLI 入口
|
||||
cli_main = repo_root / "cli" / "main.py"
|
||||
if cli_main.exists():
|
||||
entries.append({
|
||||
"type": "CLI",
|
||||
"file": "cli/main.py",
|
||||
"description": "CLI 主入口 (`python -m cli.main`)",
|
||||
})
|
||||
|
||||
# GUI 入口
|
||||
gui_main = repo_root / "gui" / "main.py"
|
||||
if gui_main.exists():
|
||||
entries.append({
|
||||
"type": "GUI",
|
||||
"file": "gui/main.py",
|
||||
"description": "GUI 主入口 (`python -m gui.main`)",
|
||||
})
|
||||
|
||||
# 批处理文件
|
||||
for bat in sorted(repo_root.glob("*.bat")):
|
||||
target = _parse_bat_python_target(bat)
|
||||
desc = f"批处理脚本"
|
||||
if target:
|
||||
desc += f",调用 `{target}`"
|
||||
entries.append({
|
||||
"type": "批处理",
|
||||
"file": bat.name,
|
||||
"description": desc,
|
||||
})
|
||||
|
||||
# 运维脚本:scripts/ 下的 .py 文件(排除 __init__.py 和 audit/ 子目录)
|
||||
scripts_dir = repo_root / "scripts"
|
||||
if scripts_dir.is_dir():
|
||||
for py_file in sorted(scripts_dir.glob("*.py")):
|
||||
if py_file.name == "__init__.py":
|
||||
continue
|
||||
# 检查是否包含 if __name__ == "__main__"
|
||||
source = _read_source(py_file)
|
||||
if source and '__name__' in source and '__main__' in source:
|
||||
rel = py_file.relative_to(repo_root).as_posix()
|
||||
entries.append({
|
||||
"type": "运维脚本",
|
||||
"file": rel,
|
||||
"description": f"运维脚本 `{py_file.name}`",
|
||||
})
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 任务类型和加载器类型区分
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def classify_task_type(rel_path: str) -> str:
|
||||
"""根据文件路径区分任务类型。
|
||||
|
||||
返回值:
|
||||
- "ODS 抓取任务"
|
||||
- "DWD 加载任务"
|
||||
- "DWS 汇总任务"
|
||||
- "校验任务"
|
||||
- "Schema 初始化任务"
|
||||
- "任务"(无法细分时的默认值)
|
||||
"""
|
||||
p = rel_path.replace("\\", "/").lower()
|
||||
|
||||
if "verification/" in p or "verification\\" in p:
|
||||
return "校验任务"
|
||||
if "dws/" in p or "dws\\" in p:
|
||||
return "DWS 汇总任务"
|
||||
# 文件名级别判断
|
||||
basename = p.rsplit("/", 1)[-1] if "/" in p else p
|
||||
if basename.startswith("ods_") or basename.startswith("ods."):
|
||||
return "ODS 抓取任务"
|
||||
if basename.startswith("dwd_") or basename.startswith("dwd."):
|
||||
return "DWD 加载任务"
|
||||
if basename.startswith("dws_"):
|
||||
return "DWS 汇总任务"
|
||||
if "init" in basename and "schema" in basename:
|
||||
return "Schema 初始化任务"
|
||||
return "任务"
|
||||
|
||||
|
||||
def classify_loader_type(rel_path: str) -> str:
|
||||
"""根据文件路径区分加载器类型。
|
||||
|
||||
返回值:
|
||||
- "维度加载器 (SCD2)"
|
||||
- "事实表加载器"
|
||||
- "ODS 通用加载器"
|
||||
- "加载器"(无法细分时的默认值)
|
||||
"""
|
||||
p = rel_path.replace("\\", "/").lower()
|
||||
|
||||
if "dimensions/" in p or "dimensions\\" in p:
|
||||
return "维度加载器 (SCD2)"
|
||||
if "facts/" in p or "facts\\" in p:
|
||||
return "事实表加载器"
|
||||
if "ods/" in p or "ods\\" in p:
|
||||
return "ODS 通用加载器"
|
||||
return "加载器"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_orphan_modules — 找出未被任何入口直接或间接引用的 Python 模块
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def find_orphan_modules(
|
||||
repo_root: Path,
|
||||
all_entries: list[FileEntry],
|
||||
reachable: set[str],
|
||||
) -> list[str]:
|
||||
"""找出未被任何入口直接或间接引用的 Python 模块。
|
||||
|
||||
排除规则(不视为孤立):
|
||||
- __init__.py 文件
|
||||
- tests/ 目录下的文件
|
||||
- scripts/audit/ 目录下的文件(审计脚本自身)
|
||||
- 目录条目
|
||||
- 非 .py 文件
|
||||
- 不属于项目包的文件
|
||||
|
||||
返回按路径排序的孤立模块列表。
|
||||
"""
|
||||
orphans: list[str] = []
|
||||
|
||||
for entry in all_entries:
|
||||
# 跳过目录
|
||||
if entry.is_dir:
|
||||
continue
|
||||
# 只关注 .py 文件
|
||||
if entry.extension != ".py":
|
||||
continue
|
||||
|
||||
rel = entry.rel_path.replace("\\", "/")
|
||||
|
||||
# 排除 __init__.py
|
||||
if rel.endswith("/__init__.py") or rel == "__init__.py":
|
||||
continue
|
||||
# 排除测试文件
|
||||
if rel.startswith("tests/") or rel.startswith("tests\\"):
|
||||
continue
|
||||
# 排除审计脚本自身
|
||||
if rel.startswith("scripts/audit/") or rel.startswith("scripts\\audit\\"):
|
||||
continue
|
||||
|
||||
# 只检查属于项目包的文件
|
||||
top_dir = rel.split("/")[0] if "/" in rel else ""
|
||||
if top_dir not in _PROJECT_PACKAGES:
|
||||
continue
|
||||
|
||||
# 不在可达集合中 → 孤立
|
||||
if rel not in reachable:
|
||||
orphans.append(rel)
|
||||
|
||||
orphans.sort()
|
||||
return orphans
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 统计辅助
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _count_nodes_by_type(trees: list[FlowNode]) -> dict[str, int]:
|
||||
"""递归统计流程树中各类型节点的数量。"""
|
||||
counts: dict[str, int] = {"entry": 0, "module": 0, "class": 0, "function": 0}
|
||||
|
||||
def _walk(node: FlowNode) -> None:
|
||||
t = node.node_type
|
||||
counts[t] = counts.get(t, 0) + 1
|
||||
for child in node.children:
|
||||
_walk(child)
|
||||
|
||||
for tree in trees:
|
||||
_walk(tree)
|
||||
return counts
|
||||
|
||||
|
||||
def _count_tasks_and_loaders(trees: list[FlowNode]) -> tuple[int, int]:
|
||||
"""统计流程树中任务模块和加载器模块的数量。"""
|
||||
tasks = 0
|
||||
loaders = 0
|
||||
seen: set[str] = set()
|
||||
|
||||
def _walk(node: FlowNode) -> None:
|
||||
nonlocal tasks, loaders
|
||||
if node.source_file in seen:
|
||||
return
|
||||
seen.add(node.source_file)
|
||||
sf = node.source_file.replace("\\", "/")
|
||||
if sf.startswith("tasks/") and not sf.endswith("__init__.py"):
|
||||
base = sf.rsplit("/", 1)[-1]
|
||||
if not base.startswith("base_"):
|
||||
tasks += 1
|
||||
if sf.startswith("loaders/") and not sf.endswith("__init__.py"):
|
||||
base = sf.rsplit("/", 1)[-1]
|
||||
if not base.startswith("base_"):
|
||||
loaders += 1
|
||||
for child in node.children:
|
||||
_walk(child)
|
||||
|
||||
for tree in trees:
|
||||
_walk(tree)
|
||||
return tasks, loaders
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 类型标注辅助
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_type_annotation(source_file: str) -> str:
|
||||
"""根据源文件路径返回类型标注字符串(用于报告中的节点标注)。"""
|
||||
sf = source_file.replace("\\", "/")
|
||||
if sf.startswith("tasks/"):
|
||||
return f" [{classify_task_type(sf)}]"
|
||||
if sf.startswith("loaders/"):
|
||||
return f" [{classify_loader_type(sf)}]"
|
||||
return ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mermaid 图生成
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _render_mermaid(trees: list[FlowNode]) -> str:
|
||||
"""生成 Mermaid 流程图代码。"""
|
||||
lines: list[str] = ["```mermaid", "graph TD"]
|
||||
seen_edges: set[tuple[str, str]] = set()
|
||||
node_ids: dict[str, str] = {}
|
||||
counter = [0]
|
||||
|
||||
def _node_id(name: str) -> str:
|
||||
if name not in node_ids:
|
||||
node_ids[name] = f"N{counter[0]}"
|
||||
counter[0] += 1
|
||||
return node_ids[name]
|
||||
|
||||
def _walk(node: FlowNode) -> None:
|
||||
nid = _node_id(node.name)
|
||||
annotation = _get_type_annotation(node.source_file)
|
||||
label = f"{node.name}{annotation}"
|
||||
# 声明节点
|
||||
lines.append(f" {nid}[\"`{label}`\"]")
|
||||
for child in node.children:
|
||||
cid = _node_id(child.name)
|
||||
edge = (nid, cid)
|
||||
if edge not in seen_edges:
|
||||
seen_edges.add(edge)
|
||||
lines.append(f" {nid} --> {cid}")
|
||||
_walk(child)
|
||||
|
||||
for tree in trees:
|
||||
_walk(tree)
|
||||
|
||||
lines.append("```")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 缩进文本树生成
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _render_text_tree(trees: list[FlowNode]) -> str:
|
||||
"""生成缩进文本形式的流程树。"""
|
||||
lines: list[str] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
def _walk(node: FlowNode, depth: int) -> None:
|
||||
indent = " " * depth
|
||||
annotation = _get_type_annotation(node.source_file)
|
||||
line = f"{indent}- `{node.name}` (`{node.source_file}`){annotation}"
|
||||
lines.append(line)
|
||||
|
||||
key = node.source_file
|
||||
if key in seen:
|
||||
# 已展开过,不再递归(避免循环)
|
||||
if node.children:
|
||||
lines.append(f"{indent} - *(已展开)*")
|
||||
return
|
||||
seen.add(key)
|
||||
|
||||
for child in node.children:
|
||||
_walk(child, depth + 1)
|
||||
|
||||
for tree in trees:
|
||||
_walk(tree, 0)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# render_flow_report — 生成 Markdown 格式的流程树报告
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def render_flow_report(
|
||||
trees: list[FlowNode],
|
||||
orphans: list[str],
|
||||
repo_root: str,
|
||||
) -> str:
|
||||
"""生成 Markdown 格式的流程树报告(含 Mermaid 图和缩进文本)。
|
||||
|
||||
报告结构:
|
||||
1. 头部(时间戳、仓库路径)
|
||||
2. Mermaid 流程图
|
||||
3. 缩进文本树
|
||||
4. 孤立模块列表
|
||||
5. 统计摘要
|
||||
"""
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
sections: list[str] = []
|
||||
|
||||
# --- 头部 ---
|
||||
sections.append("# 项目流程树报告\n")
|
||||
sections.append(f"- 生成时间: {timestamp}")
|
||||
sections.append(f"- 仓库路径: `{repo_root}`\n")
|
||||
|
||||
# --- Mermaid 图 ---
|
||||
sections.append("## 流程图(Mermaid)\n")
|
||||
sections.append(_render_mermaid(trees))
|
||||
sections.append("")
|
||||
|
||||
# --- 缩进文本树 ---
|
||||
sections.append("## 流程树(缩进文本)\n")
|
||||
sections.append(_render_text_tree(trees))
|
||||
sections.append("")
|
||||
|
||||
# --- 孤立模块 ---
|
||||
sections.append("## 孤立模块\n")
|
||||
if orphans:
|
||||
for o in orphans:
|
||||
sections.append(f"- `{o}`")
|
||||
else:
|
||||
sections.append("未发现孤立模块。")
|
||||
sections.append("")
|
||||
|
||||
# --- 统计摘要 ---
|
||||
entry_count = sum(1 for t in trees if t.node_type == "entry")
|
||||
task_count, loader_count = _count_tasks_and_loaders(trees)
|
||||
orphan_count = len(orphans)
|
||||
|
||||
sections.append("## 统计摘要\n")
|
||||
sections.append(f"| 指标 | 数量 |")
|
||||
sections.append(f"|------|------|")
|
||||
sections.append(f"| 入口点 | {entry_count} |")
|
||||
sections.append(f"| 任务 | {task_count} |")
|
||||
sections.append(f"| 加载器 | {loader_count} |")
|
||||
sections.append(f"| 孤立模块 | {orphan_count} |")
|
||||
sections.append("")
|
||||
|
||||
return "\n".join(sections)
|
||||
449
scripts/audit/inventory_analyzer.py
Normal file
449
scripts/audit/inventory_analyzer.py
Normal file
@@ -0,0 +1,449 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
文件清单分析器 — 对扫描结果进行用途分类和处置标签分配。
|
||||
|
||||
分类规则按优先级从高到低排列:
|
||||
1. tmp/ 下所有文件 → 临时与调试 / 候选删除或候选归档
|
||||
2. logs/、export/ 下的运行时产出 → 日志与输出 / 候选归档
|
||||
3. *.lnk、*.rar 文件 → 其他 / 候选删除
|
||||
4. 空目录 → 其他 / 候选删除
|
||||
5. 核心代码目录(tasks/ 等)→ 核心代码 / 保留
|
||||
6. config/ → 配置 / 保留
|
||||
7. database/*.sql、database/migrations/ → 数据库定义 / 保留
|
||||
8. database/*.py → 核心代码 / 保留
|
||||
9. tests/ → 测试 / 保留
|
||||
10. docs/ → 文档 / 保留
|
||||
11. scripts/ 下的 .py 文件 → 脚本工具 / 保留
|
||||
12. gui/ → GUI / 保留
|
||||
13. 构建与部署文件 → 构建与部署 / 保留
|
||||
14. 其余 → 其他 / 待确认
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections import Counter
|
||||
from datetime import datetime, timezone
|
||||
from itertools import groupby
|
||||
|
||||
from scripts.audit import Category, Disposition, FileEntry, InventoryItem
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 常量
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 核心代码顶层目录
|
||||
_CORE_CODE_DIRS = (
|
||||
"tasks/", "loaders/", "scd/", "orchestration/",
|
||||
"quality/", "models/", "utils/", "api/",
|
||||
)
|
||||
|
||||
# 构建与部署文件名(根目录级别)
|
||||
_BUILD_DEPLOY_BASENAMES = {"setup.py", "build_exe.py"}
|
||||
|
||||
# 构建与部署扩展名
|
||||
_BUILD_DEPLOY_EXTENSIONS = {".bat", ".sh", ".ps1"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _top_dir(rel_path: str) -> str:
|
||||
"""返回相对路径的第一级目录名(含尾部斜杠),如 'tmp/foo.py' → 'tmp/'。"""
|
||||
idx = rel_path.find("/")
|
||||
if idx == -1:
|
||||
return ""
|
||||
return rel_path[: idx + 1]
|
||||
|
||||
|
||||
def _basename(rel_path: str) -> str:
|
||||
"""返回路径的最后一段文件名。"""
|
||||
return rel_path.rsplit("/", 1)[-1]
|
||||
|
||||
|
||||
def _is_init_py(rel_path: str) -> bool:
|
||||
"""判断路径是否为 __init__.py。"""
|
||||
return _basename(rel_path) == "__init__.py"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# classify — 核心分类函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def classify(entry: FileEntry) -> InventoryItem:
|
||||
"""根据路径、扩展名等规则对单个文件/目录进行分类和标签分配。
|
||||
|
||||
规则按优先级从高到低依次匹配,首个命中的规则决定分类和处置。
|
||||
"""
|
||||
path = entry.rel_path
|
||||
top = _top_dir(path)
|
||||
ext = entry.extension.lower()
|
||||
base = _basename(path)
|
||||
|
||||
# --- 优先级 1: tmp/ 下所有文件 ---
|
||||
if top == "tmp/" or path == "tmp":
|
||||
return _classify_tmp(entry)
|
||||
|
||||
# --- 优先级 2: logs/、export/ 下的运行时产出 ---
|
||||
if top in ("logs/", "export/") or path in ("logs", "export"):
|
||||
return _classify_runtime_output(entry)
|
||||
|
||||
# --- 优先级 3: .lnk / .rar 文件 ---
|
||||
if ext in (".lnk", ".rar"):
|
||||
return InventoryItem(
|
||||
rel_path=path,
|
||||
category=Category.OTHER,
|
||||
disposition=Disposition.CANDIDATE_DELETE,
|
||||
description=f"快捷方式/压缩包文件(`{ext}`),建议删除",
|
||||
)
|
||||
|
||||
# --- 优先级 4: 空目录 ---
|
||||
if entry.is_empty_dir:
|
||||
return InventoryItem(
|
||||
rel_path=path,
|
||||
category=Category.OTHER,
|
||||
disposition=Disposition.CANDIDATE_DELETE,
|
||||
description="空目录,建议删除",
|
||||
)
|
||||
|
||||
# --- 优先级 5: 核心代码目录 ---
|
||||
if any(path.startswith(d) or path + "/" == d for d in _CORE_CODE_DIRS):
|
||||
return InventoryItem(
|
||||
rel_path=path,
|
||||
category=Category.CORE_CODE,
|
||||
disposition=Disposition.KEEP,
|
||||
description=f"核心代码(`{top.rstrip('/')}`)",
|
||||
)
|
||||
|
||||
# --- 优先级 6: config/ ---
|
||||
if top == "config/" or path == "config":
|
||||
return InventoryItem(
|
||||
rel_path=path,
|
||||
category=Category.CONFIG,
|
||||
disposition=Disposition.KEEP,
|
||||
description="配置文件",
|
||||
)
|
||||
|
||||
# --- 优先级 7: database/*.sql 和 database/migrations/ ---
|
||||
if top == "database/" or path == "database":
|
||||
return _classify_database(entry)
|
||||
|
||||
# --- 优先级 8: tests/ ---
|
||||
if top == "tests/" or path == "tests":
|
||||
return InventoryItem(
|
||||
rel_path=path,
|
||||
category=Category.TEST,
|
||||
disposition=Disposition.KEEP,
|
||||
description="测试文件",
|
||||
)
|
||||
|
||||
# --- 优先级 9: docs/ ---
|
||||
if top == "docs/" or path == "docs":
|
||||
return InventoryItem(
|
||||
rel_path=path,
|
||||
category=Category.DOCS,
|
||||
disposition=Disposition.KEEP,
|
||||
description="文档",
|
||||
)
|
||||
|
||||
# --- 优先级 10: scripts/ 下的 .py 文件 ---
|
||||
if top == "scripts/" or path == "scripts":
|
||||
cat = Category.SCRIPTS
|
||||
if ext == ".py" or entry.is_dir:
|
||||
return InventoryItem(
|
||||
rel_path=path,
|
||||
category=cat,
|
||||
disposition=Disposition.KEEP,
|
||||
description="脚本工具",
|
||||
)
|
||||
return InventoryItem(
|
||||
rel_path=path,
|
||||
category=cat,
|
||||
disposition=Disposition.NEEDS_REVIEW,
|
||||
description="脚本目录下的非 Python 文件,需确认用途",
|
||||
)
|
||||
|
||||
# --- 优先级 11: gui/ ---
|
||||
if top == "gui/" or path == "gui":
|
||||
return InventoryItem(
|
||||
rel_path=path,
|
||||
category=Category.GUI,
|
||||
disposition=Disposition.KEEP,
|
||||
description="GUI 模块",
|
||||
)
|
||||
|
||||
# --- 优先级 12: 构建与部署 ---
|
||||
if base in _BUILD_DEPLOY_BASENAMES or ext in _BUILD_DEPLOY_EXTENSIONS:
|
||||
return InventoryItem(
|
||||
rel_path=path,
|
||||
category=Category.BUILD_DEPLOY,
|
||||
disposition=Disposition.KEEP,
|
||||
description="构建与部署文件",
|
||||
)
|
||||
|
||||
# --- 优先级 13: cli/ ---
|
||||
if top == "cli/" or path == "cli":
|
||||
return InventoryItem(
|
||||
rel_path=path,
|
||||
category=Category.CORE_CODE,
|
||||
disposition=Disposition.KEEP,
|
||||
description="CLI 入口模块",
|
||||
)
|
||||
|
||||
# --- 优先级 14: 已知根目录文件 ---
|
||||
if "/" not in path:
|
||||
return _classify_root_file(entry)
|
||||
|
||||
# --- 兜底 ---
|
||||
return InventoryItem(
|
||||
rel_path=path,
|
||||
category=Category.OTHER,
|
||||
disposition=Disposition.NEEDS_REVIEW,
|
||||
description="未匹配已知规则,需人工确认用途",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 子分类函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _classify_tmp(entry: FileEntry) -> InventoryItem:
|
||||
"""tmp/ 目录下的文件分类。
|
||||
|
||||
默认候选删除;有意义的 .py 文件标记为候选归档。
|
||||
"""
|
||||
ext = entry.extension.lower()
|
||||
base = _basename(entry.rel_path)
|
||||
|
||||
# 空目录直接候选删除
|
||||
if entry.is_empty_dir:
|
||||
return InventoryItem(
|
||||
rel_path=entry.rel_path,
|
||||
category=Category.TEMP_DEBUG,
|
||||
disposition=Disposition.CANDIDATE_DELETE,
|
||||
description="临时目录下的空目录",
|
||||
)
|
||||
|
||||
# .py 文件可能有参考价值 → 候选归档
|
||||
if ext == ".py" and len(base) > 4:
|
||||
return InventoryItem(
|
||||
rel_path=entry.rel_path,
|
||||
category=Category.TEMP_DEBUG,
|
||||
disposition=Disposition.CANDIDATE_ARCHIVE,
|
||||
description="临时 Python 脚本,可能有参考价值",
|
||||
)
|
||||
|
||||
return InventoryItem(
|
||||
rel_path=entry.rel_path,
|
||||
category=Category.TEMP_DEBUG,
|
||||
disposition=Disposition.CANDIDATE_DELETE,
|
||||
description="临时/调试文件,建议删除",
|
||||
)
|
||||
|
||||
|
||||
def _classify_runtime_output(entry: FileEntry) -> InventoryItem:
|
||||
"""logs/、export/ 目录下的运行时产出分类。
|
||||
|
||||
__init__.py 保留(包标记),其余候选归档。
|
||||
"""
|
||||
if _is_init_py(entry.rel_path):
|
||||
return InventoryItem(
|
||||
rel_path=entry.rel_path,
|
||||
category=Category.LOG_OUTPUT,
|
||||
disposition=Disposition.KEEP,
|
||||
description="包初始化文件",
|
||||
)
|
||||
|
||||
return InventoryItem(
|
||||
rel_path=entry.rel_path,
|
||||
category=Category.LOG_OUTPUT,
|
||||
disposition=Disposition.CANDIDATE_ARCHIVE,
|
||||
description="运行时产出,建议归档",
|
||||
)
|
||||
|
||||
|
||||
def _classify_database(entry: FileEntry) -> InventoryItem:
|
||||
"""database/ 目录下的文件分类。"""
|
||||
path = entry.rel_path
|
||||
ext = entry.extension.lower()
|
||||
|
||||
# migrations/ 子目录
|
||||
if "migrations/" in path or path.endswith("migrations"):
|
||||
return InventoryItem(
|
||||
rel_path=path,
|
||||
category=Category.DATABASE_DEF,
|
||||
disposition=Disposition.KEEP,
|
||||
description="数据库迁移脚本",
|
||||
)
|
||||
|
||||
# .sql 文件
|
||||
if ext == ".sql":
|
||||
return InventoryItem(
|
||||
rel_path=path,
|
||||
category=Category.DATABASE_DEF,
|
||||
disposition=Disposition.KEEP,
|
||||
description="数据库 DDL/DML 脚本",
|
||||
)
|
||||
|
||||
# .py 文件 → 核心代码
|
||||
if ext == ".py":
|
||||
return InventoryItem(
|
||||
rel_path=path,
|
||||
category=Category.CORE_CODE,
|
||||
disposition=Disposition.KEEP,
|
||||
description="数据库操作模块",
|
||||
)
|
||||
|
||||
# 目录本身
|
||||
if entry.is_dir:
|
||||
if entry.is_empty_dir:
|
||||
return InventoryItem(
|
||||
rel_path=path,
|
||||
category=Category.OTHER,
|
||||
disposition=Disposition.CANDIDATE_DELETE,
|
||||
description="数据库目录下的空目录",
|
||||
)
|
||||
return InventoryItem(
|
||||
rel_path=path,
|
||||
category=Category.DATABASE_DEF,
|
||||
disposition=Disposition.KEEP,
|
||||
description="数据库子目录",
|
||||
)
|
||||
|
||||
# 其他文件
|
||||
return InventoryItem(
|
||||
rel_path=path,
|
||||
category=Category.DATABASE_DEF,
|
||||
disposition=Disposition.NEEDS_REVIEW,
|
||||
description="数据库目录下的非标准文件,需确认",
|
||||
)
|
||||
|
||||
|
||||
def _classify_root_file(entry: FileEntry) -> InventoryItem:
|
||||
"""根目录散落文件的分类。"""
|
||||
ext = entry.extension.lower()
|
||||
base = _basename(entry.rel_path)
|
||||
|
||||
# 已知构建文件
|
||||
if base in _BUILD_DEPLOY_BASENAMES or ext in _BUILD_DEPLOY_EXTENSIONS:
|
||||
return InventoryItem(
|
||||
rel_path=entry.rel_path,
|
||||
category=Category.BUILD_DEPLOY,
|
||||
disposition=Disposition.KEEP,
|
||||
description="构建与部署文件",
|
||||
)
|
||||
|
||||
# 已知配置文件
|
||||
if base in (
|
||||
"requirements.txt", "pytest.ini", ".env", ".env.example",
|
||||
".gitignore", ".flake8", "pyproject.toml",
|
||||
):
|
||||
return InventoryItem(
|
||||
rel_path=entry.rel_path,
|
||||
category=Category.CONFIG,
|
||||
disposition=Disposition.KEEP,
|
||||
description="项目配置文件",
|
||||
)
|
||||
|
||||
# README
|
||||
if base.lower().startswith("readme"):
|
||||
return InventoryItem(
|
||||
rel_path=entry.rel_path,
|
||||
category=Category.DOCS,
|
||||
disposition=Disposition.KEEP,
|
||||
description="项目说明文档",
|
||||
)
|
||||
|
||||
# 其他根目录文件 → 待确认
|
||||
return InventoryItem(
|
||||
rel_path=entry.rel_path,
|
||||
category=Category.OTHER,
|
||||
disposition=Disposition.NEEDS_REVIEW,
|
||||
description=f"根目录散落文件(`{base}`),需确认用途",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_inventory — 批量分类
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_inventory(entries: list[FileEntry]) -> list[InventoryItem]:
|
||||
"""对所有文件条目执行分类,返回清单列表。"""
|
||||
return [classify(e) for e in entries]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# render_inventory_report — Markdown 渲染
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def render_inventory_report(items: list[InventoryItem], repo_root: str) -> str:
|
||||
"""生成 Markdown 格式的文件清单报告。
|
||||
|
||||
报告结构:
|
||||
- 头部:标题、生成时间、仓库路径
|
||||
- 主体:按 Category 分组的表格
|
||||
- 尾部:统计摘要
|
||||
"""
|
||||
lines: list[str] = []
|
||||
|
||||
# --- 头部 ---
|
||||
now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
lines.append("# 文件清单报告")
|
||||
lines.append("")
|
||||
lines.append(f"- 生成时间:{now}")
|
||||
lines.append(f"- 仓库路径:`{repo_root}`")
|
||||
lines.append("")
|
||||
|
||||
# --- 按分类分组 ---
|
||||
# 保持 Category 枚举定义顺序
|
||||
cat_order = {c: i for i, c in enumerate(Category)}
|
||||
sorted_items = sorted(items, key=lambda it: cat_order[it.category])
|
||||
|
||||
for cat, group in groupby(sorted_items, key=lambda it: it.category):
|
||||
group_list = list(group)
|
||||
lines.append(f"## {cat.value}")
|
||||
lines.append("")
|
||||
lines.append("| 相对路径 | 处置标签 | 简要说明 |")
|
||||
lines.append("|---|---|---|")
|
||||
for item in group_list:
|
||||
lines.append(
|
||||
f"| `{item.rel_path}` | {item.disposition.value} | {item.description} |"
|
||||
)
|
||||
lines.append("")
|
||||
|
||||
# --- 统计摘要 ---
|
||||
lines.append("## 统计摘要")
|
||||
lines.append("")
|
||||
|
||||
# 各分类计数
|
||||
cat_counter: Counter[Category] = Counter()
|
||||
disp_counter: Counter[Disposition] = Counter()
|
||||
for item in items:
|
||||
cat_counter[item.category] += 1
|
||||
disp_counter[item.disposition] += 1
|
||||
|
||||
lines.append("### 按用途分类")
|
||||
lines.append("")
|
||||
lines.append("| 分类 | 数量 |")
|
||||
lines.append("|---|---|")
|
||||
for cat in Category:
|
||||
count = cat_counter.get(cat, 0)
|
||||
if count > 0:
|
||||
lines.append(f"| {cat.value} | {count} |")
|
||||
lines.append("")
|
||||
|
||||
lines.append("### 按处置标签")
|
||||
lines.append("")
|
||||
lines.append("| 标签 | 数量 |")
|
||||
lines.append("|---|---|")
|
||||
for disp in Disposition:
|
||||
count = disp_counter.get(disp, 0)
|
||||
if count > 0:
|
||||
lines.append(f"| {disp.value} | {count} |")
|
||||
lines.append("")
|
||||
|
||||
lines.append(f"**总计:{len(items)} 个条目**")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
255
scripts/audit/run_audit.py
Normal file
255
scripts/audit/run_audit.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
审计主入口 — 依次调用扫描器和三个分析器,生成三份报告到 docs/audit/。
|
||||
|
||||
仅在 docs/audit/ 目录下创建文件,不修改仓库中的任何现有文件。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from scripts.audit.scanner import scan_repo
|
||||
from scripts.audit.inventory_analyzer import (
|
||||
build_inventory,
|
||||
render_inventory_report,
|
||||
)
|
||||
from scripts.audit.flow_analyzer import (
|
||||
build_flow_tree,
|
||||
discover_entry_points,
|
||||
find_orphan_modules,
|
||||
render_flow_report,
|
||||
)
|
||||
from scripts.audit.doc_alignment_analyzer import (
|
||||
build_mappings,
|
||||
check_api_samples_vs_parsers,
|
||||
check_ddl_vs_dictionary,
|
||||
find_undocumented_modules,
|
||||
render_alignment_report,
|
||||
scan_docs,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 仓库根目录自动检测
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _detect_repo_root() -> Path:
|
||||
"""从当前文件向上查找仓库根目录。
|
||||
|
||||
判断依据:包含 cli/ 目录或 .git/ 目录的祖先目录。
|
||||
"""
|
||||
current = Path(__file__).resolve().parent
|
||||
for parent in (current, *current.parents):
|
||||
if (parent / "cli").is_dir() or (parent / ".git").is_dir():
|
||||
return parent
|
||||
# 回退:假设 scripts/audit/ 在仓库根目录下
|
||||
return current.parent.parent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 报告输出目录
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _ensure_report_dir(repo_root: Path) -> Path:
|
||||
"""检查并创建 docs/audit/ 目录。
|
||||
|
||||
如果目录已存在则直接返回;不存在则创建。
|
||||
创建失败时抛出 RuntimeError(因为无法输出报告)。
|
||||
"""
|
||||
audit_dir = repo_root / "docs" / "audit"
|
||||
if audit_dir.is_dir():
|
||||
return audit_dir
|
||||
try:
|
||||
audit_dir.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as exc:
|
||||
raise RuntimeError(f"无法创建报告输出目录 {audit_dir}: {exc}") from exc
|
||||
logger.info("已创建报告输出目录: %s", audit_dir)
|
||||
return audit_dir
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 报告头部元信息注入
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_HEADER_PATTERN = re.compile(r"生成时间[::]")
|
||||
_ISO_TS_PATTERN = re.compile(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z")
|
||||
# 匹配非 ISO 格式的时间戳行,用于替换
|
||||
_NON_ISO_TS_LINE = re.compile(
|
||||
r"([-*]\s*生成时间[::]\s*)\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2}"
|
||||
)
|
||||
|
||||
|
||||
def _inject_header(report: str, timestamp: str, repo_path: str) -> str:
|
||||
"""确保报告头部包含 ISO 格式时间戳和仓库路径。
|
||||
|
||||
- 已有 ISO 时间戳 → 不修改
|
||||
- 有非 ISO 时间戳 → 替换为 ISO 格式
|
||||
- 无头部 → 在标题后注入
|
||||
"""
|
||||
if _HEADER_PATTERN.search(report):
|
||||
# 已有头部——检查时间戳格式是否为 ISO
|
||||
if _ISO_TS_PATTERN.search(report):
|
||||
return report
|
||||
# 非 ISO 格式 → 替换时间戳
|
||||
report = _NON_ISO_TS_LINE.sub(
|
||||
lambda m: m.group(1) + timestamp, report,
|
||||
)
|
||||
# 同时确保仓库路径使用统一值(用 lambda 避免反斜杠转义问题)
|
||||
safe_path = repo_path
|
||||
report = re.sub(
|
||||
r"([-*]\s*仓库路径[::]\s*)`[^`]*`",
|
||||
lambda m: m.group(1) + "`" + safe_path + "`",
|
||||
report,
|
||||
)
|
||||
return report
|
||||
|
||||
# 无头部 → 在第一个标题行之后插入
|
||||
lines = report.split("\n")
|
||||
insert_idx = 1
|
||||
for i, line in enumerate(lines):
|
||||
if line.startswith("# "):
|
||||
insert_idx = i + 1
|
||||
break
|
||||
|
||||
header_lines = [
|
||||
"",
|
||||
f"- 生成时间: {timestamp}",
|
||||
f"- 仓库路径: `{repo_path}`",
|
||||
"",
|
||||
]
|
||||
lines[insert_idx:insert_idx] = header_lines
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 主函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def run_audit(repo_root: Path | None = None) -> None:
|
||||
"""执行完整审计流程,生成三份报告到 docs/audit/。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
repo_root : Path | None
|
||||
仓库根目录。为 None 时自动检测。
|
||||
"""
|
||||
# 1. 确定仓库根目录
|
||||
if repo_root is None:
|
||||
repo_root = _detect_repo_root()
|
||||
repo_root = repo_root.resolve()
|
||||
repo_path_str = str(repo_root)
|
||||
|
||||
logger.info("审计开始 — 仓库路径: %s", repo_path_str)
|
||||
|
||||
# 2. 检查/创建输出目录
|
||||
audit_dir = _ensure_report_dir(repo_root)
|
||||
|
||||
# 3. 生成 UTC 时间戳(所有报告共用)
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
# 4. 扫描仓库
|
||||
logger.info("正在扫描仓库文件...")
|
||||
entries = scan_repo(repo_root)
|
||||
logger.info("扫描完成,共 %d 个条目", len(entries))
|
||||
|
||||
# 5. 文件清单报告
|
||||
logger.info("正在生成文件清单报告...")
|
||||
try:
|
||||
inventory_items = build_inventory(entries)
|
||||
inventory_report = render_inventory_report(inventory_items, repo_path_str)
|
||||
inventory_report = _inject_header(inventory_report, timestamp, repo_path_str)
|
||||
(audit_dir / "file_inventory.md").write_text(
|
||||
inventory_report, encoding="utf-8",
|
||||
)
|
||||
logger.info("文件清单报告已写入: file_inventory.md")
|
||||
except Exception:
|
||||
logger.exception("生成文件清单报告时出错")
|
||||
|
||||
# 6. 流程树报告
|
||||
logger.info("正在生成流程树报告...")
|
||||
try:
|
||||
entry_points = discover_entry_points(repo_root)
|
||||
trees = []
|
||||
reachable: set[str] = set()
|
||||
for ep in entry_points:
|
||||
ep_file = ep["file"]
|
||||
# 批处理文件不构建流程树
|
||||
if not ep_file.endswith(".py"):
|
||||
continue
|
||||
tree = build_flow_tree(repo_root, ep_file)
|
||||
trees.append(tree)
|
||||
# 收集可达模块
|
||||
_collect_reachable(tree, reachable)
|
||||
|
||||
orphans = find_orphan_modules(repo_root, entries, reachable)
|
||||
flow_report = render_flow_report(trees, orphans, repo_path_str)
|
||||
flow_report = _inject_header(flow_report, timestamp, repo_path_str)
|
||||
(audit_dir / "flow_tree.md").write_text(
|
||||
flow_report, encoding="utf-8",
|
||||
)
|
||||
logger.info("流程树报告已写入: flow_tree.md")
|
||||
except Exception:
|
||||
logger.exception("生成流程树报告时出错")
|
||||
|
||||
# 7. 文档对齐报告
|
||||
logger.info("正在生成文档对齐报告...")
|
||||
try:
|
||||
doc_paths = scan_docs(repo_root)
|
||||
mappings = build_mappings(doc_paths, repo_root)
|
||||
|
||||
issues = []
|
||||
issues.extend(check_ddl_vs_dictionary(repo_root))
|
||||
issues.extend(check_api_samples_vs_parsers(repo_root))
|
||||
|
||||
# 缺失文档检测
|
||||
documented: set[str] = set()
|
||||
for m in mappings:
|
||||
documented.update(m.related_code)
|
||||
undoc_modules = find_undocumented_modules(repo_root, documented)
|
||||
from scripts.audit import AlignmentIssue
|
||||
for mod in undoc_modules:
|
||||
issues.append(AlignmentIssue(
|
||||
doc_path="—",
|
||||
issue_type="missing",
|
||||
description=f"核心代码模块 `{mod}` 缺少对应文档",
|
||||
related_code=mod,
|
||||
))
|
||||
|
||||
alignment_report = render_alignment_report(mappings, issues, repo_path_str)
|
||||
alignment_report = _inject_header(alignment_report, timestamp, repo_path_str)
|
||||
(audit_dir / "doc_alignment.md").write_text(
|
||||
alignment_report, encoding="utf-8",
|
||||
)
|
||||
logger.info("文档对齐报告已写入: doc_alignment.md")
|
||||
except Exception:
|
||||
logger.exception("生成文档对齐报告时出错")
|
||||
|
||||
logger.info("审计完成 — 报告输出目录: %s", audit_dir)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助:收集可达模块
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _collect_reachable(node, reachable: set[str]) -> None:
|
||||
"""递归收集流程树中所有节点的 source_file。"""
|
||||
reachable.add(node.source_file)
|
||||
for child in node.children:
|
||||
_collect_reachable(child, reachable)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 入口
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
)
|
||||
run_audit()
|
||||
150
scripts/audit/scanner.py
Normal file
150
scripts/audit/scanner.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
仓库扫描器 — 递归遍历仓库文件系统,返回结构化的文件元信息。
|
||||
|
||||
仅执行只读操作:读取文件元信息(大小、类型),不修改任何文件。
|
||||
遇到权限错误时跳过并记录日志,不中断扫描流程。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import fnmatch
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from scripts.audit import FileEntry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 排除模式
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
EXCLUDED_PATTERNS: list[str] = [
|
||||
".git",
|
||||
"__pycache__",
|
||||
".pytest_cache",
|
||||
"*.pyc",
|
||||
".kiro",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 排除匹配逻辑
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _is_excluded(name: str, patterns: list[str]) -> bool:
|
||||
"""判断文件/目录名是否匹配任一排除模式。
|
||||
|
||||
支持两种模式:
|
||||
- 精确匹配(如 ".git"、"__pycache__")
|
||||
- 通配符匹配(如 "*.pyc"),使用 fnmatch 语义
|
||||
"""
|
||||
for pat in patterns:
|
||||
if fnmatch.fnmatch(name, pat):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 递归遍历
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _walk(
|
||||
root: Path,
|
||||
base: Path,
|
||||
exclude: list[str],
|
||||
results: list[FileEntry],
|
||||
) -> None:
|
||||
"""递归遍历 *root* 下的文件和目录,将结果追加到 *results*。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
root : Path
|
||||
当前要遍历的目录。
|
||||
base : Path
|
||||
仓库根目录,用于计算相对路径。
|
||||
exclude : list[str]
|
||||
排除模式列表。
|
||||
results : list[FileEntry]
|
||||
收集结果的列表(就地修改)。
|
||||
"""
|
||||
try:
|
||||
children = sorted(root.iterdir(), key=lambda p: p.name)
|
||||
except (PermissionError, OSError) as exc:
|
||||
logger.warning("无法读取目录 %s: %s", root, exc)
|
||||
return
|
||||
|
||||
# 用于判断当前目录是否为"空目录"(排除后无可见子项)
|
||||
visible_count = 0
|
||||
|
||||
for child in children:
|
||||
if _is_excluded(child.name, exclude):
|
||||
continue
|
||||
|
||||
visible_count += 1
|
||||
rel = child.relative_to(base).as_posix()
|
||||
|
||||
if child.is_dir():
|
||||
# 先递归子目录,再判断该目录是否为空
|
||||
sub_start = len(results)
|
||||
_walk(child, base, exclude, results)
|
||||
sub_end = len(results)
|
||||
|
||||
# 该目录下递归产生的条目数为 0 → 空目录
|
||||
is_empty = (sub_end == sub_start)
|
||||
|
||||
results.append(FileEntry(
|
||||
rel_path=rel,
|
||||
is_dir=True,
|
||||
size_bytes=0,
|
||||
extension="",
|
||||
is_empty_dir=is_empty,
|
||||
))
|
||||
else:
|
||||
# 文件
|
||||
try:
|
||||
size = child.stat().st_size
|
||||
except (PermissionError, OSError) as exc:
|
||||
logger.warning("无法获取文件信息 %s: %s", child, exc)
|
||||
continue
|
||||
|
||||
results.append(FileEntry(
|
||||
rel_path=rel,
|
||||
is_dir=False,
|
||||
size_bytes=size,
|
||||
extension=child.suffix.lower(),
|
||||
is_empty_dir=False,
|
||||
))
|
||||
|
||||
# 如果 root 是仓库根目录自身,不需要额外处理
|
||||
# (根目录不作为条目出现在结果中)
|
||||
|
||||
|
||||
def scan_repo(
|
||||
root: Path,
|
||||
exclude: list[str] | None = None,
|
||||
) -> list[FileEntry]:
|
||||
"""递归扫描仓库,返回所有文件和目录的元信息列表。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
root : Path
|
||||
仓库根目录路径。
|
||||
exclude : list[str] | None
|
||||
排除模式列表,默认使用 EXCLUDED_PATTERNS。
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[FileEntry]
|
||||
按 rel_path 排序的文件/目录元信息列表。
|
||||
"""
|
||||
if exclude is None:
|
||||
exclude = EXCLUDED_PATTERNS
|
||||
|
||||
results: list[FileEntry] = []
|
||||
_walk(root, root, exclude, results)
|
||||
|
||||
# 按相对路径排序,保证输出稳定
|
||||
results.sort(key=lambda e: e.rel_path)
|
||||
return results
|
||||
193
scripts/check/check_data_integrity.py
Normal file
193
scripts/check/check_data_integrity.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Run data integrity checks across API -> ODS -> DWD."""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from dateutil import parser as dtparser
|
||||
|
||||
from config.settings import AppConfig
|
||||
from quality.integrity_service import run_history_flow, run_window_flow, write_report
|
||||
from utils.logging_utils import build_log_path, configure_logging
|
||||
from utils.windowing import split_window
|
||||
|
||||
|
||||
def _parse_dt(value: str, tz: ZoneInfo) -> datetime:
|
||||
dt = dtparser.parse(value)
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=tz)
|
||||
return dt.astimezone(tz)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
try:
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
ap = argparse.ArgumentParser(description="Data integrity checks (API -> ODS -> DWD)")
|
||||
ap.add_argument("--mode", choices=["history", "window"], default="history")
|
||||
ap.add_argument(
|
||||
"--flow",
|
||||
choices=["verify", "update_and_verify"],
|
||||
default="verify",
|
||||
help="verify only or update+verify (auto backfill then optional recheck)",
|
||||
)
|
||||
ap.add_argument("--start", default="2025-07-01", help="history start date (default: 2025-07-01)")
|
||||
ap.add_argument("--end", default="", help="history end datetime (default: last ETL end)")
|
||||
ap.add_argument("--window-start", default="", help="window start datetime (mode=window)")
|
||||
ap.add_argument("--window-end", default="", help="window end datetime (mode=window)")
|
||||
ap.add_argument("--window-split-unit", default="", help="split unit (month/none), default from config")
|
||||
ap.add_argument("--window-compensation-hours", type=int, default=None, help="window compensation hours, default from config")
|
||||
ap.add_argument(
|
||||
"--include-dimensions",
|
||||
action="store_true",
|
||||
default=None,
|
||||
help="include dimension tables in ODS->DWD checks",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--no-include-dimensions",
|
||||
action="store_true",
|
||||
help="exclude dimension tables in ODS->DWD checks",
|
||||
)
|
||||
ap.add_argument("--ods-task-codes", default="", help="comma-separated ODS task codes for API checks")
|
||||
ap.add_argument("--compare-content", action="store_true", help="compare API vs ODS content hash")
|
||||
ap.add_argument("--no-compare-content", action="store_true", help="disable content comparison even if enabled in config")
|
||||
ap.add_argument("--include-mismatch", action="store_true", help="backfill mismatch records as well")
|
||||
ap.add_argument("--no-include-mismatch", action="store_true", help="disable mismatch backfill")
|
||||
ap.add_argument("--recheck", action="store_true", help="re-run checks after backfill")
|
||||
ap.add_argument("--no-recheck", action="store_true", help="skip recheck after backfill")
|
||||
ap.add_argument("--content-sample-limit", type=int, default=None, help="max mismatch samples per table")
|
||||
ap.add_argument("--out", default="", help="output JSON path")
|
||||
ap.add_argument("--log-file", default="", help="log file path")
|
||||
ap.add_argument("--log-dir", default="", help="log directory")
|
||||
ap.add_argument("--log-level", default="INFO", help="log level")
|
||||
ap.add_argument("--no-log-console", action="store_true", help="disable console logging")
|
||||
args = ap.parse_args()
|
||||
|
||||
log_dir = Path(args.log_dir) if args.log_dir else (Path(__file__).resolve().parent / "logs")
|
||||
log_file = Path(args.log_file) if args.log_file else build_log_path(log_dir, "data_integrity")
|
||||
log_console = not args.no_log_console
|
||||
|
||||
with configure_logging(
|
||||
"data_integrity",
|
||||
log_file,
|
||||
level=args.log_level,
|
||||
console=log_console,
|
||||
tee_std=True,
|
||||
) as logger:
|
||||
cfg = AppConfig.load({})
|
||||
tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei"))
|
||||
report_path = Path(args.out) if args.out else None
|
||||
|
||||
if args.recheck and args.no_recheck:
|
||||
raise SystemExit("cannot set both --recheck and --no-recheck")
|
||||
if args.include_mismatch and args.no_include_mismatch:
|
||||
raise SystemExit("cannot set both --include-mismatch and --no-include-mismatch")
|
||||
if args.include_dimensions and args.no_include_dimensions:
|
||||
raise SystemExit("cannot set both --include-dimensions and --no-include-dimensions")
|
||||
|
||||
compare_content = None
|
||||
if args.compare_content and args.no_compare_content:
|
||||
raise SystemExit("cannot set both --compare-content and --no-compare-content")
|
||||
if args.compare_content:
|
||||
compare_content = True
|
||||
elif args.no_compare_content:
|
||||
compare_content = False
|
||||
|
||||
include_mismatch = cfg.get("integrity.backfill_mismatch", True)
|
||||
if args.include_mismatch:
|
||||
include_mismatch = True
|
||||
elif args.no_include_mismatch:
|
||||
include_mismatch = False
|
||||
|
||||
recheck_after_backfill = cfg.get("integrity.recheck_after_backfill", True)
|
||||
if args.recheck:
|
||||
recheck_after_backfill = True
|
||||
elif args.no_recheck:
|
||||
recheck_after_backfill = False
|
||||
|
||||
include_dimensions = cfg.get("integrity.include_dimensions", True)
|
||||
if args.include_dimensions:
|
||||
include_dimensions = True
|
||||
elif args.no_include_dimensions:
|
||||
include_dimensions = False
|
||||
|
||||
if args.mode == "window":
|
||||
if not args.window_start or not args.window_end:
|
||||
raise SystemExit("window-start and window-end are required for mode=window")
|
||||
start_dt = _parse_dt(args.window_start, tz)
|
||||
end_dt = _parse_dt(args.window_end, tz)
|
||||
split_unit = (args.window_split_unit or cfg.get("run.window_split.unit", "month") or "month").strip()
|
||||
comp_hours = args.window_compensation_hours
|
||||
if comp_hours is None:
|
||||
comp_hours = cfg.get("run.window_split.compensation_hours", 0)
|
||||
|
||||
windows = split_window(
|
||||
start_dt,
|
||||
end_dt,
|
||||
tz=tz,
|
||||
split_unit=split_unit,
|
||||
compensation_hours=comp_hours,
|
||||
)
|
||||
if not windows:
|
||||
windows = [(start_dt, end_dt)]
|
||||
|
||||
report, counts = run_window_flow(
|
||||
cfg=cfg,
|
||||
windows=windows,
|
||||
include_dimensions=bool(include_dimensions),
|
||||
task_codes=args.ods_task_codes,
|
||||
logger=logger,
|
||||
compare_content=compare_content,
|
||||
content_sample_limit=args.content_sample_limit,
|
||||
do_backfill=args.flow == "update_and_verify",
|
||||
include_mismatch=bool(include_mismatch),
|
||||
recheck_after_backfill=bool(recheck_after_backfill),
|
||||
page_size=int(cfg.get("api.page_size") or 200),
|
||||
chunk_size=500,
|
||||
)
|
||||
report_path = write_report(report, prefix="data_integrity_window", tz=tz, report_path=report_path)
|
||||
report["report_path"] = report_path
|
||||
logger.info("REPORT_WRITTEN path=%s", report.get("report_path"))
|
||||
else:
|
||||
start_dt = _parse_dt(args.start, tz)
|
||||
if args.end:
|
||||
end_dt = _parse_dt(args.end, tz)
|
||||
else:
|
||||
end_dt = None
|
||||
report, counts = run_history_flow(
|
||||
cfg=cfg,
|
||||
start_dt=start_dt,
|
||||
end_dt=end_dt,
|
||||
include_dimensions=bool(include_dimensions),
|
||||
task_codes=args.ods_task_codes,
|
||||
logger=logger,
|
||||
compare_content=compare_content,
|
||||
content_sample_limit=args.content_sample_limit,
|
||||
do_backfill=args.flow == "update_and_verify",
|
||||
include_mismatch=bool(include_mismatch),
|
||||
recheck_after_backfill=bool(recheck_after_backfill),
|
||||
page_size=int(cfg.get("api.page_size") or 200),
|
||||
chunk_size=500,
|
||||
)
|
||||
report_path = write_report(report, prefix="data_integrity_history", tz=tz, report_path=report_path)
|
||||
report["report_path"] = report_path
|
||||
logger.info("REPORT_WRITTEN path=%s", report.get("report_path"))
|
||||
logger.info(
|
||||
"SUMMARY missing=%s mismatch=%s errors=%s",
|
||||
counts.get("missing"),
|
||||
counts.get("mismatch"),
|
||||
counts.get("errors"),
|
||||
)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
82
scripts/check/check_dwd_service.py
Normal file
82
scripts/check/check_dwd_service.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import sys
|
||||
sys.path.insert(0, '.')
|
||||
from config.settings import AppConfig
|
||||
from database.connection import DatabaseConnection
|
||||
from database.operations import DatabaseOperations
|
||||
|
||||
config = AppConfig.load()
|
||||
db_conn = DatabaseConnection(config.config['db']['dsn'])
|
||||
db = DatabaseOperations(db_conn)
|
||||
|
||||
# 检查DWD层服务记录分布
|
||||
print("=== DWD层服务记录分析 ===")
|
||||
print()
|
||||
|
||||
# 1. 总体统计
|
||||
sql1 = """
|
||||
SELECT
|
||||
COUNT(*) as total_records,
|
||||
COUNT(DISTINCT tenant_member_id) as unique_members,
|
||||
COUNT(DISTINCT site_assistant_id) as unique_assistants,
|
||||
COUNT(DISTINCT (tenant_member_id, site_assistant_id)) as unique_pairs
|
||||
FROM billiards_dwd.dwd_assistant_service_log
|
||||
WHERE tenant_member_id > 0 AND is_delete = 0
|
||||
"""
|
||||
r = dict(db.query(sql1)[0])
|
||||
print("总体统计:")
|
||||
print(f" 总服务记录数: {r['total_records']}")
|
||||
print(f" 唯一会员数: {r['unique_members']}")
|
||||
print(f" 唯一助教数: {r['unique_assistants']}")
|
||||
print(f" 唯一客户-助教对: {r['unique_pairs']}")
|
||||
|
||||
# 2. 助教服务会员数分布
|
||||
print()
|
||||
print("助教服务会员数分布 (Top 10):")
|
||||
sql2 = """
|
||||
SELECT site_assistant_id, COUNT(DISTINCT tenant_member_id) as member_count
|
||||
FROM billiards_dwd.dwd_assistant_service_log
|
||||
WHERE tenant_member_id > 0 AND is_delete = 0
|
||||
GROUP BY site_assistant_id
|
||||
ORDER BY member_count DESC
|
||||
LIMIT 10
|
||||
"""
|
||||
for row in db.query(sql2):
|
||||
r = dict(row)
|
||||
print(f" 助教 {r['site_assistant_id']}: 服务 {r['member_count']} 个会员")
|
||||
|
||||
# 3. 每个客户-助教对的服务次数分布
|
||||
print()
|
||||
print("客户-助教对 服务次数分布 (Top 10):")
|
||||
sql3 = """
|
||||
SELECT tenant_member_id, site_assistant_id, COUNT(*) as service_count
|
||||
FROM billiards_dwd.dwd_assistant_service_log
|
||||
WHERE tenant_member_id > 0 AND is_delete = 0
|
||||
GROUP BY tenant_member_id, site_assistant_id
|
||||
ORDER BY service_count DESC
|
||||
LIMIT 10
|
||||
"""
|
||||
for row in db.query(sql3):
|
||||
r = dict(row)
|
||||
print(f" 会员 {r['tenant_member_id']} - 助教 {r['site_assistant_id']}: {r['service_count']} 次服务")
|
||||
|
||||
# 4. 近60天的数据
|
||||
print()
|
||||
print("=== 近60天数据 ===")
|
||||
sql4 = """
|
||||
SELECT
|
||||
COUNT(*) as total_records,
|
||||
COUNT(DISTINCT tenant_member_id) as unique_members,
|
||||
COUNT(DISTINCT site_assistant_id) as unique_assistants,
|
||||
COUNT(DISTINCT (tenant_member_id, site_assistant_id)) as unique_pairs
|
||||
FROM billiards_dwd.dwd_assistant_service_log
|
||||
WHERE tenant_member_id > 0 AND is_delete = 0
|
||||
AND last_use_time >= NOW() - INTERVAL '60 days'
|
||||
"""
|
||||
r4 = dict(db.query(sql4)[0])
|
||||
print(f" 总服务记录数: {r4['total_records']}")
|
||||
print(f" 唯一会员数: {r4['unique_members']}")
|
||||
print(f" 唯一助教数: {r4['unique_assistants']}")
|
||||
print(f" 唯一客户-助教对: {r4['unique_pairs']}")
|
||||
|
||||
db_conn.close()
|
||||
248
scripts/check/check_ods_content_hash.py
Normal file
248
scripts/check/check_ods_content_hash.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Validate that ODS payload content matches stored content_hash.
|
||||
|
||||
Usage:
|
||||
PYTHONPATH=. python -m scripts.check.check_ods_content_hash
|
||||
PYTHONPATH=. python -m scripts.check.check_ods_content_hash --schema billiards_ods
|
||||
PYTHONPATH=. python -m scripts.check.check_ods_content_hash --tables member_profiles,orders
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Sequence
|
||||
|
||||
from psycopg2.extras import RealDictCursor
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from config.settings import AppConfig
|
||||
from database.connection import DatabaseConnection
|
||||
from tasks.ods.ods_tasks import BaseOdsTask
|
||||
|
||||
|
||||
def _reconfigure_stdout_utf8() -> None:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
try:
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _fetch_tables(conn, schema: str) -> list[str]:
|
||||
sql = """
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = %s AND table_type = 'BASE TABLE'
|
||||
ORDER BY table_name
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (schema,))
|
||||
return [r[0] for r in cur.fetchall()]
|
||||
|
||||
|
||||
def _fetch_columns(conn, schema: str, table: str) -> list[str]:
|
||||
sql = """
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = %s AND table_name = %s
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (schema, table))
|
||||
cols = [r[0] for r in cur.fetchall()]
|
||||
return [c for c in cols if c]
|
||||
|
||||
|
||||
def _fetch_pk_columns(conn, schema: str, table: str) -> list[str]:
|
||||
sql = """
|
||||
SELECT kcu.column_name
|
||||
FROM information_schema.table_constraints tc
|
||||
JOIN information_schema.key_column_usage kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
WHERE tc.constraint_type = 'PRIMARY KEY'
|
||||
AND tc.table_schema = %s
|
||||
AND tc.table_name = %s
|
||||
ORDER BY kcu.ordinal_position
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (schema, table))
|
||||
cols = [r[0] for r in cur.fetchall()]
|
||||
return [c for c in cols if c.lower() != "content_hash"]
|
||||
|
||||
|
||||
def _fetch_row_count(conn, schema: str, table: str) -> int:
|
||||
sql = f'SELECT COUNT(*) FROM "{schema}"."{table}"'
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql)
|
||||
row = cur.fetchone()
|
||||
return int(row[0] if row else 0)
|
||||
|
||||
|
||||
def _iter_rows(
|
||||
conn,
|
||||
schema: str,
|
||||
table: str,
|
||||
select_cols: Sequence[str],
|
||||
batch_size: int,
|
||||
) -> Iterable[dict]:
|
||||
cols_sql = ", ".join(f'"{c}"' for c in select_cols)
|
||||
sql = f'SELECT {cols_sql} FROM "{schema}"."{table}"'
|
||||
with conn.cursor(name=f"ods_hash_{table}", cursor_factory=RealDictCursor) as cur:
|
||||
cur.itersize = max(1, int(batch_size or 500))
|
||||
cur.execute(sql)
|
||||
for row in cur:
|
||||
yield row
|
||||
|
||||
|
||||
def _build_report_path(out_arg: str | None) -> Path:
|
||||
if out_arg:
|
||||
return Path(out_arg)
|
||||
reports_dir = PROJECT_ROOT / "reports"
|
||||
reports_dir.mkdir(parents=True, exist_ok=True)
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
return reports_dir / f"ods_content_hash_check_{ts}.json"
|
||||
|
||||
|
||||
def _print_progress(
|
||||
table_label: str,
|
||||
processed: int,
|
||||
total: int,
|
||||
mismatched: int,
|
||||
missing_hash: int,
|
||||
invalid_payload: int,
|
||||
) -> None:
|
||||
if total:
|
||||
msg = (
|
||||
f"[{table_label}] checked {processed}/{total} "
|
||||
f"mismatch={mismatched} missing_hash={missing_hash} invalid_payload={invalid_payload}"
|
||||
)
|
||||
else:
|
||||
msg = (
|
||||
f"[{table_label}] checked {processed} "
|
||||
f"mismatch={mismatched} missing_hash={missing_hash} invalid_payload={invalid_payload}"
|
||||
)
|
||||
print(msg, flush=True)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
_reconfigure_stdout_utf8()
|
||||
ap = argparse.ArgumentParser(description="Validate ODS payload vs content_hash consistency")
|
||||
ap.add_argument("--schema", default="billiards_ods", help="ODS schema name")
|
||||
ap.add_argument("--tables", default="", help="comma-separated table names (optional)")
|
||||
ap.add_argument("--batch-size", type=int, default=500, help="DB fetch batch size")
|
||||
ap.add_argument("--progress-every", type=int, default=100, help="print progress every N rows")
|
||||
ap.add_argument("--sample-limit", type=int, default=5, help="sample mismatch rows per table")
|
||||
ap.add_argument("--out", default="", help="output report JSON path")
|
||||
args = ap.parse_args()
|
||||
|
||||
cfg = AppConfig.load({})
|
||||
db = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
|
||||
conn = db.conn
|
||||
|
||||
tables = _fetch_tables(conn, args.schema)
|
||||
if args.tables.strip():
|
||||
whitelist = {t.strip() for t in args.tables.split(",") if t.strip()}
|
||||
tables = [t for t in tables if t in whitelist]
|
||||
|
||||
report = {
|
||||
"schema": args.schema,
|
||||
"tables": [],
|
||||
"summary": {
|
||||
"total_tables": 0,
|
||||
"checked_tables": 0,
|
||||
"total_rows": 0,
|
||||
"checked_rows": 0,
|
||||
"mismatch_rows": 0,
|
||||
"missing_hash_rows": 0,
|
||||
"invalid_payload_rows": 0,
|
||||
},
|
||||
}
|
||||
|
||||
for table in tables:
|
||||
table_label = f"{args.schema}.{table}"
|
||||
cols = _fetch_columns(conn, args.schema, table)
|
||||
cols_lower = {c.lower() for c in cols}
|
||||
if "payload" not in cols_lower or "content_hash" not in cols_lower:
|
||||
print(f"[{table_label}] skip: missing payload/content_hash", flush=True)
|
||||
continue
|
||||
|
||||
total = _fetch_row_count(conn, args.schema, table)
|
||||
pk_cols = _fetch_pk_columns(conn, args.schema, table)
|
||||
select_cols = ["content_hash", "payload", *pk_cols]
|
||||
|
||||
processed = 0
|
||||
mismatched = 0
|
||||
missing_hash = 0
|
||||
invalid_payload = 0
|
||||
samples: list[dict[str, Any]] = []
|
||||
|
||||
print(f"[{table_label}] start: total_rows={total}", flush=True)
|
||||
|
||||
for row in _iter_rows(conn, args.schema, table, select_cols, args.batch_size):
|
||||
processed += 1
|
||||
content_hash = row.get("content_hash")
|
||||
payload = row.get("payload")
|
||||
recomputed = BaseOdsTask._compute_compare_hash_from_payload(payload)
|
||||
|
||||
row_mismatch = False
|
||||
if not content_hash:
|
||||
missing_hash += 1
|
||||
mismatched += 1
|
||||
row_mismatch = True
|
||||
elif not recomputed:
|
||||
invalid_payload += 1
|
||||
mismatched += 1
|
||||
row_mismatch = True
|
||||
elif content_hash != recomputed:
|
||||
mismatched += 1
|
||||
row_mismatch = True
|
||||
|
||||
if row_mismatch and len(samples) < max(0, int(args.sample_limit or 0)):
|
||||
sample = {k: row.get(k) for k in pk_cols}
|
||||
sample["content_hash"] = content_hash
|
||||
sample["recomputed_hash"] = recomputed
|
||||
samples.append(sample)
|
||||
|
||||
if args.progress_every and processed % int(args.progress_every) == 0:
|
||||
_print_progress(table_label, processed, total, mismatched, missing_hash, invalid_payload)
|
||||
|
||||
if processed and (not args.progress_every or processed % int(args.progress_every) != 0):
|
||||
_print_progress(table_label, processed, total, mismatched, missing_hash, invalid_payload)
|
||||
|
||||
report["tables"].append(
|
||||
{
|
||||
"table": table_label,
|
||||
"total_rows": total,
|
||||
"checked_rows": processed,
|
||||
"mismatch_rows": mismatched,
|
||||
"missing_hash_rows": missing_hash,
|
||||
"invalid_payload_rows": invalid_payload,
|
||||
"sample_mismatches": samples,
|
||||
}
|
||||
)
|
||||
|
||||
report["summary"]["checked_tables"] += 1
|
||||
report["summary"]["total_rows"] += total
|
||||
report["summary"]["checked_rows"] += processed
|
||||
report["summary"]["mismatch_rows"] += mismatched
|
||||
report["summary"]["missing_hash_rows"] += missing_hash
|
||||
report["summary"]["invalid_payload_rows"] += invalid_payload
|
||||
|
||||
report["summary"]["total_tables"] = len(tables)
|
||||
|
||||
out_path = _build_report_path(args.out)
|
||||
out_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(f"[REPORT] {out_path}", flush=True)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
1004
scripts/check/check_ods_gaps.py
Normal file
1004
scripts/check/check_ods_gaps.py
Normal file
File diff suppressed because it is too large
Load Diff
117
scripts/check/check_ods_json_vs_table.py
Normal file
117
scripts/check/check_ods_json_vs_table.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
ODS JSON 字段核对脚本:对照当前数据库中的 ODS 表字段,检查示例 JSON(默认目录 export/test-json-doc)
|
||||
是否包含同名键,并输出每表未命中的字段,便于补充映射或确认确实无源字段。
|
||||
|
||||
使用方法:
|
||||
set PG_DSN=postgresql://... # 如 .env 中配置
|
||||
python -m scripts.check.check_ods_json_vs_table
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
from typing import Dict, Iterable, Set, Tuple
|
||||
|
||||
import psycopg2
|
||||
|
||||
from tasks.manual_ingest_task import ManualIngestTask
|
||||
|
||||
|
||||
def _flatten_keys(obj, prefix: str = "") -> Set[str]:
|
||||
"""递归展开 JSON 所有键路径,返回形如 data.assistantInfos.id 的集合。列表不保留索引,仅继续向下展开。"""
|
||||
keys: Set[str] = set()
|
||||
if isinstance(obj, dict):
|
||||
for k, v in obj.items():
|
||||
new_prefix = f"{prefix}.{k}" if prefix else k
|
||||
keys.add(new_prefix)
|
||||
keys |= _flatten_keys(v, new_prefix)
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
keys |= _flatten_keys(item, prefix)
|
||||
return keys
|
||||
|
||||
|
||||
def _load_json_keys(path: pathlib.Path) -> Tuple[Set[str], dict[str, Set[str]]]:
|
||||
"""读取单个 JSON 文件并返回展开后的键集合以及末段->路径列表映射,若文件不存在或无法解析则返回空集合。"""
|
||||
if not path.exists():
|
||||
return set(), {}
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
paths = _flatten_keys(data)
|
||||
last_map: dict[str, Set[str]] = {}
|
||||
for p in paths:
|
||||
last = p.split(".")[-1].lower()
|
||||
last_map.setdefault(last, set()).add(p)
|
||||
return paths, last_map
|
||||
|
||||
|
||||
def _load_ods_columns(dsn: str) -> Dict[str, Set[str]]:
|
||||
"""从数据库读取 billiards_ods.* 的列名集合,按表返回。"""
|
||||
conn = psycopg2.connect(dsn)
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT table_name, column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema='billiards_ods'
|
||||
ORDER BY table_name, ordinal_position
|
||||
"""
|
||||
)
|
||||
result: Dict[str, Set[str]] = {}
|
||||
for table, col in cur.fetchall():
|
||||
result.setdefault(table, set()).add(col.lower())
|
||||
cur.close()
|
||||
conn.close()
|
||||
return result
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""主流程:遍历 FILE_MAPPING 中的 ODS 表,检查 JSON 键覆盖情况并打印报告。"""
|
||||
dsn = os.environ.get("PG_DSN")
|
||||
json_dir = pathlib.Path(os.environ.get("JSON_DOC_DIR", "export/test-json-doc"))
|
||||
|
||||
ods_cols_map = _load_ods_columns(dsn)
|
||||
|
||||
print(f"使用 JSON 目录: {json_dir}")
|
||||
print(f"连接 DSN: {dsn}")
|
||||
print("=" * 80)
|
||||
|
||||
for keywords, ods_table in ManualIngestTask.FILE_MAPPING:
|
||||
table = ods_table.split(".")[-1]
|
||||
cols = ods_cols_map.get(table, set())
|
||||
file_name = f"{keywords[0]}.json"
|
||||
file_path = json_dir / file_name
|
||||
keys_full, path_map = _load_json_keys(file_path)
|
||||
key_last_parts = set(path_map.keys())
|
||||
|
||||
missing: Set[str] = set()
|
||||
extra_keys: Set[str] = set()
|
||||
present: Set[str] = set()
|
||||
for col in sorted(cols):
|
||||
if col in key_last_parts:
|
||||
present.add(col)
|
||||
else:
|
||||
missing.add(col)
|
||||
for k in key_last_parts:
|
||||
if k not in cols:
|
||||
extra_keys.add(k)
|
||||
|
||||
print(f"[{table}] 文件={file_name} 列数={len(cols)} JSON键(末段)覆盖={len(present)}/{len(cols)}")
|
||||
if missing:
|
||||
print(" 未命中列:", ", ".join(sorted(missing)))
|
||||
else:
|
||||
print(" 未命中列: 无")
|
||||
if extra_keys:
|
||||
extras = []
|
||||
for k in sorted(extra_keys):
|
||||
paths = ", ".join(sorted(path_map.get(k, [])))
|
||||
extras.append(f"{k} ({paths})")
|
||||
print(" JSON 仅有(表无此列):", "; ".join(extras))
|
||||
else:
|
||||
print(" JSON 仅有(表无此列): 无")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
34
scripts/check/verify_dws_config.py
Normal file
34
scripts/check/verify_dws_config.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""验证DWS配置数据"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
import psycopg2
|
||||
|
||||
def main():
|
||||
load_dotenv(Path(__file__).parent.parent / ".env")
|
||||
dsn = os.getenv("PG_DSN")
|
||||
conn = psycopg2.connect(dsn)
|
||||
|
||||
tables = [
|
||||
"cfg_performance_tier",
|
||||
"cfg_assistant_level_price",
|
||||
"cfg_bonus_rules",
|
||||
"cfg_area_category",
|
||||
"cfg_skill_type"
|
||||
]
|
||||
|
||||
print("DWS 配置表数据统计:")
|
||||
print("-" * 40)
|
||||
|
||||
with conn.cursor() as cur:
|
||||
for t in tables:
|
||||
cur.execute(f"SELECT COUNT(*) FROM billiards_dws.{t}")
|
||||
cnt = cur.fetchone()[0]
|
||||
print(f"{t}: {cnt} 行")
|
||||
|
||||
conn.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
605
scripts/db_admin/import_dws_excel.py
Normal file
605
scripts/db_admin/import_dws_excel.py
Normal file
@@ -0,0 +1,605 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
DWS Excel导入脚本
|
||||
|
||||
功能说明:
|
||||
支持三类Excel数据的导入:
|
||||
1. 支出结构(dws_finance_expense_summary)
|
||||
2. 平台结算(dws_platform_settlement)
|
||||
3. 充值提成(dws_assistant_recharge_commission)
|
||||
|
||||
导入规范:
|
||||
- 字段定义:按照目标表字段要求
|
||||
- 时间粒度:支出按月,平台结算按日,充值提成按月
|
||||
- 门店维度:使用配置的site_id
|
||||
- 去重规则:按import_batch_no去重
|
||||
- 校验规则:金额字段非负,日期格式校验
|
||||
|
||||
使用方式:
|
||||
python import_dws_excel.py --type expense --file expenses.xlsx
|
||||
python import_dws_excel.py --type platform --file platform_settlement.xlsx
|
||||
python import_dws_excel.py --type commission --file recharge_commission.xlsx
|
||||
|
||||
作者:ETL团队
|
||||
创建日期:2026-02-01
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
print("请安装 pandas: pip install pandas openpyxl")
|
||||
sys.exit(1)
|
||||
|
||||
from config.settings import AppConfig
|
||||
from database.connection import DatabaseConnection
|
||||
from database.operations import DatabaseOperations
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 常量定义
|
||||
# =============================================================================
|
||||
|
||||
# 支出类型枚举
|
||||
EXPENSE_TYPES = {
|
||||
'房租': 'RENT',
|
||||
'水电费': 'UTILITY',
|
||||
'物业费': 'PROPERTY',
|
||||
'工资': 'SALARY',
|
||||
'报销': 'REIMBURSE',
|
||||
'平台服务费': 'PLATFORM_FEE',
|
||||
'其他': 'OTHER',
|
||||
}
|
||||
|
||||
# 支出大类映射
|
||||
EXPENSE_CATEGORIES = {
|
||||
'RENT': 'FIXED_COST',
|
||||
'UTILITY': 'VARIABLE_COST',
|
||||
'PROPERTY': 'FIXED_COST',
|
||||
'SALARY': 'FIXED_COST',
|
||||
'REIMBURSE': 'VARIABLE_COST',
|
||||
'PLATFORM_FEE': 'VARIABLE_COST',
|
||||
'OTHER': 'OTHER',
|
||||
}
|
||||
|
||||
# 平台类型枚举
|
||||
PLATFORM_TYPES = {
|
||||
'美团': 'MEITUAN',
|
||||
'抖音': 'DOUYIN',
|
||||
'大众点评': 'DIANPING',
|
||||
'其他': 'OTHER',
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 导入基类
|
||||
# =============================================================================
|
||||
|
||||
class BaseImporter:
|
||||
"""导入基类"""
|
||||
|
||||
def __init__(self, config: Config, db: DatabaseConnection):
|
||||
self.config = config
|
||||
self.db = db
|
||||
self.site_id = config.get("app.store_id")
|
||||
self.tenant_id = config.get("app.tenant_id", self.site_id)
|
||||
self.batch_no = self._generate_batch_no()
|
||||
|
||||
def _generate_batch_no(self) -> str:
|
||||
"""生成导入批次号"""
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
return f"{timestamp}_{unique_id}"
|
||||
|
||||
def _safe_decimal(self, value: Any, default: Decimal = Decimal('0')) -> Decimal:
|
||||
"""安全转换为Decimal"""
|
||||
if value is None or pd.isna(value):
|
||||
return default
|
||||
try:
|
||||
return Decimal(str(value))
|
||||
except (ValueError, InvalidOperation):
|
||||
return default
|
||||
|
||||
def _safe_date(self, value: Any) -> Optional[date]:
|
||||
"""安全转换为日期"""
|
||||
if value is None or pd.isna(value):
|
||||
return None
|
||||
if isinstance(value, datetime):
|
||||
return value.date()
|
||||
if isinstance(value, date):
|
||||
return value
|
||||
try:
|
||||
return pd.to_datetime(value).date()
|
||||
except:
|
||||
return None
|
||||
|
||||
def _safe_month(self, value: Any) -> Optional[date]:
|
||||
"""安全转换为月份(月第一天)"""
|
||||
dt = self._safe_date(value)
|
||||
if dt:
|
||||
return dt.replace(day=1)
|
||||
return None
|
||||
|
||||
def import_file(self, file_path: str) -> Dict[str, Any]:
|
||||
"""导入文件"""
|
||||
raise NotImplementedError
|
||||
|
||||
def validate_row(self, row: Dict[str, Any], row_idx: int) -> List[str]:
|
||||
"""校验行数据,返回错误列表"""
|
||||
return []
|
||||
|
||||
def transform_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""转换行数据"""
|
||||
raise NotImplementedError
|
||||
|
||||
def insert_records(self, records: List[Dict[str, Any]]) -> int:
|
||||
"""插入记录"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 支出导入
|
||||
# =============================================================================
|
||||
|
||||
class ExpenseImporter(BaseImporter):
|
||||
"""
|
||||
支出导入
|
||||
|
||||
Excel格式要求:
|
||||
- 月份: 2026-01 或 2026/01/01 格式
|
||||
- 支出类型: 房租/水电费/物业费/工资/报销/平台服务费/其他
|
||||
- 金额: 数字
|
||||
- 备注: 可选
|
||||
"""
|
||||
|
||||
TARGET_TABLE = "billiards_dws.dws_finance_expense_summary"
|
||||
|
||||
REQUIRED_COLUMNS = ['月份', '支出类型', '金额']
|
||||
OPTIONAL_COLUMNS = ['明细', '备注']
|
||||
|
||||
def import_file(self, file_path: str) -> Dict[str, Any]:
|
||||
"""导入支出Excel"""
|
||||
print(f"开始导入支出文件: {file_path}")
|
||||
|
||||
# 读取Excel
|
||||
df = pd.read_excel(file_path)
|
||||
|
||||
# 校验必要列
|
||||
missing_cols = [c for c in self.REQUIRED_COLUMNS if c not in df.columns]
|
||||
if missing_cols:
|
||||
return {"status": "ERROR", "message": f"缺少必要列: {missing_cols}"}
|
||||
|
||||
# 处理数据
|
||||
records = []
|
||||
errors = []
|
||||
|
||||
for idx, row in df.iterrows():
|
||||
row_dict = row.to_dict()
|
||||
row_errors = self.validate_row(row_dict, idx + 2) # Excel行号从2开始
|
||||
|
||||
if row_errors:
|
||||
errors.extend(row_errors)
|
||||
continue
|
||||
|
||||
record = self.transform_row(row_dict)
|
||||
records.append(record)
|
||||
|
||||
if errors:
|
||||
print(f"校验错误: {len(errors)} 条")
|
||||
for err in errors[:10]:
|
||||
print(f" - {err}")
|
||||
|
||||
# 插入数据
|
||||
inserted = 0
|
||||
if records:
|
||||
inserted = self.insert_records(records)
|
||||
|
||||
return {
|
||||
"status": "SUCCESS" if not errors else "PARTIAL",
|
||||
"batch_no": self.batch_no,
|
||||
"total_rows": len(df),
|
||||
"inserted": inserted,
|
||||
"errors": len(errors),
|
||||
"error_messages": errors[:10]
|
||||
}
|
||||
|
||||
def validate_row(self, row: Dict[str, Any], row_idx: int) -> List[str]:
|
||||
errors = []
|
||||
|
||||
# 校验月份
|
||||
month = self._safe_month(row.get('月份'))
|
||||
if not month:
|
||||
errors.append(f"行{row_idx}: 月份格式错误")
|
||||
|
||||
# 校验支出类型
|
||||
expense_type = row.get('支出类型', '').strip()
|
||||
if expense_type not in EXPENSE_TYPES:
|
||||
errors.append(f"行{row_idx}: 支出类型无效 '{expense_type}'")
|
||||
|
||||
# 校验金额
|
||||
amount = self._safe_decimal(row.get('金额'))
|
||||
if amount < 0:
|
||||
errors.append(f"行{row_idx}: 金额不能为负数")
|
||||
|
||||
return errors
|
||||
|
||||
def transform_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
|
||||
expense_type_name = row.get('支出类型', '').strip()
|
||||
expense_type_code = EXPENSE_TYPES.get(expense_type_name, 'OTHER')
|
||||
expense_category = EXPENSE_CATEGORIES.get(expense_type_code, 'OTHER')
|
||||
|
||||
return {
|
||||
'site_id': self.site_id,
|
||||
'tenant_id': self.tenant_id,
|
||||
'expense_month': self._safe_month(row.get('月份')),
|
||||
'expense_type_code': expense_type_code,
|
||||
'expense_type_name': expense_type_name,
|
||||
'expense_category': expense_category,
|
||||
'expense_amount': self._safe_decimal(row.get('金额')),
|
||||
'expense_detail': row.get('明细'),
|
||||
'import_batch_no': self.batch_no,
|
||||
'import_file_name': os.path.basename(str(row.get('_file_path', ''))),
|
||||
'import_time': datetime.now(),
|
||||
'import_user': os.getenv('USERNAME', 'system'),
|
||||
'remark': row.get('备注'),
|
||||
}
|
||||
|
||||
def insert_records(self, records: List[Dict[str, Any]]) -> int:
|
||||
columns = [
|
||||
'site_id', 'tenant_id', 'expense_month', 'expense_type_code',
|
||||
'expense_type_name', 'expense_category', 'expense_amount',
|
||||
'expense_detail', 'import_batch_no', 'import_file_name',
|
||||
'import_time', 'import_user', 'remark'
|
||||
]
|
||||
|
||||
cols_str = ", ".join(columns)
|
||||
placeholders = ", ".join(["%s"] * len(columns))
|
||||
sql = f"INSERT INTO {self.TARGET_TABLE} ({cols_str}) VALUES ({placeholders})"
|
||||
|
||||
inserted = 0
|
||||
with self.db.conn.cursor() as cur:
|
||||
for record in records:
|
||||
values = [record.get(col) for col in columns]
|
||||
cur.execute(sql, values)
|
||||
inserted += cur.rowcount
|
||||
|
||||
self.db.commit()
|
||||
return inserted
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 平台结算导入
|
||||
# =============================================================================
|
||||
|
||||
class PlatformSettlementImporter(BaseImporter):
|
||||
"""
|
||||
平台结算导入
|
||||
|
||||
Excel格式要求:
|
||||
- 回款日期: 日期格式
|
||||
- 平台类型: 美团/抖音/大众点评/其他
|
||||
- 平台订单号: 字符串
|
||||
- 订单原始金额: 数字
|
||||
- 佣金: 数字
|
||||
- 服务费: 数字
|
||||
- 回款金额: 数字
|
||||
- 备注: 可选
|
||||
"""
|
||||
|
||||
TARGET_TABLE = "billiards_dws.dws_platform_settlement"
|
||||
|
||||
REQUIRED_COLUMNS = ['回款日期', '平台类型', '回款金额']
|
||||
OPTIONAL_COLUMNS = ['平台订单号', '订单原始金额', '佣金', '服务费', '关联订单ID', '备注']
|
||||
|
||||
def import_file(self, file_path: str) -> Dict[str, Any]:
|
||||
print(f"开始导入平台结算文件: {file_path}")
|
||||
|
||||
df = pd.read_excel(file_path)
|
||||
|
||||
missing_cols = [c for c in self.REQUIRED_COLUMNS if c not in df.columns]
|
||||
if missing_cols:
|
||||
return {"status": "ERROR", "message": f"缺少必要列: {missing_cols}"}
|
||||
|
||||
records = []
|
||||
errors = []
|
||||
|
||||
for idx, row in df.iterrows():
|
||||
row_dict = row.to_dict()
|
||||
row_errors = self.validate_row(row_dict, idx + 2)
|
||||
|
||||
if row_errors:
|
||||
errors.extend(row_errors)
|
||||
continue
|
||||
|
||||
record = self.transform_row(row_dict)
|
||||
records.append(record)
|
||||
|
||||
if errors:
|
||||
print(f"校验错误: {len(errors)} 条")
|
||||
for err in errors[:10]:
|
||||
print(f" - {err}")
|
||||
|
||||
inserted = 0
|
||||
if records:
|
||||
inserted = self.insert_records(records)
|
||||
|
||||
return {
|
||||
"status": "SUCCESS" if not errors else "PARTIAL",
|
||||
"batch_no": self.batch_no,
|
||||
"total_rows": len(df),
|
||||
"inserted": inserted,
|
||||
"errors": len(errors),
|
||||
}
|
||||
|
||||
def validate_row(self, row: Dict[str, Any], row_idx: int) -> List[str]:
|
||||
errors = []
|
||||
|
||||
settlement_date = self._safe_date(row.get('回款日期'))
|
||||
if not settlement_date:
|
||||
errors.append(f"行{row_idx}: 回款日期格式错误")
|
||||
|
||||
platform_type = row.get('平台类型', '').strip()
|
||||
if platform_type not in PLATFORM_TYPES:
|
||||
errors.append(f"行{row_idx}: 平台类型无效 '{platform_type}'")
|
||||
|
||||
amount = self._safe_decimal(row.get('回款金额'))
|
||||
if amount < 0:
|
||||
errors.append(f"行{row_idx}: 回款金额不能为负数")
|
||||
|
||||
return errors
|
||||
|
||||
def transform_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
|
||||
platform_name = row.get('平台类型', '').strip()
|
||||
platform_type = PLATFORM_TYPES.get(platform_name, 'OTHER')
|
||||
|
||||
return {
|
||||
'site_id': self.site_id,
|
||||
'tenant_id': self.tenant_id,
|
||||
'settlement_date': self._safe_date(row.get('回款日期')),
|
||||
'platform_type': platform_type,
|
||||
'platform_name': platform_name,
|
||||
'platform_order_no': row.get('平台订单号'),
|
||||
'order_settle_id': row.get('关联订单ID'),
|
||||
'settlement_amount': self._safe_decimal(row.get('回款金额')),
|
||||
'commission_amount': self._safe_decimal(row.get('佣金')),
|
||||
'service_fee': self._safe_decimal(row.get('服务费')),
|
||||
'gross_amount': self._safe_decimal(row.get('订单原始金额')),
|
||||
'import_batch_no': self.batch_no,
|
||||
'import_file_name': os.path.basename(str(row.get('_file_path', ''))),
|
||||
'import_time': datetime.now(),
|
||||
'import_user': os.getenv('USERNAME', 'system'),
|
||||
'remark': row.get('备注'),
|
||||
}
|
||||
|
||||
def insert_records(self, records: List[Dict[str, Any]]) -> int:
|
||||
columns = [
|
||||
'site_id', 'tenant_id', 'settlement_date', 'platform_type',
|
||||
'platform_name', 'platform_order_no', 'order_settle_id',
|
||||
'settlement_amount', 'commission_amount', 'service_fee',
|
||||
'gross_amount', 'import_batch_no', 'import_file_name',
|
||||
'import_time', 'import_user', 'remark'
|
||||
]
|
||||
|
||||
cols_str = ", ".join(columns)
|
||||
placeholders = ", ".join(["%s"] * len(columns))
|
||||
sql = f"INSERT INTO {self.TARGET_TABLE} ({cols_str}) VALUES ({placeholders})"
|
||||
|
||||
inserted = 0
|
||||
with self.db.conn.cursor() as cur:
|
||||
for record in records:
|
||||
values = [record.get(col) for col in columns]
|
||||
cur.execute(sql, values)
|
||||
inserted += cur.rowcount
|
||||
|
||||
self.db.commit()
|
||||
return inserted
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 充值提成导入
|
||||
# =============================================================================
|
||||
|
||||
class RechargeCommissionImporter(BaseImporter):
|
||||
"""
|
||||
充值提成导入
|
||||
|
||||
Excel格式要求:
|
||||
- 月份: 2026-01 格式
|
||||
- 助教ID: 数字
|
||||
- 助教花名: 字符串
|
||||
- 充值订单金额: 数字
|
||||
- 提成金额: 数字
|
||||
- 充值订单号: 可选
|
||||
- 备注: 可选
|
||||
"""
|
||||
|
||||
TARGET_TABLE = "billiards_dws.dws_assistant_recharge_commission"
|
||||
|
||||
REQUIRED_COLUMNS = ['月份', '助教ID', '提成金额']
|
||||
OPTIONAL_COLUMNS = ['助教花名', '充值订单金额', '充值订单ID', '充值订单号', '备注']
|
||||
|
||||
def import_file(self, file_path: str) -> Dict[str, Any]:
|
||||
print(f"开始导入充值提成文件: {file_path}")
|
||||
|
||||
df = pd.read_excel(file_path)
|
||||
|
||||
missing_cols = [c for c in self.REQUIRED_COLUMNS if c not in df.columns]
|
||||
if missing_cols:
|
||||
return {"status": "ERROR", "message": f"缺少必要列: {missing_cols}"}
|
||||
|
||||
records = []
|
||||
errors = []
|
||||
|
||||
for idx, row in df.iterrows():
|
||||
row_dict = row.to_dict()
|
||||
row_errors = self.validate_row(row_dict, idx + 2)
|
||||
|
||||
if row_errors:
|
||||
errors.extend(row_errors)
|
||||
continue
|
||||
|
||||
record = self.transform_row(row_dict)
|
||||
records.append(record)
|
||||
|
||||
if errors:
|
||||
print(f"校验错误: {len(errors)} 条")
|
||||
for err in errors[:10]:
|
||||
print(f" - {err}")
|
||||
|
||||
inserted = 0
|
||||
if records:
|
||||
inserted = self.insert_records(records)
|
||||
|
||||
return {
|
||||
"status": "SUCCESS" if not errors else "PARTIAL",
|
||||
"batch_no": self.batch_no,
|
||||
"total_rows": len(df),
|
||||
"inserted": inserted,
|
||||
"errors": len(errors),
|
||||
}
|
||||
|
||||
def validate_row(self, row: Dict[str, Any], row_idx: int) -> List[str]:
|
||||
errors = []
|
||||
|
||||
month = self._safe_month(row.get('月份'))
|
||||
if not month:
|
||||
errors.append(f"行{row_idx}: 月份格式错误")
|
||||
|
||||
assistant_id = row.get('助教ID')
|
||||
if assistant_id is None or pd.isna(assistant_id):
|
||||
errors.append(f"行{row_idx}: 助教ID不能为空")
|
||||
|
||||
amount = self._safe_decimal(row.get('提成金额'))
|
||||
if amount < 0:
|
||||
errors.append(f"行{row_idx}: 提成金额不能为负数")
|
||||
|
||||
return errors
|
||||
|
||||
def transform_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
|
||||
recharge_amount = self._safe_decimal(row.get('充值订单金额'))
|
||||
commission_amount = self._safe_decimal(row.get('提成金额'))
|
||||
commission_ratio = commission_amount / recharge_amount if recharge_amount > 0 else None
|
||||
|
||||
return {
|
||||
'site_id': self.site_id,
|
||||
'tenant_id': self.tenant_id,
|
||||
'assistant_id': int(row.get('助教ID')),
|
||||
'assistant_nickname': row.get('助教花名'),
|
||||
'commission_month': self._safe_month(row.get('月份')),
|
||||
'recharge_order_id': row.get('充值订单ID'),
|
||||
'recharge_order_no': row.get('充值订单号'),
|
||||
'recharge_amount': recharge_amount,
|
||||
'commission_amount': commission_amount,
|
||||
'commission_ratio': commission_ratio,
|
||||
'import_batch_no': self.batch_no,
|
||||
'import_file_name': os.path.basename(str(row.get('_file_path', ''))),
|
||||
'import_time': datetime.now(),
|
||||
'import_user': os.getenv('USERNAME', 'system'),
|
||||
'remark': row.get('备注'),
|
||||
}
|
||||
|
||||
def insert_records(self, records: List[Dict[str, Any]]) -> int:
|
||||
columns = [
|
||||
'site_id', 'tenant_id', 'assistant_id', 'assistant_nickname',
|
||||
'commission_month', 'recharge_order_id', 'recharge_order_no',
|
||||
'recharge_amount', 'commission_amount', 'commission_ratio',
|
||||
'import_batch_no', 'import_file_name', 'import_time',
|
||||
'import_user', 'remark'
|
||||
]
|
||||
|
||||
cols_str = ", ".join(columns)
|
||||
placeholders = ", ".join(["%s"] * len(columns))
|
||||
sql = f"INSERT INTO {self.TARGET_TABLE} ({cols_str}) VALUES ({placeholders})"
|
||||
|
||||
inserted = 0
|
||||
with self.db.conn.cursor() as cur:
|
||||
for record in records:
|
||||
values = [record.get(col) for col in columns]
|
||||
cur.execute(sql, values)
|
||||
inserted += cur.rowcount
|
||||
|
||||
self.db.commit()
|
||||
return inserted
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 主函数
|
||||
# =============================================================================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='DWS Excel导入工具')
|
||||
parser.add_argument(
|
||||
'--type', '-t',
|
||||
choices=['expense', 'platform', 'commission'],
|
||||
required=True,
|
||||
help='导入类型: expense(支出), platform(平台结算), commission(充值提成)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--file', '-f',
|
||||
required=True,
|
||||
help='Excel文件路径'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 检查文件
|
||||
if not os.path.exists(args.file):
|
||||
print(f"文件不存在: {args.file}")
|
||||
sys.exit(1)
|
||||
|
||||
# 加载配置
|
||||
config = AppConfig.load()
|
||||
dsn = config["db"]["dsn"]
|
||||
db_conn = DatabaseConnection(dsn=dsn)
|
||||
db = DatabaseOperations(db_conn)
|
||||
|
||||
try:
|
||||
# 选择导入器
|
||||
if args.type == 'expense':
|
||||
importer = ExpenseImporter(config, db)
|
||||
elif args.type == 'platform':
|
||||
importer = PlatformSettlementImporter(config, db)
|
||||
elif args.type == 'commission':
|
||||
importer = RechargeCommissionImporter(config, db)
|
||||
else:
|
||||
print(f"未知的导入类型: {args.type}")
|
||||
sys.exit(1)
|
||||
|
||||
# 执行导入
|
||||
result = importer.import_file(args.file)
|
||||
|
||||
# 输出结果
|
||||
print("\n" + "=" * 50)
|
||||
print("导入结果:")
|
||||
print(f" 状态: {result.get('status')}")
|
||||
print(f" 批次号: {result.get('batch_no')}")
|
||||
print(f" 总行数: {result.get('total_rows')}")
|
||||
print(f" 插入行数: {result.get('inserted')}")
|
||||
print(f" 错误行数: {result.get('errors')}")
|
||||
|
||||
if result.get('status') == 'ERROR':
|
||||
print(f" 错误信息: {result.get('message')}")
|
||||
sys.exit(1)
|
||||
|
||||
except Exception as e:
|
||||
print(f"导入失败: {e}")
|
||||
db_conn.rollback()
|
||||
sys.exit(1)
|
||||
finally:
|
||||
db_conn.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
404
scripts/rebuild/rebuild_db_and_run_ods_to_dwd.py
Normal file
404
scripts/rebuild/rebuild_db_and_run_ods_to_dwd.py
Normal file
@@ -0,0 +1,404 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
一键重建 ETL 相关 Schema,并执行 ODS → DWD。
|
||||
|
||||
本脚本面向“离线示例 JSON 回放”的开发/运维场景,使用当前项目内的任务实现:
|
||||
1) (可选)DROP 并重建 schema:`etl_admin` / `billiards_ods` / `billiards_dwd`
|
||||
2) 执行 `INIT_ODS_SCHEMA`:创建 `etl_admin` 元数据表 + 执行 `schema_ODS_doc.sql`(内部会做轻量清洗)
|
||||
3) 执行 `INIT_DWD_SCHEMA`:执行 `schema_dwd_doc.sql`
|
||||
4) 执行 `MANUAL_INGEST`:从本地 JSON 目录灌入 ODS
|
||||
5) 执行 `DWD_LOAD_FROM_ODS`:从 ODS 装载到 DWD
|
||||
|
||||
用法(推荐):
|
||||
python -m scripts.rebuild.rebuild_db_and_run_ods_to_dwd ^
|
||||
--dsn "postgresql://user:pwd@host:5432/db" ^
|
||||
--store-id 1 ^
|
||||
--json-dir "export/test-json-doc" ^
|
||||
--drop-schemas
|
||||
|
||||
环境变量(可选):
|
||||
PG_DSN、STORE_ID、INGEST_SOURCE_DIR
|
||||
|
||||
日志:
|
||||
默认同时输出到控制台与文件;文件路径为 `io.log_root/rebuild_db_<时间戳>.log`。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import psycopg2
|
||||
|
||||
from config.settings import AppConfig
|
||||
from database.connection import DatabaseConnection
|
||||
from database.operations import DatabaseOperations
|
||||
from tasks.dwd.dwd_load_task import DwdLoadTask
|
||||
from tasks.utility.init_dwd_schema_task import InitDwdSchemaTask
|
||||
from tasks.utility.init_schema_task import InitOdsSchemaTask
|
||||
from tasks.utility.manual_ingest_task import ManualIngestTask
|
||||
|
||||
|
||||
DEFAULT_JSON_DIR = "export/test-json-doc"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunArgs:
|
||||
"""脚本参数对象(用于减少散落的参数传递)。"""
|
||||
|
||||
dsn: str
|
||||
store_id: int
|
||||
json_dir: str
|
||||
drop_schemas: bool
|
||||
terminate_own_sessions: bool
|
||||
demo: bool
|
||||
only_files: list[str]
|
||||
only_dwd_tables: list[str]
|
||||
stop_after: str | None
|
||||
|
||||
|
||||
def _attach_file_logger(log_root: str | Path, filename: str, logger: logging.Logger) -> logging.Handler | None:
|
||||
"""
|
||||
给 root logger 附加文件日志处理器(UTF-8)。
|
||||
|
||||
说明:
|
||||
- 使用 root logger 是为了覆盖项目中不同命名的 logger(包含第三方/子模块)。
|
||||
- 若创建失败仅记录 warning,不中断主流程。
|
||||
|
||||
返回值:
|
||||
创建成功返回 handler(调用方负责 removeHandler/close),失败返回 None。
|
||||
"""
|
||||
log_dir = Path(log_root)
|
||||
try:
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("创建日志目录失败:%s(%s)", log_dir, exc)
|
||||
return None
|
||||
|
||||
log_path = log_dir / filename
|
||||
try:
|
||||
handler: logging.Handler = logging.FileHandler(log_path, encoding="utf-8")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("创建文件日志失败:%s(%s)", log_path, exc)
|
||||
return None
|
||||
|
||||
handler.setLevel(logging.INFO)
|
||||
handler.setFormatter(
|
||||
logging.Formatter(
|
||||
fmt="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
)
|
||||
logging.getLogger().addHandler(handler)
|
||||
logger.info("文件日志已启用:%s", log_path)
|
||||
return handler
|
||||
|
||||
|
||||
def _parse_args() -> RunArgs:
|
||||
"""解析命令行/环境变量参数。"""
|
||||
parser = argparse.ArgumentParser(description="重建 Schema 并执行 ODS→DWD(离线 JSON 回放)")
|
||||
parser.add_argument("--dsn", default=os.environ.get("PG_DSN"), help="PostgreSQL DSN(默认读取 PG_DSN)")
|
||||
parser.add_argument(
|
||||
"--store-id",
|
||||
type=int,
|
||||
default=int(os.environ.get("STORE_ID") or 1),
|
||||
help="门店/租户 store_id(默认读取 STORE_ID,否则为 1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--json-dir",
|
||||
default=os.environ.get("INGEST_SOURCE_DIR") or DEFAULT_JSON_DIR,
|
||||
help=f"示例 JSON 目录(默认 {DEFAULT_JSON_DIR},也可读 INGEST_SOURCE_DIR)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--drop-schemas",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help="是否先 DROP 并重建 etl_admin/billiards_ods/billiards_dwd(默认:是)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--terminate-own-sessions",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help="执行 DROP 前是否终止当前用户的 idle-in-transaction 会话(默认:是)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--demo",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=False,
|
||||
help="运行最小 Demo(仅导入 member_profiles 并生成 dim_member/dim_member_ex)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--only-files",
|
||||
default="",
|
||||
help="仅处理指定 JSON 文件(逗号分隔,不含 .json,例如:member_profiles,settlement_records)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--only-dwd-tables",
|
||||
default="",
|
||||
help="仅处理指定 DWD 表(逗号分隔,支持完整名或表名,例如:billiards_dwd.dim_member,dim_member_ex)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stop-after",
|
||||
default="",
|
||||
help="在指定阶段后停止(可选:DROP_SCHEMAS/INIT_ODS_SCHEMA/INIT_DWD_SCHEMA/MANUAL_INGEST/DWD_LOAD_FROM_ODS/BASIC_VALIDATE)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.dsn:
|
||||
raise SystemExit("缺少 DSN:请传入 --dsn 或设置环境变量 PG_DSN")
|
||||
only_files = [x.strip().lower() for x in str(args.only_files or "").split(",") if x.strip()]
|
||||
only_dwd_tables = [x.strip().lower() for x in str(args.only_dwd_tables or "").split(",") if x.strip()]
|
||||
stop_after = str(args.stop_after or "").strip().upper() or None
|
||||
return RunArgs(
|
||||
dsn=args.dsn,
|
||||
store_id=args.store_id,
|
||||
json_dir=str(args.json_dir),
|
||||
drop_schemas=bool(args.drop_schemas),
|
||||
terminate_own_sessions=bool(args.terminate_own_sessions),
|
||||
demo=bool(args.demo),
|
||||
only_files=only_files,
|
||||
only_dwd_tables=only_dwd_tables,
|
||||
stop_after=stop_after,
|
||||
)
|
||||
|
||||
|
||||
def _build_config(args: RunArgs) -> AppConfig:
|
||||
"""构建本次执行所需的最小配置覆盖。"""
|
||||
manual_cfg: dict[str, Any] = {}
|
||||
dwd_cfg: dict[str, Any] = {}
|
||||
if args.demo:
|
||||
manual_cfg["include_files"] = ["member_profiles"]
|
||||
dwd_cfg["only_tables"] = ["billiards_dwd.dim_member", "billiards_dwd.dim_member_ex"]
|
||||
if args.only_files:
|
||||
manual_cfg["include_files"] = args.only_files
|
||||
if args.only_dwd_tables:
|
||||
dwd_cfg["only_tables"] = args.only_dwd_tables
|
||||
|
||||
overrides: dict[str, Any] = {
|
||||
"app": {"store_id": args.store_id},
|
||||
"pipeline": {"flow": "INGEST_ONLY", "ingest_source_dir": args.json_dir},
|
||||
"manual": manual_cfg,
|
||||
"dwd": dwd_cfg,
|
||||
# 离线回放/建仓可能耗时较长,关闭 statement_timeout,避免被默认 30s 中断。
|
||||
# 同时关闭 lock_timeout,避免 DROP/DDL 因锁等待稍久就直接失败。
|
||||
"db": {"dsn": args.dsn, "session": {"statement_timeout_ms": 0, "lock_timeout_ms": 0}},
|
||||
}
|
||||
return AppConfig.load(overrides)
|
||||
|
||||
|
||||
def _drop_schemas(db: DatabaseOperations, logger: logging.Logger) -> None:
|
||||
"""删除并重建 ETL 相关 schema(具备破坏性,请谨慎)。"""
|
||||
with db.conn.cursor() as cur:
|
||||
# 避免因为其他会话持锁而无限等待;若确实被占用,提示用户先释放/终止阻塞会话。
|
||||
cur.execute("SET lock_timeout TO '5s'")
|
||||
for schema in ("billiards_dwd", "billiards_ods", "etl_admin"):
|
||||
logger.info("DROP SCHEMA IF EXISTS %s CASCADE ...", schema)
|
||||
cur.execute(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE;')
|
||||
|
||||
|
||||
def _terminate_own_idle_in_tx(db: DatabaseOperations, logger: logging.Logger) -> int:
|
||||
"""终止当前用户在本库中处于 idle-in-transaction 的会话,避免阻塞 DROP/DDL。"""
|
||||
with db.conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT pid
|
||||
FROM pg_stat_activity
|
||||
WHERE datname = current_database()
|
||||
AND usename = current_user
|
||||
AND pid <> pg_backend_pid()
|
||||
AND state = 'idle in transaction'
|
||||
"""
|
||||
)
|
||||
pids = [r[0] for r in cur.fetchall()]
|
||||
killed = 0
|
||||
for pid in pids:
|
||||
cur.execute("SELECT pg_terminate_backend(%s)", (pid,))
|
||||
ok = bool(cur.fetchone()[0])
|
||||
logger.info("终止会话 pid=%s ok=%s", pid, ok)
|
||||
killed += 1 if ok else 0
|
||||
return killed
|
||||
|
||||
|
||||
def _run_task(task, logger: logging.Logger) -> dict:
|
||||
"""统一运行任务并打印关键结果。"""
|
||||
result = task.execute(None)
|
||||
logger.info("%s: status=%s counts=%s", task.get_task_code(), result.get("status"), result.get("counts"))
|
||||
return result
|
||||
|
||||
|
||||
def _basic_validate(db: DatabaseOperations, logger: logging.Logger) -> None:
|
||||
"""做最基础的可用性校验:schema 存在、关键表行数可查询。"""
|
||||
checks = [
|
||||
("billiards_ods", "member_profiles"),
|
||||
("billiards_ods", "settlement_records"),
|
||||
("billiards_dwd", "dim_member"),
|
||||
("billiards_dwd", "dwd_settlement_head"),
|
||||
]
|
||||
for schema, table in checks:
|
||||
try:
|
||||
rows = db.query(f'SELECT COUNT(1) AS cnt FROM "{schema}"."{table}"')
|
||||
logger.info("校验行数:%s.%s = %s", schema, table, (rows[0] or {}).get("cnt") if rows else None)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("校验失败:%s.%s(%s)", schema, table, exc)
|
||||
|
||||
|
||||
def _connect_db_with_retry(cfg: AppConfig, logger: logging.Logger) -> DatabaseConnection:
|
||||
"""创建数据库连接(带重试),避免短暂网络抖动导致脚本直接失败。"""
|
||||
dsn = cfg["db"]["dsn"]
|
||||
session = cfg["db"].get("session")
|
||||
connect_timeout = cfg["db"].get("connect_timeout_sec")
|
||||
|
||||
backoffs = [1, 2, 4, 8, 16]
|
||||
last_exc: Exception | None = None
|
||||
for attempt, wait_sec in enumerate([0] + backoffs, start=1):
|
||||
if wait_sec:
|
||||
time.sleep(wait_sec)
|
||||
try:
|
||||
return DatabaseConnection(dsn=dsn, session=session, connect_timeout=connect_timeout)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
last_exc = exc
|
||||
logger.warning("数据库连接失败(第 %s 次):%s", attempt, exc)
|
||||
raise last_exc or RuntimeError("数据库连接失败")
|
||||
|
||||
|
||||
def _is_connection_error(exc: Exception) -> bool:
|
||||
"""判断是否为连接断开/服务端异常导致的可重试错误。"""
|
||||
return isinstance(exc, (psycopg2.OperationalError, psycopg2.InterfaceError))
|
||||
|
||||
|
||||
def _run_stage_with_reconnect(
|
||||
cfg: AppConfig,
|
||||
logger: logging.Logger,
|
||||
stage_name: str,
|
||||
fn,
|
||||
max_attempts: int = 3,
|
||||
) -> dict | None:
|
||||
"""
|
||||
运行单个阶段:失败(尤其是连接断开)时自动重连并重试。
|
||||
|
||||
fn: (db_ops) -> dict | None
|
||||
"""
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
db_conn = _connect_db_with_retry(cfg, logger)
|
||||
db_ops = DatabaseOperations(db_conn)
|
||||
try:
|
||||
logger.info("阶段开始:%s(第 %s/%s 次)", stage_name, attempt, max_attempts)
|
||||
result = fn(db_ops)
|
||||
logger.info("阶段完成:%s", stage_name)
|
||||
return result
|
||||
except Exception as exc: # noqa: BLE001
|
||||
last_exc = exc
|
||||
logger.exception("阶段失败:%s(第 %s/%s 次):%s", stage_name, attempt, max_attempts, exc)
|
||||
# 连接类错误允许重试;非连接错误直接抛出,避免掩盖逻辑问题。
|
||||
if not _is_connection_error(exc):
|
||||
raise
|
||||
time.sleep(min(2**attempt, 10))
|
||||
finally:
|
||||
try:
|
||||
db_ops.close() # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
db_conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
raise last_exc or RuntimeError(f"阶段失败:{stage_name}")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""脚本主入口:按顺序重建并跑通 ODS→DWD。"""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger("fq_etl.rebuild_db")
|
||||
|
||||
args = _parse_args()
|
||||
cfg = _build_config(args)
|
||||
|
||||
# 默认启用文件日志,便于事后追溯(即便运行失败也应尽早落盘)。
|
||||
file_handler = _attach_file_logger(
|
||||
log_root=cfg["io"]["log_root"],
|
||||
filename=time.strftime("rebuild_db_%Y%m%d-%H%M%S.log"),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
try:
|
||||
json_dir = Path(args.json_dir)
|
||||
if not json_dir.exists():
|
||||
logger.error("示例 JSON 目录不存在:%s", json_dir)
|
||||
return 2
|
||||
|
||||
def stage_drop(db_ops: DatabaseOperations):
|
||||
if not args.drop_schemas:
|
||||
return None
|
||||
if args.terminate_own_sessions:
|
||||
killed = _terminate_own_idle_in_tx(db_ops, logger)
|
||||
if killed:
|
||||
db_ops.commit()
|
||||
_drop_schemas(db_ops, logger)
|
||||
db_ops.commit()
|
||||
return None
|
||||
|
||||
def stage_init_ods(db_ops: DatabaseOperations):
|
||||
return _run_task(InitOdsSchemaTask(cfg, db_ops, None, logger), logger)
|
||||
|
||||
def stage_init_dwd(db_ops: DatabaseOperations):
|
||||
return _run_task(InitDwdSchemaTask(cfg, db_ops, None, logger), logger)
|
||||
|
||||
def stage_manual_ingest(db_ops: DatabaseOperations):
|
||||
logger.info("开始执行:MANUAL_INGEST(json_dir=%s)", json_dir)
|
||||
return _run_task(ManualIngestTask(cfg, db_ops, None, logger), logger)
|
||||
|
||||
def stage_dwd_load(db_ops: DatabaseOperations):
|
||||
logger.info("开始执行:DWD_LOAD_FROM_ODS")
|
||||
return _run_task(DwdLoadTask(cfg, db_ops, None, logger), logger)
|
||||
|
||||
_run_stage_with_reconnect(cfg, logger, "DROP_SCHEMAS", stage_drop, max_attempts=3)
|
||||
if args.stop_after == "DROP_SCHEMAS":
|
||||
return 0
|
||||
_run_stage_with_reconnect(cfg, logger, "INIT_ODS_SCHEMA", stage_init_ods, max_attempts=3)
|
||||
if args.stop_after == "INIT_ODS_SCHEMA":
|
||||
return 0
|
||||
_run_stage_with_reconnect(cfg, logger, "INIT_DWD_SCHEMA", stage_init_dwd, max_attempts=3)
|
||||
if args.stop_after == "INIT_DWD_SCHEMA":
|
||||
return 0
|
||||
_run_stage_with_reconnect(cfg, logger, "MANUAL_INGEST", stage_manual_ingest, max_attempts=5)
|
||||
if args.stop_after == "MANUAL_INGEST":
|
||||
return 0
|
||||
_run_stage_with_reconnect(cfg, logger, "DWD_LOAD_FROM_ODS", stage_dwd_load, max_attempts=5)
|
||||
if args.stop_after == "DWD_LOAD_FROM_ODS":
|
||||
return 0
|
||||
|
||||
# 校验阶段复用一条新连接即可
|
||||
_run_stage_with_reconnect(
|
||||
cfg,
|
||||
logger,
|
||||
"BASIC_VALIDATE",
|
||||
lambda db_ops: _basic_validate(db_ops, logger),
|
||||
max_attempts=3,
|
||||
)
|
||||
if args.stop_after == "BASIC_VALIDATE":
|
||||
return 0
|
||||
return 0
|
||||
finally:
|
||||
if file_handler is not None:
|
||||
try:
|
||||
logging.getLogger().removeHandler(file_handler)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
file_handler.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
717
scripts/repair/backfill_missing_data.py
Normal file
717
scripts/repair/backfill_missing_data.py
Normal file
@@ -0,0 +1,717 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
补全丢失的 ODS 数据
|
||||
|
||||
通过运行数据完整性检查,找出 API 与 ODS 之间的差异,
|
||||
然后重新从 API 获取丢失的数据并插入 ODS。
|
||||
|
||||
用法:
|
||||
python -m scripts.backfill_missing_data --start 2025-07-01 --end 2026-01-19
|
||||
python -m scripts.backfill_missing_data --from-report reports/ods_gap_check_xxx.json
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import time as time_mod
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from dateutil import parser as dtparser
|
||||
from psycopg2.extras import Json, execute_values
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from api.recording_client import build_recording_client
|
||||
from config.settings import AppConfig
|
||||
from database.connection import DatabaseConnection
|
||||
from models.parsers import TypeParser
|
||||
from tasks.ods.ods_tasks import BaseOdsTask, ENABLED_ODS_CODES, ODS_TASK_SPECS, OdsTaskSpec
|
||||
from scripts.check.check_ods_gaps import run_gap_check
|
||||
from utils.logging_utils import build_log_path, configure_logging
|
||||
from utils.ods_record_utils import (
|
||||
get_value_case_insensitive,
|
||||
merge_record_layers,
|
||||
normalize_pk_value,
|
||||
pk_tuple_from_record,
|
||||
)
|
||||
|
||||
|
||||
def _reconfigure_stdout_utf8() -> None:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
try:
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _parse_dt(value: str, tz: ZoneInfo, *, is_end: bool = False) -> datetime:
|
||||
raw = (value or "").strip()
|
||||
if not raw:
|
||||
raise ValueError("empty datetime")
|
||||
has_time = any(ch in raw for ch in (":", "T"))
|
||||
dt = dtparser.parse(raw)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=tz)
|
||||
else:
|
||||
dt = dt.astimezone(tz)
|
||||
if not has_time:
|
||||
dt = dt.replace(
|
||||
hour=23 if is_end else 0,
|
||||
minute=59 if is_end else 0,
|
||||
second=59 if is_end else 0,
|
||||
microsecond=0
|
||||
)
|
||||
return dt
|
||||
|
||||
|
||||
def _get_spec(code: str) -> Optional[OdsTaskSpec]:
|
||||
"""根据任务代码获取 ODS 任务规格"""
|
||||
for spec in ODS_TASK_SPECS:
|
||||
if spec.code == code:
|
||||
return spec
|
||||
return None
|
||||
|
||||
|
||||
def _merge_record_layers(record: dict) -> dict:
|
||||
"""Flatten nested data layers into a single dict."""
|
||||
return merge_record_layers(record)
|
||||
|
||||
|
||||
def _get_value_case_insensitive(record: dict | None, col: str | None):
|
||||
"""Fetch value without case sensitivity."""
|
||||
return get_value_case_insensitive(record, col)
|
||||
|
||||
|
||||
def _normalize_pk_value(value):
|
||||
"""Normalize PK value."""
|
||||
return normalize_pk_value(value)
|
||||
|
||||
|
||||
def _pk_tuple_from_record(record: dict, pk_cols: List[str]) -> Optional[Tuple]:
|
||||
"""Extract PK tuple from record."""
|
||||
return pk_tuple_from_record(record, pk_cols)
|
||||
|
||||
|
||||
def _get_table_pk_columns(conn, table: str, *, include_content_hash: bool = False) -> List[str]:
|
||||
"""获取表的主键列"""
|
||||
if "." in table:
|
||||
schema, name = table.split(".", 1)
|
||||
else:
|
||||
schema, name = "public", table
|
||||
sql = """
|
||||
SELECT kcu.column_name
|
||||
FROM information_schema.table_constraints tc
|
||||
JOIN information_schema.key_column_usage kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
WHERE tc.constraint_type = 'PRIMARY KEY'
|
||||
AND tc.table_schema = %s
|
||||
AND tc.table_name = %s
|
||||
ORDER BY kcu.ordinal_position
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (schema, name))
|
||||
cols = [r[0] for r in cur.fetchall()]
|
||||
if include_content_hash:
|
||||
return cols
|
||||
return [c for c in cols if c.lower() != "content_hash"]
|
||||
|
||||
|
||||
def _get_table_columns(conn, table: str) -> List[Tuple[str, str, str]]:
|
||||
"""获取表的所有列信息"""
|
||||
if "." in table:
|
||||
schema, name = table.split(".", 1)
|
||||
else:
|
||||
schema, name = "public", table
|
||||
sql = """
|
||||
SELECT column_name, data_type, udt_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = %s AND table_name = %s
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (schema, name))
|
||||
return [(r[0], (r[1] or "").lower(), (r[2] or "").lower()) for r in cur.fetchall()]
|
||||
|
||||
|
||||
def _fetch_existing_pk_set(
|
||||
conn, table: str, pk_cols: List[str], pk_values: List[Tuple], chunk_size: int
|
||||
) -> Set[Tuple]:
|
||||
"""获取已存在的 PK 集合"""
|
||||
if not pk_values:
|
||||
return set()
|
||||
select_cols = ", ".join(f't."{c}"' for c in pk_cols)
|
||||
value_cols = ", ".join(f'"{c}"' for c in pk_cols)
|
||||
join_cond = " AND ".join(f't."{c}" = v."{c}"' for c in pk_cols)
|
||||
sql = (
|
||||
f"SELECT {select_cols} FROM {table} t "
|
||||
f"JOIN (VALUES %s) AS v({value_cols}) ON {join_cond}"
|
||||
)
|
||||
existing: Set[Tuple] = set()
|
||||
with conn.cursor() as cur:
|
||||
for i in range(0, len(pk_values), chunk_size):
|
||||
chunk = pk_values[i:i + chunk_size]
|
||||
execute_values(cur, sql, chunk, page_size=len(chunk))
|
||||
for row in cur.fetchall():
|
||||
existing.add(tuple(row))
|
||||
return existing
|
||||
|
||||
|
||||
def _cast_value(value, data_type: str):
|
||||
"""类型转换"""
|
||||
if value is None:
|
||||
return None
|
||||
dt = (data_type or "").lower()
|
||||
if dt in ("integer", "bigint", "smallint"):
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return None
|
||||
if dt in ("numeric", "double precision", "real", "decimal"):
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
return None
|
||||
if dt.startswith("timestamp") or dt in ("date", "time", "interval"):
|
||||
return value if isinstance(value, (str, datetime)) else None
|
||||
return value
|
||||
|
||||
|
||||
def _normalize_scalar(value):
|
||||
"""规范化标量值"""
|
||||
if value == "" or value == "{}" or value == "[]":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
class MissingDataBackfiller:
|
||||
"""丢失数据补全器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg: AppConfig,
|
||||
logger: logging.Logger,
|
||||
dry_run: bool = False,
|
||||
):
|
||||
self.cfg = cfg
|
||||
self.logger = logger
|
||||
self.dry_run = dry_run
|
||||
self.tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei"))
|
||||
self.store_id = int(cfg.get("app.store_id") or 0)
|
||||
|
||||
# API 客户端
|
||||
self.api = build_recording_client(cfg, task_code="BACKFILL_MISSING_DATA")
|
||||
|
||||
# 数据库连接(DatabaseConnection 构造时已设置 autocommit=False)
|
||||
self.db = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
|
||||
|
||||
def close(self):
|
||||
"""关闭连接"""
|
||||
if self.db:
|
||||
self.db.close()
|
||||
|
||||
def _ensure_db(self):
|
||||
"""确保数据库连接可用"""
|
||||
if self.db and getattr(self.db, "conn", None) is not None:
|
||||
if getattr(self.db.conn, "closed", 0) == 0:
|
||||
return
|
||||
self.db = DatabaseConnection(dsn=self.cfg["db"]["dsn"], session=self.cfg["db"].get("session"))
|
||||
|
||||
def backfill_from_gap_check(
|
||||
self,
|
||||
*,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
task_codes: Optional[str] = None,
|
||||
include_mismatch: bool = False,
|
||||
page_size: int = 200,
|
||||
chunk_size: int = 500,
|
||||
content_sample_limit: int | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
运行 gap check 并补全丢失数据
|
||||
|
||||
Returns:
|
||||
补全结果统计
|
||||
"""
|
||||
self.logger.info("数据补全开始 起始=%s 结束=%s", start.isoformat(), end.isoformat())
|
||||
|
||||
# 计算窗口大小
|
||||
total_seconds = max(0, int((end - start).total_seconds()))
|
||||
if total_seconds >= 86400:
|
||||
window_days = max(1, total_seconds // 86400)
|
||||
window_hours = 0
|
||||
else:
|
||||
window_days = 0
|
||||
window_hours = max(1, total_seconds // 3600 or 1)
|
||||
|
||||
# 运行 gap check
|
||||
self.logger.info("正在执行缺失检查...")
|
||||
gap_result = run_gap_check(
|
||||
cfg=self.cfg,
|
||||
start=start,
|
||||
end=end,
|
||||
window_days=window_days,
|
||||
window_hours=window_hours,
|
||||
page_size=page_size,
|
||||
chunk_size=chunk_size,
|
||||
sample_limit=10000, # 获取所有丢失样本
|
||||
sleep_per_window=0,
|
||||
sleep_per_page=0,
|
||||
task_codes=task_codes or "",
|
||||
from_cutoff=False,
|
||||
cutoff_overlap_hours=24,
|
||||
allow_small_window=True,
|
||||
logger=self.logger,
|
||||
compare_content=include_mismatch,
|
||||
content_sample_limit=content_sample_limit or 10000,
|
||||
)
|
||||
|
||||
total_missing = gap_result.get("total_missing", 0)
|
||||
total_mismatch = gap_result.get("total_mismatch", 0)
|
||||
if total_missing == 0 and (not include_mismatch or total_mismatch == 0):
|
||||
self.logger.info("Data complete: no missing/mismatch records")
|
||||
return {"backfilled": 0, "errors": 0, "details": []}
|
||||
|
||||
if include_mismatch:
|
||||
self.logger.info("Missing/mismatch check done missing=%s mismatch=%s", total_missing, total_mismatch)
|
||||
else:
|
||||
self.logger.info("Missing check done missing=%s", total_missing)
|
||||
|
||||
results = []
|
||||
total_backfilled = 0
|
||||
total_errors = 0
|
||||
|
||||
for task_result in gap_result.get("results", []):
|
||||
task_code = task_result.get("task_code")
|
||||
missing = task_result.get("missing", 0)
|
||||
missing_samples = task_result.get("missing_samples", [])
|
||||
mismatch = task_result.get("mismatch", 0) if include_mismatch else 0
|
||||
mismatch_samples = task_result.get("mismatch_samples", []) if include_mismatch else []
|
||||
target_samples = list(missing_samples) + list(mismatch_samples)
|
||||
|
||||
if missing == 0 and mismatch == 0:
|
||||
continue
|
||||
|
||||
self.logger.info(
|
||||
"Start backfill task task=%s missing=%s mismatch=%s samples=%s",
|
||||
task_code, missing, mismatch, len(target_samples)
|
||||
)
|
||||
|
||||
try:
|
||||
backfilled = self._backfill_task(
|
||||
task_code=task_code,
|
||||
table=task_result.get("table"),
|
||||
pk_columns=task_result.get("pk_columns", []),
|
||||
pk_samples=target_samples,
|
||||
start=start,
|
||||
end=end,
|
||||
page_size=page_size,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
results.append({
|
||||
"task_code": task_code,
|
||||
"missing": missing,
|
||||
"mismatch": mismatch,
|
||||
"backfilled": backfilled,
|
||||
"error": None,
|
||||
})
|
||||
total_backfilled += backfilled
|
||||
except Exception as exc:
|
||||
self.logger.exception("补全失败 任务=%s", task_code)
|
||||
results.append({
|
||||
"task_code": task_code,
|
||||
"missing": missing,
|
||||
"mismatch": mismatch,
|
||||
"backfilled": 0,
|
||||
"error": str(exc),
|
||||
})
|
||||
total_errors += 1
|
||||
|
||||
self.logger.info(
|
||||
"数据补全完成 总缺失=%s 已补全=%s 错误数=%s",
|
||||
total_missing, total_backfilled, total_errors
|
||||
)
|
||||
|
||||
return {
|
||||
"total_missing": total_missing,
|
||||
"total_mismatch": total_mismatch,
|
||||
"backfilled": total_backfilled,
|
||||
"errors": total_errors,
|
||||
"details": results,
|
||||
}
|
||||
|
||||
def _backfill_task(
|
||||
self,
|
||||
*,
|
||||
task_code: str,
|
||||
table: str,
|
||||
pk_columns: List[str],
|
||||
pk_samples: List[Dict],
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
page_size: int,
|
||||
chunk_size: int,
|
||||
) -> int:
|
||||
"""补全单个任务的丢失数据"""
|
||||
self._ensure_db()
|
||||
spec = _get_spec(task_code)
|
||||
if not spec:
|
||||
self.logger.warning("未找到任务规格 任务=%s", task_code)
|
||||
return 0
|
||||
|
||||
if not pk_columns:
|
||||
pk_columns = _get_table_pk_columns(self.db.conn, table, include_content_hash=False)
|
||||
|
||||
conflict_columns = _get_table_pk_columns(self.db.conn, table, include_content_hash=True)
|
||||
if not conflict_columns:
|
||||
conflict_columns = pk_columns
|
||||
|
||||
if not pk_columns:
|
||||
self.logger.warning("未找到主键列 任务=%s 表=%s", task_code, table)
|
||||
return 0
|
||||
|
||||
# 提取丢失的 PK 值
|
||||
missing_pks: Set[Tuple] = set()
|
||||
for sample in pk_samples:
|
||||
pk_tuple = tuple(sample.get(col) for col in pk_columns)
|
||||
if all(v is not None for v in pk_tuple):
|
||||
missing_pks.add(pk_tuple)
|
||||
|
||||
if not missing_pks:
|
||||
self.logger.info("无缺失主键 任务=%s", task_code)
|
||||
return 0
|
||||
|
||||
self.logger.info(
|
||||
"开始获取数据 任务=%s 缺失主键数=%s",
|
||||
task_code, len(missing_pks)
|
||||
)
|
||||
|
||||
# 从 API 获取数据并过滤出丢失的记录
|
||||
params = self._build_params(spec, start, end)
|
||||
|
||||
backfilled = 0
|
||||
cols_info = _get_table_columns(self.db.conn, table)
|
||||
db_json_cols_lower = {
|
||||
c[0].lower() for c in cols_info
|
||||
if c[1] in ("json", "jsonb") or c[2] in ("json", "jsonb")
|
||||
}
|
||||
col_names = [c[0] for c in cols_info]
|
||||
|
||||
# 结束只读事务,避免长时间 API 拉取导致 idle_in_tx 超时
|
||||
try:
|
||||
self.db.conn.commit()
|
||||
except Exception:
|
||||
self.db.conn.rollback()
|
||||
|
||||
try:
|
||||
for page_no, records, _, response_payload in self.api.iter_paginated(
|
||||
endpoint=spec.endpoint,
|
||||
params=params,
|
||||
page_size=page_size,
|
||||
data_path=spec.data_path,
|
||||
list_key=spec.list_key,
|
||||
):
|
||||
# 过滤出丢失的记录
|
||||
records_to_insert = []
|
||||
for rec in records:
|
||||
if not isinstance(rec, dict):
|
||||
continue
|
||||
pk_tuple = _pk_tuple_from_record(rec, pk_columns)
|
||||
if pk_tuple and pk_tuple in missing_pks:
|
||||
records_to_insert.append(rec)
|
||||
|
||||
if not records_to_insert:
|
||||
continue
|
||||
|
||||
# 插入丢失的记录
|
||||
if self.dry_run:
|
||||
backfilled += len(records_to_insert)
|
||||
self.logger.info(
|
||||
"模拟运行 任务=%s 页=%s 将插入=%s",
|
||||
task_code, page_no, len(records_to_insert)
|
||||
)
|
||||
else:
|
||||
inserted = self._insert_records(
|
||||
table=table,
|
||||
records=records_to_insert,
|
||||
cols_info=cols_info,
|
||||
pk_columns=pk_columns,
|
||||
conflict_columns=conflict_columns,
|
||||
db_json_cols_lower=db_json_cols_lower,
|
||||
)
|
||||
backfilled += inserted
|
||||
# 避免长事务阻塞与 idle_in_tx 超时
|
||||
self.db.conn.commit()
|
||||
self.logger.info(
|
||||
"已插入 任务=%s 页=%s 数量=%s",
|
||||
task_code, page_no, inserted
|
||||
)
|
||||
|
||||
if not self.dry_run:
|
||||
self.db.conn.commit()
|
||||
|
||||
self.logger.info("任务补全完成 任务=%s 已补全=%s", task_code, backfilled)
|
||||
return backfilled
|
||||
|
||||
except Exception:
|
||||
self.db.conn.rollback()
|
||||
raise
|
||||
|
||||
def _build_params(
|
||||
self,
|
||||
spec: OdsTaskSpec,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
) -> Dict:
|
||||
"""构建 API 请求参数"""
|
||||
base: Dict[str, Any] = {}
|
||||
if spec.include_site_id:
|
||||
if spec.endpoint == "/TenantGoods/GetGoodsInventoryList":
|
||||
base["siteId"] = [self.store_id]
|
||||
else:
|
||||
base["siteId"] = self.store_id
|
||||
|
||||
if spec.requires_window and spec.time_fields:
|
||||
start_key, end_key = spec.time_fields
|
||||
base[start_key] = TypeParser.format_timestamp(start, self.tz)
|
||||
base[end_key] = TypeParser.format_timestamp(end, self.tz)
|
||||
|
||||
# 合并公共参数
|
||||
common = self.cfg.get("api.params", {}) or {}
|
||||
if isinstance(common, dict):
|
||||
merged = {**common, **base}
|
||||
else:
|
||||
merged = base
|
||||
|
||||
merged.update(spec.extra_params or {})
|
||||
return merged
|
||||
|
||||
def _insert_records(
|
||||
self,
|
||||
*,
|
||||
table: str,
|
||||
records: List[Dict],
|
||||
cols_info: List[Tuple[str, str, str]],
|
||||
pk_columns: List[str],
|
||||
conflict_columns: List[str],
|
||||
db_json_cols_lower: Set[str],
|
||||
) -> int:
|
||||
"""插入记录到数据库"""
|
||||
if not records:
|
||||
return 0
|
||||
|
||||
col_names = [c[0] for c in cols_info]
|
||||
needs_content_hash = any(c[0].lower() == "content_hash" for c in cols_info)
|
||||
quoted_cols = ", ".join(f'"{c}"' for c in col_names)
|
||||
sql = f"INSERT INTO {table} ({quoted_cols}) VALUES %s"
|
||||
conflict_cols = conflict_columns or pk_columns
|
||||
if conflict_cols:
|
||||
pk_clause = ", ".join(f'"{c}"' for c in conflict_cols)
|
||||
sql += f" ON CONFLICT ({pk_clause}) DO NOTHING"
|
||||
|
||||
now = datetime.now(self.tz)
|
||||
json_dump = lambda v: json.dumps(v, ensure_ascii=False)
|
||||
|
||||
params: List[Tuple] = []
|
||||
for rec in records:
|
||||
merged_rec = _merge_record_layers(rec)
|
||||
|
||||
# 检查 PK
|
||||
if pk_columns:
|
||||
missing_pk = False
|
||||
for pk in pk_columns:
|
||||
if str(pk).lower() == "content_hash":
|
||||
continue
|
||||
pk_val = _get_value_case_insensitive(merged_rec, pk)
|
||||
if pk_val is None or pk_val == "":
|
||||
missing_pk = True
|
||||
break
|
||||
if missing_pk:
|
||||
continue
|
||||
|
||||
content_hash = None
|
||||
if needs_content_hash:
|
||||
content_hash = BaseOdsTask._compute_content_hash(
|
||||
merged_rec, include_fetched_at=False
|
||||
)
|
||||
|
||||
row_vals: List[Any] = []
|
||||
for (col_name, data_type, _udt) in cols_info:
|
||||
col_lower = col_name.lower()
|
||||
if col_lower == "payload":
|
||||
row_vals.append(Json(rec, dumps=json_dump))
|
||||
continue
|
||||
if col_lower == "source_file":
|
||||
row_vals.append("backfill")
|
||||
continue
|
||||
if col_lower == "source_endpoint":
|
||||
row_vals.append("backfill")
|
||||
continue
|
||||
if col_lower == "fetched_at":
|
||||
row_vals.append(now)
|
||||
continue
|
||||
if col_lower == "content_hash":
|
||||
row_vals.append(content_hash)
|
||||
continue
|
||||
|
||||
value = _normalize_scalar(_get_value_case_insensitive(merged_rec, col_name))
|
||||
if col_lower in db_json_cols_lower:
|
||||
row_vals.append(Json(value, dumps=json_dump) if value is not None else None)
|
||||
continue
|
||||
|
||||
row_vals.append(_cast_value(value, data_type))
|
||||
|
||||
params.append(tuple(row_vals))
|
||||
|
||||
if not params:
|
||||
return 0
|
||||
|
||||
inserted = 0
|
||||
with self.db.conn.cursor() as cur:
|
||||
for i in range(0, len(params), 200):
|
||||
chunk = params[i:i + 200]
|
||||
execute_values(cur, sql, chunk, page_size=len(chunk))
|
||||
if cur.rowcount is not None and cur.rowcount > 0:
|
||||
inserted += int(cur.rowcount)
|
||||
|
||||
return inserted
|
||||
|
||||
|
||||
def run_backfill(
|
||||
*,
|
||||
cfg: AppConfig,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
task_codes: Optional[str] = None,
|
||||
include_mismatch: bool = False,
|
||||
dry_run: bool = False,
|
||||
page_size: int = 200,
|
||||
chunk_size: int = 500,
|
||||
content_sample_limit: int | None = None,
|
||||
logger: logging.Logger,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
运行数据补全
|
||||
|
||||
Args:
|
||||
cfg: 应用配置
|
||||
start: 开始时间
|
||||
end: 结束时间
|
||||
task_codes: 指定任务代码(逗号分隔)
|
||||
dry_run: 是否仅预览
|
||||
page_size: API 分页大小
|
||||
chunk_size: 数据库批量大小
|
||||
logger: 日志记录器
|
||||
|
||||
Returns:
|
||||
补全结果
|
||||
"""
|
||||
backfiller = MissingDataBackfiller(cfg, logger, dry_run)
|
||||
try:
|
||||
return backfiller.backfill_from_gap_check(
|
||||
start=start,
|
||||
end=end,
|
||||
task_codes=task_codes,
|
||||
include_mismatch=include_mismatch,
|
||||
page_size=page_size,
|
||||
chunk_size=chunk_size,
|
||||
content_sample_limit=content_sample_limit,
|
||||
)
|
||||
finally:
|
||||
backfiller.close()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
_reconfigure_stdout_utf8()
|
||||
|
||||
ap = argparse.ArgumentParser(description="补全丢失的 ODS 数据")
|
||||
ap.add_argument("--start", default="2025-07-01", help="开始日期 (默认: 2025-07-01)")
|
||||
ap.add_argument("--end", default="", help="结束日期 (默认: 当前时间)")
|
||||
ap.add_argument("--task-codes", default="", help="指定任务代码(逗号分隔,留空=全部)")
|
||||
ap.add_argument("--include-mismatch", action="store_true", help="同时补全内容不一致的记录")
|
||||
ap.add_argument("--content-sample-limit", type=int, default=None, help="不一致样本上限 (默认: 10000)")
|
||||
ap.add_argument("--dry-run", action="store_true", help="仅预览,不实际写入")
|
||||
ap.add_argument("--page-size", type=int, default=200, help="API 分页大小 (默认: 200)")
|
||||
ap.add_argument("--chunk-size", type=int, default=500, help="数据库批量大小 (默认: 500)")
|
||||
ap.add_argument("--log-file", default="", help="日志文件路径")
|
||||
ap.add_argument("--log-dir", default="", help="日志目录")
|
||||
ap.add_argument("--log-level", default="INFO", help="日志级别 (默认: INFO)")
|
||||
ap.add_argument("--no-log-console", action="store_true", help="禁用控制台日志")
|
||||
args = ap.parse_args()
|
||||
|
||||
log_dir = Path(args.log_dir) if args.log_dir else (PROJECT_ROOT / "logs")
|
||||
log_file = Path(args.log_file) if args.log_file else build_log_path(log_dir, "backfill_missing")
|
||||
log_console = not args.no_log_console
|
||||
|
||||
with configure_logging(
|
||||
"backfill_missing",
|
||||
log_file,
|
||||
level=args.log_level,
|
||||
console=log_console,
|
||||
tee_std=True,
|
||||
) as logger:
|
||||
cfg = AppConfig.load({})
|
||||
tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei"))
|
||||
|
||||
start = _parse_dt(args.start, tz)
|
||||
end = _parse_dt(args.end, tz, is_end=True) if args.end else datetime.now(tz)
|
||||
|
||||
result = run_backfill(
|
||||
cfg=cfg,
|
||||
start=start,
|
||||
end=end,
|
||||
task_codes=args.task_codes or None,
|
||||
include_mismatch=args.include_mismatch,
|
||||
dry_run=args.dry_run,
|
||||
page_size=args.page_size,
|
||||
chunk_size=args.chunk_size,
|
||||
content_sample_limit=args.content_sample_limit,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("补全完成!")
|
||||
logger.info(" 总丢失: %s", result.get("total_missing", 0))
|
||||
if args.include_mismatch:
|
||||
logger.info(" 总不一致: %s", result.get("total_mismatch", 0))
|
||||
logger.info(" 已补全: %s", result.get("backfilled", 0))
|
||||
logger.info(" 错误数: %s", result.get("errors", 0))
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 输出详细结果
|
||||
for detail in result.get("details", []):
|
||||
if detail.get("error"):
|
||||
logger.error(
|
||||
" %s: 丢失=%s 不一致=%s 补全=%s 错误=%s",
|
||||
detail.get("task_code"),
|
||||
detail.get("missing"),
|
||||
detail.get("mismatch", 0),
|
||||
detail.get("backfilled"),
|
||||
detail.get("error"),
|
||||
)
|
||||
elif detail.get("backfilled", 0) > 0:
|
||||
logger.info(
|
||||
" %s: 丢失=%s 不一致=%s 补全=%s",
|
||||
detail.get("task_code"),
|
||||
detail.get("missing"),
|
||||
detail.get("mismatch", 0),
|
||||
detail.get("backfilled"),
|
||||
)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
261
scripts/repair/dedupe_ods_snapshots.py
Normal file
261
scripts/repair/dedupe_ods_snapshots.py
Normal file
@@ -0,0 +1,261 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Deduplicate ODS snapshots by (business PK, content_hash).
|
||||
Keep the latest row by fetched_at (tie-breaker: ctid desc).
|
||||
|
||||
Usage:
|
||||
PYTHONPATH=. python -m scripts.repair.dedupe_ods_snapshots
|
||||
PYTHONPATH=. python -m scripts.repair.dedupe_ods_snapshots --schema billiards_ods
|
||||
PYTHONPATH=. python -m scripts.repair.dedupe_ods_snapshots --tables member_profiles,orders
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Sequence
|
||||
|
||||
import psycopg2
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from config.settings import AppConfig
|
||||
from database.connection import DatabaseConnection
|
||||
|
||||
|
||||
def _reconfigure_stdout_utf8() -> None:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
try:
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _quote_ident(name: str) -> str:
|
||||
return '"' + str(name).replace('"', '""') + '"'
|
||||
|
||||
|
||||
def _fetch_tables(conn, schema: str) -> list[str]:
|
||||
sql = """
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = %s AND table_type = 'BASE TABLE'
|
||||
ORDER BY table_name
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (schema,))
|
||||
return [r[0] for r in cur.fetchall()]
|
||||
|
||||
|
||||
def _fetch_columns(conn, schema: str, table: str) -> list[str]:
|
||||
sql = """
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = %s AND table_name = %s
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (schema, table))
|
||||
return [r[0] for r in cur.fetchall()]
|
||||
|
||||
|
||||
def _fetch_pk_columns(conn, schema: str, table: str) -> list[str]:
|
||||
sql = """
|
||||
SELECT kcu.column_name
|
||||
FROM information_schema.table_constraints tc
|
||||
JOIN information_schema.key_column_usage kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
WHERE tc.constraint_type = 'PRIMARY KEY'
|
||||
AND tc.table_schema = %s
|
||||
AND tc.table_name = %s
|
||||
ORDER BY kcu.ordinal_position
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (schema, table))
|
||||
cols = [r[0] for r in cur.fetchall()]
|
||||
return [c for c in cols if c.lower() != "content_hash"]
|
||||
|
||||
|
||||
def _build_report_path(out_arg: str | None) -> Path:
|
||||
if out_arg:
|
||||
return Path(out_arg)
|
||||
reports_dir = PROJECT_ROOT / "reports"
|
||||
reports_dir.mkdir(parents=True, exist_ok=True)
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
return reports_dir / f"ods_snapshot_dedupe_{ts}.json"
|
||||
|
||||
|
||||
def _print_progress(
|
||||
table_label: str,
|
||||
deleted: int,
|
||||
total: int,
|
||||
errors: int,
|
||||
) -> None:
|
||||
if total:
|
||||
msg = f"[{table_label}] deleted {deleted}/{total} errors={errors}"
|
||||
else:
|
||||
msg = f"[{table_label}] deleted {deleted} errors={errors}"
|
||||
print(msg, flush=True)
|
||||
|
||||
|
||||
def _count_duplicates(conn, schema: str, table: str, key_cols: Sequence[str]) -> int:
|
||||
keys_sql = ", ".join(_quote_ident(c) for c in [*key_cols, "content_hash"])
|
||||
table_sql = f"{_quote_ident(schema)}.{_quote_ident(table)}"
|
||||
sql = f"""
|
||||
SELECT COUNT(*) FROM (
|
||||
SELECT 1
|
||||
FROM (
|
||||
SELECT ROW_NUMBER() OVER (
|
||||
PARTITION BY {keys_sql}
|
||||
ORDER BY fetched_at DESC NULLS LAST, ctid DESC
|
||||
) AS rn
|
||||
FROM {table_sql}
|
||||
) t
|
||||
WHERE rn > 1
|
||||
) s
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql)
|
||||
row = cur.fetchone()
|
||||
return int(row[0] if row else 0)
|
||||
|
||||
|
||||
def _delete_duplicate_batch(
|
||||
conn,
|
||||
schema: str,
|
||||
table: str,
|
||||
key_cols: Sequence[str],
|
||||
batch_size: int,
|
||||
) -> int:
|
||||
keys_sql = ", ".join(_quote_ident(c) for c in [*key_cols, "content_hash"])
|
||||
table_sql = f"{_quote_ident(schema)}.{_quote_ident(table)}"
|
||||
sql = f"""
|
||||
WITH dupes AS (
|
||||
SELECT ctid
|
||||
FROM (
|
||||
SELECT ctid,
|
||||
ROW_NUMBER() OVER (
|
||||
PARTITION BY {keys_sql}
|
||||
ORDER BY fetched_at DESC NULLS LAST, ctid DESC
|
||||
) AS rn
|
||||
FROM {table_sql}
|
||||
) s
|
||||
WHERE rn > 1
|
||||
LIMIT %s
|
||||
)
|
||||
DELETE FROM {table_sql} t
|
||||
USING dupes d
|
||||
WHERE t.ctid = d.ctid
|
||||
RETURNING 1
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (int(batch_size),))
|
||||
rows = cur.fetchall()
|
||||
return len(rows or [])
|
||||
|
||||
|
||||
def main() -> int:
|
||||
_reconfigure_stdout_utf8()
|
||||
ap = argparse.ArgumentParser(description="Deduplicate ODS snapshot rows by PK+content_hash")
|
||||
ap.add_argument("--schema", default="billiards_ods", help="ODS schema name")
|
||||
ap.add_argument("--tables", default="", help="comma-separated table names (optional)")
|
||||
ap.add_argument("--batch-size", type=int, default=1000, help="delete batch size")
|
||||
ap.add_argument("--progress-every", type=int, default=100, help="print progress every N deletions")
|
||||
ap.add_argument("--out", default="", help="output report JSON path")
|
||||
ap.add_argument("--dry-run", action="store_true", help="only compute duplicate counts")
|
||||
args = ap.parse_args()
|
||||
|
||||
cfg = AppConfig.load({})
|
||||
db = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
|
||||
try:
|
||||
db.conn.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
db.conn.autocommit = True
|
||||
|
||||
tables = _fetch_tables(db.conn, args.schema)
|
||||
if args.tables.strip():
|
||||
whitelist = {t.strip() for t in args.tables.split(",") if t.strip()}
|
||||
tables = [t for t in tables if t in whitelist]
|
||||
|
||||
report = {
|
||||
"schema": args.schema,
|
||||
"tables": [],
|
||||
"summary": {
|
||||
"total_tables": len(tables),
|
||||
"checked_tables": 0,
|
||||
"total_duplicates": 0,
|
||||
"deleted_rows": 0,
|
||||
"error_rows": 0,
|
||||
"skipped_tables": 0,
|
||||
},
|
||||
}
|
||||
|
||||
for table in tables:
|
||||
table_label = f"{args.schema}.{table}"
|
||||
cols = _fetch_columns(db.conn, args.schema, table)
|
||||
cols_lower = {c.lower() for c in cols}
|
||||
if "content_hash" not in cols_lower or "fetched_at" not in cols_lower:
|
||||
print(f"[{table_label}] skip: missing content_hash/fetched_at", flush=True)
|
||||
report["summary"]["skipped_tables"] += 1
|
||||
continue
|
||||
|
||||
key_cols = _fetch_pk_columns(db.conn, args.schema, table)
|
||||
if not key_cols:
|
||||
print(f"[{table_label}] skip: missing primary key", flush=True)
|
||||
report["summary"]["skipped_tables"] += 1
|
||||
continue
|
||||
|
||||
total_dupes = _count_duplicates(db.conn, args.schema, table, key_cols)
|
||||
print(f"[{table_label}] duplicates={total_dupes}", flush=True)
|
||||
deleted = 0
|
||||
errors = 0
|
||||
|
||||
if not args.dry_run and total_dupes:
|
||||
while True:
|
||||
try:
|
||||
batch_deleted = _delete_duplicate_batch(
|
||||
db.conn,
|
||||
args.schema,
|
||||
table,
|
||||
key_cols,
|
||||
args.batch_size,
|
||||
)
|
||||
except psycopg2.Error:
|
||||
errors += 1
|
||||
break
|
||||
if batch_deleted <= 0:
|
||||
break
|
||||
deleted += batch_deleted
|
||||
if args.progress_every and deleted % int(args.progress_every) == 0:
|
||||
_print_progress(table_label, deleted, total_dupes, errors)
|
||||
|
||||
if deleted and (not args.progress_every or deleted % int(args.progress_every) != 0):
|
||||
_print_progress(table_label, deleted, total_dupes, errors)
|
||||
|
||||
report["tables"].append(
|
||||
{
|
||||
"table": table_label,
|
||||
"duplicate_rows": total_dupes,
|
||||
"deleted_rows": deleted,
|
||||
"error_rows": errors,
|
||||
}
|
||||
)
|
||||
report["summary"]["checked_tables"] += 1
|
||||
report["summary"]["total_duplicates"] += total_dupes
|
||||
report["summary"]["deleted_rows"] += deleted
|
||||
report["summary"]["error_rows"] += errors
|
||||
|
||||
out_path = _build_report_path(args.out)
|
||||
out_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(f"[REPORT] {out_path}", flush=True)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
86
scripts/repair/fix_dim_assistant_user_id.py
Normal file
86
scripts/repair/fix_dim_assistant_user_id.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""修复 dim_assistant 表中的 user_id 字段"""
|
||||
import sys
|
||||
sys.path.insert(0, '.')
|
||||
from config.settings import AppConfig
|
||||
from database.connection import DatabaseConnection
|
||||
from database.operations import DatabaseOperations
|
||||
|
||||
config = AppConfig.load()
|
||||
db_conn = DatabaseConnection(config.config['db']['dsn'])
|
||||
db = DatabaseOperations(db_conn)
|
||||
|
||||
print("=== 修复 dim_assistant.user_id ===")
|
||||
|
||||
# 方案:从 ODS 表更新 DWD 表的 user_id
|
||||
# 通过 id (ODS) = assistant_id (DWD) 关联
|
||||
|
||||
# 1. 先检查当前状态
|
||||
print("\n修复前:")
|
||||
sql_before = """
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
COUNT(CASE WHEN user_id > 0 THEN 1 END) as has_user_id
|
||||
FROM billiards_dwd.dim_assistant
|
||||
WHERE scd2_is_current = 1
|
||||
"""
|
||||
r = dict(db.query(sql_before)[0])
|
||||
print(f" 总记录: {r['total']}, 有user_id: {r['has_user_id']}")
|
||||
|
||||
# 2. 执行更新
|
||||
print("\n执行更新...")
|
||||
update_sql = """
|
||||
UPDATE billiards_dwd.dim_assistant d
|
||||
SET user_id = o.user_id
|
||||
FROM (
|
||||
SELECT DISTINCT ON (id) id, user_id
|
||||
FROM billiards_ods.assistant_accounts_master
|
||||
WHERE user_id > 0
|
||||
ORDER BY id, fetched_at DESC
|
||||
) o
|
||||
WHERE d.assistant_id = o.id
|
||||
AND (d.user_id IS NULL OR d.user_id = 0)
|
||||
"""
|
||||
with db_conn.conn.cursor() as cur:
|
||||
cur.execute(update_sql)
|
||||
updated = cur.rowcount
|
||||
print(f" 更新了 {updated} 条记录")
|
||||
db_conn.conn.commit()
|
||||
|
||||
# 3. 检查修复后状态
|
||||
print("\n修复后:")
|
||||
r2 = dict(db.query(sql_before)[0])
|
||||
print(f" 总记录: {r2['total']}, 有user_id: {r2['has_user_id']}")
|
||||
|
||||
# 4. 显示样本数据
|
||||
print("\n样本数据:")
|
||||
sql_sample = """
|
||||
SELECT assistant_id, user_id, assistant_no, nickname
|
||||
FROM billiards_dwd.dim_assistant
|
||||
WHERE scd2_is_current = 1
|
||||
ORDER BY assistant_no::int
|
||||
LIMIT 10
|
||||
"""
|
||||
for row in db.query(sql_sample):
|
||||
r = dict(row)
|
||||
print(f" assistant_id={r['assistant_id']}, user_id={r['user_id']}, no={r['assistant_no']}, nickname={r['nickname']}")
|
||||
|
||||
# 5. 验证与服务日志的关联
|
||||
print("\n验证与服务日志的关联:")
|
||||
sql_verify = """
|
||||
SELECT
|
||||
COUNT(DISTINCT s.user_id) as service_unique_users,
|
||||
COUNT(DISTINCT CASE WHEN d.assistant_id IS NOT NULL THEN s.user_id END) as matched_users
|
||||
FROM billiards_dwd.dwd_assistant_service_log s
|
||||
LEFT JOIN billiards_dwd.dim_assistant d
|
||||
ON s.user_id = d.user_id AND d.scd2_is_current = 1
|
||||
WHERE s.is_delete = 0 AND s.user_id > 0
|
||||
"""
|
||||
r3 = dict(db.query(sql_verify)[0])
|
||||
print(f" 服务日志唯一user_id: {r3['service_unique_users']}")
|
||||
print(f" 能匹配到dim_assistant: {r3['matched_users']}")
|
||||
match_rate = r3['matched_users'] / r3['service_unique_users'] * 100 if r3['service_unique_users'] > 0 else 0
|
||||
print(f" 匹配率: {match_rate:.1f}%")
|
||||
|
||||
db_conn.close()
|
||||
print("\n完成!")
|
||||
302
scripts/repair/repair_ods_content_hash.py
Normal file
302
scripts/repair/repair_ods_content_hash.py
Normal file
@@ -0,0 +1,302 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Repair ODS content_hash values by recomputing from payload.
|
||||
|
||||
Usage:
|
||||
PYTHONPATH=. python -m scripts.repair.repair_ods_content_hash
|
||||
PYTHONPATH=. python -m scripts.repair.repair_ods_content_hash --schema billiards_ods
|
||||
PYTHONPATH=. python -m scripts.repair.repair_ods_content_hash --tables member_profiles,orders
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Sequence
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import RealDictCursor
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from config.settings import AppConfig
|
||||
from database.connection import DatabaseConnection
|
||||
from tasks.ods.ods_tasks import BaseOdsTask
|
||||
|
||||
|
||||
def _reconfigure_stdout_utf8() -> None:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
try:
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _fetch_tables(conn, schema: str) -> list[str]:
|
||||
sql = """
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = %s AND table_type = 'BASE TABLE'
|
||||
ORDER BY table_name
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (schema,))
|
||||
return [r[0] for r in cur.fetchall()]
|
||||
|
||||
|
||||
def _fetch_columns(conn, schema: str, table: str) -> list[str]:
|
||||
sql = """
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = %s AND table_name = %s
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (schema, table))
|
||||
cols = [r[0] for r in cur.fetchall()]
|
||||
return [c for c in cols if c]
|
||||
|
||||
|
||||
def _fetch_pk_columns(conn, schema: str, table: str) -> list[str]:
|
||||
sql = """
|
||||
SELECT kcu.column_name
|
||||
FROM information_schema.table_constraints tc
|
||||
JOIN information_schema.key_column_usage kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
WHERE tc.constraint_type = 'PRIMARY KEY'
|
||||
AND tc.table_schema = %s
|
||||
AND tc.table_name = %s
|
||||
ORDER BY kcu.ordinal_position
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (schema, table))
|
||||
cols = [r[0] for r in cur.fetchall()]
|
||||
return [c for c in cols if c.lower() != "content_hash"]
|
||||
|
||||
|
||||
def _fetch_row_count(conn, schema: str, table: str) -> int:
|
||||
sql = f'SELECT COUNT(*) FROM "{schema}"."{table}"'
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql)
|
||||
row = cur.fetchone()
|
||||
return int(row[0] if row else 0)
|
||||
|
||||
|
||||
def _iter_rows(
|
||||
conn,
|
||||
schema: str,
|
||||
table: str,
|
||||
select_cols: Sequence[str],
|
||||
batch_size: int,
|
||||
) -> Iterable[dict]:
|
||||
cols_sql = ", ".join("ctid" if c == "ctid" else f'"{c}"' for c in select_cols)
|
||||
sql = f'SELECT {cols_sql} FROM "{schema}"."{table}"'
|
||||
with conn.cursor(name=f"ods_hash_fix_{table}", cursor_factory=RealDictCursor) as cur:
|
||||
cur.itersize = max(1, int(batch_size or 500))
|
||||
cur.execute(sql)
|
||||
for row in cur:
|
||||
yield row
|
||||
|
||||
|
||||
def _build_report_path(out_arg: str | None) -> Path:
|
||||
if out_arg:
|
||||
return Path(out_arg)
|
||||
reports_dir = PROJECT_ROOT / "reports"
|
||||
reports_dir.mkdir(parents=True, exist_ok=True)
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
return reports_dir / f"ods_content_hash_repair_{ts}.json"
|
||||
|
||||
|
||||
def _print_progress(
|
||||
table_label: str,
|
||||
processed: int,
|
||||
total: int,
|
||||
updated: int,
|
||||
skipped: int,
|
||||
conflicts: int,
|
||||
errors: int,
|
||||
missing_hash: int,
|
||||
invalid_payload: int,
|
||||
) -> None:
|
||||
if total:
|
||||
msg = (
|
||||
f"[{table_label}] checked {processed}/{total} "
|
||||
f"updated={updated} skipped={skipped} conflicts={conflicts} errors={errors} "
|
||||
f"missing_hash={missing_hash} invalid_payload={invalid_payload}"
|
||||
)
|
||||
else:
|
||||
msg = (
|
||||
f"[{table_label}] checked {processed} "
|
||||
f"updated={updated} skipped={skipped} conflicts={conflicts} errors={errors} "
|
||||
f"missing_hash={missing_hash} invalid_payload={invalid_payload}"
|
||||
)
|
||||
print(msg, flush=True)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
_reconfigure_stdout_utf8()
|
||||
ap = argparse.ArgumentParser(description="Repair ODS content_hash using payload")
|
||||
ap.add_argument("--schema", default="billiards_ods", help="ODS schema name")
|
||||
ap.add_argument("--tables", default="", help="comma-separated table names (optional)")
|
||||
ap.add_argument("--batch-size", type=int, default=500, help="DB fetch batch size")
|
||||
ap.add_argument("--progress-every", type=int, default=100, help="print progress every N rows")
|
||||
ap.add_argument("--sample-limit", type=int, default=10, help="sample conflicts per table")
|
||||
ap.add_argument("--out", default="", help="output report JSON path")
|
||||
ap.add_argument("--dry-run", action="store_true", help="only compute stats, do not update")
|
||||
args = ap.parse_args()
|
||||
|
||||
cfg = AppConfig.load({})
|
||||
db_read = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
|
||||
db_write = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
|
||||
try:
|
||||
db_write.conn.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
db_write.conn.autocommit = True
|
||||
|
||||
tables = _fetch_tables(db_read.conn, args.schema)
|
||||
if args.tables.strip():
|
||||
whitelist = {t.strip() for t in args.tables.split(",") if t.strip()}
|
||||
tables = [t for t in tables if t in whitelist]
|
||||
|
||||
report = {
|
||||
"schema": args.schema,
|
||||
"tables": [],
|
||||
"summary": {
|
||||
"total_tables": len(tables),
|
||||
"checked_tables": 0,
|
||||
"total_rows": 0,
|
||||
"checked_rows": 0,
|
||||
"updated_rows": 0,
|
||||
"skipped_rows": 0,
|
||||
"conflict_rows": 0,
|
||||
"error_rows": 0,
|
||||
"missing_hash_rows": 0,
|
||||
"invalid_payload_rows": 0,
|
||||
},
|
||||
}
|
||||
|
||||
for table in tables:
|
||||
table_label = f"{args.schema}.{table}"
|
||||
cols = _fetch_columns(db_read.conn, args.schema, table)
|
||||
cols_lower = {c.lower() for c in cols}
|
||||
if "payload" not in cols_lower or "content_hash" not in cols_lower:
|
||||
print(f"[{table_label}] skip: missing payload/content_hash", flush=True)
|
||||
continue
|
||||
|
||||
total = _fetch_row_count(db_read.conn, args.schema, table)
|
||||
pk_cols = _fetch_pk_columns(db_read.conn, args.schema, table)
|
||||
select_cols = ["ctid", "content_hash", "payload", *pk_cols]
|
||||
|
||||
processed = 0
|
||||
updated = 0
|
||||
skipped = 0
|
||||
conflicts = 0
|
||||
errors = 0
|
||||
missing_hash = 0
|
||||
invalid_payload = 0
|
||||
samples: list[dict[str, Any]] = []
|
||||
|
||||
print(f"[{table_label}] start: total_rows={total}", flush=True)
|
||||
|
||||
for row in _iter_rows(db_read.conn, args.schema, table, select_cols, args.batch_size):
|
||||
processed += 1
|
||||
content_hash = row.get("content_hash")
|
||||
payload = row.get("payload")
|
||||
recomputed = BaseOdsTask._compute_compare_hash_from_payload(payload)
|
||||
row_ctid = row.get("ctid")
|
||||
|
||||
if not content_hash:
|
||||
missing_hash += 1
|
||||
if not recomputed:
|
||||
invalid_payload += 1
|
||||
|
||||
if not recomputed:
|
||||
skipped += 1
|
||||
elif content_hash == recomputed:
|
||||
skipped += 1
|
||||
else:
|
||||
if args.dry_run:
|
||||
updated += 1
|
||||
else:
|
||||
try:
|
||||
with db_write.conn.cursor() as cur:
|
||||
cur.execute(
|
||||
f'UPDATE "{args.schema}"."{table}" SET content_hash = %s WHERE ctid = %s',
|
||||
(recomputed, row_ctid),
|
||||
)
|
||||
updated += 1
|
||||
except psycopg2.errors.UniqueViolation:
|
||||
conflicts += 1
|
||||
if len(samples) < max(0, int(args.sample_limit or 0)):
|
||||
sample = {k: row.get(k) for k in pk_cols}
|
||||
sample["content_hash"] = content_hash
|
||||
sample["recomputed_hash"] = recomputed
|
||||
samples.append(sample)
|
||||
except psycopg2.Error:
|
||||
errors += 1
|
||||
|
||||
if args.progress_every and processed % int(args.progress_every) == 0:
|
||||
_print_progress(
|
||||
table_label,
|
||||
processed,
|
||||
total,
|
||||
updated,
|
||||
skipped,
|
||||
conflicts,
|
||||
errors,
|
||||
missing_hash,
|
||||
invalid_payload,
|
||||
)
|
||||
|
||||
if processed and (not args.progress_every or processed % int(args.progress_every) != 0):
|
||||
_print_progress(
|
||||
table_label,
|
||||
processed,
|
||||
total,
|
||||
updated,
|
||||
skipped,
|
||||
conflicts,
|
||||
errors,
|
||||
missing_hash,
|
||||
invalid_payload,
|
||||
)
|
||||
|
||||
report["tables"].append(
|
||||
{
|
||||
"table": table_label,
|
||||
"total_rows": total,
|
||||
"checked_rows": processed,
|
||||
"updated_rows": updated,
|
||||
"skipped_rows": skipped,
|
||||
"conflict_rows": conflicts,
|
||||
"error_rows": errors,
|
||||
"missing_hash_rows": missing_hash,
|
||||
"invalid_payload_rows": invalid_payload,
|
||||
"conflict_samples": samples,
|
||||
}
|
||||
)
|
||||
|
||||
report["summary"]["checked_tables"] += 1
|
||||
report["summary"]["total_rows"] += total
|
||||
report["summary"]["checked_rows"] += processed
|
||||
report["summary"]["updated_rows"] += updated
|
||||
report["summary"]["skipped_rows"] += skipped
|
||||
report["summary"]["conflict_rows"] += conflicts
|
||||
report["summary"]["error_rows"] += errors
|
||||
report["summary"]["missing_hash_rows"] += missing_hash
|
||||
report["summary"]["invalid_payload_rows"] += invalid_payload
|
||||
|
||||
out_path = _build_report_path(args.out)
|
||||
out_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(f"[REPORT] {out_path}", flush=True)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
231
scripts/repair/tune_integrity_indexes.py
Normal file
231
scripts/repair/tune_integrity_indexes.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Create performance indexes for integrity verification and run ANALYZE.
|
||||
|
||||
Usage:
|
||||
python -m scripts.tune_integrity_indexes
|
||||
python -m scripts.tune_integrity_indexes --dry-run
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Sequence, Set, Tuple
|
||||
|
||||
import psycopg2
|
||||
from psycopg2 import sql
|
||||
|
||||
from config.settings import AppConfig
|
||||
|
||||
|
||||
TIME_CANDIDATES = (
|
||||
"pay_time",
|
||||
"create_time",
|
||||
"start_use_time",
|
||||
"scd2_start_time",
|
||||
"calc_time",
|
||||
"order_date",
|
||||
"fetched_at",
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IndexPlan:
|
||||
schema: str
|
||||
table: str
|
||||
index_name: str
|
||||
columns: Tuple[str, ...]
|
||||
|
||||
|
||||
def _short_index_name(table: str, tag: str, columns: Sequence[str]) -> str:
|
||||
raw = f"idx_{table}_{tag}_{'_'.join(columns)}"
|
||||
if len(raw) <= 63:
|
||||
return raw
|
||||
digest = hashlib.md5(raw.encode("utf-8")).hexdigest()[:8]
|
||||
shortened = f"idx_{table}_{tag}_{digest}"
|
||||
return shortened[:63]
|
||||
|
||||
|
||||
def _load_table_columns(cur, schema: str, table: str) -> Set[str]:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = %s AND table_name = %s
|
||||
""",
|
||||
(schema, table),
|
||||
)
|
||||
return {r[0] for r in cur.fetchall()}
|
||||
|
||||
|
||||
def _load_pk_columns(cur, schema: str, table: str) -> List[str]:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT kcu.column_name
|
||||
FROM information_schema.table_constraints tc
|
||||
JOIN information_schema.key_column_usage kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
AND tc.table_name = kcu.table_name
|
||||
WHERE tc.table_schema = %s
|
||||
AND tc.table_name = %s
|
||||
AND tc.constraint_type = 'PRIMARY KEY'
|
||||
ORDER BY kcu.ordinal_position
|
||||
""",
|
||||
(schema, table),
|
||||
)
|
||||
return [r[0] for r in cur.fetchall()]
|
||||
|
||||
|
||||
def _load_tables(cur, schema: str) -> List[str]:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = %s
|
||||
AND table_type = 'BASE TABLE'
|
||||
ORDER BY table_name
|
||||
""",
|
||||
(schema,),
|
||||
)
|
||||
return [r[0] for r in cur.fetchall()]
|
||||
|
||||
|
||||
def _plan_indexes(cur, schema: str, table: str) -> List[IndexPlan]:
|
||||
plans: List[IndexPlan] = []
|
||||
cols = _load_table_columns(cur, schema, table)
|
||||
pk_cols = _load_pk_columns(cur, schema, table)
|
||||
|
||||
if schema == "billiards_ods":
|
||||
if "fetched_at" in cols:
|
||||
plans.append(
|
||||
IndexPlan(
|
||||
schema=schema,
|
||||
table=table,
|
||||
index_name=_short_index_name(table, "fetched_at", ("fetched_at",)),
|
||||
columns=("fetched_at",),
|
||||
)
|
||||
)
|
||||
if pk_cols and len(pk_cols) <= 3 and all(c in cols for c in pk_cols):
|
||||
comp_cols = ("fetched_at", *pk_cols)
|
||||
plans.append(
|
||||
IndexPlan(
|
||||
schema=schema,
|
||||
table=table,
|
||||
index_name=_short_index_name(table, "fetched_pk", comp_cols),
|
||||
columns=comp_cols,
|
||||
)
|
||||
)
|
||||
|
||||
if schema == "billiards_dwd":
|
||||
if pk_cols and "scd2_is_current" in cols and len(pk_cols) <= 4:
|
||||
comp_cols = (*pk_cols, "scd2_is_current")
|
||||
plans.append(
|
||||
IndexPlan(
|
||||
schema=schema,
|
||||
table=table,
|
||||
index_name=_short_index_name(table, "pk_current", comp_cols),
|
||||
columns=comp_cols,
|
||||
)
|
||||
)
|
||||
|
||||
for tcol in TIME_CANDIDATES:
|
||||
if tcol in cols:
|
||||
plans.append(
|
||||
IndexPlan(
|
||||
schema=schema,
|
||||
table=table,
|
||||
index_name=_short_index_name(table, "time", (tcol,)),
|
||||
columns=(tcol,),
|
||||
)
|
||||
)
|
||||
if pk_cols and len(pk_cols) <= 3 and all(c in cols for c in pk_cols):
|
||||
comp_cols = (tcol, *pk_cols)
|
||||
plans.append(
|
||||
IndexPlan(
|
||||
schema=schema,
|
||||
table=table,
|
||||
index_name=_short_index_name(table, "time_pk", comp_cols),
|
||||
columns=comp_cols,
|
||||
)
|
||||
)
|
||||
|
||||
# 按索引名去重
|
||||
dedup: Dict[str, IndexPlan] = {}
|
||||
for p in plans:
|
||||
dedup[p.index_name] = p
|
||||
return list(dedup.values())
|
||||
|
||||
|
||||
def _create_index(cur, plan: IndexPlan) -> None:
|
||||
stmt = sql.SQL("CREATE INDEX IF NOT EXISTS {idx} ON {sch}.{tbl} ({cols})").format(
|
||||
idx=sql.Identifier(plan.index_name),
|
||||
sch=sql.Identifier(plan.schema),
|
||||
tbl=sql.Identifier(plan.table),
|
||||
cols=sql.SQL(", ").join(sql.Identifier(c) for c in plan.columns),
|
||||
)
|
||||
cur.execute(stmt)
|
||||
|
||||
|
||||
def _analyze_table(cur, schema: str, table: str) -> None:
|
||||
stmt = sql.SQL("ANALYZE {sch}.{tbl}").format(
|
||||
sch=sql.Identifier(schema),
|
||||
tbl=sql.Identifier(table),
|
||||
)
|
||||
cur.execute(stmt)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser(description="Tune indexes for integrity verification.")
|
||||
ap.add_argument("--dry-run", action="store_true", help="Print planned SQL only.")
|
||||
ap.add_argument(
|
||||
"--skip-analyze",
|
||||
action="store_true",
|
||||
help="Create indexes but skip ANALYZE.",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
cfg = AppConfig.load({})
|
||||
dsn = cfg.get("db.dsn")
|
||||
timeout_sec = int(cfg.get("db.connect_timeout_sec", 10) or 10)
|
||||
|
||||
with psycopg2.connect(dsn, connect_timeout=timeout_sec) as conn:
|
||||
conn.autocommit = False
|
||||
with conn.cursor() as cur:
|
||||
all_plans: List[IndexPlan] = []
|
||||
for schema in ("billiards_ods", "billiards_dwd"):
|
||||
for table in _load_tables(cur, schema):
|
||||
all_plans.extend(_plan_indexes(cur, schema, table))
|
||||
|
||||
touched_tables: Set[Tuple[str, str]] = set()
|
||||
print(f"planned indexes: {len(all_plans)}")
|
||||
for plan in all_plans:
|
||||
cols = ", ".join(plan.columns)
|
||||
print(f"[INDEX] {plan.schema}.{plan.table} ({cols}) -> {plan.index_name}")
|
||||
if not args.dry_run:
|
||||
_create_index(cur, plan)
|
||||
touched_tables.add((plan.schema, plan.table))
|
||||
|
||||
if not args.skip_analyze:
|
||||
if args.dry_run:
|
||||
for schema, table in sorted({(p.schema, p.table) for p in all_plans}):
|
||||
print(f"[ANALYZE] {schema}.{table}")
|
||||
else:
|
||||
for schema, table in sorted(touched_tables):
|
||||
_analyze_table(cur, schema, table)
|
||||
print(f"[ANALYZE] {schema}.{table}")
|
||||
|
||||
if args.dry_run:
|
||||
conn.rollback()
|
||||
print("dry-run complete; transaction rolled back")
|
||||
else:
|
||||
conn.commit()
|
||||
print("index tuning complete")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
|
||||
26
scripts/run_ods.bat
Normal file
26
scripts/run_ods.bat
Normal file
@@ -0,0 +1,26 @@
|
||||
@echo off
|
||||
REM -*- coding: utf-8 -*-
|
||||
REM 说明:一键重建 ODS(执行 INIT_ODS_SCHEMA)并灌入示例 JSON(执行 MANUAL_INGEST)
|
||||
|
||||
setlocal
|
||||
cd /d "%~dp0\.."
|
||||
|
||||
REM 如果需要覆盖示例目录,可修改下面的 INGEST_DIR
|
||||
set "INGEST_DIR=export\\test-json-doc"
|
||||
|
||||
echo [INIT_ODS_SCHEMA] 准备执行,源目录=%INGEST_DIR%
|
||||
python -m cli.main --tasks INIT_ODS_SCHEMA --pipeline-flow INGEST_ONLY --ingest-source "%INGEST_DIR%"
|
||||
if errorlevel 1 (
|
||||
echo INIT_ODS_SCHEMA 失败,退出
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo [MANUAL_INGEST] 准备执行,源目录=%INGEST_DIR%
|
||||
python -m cli.main --tasks MANUAL_INGEST --pipeline-flow INGEST_ONLY --ingest-source "%INGEST_DIR%"
|
||||
if errorlevel 1 (
|
||||
echo MANUAL_INGEST 失败,退出
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo 全部完成。
|
||||
endlocal
|
||||
516
scripts/run_update.py
Normal file
516
scripts/run_update.py
Normal file
@@ -0,0 +1,516 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
一键增量更新脚本(ODS -> DWD -> DWS)。
|
||||
|
||||
用法:
|
||||
python scripts/run_update.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import subprocess
|
||||
import sys
|
||||
import time as time_mod
|
||||
from datetime import date, datetime, time, timedelta
|
||||
from pathlib import Path
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from api.client import APIClient
|
||||
from config.settings import AppConfig
|
||||
from database.connection import DatabaseConnection
|
||||
from database.operations import DatabaseOperations
|
||||
from orchestration.scheduler import ETLScheduler
|
||||
from tasks.utility.check_cutoff_task import CheckCutoffTask
|
||||
from tasks.dwd.dwd_load_task import DwdLoadTask
|
||||
from tasks.ods.ods_tasks import ENABLED_ODS_CODES
|
||||
from utils.logging_utils import build_log_path, configure_logging
|
||||
|
||||
STEP_TIMEOUT_SEC = 120
|
||||
|
||||
|
||||
|
||||
def _coerce_date(s: str) -> date:
|
||||
s = (s or "").strip()
|
||||
if not s:
|
||||
raise ValueError("empty date")
|
||||
if len(s) >= 10:
|
||||
s = s[:10]
|
||||
return date.fromisoformat(s)
|
||||
|
||||
|
||||
def _compute_dws_window(
|
||||
*,
|
||||
cfg: AppConfig,
|
||||
tz: ZoneInfo,
|
||||
rebuild_days: int,
|
||||
bootstrap_days: int,
|
||||
dws_start: date | None,
|
||||
dws_end: date | None,
|
||||
) -> tuple[datetime, datetime]:
|
||||
if dws_start and dws_end and dws_end < dws_start:
|
||||
raise ValueError("dws_end must be >= dws_start")
|
||||
|
||||
store_id = int(cfg.get("app.store_id"))
|
||||
dsn = cfg["db"]["dsn"]
|
||||
session = cfg["db"].get("session")
|
||||
conn = DatabaseConnection(dsn=dsn, session=session)
|
||||
try:
|
||||
if dws_start is None:
|
||||
row = conn.query(
|
||||
"SELECT MAX(order_date) AS mx FROM billiards_dws.dws_order_summary WHERE site_id=%s",
|
||||
(store_id,),
|
||||
)
|
||||
mx = (row[0] or {}).get("mx") if row else None
|
||||
if isinstance(mx, date):
|
||||
dws_start = mx - timedelta(days=max(0, int(rebuild_days)))
|
||||
else:
|
||||
dws_start = (datetime.now(tz).date()) - timedelta(days=max(1, int(bootstrap_days)))
|
||||
|
||||
if dws_end is None:
|
||||
dws_end = datetime.now(tz).date()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
start_dt = datetime.combine(dws_start, time.min).replace(tzinfo=tz)
|
||||
# end_dt 取到当天 23:59:59,避免只跑到“当前时刻”的 date() 导致少一天
|
||||
end_dt = datetime.combine(dws_end, time.max).replace(tzinfo=tz)
|
||||
return start_dt, end_dt
|
||||
|
||||
|
||||
def _run_check_cutoff(cfg: AppConfig, logger: logging.Logger):
|
||||
dsn = cfg["db"]["dsn"]
|
||||
session = cfg["db"].get("session")
|
||||
db_conn = DatabaseConnection(dsn=dsn, session=session)
|
||||
db_ops = DatabaseOperations(db_conn)
|
||||
api = APIClient(
|
||||
base_url=cfg["api"]["base_url"],
|
||||
token=cfg["api"]["token"],
|
||||
timeout=cfg["api"]["timeout_sec"],
|
||||
retry_max=cfg["api"]["retries"]["max_attempts"],
|
||||
headers_extra=cfg["api"].get("headers_extra"),
|
||||
)
|
||||
try:
|
||||
CheckCutoffTask(cfg, db_ops, api, logger).execute(None)
|
||||
finally:
|
||||
db_conn.close()
|
||||
|
||||
|
||||
def _iter_daily_windows(window_start: datetime, window_end: datetime) -> list[tuple[datetime, datetime]]:
|
||||
if window_start > window_end:
|
||||
return []
|
||||
tz = window_start.tzinfo
|
||||
windows: list[tuple[datetime, datetime]] = []
|
||||
cur = window_start
|
||||
while cur <= window_end:
|
||||
day_start = datetime.combine(cur.date(), time.min).replace(tzinfo=tz)
|
||||
day_end = datetime.combine(cur.date(), time.max).replace(tzinfo=tz)
|
||||
if day_start < window_start:
|
||||
day_start = window_start
|
||||
if day_end > window_end:
|
||||
day_end = window_end
|
||||
windows.append((day_start, day_end))
|
||||
next_day = cur.date() + timedelta(days=1)
|
||||
cur = datetime.combine(next_day, time.min).replace(tzinfo=tz)
|
||||
return windows
|
||||
|
||||
|
||||
def _run_step_worker(result_queue: "mp.Queue[dict[str, str]]", step: dict[str, str]) -> None:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
try:
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
log_file = step.get("log_file") or ""
|
||||
log_level = step.get("log_level") or "INFO"
|
||||
log_console = bool(step.get("log_console", True))
|
||||
log_path = Path(log_file) if log_file else None
|
||||
|
||||
with configure_logging(
|
||||
"etl_update",
|
||||
log_path,
|
||||
level=log_level,
|
||||
console=log_console,
|
||||
tee_std=True,
|
||||
) as logger:
|
||||
cfg_base = AppConfig.load({})
|
||||
step_type = step.get("type", "")
|
||||
try:
|
||||
if step_type == "check_cutoff":
|
||||
_run_check_cutoff(cfg_base, logger)
|
||||
elif step_type == "ods_task":
|
||||
task_code = step["task_code"]
|
||||
overlap_seconds = int(step.get("overlap_seconds", 0))
|
||||
cfg_ods = AppConfig.load(
|
||||
{
|
||||
"pipeline": {"flow": "FULL"},
|
||||
"run": {"tasks": [task_code], "overlap_seconds": overlap_seconds},
|
||||
}
|
||||
)
|
||||
scheduler = ETLScheduler(cfg_ods, logger)
|
||||
try:
|
||||
scheduler.run_tasks([task_code])
|
||||
finally:
|
||||
scheduler.close()
|
||||
elif step_type == "init_dws_schema":
|
||||
overlap_seconds = int(step.get("overlap_seconds", 0))
|
||||
cfg_dwd = AppConfig.load(
|
||||
{
|
||||
"pipeline": {"flow": "INGEST_ONLY"},
|
||||
"run": {"tasks": ["INIT_DWS_SCHEMA"], "overlap_seconds": overlap_seconds},
|
||||
}
|
||||
)
|
||||
scheduler = ETLScheduler(cfg_dwd, logger)
|
||||
try:
|
||||
scheduler.run_tasks(["INIT_DWS_SCHEMA"])
|
||||
finally:
|
||||
scheduler.close()
|
||||
elif step_type == "dwd_table":
|
||||
dwd_table = step["dwd_table"]
|
||||
overlap_seconds = int(step.get("overlap_seconds", 0))
|
||||
cfg_dwd = AppConfig.load(
|
||||
{
|
||||
"pipeline": {"flow": "INGEST_ONLY"},
|
||||
"run": {"tasks": ["DWD_LOAD_FROM_ODS"], "overlap_seconds": overlap_seconds},
|
||||
"dwd": {"only_tables": [dwd_table]},
|
||||
}
|
||||
)
|
||||
scheduler = ETLScheduler(cfg_dwd, logger)
|
||||
try:
|
||||
scheduler.run_tasks(["DWD_LOAD_FROM_ODS"])
|
||||
finally:
|
||||
scheduler.close()
|
||||
elif step_type == "dws_window":
|
||||
overlap_seconds = int(step.get("overlap_seconds", 0))
|
||||
window_start = step["window_start"]
|
||||
window_end = step["window_end"]
|
||||
cfg_dws = AppConfig.load(
|
||||
{
|
||||
"pipeline": {"flow": "INGEST_ONLY"},
|
||||
"run": {
|
||||
"tasks": ["DWS_BUILD_ORDER_SUMMARY"],
|
||||
"overlap_seconds": overlap_seconds,
|
||||
"window_override": {"start": window_start, "end": window_end},
|
||||
},
|
||||
}
|
||||
)
|
||||
scheduler = ETLScheduler(cfg_dws, logger)
|
||||
try:
|
||||
scheduler.run_tasks(["DWS_BUILD_ORDER_SUMMARY"])
|
||||
finally:
|
||||
scheduler.close()
|
||||
elif step_type == "ods_gap_check":
|
||||
overlap_hours = int(step.get("overlap_hours", 24))
|
||||
window_days = int(step.get("window_days", 1))
|
||||
window_hours = int(step.get("window_hours", 0))
|
||||
page_size = int(step.get("page_size", 0) or 0)
|
||||
sleep_per_window = float(step.get("sleep_per_window", 0) or 0)
|
||||
sleep_per_page = float(step.get("sleep_per_page", 0) or 0)
|
||||
tag = step.get("tag", "run_update")
|
||||
task_codes = (step.get("task_codes") or "").strip()
|
||||
script_dir = Path(__file__).resolve().parent.parent
|
||||
script_path = script_dir / "scripts" / "check" / "check_ods_gaps.py"
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(script_path),
|
||||
"--from-cutoff",
|
||||
"--cutoff-overlap-hours",
|
||||
str(overlap_hours),
|
||||
"--window-days",
|
||||
str(window_days),
|
||||
"--tag",
|
||||
str(tag),
|
||||
]
|
||||
if window_hours > 0:
|
||||
cmd += ["--window-hours", str(window_hours)]
|
||||
if page_size > 0:
|
||||
cmd += ["--page-size", str(page_size)]
|
||||
if sleep_per_window > 0:
|
||||
cmd += ["--sleep-per-window-seconds", str(sleep_per_window)]
|
||||
if sleep_per_page > 0:
|
||||
cmd += ["--sleep-per-page-seconds", str(sleep_per_page)]
|
||||
if task_codes:
|
||||
cmd += ["--task-codes", task_codes]
|
||||
subprocess.run(cmd, check=True, cwd=str(script_dir))
|
||||
else:
|
||||
raise ValueError(f"Unknown step type: {step_type}")
|
||||
result_queue.put({"status": "ok"})
|
||||
except Exception as exc:
|
||||
result_queue.put({"status": "error", "error": str(exc)})
|
||||
|
||||
|
||||
def _run_step_with_timeout(
|
||||
step: dict[str, str], logger: logging.Logger, timeout_sec: int
|
||||
) -> dict[str, object]:
|
||||
start = time_mod.monotonic()
|
||||
step_timeout = timeout_sec
|
||||
if step.get("timeout_sec"):
|
||||
try:
|
||||
step_timeout = int(step.get("timeout_sec"))
|
||||
except Exception:
|
||||
step_timeout = timeout_sec
|
||||
ctx = mp.get_context("spawn")
|
||||
result_queue: mp.Queue = ctx.Queue()
|
||||
proc = ctx.Process(target=_run_step_worker, args=(result_queue, step))
|
||||
proc.start()
|
||||
proc.join(timeout=step_timeout)
|
||||
elapsed = time_mod.monotonic() - start
|
||||
if proc.is_alive():
|
||||
logger.error(
|
||||
"STEP_TIMEOUT name=%s elapsed=%.2fs limit=%ss", step["name"], elapsed, step_timeout
|
||||
)
|
||||
proc.terminate()
|
||||
proc.join(10)
|
||||
return {"name": step["name"], "status": "timeout", "elapsed": elapsed}
|
||||
|
||||
result: dict[str, object] = {"name": step["name"], "status": "error", "elapsed": elapsed}
|
||||
try:
|
||||
payload = result_queue.get_nowait()
|
||||
except Exception:
|
||||
payload = {}
|
||||
if payload:
|
||||
result.update(payload)
|
||||
|
||||
if result.get("status") == "ok":
|
||||
logger.info("STEP_OK name=%s elapsed=%.2fs", step["name"], elapsed)
|
||||
else:
|
||||
logger.error(
|
||||
"STEP_FAIL name=%s elapsed=%.2fs error=%s",
|
||||
step["name"],
|
||||
elapsed,
|
||||
result.get("error"),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def main() -> int:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
try:
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
parser = argparse.ArgumentParser(description="One-click ETL update (ODS -> DWD -> DWS)")
|
||||
parser.add_argument("--overlap-seconds", type=int, default=3600, help="overlap seconds (default: 3600)")
|
||||
parser.add_argument(
|
||||
"--dws-rebuild-days",
|
||||
type=int,
|
||||
default=1,
|
||||
help="DWS 回算冗余天数(default: 1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dws-bootstrap-days",
|
||||
type=int,
|
||||
default=30,
|
||||
help="DWS 首次/空表时回算天数(default: 30)",
|
||||
)
|
||||
parser.add_argument("--dws-start", type=str, default="", help="DWS 回算开始日期 YYYY-MM-DD(可选)")
|
||||
parser.add_argument("--dws-end", type=str, default="", help="DWS 回算结束日期 YYYY-MM-DD(可选)")
|
||||
parser.add_argument(
|
||||
"--skip-cutoff",
|
||||
action="store_true",
|
||||
help="跳过 CHECK_CUTOFF(默认会在开始/结束各跑一次)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-ods",
|
||||
action="store_true",
|
||||
help="跳过 ODS 在线抓取(仅跑 DWD/DWS)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ods-tasks",
|
||||
type=str,
|
||||
default="",
|
||||
help="指定要跑的 ODS 任务(逗号分隔),默认跑全部 ENABLED_ODS_CODES",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--check-ods-gaps",
|
||||
action="store_true",
|
||||
help="run ODS gap check after ODS load (default: off)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--check-ods-overlap-hours",
|
||||
type=int,
|
||||
default=24,
|
||||
help="gap check overlap hours from cutoff (default: 24)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--check-ods-window-days",
|
||||
type=int,
|
||||
default=1,
|
||||
help="gap check window days (default: 1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--check-ods-window-hours",
|
||||
type=int,
|
||||
default=0,
|
||||
help="gap check window hours (default: 0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--check-ods-page-size",
|
||||
type=int,
|
||||
default=200,
|
||||
help="gap check API page size (default: 200)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--check-ods-timeout-sec",
|
||||
type=int,
|
||||
default=1800,
|
||||
help="gap check timeout seconds (default: 1800)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--check-ods-task-codes",
|
||||
type=str,
|
||||
default="",
|
||||
help="gap check task codes (comma-separated, optional)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--check-ods-sleep-per-window-seconds",
|
||||
type=float,
|
||||
default=0,
|
||||
help="gap check sleep seconds after each window (default: 0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--check-ods-sleep-per-page-seconds",
|
||||
type=float,
|
||||
default=0,
|
||||
help="gap check sleep seconds after each page (default: 0)",
|
||||
)
|
||||
parser.add_argument("--log-file", type=str, default="", help="log file path (default: logs/run_update_YYYYMMDD_HHMMSS.log)")
|
||||
parser.add_argument("--log-dir", type=str, default="", help="log directory (default: logs)")
|
||||
parser.add_argument("--log-level", type=str, default="INFO", help="log level (default: INFO)")
|
||||
parser.add_argument("--no-log-console", action="store_true", help="disable console logging")
|
||||
args = parser.parse_args()
|
||||
|
||||
log_dir = Path(args.log_dir) if args.log_dir else (Path(__file__).resolve().parent.parent / "logs")
|
||||
log_file = Path(args.log_file) if args.log_file else build_log_path(log_dir, "run_update")
|
||||
log_console = not args.no_log_console
|
||||
|
||||
with configure_logging(
|
||||
"etl_update",
|
||||
log_file,
|
||||
level=args.log_level,
|
||||
console=log_console,
|
||||
tee_std=True,
|
||||
) as logger:
|
||||
cfg_base = AppConfig.load({})
|
||||
tz = ZoneInfo(cfg_base.get("app.timezone", "Asia/Taipei"))
|
||||
|
||||
dws_start = _coerce_date(args.dws_start) if args.dws_start else None
|
||||
dws_end = _coerce_date(args.dws_end) if args.dws_end else None
|
||||
|
||||
steps: list[dict[str, str]] = []
|
||||
if not args.skip_cutoff:
|
||||
steps.append({"name": "CHECK_CUTOFF:before", "type": "check_cutoff"})
|
||||
|
||||
# ------------------------------------------------------------------ ODS(在线抓取 + 写入)
|
||||
if not args.skip_ods:
|
||||
if args.ods_tasks:
|
||||
ods_tasks = [t.strip().upper() for t in args.ods_tasks.split(",") if t.strip()]
|
||||
else:
|
||||
ods_tasks = sorted(ENABLED_ODS_CODES)
|
||||
for task_code in ods_tasks:
|
||||
steps.append(
|
||||
{
|
||||
"name": f"ODS:{task_code}",
|
||||
"type": "ods_task",
|
||||
"task_code": task_code,
|
||||
"overlap_seconds": str(args.overlap_seconds),
|
||||
}
|
||||
)
|
||||
|
||||
if args.check_ods_gaps:
|
||||
steps.append(
|
||||
{
|
||||
"name": "ODS_GAP_CHECK",
|
||||
"type": "ods_gap_check",
|
||||
"overlap_hours": str(args.check_ods_overlap_hours),
|
||||
"window_days": str(args.check_ods_window_days),
|
||||
"window_hours": str(args.check_ods_window_hours),
|
||||
"page_size": str(args.check_ods_page_size),
|
||||
"sleep_per_window": str(args.check_ods_sleep_per_window_seconds),
|
||||
"sleep_per_page": str(args.check_ods_sleep_per_page_seconds),
|
||||
"timeout_sec": str(args.check_ods_timeout_sec),
|
||||
"task_codes": str(args.check_ods_task_codes or ""),
|
||||
"tag": "run_update",
|
||||
}
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------ DWD(从 ODS 表装载)
|
||||
steps.append(
|
||||
{
|
||||
"name": "INIT_DWS_SCHEMA",
|
||||
"type": "init_dws_schema",
|
||||
"overlap_seconds": str(args.overlap_seconds),
|
||||
}
|
||||
)
|
||||
for dwd_table in DwdLoadTask.TABLE_MAP.keys():
|
||||
steps.append(
|
||||
{
|
||||
"name": f"DWD:{dwd_table}",
|
||||
"type": "dwd_table",
|
||||
"dwd_table": dwd_table,
|
||||
"overlap_seconds": str(args.overlap_seconds),
|
||||
}
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------ DWS(按日期窗口重建)
|
||||
window_start, window_end = _compute_dws_window(
|
||||
cfg=cfg_base,
|
||||
tz=tz,
|
||||
rebuild_days=int(args.dws_rebuild_days),
|
||||
bootstrap_days=int(args.dws_bootstrap_days),
|
||||
dws_start=dws_start,
|
||||
dws_end=dws_end,
|
||||
)
|
||||
for start_dt, end_dt in _iter_daily_windows(window_start, window_end):
|
||||
steps.append(
|
||||
{
|
||||
"name": f"DWS:{start_dt.date().isoformat()}",
|
||||
"type": "dws_window",
|
||||
"window_start": start_dt.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"window_end": end_dt.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"overlap_seconds": str(args.overlap_seconds),
|
||||
}
|
||||
)
|
||||
|
||||
if not args.skip_cutoff:
|
||||
steps.append({"name": "CHECK_CUTOFF:after", "type": "check_cutoff"})
|
||||
|
||||
for step in steps:
|
||||
step["log_file"] = str(log_file)
|
||||
step["log_level"] = args.log_level
|
||||
step["log_console"] = log_console
|
||||
|
||||
step_results: list[dict[str, object]] = []
|
||||
for step in steps:
|
||||
logger.info("STEP_START name=%s timeout=%ss", step["name"], STEP_TIMEOUT_SEC)
|
||||
result = _run_step_with_timeout(step, logger, STEP_TIMEOUT_SEC)
|
||||
step_results.append(result)
|
||||
|
||||
total = len(step_results)
|
||||
ok_count = sum(1 for r in step_results if r.get("status") == "ok")
|
||||
timeout_count = sum(1 for r in step_results if r.get("status") == "timeout")
|
||||
fail_count = total - ok_count - timeout_count
|
||||
logger.info(
|
||||
"STEP_SUMMARY total=%s ok=%s failed=%s timeout=%s",
|
||||
total,
|
||||
ok_count,
|
||||
fail_count,
|
||||
timeout_count,
|
||||
)
|
||||
for item in sorted(step_results, key=lambda r: float(r.get("elapsed", 0.0)), reverse=True):
|
||||
logger.info(
|
||||
"STEP_RESULT name=%s status=%s elapsed=%.2fs",
|
||||
item.get("name"),
|
||||
item.get("status"),
|
||||
item.get("elapsed", 0.0),
|
||||
)
|
||||
|
||||
logger.info("Update done.")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user