初始提交:飞球 ETL 系统全量代码

This commit is contained in:
Neo
2026-02-13 08:05:34 +08:00
commit 3c51f5485d
441 changed files with 117631 additions and 0 deletions

38
scripts/README.md Normal file
View File

@@ -0,0 +1,38 @@
# scripts/ — 运维与工具脚本
## 子目录
| 目录 | 用途 | 典型场景 |
|------|------|----------|
| `audit/` | 仓库审计(文件清单、调用流、文档对齐分析) | `python -m scripts.audit.run_audit` |
| `check/` | 数据检查ODS 缺口、内容哈希、完整性校验) | `python -m scripts.check.check_data_integrity` |
| `db_admin/` | 数据库管理Excel 导入 DWS 支出/回款/提成) | `python scripts/db_admin/import_dws_excel.py --type expense` |
| `export/` | 数据导出(指数、团购、亲密度、会员明细等) | `python scripts/export/export_index_tables.py` |
| `rebuild/` | 数据重建(全量 ODS→DWD 重建) | `python scripts/rebuild/rebuild_db_and_run_ods_to_dwd.py` |
| `repair/` | 数据修复回填、去重、hash 修复、维度修复) | `python scripts/repair/dedupe_ods_snapshots.py` |
## 根目录脚本
- `run_update.py` — 一键增量更新ODS → DWD → DWS适合 cron/计划任务调用
- `run_ods.bat` — Windows 批处理ODS 建表 + 灌入示例 JSON
## 运行方式
所有脚本在项目根目录(`C:\ZQYY\FQ-ETL`)执行:
```bash
# 审计报告生成
python -m scripts.audit.run_audit
# 一键增量更新
python scripts/run_update.py
# 数据完整性检查(需要数据库连接)
python -m scripts.check.check_data_integrity --window-start "2025-01-01" --window-end "2025-02-01"
```
## 注意事项
- 所有脚本依赖 `.env` 中的 `PG_DSN` 配置(或环境变量)
- `rebuild/` 下的脚本会重建 Schema生产环境慎用
- `repair/` 下的脚本会修改数据,建议先 `--dry-run`(如支持)

1
scripts/__init__.py Normal file
View File

@@ -0,0 +1 @@
# 脚本辅助工具包标记。

107
scripts/audit/__init__.py Normal file
View File

@@ -0,0 +1,107 @@
# -*- coding: utf-8 -*-
"""
仓库治理只读审计 — 共享数据模型
定义审计脚本各模块共用的 dataclass 和枚举类型。
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
# ---------------------------------------------------------------------------
# 文件元信息
# ---------------------------------------------------------------------------
@dataclass
class FileEntry:
"""单个文件/目录的元信息。"""
rel_path: str # 相对于仓库根目录的路径
is_dir: bool # 是否为目录
size_bytes: int # 文件大小(目录为 0
extension: str # 文件扩展名(小写,含点号)
is_empty_dir: bool # 是否为空目录
# ---------------------------------------------------------------------------
# 用途分类与处置标签
# ---------------------------------------------------------------------------
class Category(str, Enum):
"""文件用途分类。"""
CORE_CODE = "核心代码"
CONFIG = "配置"
DATABASE_DEF = "数据库定义"
TEST = "测试"
DOCS = "文档"
SCRIPTS = "脚本工具"
GUI = "GUI"
BUILD_DEPLOY = "构建与部署"
LOG_OUTPUT = "日志与输出"
TEMP_DEBUG = "临时与调试"
OTHER = "其他"
class Disposition(str, Enum):
"""处置标签。"""
KEEP = "保留"
CANDIDATE_DELETE = "候选删除"
CANDIDATE_ARCHIVE = "候选归档"
NEEDS_REVIEW = "待确认"
# ---------------------------------------------------------------------------
# 文件清单条目
# ---------------------------------------------------------------------------
@dataclass
class InventoryItem:
"""清单条目:路径 + 分类 + 处置 + 说明。"""
rel_path: str
category: Category
disposition: Disposition
description: str
# ---------------------------------------------------------------------------
# 流程树节点
# ---------------------------------------------------------------------------
@dataclass
class FlowNode:
"""流程树节点。"""
name: str # 节点名称(模块名/类名/函数名)
source_file: str # 所在源文件路径
node_type: str # 类型entry / module / class / function
children: list[FlowNode] = field(default_factory=list)
# ---------------------------------------------------------------------------
# 文档对齐
# ---------------------------------------------------------------------------
@dataclass
class DocMapping:
"""文档与代码的映射关系。"""
doc_path: str # 文档文件路径
doc_topic: str # 文档主题
related_code: list[str] # 关联的代码文件/模块
status: str # 状态aligned / stale / conflict / orphan
@dataclass
class AlignmentIssue:
"""对齐问题。"""
doc_path: str # 文档路径
issue_type: str # stale / conflict / missing
description: str # 问题描述
related_code: str # 关联代码路径

View File

@@ -0,0 +1,617 @@
# -*- coding: utf-8 -*-
"""
文档对齐分析器 — 检查文档与代码之间的映射关系、过期点、冲突点和缺失点。
文档来源:
- docs/ 目录(.md, .txt, .csv, .json
- 根目录 README.md
- 开发笔记/ 目录
- 各模块内的 README.md
- .kiro/steering/ 引导文件
- docs/test-json-doc/ API 响应样本
"""
from __future__ import annotations
import json
import re
from datetime import datetime, timezone
from pathlib import Path
from scripts.audit import AlignmentIssue, DocMapping
# ---------------------------------------------------------------------------
# 常量
# ---------------------------------------------------------------------------
# 文档文件扩展名
_DOC_EXTENSIONS = {".md", ".txt", ".csv"}
# 核心代码目录——缺少文档时应报告
_CORE_CODE_DIRS = {
"tasks",
"loaders",
"orchestration",
"quality",
"models",
"utils",
"api",
"scd",
"config",
"database",
}
# ODS 表中的通用元数据列,比对时忽略
_ODS_META_COLUMNS = {"content_hash", "payload", "created_at", "updated_at", "id"}
# SQL 关键字,解析 DDL 列名时排除
_SQL_KEYWORDS = {
"primary", "key", "not", "null", "default", "unique", "check",
"references", "foreign", "constraint", "index", "create", "table",
"if", "exists", "serial", "bigserial", "true", "false",
}
# ---------------------------------------------------------------------------
# 安全读取文件(编码回退)
# ---------------------------------------------------------------------------
def _safe_read(path: Path) -> str:
"""尝试以 utf-8 → gbk → latin-1 回退读取文件内容。"""
for enc in ("utf-8", "gbk", "latin-1"):
try:
return path.read_text(encoding=enc)
except (UnicodeDecodeError, UnicodeError):
continue
return ""
# ---------------------------------------------------------------------------
# scan_docs — 扫描所有文档来源
# ---------------------------------------------------------------------------
def scan_docs(repo_root: Path) -> list[str]:
"""扫描所有文档文件路径,返回相对路径列表(已排序)。
文档来源:
1. docs/ 目录下的 .md, .txt, .csv, .json 文件
2. 根目录 README.md
3. 开发笔记/ 目录
4. 各模块内的 README.md如 gui/README.md
5. .kiro/steering/ 引导文件
"""
results: list[str] = []
def _rel(p: Path) -> str:
"""返回归一化的正斜杠相对路径。"""
return str(p.relative_to(repo_root)).replace("\\", "/")
# 1. docs/ 目录(递归,含 test-json-doc 下的 .json
docs_dir = repo_root / "docs"
if docs_dir.is_dir():
for p in docs_dir.rglob("*"):
if p.is_file():
ext = p.suffix.lower()
if ext in _DOC_EXTENSIONS or ext == ".json":
results.append(_rel(p))
# 2. 根目录 README.md
root_readme = repo_root / "README.md"
if root_readme.is_file():
results.append("README.md")
# 3. 开发笔记/
dev_notes = repo_root / "开发笔记"
if dev_notes.is_dir():
for p in dev_notes.rglob("*"):
if p.is_file():
results.append(_rel(p))
# 4. 各模块内的 README.md
for child in sorted(repo_root.iterdir()):
if child.is_dir() and child.name not in ("docs", "开发笔记", ".kiro"):
readme = child / "README.md"
if readme.is_file():
results.append(_rel(readme))
# 5. .kiro/steering/
steering_dir = repo_root / ".kiro" / "steering"
if steering_dir.is_dir():
for p in sorted(steering_dir.iterdir()):
if p.is_file():
results.append(_rel(p))
return sorted(set(results))
# ---------------------------------------------------------------------------
# extract_code_references — 从文档提取代码引用
# ---------------------------------------------------------------------------
def extract_code_references(doc_path: Path) -> list[str]:
"""从文档中提取代码引用(反引号内的文件路径、类名、函数名等)。
规则:
- 提取反引号内的内容
- 跳过单字符引用
- 跳过纯数字/版本号
- 反斜杠归一化为正斜杠
- 去重
"""
if not doc_path.is_file():
return []
text = _safe_read(doc_path)
if not text:
return []
# 提取反引号内容
backtick_refs = re.findall(r"`([^`]+)`", text)
seen: set[str] = set()
results: list[str] = []
for raw in backtick_refs:
ref = raw.strip()
# 归一化反斜杠
ref = ref.replace("\\", "/")
# 跳过单字符
if len(ref) <= 1:
continue
# 跳过纯数字和版本号
if re.fullmatch(r"[\d.]+", ref):
continue
# 去重
if ref in seen:
continue
seen.add(ref)
results.append(ref)
return results
# ---------------------------------------------------------------------------
# check_reference_validity — 检查引用有效性
# ---------------------------------------------------------------------------
def check_reference_validity(ref: str, repo_root: Path) -> bool:
"""检查文档中的代码引用是否仍然有效。
检查策略:
1. 直接作为文件/目录路径检查
2. 去掉 FQ-ETL/ 前缀后检查(兼容旧文档引用)
3. 将点号路径转为文件路径检查(如 config.settings → config/settings.py
"""
# 1. 直接路径
if (repo_root / ref).exists():
return True
# 2. 去掉旧包名前缀(兼容历史文档)
for prefix in ("FQ-ETL/", "etl_billiards/"):
if ref.startswith(prefix):
stripped = ref[len(prefix):]
if (repo_root / stripped).exists():
return True
# 3. 点号模块路径 → 文件路径
if "." in ref and "/" not in ref:
as_path = ref.replace(".", "/") + ".py"
if (repo_root / as_path).exists():
return True
# 也可能是目录(包)
as_dir = ref.replace(".", "/")
if (repo_root / as_dir).is_dir():
return True
return False
# ---------------------------------------------------------------------------
# find_undocumented_modules — 找出缺少文档的核心代码模块
# ---------------------------------------------------------------------------
def find_undocumented_modules(
repo_root: Path,
documented: set[str],
) -> list[str]:
"""找出缺少文档的核心代码模块。
只检查 _CORE_CODE_DIRS 中的 .py 文件(排除 __init__.py
返回已排序的相对路径列表。
"""
undocumented: list[str] = []
for core_dir in sorted(_CORE_CODE_DIRS):
dir_path = repo_root / core_dir
if not dir_path.is_dir():
continue
for py_file in dir_path.rglob("*.py"):
if py_file.name == "__init__.py":
continue
rel = str(py_file.relative_to(repo_root))
# 归一化路径分隔符
rel = rel.replace("\\", "/")
if rel not in documented:
undocumented.append(rel)
return sorted(undocumented)
# ---------------------------------------------------------------------------
# DDL / 数据字典解析辅助函数
# ---------------------------------------------------------------------------
def _parse_ddl_tables(sql: str) -> dict[str, set[str]]:
"""从 DDL SQL 中提取表名和列名。
返回 {表名: {列名集合}} 字典。
支持带 schema 前缀的表名(如 billiards_dwd.dim_member → dim_member
"""
tables: dict[str, set[str]] = {}
# 匹配 CREATE TABLE [IF NOT EXISTS] [schema.]table_name (
create_re = re.compile(
r"CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?"
r"(?:\w+\.)?(\w+)\s*\(",
re.IGNORECASE,
)
for match in create_re.finditer(sql):
table_name = match.group(1)
# 找到对应的括号内容
start = match.end()
depth = 1
pos = start
while pos < len(sql) and depth > 0:
if sql[pos] == "(":
depth += 1
elif sql[pos] == ")":
depth -= 1
pos += 1
body = sql[start:pos - 1]
columns: set[str] = set()
# 逐行提取列名——取每行第一个标识符
for line in body.split("\n"):
line = line.strip().rstrip(",")
if not line:
continue
# 提取第一个单词
col_match = re.match(r"(\w+)", line)
if col_match:
col_name = col_match.group(1).lower()
# 排除 SQL 关键字
if col_name not in _SQL_KEYWORDS:
columns.add(col_name)
tables[table_name] = columns
return tables
def _parse_dictionary_tables(md: str) -> dict[str, set[str]]:
"""从数据字典 Markdown 中提取表名和字段名。
约定:
- 表名出现在 ## 标题中(可能带反引号)
- 字段名出现在 Markdown 表格的第一列
- 跳过表头行(含"字段"字样)和分隔行(含 ---
"""
tables: dict[str, set[str]] = {}
current_table: str | None = None
for line in md.split("\n"):
# 匹配 ## 标题中的表名
heading = re.match(r"^##\s+`?(\w+)`?", line)
if heading:
current_table = heading.group(1)
tables[current_table] = set()
continue
if current_table is None:
continue
# 跳过分隔行
if re.match(r"^\s*\|[-\s|]+\|\s*$", line):
continue
# 解析表格行
row_match = re.match(r"^\s*\|\s*(\S+)", line)
if row_match:
field = row_match.group(1)
# 跳过表头(含"字段"字样)
if field in ("字段",):
continue
tables[current_table].add(field)
return tables
# ---------------------------------------------------------------------------
# check_ddl_vs_dictionary — DDL 与数据字典比对
# ---------------------------------------------------------------------------
def check_ddl_vs_dictionary(repo_root: Path) -> list[AlignmentIssue]:
"""比对 DDL 文件与数据字典文档的覆盖度。
检查:
1. DDL 中有但字典中没有的表 → missing
2. 同名表中 DDL 有但字典没有的列 → conflict
"""
issues: list[AlignmentIssue] = []
# 收集所有 DDL 表定义
ddl_tables: dict[str, set[str]] = {}
db_dir = repo_root / "database"
if db_dir.is_dir():
for sql_file in sorted(db_dir.glob("schema_*.sql")):
content = _safe_read(sql_file)
for tbl, cols in _parse_ddl_tables(content).items():
if tbl in ddl_tables:
ddl_tables[tbl] |= cols
else:
ddl_tables[tbl] = set(cols)
# 收集所有数据字典表定义
dict_tables: dict[str, set[str]] = {}
docs_dir = repo_root / "docs"
if docs_dir.is_dir():
for dict_file in sorted(docs_dir.glob("*dictionary*.md")):
content = _safe_read(dict_file)
for tbl, fields in _parse_dictionary_tables(content).items():
if tbl in dict_tables:
dict_tables[tbl] |= fields
else:
dict_tables[tbl] = set(fields)
# 比对
for tbl, ddl_cols in sorted(ddl_tables.items()):
if tbl not in dict_tables:
issues.append(AlignmentIssue(
doc_path="docs/*dictionary*.md",
issue_type="missing",
description=f"DDL 定义了表 `{tbl}`,但数据字典中未收录",
related_code=f"database/schema_*.sql ({tbl})",
))
else:
# 检查列差异
dict_cols = dict_tables[tbl]
missing_cols = ddl_cols - dict_cols
for col in sorted(missing_cols):
issues.append(AlignmentIssue(
doc_path="docs/*dictionary*.md",
issue_type="conflict",
description=f"表 `{tbl}` 的列 `{col}` 在 DDL 中存在但数据字典中缺失",
related_code=f"database/schema_*.sql ({tbl}.{col})",
))
return issues
# ---------------------------------------------------------------------------
# check_api_samples_vs_parsers — API 样本与解析器比对
# ---------------------------------------------------------------------------
def check_api_samples_vs_parsers(repo_root: Path) -> list[AlignmentIssue]:
"""比对 API 响应样本与 ODS 表结构的一致性。
策略:
1. 扫描 docs/test-json-doc/ 下的 .json 文件
2. 提取 JSON 中的顶层字段名
3. 从 ODS DDL 中查找同名表
4. 比对字段差异(忽略 ODS 元数据列)
"""
issues: list[AlignmentIssue] = []
sample_dir = repo_root / "docs" / "test-json-doc"
if not sample_dir.is_dir():
return issues
# 收集 ODS 表定义(保留全部列,比对时忽略元数据列)
ods_tables: dict[str, set[str]] = {}
db_dir = repo_root / "database"
if db_dir.is_dir():
for sql_file in sorted(db_dir.glob("schema_*ODS*.sql")):
content = _safe_read(sql_file)
for tbl, cols in _parse_ddl_tables(content).items():
ods_tables[tbl] = cols
# 逐个样本文件比对
for json_file in sorted(sample_dir.glob("*.json")):
entity_name = json_file.stem # 文件名(不含扩展名)作为实体名
# 解析 JSON 样本
try:
content = _safe_read(json_file)
data = json.loads(content)
except (json.JSONDecodeError, ValueError):
continue
# 提取顶层字段名
sample_fields: set[str] = set()
if isinstance(data, list) and data:
# 数组格式——取第一个元素的键
first = data[0]
if isinstance(first, dict):
sample_fields = set(first.keys())
elif isinstance(data, dict):
sample_fields = set(data.keys())
if not sample_fields:
continue
# 查找匹配的 ODS 表
matched_table: str | None = None
matched_cols: set[str] = set()
for tbl, cols in ods_tables.items():
# 表名包含实体名(如 test_entity 匹配 billiards_ods.test_entity
tbl_lower = tbl.lower()
entity_lower = entity_name.lower()
if entity_lower in tbl_lower or tbl_lower == entity_lower:
matched_table = tbl
matched_cols = cols
break
if matched_table is None:
continue
# 比对:样本中有但 ODS 表中没有的字段
extra_fields = sample_fields - matched_cols
for field in sorted(extra_fields):
issues.append(AlignmentIssue(
doc_path=f"docs/test-json-doc/{json_file.name}",
issue_type="conflict",
description=(
f"API 样本字段 `{field}` 在 ODS 表 `{matched_table}` 中未定义"
),
related_code=f"database/schema_*ODS*.sql ({matched_table})",
))
return issues
# ---------------------------------------------------------------------------
# build_mappings — 构建文档与代码的映射关系
# ---------------------------------------------------------------------------
def build_mappings(
doc_paths: list[str],
repo_root: Path,
) -> list[DocMapping]:
"""为每份文档建立与代码模块的映射关系。"""
mappings: list[DocMapping] = []
for doc_rel in doc_paths:
doc_path = repo_root / doc_rel
refs = extract_code_references(doc_path)
# 确定关联代码和状态
valid_refs: list[str] = []
has_stale = False
for ref in refs:
if check_reference_validity(ref, repo_root):
valid_refs.append(ref)
else:
has_stale = True
# 推断文档主题(取文件名或第一行标题)
topic = _infer_topic(doc_path, doc_rel)
if not refs:
status = "orphan"
elif has_stale:
status = "stale"
else:
status = "aligned"
mappings.append(DocMapping(
doc_path=doc_rel,
doc_topic=topic,
related_code=valid_refs,
status=status,
))
return mappings
def _infer_topic(doc_path: Path, doc_rel: str) -> str:
"""从文档推断主题——优先取 Markdown 一级标题,否则用文件名。"""
if doc_path.is_file() and doc_path.suffix.lower() in (".md", ".txt"):
try:
text = _safe_read(doc_path)
for line in text.split("\n"):
line = line.strip()
if line.startswith("# "):
return line[2:].strip()
except Exception:
pass
return doc_rel
# ---------------------------------------------------------------------------
# render_alignment_report — 生成 Markdown 格式的文档对齐报告
# ---------------------------------------------------------------------------
def render_alignment_report(
mappings: list[DocMapping],
issues: list[AlignmentIssue],
repo_root: str,
) -> str:
"""生成 Markdown 格式的文档对齐报告。
分区:映射关系表、过期点列表、冲突点列表、缺失点列表、统计摘要。
"""
lines: list[str] = []
# --- 头部 ---
now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
lines.append("# 文档对齐报告")
lines.append("")
lines.append(f"- 生成时间:{now}")
lines.append(f"- 仓库路径:`{repo_root}`")
lines.append("")
# --- 映射关系 ---
lines.append("## 映射关系")
lines.append("")
if mappings:
lines.append("| 文档路径 | 主题 | 关联代码 | 状态 |")
lines.append("|---|---|---|---|")
for m in mappings:
code_str = ", ".join(f"`{c}`" for c in m.related_code) if m.related_code else ""
lines.append(f"| `{m.doc_path}` | {m.doc_topic} | {code_str} | {m.status} |")
else:
lines.append("未发现文档映射关系。")
lines.append("")
# --- 按 issue_type 分组 ---
stale = [i for i in issues if i.issue_type == "stale"]
conflict = [i for i in issues if i.issue_type == "conflict"]
missing = [i for i in issues if i.issue_type == "missing"]
# --- 过期点 ---
lines.append("## 过期点")
lines.append("")
if stale:
lines.append("| 文档路径 | 描述 | 关联代码 |")
lines.append("|---|---|---|")
for i in stale:
lines.append(f"| `{i.doc_path}` | {i.description} | `{i.related_code}` |")
else:
lines.append("未发现过期点。")
lines.append("")
# --- 冲突点 ---
lines.append("## 冲突点")
lines.append("")
if conflict:
lines.append("| 文档路径 | 描述 | 关联代码 |")
lines.append("|---|---|---|")
for i in conflict:
lines.append(f"| `{i.doc_path}` | {i.description} | `{i.related_code}` |")
else:
lines.append("未发现冲突点。")
lines.append("")
# --- 缺失点 ---
lines.append("## 缺失点")
lines.append("")
if missing:
lines.append("| 文档路径 | 描述 | 关联代码 |")
lines.append("|---|---|---|")
for i in missing:
lines.append(f"| `{i.doc_path}` | {i.description} | `{i.related_code}` |")
else:
lines.append("未发现缺失点。")
lines.append("")
# --- 统计摘要 ---
lines.append("## 统计摘要")
lines.append("")
lines.append(f"- 文档总数:{len(mappings)}")
lines.append(f"- 过期点数量:{len(stale)}")
lines.append(f"- 冲突点数量:{len(conflict)}")
lines.append(f"- 缺失点数量:{len(missing)}")
lines.append("")
return "\n".join(lines)

View File

@@ -0,0 +1,618 @@
# -*- coding: utf-8 -*-
"""
流程树分析器 — 通过静态分析 Python 源码的 import 语句和类继承关系,
构建从入口到末端模块的调用树。
仅执行只读操作:读取并解析 Python 源文件,不修改任何文件。
"""
from __future__ import annotations
import ast
import logging
import re
import sys
from datetime import datetime, timezone
from pathlib import Path
from scripts.audit import FileEntry, FlowNode
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# 项目内部包名列表(顶层目录中属于项目代码的包)
# ---------------------------------------------------------------------------
_PROJECT_PACKAGES: set[str] = {
"cli", "config", "api", "database", "tasks", "loaders",
"scd", "orchestration", "quality", "models", "utils",
"gui", "scripts",
}
# ---------------------------------------------------------------------------
# 已知的第三方包和标准库顶层模块(用于排除非项目导入)
# ---------------------------------------------------------------------------
_KNOWN_THIRD_PARTY: set[str] = {
"psycopg2", "requests", "dateutil", "python_dateutil",
"dotenv", "openpyxl", "PySide6", "flask", "pyinstaller",
"PyInstaller", "hypothesis", "pytest", "_pytest", "py",
"pluggy", "pkg_resources", "setuptools", "pip", "wheel",
"tzdata", "six", "certifi", "urllib3", "charset_normalizer",
"idna", "shiboken6",
}
def _is_project_module(module_name: str) -> bool:
"""判断模块名是否属于项目内部模块。"""
top = module_name.split(".")[0]
if top in _PROJECT_PACKAGES:
return True
return False
def _is_stdlib_or_third_party(module_name: str) -> bool:
"""判断模块名是否属于标准库或已知第三方包。"""
top = module_name.split(".")[0]
if top in _KNOWN_THIRD_PARTY:
return True
# 检查标准库
if top in sys.stdlib_module_names:
return True
return False
# ---------------------------------------------------------------------------
# 文件读取(多编码回退)
# ---------------------------------------------------------------------------
def _read_source(filepath: Path) -> str | None:
"""读取 Python 源文件内容,尝试 utf-8 → gbk → latin-1 回退。
返回文件内容字符串,读取失败时返回 None。
"""
for encoding in ("utf-8", "gbk", "latin-1"):
try:
return filepath.read_text(encoding=encoding)
except (UnicodeDecodeError, UnicodeError):
continue
except (OSError, PermissionError) as exc:
logger.warning("无法读取文件 %s: %s", filepath, exc)
return None
logger.warning("无法以任何编码读取文件 %s", filepath)
return None
# ---------------------------------------------------------------------------
# 路径 ↔ 模块名转换
# ---------------------------------------------------------------------------
def _path_to_module_name(rel_path: str) -> str:
"""将相对路径转换为 Python 模块名。
例如:
- "cli/main.py""cli.main"
- "cli/__init__.py""cli"
- "tasks/dws/assistant.py""tasks.dws.assistant"
"""
p = rel_path.replace("\\", "/")
if p.endswith("/__init__.py"):
p = p[: -len("/__init__.py")]
elif p.endswith(".py"):
p = p[:-3]
return p.replace("/", ".")
def _module_to_path(module_name: str) -> str:
"""将模块名转换为相对文件路径(优先 .py 文件)。
例如:
- "cli.main""cli/main.py"
- "cli""cli/__init__.py"
"""
return module_name.replace(".", "/") + ".py"
# ---------------------------------------------------------------------------
# parse_imports — 解析 Python 文件的 import 语句
# ---------------------------------------------------------------------------
def parse_imports(filepath: Path) -> list[str]:
"""使用 ast 模块解析 Python 文件的 import 语句,返回被导入的本地模块列表。
- 仅返回项目内部模块(排除标准库和第三方包)
- 结果去重
- 语法错误或文件不存在时返回空列表
"""
if not filepath.exists():
return []
source = _read_source(filepath)
if source is None:
return []
try:
tree = ast.parse(source, filename=str(filepath))
except SyntaxError:
logger.warning("语法错误,无法解析 %s", filepath)
return []
modules: list[str] = []
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
name = alias.name
if _is_project_module(name) and not _is_stdlib_or_third_party(name):
modules.append(name)
elif isinstance(node, ast.ImportFrom):
if node.module and node.level == 0:
name = node.module
if _is_project_module(name) and not _is_stdlib_or_third_party(name):
modules.append(name)
# 去重并保持顺序
seen: set[str] = set()
result: list[str] = []
for m in modules:
if m not in seen:
seen.add(m)
result.append(m)
return result
# ---------------------------------------------------------------------------
# build_flow_tree — 从入口递归追踪 import 链,构建流程树
# ---------------------------------------------------------------------------
def build_flow_tree(
repo_root: Path,
entry_file: str,
_visited: set[str] | None = None,
) -> FlowNode:
"""从指定入口文件出发,递归追踪 import 链,构建流程树。
Parameters
----------
repo_root : Path
仓库根目录。
entry_file : str
入口文件的相对路径(如 "cli/main.py")。
_visited : set[str] | None
内部使用,防止循环导入导致无限递归。
Returns
-------
FlowNode
以入口文件为根的流程树。
"""
is_root = _visited is None
if _visited is None:
_visited = set()
module_name = _path_to_module_name(entry_file)
node_type = "entry" if is_root else "module"
_visited.add(entry_file)
filepath = repo_root / entry_file
children: list[FlowNode] = []
if filepath.exists():
imported_modules = parse_imports(filepath)
for mod in imported_modules:
child_path = _module_to_path(mod)
# 如果 .py 文件不存在,尝试 __init__.py
if not (repo_root / child_path).exists():
alt_path = mod.replace(".", "/") + "/__init__.py"
if (repo_root / alt_path).exists():
child_path = alt_path
if child_path not in _visited:
child_node = build_flow_tree(repo_root, child_path, _visited)
children.append(child_node)
return FlowNode(
name=module_name,
source_file=entry_file,
node_type=node_type,
children=children,
)
# ---------------------------------------------------------------------------
# 批处理文件解析
# ---------------------------------------------------------------------------
def _parse_bat_python_target(bat_path: Path) -> str | None:
"""从批处理文件中解析 python -m 命令的目标模块名。
返回模块名(如 "cli.main"),未找到时返回 None。
"""
if not bat_path.exists():
return None
content = _read_source(bat_path)
if content is None:
return None
# 匹配 python -m module.name 或 python3 -m module.name
pattern = re.compile(r"python[3]?\s+-m\s+([\w.]+)", re.IGNORECASE)
for line in content.splitlines():
m = pattern.search(line)
if m:
return m.group(1)
return None
# ---------------------------------------------------------------------------
# 入口点识别
# ---------------------------------------------------------------------------
def discover_entry_points(repo_root: Path) -> list[dict[str, str]]:
"""识别项目的所有入口点。
返回字典列表,每个字典包含:
- type: 入口类型CLI / GUI / 批处理 / 运维脚本)
- file: 相对路径
- description: 简要说明
识别规则:
- cli/main.py → CLI 入口
- gui/main.py → GUI 入口
- *.bat 文件 → 解析其中的 python -m 命令
- scripts/*.py含 if __name__ == "__main__",排除 __init__.py 和 audit/ 子目录)
"""
entries: list[dict[str, str]] = []
# CLI 入口
cli_main = repo_root / "cli" / "main.py"
if cli_main.exists():
entries.append({
"type": "CLI",
"file": "cli/main.py",
"description": "CLI 主入口 (`python -m cli.main`)",
})
# GUI 入口
gui_main = repo_root / "gui" / "main.py"
if gui_main.exists():
entries.append({
"type": "GUI",
"file": "gui/main.py",
"description": "GUI 主入口 (`python -m gui.main`)",
})
# 批处理文件
for bat in sorted(repo_root.glob("*.bat")):
target = _parse_bat_python_target(bat)
desc = f"批处理脚本"
if target:
desc += f",调用 `{target}`"
entries.append({
"type": "批处理",
"file": bat.name,
"description": desc,
})
# 运维脚本scripts/ 下的 .py 文件(排除 __init__.py 和 audit/ 子目录)
scripts_dir = repo_root / "scripts"
if scripts_dir.is_dir():
for py_file in sorted(scripts_dir.glob("*.py")):
if py_file.name == "__init__.py":
continue
# 检查是否包含 if __name__ == "__main__"
source = _read_source(py_file)
if source and '__name__' in source and '__main__' in source:
rel = py_file.relative_to(repo_root).as_posix()
entries.append({
"type": "运维脚本",
"file": rel,
"description": f"运维脚本 `{py_file.name}`",
})
return entries
# ---------------------------------------------------------------------------
# 任务类型和加载器类型区分
# ---------------------------------------------------------------------------
def classify_task_type(rel_path: str) -> str:
"""根据文件路径区分任务类型。
返回值:
- "ODS 抓取任务"
- "DWD 加载任务"
- "DWS 汇总任务"
- "校验任务"
- "Schema 初始化任务"
- "任务"(无法细分时的默认值)
"""
p = rel_path.replace("\\", "/").lower()
if "verification/" in p or "verification\\" in p:
return "校验任务"
if "dws/" in p or "dws\\" in p:
return "DWS 汇总任务"
# 文件名级别判断
basename = p.rsplit("/", 1)[-1] if "/" in p else p
if basename.startswith("ods_") or basename.startswith("ods."):
return "ODS 抓取任务"
if basename.startswith("dwd_") or basename.startswith("dwd."):
return "DWD 加载任务"
if basename.startswith("dws_"):
return "DWS 汇总任务"
if "init" in basename and "schema" in basename:
return "Schema 初始化任务"
return "任务"
def classify_loader_type(rel_path: str) -> str:
"""根据文件路径区分加载器类型。
返回值:
- "维度加载器 (SCD2)"
- "事实表加载器"
- "ODS 通用加载器"
- "加载器"(无法细分时的默认值)
"""
p = rel_path.replace("\\", "/").lower()
if "dimensions/" in p or "dimensions\\" in p:
return "维度加载器 (SCD2)"
if "facts/" in p or "facts\\" in p:
return "事实表加载器"
if "ods/" in p or "ods\\" in p:
return "ODS 通用加载器"
return "加载器"
# ---------------------------------------------------------------------------
# find_orphan_modules — 找出未被任何入口直接或间接引用的 Python 模块
# ---------------------------------------------------------------------------
def find_orphan_modules(
repo_root: Path,
all_entries: list[FileEntry],
reachable: set[str],
) -> list[str]:
"""找出未被任何入口直接或间接引用的 Python 模块。
排除规则(不视为孤立):
- __init__.py 文件
- tests/ 目录下的文件
- scripts/audit/ 目录下的文件(审计脚本自身)
- 目录条目
- 非 .py 文件
- 不属于项目包的文件
返回按路径排序的孤立模块列表。
"""
orphans: list[str] = []
for entry in all_entries:
# 跳过目录
if entry.is_dir:
continue
# 只关注 .py 文件
if entry.extension != ".py":
continue
rel = entry.rel_path.replace("\\", "/")
# 排除 __init__.py
if rel.endswith("/__init__.py") or rel == "__init__.py":
continue
# 排除测试文件
if rel.startswith("tests/") or rel.startswith("tests\\"):
continue
# 排除审计脚本自身
if rel.startswith("scripts/audit/") or rel.startswith("scripts\\audit\\"):
continue
# 只检查属于项目包的文件
top_dir = rel.split("/")[0] if "/" in rel else ""
if top_dir not in _PROJECT_PACKAGES:
continue
# 不在可达集合中 → 孤立
if rel not in reachable:
orphans.append(rel)
orphans.sort()
return orphans
# ---------------------------------------------------------------------------
# 统计辅助
# ---------------------------------------------------------------------------
def _count_nodes_by_type(trees: list[FlowNode]) -> dict[str, int]:
"""递归统计流程树中各类型节点的数量。"""
counts: dict[str, int] = {"entry": 0, "module": 0, "class": 0, "function": 0}
def _walk(node: FlowNode) -> None:
t = node.node_type
counts[t] = counts.get(t, 0) + 1
for child in node.children:
_walk(child)
for tree in trees:
_walk(tree)
return counts
def _count_tasks_and_loaders(trees: list[FlowNode]) -> tuple[int, int]:
"""统计流程树中任务模块和加载器模块的数量。"""
tasks = 0
loaders = 0
seen: set[str] = set()
def _walk(node: FlowNode) -> None:
nonlocal tasks, loaders
if node.source_file in seen:
return
seen.add(node.source_file)
sf = node.source_file.replace("\\", "/")
if sf.startswith("tasks/") and not sf.endswith("__init__.py"):
base = sf.rsplit("/", 1)[-1]
if not base.startswith("base_"):
tasks += 1
if sf.startswith("loaders/") and not sf.endswith("__init__.py"):
base = sf.rsplit("/", 1)[-1]
if not base.startswith("base_"):
loaders += 1
for child in node.children:
_walk(child)
for tree in trees:
_walk(tree)
return tasks, loaders
# ---------------------------------------------------------------------------
# 类型标注辅助
# ---------------------------------------------------------------------------
def _get_type_annotation(source_file: str) -> str:
"""根据源文件路径返回类型标注字符串(用于报告中的节点标注)。"""
sf = source_file.replace("\\", "/")
if sf.startswith("tasks/"):
return f" [{classify_task_type(sf)}]"
if sf.startswith("loaders/"):
return f" [{classify_loader_type(sf)}]"
return ""
# ---------------------------------------------------------------------------
# Mermaid 图生成
# ---------------------------------------------------------------------------
def _render_mermaid(trees: list[FlowNode]) -> str:
"""生成 Mermaid 流程图代码。"""
lines: list[str] = ["```mermaid", "graph TD"]
seen_edges: set[tuple[str, str]] = set()
node_ids: dict[str, str] = {}
counter = [0]
def _node_id(name: str) -> str:
if name not in node_ids:
node_ids[name] = f"N{counter[0]}"
counter[0] += 1
return node_ids[name]
def _walk(node: FlowNode) -> None:
nid = _node_id(node.name)
annotation = _get_type_annotation(node.source_file)
label = f"{node.name}{annotation}"
# 声明节点
lines.append(f" {nid}[\"`{label}`\"]")
for child in node.children:
cid = _node_id(child.name)
edge = (nid, cid)
if edge not in seen_edges:
seen_edges.add(edge)
lines.append(f" {nid} --> {cid}")
_walk(child)
for tree in trees:
_walk(tree)
lines.append("```")
return "\n".join(lines)
# ---------------------------------------------------------------------------
# 缩进文本树生成
# ---------------------------------------------------------------------------
def _render_text_tree(trees: list[FlowNode]) -> str:
"""生成缩进文本形式的流程树。"""
lines: list[str] = []
seen: set[str] = set()
def _walk(node: FlowNode, depth: int) -> None:
indent = " " * depth
annotation = _get_type_annotation(node.source_file)
line = f"{indent}- `{node.name}` (`{node.source_file}`){annotation}"
lines.append(line)
key = node.source_file
if key in seen:
# 已展开过,不再递归(避免循环)
if node.children:
lines.append(f"{indent} - *(已展开)*")
return
seen.add(key)
for child in node.children:
_walk(child, depth + 1)
for tree in trees:
_walk(tree, 0)
return "\n".join(lines)
# ---------------------------------------------------------------------------
# render_flow_report — 生成 Markdown 格式的流程树报告
# ---------------------------------------------------------------------------
def render_flow_report(
trees: list[FlowNode],
orphans: list[str],
repo_root: str,
) -> str:
"""生成 Markdown 格式的流程树报告(含 Mermaid 图和缩进文本)。
报告结构:
1. 头部(时间戳、仓库路径)
2. Mermaid 流程图
3. 缩进文本树
4. 孤立模块列表
5. 统计摘要
"""
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
sections: list[str] = []
# --- 头部 ---
sections.append("# 项目流程树报告\n")
sections.append(f"- 生成时间: {timestamp}")
sections.append(f"- 仓库路径: `{repo_root}`\n")
# --- Mermaid 图 ---
sections.append("## 流程图Mermaid\n")
sections.append(_render_mermaid(trees))
sections.append("")
# --- 缩进文本树 ---
sections.append("## 流程树(缩进文本)\n")
sections.append(_render_text_tree(trees))
sections.append("")
# --- 孤立模块 ---
sections.append("## 孤立模块\n")
if orphans:
for o in orphans:
sections.append(f"- `{o}`")
else:
sections.append("未发现孤立模块。")
sections.append("")
# --- 统计摘要 ---
entry_count = sum(1 for t in trees if t.node_type == "entry")
task_count, loader_count = _count_tasks_and_loaders(trees)
orphan_count = len(orphans)
sections.append("## 统计摘要\n")
sections.append(f"| 指标 | 数量 |")
sections.append(f"|------|------|")
sections.append(f"| 入口点 | {entry_count} |")
sections.append(f"| 任务 | {task_count} |")
sections.append(f"| 加载器 | {loader_count} |")
sections.append(f"| 孤立模块 | {orphan_count} |")
sections.append("")
return "\n".join(sections)

View File

@@ -0,0 +1,449 @@
# -*- coding: utf-8 -*-
"""
文件清单分析器 — 对扫描结果进行用途分类和处置标签分配。
分类规则按优先级从高到低排列:
1. tmp/ 下所有文件 → 临时与调试 / 候选删除或候选归档
2. logs/、export/ 下的运行时产出 → 日志与输出 / 候选归档
3. *.lnk、*.rar 文件 → 其他 / 候选删除
4. 空目录 → 其他 / 候选删除
5. 核心代码目录tasks/ 等)→ 核心代码 / 保留
6. config/ → 配置 / 保留
7. database/*.sql、database/migrations/ → 数据库定义 / 保留
8. database/*.py → 核心代码 / 保留
9. tests/ → 测试 / 保留
10. docs/ → 文档 / 保留
11. scripts/ 下的 .py 文件 → 脚本工具 / 保留
12. gui/ → GUI / 保留
13. 构建与部署文件 → 构建与部署 / 保留
14. 其余 → 其他 / 待确认
"""
from __future__ import annotations
import os
from collections import Counter
from datetime import datetime, timezone
from itertools import groupby
from scripts.audit import Category, Disposition, FileEntry, InventoryItem
# ---------------------------------------------------------------------------
# 常量
# ---------------------------------------------------------------------------
# 核心代码顶层目录
_CORE_CODE_DIRS = (
"tasks/", "loaders/", "scd/", "orchestration/",
"quality/", "models/", "utils/", "api/",
)
# 构建与部署文件名(根目录级别)
_BUILD_DEPLOY_BASENAMES = {"setup.py", "build_exe.py"}
# 构建与部署扩展名
_BUILD_DEPLOY_EXTENSIONS = {".bat", ".sh", ".ps1"}
# ---------------------------------------------------------------------------
# 辅助函数
# ---------------------------------------------------------------------------
def _top_dir(rel_path: str) -> str:
"""返回相对路径的第一级目录名(含尾部斜杠),如 'tmp/foo.py''tmp/'"""
idx = rel_path.find("/")
if idx == -1:
return ""
return rel_path[: idx + 1]
def _basename(rel_path: str) -> str:
"""返回路径的最后一段文件名。"""
return rel_path.rsplit("/", 1)[-1]
def _is_init_py(rel_path: str) -> bool:
"""判断路径是否为 __init__.py。"""
return _basename(rel_path) == "__init__.py"
# ---------------------------------------------------------------------------
# classify — 核心分类函数
# ---------------------------------------------------------------------------
def classify(entry: FileEntry) -> InventoryItem:
"""根据路径、扩展名等规则对单个文件/目录进行分类和标签分配。
规则按优先级从高到低依次匹配,首个命中的规则决定分类和处置。
"""
path = entry.rel_path
top = _top_dir(path)
ext = entry.extension.lower()
base = _basename(path)
# --- 优先级 1: tmp/ 下所有文件 ---
if top == "tmp/" or path == "tmp":
return _classify_tmp(entry)
# --- 优先级 2: logs/、export/ 下的运行时产出 ---
if top in ("logs/", "export/") or path in ("logs", "export"):
return _classify_runtime_output(entry)
# --- 优先级 3: .lnk / .rar 文件 ---
if ext in (".lnk", ".rar"):
return InventoryItem(
rel_path=path,
category=Category.OTHER,
disposition=Disposition.CANDIDATE_DELETE,
description=f"快捷方式/压缩包文件(`{ext}`),建议删除",
)
# --- 优先级 4: 空目录 ---
if entry.is_empty_dir:
return InventoryItem(
rel_path=path,
category=Category.OTHER,
disposition=Disposition.CANDIDATE_DELETE,
description="空目录,建议删除",
)
# --- 优先级 5: 核心代码目录 ---
if any(path.startswith(d) or path + "/" == d for d in _CORE_CODE_DIRS):
return InventoryItem(
rel_path=path,
category=Category.CORE_CODE,
disposition=Disposition.KEEP,
description=f"核心代码(`{top.rstrip('/')}`",
)
# --- 优先级 6: config/ ---
if top == "config/" or path == "config":
return InventoryItem(
rel_path=path,
category=Category.CONFIG,
disposition=Disposition.KEEP,
description="配置文件",
)
# --- 优先级 7: database/*.sql 和 database/migrations/ ---
if top == "database/" or path == "database":
return _classify_database(entry)
# --- 优先级 8: tests/ ---
if top == "tests/" or path == "tests":
return InventoryItem(
rel_path=path,
category=Category.TEST,
disposition=Disposition.KEEP,
description="测试文件",
)
# --- 优先级 9: docs/ ---
if top == "docs/" or path == "docs":
return InventoryItem(
rel_path=path,
category=Category.DOCS,
disposition=Disposition.KEEP,
description="文档",
)
# --- 优先级 10: scripts/ 下的 .py 文件 ---
if top == "scripts/" or path == "scripts":
cat = Category.SCRIPTS
if ext == ".py" or entry.is_dir:
return InventoryItem(
rel_path=path,
category=cat,
disposition=Disposition.KEEP,
description="脚本工具",
)
return InventoryItem(
rel_path=path,
category=cat,
disposition=Disposition.NEEDS_REVIEW,
description="脚本目录下的非 Python 文件,需确认用途",
)
# --- 优先级 11: gui/ ---
if top == "gui/" or path == "gui":
return InventoryItem(
rel_path=path,
category=Category.GUI,
disposition=Disposition.KEEP,
description="GUI 模块",
)
# --- 优先级 12: 构建与部署 ---
if base in _BUILD_DEPLOY_BASENAMES or ext in _BUILD_DEPLOY_EXTENSIONS:
return InventoryItem(
rel_path=path,
category=Category.BUILD_DEPLOY,
disposition=Disposition.KEEP,
description="构建与部署文件",
)
# --- 优先级 13: cli/ ---
if top == "cli/" or path == "cli":
return InventoryItem(
rel_path=path,
category=Category.CORE_CODE,
disposition=Disposition.KEEP,
description="CLI 入口模块",
)
# --- 优先级 14: 已知根目录文件 ---
if "/" not in path:
return _classify_root_file(entry)
# --- 兜底 ---
return InventoryItem(
rel_path=path,
category=Category.OTHER,
disposition=Disposition.NEEDS_REVIEW,
description="未匹配已知规则,需人工确认用途",
)
# ---------------------------------------------------------------------------
# 子分类函数
# ---------------------------------------------------------------------------
def _classify_tmp(entry: FileEntry) -> InventoryItem:
"""tmp/ 目录下的文件分类。
默认候选删除;有意义的 .py 文件标记为候选归档。
"""
ext = entry.extension.lower()
base = _basename(entry.rel_path)
# 空目录直接候选删除
if entry.is_empty_dir:
return InventoryItem(
rel_path=entry.rel_path,
category=Category.TEMP_DEBUG,
disposition=Disposition.CANDIDATE_DELETE,
description="临时目录下的空目录",
)
# .py 文件可能有参考价值 → 候选归档
if ext == ".py" and len(base) > 4:
return InventoryItem(
rel_path=entry.rel_path,
category=Category.TEMP_DEBUG,
disposition=Disposition.CANDIDATE_ARCHIVE,
description="临时 Python 脚本,可能有参考价值",
)
return InventoryItem(
rel_path=entry.rel_path,
category=Category.TEMP_DEBUG,
disposition=Disposition.CANDIDATE_DELETE,
description="临时/调试文件,建议删除",
)
def _classify_runtime_output(entry: FileEntry) -> InventoryItem:
"""logs/、export/ 目录下的运行时产出分类。
__init__.py 保留(包标记),其余候选归档。
"""
if _is_init_py(entry.rel_path):
return InventoryItem(
rel_path=entry.rel_path,
category=Category.LOG_OUTPUT,
disposition=Disposition.KEEP,
description="包初始化文件",
)
return InventoryItem(
rel_path=entry.rel_path,
category=Category.LOG_OUTPUT,
disposition=Disposition.CANDIDATE_ARCHIVE,
description="运行时产出,建议归档",
)
def _classify_database(entry: FileEntry) -> InventoryItem:
"""database/ 目录下的文件分类。"""
path = entry.rel_path
ext = entry.extension.lower()
# migrations/ 子目录
if "migrations/" in path or path.endswith("migrations"):
return InventoryItem(
rel_path=path,
category=Category.DATABASE_DEF,
disposition=Disposition.KEEP,
description="数据库迁移脚本",
)
# .sql 文件
if ext == ".sql":
return InventoryItem(
rel_path=path,
category=Category.DATABASE_DEF,
disposition=Disposition.KEEP,
description="数据库 DDL/DML 脚本",
)
# .py 文件 → 核心代码
if ext == ".py":
return InventoryItem(
rel_path=path,
category=Category.CORE_CODE,
disposition=Disposition.KEEP,
description="数据库操作模块",
)
# 目录本身
if entry.is_dir:
if entry.is_empty_dir:
return InventoryItem(
rel_path=path,
category=Category.OTHER,
disposition=Disposition.CANDIDATE_DELETE,
description="数据库目录下的空目录",
)
return InventoryItem(
rel_path=path,
category=Category.DATABASE_DEF,
disposition=Disposition.KEEP,
description="数据库子目录",
)
# 其他文件
return InventoryItem(
rel_path=path,
category=Category.DATABASE_DEF,
disposition=Disposition.NEEDS_REVIEW,
description="数据库目录下的非标准文件,需确认",
)
def _classify_root_file(entry: FileEntry) -> InventoryItem:
"""根目录散落文件的分类。"""
ext = entry.extension.lower()
base = _basename(entry.rel_path)
# 已知构建文件
if base in _BUILD_DEPLOY_BASENAMES or ext in _BUILD_DEPLOY_EXTENSIONS:
return InventoryItem(
rel_path=entry.rel_path,
category=Category.BUILD_DEPLOY,
disposition=Disposition.KEEP,
description="构建与部署文件",
)
# 已知配置文件
if base in (
"requirements.txt", "pytest.ini", ".env", ".env.example",
".gitignore", ".flake8", "pyproject.toml",
):
return InventoryItem(
rel_path=entry.rel_path,
category=Category.CONFIG,
disposition=Disposition.KEEP,
description="项目配置文件",
)
# README
if base.lower().startswith("readme"):
return InventoryItem(
rel_path=entry.rel_path,
category=Category.DOCS,
disposition=Disposition.KEEP,
description="项目说明文档",
)
# 其他根目录文件 → 待确认
return InventoryItem(
rel_path=entry.rel_path,
category=Category.OTHER,
disposition=Disposition.NEEDS_REVIEW,
description=f"根目录散落文件(`{base}`),需确认用途",
)
# ---------------------------------------------------------------------------
# build_inventory — 批量分类
# ---------------------------------------------------------------------------
def build_inventory(entries: list[FileEntry]) -> list[InventoryItem]:
"""对所有文件条目执行分类,返回清单列表。"""
return [classify(e) for e in entries]
# ---------------------------------------------------------------------------
# render_inventory_report — Markdown 渲染
# ---------------------------------------------------------------------------
def render_inventory_report(items: list[InventoryItem], repo_root: str) -> str:
"""生成 Markdown 格式的文件清单报告。
报告结构:
- 头部:标题、生成时间、仓库路径
- 主体:按 Category 分组的表格
- 尾部:统计摘要
"""
lines: list[str] = []
# --- 头部 ---
now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
lines.append("# 文件清单报告")
lines.append("")
lines.append(f"- 生成时间:{now}")
lines.append(f"- 仓库路径:`{repo_root}`")
lines.append("")
# --- 按分类分组 ---
# 保持 Category 枚举定义顺序
cat_order = {c: i for i, c in enumerate(Category)}
sorted_items = sorted(items, key=lambda it: cat_order[it.category])
for cat, group in groupby(sorted_items, key=lambda it: it.category):
group_list = list(group)
lines.append(f"## {cat.value}")
lines.append("")
lines.append("| 相对路径 | 处置标签 | 简要说明 |")
lines.append("|---|---|---|")
for item in group_list:
lines.append(
f"| `{item.rel_path}` | {item.disposition.value} | {item.description} |"
)
lines.append("")
# --- 统计摘要 ---
lines.append("## 统计摘要")
lines.append("")
# 各分类计数
cat_counter: Counter[Category] = Counter()
disp_counter: Counter[Disposition] = Counter()
for item in items:
cat_counter[item.category] += 1
disp_counter[item.disposition] += 1
lines.append("### 按用途分类")
lines.append("")
lines.append("| 分类 | 数量 |")
lines.append("|---|---|")
for cat in Category:
count = cat_counter.get(cat, 0)
if count > 0:
lines.append(f"| {cat.value} | {count} |")
lines.append("")
lines.append("### 按处置标签")
lines.append("")
lines.append("| 标签 | 数量 |")
lines.append("|---|---|")
for disp in Disposition:
count = disp_counter.get(disp, 0)
if count > 0:
lines.append(f"| {disp.value} | {count} |")
lines.append("")
lines.append(f"**总计:{len(items)} 个条目**")
lines.append("")
return "\n".join(lines)

255
scripts/audit/run_audit.py Normal file
View File

@@ -0,0 +1,255 @@
# -*- coding: utf-8 -*-
"""
审计主入口 — 依次调用扫描器和三个分析器,生成三份报告到 docs/audit/。
仅在 docs/audit/ 目录下创建文件,不修改仓库中的任何现有文件。
"""
from __future__ import annotations
import logging
import re
from datetime import datetime, timezone
from pathlib import Path
from scripts.audit.scanner import scan_repo
from scripts.audit.inventory_analyzer import (
build_inventory,
render_inventory_report,
)
from scripts.audit.flow_analyzer import (
build_flow_tree,
discover_entry_points,
find_orphan_modules,
render_flow_report,
)
from scripts.audit.doc_alignment_analyzer import (
build_mappings,
check_api_samples_vs_parsers,
check_ddl_vs_dictionary,
find_undocumented_modules,
render_alignment_report,
scan_docs,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# 仓库根目录自动检测
# ---------------------------------------------------------------------------
def _detect_repo_root() -> Path:
"""从当前文件向上查找仓库根目录。
判断依据:包含 cli/ 目录或 .git/ 目录的祖先目录。
"""
current = Path(__file__).resolve().parent
for parent in (current, *current.parents):
if (parent / "cli").is_dir() or (parent / ".git").is_dir():
return parent
# 回退:假设 scripts/audit/ 在仓库根目录下
return current.parent.parent
# ---------------------------------------------------------------------------
# 报告输出目录
# ---------------------------------------------------------------------------
def _ensure_report_dir(repo_root: Path) -> Path:
"""检查并创建 docs/audit/ 目录。
如果目录已存在则直接返回;不存在则创建。
创建失败时抛出 RuntimeError因为无法输出报告
"""
audit_dir = repo_root / "docs" / "audit"
if audit_dir.is_dir():
return audit_dir
try:
audit_dir.mkdir(parents=True, exist_ok=True)
except OSError as exc:
raise RuntimeError(f"无法创建报告输出目录 {audit_dir}: {exc}") from exc
logger.info("已创建报告输出目录: %s", audit_dir)
return audit_dir
# ---------------------------------------------------------------------------
# 报告头部元信息注入
# ---------------------------------------------------------------------------
_HEADER_PATTERN = re.compile(r"生成时间[:]")
_ISO_TS_PATTERN = re.compile(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z")
# 匹配非 ISO 格式的时间戳行,用于替换
_NON_ISO_TS_LINE = re.compile(
r"([-*]\s*生成时间[:]\s*)\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2}"
)
def _inject_header(report: str, timestamp: str, repo_path: str) -> str:
"""确保报告头部包含 ISO 格式时间戳和仓库路径。
- 已有 ISO 时间戳 → 不修改
- 有非 ISO 时间戳 → 替换为 ISO 格式
- 无头部 → 在标题后注入
"""
if _HEADER_PATTERN.search(report):
# 已有头部——检查时间戳格式是否为 ISO
if _ISO_TS_PATTERN.search(report):
return report
# 非 ISO 格式 → 替换时间戳
report = _NON_ISO_TS_LINE.sub(
lambda m: m.group(1) + timestamp, report,
)
# 同时确保仓库路径使用统一值(用 lambda 避免反斜杠转义问题)
safe_path = repo_path
report = re.sub(
r"([-*]\s*仓库路径[:]\s*)`[^`]*`",
lambda m: m.group(1) + "`" + safe_path + "`",
report,
)
return report
# 无头部 → 在第一个标题行之后插入
lines = report.split("\n")
insert_idx = 1
for i, line in enumerate(lines):
if line.startswith("# "):
insert_idx = i + 1
break
header_lines = [
"",
f"- 生成时间: {timestamp}",
f"- 仓库路径: `{repo_path}`",
"",
]
lines[insert_idx:insert_idx] = header_lines
return "\n".join(lines)
# ---------------------------------------------------------------------------
# 主函数
# ---------------------------------------------------------------------------
def run_audit(repo_root: Path | None = None) -> None:
"""执行完整审计流程,生成三份报告到 docs/audit/。
Parameters
----------
repo_root : Path | None
仓库根目录。为 None 时自动检测。
"""
# 1. 确定仓库根目录
if repo_root is None:
repo_root = _detect_repo_root()
repo_root = repo_root.resolve()
repo_path_str = str(repo_root)
logger.info("审计开始 — 仓库路径: %s", repo_path_str)
# 2. 检查/创建输出目录
audit_dir = _ensure_report_dir(repo_root)
# 3. 生成 UTC 时间戳(所有报告共用)
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
# 4. 扫描仓库
logger.info("正在扫描仓库文件...")
entries = scan_repo(repo_root)
logger.info("扫描完成,共 %d 个条目", len(entries))
# 5. 文件清单报告
logger.info("正在生成文件清单报告...")
try:
inventory_items = build_inventory(entries)
inventory_report = render_inventory_report(inventory_items, repo_path_str)
inventory_report = _inject_header(inventory_report, timestamp, repo_path_str)
(audit_dir / "file_inventory.md").write_text(
inventory_report, encoding="utf-8",
)
logger.info("文件清单报告已写入: file_inventory.md")
except Exception:
logger.exception("生成文件清单报告时出错")
# 6. 流程树报告
logger.info("正在生成流程树报告...")
try:
entry_points = discover_entry_points(repo_root)
trees = []
reachable: set[str] = set()
for ep in entry_points:
ep_file = ep["file"]
# 批处理文件不构建流程树
if not ep_file.endswith(".py"):
continue
tree = build_flow_tree(repo_root, ep_file)
trees.append(tree)
# 收集可达模块
_collect_reachable(tree, reachable)
orphans = find_orphan_modules(repo_root, entries, reachable)
flow_report = render_flow_report(trees, orphans, repo_path_str)
flow_report = _inject_header(flow_report, timestamp, repo_path_str)
(audit_dir / "flow_tree.md").write_text(
flow_report, encoding="utf-8",
)
logger.info("流程树报告已写入: flow_tree.md")
except Exception:
logger.exception("生成流程树报告时出错")
# 7. 文档对齐报告
logger.info("正在生成文档对齐报告...")
try:
doc_paths = scan_docs(repo_root)
mappings = build_mappings(doc_paths, repo_root)
issues = []
issues.extend(check_ddl_vs_dictionary(repo_root))
issues.extend(check_api_samples_vs_parsers(repo_root))
# 缺失文档检测
documented: set[str] = set()
for m in mappings:
documented.update(m.related_code)
undoc_modules = find_undocumented_modules(repo_root, documented)
from scripts.audit import AlignmentIssue
for mod in undoc_modules:
issues.append(AlignmentIssue(
doc_path="",
issue_type="missing",
description=f"核心代码模块 `{mod}` 缺少对应文档",
related_code=mod,
))
alignment_report = render_alignment_report(mappings, issues, repo_path_str)
alignment_report = _inject_header(alignment_report, timestamp, repo_path_str)
(audit_dir / "doc_alignment.md").write_text(
alignment_report, encoding="utf-8",
)
logger.info("文档对齐报告已写入: doc_alignment.md")
except Exception:
logger.exception("生成文档对齐报告时出错")
logger.info("审计完成 — 报告输出目录: %s", audit_dir)
# ---------------------------------------------------------------------------
# 辅助:收集可达模块
# ---------------------------------------------------------------------------
def _collect_reachable(node, reachable: set[str]) -> None:
"""递归收集流程树中所有节点的 source_file。"""
reachable.add(node.source_file)
for child in node.children:
_collect_reachable(child, reachable)
# ---------------------------------------------------------------------------
# 入口
# ---------------------------------------------------------------------------
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
run_audit()

150
scripts/audit/scanner.py Normal file
View File

@@ -0,0 +1,150 @@
# -*- coding: utf-8 -*-
"""
仓库扫描器 — 递归遍历仓库文件系统,返回结构化的文件元信息。
仅执行只读操作:读取文件元信息(大小、类型),不修改任何文件。
遇到权限错误时跳过并记录日志,不中断扫描流程。
"""
from __future__ import annotations
import fnmatch
import logging
from pathlib import Path
from scripts.audit import FileEntry
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# 排除模式
# ---------------------------------------------------------------------------
EXCLUDED_PATTERNS: list[str] = [
".git",
"__pycache__",
".pytest_cache",
"*.pyc",
".kiro",
]
# ---------------------------------------------------------------------------
# 排除匹配逻辑
# ---------------------------------------------------------------------------
def _is_excluded(name: str, patterns: list[str]) -> bool:
"""判断文件/目录名是否匹配任一排除模式。
支持两种模式:
- 精确匹配(如 ".git""__pycache__"
- 通配符匹配(如 "*.pyc"),使用 fnmatch 语义
"""
for pat in patterns:
if fnmatch.fnmatch(name, pat):
return True
return False
# ---------------------------------------------------------------------------
# 递归遍历
# ---------------------------------------------------------------------------
def _walk(
root: Path,
base: Path,
exclude: list[str],
results: list[FileEntry],
) -> None:
"""递归遍历 *root* 下的文件和目录,将结果追加到 *results*。
Parameters
----------
root : Path
当前要遍历的目录。
base : Path
仓库根目录,用于计算相对路径。
exclude : list[str]
排除模式列表。
results : list[FileEntry]
收集结果的列表(就地修改)。
"""
try:
children = sorted(root.iterdir(), key=lambda p: p.name)
except (PermissionError, OSError) as exc:
logger.warning("无法读取目录 %s: %s", root, exc)
return
# 用于判断当前目录是否为"空目录"(排除后无可见子项)
visible_count = 0
for child in children:
if _is_excluded(child.name, exclude):
continue
visible_count += 1
rel = child.relative_to(base).as_posix()
if child.is_dir():
# 先递归子目录,再判断该目录是否为空
sub_start = len(results)
_walk(child, base, exclude, results)
sub_end = len(results)
# 该目录下递归产生的条目数为 0 → 空目录
is_empty = (sub_end == sub_start)
results.append(FileEntry(
rel_path=rel,
is_dir=True,
size_bytes=0,
extension="",
is_empty_dir=is_empty,
))
else:
# 文件
try:
size = child.stat().st_size
except (PermissionError, OSError) as exc:
logger.warning("无法获取文件信息 %s: %s", child, exc)
continue
results.append(FileEntry(
rel_path=rel,
is_dir=False,
size_bytes=size,
extension=child.suffix.lower(),
is_empty_dir=False,
))
# 如果 root 是仓库根目录自身,不需要额外处理
# (根目录不作为条目出现在结果中)
def scan_repo(
root: Path,
exclude: list[str] | None = None,
) -> list[FileEntry]:
"""递归扫描仓库,返回所有文件和目录的元信息列表。
Parameters
----------
root : Path
仓库根目录路径。
exclude : list[str] | None
排除模式列表,默认使用 EXCLUDED_PATTERNS。
Returns
-------
list[FileEntry]
按 rel_path 排序的文件/目录元信息列表。
"""
if exclude is None:
exclude = EXCLUDED_PATTERNS
results: list[FileEntry] = []
_walk(root, root, exclude, results)
# 按相对路径排序,保证输出稳定
results.sort(key=lambda e: e.rel_path)
return results

View File

@@ -0,0 +1,193 @@
# -*- coding: utf-8 -*-
"""Run data integrity checks across API -> ODS -> DWD."""
from __future__ import annotations
import argparse
import sys
from datetime import datetime
from pathlib import Path
from zoneinfo import ZoneInfo
from dateutil import parser as dtparser
from config.settings import AppConfig
from quality.integrity_service import run_history_flow, run_window_flow, write_report
from utils.logging_utils import build_log_path, configure_logging
from utils.windowing import split_window
def _parse_dt(value: str, tz: ZoneInfo) -> datetime:
dt = dtparser.parse(value)
if dt.tzinfo is None:
return dt.replace(tzinfo=tz)
return dt.astimezone(tz)
def main() -> int:
if hasattr(sys.stdout, "reconfigure"):
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
ap = argparse.ArgumentParser(description="Data integrity checks (API -> ODS -> DWD)")
ap.add_argument("--mode", choices=["history", "window"], default="history")
ap.add_argument(
"--flow",
choices=["verify", "update_and_verify"],
default="verify",
help="verify only or update+verify (auto backfill then optional recheck)",
)
ap.add_argument("--start", default="2025-07-01", help="history start date (default: 2025-07-01)")
ap.add_argument("--end", default="", help="history end datetime (default: last ETL end)")
ap.add_argument("--window-start", default="", help="window start datetime (mode=window)")
ap.add_argument("--window-end", default="", help="window end datetime (mode=window)")
ap.add_argument("--window-split-unit", default="", help="split unit (month/none), default from config")
ap.add_argument("--window-compensation-hours", type=int, default=None, help="window compensation hours, default from config")
ap.add_argument(
"--include-dimensions",
action="store_true",
default=None,
help="include dimension tables in ODS->DWD checks",
)
ap.add_argument(
"--no-include-dimensions",
action="store_true",
help="exclude dimension tables in ODS->DWD checks",
)
ap.add_argument("--ods-task-codes", default="", help="comma-separated ODS task codes for API checks")
ap.add_argument("--compare-content", action="store_true", help="compare API vs ODS content hash")
ap.add_argument("--no-compare-content", action="store_true", help="disable content comparison even if enabled in config")
ap.add_argument("--include-mismatch", action="store_true", help="backfill mismatch records as well")
ap.add_argument("--no-include-mismatch", action="store_true", help="disable mismatch backfill")
ap.add_argument("--recheck", action="store_true", help="re-run checks after backfill")
ap.add_argument("--no-recheck", action="store_true", help="skip recheck after backfill")
ap.add_argument("--content-sample-limit", type=int, default=None, help="max mismatch samples per table")
ap.add_argument("--out", default="", help="output JSON path")
ap.add_argument("--log-file", default="", help="log file path")
ap.add_argument("--log-dir", default="", help="log directory")
ap.add_argument("--log-level", default="INFO", help="log level")
ap.add_argument("--no-log-console", action="store_true", help="disable console logging")
args = ap.parse_args()
log_dir = Path(args.log_dir) if args.log_dir else (Path(__file__).resolve().parent / "logs")
log_file = Path(args.log_file) if args.log_file else build_log_path(log_dir, "data_integrity")
log_console = not args.no_log_console
with configure_logging(
"data_integrity",
log_file,
level=args.log_level,
console=log_console,
tee_std=True,
) as logger:
cfg = AppConfig.load({})
tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei"))
report_path = Path(args.out) if args.out else None
if args.recheck and args.no_recheck:
raise SystemExit("cannot set both --recheck and --no-recheck")
if args.include_mismatch and args.no_include_mismatch:
raise SystemExit("cannot set both --include-mismatch and --no-include-mismatch")
if args.include_dimensions and args.no_include_dimensions:
raise SystemExit("cannot set both --include-dimensions and --no-include-dimensions")
compare_content = None
if args.compare_content and args.no_compare_content:
raise SystemExit("cannot set both --compare-content and --no-compare-content")
if args.compare_content:
compare_content = True
elif args.no_compare_content:
compare_content = False
include_mismatch = cfg.get("integrity.backfill_mismatch", True)
if args.include_mismatch:
include_mismatch = True
elif args.no_include_mismatch:
include_mismatch = False
recheck_after_backfill = cfg.get("integrity.recheck_after_backfill", True)
if args.recheck:
recheck_after_backfill = True
elif args.no_recheck:
recheck_after_backfill = False
include_dimensions = cfg.get("integrity.include_dimensions", True)
if args.include_dimensions:
include_dimensions = True
elif args.no_include_dimensions:
include_dimensions = False
if args.mode == "window":
if not args.window_start or not args.window_end:
raise SystemExit("window-start and window-end are required for mode=window")
start_dt = _parse_dt(args.window_start, tz)
end_dt = _parse_dt(args.window_end, tz)
split_unit = (args.window_split_unit or cfg.get("run.window_split.unit", "month") or "month").strip()
comp_hours = args.window_compensation_hours
if comp_hours is None:
comp_hours = cfg.get("run.window_split.compensation_hours", 0)
windows = split_window(
start_dt,
end_dt,
tz=tz,
split_unit=split_unit,
compensation_hours=comp_hours,
)
if not windows:
windows = [(start_dt, end_dt)]
report, counts = run_window_flow(
cfg=cfg,
windows=windows,
include_dimensions=bool(include_dimensions),
task_codes=args.ods_task_codes,
logger=logger,
compare_content=compare_content,
content_sample_limit=args.content_sample_limit,
do_backfill=args.flow == "update_and_verify",
include_mismatch=bool(include_mismatch),
recheck_after_backfill=bool(recheck_after_backfill),
page_size=int(cfg.get("api.page_size") or 200),
chunk_size=500,
)
report_path = write_report(report, prefix="data_integrity_window", tz=tz, report_path=report_path)
report["report_path"] = report_path
logger.info("REPORT_WRITTEN path=%s", report.get("report_path"))
else:
start_dt = _parse_dt(args.start, tz)
if args.end:
end_dt = _parse_dt(args.end, tz)
else:
end_dt = None
report, counts = run_history_flow(
cfg=cfg,
start_dt=start_dt,
end_dt=end_dt,
include_dimensions=bool(include_dimensions),
task_codes=args.ods_task_codes,
logger=logger,
compare_content=compare_content,
content_sample_limit=args.content_sample_limit,
do_backfill=args.flow == "update_and_verify",
include_mismatch=bool(include_mismatch),
recheck_after_backfill=bool(recheck_after_backfill),
page_size=int(cfg.get("api.page_size") or 200),
chunk_size=500,
)
report_path = write_report(report, prefix="data_integrity_history", tz=tz, report_path=report_path)
report["report_path"] = report_path
logger.info("REPORT_WRITTEN path=%s", report.get("report_path"))
logger.info(
"SUMMARY missing=%s mismatch=%s errors=%s",
counts.get("missing"),
counts.get("mismatch"),
counts.get("errors"),
)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,82 @@
# -*- coding: utf-8 -*-
import sys
sys.path.insert(0, '.')
from config.settings import AppConfig
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations
config = AppConfig.load()
db_conn = DatabaseConnection(config.config['db']['dsn'])
db = DatabaseOperations(db_conn)
# 检查DWD层服务记录分布
print("=== DWD层服务记录分析 ===")
print()
# 1. 总体统计
sql1 = """
SELECT
COUNT(*) as total_records,
COUNT(DISTINCT tenant_member_id) as unique_members,
COUNT(DISTINCT site_assistant_id) as unique_assistants,
COUNT(DISTINCT (tenant_member_id, site_assistant_id)) as unique_pairs
FROM billiards_dwd.dwd_assistant_service_log
WHERE tenant_member_id > 0 AND is_delete = 0
"""
r = dict(db.query(sql1)[0])
print("总体统计:")
print(f" 总服务记录数: {r['total_records']}")
print(f" 唯一会员数: {r['unique_members']}")
print(f" 唯一助教数: {r['unique_assistants']}")
print(f" 唯一客户-助教对: {r['unique_pairs']}")
# 2. 助教服务会员数分布
print()
print("助教服务会员数分布 (Top 10):")
sql2 = """
SELECT site_assistant_id, COUNT(DISTINCT tenant_member_id) as member_count
FROM billiards_dwd.dwd_assistant_service_log
WHERE tenant_member_id > 0 AND is_delete = 0
GROUP BY site_assistant_id
ORDER BY member_count DESC
LIMIT 10
"""
for row in db.query(sql2):
r = dict(row)
print(f" 助教 {r['site_assistant_id']}: 服务 {r['member_count']} 个会员")
# 3. 每个客户-助教对的服务次数分布
print()
print("客户-助教对 服务次数分布 (Top 10):")
sql3 = """
SELECT tenant_member_id, site_assistant_id, COUNT(*) as service_count
FROM billiards_dwd.dwd_assistant_service_log
WHERE tenant_member_id > 0 AND is_delete = 0
GROUP BY tenant_member_id, site_assistant_id
ORDER BY service_count DESC
LIMIT 10
"""
for row in db.query(sql3):
r = dict(row)
print(f" 会员 {r['tenant_member_id']} - 助教 {r['site_assistant_id']}: {r['service_count']} 次服务")
# 4. 近60天的数据
print()
print("=== 近60天数据 ===")
sql4 = """
SELECT
COUNT(*) as total_records,
COUNT(DISTINCT tenant_member_id) as unique_members,
COUNT(DISTINCT site_assistant_id) as unique_assistants,
COUNT(DISTINCT (tenant_member_id, site_assistant_id)) as unique_pairs
FROM billiards_dwd.dwd_assistant_service_log
WHERE tenant_member_id > 0 AND is_delete = 0
AND last_use_time >= NOW() - INTERVAL '60 days'
"""
r4 = dict(db.query(sql4)[0])
print(f" 总服务记录数: {r4['total_records']}")
print(f" 唯一会员数: {r4['unique_members']}")
print(f" 唯一助教数: {r4['unique_assistants']}")
print(f" 唯一客户-助教对: {r4['unique_pairs']}")
db_conn.close()

View File

@@ -0,0 +1,248 @@
# -*- coding: utf-8 -*-
"""
Validate that ODS payload content matches stored content_hash.
Usage:
PYTHONPATH=. python -m scripts.check.check_ods_content_hash
PYTHONPATH=. python -m scripts.check.check_ods_content_hash --schema billiards_ods
PYTHONPATH=. python -m scripts.check.check_ods_content_hash --tables member_profiles,orders
"""
from __future__ import annotations
import argparse
import json
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, Iterable, Sequence
from psycopg2.extras import RealDictCursor
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from config.settings import AppConfig
from database.connection import DatabaseConnection
from tasks.ods.ods_tasks import BaseOdsTask
def _reconfigure_stdout_utf8() -> None:
if hasattr(sys.stdout, "reconfigure"):
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
def _fetch_tables(conn, schema: str) -> list[str]:
sql = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = %s AND table_type = 'BASE TABLE'
ORDER BY table_name
"""
with conn.cursor() as cur:
cur.execute(sql, (schema,))
return [r[0] for r in cur.fetchall()]
def _fetch_columns(conn, schema: str, table: str) -> list[str]:
sql = """
SELECT column_name
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, table))
cols = [r[0] for r in cur.fetchall()]
return [c for c in cols if c]
def _fetch_pk_columns(conn, schema: str, table: str) -> list[str]:
sql = """
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema = %s
AND tc.table_name = %s
ORDER BY kcu.ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, table))
cols = [r[0] for r in cur.fetchall()]
return [c for c in cols if c.lower() != "content_hash"]
def _fetch_row_count(conn, schema: str, table: str) -> int:
sql = f'SELECT COUNT(*) FROM "{schema}"."{table}"'
with conn.cursor() as cur:
cur.execute(sql)
row = cur.fetchone()
return int(row[0] if row else 0)
def _iter_rows(
conn,
schema: str,
table: str,
select_cols: Sequence[str],
batch_size: int,
) -> Iterable[dict]:
cols_sql = ", ".join(f'"{c}"' for c in select_cols)
sql = f'SELECT {cols_sql} FROM "{schema}"."{table}"'
with conn.cursor(name=f"ods_hash_{table}", cursor_factory=RealDictCursor) as cur:
cur.itersize = max(1, int(batch_size or 500))
cur.execute(sql)
for row in cur:
yield row
def _build_report_path(out_arg: str | None) -> Path:
if out_arg:
return Path(out_arg)
reports_dir = PROJECT_ROOT / "reports"
reports_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
return reports_dir / f"ods_content_hash_check_{ts}.json"
def _print_progress(
table_label: str,
processed: int,
total: int,
mismatched: int,
missing_hash: int,
invalid_payload: int,
) -> None:
if total:
msg = (
f"[{table_label}] checked {processed}/{total} "
f"mismatch={mismatched} missing_hash={missing_hash} invalid_payload={invalid_payload}"
)
else:
msg = (
f"[{table_label}] checked {processed} "
f"mismatch={mismatched} missing_hash={missing_hash} invalid_payload={invalid_payload}"
)
print(msg, flush=True)
def main() -> int:
_reconfigure_stdout_utf8()
ap = argparse.ArgumentParser(description="Validate ODS payload vs content_hash consistency")
ap.add_argument("--schema", default="billiards_ods", help="ODS schema name")
ap.add_argument("--tables", default="", help="comma-separated table names (optional)")
ap.add_argument("--batch-size", type=int, default=500, help="DB fetch batch size")
ap.add_argument("--progress-every", type=int, default=100, help="print progress every N rows")
ap.add_argument("--sample-limit", type=int, default=5, help="sample mismatch rows per table")
ap.add_argument("--out", default="", help="output report JSON path")
args = ap.parse_args()
cfg = AppConfig.load({})
db = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
conn = db.conn
tables = _fetch_tables(conn, args.schema)
if args.tables.strip():
whitelist = {t.strip() for t in args.tables.split(",") if t.strip()}
tables = [t for t in tables if t in whitelist]
report = {
"schema": args.schema,
"tables": [],
"summary": {
"total_tables": 0,
"checked_tables": 0,
"total_rows": 0,
"checked_rows": 0,
"mismatch_rows": 0,
"missing_hash_rows": 0,
"invalid_payload_rows": 0,
},
}
for table in tables:
table_label = f"{args.schema}.{table}"
cols = _fetch_columns(conn, args.schema, table)
cols_lower = {c.lower() for c in cols}
if "payload" not in cols_lower or "content_hash" not in cols_lower:
print(f"[{table_label}] skip: missing payload/content_hash", flush=True)
continue
total = _fetch_row_count(conn, args.schema, table)
pk_cols = _fetch_pk_columns(conn, args.schema, table)
select_cols = ["content_hash", "payload", *pk_cols]
processed = 0
mismatched = 0
missing_hash = 0
invalid_payload = 0
samples: list[dict[str, Any]] = []
print(f"[{table_label}] start: total_rows={total}", flush=True)
for row in _iter_rows(conn, args.schema, table, select_cols, args.batch_size):
processed += 1
content_hash = row.get("content_hash")
payload = row.get("payload")
recomputed = BaseOdsTask._compute_compare_hash_from_payload(payload)
row_mismatch = False
if not content_hash:
missing_hash += 1
mismatched += 1
row_mismatch = True
elif not recomputed:
invalid_payload += 1
mismatched += 1
row_mismatch = True
elif content_hash != recomputed:
mismatched += 1
row_mismatch = True
if row_mismatch and len(samples) < max(0, int(args.sample_limit or 0)):
sample = {k: row.get(k) for k in pk_cols}
sample["content_hash"] = content_hash
sample["recomputed_hash"] = recomputed
samples.append(sample)
if args.progress_every and processed % int(args.progress_every) == 0:
_print_progress(table_label, processed, total, mismatched, missing_hash, invalid_payload)
if processed and (not args.progress_every or processed % int(args.progress_every) != 0):
_print_progress(table_label, processed, total, mismatched, missing_hash, invalid_payload)
report["tables"].append(
{
"table": table_label,
"total_rows": total,
"checked_rows": processed,
"mismatch_rows": mismatched,
"missing_hash_rows": missing_hash,
"invalid_payload_rows": invalid_payload,
"sample_mismatches": samples,
}
)
report["summary"]["checked_tables"] += 1
report["summary"]["total_rows"] += total
report["summary"]["checked_rows"] += processed
report["summary"]["mismatch_rows"] += mismatched
report["summary"]["missing_hash_rows"] += missing_hash
report["summary"]["invalid_payload_rows"] += invalid_payload
report["summary"]["total_tables"] = len(tables)
out_path = _build_report_path(args.out)
out_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"[REPORT] {out_path}", flush=True)
return 0
if __name__ == "__main__":
raise SystemExit(main())

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,117 @@
# -*- coding: utf-8 -*-
"""
ODS JSON 字段核对脚本:对照当前数据库中的 ODS 表字段,检查示例 JSON默认目录 export/test-json-doc
是否包含同名键,并输出每表未命中的字段,便于补充映射或确认确实无源字段。
使用方法:
set PG_DSN=postgresql://... # 如 .env 中配置
python -m scripts.check.check_ods_json_vs_table
"""
from __future__ import annotations
import json
import os
import pathlib
from typing import Dict, Iterable, Set, Tuple
import psycopg2
from tasks.manual_ingest_task import ManualIngestTask
def _flatten_keys(obj, prefix: str = "") -> Set[str]:
"""递归展开 JSON 所有键路径,返回形如 data.assistantInfos.id 的集合。列表不保留索引,仅继续向下展开。"""
keys: Set[str] = set()
if isinstance(obj, dict):
for k, v in obj.items():
new_prefix = f"{prefix}.{k}" if prefix else k
keys.add(new_prefix)
keys |= _flatten_keys(v, new_prefix)
elif isinstance(obj, list):
for item in obj:
keys |= _flatten_keys(item, prefix)
return keys
def _load_json_keys(path: pathlib.Path) -> Tuple[Set[str], dict[str, Set[str]]]:
"""读取单个 JSON 文件并返回展开后的键集合以及末段->路径列表映射,若文件不存在或无法解析则返回空集合。"""
if not path.exists():
return set(), {}
data = json.loads(path.read_text(encoding="utf-8"))
paths = _flatten_keys(data)
last_map: dict[str, Set[str]] = {}
for p in paths:
last = p.split(".")[-1].lower()
last_map.setdefault(last, set()).add(p)
return paths, last_map
def _load_ods_columns(dsn: str) -> Dict[str, Set[str]]:
"""从数据库读取 billiards_ods.* 的列名集合,按表返回。"""
conn = psycopg2.connect(dsn)
cur = conn.cursor()
cur.execute(
"""
SELECT table_name, column_name
FROM information_schema.columns
WHERE table_schema='billiards_ods'
ORDER BY table_name, ordinal_position
"""
)
result: Dict[str, Set[str]] = {}
for table, col in cur.fetchall():
result.setdefault(table, set()).add(col.lower())
cur.close()
conn.close()
return result
def main() -> None:
"""主流程:遍历 FILE_MAPPING 中的 ODS 表,检查 JSON 键覆盖情况并打印报告。"""
dsn = os.environ.get("PG_DSN")
json_dir = pathlib.Path(os.environ.get("JSON_DOC_DIR", "export/test-json-doc"))
ods_cols_map = _load_ods_columns(dsn)
print(f"使用 JSON 目录: {json_dir}")
print(f"连接 DSN: {dsn}")
print("=" * 80)
for keywords, ods_table in ManualIngestTask.FILE_MAPPING:
table = ods_table.split(".")[-1]
cols = ods_cols_map.get(table, set())
file_name = f"{keywords[0]}.json"
file_path = json_dir / file_name
keys_full, path_map = _load_json_keys(file_path)
key_last_parts = set(path_map.keys())
missing: Set[str] = set()
extra_keys: Set[str] = set()
present: Set[str] = set()
for col in sorted(cols):
if col in key_last_parts:
present.add(col)
else:
missing.add(col)
for k in key_last_parts:
if k not in cols:
extra_keys.add(k)
print(f"[{table}] 文件={file_name} 列数={len(cols)} JSON键(末段)覆盖={len(present)}/{len(cols)}")
if missing:
print(" 未命中列:", ", ".join(sorted(missing)))
else:
print(" 未命中列: 无")
if extra_keys:
extras = []
for k in sorted(extra_keys):
paths = ", ".join(sorted(path_map.get(k, [])))
extras.append(f"{k} ({paths})")
print(" JSON 仅有(表无此列):", "; ".join(extras))
else:
print(" JSON 仅有(表无此列): 无")
print("-" * 80)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
"""验证DWS配置数据"""
import os
from pathlib import Path
from dotenv import load_dotenv
import psycopg2
def main():
load_dotenv(Path(__file__).parent.parent / ".env")
dsn = os.getenv("PG_DSN")
conn = psycopg2.connect(dsn)
tables = [
"cfg_performance_tier",
"cfg_assistant_level_price",
"cfg_bonus_rules",
"cfg_area_category",
"cfg_skill_type"
]
print("DWS 配置表数据统计:")
print("-" * 40)
with conn.cursor() as cur:
for t in tables:
cur.execute(f"SELECT COUNT(*) FROM billiards_dws.{t}")
cnt = cur.fetchone()[0]
print(f"{t}: {cnt}")
conn.close()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,605 @@
# -*- coding: utf-8 -*-
"""
DWS Excel导入脚本
功能说明:
支持三类Excel数据的导入
1. 支出结构dws_finance_expense_summary
2. 平台结算dws_platform_settlement
3. 充值提成dws_assistant_recharge_commission
导入规范:
- 字段定义:按照目标表字段要求
- 时间粒度:支出按月,平台结算按日,充值提成按月
- 门店维度使用配置的site_id
- 去重规则按import_batch_no去重
- 校验规则:金额字段非负,日期格式校验
使用方式:
python import_dws_excel.py --type expense --file expenses.xlsx
python import_dws_excel.py --type platform --file platform_settlement.xlsx
python import_dws_excel.py --type commission --file recharge_commission.xlsx
作者ETL团队
创建日期2026-02-01
"""
import argparse
import os
import sys
import uuid
from datetime import date, datetime
from decimal import Decimal, InvalidOperation
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
# 添加项目根目录到Python路径
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
try:
import pandas as pd
except ImportError:
print("请安装 pandas: pip install pandas openpyxl")
sys.exit(1)
from config.settings import AppConfig
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations
# =============================================================================
# 常量定义
# =============================================================================
# 支出类型枚举
EXPENSE_TYPES = {
'房租': 'RENT',
'水电费': 'UTILITY',
'物业费': 'PROPERTY',
'工资': 'SALARY',
'报销': 'REIMBURSE',
'平台服务费': 'PLATFORM_FEE',
'其他': 'OTHER',
}
# 支出大类映射
EXPENSE_CATEGORIES = {
'RENT': 'FIXED_COST',
'UTILITY': 'VARIABLE_COST',
'PROPERTY': 'FIXED_COST',
'SALARY': 'FIXED_COST',
'REIMBURSE': 'VARIABLE_COST',
'PLATFORM_FEE': 'VARIABLE_COST',
'OTHER': 'OTHER',
}
# 平台类型枚举
PLATFORM_TYPES = {
'美团': 'MEITUAN',
'抖音': 'DOUYIN',
'大众点评': 'DIANPING',
'其他': 'OTHER',
}
# =============================================================================
# 导入基类
# =============================================================================
class BaseImporter:
"""导入基类"""
def __init__(self, config: Config, db: DatabaseConnection):
self.config = config
self.db = db
self.site_id = config.get("app.store_id")
self.tenant_id = config.get("app.tenant_id", self.site_id)
self.batch_no = self._generate_batch_no()
def _generate_batch_no(self) -> str:
"""生成导入批次号"""
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
unique_id = str(uuid.uuid4())[:8]
return f"{timestamp}_{unique_id}"
def _safe_decimal(self, value: Any, default: Decimal = Decimal('0')) -> Decimal:
"""安全转换为Decimal"""
if value is None or pd.isna(value):
return default
try:
return Decimal(str(value))
except (ValueError, InvalidOperation):
return default
def _safe_date(self, value: Any) -> Optional[date]:
"""安全转换为日期"""
if value is None or pd.isna(value):
return None
if isinstance(value, datetime):
return value.date()
if isinstance(value, date):
return value
try:
return pd.to_datetime(value).date()
except:
return None
def _safe_month(self, value: Any) -> Optional[date]:
"""安全转换为月份(月第一天)"""
dt = self._safe_date(value)
if dt:
return dt.replace(day=1)
return None
def import_file(self, file_path: str) -> Dict[str, Any]:
"""导入文件"""
raise NotImplementedError
def validate_row(self, row: Dict[str, Any], row_idx: int) -> List[str]:
"""校验行数据,返回错误列表"""
return []
def transform_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
"""转换行数据"""
raise NotImplementedError
def insert_records(self, records: List[Dict[str, Any]]) -> int:
"""插入记录"""
raise NotImplementedError
# =============================================================================
# 支出导入
# =============================================================================
class ExpenseImporter(BaseImporter):
"""
支出导入
Excel格式要求
- 月份: 2026-01 或 2026/01/01 格式
- 支出类型: 房租/水电费/物业费/工资/报销/平台服务费/其他
- 金额: 数字
- 备注: 可选
"""
TARGET_TABLE = "billiards_dws.dws_finance_expense_summary"
REQUIRED_COLUMNS = ['月份', '支出类型', '金额']
OPTIONAL_COLUMNS = ['明细', '备注']
def import_file(self, file_path: str) -> Dict[str, Any]:
"""导入支出Excel"""
print(f"开始导入支出文件: {file_path}")
# 读取Excel
df = pd.read_excel(file_path)
# 校验必要列
missing_cols = [c for c in self.REQUIRED_COLUMNS if c not in df.columns]
if missing_cols:
return {"status": "ERROR", "message": f"缺少必要列: {missing_cols}"}
# 处理数据
records = []
errors = []
for idx, row in df.iterrows():
row_dict = row.to_dict()
row_errors = self.validate_row(row_dict, idx + 2) # Excel行号从2开始
if row_errors:
errors.extend(row_errors)
continue
record = self.transform_row(row_dict)
records.append(record)
if errors:
print(f"校验错误: {len(errors)}")
for err in errors[:10]:
print(f" - {err}")
# 插入数据
inserted = 0
if records:
inserted = self.insert_records(records)
return {
"status": "SUCCESS" if not errors else "PARTIAL",
"batch_no": self.batch_no,
"total_rows": len(df),
"inserted": inserted,
"errors": len(errors),
"error_messages": errors[:10]
}
def validate_row(self, row: Dict[str, Any], row_idx: int) -> List[str]:
errors = []
# 校验月份
month = self._safe_month(row.get('月份'))
if not month:
errors.append(f"{row_idx}: 月份格式错误")
# 校验支出类型
expense_type = row.get('支出类型', '').strip()
if expense_type not in EXPENSE_TYPES:
errors.append(f"{row_idx}: 支出类型无效 '{expense_type}'")
# 校验金额
amount = self._safe_decimal(row.get('金额'))
if amount < 0:
errors.append(f"{row_idx}: 金额不能为负数")
return errors
def transform_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
expense_type_name = row.get('支出类型', '').strip()
expense_type_code = EXPENSE_TYPES.get(expense_type_name, 'OTHER')
expense_category = EXPENSE_CATEGORIES.get(expense_type_code, 'OTHER')
return {
'site_id': self.site_id,
'tenant_id': self.tenant_id,
'expense_month': self._safe_month(row.get('月份')),
'expense_type_code': expense_type_code,
'expense_type_name': expense_type_name,
'expense_category': expense_category,
'expense_amount': self._safe_decimal(row.get('金额')),
'expense_detail': row.get('明细'),
'import_batch_no': self.batch_no,
'import_file_name': os.path.basename(str(row.get('_file_path', ''))),
'import_time': datetime.now(),
'import_user': os.getenv('USERNAME', 'system'),
'remark': row.get('备注'),
}
def insert_records(self, records: List[Dict[str, Any]]) -> int:
columns = [
'site_id', 'tenant_id', 'expense_month', 'expense_type_code',
'expense_type_name', 'expense_category', 'expense_amount',
'expense_detail', 'import_batch_no', 'import_file_name',
'import_time', 'import_user', 'remark'
]
cols_str = ", ".join(columns)
placeholders = ", ".join(["%s"] * len(columns))
sql = f"INSERT INTO {self.TARGET_TABLE} ({cols_str}) VALUES ({placeholders})"
inserted = 0
with self.db.conn.cursor() as cur:
for record in records:
values = [record.get(col) for col in columns]
cur.execute(sql, values)
inserted += cur.rowcount
self.db.commit()
return inserted
# =============================================================================
# 平台结算导入
# =============================================================================
class PlatformSettlementImporter(BaseImporter):
"""
平台结算导入
Excel格式要求
- 回款日期: 日期格式
- 平台类型: 美团/抖音/大众点评/其他
- 平台订单号: 字符串
- 订单原始金额: 数字
- 佣金: 数字
- 服务费: 数字
- 回款金额: 数字
- 备注: 可选
"""
TARGET_TABLE = "billiards_dws.dws_platform_settlement"
REQUIRED_COLUMNS = ['回款日期', '平台类型', '回款金额']
OPTIONAL_COLUMNS = ['平台订单号', '订单原始金额', '佣金', '服务费', '关联订单ID', '备注']
def import_file(self, file_path: str) -> Dict[str, Any]:
print(f"开始导入平台结算文件: {file_path}")
df = pd.read_excel(file_path)
missing_cols = [c for c in self.REQUIRED_COLUMNS if c not in df.columns]
if missing_cols:
return {"status": "ERROR", "message": f"缺少必要列: {missing_cols}"}
records = []
errors = []
for idx, row in df.iterrows():
row_dict = row.to_dict()
row_errors = self.validate_row(row_dict, idx + 2)
if row_errors:
errors.extend(row_errors)
continue
record = self.transform_row(row_dict)
records.append(record)
if errors:
print(f"校验错误: {len(errors)}")
for err in errors[:10]:
print(f" - {err}")
inserted = 0
if records:
inserted = self.insert_records(records)
return {
"status": "SUCCESS" if not errors else "PARTIAL",
"batch_no": self.batch_no,
"total_rows": len(df),
"inserted": inserted,
"errors": len(errors),
}
def validate_row(self, row: Dict[str, Any], row_idx: int) -> List[str]:
errors = []
settlement_date = self._safe_date(row.get('回款日期'))
if not settlement_date:
errors.append(f"{row_idx}: 回款日期格式错误")
platform_type = row.get('平台类型', '').strip()
if platform_type not in PLATFORM_TYPES:
errors.append(f"{row_idx}: 平台类型无效 '{platform_type}'")
amount = self._safe_decimal(row.get('回款金额'))
if amount < 0:
errors.append(f"{row_idx}: 回款金额不能为负数")
return errors
def transform_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
platform_name = row.get('平台类型', '').strip()
platform_type = PLATFORM_TYPES.get(platform_name, 'OTHER')
return {
'site_id': self.site_id,
'tenant_id': self.tenant_id,
'settlement_date': self._safe_date(row.get('回款日期')),
'platform_type': platform_type,
'platform_name': platform_name,
'platform_order_no': row.get('平台订单号'),
'order_settle_id': row.get('关联订单ID'),
'settlement_amount': self._safe_decimal(row.get('回款金额')),
'commission_amount': self._safe_decimal(row.get('佣金')),
'service_fee': self._safe_decimal(row.get('服务费')),
'gross_amount': self._safe_decimal(row.get('订单原始金额')),
'import_batch_no': self.batch_no,
'import_file_name': os.path.basename(str(row.get('_file_path', ''))),
'import_time': datetime.now(),
'import_user': os.getenv('USERNAME', 'system'),
'remark': row.get('备注'),
}
def insert_records(self, records: List[Dict[str, Any]]) -> int:
columns = [
'site_id', 'tenant_id', 'settlement_date', 'platform_type',
'platform_name', 'platform_order_no', 'order_settle_id',
'settlement_amount', 'commission_amount', 'service_fee',
'gross_amount', 'import_batch_no', 'import_file_name',
'import_time', 'import_user', 'remark'
]
cols_str = ", ".join(columns)
placeholders = ", ".join(["%s"] * len(columns))
sql = f"INSERT INTO {self.TARGET_TABLE} ({cols_str}) VALUES ({placeholders})"
inserted = 0
with self.db.conn.cursor() as cur:
for record in records:
values = [record.get(col) for col in columns]
cur.execute(sql, values)
inserted += cur.rowcount
self.db.commit()
return inserted
# =============================================================================
# 充值提成导入
# =============================================================================
class RechargeCommissionImporter(BaseImporter):
"""
充值提成导入
Excel格式要求
- 月份: 2026-01 格式
- 助教ID: 数字
- 助教花名: 字符串
- 充值订单金额: 数字
- 提成金额: 数字
- 充值订单号: 可选
- 备注: 可选
"""
TARGET_TABLE = "billiards_dws.dws_assistant_recharge_commission"
REQUIRED_COLUMNS = ['月份', '助教ID', '提成金额']
OPTIONAL_COLUMNS = ['助教花名', '充值订单金额', '充值订单ID', '充值订单号', '备注']
def import_file(self, file_path: str) -> Dict[str, Any]:
print(f"开始导入充值提成文件: {file_path}")
df = pd.read_excel(file_path)
missing_cols = [c for c in self.REQUIRED_COLUMNS if c not in df.columns]
if missing_cols:
return {"status": "ERROR", "message": f"缺少必要列: {missing_cols}"}
records = []
errors = []
for idx, row in df.iterrows():
row_dict = row.to_dict()
row_errors = self.validate_row(row_dict, idx + 2)
if row_errors:
errors.extend(row_errors)
continue
record = self.transform_row(row_dict)
records.append(record)
if errors:
print(f"校验错误: {len(errors)}")
for err in errors[:10]:
print(f" - {err}")
inserted = 0
if records:
inserted = self.insert_records(records)
return {
"status": "SUCCESS" if not errors else "PARTIAL",
"batch_no": self.batch_no,
"total_rows": len(df),
"inserted": inserted,
"errors": len(errors),
}
def validate_row(self, row: Dict[str, Any], row_idx: int) -> List[str]:
errors = []
month = self._safe_month(row.get('月份'))
if not month:
errors.append(f"{row_idx}: 月份格式错误")
assistant_id = row.get('助教ID')
if assistant_id is None or pd.isna(assistant_id):
errors.append(f"{row_idx}: 助教ID不能为空")
amount = self._safe_decimal(row.get('提成金额'))
if amount < 0:
errors.append(f"{row_idx}: 提成金额不能为负数")
return errors
def transform_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
recharge_amount = self._safe_decimal(row.get('充值订单金额'))
commission_amount = self._safe_decimal(row.get('提成金额'))
commission_ratio = commission_amount / recharge_amount if recharge_amount > 0 else None
return {
'site_id': self.site_id,
'tenant_id': self.tenant_id,
'assistant_id': int(row.get('助教ID')),
'assistant_nickname': row.get('助教花名'),
'commission_month': self._safe_month(row.get('月份')),
'recharge_order_id': row.get('充值订单ID'),
'recharge_order_no': row.get('充值订单号'),
'recharge_amount': recharge_amount,
'commission_amount': commission_amount,
'commission_ratio': commission_ratio,
'import_batch_no': self.batch_no,
'import_file_name': os.path.basename(str(row.get('_file_path', ''))),
'import_time': datetime.now(),
'import_user': os.getenv('USERNAME', 'system'),
'remark': row.get('备注'),
}
def insert_records(self, records: List[Dict[str, Any]]) -> int:
columns = [
'site_id', 'tenant_id', 'assistant_id', 'assistant_nickname',
'commission_month', 'recharge_order_id', 'recharge_order_no',
'recharge_amount', 'commission_amount', 'commission_ratio',
'import_batch_no', 'import_file_name', 'import_time',
'import_user', 'remark'
]
cols_str = ", ".join(columns)
placeholders = ", ".join(["%s"] * len(columns))
sql = f"INSERT INTO {self.TARGET_TABLE} ({cols_str}) VALUES ({placeholders})"
inserted = 0
with self.db.conn.cursor() as cur:
for record in records:
values = [record.get(col) for col in columns]
cur.execute(sql, values)
inserted += cur.rowcount
self.db.commit()
return inserted
# =============================================================================
# 主函数
# =============================================================================
def main():
parser = argparse.ArgumentParser(description='DWS Excel导入工具')
parser.add_argument(
'--type', '-t',
choices=['expense', 'platform', 'commission'],
required=True,
help='导入类型: expense(支出), platform(平台结算), commission(充值提成)'
)
parser.add_argument(
'--file', '-f',
required=True,
help='Excel文件路径'
)
args = parser.parse_args()
# 检查文件
if not os.path.exists(args.file):
print(f"文件不存在: {args.file}")
sys.exit(1)
# 加载配置
config = AppConfig.load()
dsn = config["db"]["dsn"]
db_conn = DatabaseConnection(dsn=dsn)
db = DatabaseOperations(db_conn)
try:
# 选择导入器
if args.type == 'expense':
importer = ExpenseImporter(config, db)
elif args.type == 'platform':
importer = PlatformSettlementImporter(config, db)
elif args.type == 'commission':
importer = RechargeCommissionImporter(config, db)
else:
print(f"未知的导入类型: {args.type}")
sys.exit(1)
# 执行导入
result = importer.import_file(args.file)
# 输出结果
print("\n" + "=" * 50)
print("导入结果:")
print(f" 状态: {result.get('status')}")
print(f" 批次号: {result.get('batch_no')}")
print(f" 总行数: {result.get('total_rows')}")
print(f" 插入行数: {result.get('inserted')}")
print(f" 错误行数: {result.get('errors')}")
if result.get('status') == 'ERROR':
print(f" 错误信息: {result.get('message')}")
sys.exit(1)
except Exception as e:
print(f"导入失败: {e}")
db_conn.rollback()
sys.exit(1)
finally:
db_conn.close()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,404 @@
# -*- coding: utf-8 -*-
"""
一键重建 ETL 相关 Schema并执行 ODS → DWD。
本脚本面向“离线示例 JSON 回放”的开发/运维场景,使用当前项目内的任务实现:
1) 可选DROP 并重建 schema`etl_admin` / `billiards_ods` / `billiards_dwd`
2) 执行 `INIT_ODS_SCHEMA`:创建 `etl_admin` 元数据表 + 执行 `schema_ODS_doc.sql`(内部会做轻量清洗)
3) 执行 `INIT_DWD_SCHEMA`:执行 `schema_dwd_doc.sql`
4) 执行 `MANUAL_INGEST`:从本地 JSON 目录灌入 ODS
5) 执行 `DWD_LOAD_FROM_ODS`:从 ODS 装载到 DWD
用法(推荐):
python -m scripts.rebuild.rebuild_db_and_run_ods_to_dwd ^
--dsn "postgresql://user:pwd@host:5432/db" ^
--store-id 1 ^
--json-dir "export/test-json-doc" ^
--drop-schemas
环境变量(可选):
PG_DSN、STORE_ID、INGEST_SOURCE_DIR
日志:
默认同时输出到控制台与文件;文件路径为 `io.log_root/rebuild_db_<时间戳>.log`。
"""
from __future__ import annotations
import argparse
import logging
import os
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import psycopg2
from config.settings import AppConfig
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations
from tasks.dwd.dwd_load_task import DwdLoadTask
from tasks.utility.init_dwd_schema_task import InitDwdSchemaTask
from tasks.utility.init_schema_task import InitOdsSchemaTask
from tasks.utility.manual_ingest_task import ManualIngestTask
DEFAULT_JSON_DIR = "export/test-json-doc"
@dataclass(frozen=True)
class RunArgs:
"""脚本参数对象(用于减少散落的参数传递)。"""
dsn: str
store_id: int
json_dir: str
drop_schemas: bool
terminate_own_sessions: bool
demo: bool
only_files: list[str]
only_dwd_tables: list[str]
stop_after: str | None
def _attach_file_logger(log_root: str | Path, filename: str, logger: logging.Logger) -> logging.Handler | None:
"""
给 root logger 附加文件日志处理器UTF-8
说明:
- 使用 root logger 是为了覆盖项目中不同命名的 logger包含第三方/子模块)。
- 若创建失败仅记录 warning不中断主流程。
返回值:
创建成功返回 handler调用方负责 removeHandler/close失败返回 None。
"""
log_dir = Path(log_root)
try:
log_dir.mkdir(parents=True, exist_ok=True)
except Exception as exc: # noqa: BLE001
logger.warning("创建日志目录失败:%s%s", log_dir, exc)
return None
log_path = log_dir / filename
try:
handler: logging.Handler = logging.FileHandler(log_path, encoding="utf-8")
except Exception as exc: # noqa: BLE001
logger.warning("创建文件日志失败:%s%s", log_path, exc)
return None
handler.setLevel(logging.INFO)
handler.setFormatter(
logging.Formatter(
fmt="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
)
logging.getLogger().addHandler(handler)
logger.info("文件日志已启用:%s", log_path)
return handler
def _parse_args() -> RunArgs:
"""解析命令行/环境变量参数。"""
parser = argparse.ArgumentParser(description="重建 Schema 并执行 ODS→DWD离线 JSON 回放)")
parser.add_argument("--dsn", default=os.environ.get("PG_DSN"), help="PostgreSQL DSN默认读取 PG_DSN")
parser.add_argument(
"--store-id",
type=int,
default=int(os.environ.get("STORE_ID") or 1),
help="门店/租户 store_id默认读取 STORE_ID否则为 1",
)
parser.add_argument(
"--json-dir",
default=os.environ.get("INGEST_SOURCE_DIR") or DEFAULT_JSON_DIR,
help=f"示例 JSON 目录(默认 {DEFAULT_JSON_DIR},也可读 INGEST_SOURCE_DIR",
)
parser.add_argument(
"--drop-schemas",
action=argparse.BooleanOptionalAction,
default=True,
help="是否先 DROP 并重建 etl_admin/billiards_ods/billiards_dwd默认",
)
parser.add_argument(
"--terminate-own-sessions",
action=argparse.BooleanOptionalAction,
default=True,
help="执行 DROP 前是否终止当前用户的 idle-in-transaction 会话(默认:是)",
)
parser.add_argument(
"--demo",
action=argparse.BooleanOptionalAction,
default=False,
help="运行最小 Demo仅导入 member_profiles 并生成 dim_member/dim_member_ex",
)
parser.add_argument(
"--only-files",
default="",
help="仅处理指定 JSON 文件(逗号分隔,不含 .json例如member_profiles,settlement_records",
)
parser.add_argument(
"--only-dwd-tables",
default="",
help="仅处理指定 DWD 表逗号分隔支持完整名或表名例如billiards_dwd.dim_member,dim_member_ex",
)
parser.add_argument(
"--stop-after",
default="",
help="在指定阶段后停止可选DROP_SCHEMAS/INIT_ODS_SCHEMA/INIT_DWD_SCHEMA/MANUAL_INGEST/DWD_LOAD_FROM_ODS/BASIC_VALIDATE",
)
args = parser.parse_args()
if not args.dsn:
raise SystemExit("缺少 DSN请传入 --dsn 或设置环境变量 PG_DSN")
only_files = [x.strip().lower() for x in str(args.only_files or "").split(",") if x.strip()]
only_dwd_tables = [x.strip().lower() for x in str(args.only_dwd_tables or "").split(",") if x.strip()]
stop_after = str(args.stop_after or "").strip().upper() or None
return RunArgs(
dsn=args.dsn,
store_id=args.store_id,
json_dir=str(args.json_dir),
drop_schemas=bool(args.drop_schemas),
terminate_own_sessions=bool(args.terminate_own_sessions),
demo=bool(args.demo),
only_files=only_files,
only_dwd_tables=only_dwd_tables,
stop_after=stop_after,
)
def _build_config(args: RunArgs) -> AppConfig:
"""构建本次执行所需的最小配置覆盖。"""
manual_cfg: dict[str, Any] = {}
dwd_cfg: dict[str, Any] = {}
if args.demo:
manual_cfg["include_files"] = ["member_profiles"]
dwd_cfg["only_tables"] = ["billiards_dwd.dim_member", "billiards_dwd.dim_member_ex"]
if args.only_files:
manual_cfg["include_files"] = args.only_files
if args.only_dwd_tables:
dwd_cfg["only_tables"] = args.only_dwd_tables
overrides: dict[str, Any] = {
"app": {"store_id": args.store_id},
"pipeline": {"flow": "INGEST_ONLY", "ingest_source_dir": args.json_dir},
"manual": manual_cfg,
"dwd": dwd_cfg,
# 离线回放/建仓可能耗时较长,关闭 statement_timeout避免被默认 30s 中断。
# 同时关闭 lock_timeout避免 DROP/DDL 因锁等待稍久就直接失败。
"db": {"dsn": args.dsn, "session": {"statement_timeout_ms": 0, "lock_timeout_ms": 0}},
}
return AppConfig.load(overrides)
def _drop_schemas(db: DatabaseOperations, logger: logging.Logger) -> None:
"""删除并重建 ETL 相关 schema具备破坏性请谨慎"""
with db.conn.cursor() as cur:
# 避免因为其他会话持锁而无限等待;若确实被占用,提示用户先释放/终止阻塞会话。
cur.execute("SET lock_timeout TO '5s'")
for schema in ("billiards_dwd", "billiards_ods", "etl_admin"):
logger.info("DROP SCHEMA IF EXISTS %s CASCADE ...", schema)
cur.execute(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE;')
def _terminate_own_idle_in_tx(db: DatabaseOperations, logger: logging.Logger) -> int:
"""终止当前用户在本库中处于 idle-in-transaction 的会话,避免阻塞 DROP/DDL。"""
with db.conn.cursor() as cur:
cur.execute(
"""
SELECT pid
FROM pg_stat_activity
WHERE datname = current_database()
AND usename = current_user
AND pid <> pg_backend_pid()
AND state = 'idle in transaction'
"""
)
pids = [r[0] for r in cur.fetchall()]
killed = 0
for pid in pids:
cur.execute("SELECT pg_terminate_backend(%s)", (pid,))
ok = bool(cur.fetchone()[0])
logger.info("终止会话 pid=%s ok=%s", pid, ok)
killed += 1 if ok else 0
return killed
def _run_task(task, logger: logging.Logger) -> dict:
"""统一运行任务并打印关键结果。"""
result = task.execute(None)
logger.info("%s: status=%s counts=%s", task.get_task_code(), result.get("status"), result.get("counts"))
return result
def _basic_validate(db: DatabaseOperations, logger: logging.Logger) -> None:
"""做最基础的可用性校验schema 存在、关键表行数可查询。"""
checks = [
("billiards_ods", "member_profiles"),
("billiards_ods", "settlement_records"),
("billiards_dwd", "dim_member"),
("billiards_dwd", "dwd_settlement_head"),
]
for schema, table in checks:
try:
rows = db.query(f'SELECT COUNT(1) AS cnt FROM "{schema}"."{table}"')
logger.info("校验行数:%s.%s = %s", schema, table, (rows[0] or {}).get("cnt") if rows else None)
except Exception as exc: # noqa: BLE001
logger.warning("校验失败:%s.%s%s", schema, table, exc)
def _connect_db_with_retry(cfg: AppConfig, logger: logging.Logger) -> DatabaseConnection:
"""创建数据库连接(带重试),避免短暂网络抖动导致脚本直接失败。"""
dsn = cfg["db"]["dsn"]
session = cfg["db"].get("session")
connect_timeout = cfg["db"].get("connect_timeout_sec")
backoffs = [1, 2, 4, 8, 16]
last_exc: Exception | None = None
for attempt, wait_sec in enumerate([0] + backoffs, start=1):
if wait_sec:
time.sleep(wait_sec)
try:
return DatabaseConnection(dsn=dsn, session=session, connect_timeout=connect_timeout)
except Exception as exc: # noqa: BLE001
last_exc = exc
logger.warning("数据库连接失败(第 %s 次):%s", attempt, exc)
raise last_exc or RuntimeError("数据库连接失败")
def _is_connection_error(exc: Exception) -> bool:
"""判断是否为连接断开/服务端异常导致的可重试错误。"""
return isinstance(exc, (psycopg2.OperationalError, psycopg2.InterfaceError))
def _run_stage_with_reconnect(
cfg: AppConfig,
logger: logging.Logger,
stage_name: str,
fn,
max_attempts: int = 3,
) -> dict | None:
"""
运行单个阶段:失败(尤其是连接断开)时自动重连并重试。
fn: (db_ops) -> dict | None
"""
last_exc: Exception | None = None
for attempt in range(1, max_attempts + 1):
db_conn = _connect_db_with_retry(cfg, logger)
db_ops = DatabaseOperations(db_conn)
try:
logger.info("阶段开始:%s(第 %s/%s 次)", stage_name, attempt, max_attempts)
result = fn(db_ops)
logger.info("阶段完成:%s", stage_name)
return result
except Exception as exc: # noqa: BLE001
last_exc = exc
logger.exception("阶段失败:%s(第 %s/%s 次):%s", stage_name, attempt, max_attempts, exc)
# 连接类错误允许重试;非连接错误直接抛出,避免掩盖逻辑问题。
if not _is_connection_error(exc):
raise
time.sleep(min(2**attempt, 10))
finally:
try:
db_ops.close() # type: ignore[attr-defined]
except Exception:
pass
try:
db_conn.close()
except Exception:
pass
raise last_exc or RuntimeError(f"阶段失败:{stage_name}")
def main() -> int:
"""脚本主入口:按顺序重建并跑通 ODS→DWD。"""
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("fq_etl.rebuild_db")
args = _parse_args()
cfg = _build_config(args)
# 默认启用文件日志,便于事后追溯(即便运行失败也应尽早落盘)。
file_handler = _attach_file_logger(
log_root=cfg["io"]["log_root"],
filename=time.strftime("rebuild_db_%Y%m%d-%H%M%S.log"),
logger=logger,
)
try:
json_dir = Path(args.json_dir)
if not json_dir.exists():
logger.error("示例 JSON 目录不存在:%s", json_dir)
return 2
def stage_drop(db_ops: DatabaseOperations):
if not args.drop_schemas:
return None
if args.terminate_own_sessions:
killed = _terminate_own_idle_in_tx(db_ops, logger)
if killed:
db_ops.commit()
_drop_schemas(db_ops, logger)
db_ops.commit()
return None
def stage_init_ods(db_ops: DatabaseOperations):
return _run_task(InitOdsSchemaTask(cfg, db_ops, None, logger), logger)
def stage_init_dwd(db_ops: DatabaseOperations):
return _run_task(InitDwdSchemaTask(cfg, db_ops, None, logger), logger)
def stage_manual_ingest(db_ops: DatabaseOperations):
logger.info("开始执行MANUAL_INGESTjson_dir=%s", json_dir)
return _run_task(ManualIngestTask(cfg, db_ops, None, logger), logger)
def stage_dwd_load(db_ops: DatabaseOperations):
logger.info("开始执行DWD_LOAD_FROM_ODS")
return _run_task(DwdLoadTask(cfg, db_ops, None, logger), logger)
_run_stage_with_reconnect(cfg, logger, "DROP_SCHEMAS", stage_drop, max_attempts=3)
if args.stop_after == "DROP_SCHEMAS":
return 0
_run_stage_with_reconnect(cfg, logger, "INIT_ODS_SCHEMA", stage_init_ods, max_attempts=3)
if args.stop_after == "INIT_ODS_SCHEMA":
return 0
_run_stage_with_reconnect(cfg, logger, "INIT_DWD_SCHEMA", stage_init_dwd, max_attempts=3)
if args.stop_after == "INIT_DWD_SCHEMA":
return 0
_run_stage_with_reconnect(cfg, logger, "MANUAL_INGEST", stage_manual_ingest, max_attempts=5)
if args.stop_after == "MANUAL_INGEST":
return 0
_run_stage_with_reconnect(cfg, logger, "DWD_LOAD_FROM_ODS", stage_dwd_load, max_attempts=5)
if args.stop_after == "DWD_LOAD_FROM_ODS":
return 0
# 校验阶段复用一条新连接即可
_run_stage_with_reconnect(
cfg,
logger,
"BASIC_VALIDATE",
lambda db_ops: _basic_validate(db_ops, logger),
max_attempts=3,
)
if args.stop_after == "BASIC_VALIDATE":
return 0
return 0
finally:
if file_handler is not None:
try:
logging.getLogger().removeHandler(file_handler)
except Exception:
pass
try:
file_handler.close()
except Exception:
pass
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,717 @@
# -*- coding: utf-8 -*-
"""
补全丢失的 ODS 数据
通过运行数据完整性检查,找出 API 与 ODS 之间的差异,
然后重新从 API 获取丢失的数据并插入 ODS。
用法:
python -m scripts.backfill_missing_data --start 2025-07-01 --end 2026-01-19
python -m scripts.backfill_missing_data --from-report reports/ods_gap_check_xxx.json
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
import time as time_mod
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
from zoneinfo import ZoneInfo
from dateutil import parser as dtparser
from psycopg2.extras import Json, execute_values
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from api.recording_client import build_recording_client
from config.settings import AppConfig
from database.connection import DatabaseConnection
from models.parsers import TypeParser
from tasks.ods.ods_tasks import BaseOdsTask, ENABLED_ODS_CODES, ODS_TASK_SPECS, OdsTaskSpec
from scripts.check.check_ods_gaps import run_gap_check
from utils.logging_utils import build_log_path, configure_logging
from utils.ods_record_utils import (
get_value_case_insensitive,
merge_record_layers,
normalize_pk_value,
pk_tuple_from_record,
)
def _reconfigure_stdout_utf8() -> None:
if hasattr(sys.stdout, "reconfigure"):
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
def _parse_dt(value: str, tz: ZoneInfo, *, is_end: bool = False) -> datetime:
raw = (value or "").strip()
if not raw:
raise ValueError("empty datetime")
has_time = any(ch in raw for ch in (":", "T"))
dt = dtparser.parse(raw)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=tz)
else:
dt = dt.astimezone(tz)
if not has_time:
dt = dt.replace(
hour=23 if is_end else 0,
minute=59 if is_end else 0,
second=59 if is_end else 0,
microsecond=0
)
return dt
def _get_spec(code: str) -> Optional[OdsTaskSpec]:
"""根据任务代码获取 ODS 任务规格"""
for spec in ODS_TASK_SPECS:
if spec.code == code:
return spec
return None
def _merge_record_layers(record: dict) -> dict:
"""Flatten nested data layers into a single dict."""
return merge_record_layers(record)
def _get_value_case_insensitive(record: dict | None, col: str | None):
"""Fetch value without case sensitivity."""
return get_value_case_insensitive(record, col)
def _normalize_pk_value(value):
"""Normalize PK value."""
return normalize_pk_value(value)
def _pk_tuple_from_record(record: dict, pk_cols: List[str]) -> Optional[Tuple]:
"""Extract PK tuple from record."""
return pk_tuple_from_record(record, pk_cols)
def _get_table_pk_columns(conn, table: str, *, include_content_hash: bool = False) -> List[str]:
"""获取表的主键列"""
if "." in table:
schema, name = table.split(".", 1)
else:
schema, name = "public", table
sql = """
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema = %s
AND tc.table_name = %s
ORDER BY kcu.ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, name))
cols = [r[0] for r in cur.fetchall()]
if include_content_hash:
return cols
return [c for c in cols if c.lower() != "content_hash"]
def _get_table_columns(conn, table: str) -> List[Tuple[str, str, str]]:
"""获取表的所有列信息"""
if "." in table:
schema, name = table.split(".", 1)
else:
schema, name = "public", table
sql = """
SELECT column_name, data_type, udt_name
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, name))
return [(r[0], (r[1] or "").lower(), (r[2] or "").lower()) for r in cur.fetchall()]
def _fetch_existing_pk_set(
conn, table: str, pk_cols: List[str], pk_values: List[Tuple], chunk_size: int
) -> Set[Tuple]:
"""获取已存在的 PK 集合"""
if not pk_values:
return set()
select_cols = ", ".join(f't."{c}"' for c in pk_cols)
value_cols = ", ".join(f'"{c}"' for c in pk_cols)
join_cond = " AND ".join(f't."{c}" = v."{c}"' for c in pk_cols)
sql = (
f"SELECT {select_cols} FROM {table} t "
f"JOIN (VALUES %s) AS v({value_cols}) ON {join_cond}"
)
existing: Set[Tuple] = set()
with conn.cursor() as cur:
for i in range(0, len(pk_values), chunk_size):
chunk = pk_values[i:i + chunk_size]
execute_values(cur, sql, chunk, page_size=len(chunk))
for row in cur.fetchall():
existing.add(tuple(row))
return existing
def _cast_value(value, data_type: str):
"""类型转换"""
if value is None:
return None
dt = (data_type or "").lower()
if dt in ("integer", "bigint", "smallint"):
if isinstance(value, bool):
return int(value)
try:
return int(value)
except Exception:
return None
if dt in ("numeric", "double precision", "real", "decimal"):
if isinstance(value, bool):
return int(value)
try:
return float(value)
except Exception:
return None
if dt.startswith("timestamp") or dt in ("date", "time", "interval"):
return value if isinstance(value, (str, datetime)) else None
return value
def _normalize_scalar(value):
"""规范化标量值"""
if value == "" or value == "{}" or value == "[]":
return None
return value
class MissingDataBackfiller:
"""丢失数据补全器"""
def __init__(
self,
cfg: AppConfig,
logger: logging.Logger,
dry_run: bool = False,
):
self.cfg = cfg
self.logger = logger
self.dry_run = dry_run
self.tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei"))
self.store_id = int(cfg.get("app.store_id") or 0)
# API 客户端
self.api = build_recording_client(cfg, task_code="BACKFILL_MISSING_DATA")
# 数据库连接DatabaseConnection 构造时已设置 autocommit=False
self.db = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
def close(self):
"""关闭连接"""
if self.db:
self.db.close()
def _ensure_db(self):
"""确保数据库连接可用"""
if self.db and getattr(self.db, "conn", None) is not None:
if getattr(self.db.conn, "closed", 0) == 0:
return
self.db = DatabaseConnection(dsn=self.cfg["db"]["dsn"], session=self.cfg["db"].get("session"))
def backfill_from_gap_check(
self,
*,
start: datetime,
end: datetime,
task_codes: Optional[str] = None,
include_mismatch: bool = False,
page_size: int = 200,
chunk_size: int = 500,
content_sample_limit: int | None = None,
) -> Dict[str, Any]:
"""
运行 gap check 并补全丢失数据
Returns:
补全结果统计
"""
self.logger.info("数据补全开始 起始=%s 结束=%s", start.isoformat(), end.isoformat())
# 计算窗口大小
total_seconds = max(0, int((end - start).total_seconds()))
if total_seconds >= 86400:
window_days = max(1, total_seconds // 86400)
window_hours = 0
else:
window_days = 0
window_hours = max(1, total_seconds // 3600 or 1)
# 运行 gap check
self.logger.info("正在执行缺失检查...")
gap_result = run_gap_check(
cfg=self.cfg,
start=start,
end=end,
window_days=window_days,
window_hours=window_hours,
page_size=page_size,
chunk_size=chunk_size,
sample_limit=10000, # 获取所有丢失样本
sleep_per_window=0,
sleep_per_page=0,
task_codes=task_codes or "",
from_cutoff=False,
cutoff_overlap_hours=24,
allow_small_window=True,
logger=self.logger,
compare_content=include_mismatch,
content_sample_limit=content_sample_limit or 10000,
)
total_missing = gap_result.get("total_missing", 0)
total_mismatch = gap_result.get("total_mismatch", 0)
if total_missing == 0 and (not include_mismatch or total_mismatch == 0):
self.logger.info("Data complete: no missing/mismatch records")
return {"backfilled": 0, "errors": 0, "details": []}
if include_mismatch:
self.logger.info("Missing/mismatch check done missing=%s mismatch=%s", total_missing, total_mismatch)
else:
self.logger.info("Missing check done missing=%s", total_missing)
results = []
total_backfilled = 0
total_errors = 0
for task_result in gap_result.get("results", []):
task_code = task_result.get("task_code")
missing = task_result.get("missing", 0)
missing_samples = task_result.get("missing_samples", [])
mismatch = task_result.get("mismatch", 0) if include_mismatch else 0
mismatch_samples = task_result.get("mismatch_samples", []) if include_mismatch else []
target_samples = list(missing_samples) + list(mismatch_samples)
if missing == 0 and mismatch == 0:
continue
self.logger.info(
"Start backfill task task=%s missing=%s mismatch=%s samples=%s",
task_code, missing, mismatch, len(target_samples)
)
try:
backfilled = self._backfill_task(
task_code=task_code,
table=task_result.get("table"),
pk_columns=task_result.get("pk_columns", []),
pk_samples=target_samples,
start=start,
end=end,
page_size=page_size,
chunk_size=chunk_size,
)
results.append({
"task_code": task_code,
"missing": missing,
"mismatch": mismatch,
"backfilled": backfilled,
"error": None,
})
total_backfilled += backfilled
except Exception as exc:
self.logger.exception("补全失败 任务=%s", task_code)
results.append({
"task_code": task_code,
"missing": missing,
"mismatch": mismatch,
"backfilled": 0,
"error": str(exc),
})
total_errors += 1
self.logger.info(
"数据补全完成 总缺失=%s 已补全=%s 错误数=%s",
total_missing, total_backfilled, total_errors
)
return {
"total_missing": total_missing,
"total_mismatch": total_mismatch,
"backfilled": total_backfilled,
"errors": total_errors,
"details": results,
}
def _backfill_task(
self,
*,
task_code: str,
table: str,
pk_columns: List[str],
pk_samples: List[Dict],
start: datetime,
end: datetime,
page_size: int,
chunk_size: int,
) -> int:
"""补全单个任务的丢失数据"""
self._ensure_db()
spec = _get_spec(task_code)
if not spec:
self.logger.warning("未找到任务规格 任务=%s", task_code)
return 0
if not pk_columns:
pk_columns = _get_table_pk_columns(self.db.conn, table, include_content_hash=False)
conflict_columns = _get_table_pk_columns(self.db.conn, table, include_content_hash=True)
if not conflict_columns:
conflict_columns = pk_columns
if not pk_columns:
self.logger.warning("未找到主键列 任务=%s 表=%s", task_code, table)
return 0
# 提取丢失的 PK 值
missing_pks: Set[Tuple] = set()
for sample in pk_samples:
pk_tuple = tuple(sample.get(col) for col in pk_columns)
if all(v is not None for v in pk_tuple):
missing_pks.add(pk_tuple)
if not missing_pks:
self.logger.info("无缺失主键 任务=%s", task_code)
return 0
self.logger.info(
"开始获取数据 任务=%s 缺失主键数=%s",
task_code, len(missing_pks)
)
# 从 API 获取数据并过滤出丢失的记录
params = self._build_params(spec, start, end)
backfilled = 0
cols_info = _get_table_columns(self.db.conn, table)
db_json_cols_lower = {
c[0].lower() for c in cols_info
if c[1] in ("json", "jsonb") or c[2] in ("json", "jsonb")
}
col_names = [c[0] for c in cols_info]
# 结束只读事务,避免长时间 API 拉取导致 idle_in_tx 超时
try:
self.db.conn.commit()
except Exception:
self.db.conn.rollback()
try:
for page_no, records, _, response_payload in self.api.iter_paginated(
endpoint=spec.endpoint,
params=params,
page_size=page_size,
data_path=spec.data_path,
list_key=spec.list_key,
):
# 过滤出丢失的记录
records_to_insert = []
for rec in records:
if not isinstance(rec, dict):
continue
pk_tuple = _pk_tuple_from_record(rec, pk_columns)
if pk_tuple and pk_tuple in missing_pks:
records_to_insert.append(rec)
if not records_to_insert:
continue
# 插入丢失的记录
if self.dry_run:
backfilled += len(records_to_insert)
self.logger.info(
"模拟运行 任务=%s 页=%s 将插入=%s",
task_code, page_no, len(records_to_insert)
)
else:
inserted = self._insert_records(
table=table,
records=records_to_insert,
cols_info=cols_info,
pk_columns=pk_columns,
conflict_columns=conflict_columns,
db_json_cols_lower=db_json_cols_lower,
)
backfilled += inserted
# 避免长事务阻塞与 idle_in_tx 超时
self.db.conn.commit()
self.logger.info(
"已插入 任务=%s 页=%s 数量=%s",
task_code, page_no, inserted
)
if not self.dry_run:
self.db.conn.commit()
self.logger.info("任务补全完成 任务=%s 已补全=%s", task_code, backfilled)
return backfilled
except Exception:
self.db.conn.rollback()
raise
def _build_params(
self,
spec: OdsTaskSpec,
start: datetime,
end: datetime,
) -> Dict:
"""构建 API 请求参数"""
base: Dict[str, Any] = {}
if spec.include_site_id:
if spec.endpoint == "/TenantGoods/GetGoodsInventoryList":
base["siteId"] = [self.store_id]
else:
base["siteId"] = self.store_id
if spec.requires_window and spec.time_fields:
start_key, end_key = spec.time_fields
base[start_key] = TypeParser.format_timestamp(start, self.tz)
base[end_key] = TypeParser.format_timestamp(end, self.tz)
# 合并公共参数
common = self.cfg.get("api.params", {}) or {}
if isinstance(common, dict):
merged = {**common, **base}
else:
merged = base
merged.update(spec.extra_params or {})
return merged
def _insert_records(
self,
*,
table: str,
records: List[Dict],
cols_info: List[Tuple[str, str, str]],
pk_columns: List[str],
conflict_columns: List[str],
db_json_cols_lower: Set[str],
) -> int:
"""插入记录到数据库"""
if not records:
return 0
col_names = [c[0] for c in cols_info]
needs_content_hash = any(c[0].lower() == "content_hash" for c in cols_info)
quoted_cols = ", ".join(f'"{c}"' for c in col_names)
sql = f"INSERT INTO {table} ({quoted_cols}) VALUES %s"
conflict_cols = conflict_columns or pk_columns
if conflict_cols:
pk_clause = ", ".join(f'"{c}"' for c in conflict_cols)
sql += f" ON CONFLICT ({pk_clause}) DO NOTHING"
now = datetime.now(self.tz)
json_dump = lambda v: json.dumps(v, ensure_ascii=False)
params: List[Tuple] = []
for rec in records:
merged_rec = _merge_record_layers(rec)
# 检查 PK
if pk_columns:
missing_pk = False
for pk in pk_columns:
if str(pk).lower() == "content_hash":
continue
pk_val = _get_value_case_insensitive(merged_rec, pk)
if pk_val is None or pk_val == "":
missing_pk = True
break
if missing_pk:
continue
content_hash = None
if needs_content_hash:
content_hash = BaseOdsTask._compute_content_hash(
merged_rec, include_fetched_at=False
)
row_vals: List[Any] = []
for (col_name, data_type, _udt) in cols_info:
col_lower = col_name.lower()
if col_lower == "payload":
row_vals.append(Json(rec, dumps=json_dump))
continue
if col_lower == "source_file":
row_vals.append("backfill")
continue
if col_lower == "source_endpoint":
row_vals.append("backfill")
continue
if col_lower == "fetched_at":
row_vals.append(now)
continue
if col_lower == "content_hash":
row_vals.append(content_hash)
continue
value = _normalize_scalar(_get_value_case_insensitive(merged_rec, col_name))
if col_lower in db_json_cols_lower:
row_vals.append(Json(value, dumps=json_dump) if value is not None else None)
continue
row_vals.append(_cast_value(value, data_type))
params.append(tuple(row_vals))
if not params:
return 0
inserted = 0
with self.db.conn.cursor() as cur:
for i in range(0, len(params), 200):
chunk = params[i:i + 200]
execute_values(cur, sql, chunk, page_size=len(chunk))
if cur.rowcount is not None and cur.rowcount > 0:
inserted += int(cur.rowcount)
return inserted
def run_backfill(
*,
cfg: AppConfig,
start: datetime,
end: datetime,
task_codes: Optional[str] = None,
include_mismatch: bool = False,
dry_run: bool = False,
page_size: int = 200,
chunk_size: int = 500,
content_sample_limit: int | None = None,
logger: logging.Logger,
) -> Dict[str, Any]:
"""
运行数据补全
Args:
cfg: 应用配置
start: 开始时间
end: 结束时间
task_codes: 指定任务代码(逗号分隔)
dry_run: 是否仅预览
page_size: API 分页大小
chunk_size: 数据库批量大小
logger: 日志记录器
Returns:
补全结果
"""
backfiller = MissingDataBackfiller(cfg, logger, dry_run)
try:
return backfiller.backfill_from_gap_check(
start=start,
end=end,
task_codes=task_codes,
include_mismatch=include_mismatch,
page_size=page_size,
chunk_size=chunk_size,
content_sample_limit=content_sample_limit,
)
finally:
backfiller.close()
def main() -> int:
_reconfigure_stdout_utf8()
ap = argparse.ArgumentParser(description="补全丢失的 ODS 数据")
ap.add_argument("--start", default="2025-07-01", help="开始日期 (默认: 2025-07-01)")
ap.add_argument("--end", default="", help="结束日期 (默认: 当前时间)")
ap.add_argument("--task-codes", default="", help="指定任务代码(逗号分隔,留空=全部)")
ap.add_argument("--include-mismatch", action="store_true", help="同时补全内容不一致的记录")
ap.add_argument("--content-sample-limit", type=int, default=None, help="不一致样本上限 (默认: 10000)")
ap.add_argument("--dry-run", action="store_true", help="仅预览,不实际写入")
ap.add_argument("--page-size", type=int, default=200, help="API 分页大小 (默认: 200)")
ap.add_argument("--chunk-size", type=int, default=500, help="数据库批量大小 (默认: 500)")
ap.add_argument("--log-file", default="", help="日志文件路径")
ap.add_argument("--log-dir", default="", help="日志目录")
ap.add_argument("--log-level", default="INFO", help="日志级别 (默认: INFO)")
ap.add_argument("--no-log-console", action="store_true", help="禁用控制台日志")
args = ap.parse_args()
log_dir = Path(args.log_dir) if args.log_dir else (PROJECT_ROOT / "logs")
log_file = Path(args.log_file) if args.log_file else build_log_path(log_dir, "backfill_missing")
log_console = not args.no_log_console
with configure_logging(
"backfill_missing",
log_file,
level=args.log_level,
console=log_console,
tee_std=True,
) as logger:
cfg = AppConfig.load({})
tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei"))
start = _parse_dt(args.start, tz)
end = _parse_dt(args.end, tz, is_end=True) if args.end else datetime.now(tz)
result = run_backfill(
cfg=cfg,
start=start,
end=end,
task_codes=args.task_codes or None,
include_mismatch=args.include_mismatch,
dry_run=args.dry_run,
page_size=args.page_size,
chunk_size=args.chunk_size,
content_sample_limit=args.content_sample_limit,
logger=logger,
)
logger.info("=" * 60)
logger.info("补全完成!")
logger.info(" 总丢失: %s", result.get("total_missing", 0))
if args.include_mismatch:
logger.info(" 总不一致: %s", result.get("total_mismatch", 0))
logger.info(" 已补全: %s", result.get("backfilled", 0))
logger.info(" 错误数: %s", result.get("errors", 0))
logger.info("=" * 60)
# 输出详细结果
for detail in result.get("details", []):
if detail.get("error"):
logger.error(
" %s: 丢失=%s 不一致=%s 补全=%s 错误=%s",
detail.get("task_code"),
detail.get("missing"),
detail.get("mismatch", 0),
detail.get("backfilled"),
detail.get("error"),
)
elif detail.get("backfilled", 0) > 0:
logger.info(
" %s: 丢失=%s 不一致=%s 补全=%s",
detail.get("task_code"),
detail.get("missing"),
detail.get("mismatch", 0),
detail.get("backfilled"),
)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,261 @@
# -*- coding: utf-8 -*-
"""
Deduplicate ODS snapshots by (business PK, content_hash).
Keep the latest row by fetched_at (tie-breaker: ctid desc).
Usage:
PYTHONPATH=. python -m scripts.repair.dedupe_ods_snapshots
PYTHONPATH=. python -m scripts.repair.dedupe_ods_snapshots --schema billiards_ods
PYTHONPATH=. python -m scripts.repair.dedupe_ods_snapshots --tables member_profiles,orders
"""
from __future__ import annotations
import argparse
import json
import sys
from datetime import datetime
from pathlib import Path
from typing import Iterable, Sequence
import psycopg2
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from config.settings import AppConfig
from database.connection import DatabaseConnection
def _reconfigure_stdout_utf8() -> None:
if hasattr(sys.stdout, "reconfigure"):
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
def _quote_ident(name: str) -> str:
return '"' + str(name).replace('"', '""') + '"'
def _fetch_tables(conn, schema: str) -> list[str]:
sql = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = %s AND table_type = 'BASE TABLE'
ORDER BY table_name
"""
with conn.cursor() as cur:
cur.execute(sql, (schema,))
return [r[0] for r in cur.fetchall()]
def _fetch_columns(conn, schema: str, table: str) -> list[str]:
sql = """
SELECT column_name
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, table))
return [r[0] for r in cur.fetchall()]
def _fetch_pk_columns(conn, schema: str, table: str) -> list[str]:
sql = """
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema = %s
AND tc.table_name = %s
ORDER BY kcu.ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, table))
cols = [r[0] for r in cur.fetchall()]
return [c for c in cols if c.lower() != "content_hash"]
def _build_report_path(out_arg: str | None) -> Path:
if out_arg:
return Path(out_arg)
reports_dir = PROJECT_ROOT / "reports"
reports_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
return reports_dir / f"ods_snapshot_dedupe_{ts}.json"
def _print_progress(
table_label: str,
deleted: int,
total: int,
errors: int,
) -> None:
if total:
msg = f"[{table_label}] deleted {deleted}/{total} errors={errors}"
else:
msg = f"[{table_label}] deleted {deleted} errors={errors}"
print(msg, flush=True)
def _count_duplicates(conn, schema: str, table: str, key_cols: Sequence[str]) -> int:
keys_sql = ", ".join(_quote_ident(c) for c in [*key_cols, "content_hash"])
table_sql = f"{_quote_ident(schema)}.{_quote_ident(table)}"
sql = f"""
SELECT COUNT(*) FROM (
SELECT 1
FROM (
SELECT ROW_NUMBER() OVER (
PARTITION BY {keys_sql}
ORDER BY fetched_at DESC NULLS LAST, ctid DESC
) AS rn
FROM {table_sql}
) t
WHERE rn > 1
) s
"""
with conn.cursor() as cur:
cur.execute(sql)
row = cur.fetchone()
return int(row[0] if row else 0)
def _delete_duplicate_batch(
conn,
schema: str,
table: str,
key_cols: Sequence[str],
batch_size: int,
) -> int:
keys_sql = ", ".join(_quote_ident(c) for c in [*key_cols, "content_hash"])
table_sql = f"{_quote_ident(schema)}.{_quote_ident(table)}"
sql = f"""
WITH dupes AS (
SELECT ctid
FROM (
SELECT ctid,
ROW_NUMBER() OVER (
PARTITION BY {keys_sql}
ORDER BY fetched_at DESC NULLS LAST, ctid DESC
) AS rn
FROM {table_sql}
) s
WHERE rn > 1
LIMIT %s
)
DELETE FROM {table_sql} t
USING dupes d
WHERE t.ctid = d.ctid
RETURNING 1
"""
with conn.cursor() as cur:
cur.execute(sql, (int(batch_size),))
rows = cur.fetchall()
return len(rows or [])
def main() -> int:
_reconfigure_stdout_utf8()
ap = argparse.ArgumentParser(description="Deduplicate ODS snapshot rows by PK+content_hash")
ap.add_argument("--schema", default="billiards_ods", help="ODS schema name")
ap.add_argument("--tables", default="", help="comma-separated table names (optional)")
ap.add_argument("--batch-size", type=int, default=1000, help="delete batch size")
ap.add_argument("--progress-every", type=int, default=100, help="print progress every N deletions")
ap.add_argument("--out", default="", help="output report JSON path")
ap.add_argument("--dry-run", action="store_true", help="only compute duplicate counts")
args = ap.parse_args()
cfg = AppConfig.load({})
db = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
try:
db.conn.rollback()
except Exception:
pass
db.conn.autocommit = True
tables = _fetch_tables(db.conn, args.schema)
if args.tables.strip():
whitelist = {t.strip() for t in args.tables.split(",") if t.strip()}
tables = [t for t in tables if t in whitelist]
report = {
"schema": args.schema,
"tables": [],
"summary": {
"total_tables": len(tables),
"checked_tables": 0,
"total_duplicates": 0,
"deleted_rows": 0,
"error_rows": 0,
"skipped_tables": 0,
},
}
for table in tables:
table_label = f"{args.schema}.{table}"
cols = _fetch_columns(db.conn, args.schema, table)
cols_lower = {c.lower() for c in cols}
if "content_hash" not in cols_lower or "fetched_at" not in cols_lower:
print(f"[{table_label}] skip: missing content_hash/fetched_at", flush=True)
report["summary"]["skipped_tables"] += 1
continue
key_cols = _fetch_pk_columns(db.conn, args.schema, table)
if not key_cols:
print(f"[{table_label}] skip: missing primary key", flush=True)
report["summary"]["skipped_tables"] += 1
continue
total_dupes = _count_duplicates(db.conn, args.schema, table, key_cols)
print(f"[{table_label}] duplicates={total_dupes}", flush=True)
deleted = 0
errors = 0
if not args.dry_run and total_dupes:
while True:
try:
batch_deleted = _delete_duplicate_batch(
db.conn,
args.schema,
table,
key_cols,
args.batch_size,
)
except psycopg2.Error:
errors += 1
break
if batch_deleted <= 0:
break
deleted += batch_deleted
if args.progress_every and deleted % int(args.progress_every) == 0:
_print_progress(table_label, deleted, total_dupes, errors)
if deleted and (not args.progress_every or deleted % int(args.progress_every) != 0):
_print_progress(table_label, deleted, total_dupes, errors)
report["tables"].append(
{
"table": table_label,
"duplicate_rows": total_dupes,
"deleted_rows": deleted,
"error_rows": errors,
}
)
report["summary"]["checked_tables"] += 1
report["summary"]["total_duplicates"] += total_dupes
report["summary"]["deleted_rows"] += deleted
report["summary"]["error_rows"] += errors
out_path = _build_report_path(args.out)
out_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"[REPORT] {out_path}", flush=True)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,86 @@
# -*- coding: utf-8 -*-
"""修复 dim_assistant 表中的 user_id 字段"""
import sys
sys.path.insert(0, '.')
from config.settings import AppConfig
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations
config = AppConfig.load()
db_conn = DatabaseConnection(config.config['db']['dsn'])
db = DatabaseOperations(db_conn)
print("=== 修复 dim_assistant.user_id ===")
# 方案:从 ODS 表更新 DWD 表的 user_id
# 通过 id (ODS) = assistant_id (DWD) 关联
# 1. 先检查当前状态
print("\n修复前:")
sql_before = """
SELECT
COUNT(*) as total,
COUNT(CASE WHEN user_id > 0 THEN 1 END) as has_user_id
FROM billiards_dwd.dim_assistant
WHERE scd2_is_current = 1
"""
r = dict(db.query(sql_before)[0])
print(f" 总记录: {r['total']}, 有user_id: {r['has_user_id']}")
# 2. 执行更新
print("\n执行更新...")
update_sql = """
UPDATE billiards_dwd.dim_assistant d
SET user_id = o.user_id
FROM (
SELECT DISTINCT ON (id) id, user_id
FROM billiards_ods.assistant_accounts_master
WHERE user_id > 0
ORDER BY id, fetched_at DESC
) o
WHERE d.assistant_id = o.id
AND (d.user_id IS NULL OR d.user_id = 0)
"""
with db_conn.conn.cursor() as cur:
cur.execute(update_sql)
updated = cur.rowcount
print(f" 更新了 {updated} 条记录")
db_conn.conn.commit()
# 3. 检查修复后状态
print("\n修复后:")
r2 = dict(db.query(sql_before)[0])
print(f" 总记录: {r2['total']}, 有user_id: {r2['has_user_id']}")
# 4. 显示样本数据
print("\n样本数据:")
sql_sample = """
SELECT assistant_id, user_id, assistant_no, nickname
FROM billiards_dwd.dim_assistant
WHERE scd2_is_current = 1
ORDER BY assistant_no::int
LIMIT 10
"""
for row in db.query(sql_sample):
r = dict(row)
print(f" assistant_id={r['assistant_id']}, user_id={r['user_id']}, no={r['assistant_no']}, nickname={r['nickname']}")
# 5. 验证与服务日志的关联
print("\n验证与服务日志的关联:")
sql_verify = """
SELECT
COUNT(DISTINCT s.user_id) as service_unique_users,
COUNT(DISTINCT CASE WHEN d.assistant_id IS NOT NULL THEN s.user_id END) as matched_users
FROM billiards_dwd.dwd_assistant_service_log s
LEFT JOIN billiards_dwd.dim_assistant d
ON s.user_id = d.user_id AND d.scd2_is_current = 1
WHERE s.is_delete = 0 AND s.user_id > 0
"""
r3 = dict(db.query(sql_verify)[0])
print(f" 服务日志唯一user_id: {r3['service_unique_users']}")
print(f" 能匹配到dim_assistant: {r3['matched_users']}")
match_rate = r3['matched_users'] / r3['service_unique_users'] * 100 if r3['service_unique_users'] > 0 else 0
print(f" 匹配率: {match_rate:.1f}%")
db_conn.close()
print("\n完成!")

View File

@@ -0,0 +1,302 @@
# -*- coding: utf-8 -*-
"""
Repair ODS content_hash values by recomputing from payload.
Usage:
PYTHONPATH=. python -m scripts.repair.repair_ods_content_hash
PYTHONPATH=. python -m scripts.repair.repair_ods_content_hash --schema billiards_ods
PYTHONPATH=. python -m scripts.repair.repair_ods_content_hash --tables member_profiles,orders
"""
from __future__ import annotations
import argparse
import json
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, Iterable, Sequence
import psycopg2
from psycopg2.extras import RealDictCursor
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from config.settings import AppConfig
from database.connection import DatabaseConnection
from tasks.ods.ods_tasks import BaseOdsTask
def _reconfigure_stdout_utf8() -> None:
if hasattr(sys.stdout, "reconfigure"):
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
def _fetch_tables(conn, schema: str) -> list[str]:
sql = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = %s AND table_type = 'BASE TABLE'
ORDER BY table_name
"""
with conn.cursor() as cur:
cur.execute(sql, (schema,))
return [r[0] for r in cur.fetchall()]
def _fetch_columns(conn, schema: str, table: str) -> list[str]:
sql = """
SELECT column_name
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, table))
cols = [r[0] for r in cur.fetchall()]
return [c for c in cols if c]
def _fetch_pk_columns(conn, schema: str, table: str) -> list[str]:
sql = """
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema = %s
AND tc.table_name = %s
ORDER BY kcu.ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, table))
cols = [r[0] for r in cur.fetchall()]
return [c for c in cols if c.lower() != "content_hash"]
def _fetch_row_count(conn, schema: str, table: str) -> int:
sql = f'SELECT COUNT(*) FROM "{schema}"."{table}"'
with conn.cursor() as cur:
cur.execute(sql)
row = cur.fetchone()
return int(row[0] if row else 0)
def _iter_rows(
conn,
schema: str,
table: str,
select_cols: Sequence[str],
batch_size: int,
) -> Iterable[dict]:
cols_sql = ", ".join("ctid" if c == "ctid" else f'"{c}"' for c in select_cols)
sql = f'SELECT {cols_sql} FROM "{schema}"."{table}"'
with conn.cursor(name=f"ods_hash_fix_{table}", cursor_factory=RealDictCursor) as cur:
cur.itersize = max(1, int(batch_size or 500))
cur.execute(sql)
for row in cur:
yield row
def _build_report_path(out_arg: str | None) -> Path:
if out_arg:
return Path(out_arg)
reports_dir = PROJECT_ROOT / "reports"
reports_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
return reports_dir / f"ods_content_hash_repair_{ts}.json"
def _print_progress(
table_label: str,
processed: int,
total: int,
updated: int,
skipped: int,
conflicts: int,
errors: int,
missing_hash: int,
invalid_payload: int,
) -> None:
if total:
msg = (
f"[{table_label}] checked {processed}/{total} "
f"updated={updated} skipped={skipped} conflicts={conflicts} errors={errors} "
f"missing_hash={missing_hash} invalid_payload={invalid_payload}"
)
else:
msg = (
f"[{table_label}] checked {processed} "
f"updated={updated} skipped={skipped} conflicts={conflicts} errors={errors} "
f"missing_hash={missing_hash} invalid_payload={invalid_payload}"
)
print(msg, flush=True)
def main() -> int:
_reconfigure_stdout_utf8()
ap = argparse.ArgumentParser(description="Repair ODS content_hash using payload")
ap.add_argument("--schema", default="billiards_ods", help="ODS schema name")
ap.add_argument("--tables", default="", help="comma-separated table names (optional)")
ap.add_argument("--batch-size", type=int, default=500, help="DB fetch batch size")
ap.add_argument("--progress-every", type=int, default=100, help="print progress every N rows")
ap.add_argument("--sample-limit", type=int, default=10, help="sample conflicts per table")
ap.add_argument("--out", default="", help="output report JSON path")
ap.add_argument("--dry-run", action="store_true", help="only compute stats, do not update")
args = ap.parse_args()
cfg = AppConfig.load({})
db_read = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
db_write = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
try:
db_write.conn.rollback()
except Exception:
pass
db_write.conn.autocommit = True
tables = _fetch_tables(db_read.conn, args.schema)
if args.tables.strip():
whitelist = {t.strip() for t in args.tables.split(",") if t.strip()}
tables = [t for t in tables if t in whitelist]
report = {
"schema": args.schema,
"tables": [],
"summary": {
"total_tables": len(tables),
"checked_tables": 0,
"total_rows": 0,
"checked_rows": 0,
"updated_rows": 0,
"skipped_rows": 0,
"conflict_rows": 0,
"error_rows": 0,
"missing_hash_rows": 0,
"invalid_payload_rows": 0,
},
}
for table in tables:
table_label = f"{args.schema}.{table}"
cols = _fetch_columns(db_read.conn, args.schema, table)
cols_lower = {c.lower() for c in cols}
if "payload" not in cols_lower or "content_hash" not in cols_lower:
print(f"[{table_label}] skip: missing payload/content_hash", flush=True)
continue
total = _fetch_row_count(db_read.conn, args.schema, table)
pk_cols = _fetch_pk_columns(db_read.conn, args.schema, table)
select_cols = ["ctid", "content_hash", "payload", *pk_cols]
processed = 0
updated = 0
skipped = 0
conflicts = 0
errors = 0
missing_hash = 0
invalid_payload = 0
samples: list[dict[str, Any]] = []
print(f"[{table_label}] start: total_rows={total}", flush=True)
for row in _iter_rows(db_read.conn, args.schema, table, select_cols, args.batch_size):
processed += 1
content_hash = row.get("content_hash")
payload = row.get("payload")
recomputed = BaseOdsTask._compute_compare_hash_from_payload(payload)
row_ctid = row.get("ctid")
if not content_hash:
missing_hash += 1
if not recomputed:
invalid_payload += 1
if not recomputed:
skipped += 1
elif content_hash == recomputed:
skipped += 1
else:
if args.dry_run:
updated += 1
else:
try:
with db_write.conn.cursor() as cur:
cur.execute(
f'UPDATE "{args.schema}"."{table}" SET content_hash = %s WHERE ctid = %s',
(recomputed, row_ctid),
)
updated += 1
except psycopg2.errors.UniqueViolation:
conflicts += 1
if len(samples) < max(0, int(args.sample_limit or 0)):
sample = {k: row.get(k) for k in pk_cols}
sample["content_hash"] = content_hash
sample["recomputed_hash"] = recomputed
samples.append(sample)
except psycopg2.Error:
errors += 1
if args.progress_every and processed % int(args.progress_every) == 0:
_print_progress(
table_label,
processed,
total,
updated,
skipped,
conflicts,
errors,
missing_hash,
invalid_payload,
)
if processed and (not args.progress_every or processed % int(args.progress_every) != 0):
_print_progress(
table_label,
processed,
total,
updated,
skipped,
conflicts,
errors,
missing_hash,
invalid_payload,
)
report["tables"].append(
{
"table": table_label,
"total_rows": total,
"checked_rows": processed,
"updated_rows": updated,
"skipped_rows": skipped,
"conflict_rows": conflicts,
"error_rows": errors,
"missing_hash_rows": missing_hash,
"invalid_payload_rows": invalid_payload,
"conflict_samples": samples,
}
)
report["summary"]["checked_tables"] += 1
report["summary"]["total_rows"] += total
report["summary"]["checked_rows"] += processed
report["summary"]["updated_rows"] += updated
report["summary"]["skipped_rows"] += skipped
report["summary"]["conflict_rows"] += conflicts
report["summary"]["error_rows"] += errors
report["summary"]["missing_hash_rows"] += missing_hash
report["summary"]["invalid_payload_rows"] += invalid_payload
out_path = _build_report_path(args.out)
out_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"[REPORT] {out_path}", flush=True)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,231 @@
# -*- coding: utf-8 -*-
"""Create performance indexes for integrity verification and run ANALYZE.
Usage:
python -m scripts.tune_integrity_indexes
python -m scripts.tune_integrity_indexes --dry-run
"""
from __future__ import annotations
import argparse
import hashlib
from dataclasses import dataclass
from typing import Dict, List, Sequence, Set, Tuple
import psycopg2
from psycopg2 import sql
from config.settings import AppConfig
TIME_CANDIDATES = (
"pay_time",
"create_time",
"start_use_time",
"scd2_start_time",
"calc_time",
"order_date",
"fetched_at",
)
@dataclass(frozen=True)
class IndexPlan:
schema: str
table: str
index_name: str
columns: Tuple[str, ...]
def _short_index_name(table: str, tag: str, columns: Sequence[str]) -> str:
raw = f"idx_{table}_{tag}_{'_'.join(columns)}"
if len(raw) <= 63:
return raw
digest = hashlib.md5(raw.encode("utf-8")).hexdigest()[:8]
shortened = f"idx_{table}_{tag}_{digest}"
return shortened[:63]
def _load_table_columns(cur, schema: str, table: str) -> Set[str]:
cur.execute(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
""",
(schema, table),
)
return {r[0] for r in cur.fetchall()}
def _load_pk_columns(cur, schema: str, table: str) -> List[str]:
cur.execute(
"""
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
AND tc.table_name = kcu.table_name
WHERE tc.table_schema = %s
AND tc.table_name = %s
AND tc.constraint_type = 'PRIMARY KEY'
ORDER BY kcu.ordinal_position
""",
(schema, table),
)
return [r[0] for r in cur.fetchall()]
def _load_tables(cur, schema: str) -> List[str]:
cur.execute(
"""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = %s
AND table_type = 'BASE TABLE'
ORDER BY table_name
""",
(schema,),
)
return [r[0] for r in cur.fetchall()]
def _plan_indexes(cur, schema: str, table: str) -> List[IndexPlan]:
plans: List[IndexPlan] = []
cols = _load_table_columns(cur, schema, table)
pk_cols = _load_pk_columns(cur, schema, table)
if schema == "billiards_ods":
if "fetched_at" in cols:
plans.append(
IndexPlan(
schema=schema,
table=table,
index_name=_short_index_name(table, "fetched_at", ("fetched_at",)),
columns=("fetched_at",),
)
)
if pk_cols and len(pk_cols) <= 3 and all(c in cols for c in pk_cols):
comp_cols = ("fetched_at", *pk_cols)
plans.append(
IndexPlan(
schema=schema,
table=table,
index_name=_short_index_name(table, "fetched_pk", comp_cols),
columns=comp_cols,
)
)
if schema == "billiards_dwd":
if pk_cols and "scd2_is_current" in cols and len(pk_cols) <= 4:
comp_cols = (*pk_cols, "scd2_is_current")
plans.append(
IndexPlan(
schema=schema,
table=table,
index_name=_short_index_name(table, "pk_current", comp_cols),
columns=comp_cols,
)
)
for tcol in TIME_CANDIDATES:
if tcol in cols:
plans.append(
IndexPlan(
schema=schema,
table=table,
index_name=_short_index_name(table, "time", (tcol,)),
columns=(tcol,),
)
)
if pk_cols and len(pk_cols) <= 3 and all(c in cols for c in pk_cols):
comp_cols = (tcol, *pk_cols)
plans.append(
IndexPlan(
schema=schema,
table=table,
index_name=_short_index_name(table, "time_pk", comp_cols),
columns=comp_cols,
)
)
# 按索引名去重
dedup: Dict[str, IndexPlan] = {}
for p in plans:
dedup[p.index_name] = p
return list(dedup.values())
def _create_index(cur, plan: IndexPlan) -> None:
stmt = sql.SQL("CREATE INDEX IF NOT EXISTS {idx} ON {sch}.{tbl} ({cols})").format(
idx=sql.Identifier(plan.index_name),
sch=sql.Identifier(plan.schema),
tbl=sql.Identifier(plan.table),
cols=sql.SQL(", ").join(sql.Identifier(c) for c in plan.columns),
)
cur.execute(stmt)
def _analyze_table(cur, schema: str, table: str) -> None:
stmt = sql.SQL("ANALYZE {sch}.{tbl}").format(
sch=sql.Identifier(schema),
tbl=sql.Identifier(table),
)
cur.execute(stmt)
def main() -> int:
ap = argparse.ArgumentParser(description="Tune indexes for integrity verification.")
ap.add_argument("--dry-run", action="store_true", help="Print planned SQL only.")
ap.add_argument(
"--skip-analyze",
action="store_true",
help="Create indexes but skip ANALYZE.",
)
args = ap.parse_args()
cfg = AppConfig.load({})
dsn = cfg.get("db.dsn")
timeout_sec = int(cfg.get("db.connect_timeout_sec", 10) or 10)
with psycopg2.connect(dsn, connect_timeout=timeout_sec) as conn:
conn.autocommit = False
with conn.cursor() as cur:
all_plans: List[IndexPlan] = []
for schema in ("billiards_ods", "billiards_dwd"):
for table in _load_tables(cur, schema):
all_plans.extend(_plan_indexes(cur, schema, table))
touched_tables: Set[Tuple[str, str]] = set()
print(f"planned indexes: {len(all_plans)}")
for plan in all_plans:
cols = ", ".join(plan.columns)
print(f"[INDEX] {plan.schema}.{plan.table} ({cols}) -> {plan.index_name}")
if not args.dry_run:
_create_index(cur, plan)
touched_tables.add((plan.schema, plan.table))
if not args.skip_analyze:
if args.dry_run:
for schema, table in sorted({(p.schema, p.table) for p in all_plans}):
print(f"[ANALYZE] {schema}.{table}")
else:
for schema, table in sorted(touched_tables):
_analyze_table(cur, schema, table)
print(f"[ANALYZE] {schema}.{table}")
if args.dry_run:
conn.rollback()
print("dry-run complete; transaction rolled back")
else:
conn.commit()
print("index tuning complete")
return 0
if __name__ == "__main__":
raise SystemExit(main())

26
scripts/run_ods.bat Normal file
View File

@@ -0,0 +1,26 @@
@echo off
REM -*- coding: utf-8 -*-
REM 说明:一键重建 ODS执行 INIT_ODS_SCHEMA并灌入示例 JSON执行 MANUAL_INGEST
setlocal
cd /d "%~dp0\.."
REM 如果需要覆盖示例目录,可修改下面的 INGEST_DIR
set "INGEST_DIR=export\\test-json-doc"
echo [INIT_ODS_SCHEMA] 准备执行,源目录=%INGEST_DIR%
python -m cli.main --tasks INIT_ODS_SCHEMA --pipeline-flow INGEST_ONLY --ingest-source "%INGEST_DIR%"
if errorlevel 1 (
echo INIT_ODS_SCHEMA 失败,退出
exit /b 1
)
echo [MANUAL_INGEST] 准备执行,源目录=%INGEST_DIR%
python -m cli.main --tasks MANUAL_INGEST --pipeline-flow INGEST_ONLY --ingest-source "%INGEST_DIR%"
if errorlevel 1 (
echo MANUAL_INGEST 失败,退出
exit /b 1
)
echo 全部完成。
endlocal

516
scripts/run_update.py Normal file
View File

@@ -0,0 +1,516 @@
# -*- coding: utf-8 -*-
"""
一键增量更新脚本ODS -> DWD -> DWS
用法:
python scripts/run_update.py
"""
from __future__ import annotations
import argparse
import logging
import multiprocessing as mp
import subprocess
import sys
import time as time_mod
from datetime import date, datetime, time, timedelta
from pathlib import Path
from zoneinfo import ZoneInfo
from api.client import APIClient
from config.settings import AppConfig
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations
from orchestration.scheduler import ETLScheduler
from tasks.utility.check_cutoff_task import CheckCutoffTask
from tasks.dwd.dwd_load_task import DwdLoadTask
from tasks.ods.ods_tasks import ENABLED_ODS_CODES
from utils.logging_utils import build_log_path, configure_logging
STEP_TIMEOUT_SEC = 120
def _coerce_date(s: str) -> date:
s = (s or "").strip()
if not s:
raise ValueError("empty date")
if len(s) >= 10:
s = s[:10]
return date.fromisoformat(s)
def _compute_dws_window(
*,
cfg: AppConfig,
tz: ZoneInfo,
rebuild_days: int,
bootstrap_days: int,
dws_start: date | None,
dws_end: date | None,
) -> tuple[datetime, datetime]:
if dws_start and dws_end and dws_end < dws_start:
raise ValueError("dws_end must be >= dws_start")
store_id = int(cfg.get("app.store_id"))
dsn = cfg["db"]["dsn"]
session = cfg["db"].get("session")
conn = DatabaseConnection(dsn=dsn, session=session)
try:
if dws_start is None:
row = conn.query(
"SELECT MAX(order_date) AS mx FROM billiards_dws.dws_order_summary WHERE site_id=%s",
(store_id,),
)
mx = (row[0] or {}).get("mx") if row else None
if isinstance(mx, date):
dws_start = mx - timedelta(days=max(0, int(rebuild_days)))
else:
dws_start = (datetime.now(tz).date()) - timedelta(days=max(1, int(bootstrap_days)))
if dws_end is None:
dws_end = datetime.now(tz).date()
finally:
conn.close()
start_dt = datetime.combine(dws_start, time.min).replace(tzinfo=tz)
# end_dt 取到当天 23:59:59避免只跑到“当前时刻”的 date() 导致少一天
end_dt = datetime.combine(dws_end, time.max).replace(tzinfo=tz)
return start_dt, end_dt
def _run_check_cutoff(cfg: AppConfig, logger: logging.Logger):
dsn = cfg["db"]["dsn"]
session = cfg["db"].get("session")
db_conn = DatabaseConnection(dsn=dsn, session=session)
db_ops = DatabaseOperations(db_conn)
api = APIClient(
base_url=cfg["api"]["base_url"],
token=cfg["api"]["token"],
timeout=cfg["api"]["timeout_sec"],
retry_max=cfg["api"]["retries"]["max_attempts"],
headers_extra=cfg["api"].get("headers_extra"),
)
try:
CheckCutoffTask(cfg, db_ops, api, logger).execute(None)
finally:
db_conn.close()
def _iter_daily_windows(window_start: datetime, window_end: datetime) -> list[tuple[datetime, datetime]]:
if window_start > window_end:
return []
tz = window_start.tzinfo
windows: list[tuple[datetime, datetime]] = []
cur = window_start
while cur <= window_end:
day_start = datetime.combine(cur.date(), time.min).replace(tzinfo=tz)
day_end = datetime.combine(cur.date(), time.max).replace(tzinfo=tz)
if day_start < window_start:
day_start = window_start
if day_end > window_end:
day_end = window_end
windows.append((day_start, day_end))
next_day = cur.date() + timedelta(days=1)
cur = datetime.combine(next_day, time.min).replace(tzinfo=tz)
return windows
def _run_step_worker(result_queue: "mp.Queue[dict[str, str]]", step: dict[str, str]) -> None:
if hasattr(sys.stdout, "reconfigure"):
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
log_file = step.get("log_file") or ""
log_level = step.get("log_level") or "INFO"
log_console = bool(step.get("log_console", True))
log_path = Path(log_file) if log_file else None
with configure_logging(
"etl_update",
log_path,
level=log_level,
console=log_console,
tee_std=True,
) as logger:
cfg_base = AppConfig.load({})
step_type = step.get("type", "")
try:
if step_type == "check_cutoff":
_run_check_cutoff(cfg_base, logger)
elif step_type == "ods_task":
task_code = step["task_code"]
overlap_seconds = int(step.get("overlap_seconds", 0))
cfg_ods = AppConfig.load(
{
"pipeline": {"flow": "FULL"},
"run": {"tasks": [task_code], "overlap_seconds": overlap_seconds},
}
)
scheduler = ETLScheduler(cfg_ods, logger)
try:
scheduler.run_tasks([task_code])
finally:
scheduler.close()
elif step_type == "init_dws_schema":
overlap_seconds = int(step.get("overlap_seconds", 0))
cfg_dwd = AppConfig.load(
{
"pipeline": {"flow": "INGEST_ONLY"},
"run": {"tasks": ["INIT_DWS_SCHEMA"], "overlap_seconds": overlap_seconds},
}
)
scheduler = ETLScheduler(cfg_dwd, logger)
try:
scheduler.run_tasks(["INIT_DWS_SCHEMA"])
finally:
scheduler.close()
elif step_type == "dwd_table":
dwd_table = step["dwd_table"]
overlap_seconds = int(step.get("overlap_seconds", 0))
cfg_dwd = AppConfig.load(
{
"pipeline": {"flow": "INGEST_ONLY"},
"run": {"tasks": ["DWD_LOAD_FROM_ODS"], "overlap_seconds": overlap_seconds},
"dwd": {"only_tables": [dwd_table]},
}
)
scheduler = ETLScheduler(cfg_dwd, logger)
try:
scheduler.run_tasks(["DWD_LOAD_FROM_ODS"])
finally:
scheduler.close()
elif step_type == "dws_window":
overlap_seconds = int(step.get("overlap_seconds", 0))
window_start = step["window_start"]
window_end = step["window_end"]
cfg_dws = AppConfig.load(
{
"pipeline": {"flow": "INGEST_ONLY"},
"run": {
"tasks": ["DWS_BUILD_ORDER_SUMMARY"],
"overlap_seconds": overlap_seconds,
"window_override": {"start": window_start, "end": window_end},
},
}
)
scheduler = ETLScheduler(cfg_dws, logger)
try:
scheduler.run_tasks(["DWS_BUILD_ORDER_SUMMARY"])
finally:
scheduler.close()
elif step_type == "ods_gap_check":
overlap_hours = int(step.get("overlap_hours", 24))
window_days = int(step.get("window_days", 1))
window_hours = int(step.get("window_hours", 0))
page_size = int(step.get("page_size", 0) or 0)
sleep_per_window = float(step.get("sleep_per_window", 0) or 0)
sleep_per_page = float(step.get("sleep_per_page", 0) or 0)
tag = step.get("tag", "run_update")
task_codes = (step.get("task_codes") or "").strip()
script_dir = Path(__file__).resolve().parent.parent
script_path = script_dir / "scripts" / "check" / "check_ods_gaps.py"
cmd = [
sys.executable,
str(script_path),
"--from-cutoff",
"--cutoff-overlap-hours",
str(overlap_hours),
"--window-days",
str(window_days),
"--tag",
str(tag),
]
if window_hours > 0:
cmd += ["--window-hours", str(window_hours)]
if page_size > 0:
cmd += ["--page-size", str(page_size)]
if sleep_per_window > 0:
cmd += ["--sleep-per-window-seconds", str(sleep_per_window)]
if sleep_per_page > 0:
cmd += ["--sleep-per-page-seconds", str(sleep_per_page)]
if task_codes:
cmd += ["--task-codes", task_codes]
subprocess.run(cmd, check=True, cwd=str(script_dir))
else:
raise ValueError(f"Unknown step type: {step_type}")
result_queue.put({"status": "ok"})
except Exception as exc:
result_queue.put({"status": "error", "error": str(exc)})
def _run_step_with_timeout(
step: dict[str, str], logger: logging.Logger, timeout_sec: int
) -> dict[str, object]:
start = time_mod.monotonic()
step_timeout = timeout_sec
if step.get("timeout_sec"):
try:
step_timeout = int(step.get("timeout_sec"))
except Exception:
step_timeout = timeout_sec
ctx = mp.get_context("spawn")
result_queue: mp.Queue = ctx.Queue()
proc = ctx.Process(target=_run_step_worker, args=(result_queue, step))
proc.start()
proc.join(timeout=step_timeout)
elapsed = time_mod.monotonic() - start
if proc.is_alive():
logger.error(
"STEP_TIMEOUT name=%s elapsed=%.2fs limit=%ss", step["name"], elapsed, step_timeout
)
proc.terminate()
proc.join(10)
return {"name": step["name"], "status": "timeout", "elapsed": elapsed}
result: dict[str, object] = {"name": step["name"], "status": "error", "elapsed": elapsed}
try:
payload = result_queue.get_nowait()
except Exception:
payload = {}
if payload:
result.update(payload)
if result.get("status") == "ok":
logger.info("STEP_OK name=%s elapsed=%.2fs", step["name"], elapsed)
else:
logger.error(
"STEP_FAIL name=%s elapsed=%.2fs error=%s",
step["name"],
elapsed,
result.get("error"),
)
return result
def main() -> int:
if hasattr(sys.stdout, "reconfigure"):
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
parser = argparse.ArgumentParser(description="One-click ETL update (ODS -> DWD -> DWS)")
parser.add_argument("--overlap-seconds", type=int, default=3600, help="overlap seconds (default: 3600)")
parser.add_argument(
"--dws-rebuild-days",
type=int,
default=1,
help="DWS 回算冗余天数default: 1",
)
parser.add_argument(
"--dws-bootstrap-days",
type=int,
default=30,
help="DWS 首次/空表时回算天数default: 30",
)
parser.add_argument("--dws-start", type=str, default="", help="DWS 回算开始日期 YYYY-MM-DD可选")
parser.add_argument("--dws-end", type=str, default="", help="DWS 回算结束日期 YYYY-MM-DD可选")
parser.add_argument(
"--skip-cutoff",
action="store_true",
help="跳过 CHECK_CUTOFF默认会在开始/结束各跑一次)",
)
parser.add_argument(
"--skip-ods",
action="store_true",
help="跳过 ODS 在线抓取(仅跑 DWD/DWS",
)
parser.add_argument(
"--ods-tasks",
type=str,
default="",
help="指定要跑的 ODS 任务(逗号分隔),默认跑全部 ENABLED_ODS_CODES",
)
parser.add_argument(
"--check-ods-gaps",
action="store_true",
help="run ODS gap check after ODS load (default: off)",
)
parser.add_argument(
"--check-ods-overlap-hours",
type=int,
default=24,
help="gap check overlap hours from cutoff (default: 24)",
)
parser.add_argument(
"--check-ods-window-days",
type=int,
default=1,
help="gap check window days (default: 1)",
)
parser.add_argument(
"--check-ods-window-hours",
type=int,
default=0,
help="gap check window hours (default: 0)",
)
parser.add_argument(
"--check-ods-page-size",
type=int,
default=200,
help="gap check API page size (default: 200)",
)
parser.add_argument(
"--check-ods-timeout-sec",
type=int,
default=1800,
help="gap check timeout seconds (default: 1800)",
)
parser.add_argument(
"--check-ods-task-codes",
type=str,
default="",
help="gap check task codes (comma-separated, optional)",
)
parser.add_argument(
"--check-ods-sleep-per-window-seconds",
type=float,
default=0,
help="gap check sleep seconds after each window (default: 0)",
)
parser.add_argument(
"--check-ods-sleep-per-page-seconds",
type=float,
default=0,
help="gap check sleep seconds after each page (default: 0)",
)
parser.add_argument("--log-file", type=str, default="", help="log file path (default: logs/run_update_YYYYMMDD_HHMMSS.log)")
parser.add_argument("--log-dir", type=str, default="", help="log directory (default: logs)")
parser.add_argument("--log-level", type=str, default="INFO", help="log level (default: INFO)")
parser.add_argument("--no-log-console", action="store_true", help="disable console logging")
args = parser.parse_args()
log_dir = Path(args.log_dir) if args.log_dir else (Path(__file__).resolve().parent.parent / "logs")
log_file = Path(args.log_file) if args.log_file else build_log_path(log_dir, "run_update")
log_console = not args.no_log_console
with configure_logging(
"etl_update",
log_file,
level=args.log_level,
console=log_console,
tee_std=True,
) as logger:
cfg_base = AppConfig.load({})
tz = ZoneInfo(cfg_base.get("app.timezone", "Asia/Taipei"))
dws_start = _coerce_date(args.dws_start) if args.dws_start else None
dws_end = _coerce_date(args.dws_end) if args.dws_end else None
steps: list[dict[str, str]] = []
if not args.skip_cutoff:
steps.append({"name": "CHECK_CUTOFF:before", "type": "check_cutoff"})
# ------------------------------------------------------------------ ODS在线抓取 + 写入)
if not args.skip_ods:
if args.ods_tasks:
ods_tasks = [t.strip().upper() for t in args.ods_tasks.split(",") if t.strip()]
else:
ods_tasks = sorted(ENABLED_ODS_CODES)
for task_code in ods_tasks:
steps.append(
{
"name": f"ODS:{task_code}",
"type": "ods_task",
"task_code": task_code,
"overlap_seconds": str(args.overlap_seconds),
}
)
if args.check_ods_gaps:
steps.append(
{
"name": "ODS_GAP_CHECK",
"type": "ods_gap_check",
"overlap_hours": str(args.check_ods_overlap_hours),
"window_days": str(args.check_ods_window_days),
"window_hours": str(args.check_ods_window_hours),
"page_size": str(args.check_ods_page_size),
"sleep_per_window": str(args.check_ods_sleep_per_window_seconds),
"sleep_per_page": str(args.check_ods_sleep_per_page_seconds),
"timeout_sec": str(args.check_ods_timeout_sec),
"task_codes": str(args.check_ods_task_codes or ""),
"tag": "run_update",
}
)
# ------------------------------------------------------------------ DWD从 ODS 表装载)
steps.append(
{
"name": "INIT_DWS_SCHEMA",
"type": "init_dws_schema",
"overlap_seconds": str(args.overlap_seconds),
}
)
for dwd_table in DwdLoadTask.TABLE_MAP.keys():
steps.append(
{
"name": f"DWD:{dwd_table}",
"type": "dwd_table",
"dwd_table": dwd_table,
"overlap_seconds": str(args.overlap_seconds),
}
)
# ------------------------------------------------------------------ DWS按日期窗口重建
window_start, window_end = _compute_dws_window(
cfg=cfg_base,
tz=tz,
rebuild_days=int(args.dws_rebuild_days),
bootstrap_days=int(args.dws_bootstrap_days),
dws_start=dws_start,
dws_end=dws_end,
)
for start_dt, end_dt in _iter_daily_windows(window_start, window_end):
steps.append(
{
"name": f"DWS:{start_dt.date().isoformat()}",
"type": "dws_window",
"window_start": start_dt.strftime("%Y-%m-%d %H:%M:%S"),
"window_end": end_dt.strftime("%Y-%m-%d %H:%M:%S"),
"overlap_seconds": str(args.overlap_seconds),
}
)
if not args.skip_cutoff:
steps.append({"name": "CHECK_CUTOFF:after", "type": "check_cutoff"})
for step in steps:
step["log_file"] = str(log_file)
step["log_level"] = args.log_level
step["log_console"] = log_console
step_results: list[dict[str, object]] = []
for step in steps:
logger.info("STEP_START name=%s timeout=%ss", step["name"], STEP_TIMEOUT_SEC)
result = _run_step_with_timeout(step, logger, STEP_TIMEOUT_SEC)
step_results.append(result)
total = len(step_results)
ok_count = sum(1 for r in step_results if r.get("status") == "ok")
timeout_count = sum(1 for r in step_results if r.get("status") == "timeout")
fail_count = total - ok_count - timeout_count
logger.info(
"STEP_SUMMARY total=%s ok=%s failed=%s timeout=%s",
total,
ok_count,
fail_count,
timeout_count,
)
for item in sorted(step_results, key=lambda r: float(r.get("elapsed", 0.0)), reverse=True):
logger.info(
"STEP_RESULT name=%s status=%s elapsed=%.2fs",
item.get("name"),
item.get("status"),
item.get("elapsed", 0.0),
)
logger.info("Update done.")
return 0
if __name__ == "__main__":
raise SystemExit(main())