init: 项目初始提交 - NeoZQYY Monorepo 完整代码

This commit is contained in:
Neo
2026-02-15 14:58:14 +08:00
commit ded6dfb9d8
769 changed files with 182616 additions and 0 deletions

View File

@@ -0,0 +1,40 @@
# scripts/ — 运维与工具脚本
## 子目录
| 目录 | 用途 | 典型场景 |
|------|------|----------|
| `audit/` | 仓库审计(文件清单、调用流、文档对齐分析) | `python -m scripts.audit.run_audit` |
| `check/` | 数据检查ODS 缺口、内容哈希、完整性校验) | `python -m scripts.check.check_data_integrity` |
| `db_admin/` | 数据库管理Excel 导入 DWS 支出/回款/提成) | `python scripts/db_admin/import_dws_excel.py --type expense` |
| `export/` | 数据导出(指数、团购、亲密度、会员明细等) | `python scripts/export/export_index_tables.py` |
| `rebuild/` | 数据重建(全量 ODS→DWD 重建) | `python scripts/rebuild/rebuild_db_and_run_ods_to_dwd.py` |
| `repair/` | 数据修复回填、去重、hash 修复、维度修复) | `python scripts/repair/dedupe_ods_snapshots.py` |
## 根目录脚本
- `run_update.py` — 一键增量更新ODS → DWD → DWS适合 cron/计划任务调用
- `run_ods.bat` — Windows 批处理ODS 建表 + 灌入示例 JSON
- `compare_ddl_db.py` — DDL 文件与数据库实际表结构对比(支持 `--all` 对比四个 schema
- `validate_bd_manual.py` — BD_Manual 文档体系验证(覆盖率、格式、命名规范)
## 运行方式
所有脚本在项目根目录(`C:\ZQYY\FQ-ETL`)执行:
```bash
# 审计报告生成
python -m scripts.audit.run_audit
# 一键增量更新
python scripts/run_update.py
# 数据完整性检查(需要数据库连接)
python -m scripts.check.check_data_integrity --window-start "2025-01-01" --window-end "2025-02-01"
```
## 注意事项
- 所有脚本依赖 `.env` 中的 `PG_DSN` 配置(或环境变量)
- `rebuild/` 下的脚本会重建 Schema生产环境慎用
- `repair/` 下的脚本会修改数据,建议先 `--dry-run`(如支持)

View File

@@ -0,0 +1 @@
# 脚本辅助工具包标记。

View File

@@ -0,0 +1,107 @@
# -*- coding: utf-8 -*-
"""
仓库治理只读审计 — 共享数据模型
定义审计脚本各模块共用的 dataclass 和枚举类型。
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
# ---------------------------------------------------------------------------
# 文件元信息
# ---------------------------------------------------------------------------
@dataclass
class FileEntry:
"""单个文件/目录的元信息。"""
rel_path: str # 相对于仓库根目录的路径
is_dir: bool # 是否为目录
size_bytes: int # 文件大小(目录为 0
extension: str # 文件扩展名(小写,含点号)
is_empty_dir: bool # 是否为空目录
# ---------------------------------------------------------------------------
# 用途分类与处置标签
# ---------------------------------------------------------------------------
class Category(str, Enum):
"""文件用途分类。"""
CORE_CODE = "核心代码"
CONFIG = "配置"
DATABASE_DEF = "数据库定义"
TEST = "测试"
DOCS = "文档"
SCRIPTS = "脚本工具"
GUI = "GUI"
BUILD_DEPLOY = "构建与部署"
LOG_OUTPUT = "日志与输出"
TEMP_DEBUG = "临时与调试"
OTHER = "其他"
class Disposition(str, Enum):
"""处置标签。"""
KEEP = "保留"
CANDIDATE_DELETE = "候选删除"
CANDIDATE_ARCHIVE = "候选归档"
NEEDS_REVIEW = "待确认"
# ---------------------------------------------------------------------------
# 文件清单条目
# ---------------------------------------------------------------------------
@dataclass
class InventoryItem:
"""清单条目:路径 + 分类 + 处置 + 说明。"""
rel_path: str
category: Category
disposition: Disposition
description: str
# ---------------------------------------------------------------------------
# 流程树节点
# ---------------------------------------------------------------------------
@dataclass
class FlowNode:
"""流程树节点。"""
name: str # 节点名称(模块名/类名/函数名)
source_file: str # 所在源文件路径
node_type: str # 类型entry / module / class / function
children: list[FlowNode] = field(default_factory=list)
# ---------------------------------------------------------------------------
# 文档对齐
# ---------------------------------------------------------------------------
@dataclass
class DocMapping:
"""文档与代码的映射关系。"""
doc_path: str # 文档文件路径
doc_topic: str # 文档主题
related_code: list[str] # 关联的代码文件/模块
status: str # 状态aligned / stale / conflict / orphan
@dataclass
class AlignmentIssue:
"""对齐问题。"""
doc_path: str # 文档路径
issue_type: str # stale / conflict / missing
description: str # 问题描述
related_code: str # 关联代码路径

View File

@@ -0,0 +1,608 @@
# -*- coding: utf-8 -*-
"""
文档对齐分析器 — 检查文档与代码之间的映射关系、过期点、冲突点和缺失点。
文档来源:
- docs/ 目录(.md, .txt, .csv, .json
- 根目录 README.md
- 各模块内的 README.md
- .kiro/steering/ 引导文件
- docs/test-json-doc/ API 响应样本
"""
from __future__ import annotations
import json
import re
from datetime import datetime, timezone
from pathlib import Path
from scripts.audit import AlignmentIssue, DocMapping
# ---------------------------------------------------------------------------
# 常量
# ---------------------------------------------------------------------------
# 文档文件扩展名
_DOC_EXTENSIONS = {".md", ".txt", ".csv"}
# 核心代码目录——缺少文档时应报告
_CORE_CODE_DIRS = {
"tasks",
"loaders",
"orchestration",
"quality",
"models",
"utils",
"api",
"scd",
"config",
"database",
}
# ODS 表中的通用元数据列,比对时忽略
_ODS_META_COLUMNS = {"content_hash", "payload", "created_at", "updated_at", "id"}
# SQL 关键字,解析 DDL 列名时排除
_SQL_KEYWORDS = {
"primary", "key", "not", "null", "default", "unique", "check",
"references", "foreign", "constraint", "index", "create", "table",
"if", "exists", "serial", "bigserial", "true", "false",
}
# ---------------------------------------------------------------------------
# 安全读取文件(编码回退)
# ---------------------------------------------------------------------------
def _safe_read(path: Path) -> str:
"""尝试以 utf-8 → gbk → latin-1 回退读取文件内容。"""
for enc in ("utf-8", "gbk", "latin-1"):
try:
return path.read_text(encoding=enc)
except (UnicodeDecodeError, UnicodeError):
continue
return ""
# ---------------------------------------------------------------------------
# scan_docs — 扫描所有文档来源
# ---------------------------------------------------------------------------
def scan_docs(repo_root: Path) -> list[str]:
"""扫描所有文档文件路径,返回相对路径列表(已排序)。
文档来源:
1. docs/ 目录下的 .md, .txt, .csv, .json 文件
2. 根目录 README.md
3. 各模块内的 README.md如 gui/README.md
4. .kiro/steering/ 引导文件
"""
results: list[str] = []
def _rel(p: Path) -> str:
"""返回归一化的正斜杠相对路径。"""
return str(p.relative_to(repo_root)).replace("\\", "/")
# 1. docs/ 目录(递归,含 test-json-doc 下的 .json
docs_dir = repo_root / "docs"
if docs_dir.is_dir():
for p in docs_dir.rglob("*"):
if p.is_file():
ext = p.suffix.lower()
if ext in _DOC_EXTENSIONS or ext == ".json":
results.append(_rel(p))
# 2. 根目录 README.md
root_readme = repo_root / "README.md"
if root_readme.is_file():
results.append("README.md")
# 3. 各模块内的 README.md
for child in sorted(repo_root.iterdir()):
if child.is_dir() and child.name not in ("docs", ".kiro"):
readme = child / "README.md"
if readme.is_file():
results.append(_rel(readme))
# 4. .kiro/steering/
steering_dir = repo_root / ".kiro" / "steering"
if steering_dir.is_dir():
for p in sorted(steering_dir.iterdir()):
if p.is_file():
results.append(_rel(p))
return sorted(set(results))
# ---------------------------------------------------------------------------
# extract_code_references — 从文档提取代码引用
# ---------------------------------------------------------------------------
def extract_code_references(doc_path: Path) -> list[str]:
"""从文档中提取代码引用(反引号内的文件路径、类名、函数名等)。
规则:
- 提取反引号内的内容
- 跳过单字符引用
- 跳过纯数字/版本号
- 反斜杠归一化为正斜杠
- 去重
"""
if not doc_path.is_file():
return []
text = _safe_read(doc_path)
if not text:
return []
# 提取反引号内容
backtick_refs = re.findall(r"`([^`]+)`", text)
seen: set[str] = set()
results: list[str] = []
for raw in backtick_refs:
ref = raw.strip()
# 归一化反斜杠
ref = ref.replace("\\", "/")
# 跳过单字符
if len(ref) <= 1:
continue
# 跳过纯数字和版本号
if re.fullmatch(r"[\d.]+", ref):
continue
# 去重
if ref in seen:
continue
seen.add(ref)
results.append(ref)
return results
# ---------------------------------------------------------------------------
# check_reference_validity — 检查引用有效性
# ---------------------------------------------------------------------------
def check_reference_validity(ref: str, repo_root: Path) -> bool:
"""检查文档中的代码引用是否仍然有效。
检查策略:
1. 直接作为文件/目录路径检查
2. 去掉 FQ-ETL/ 前缀后检查(兼容旧文档引用)
3. 将点号路径转为文件路径检查(如 config.settings → config/settings.py
"""
# 1. 直接路径
if (repo_root / ref).exists():
return True
# 2. 去掉旧包名前缀(兼容历史文档)
for prefix in ("FQ-ETL/", "etl_billiards/"):
if ref.startswith(prefix):
stripped = ref[len(prefix):]
if (repo_root / stripped).exists():
return True
# 3. 点号模块路径 → 文件路径
if "." in ref and "/" not in ref:
as_path = ref.replace(".", "/") + ".py"
if (repo_root / as_path).exists():
return True
# 也可能是目录(包)
as_dir = ref.replace(".", "/")
if (repo_root / as_dir).is_dir():
return True
return False
# ---------------------------------------------------------------------------
# find_undocumented_modules — 找出缺少文档的核心代码模块
# ---------------------------------------------------------------------------
def find_undocumented_modules(
repo_root: Path,
documented: set[str],
) -> list[str]:
"""找出缺少文档的核心代码模块。
只检查 _CORE_CODE_DIRS 中的 .py 文件(排除 __init__.py
返回已排序的相对路径列表。
"""
undocumented: list[str] = []
for core_dir in sorted(_CORE_CODE_DIRS):
dir_path = repo_root / core_dir
if not dir_path.is_dir():
continue
for py_file in dir_path.rglob("*.py"):
if py_file.name == "__init__.py":
continue
rel = str(py_file.relative_to(repo_root))
# 归一化路径分隔符
rel = rel.replace("\\", "/")
if rel not in documented:
undocumented.append(rel)
return sorted(undocumented)
# ---------------------------------------------------------------------------
# DDL / 数据字典解析辅助函数
# ---------------------------------------------------------------------------
def _parse_ddl_tables(sql: str) -> dict[str, set[str]]:
"""从 DDL SQL 中提取表名和列名。
返回 {表名: {列名集合}} 字典。
支持带 schema 前缀的表名(如 billiards_dwd.dim_member → dim_member
"""
tables: dict[str, set[str]] = {}
# 匹配 CREATE TABLE [IF NOT EXISTS] [schema.]table_name (
create_re = re.compile(
r"CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?"
r"(?:\w+\.)?(\w+)\s*\(",
re.IGNORECASE,
)
for match in create_re.finditer(sql):
table_name = match.group(1)
# 找到对应的括号内容
start = match.end()
depth = 1
pos = start
while pos < len(sql) and depth > 0:
if sql[pos] == "(":
depth += 1
elif sql[pos] == ")":
depth -= 1
pos += 1
body = sql[start:pos - 1]
columns: set[str] = set()
# 逐行提取列名——取每行第一个标识符
for line in body.split("\n"):
line = line.strip().rstrip(",")
if not line:
continue
# 提取第一个单词
col_match = re.match(r"(\w+)", line)
if col_match:
col_name = col_match.group(1).lower()
# 排除 SQL 关键字
if col_name not in _SQL_KEYWORDS:
columns.add(col_name)
tables[table_name] = columns
return tables
def _parse_dictionary_tables(md: str) -> dict[str, set[str]]:
"""从数据字典 Markdown 中提取表名和字段名。
约定:
- 表名出现在 ## 标题中(可能带反引号)
- 字段名出现在 Markdown 表格的第一列
- 跳过表头行(含"字段"字样)和分隔行(含 ---
"""
tables: dict[str, set[str]] = {}
current_table: str | None = None
for line in md.split("\n"):
# 匹配 ## 标题中的表名
heading = re.match(r"^##\s+`?(\w+)`?", line)
if heading:
current_table = heading.group(1)
tables[current_table] = set()
continue
if current_table is None:
continue
# 跳过分隔行
if re.match(r"^\s*\|[-\s|]+\|\s*$", line):
continue
# 解析表格行
row_match = re.match(r"^\s*\|\s*(\S+)", line)
if row_match:
field = row_match.group(1)
# 跳过表头(含"字段"字样)
if field in ("字段",):
continue
tables[current_table].add(field)
return tables
# ---------------------------------------------------------------------------
# check_ddl_vs_dictionary — DDL 与数据字典比对
# ---------------------------------------------------------------------------
def check_ddl_vs_dictionary(repo_root: Path) -> list[AlignmentIssue]:
"""比对 DDL 文件与数据字典文档的覆盖度。
检查:
1. DDL 中有但字典中没有的表 → missing
2. 同名表中 DDL 有但字典没有的列 → conflict
"""
issues: list[AlignmentIssue] = []
# 收集所有 DDL 表定义
ddl_tables: dict[str, set[str]] = {}
db_dir = repo_root / "database"
if db_dir.is_dir():
for sql_file in sorted(db_dir.glob("schema_*.sql")):
content = _safe_read(sql_file)
for tbl, cols in _parse_ddl_tables(content).items():
if tbl in ddl_tables:
ddl_tables[tbl] |= cols
else:
ddl_tables[tbl] = set(cols)
# 收集所有数据字典表定义
dict_tables: dict[str, set[str]] = {}
docs_dir = repo_root / "docs"
if docs_dir.is_dir():
for dict_file in sorted(docs_dir.glob("*dictionary*.md")):
content = _safe_read(dict_file)
for tbl, fields in _parse_dictionary_tables(content).items():
if tbl in dict_tables:
dict_tables[tbl] |= fields
else:
dict_tables[tbl] = set(fields)
# 比对
for tbl, ddl_cols in sorted(ddl_tables.items()):
if tbl not in dict_tables:
issues.append(AlignmentIssue(
doc_path="docs/*dictionary*.md",
issue_type="missing",
description=f"DDL 定义了表 `{tbl}`,但数据字典中未收录",
related_code=f"database/schema_*.sql ({tbl})",
))
else:
# 检查列差异
dict_cols = dict_tables[tbl]
missing_cols = ddl_cols - dict_cols
for col in sorted(missing_cols):
issues.append(AlignmentIssue(
doc_path="docs/*dictionary*.md",
issue_type="conflict",
description=f"表 `{tbl}` 的列 `{col}` 在 DDL 中存在但数据字典中缺失",
related_code=f"database/schema_*.sql ({tbl}.{col})",
))
return issues
# ---------------------------------------------------------------------------
# check_api_samples_vs_parsers — API 样本与解析器比对
# ---------------------------------------------------------------------------
def check_api_samples_vs_parsers(repo_root: Path) -> list[AlignmentIssue]:
"""比对 API 响应样本与 ODS 表结构的一致性。
策略:
1. 扫描 docs/test-json-doc/ 下的 .json 文件
2. 提取 JSON 中的顶层字段名
3. 从 ODS DDL 中查找同名表
4. 比对字段差异(忽略 ODS 元数据列)
"""
issues: list[AlignmentIssue] = []
sample_dir = repo_root / "docs" / "test-json-doc"
if not sample_dir.is_dir():
return issues
# 收集 ODS 表定义(保留全部列,比对时忽略元数据列)
ods_tables: dict[str, set[str]] = {}
db_dir = repo_root / "database"
if db_dir.is_dir():
for sql_file in sorted(db_dir.glob("schema_*ODS*.sql")):
content = _safe_read(sql_file)
for tbl, cols in _parse_ddl_tables(content).items():
ods_tables[tbl] = cols
# 逐个样本文件比对
for json_file in sorted(sample_dir.glob("*.json")):
entity_name = json_file.stem # 文件名(不含扩展名)作为实体名
# 解析 JSON 样本
try:
content = _safe_read(json_file)
data = json.loads(content)
except (json.JSONDecodeError, ValueError):
continue
# 提取顶层字段名
sample_fields: set[str] = set()
if isinstance(data, list) and data:
# 数组格式——取第一个元素的键
first = data[0]
if isinstance(first, dict):
sample_fields = set(first.keys())
elif isinstance(data, dict):
sample_fields = set(data.keys())
if not sample_fields:
continue
# 查找匹配的 ODS 表
matched_table: str | None = None
matched_cols: set[str] = set()
for tbl, cols in ods_tables.items():
# 表名包含实体名(如 test_entity 匹配 billiards_ods.test_entity
tbl_lower = tbl.lower()
entity_lower = entity_name.lower()
if entity_lower in tbl_lower or tbl_lower == entity_lower:
matched_table = tbl
matched_cols = cols
break
if matched_table is None:
continue
# 比对:样本中有但 ODS 表中没有的字段
extra_fields = sample_fields - matched_cols
for field in sorted(extra_fields):
issues.append(AlignmentIssue(
doc_path=f"docs/test-json-doc/{json_file.name}",
issue_type="conflict",
description=(
f"API 样本字段 `{field}` 在 ODS 表 `{matched_table}` 中未定义"
),
related_code=f"database/schema_*ODS*.sql ({matched_table})",
))
return issues
# ---------------------------------------------------------------------------
# build_mappings — 构建文档与代码的映射关系
# ---------------------------------------------------------------------------
def build_mappings(
doc_paths: list[str],
repo_root: Path,
) -> list[DocMapping]:
"""为每份文档建立与代码模块的映射关系。"""
mappings: list[DocMapping] = []
for doc_rel in doc_paths:
doc_path = repo_root / doc_rel
refs = extract_code_references(doc_path)
# 确定关联代码和状态
valid_refs: list[str] = []
has_stale = False
for ref in refs:
if check_reference_validity(ref, repo_root):
valid_refs.append(ref)
else:
has_stale = True
# 推断文档主题(取文件名或第一行标题)
topic = _infer_topic(doc_path, doc_rel)
if not refs:
status = "orphan"
elif has_stale:
status = "stale"
else:
status = "aligned"
mappings.append(DocMapping(
doc_path=doc_rel,
doc_topic=topic,
related_code=valid_refs,
status=status,
))
return mappings
def _infer_topic(doc_path: Path, doc_rel: str) -> str:
"""从文档推断主题——优先取 Markdown 一级标题,否则用文件名。"""
if doc_path.is_file() and doc_path.suffix.lower() in (".md", ".txt"):
try:
text = _safe_read(doc_path)
for line in text.split("\n"):
line = line.strip()
if line.startswith("# "):
return line[2:].strip()
except Exception:
pass
return doc_rel
# ---------------------------------------------------------------------------
# render_alignment_report — 生成 Markdown 格式的文档对齐报告
# ---------------------------------------------------------------------------
def render_alignment_report(
mappings: list[DocMapping],
issues: list[AlignmentIssue],
repo_root: str,
) -> str:
"""生成 Markdown 格式的文档对齐报告。
分区:映射关系表、过期点列表、冲突点列表、缺失点列表、统计摘要。
"""
lines: list[str] = []
# --- 头部 ---
now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
lines.append("# 文档对齐报告")
lines.append("")
lines.append(f"- 生成时间:{now}")
lines.append(f"- 仓库路径:`{repo_root}`")
lines.append("")
# --- 映射关系 ---
lines.append("## 映射关系")
lines.append("")
if mappings:
lines.append("| 文档路径 | 主题 | 关联代码 | 状态 |")
lines.append("|---|---|---|---|")
for m in mappings:
code_str = ", ".join(f"`{c}`" for c in m.related_code) if m.related_code else ""
lines.append(f"| `{m.doc_path}` | {m.doc_topic} | {code_str} | {m.status} |")
else:
lines.append("未发现文档映射关系。")
lines.append("")
# --- 按 issue_type 分组 ---
stale = [i for i in issues if i.issue_type == "stale"]
conflict = [i for i in issues if i.issue_type == "conflict"]
missing = [i for i in issues if i.issue_type == "missing"]
# --- 过期点 ---
lines.append("## 过期点")
lines.append("")
if stale:
lines.append("| 文档路径 | 描述 | 关联代码 |")
lines.append("|---|---|---|")
for i in stale:
lines.append(f"| `{i.doc_path}` | {i.description} | `{i.related_code}` |")
else:
lines.append("未发现过期点。")
lines.append("")
# --- 冲突点 ---
lines.append("## 冲突点")
lines.append("")
if conflict:
lines.append("| 文档路径 | 描述 | 关联代码 |")
lines.append("|---|---|---|")
for i in conflict:
lines.append(f"| `{i.doc_path}` | {i.description} | `{i.related_code}` |")
else:
lines.append("未发现冲突点。")
lines.append("")
# --- 缺失点 ---
lines.append("## 缺失点")
lines.append("")
if missing:
lines.append("| 文档路径 | 描述 | 关联代码 |")
lines.append("|---|---|---|")
for i in missing:
lines.append(f"| `{i.doc_path}` | {i.description} | `{i.related_code}` |")
else:
lines.append("未发现缺失点。")
lines.append("")
# --- 统计摘要 ---
lines.append("## 统计摘要")
lines.append("")
lines.append(f"- 文档总数:{len(mappings)}")
lines.append(f"- 过期点数量:{len(stale)}")
lines.append(f"- 冲突点数量:{len(conflict)}")
lines.append(f"- 缺失点数量:{len(missing)}")
lines.append("")
return "\n".join(lines)

View File

@@ -0,0 +1,618 @@
# -*- coding: utf-8 -*-
"""
流程树分析器 — 通过静态分析 Python 源码的 import 语句和类继承关系,
构建从入口到末端模块的调用树。
仅执行只读操作:读取并解析 Python 源文件,不修改任何文件。
"""
from __future__ import annotations
import ast
import logging
import re
import sys
from datetime import datetime, timezone
from pathlib import Path
from scripts.audit import FileEntry, FlowNode
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# 项目内部包名列表(顶层目录中属于项目代码的包)
# ---------------------------------------------------------------------------
_PROJECT_PACKAGES: set[str] = {
"cli", "config", "api", "database", "tasks", "loaders",
"scd", "orchestration", "quality", "models", "utils",
"gui", "scripts",
}
# ---------------------------------------------------------------------------
# 已知的第三方包和标准库顶层模块(用于排除非项目导入)
# ---------------------------------------------------------------------------
_KNOWN_THIRD_PARTY: set[str] = {
"psycopg2", "requests", "dateutil", "python_dateutil",
"dotenv", "openpyxl", "PySide6", "flask", "pyinstaller",
"PyInstaller", "hypothesis", "pytest", "_pytest", "py",
"pluggy", "pkg_resources", "setuptools", "pip", "wheel",
"tzdata", "six", "certifi", "urllib3", "charset_normalizer",
"idna", "shiboken6",
}
def _is_project_module(module_name: str) -> bool:
"""判断模块名是否属于项目内部模块。"""
top = module_name.split(".")[0]
if top in _PROJECT_PACKAGES:
return True
return False
def _is_stdlib_or_third_party(module_name: str) -> bool:
"""判断模块名是否属于标准库或已知第三方包。"""
top = module_name.split(".")[0]
if top in _KNOWN_THIRD_PARTY:
return True
# 检查标准库
if top in sys.stdlib_module_names:
return True
return False
# ---------------------------------------------------------------------------
# 文件读取(多编码回退)
# ---------------------------------------------------------------------------
def _read_source(filepath: Path) -> str | None:
"""读取 Python 源文件内容,尝试 utf-8 → gbk → latin-1 回退。
返回文件内容字符串,读取失败时返回 None。
"""
for encoding in ("utf-8", "gbk", "latin-1"):
try:
return filepath.read_text(encoding=encoding)
except (UnicodeDecodeError, UnicodeError):
continue
except (OSError, PermissionError) as exc:
logger.warning("无法读取文件 %s: %s", filepath, exc)
return None
logger.warning("无法以任何编码读取文件 %s", filepath)
return None
# ---------------------------------------------------------------------------
# 路径 ↔ 模块名转换
# ---------------------------------------------------------------------------
def _path_to_module_name(rel_path: str) -> str:
"""将相对路径转换为 Python 模块名。
例如:
- "cli/main.py""cli.main"
- "cli/__init__.py""cli"
- "tasks/dws/assistant.py""tasks.dws.assistant"
"""
p = rel_path.replace("\\", "/")
if p.endswith("/__init__.py"):
p = p[: -len("/__init__.py")]
elif p.endswith(".py"):
p = p[:-3]
return p.replace("/", ".")
def _module_to_path(module_name: str) -> str:
"""将模块名转换为相对文件路径(优先 .py 文件)。
例如:
- "cli.main""cli/main.py"
- "cli""cli/__init__.py"
"""
return module_name.replace(".", "/") + ".py"
# ---------------------------------------------------------------------------
# parse_imports — 解析 Python 文件的 import 语句
# ---------------------------------------------------------------------------
def parse_imports(filepath: Path) -> list[str]:
"""使用 ast 模块解析 Python 文件的 import 语句,返回被导入的本地模块列表。
- 仅返回项目内部模块(排除标准库和第三方包)
- 结果去重
- 语法错误或文件不存在时返回空列表
"""
if not filepath.exists():
return []
source = _read_source(filepath)
if source is None:
return []
try:
tree = ast.parse(source, filename=str(filepath))
except SyntaxError:
logger.warning("语法错误,无法解析 %s", filepath)
return []
modules: list[str] = []
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
name = alias.name
if _is_project_module(name) and not _is_stdlib_or_third_party(name):
modules.append(name)
elif isinstance(node, ast.ImportFrom):
if node.module and node.level == 0:
name = node.module
if _is_project_module(name) and not _is_stdlib_or_third_party(name):
modules.append(name)
# 去重并保持顺序
seen: set[str] = set()
result: list[str] = []
for m in modules:
if m not in seen:
seen.add(m)
result.append(m)
return result
# ---------------------------------------------------------------------------
# build_flow_tree — 从入口递归追踪 import 链,构建流程树
# ---------------------------------------------------------------------------
def build_flow_tree(
repo_root: Path,
entry_file: str,
_visited: set[str] | None = None,
) -> FlowNode:
"""从指定入口文件出发,递归追踪 import 链,构建流程树。
Parameters
----------
repo_root : Path
仓库根目录。
entry_file : str
入口文件的相对路径(如 "cli/main.py")。
_visited : set[str] | None
内部使用,防止循环导入导致无限递归。
Returns
-------
FlowNode
以入口文件为根的流程树。
"""
is_root = _visited is None
if _visited is None:
_visited = set()
module_name = _path_to_module_name(entry_file)
node_type = "entry" if is_root else "module"
_visited.add(entry_file)
filepath = repo_root / entry_file
children: list[FlowNode] = []
if filepath.exists():
imported_modules = parse_imports(filepath)
for mod in imported_modules:
child_path = _module_to_path(mod)
# 如果 .py 文件不存在,尝试 __init__.py
if not (repo_root / child_path).exists():
alt_path = mod.replace(".", "/") + "/__init__.py"
if (repo_root / alt_path).exists():
child_path = alt_path
if child_path not in _visited:
child_node = build_flow_tree(repo_root, child_path, _visited)
children.append(child_node)
return FlowNode(
name=module_name,
source_file=entry_file,
node_type=node_type,
children=children,
)
# ---------------------------------------------------------------------------
# 批处理文件解析
# ---------------------------------------------------------------------------
def _parse_bat_python_target(bat_path: Path) -> str | None:
"""从批处理文件中解析 python -m 命令的目标模块名。
返回模块名(如 "cli.main"),未找到时返回 None。
"""
if not bat_path.exists():
return None
content = _read_source(bat_path)
if content is None:
return None
# 匹配 python -m module.name 或 python3 -m module.name
pattern = re.compile(r"python[3]?\s+-m\s+([\w.]+)", re.IGNORECASE)
for line in content.splitlines():
m = pattern.search(line)
if m:
return m.group(1)
return None
# ---------------------------------------------------------------------------
# 入口点识别
# ---------------------------------------------------------------------------
def discover_entry_points(repo_root: Path) -> list[dict[str, str]]:
"""识别项目的所有入口点。
返回字典列表,每个字典包含:
- type: 入口类型CLI / GUI / 批处理 / 运维脚本)
- file: 相对路径
- description: 简要说明
识别规则:
- cli/main.py → CLI 入口
- gui/main.py → GUI 入口
- *.bat 文件 → 解析其中的 python -m 命令
- scripts/*.py含 if __name__ == "__main__",排除 __init__.py 和 audit/ 子目录)
"""
entries: list[dict[str, str]] = []
# CLI 入口
cli_main = repo_root / "cli" / "main.py"
if cli_main.exists():
entries.append({
"type": "CLI",
"file": "cli/main.py",
"description": "CLI 主入口 (`python -m cli.main`)",
})
# GUI 入口
gui_main = repo_root / "gui" / "main.py"
if gui_main.exists():
entries.append({
"type": "GUI",
"file": "gui/main.py",
"description": "GUI 主入口 (`python -m gui.main`)",
})
# 批处理文件
for bat in sorted(repo_root.glob("*.bat")):
target = _parse_bat_python_target(bat)
desc = f"批处理脚本"
if target:
desc += f",调用 `{target}`"
entries.append({
"type": "批处理",
"file": bat.name,
"description": desc,
})
# 运维脚本scripts/ 下的 .py 文件(排除 __init__.py 和 audit/ 子目录)
scripts_dir = repo_root / "scripts"
if scripts_dir.is_dir():
for py_file in sorted(scripts_dir.glob("*.py")):
if py_file.name == "__init__.py":
continue
# 检查是否包含 if __name__ == "__main__"
source = _read_source(py_file)
if source and '__name__' in source and '__main__' in source:
rel = py_file.relative_to(repo_root).as_posix()
entries.append({
"type": "运维脚本",
"file": rel,
"description": f"运维脚本 `{py_file.name}`",
})
return entries
# ---------------------------------------------------------------------------
# 任务类型和加载器类型区分
# ---------------------------------------------------------------------------
def classify_task_type(rel_path: str) -> str:
"""根据文件路径区分任务类型。
返回值:
- "ODS 抓取任务"
- "DWD 加载任务"
- "DWS 汇总任务"
- "校验任务"
- "Schema 初始化任务"
- "任务"(无法细分时的默认值)
"""
p = rel_path.replace("\\", "/").lower()
if "verification/" in p or "verification\\" in p:
return "校验任务"
if "dws/" in p or "dws\\" in p:
return "DWS 汇总任务"
# 文件名级别判断
basename = p.rsplit("/", 1)[-1] if "/" in p else p
if basename.startswith("ods_") or basename.startswith("ods."):
return "ODS 抓取任务"
if basename.startswith("dwd_") or basename.startswith("dwd."):
return "DWD 加载任务"
if basename.startswith("dws_"):
return "DWS 汇总任务"
if "init" in basename and "schema" in basename:
return "Schema 初始化任务"
return "任务"
def classify_loader_type(rel_path: str) -> str:
"""根据文件路径区分加载器类型。
返回值:
- "维度加载器 (SCD2)"
- "事实表加载器"
- "ODS 通用加载器"
- "加载器"(无法细分时的默认值)
"""
p = rel_path.replace("\\", "/").lower()
if "dimensions/" in p or "dimensions\\" in p:
return "维度加载器 (SCD2)"
if "facts/" in p or "facts\\" in p:
return "事实表加载器"
if "ods/" in p or "ods\\" in p:
return "ODS 通用加载器"
return "加载器"
# ---------------------------------------------------------------------------
# find_orphan_modules — 找出未被任何入口直接或间接引用的 Python 模块
# ---------------------------------------------------------------------------
def find_orphan_modules(
repo_root: Path,
all_entries: list[FileEntry],
reachable: set[str],
) -> list[str]:
"""找出未被任何入口直接或间接引用的 Python 模块。
排除规则(不视为孤立):
- __init__.py 文件
- tests/ 目录下的文件
- scripts/audit/ 目录下的文件(审计脚本自身)
- 目录条目
- 非 .py 文件
- 不属于项目包的文件
返回按路径排序的孤立模块列表。
"""
orphans: list[str] = []
for entry in all_entries:
# 跳过目录
if entry.is_dir:
continue
# 只关注 .py 文件
if entry.extension != ".py":
continue
rel = entry.rel_path.replace("\\", "/")
# 排除 __init__.py
if rel.endswith("/__init__.py") or rel == "__init__.py":
continue
# 排除测试文件
if rel.startswith("tests/") or rel.startswith("tests\\"):
continue
# 排除审计脚本自身
if rel.startswith("scripts/audit/") or rel.startswith("scripts\\audit\\"):
continue
# 只检查属于项目包的文件
top_dir = rel.split("/")[0] if "/" in rel else ""
if top_dir not in _PROJECT_PACKAGES:
continue
# 不在可达集合中 → 孤立
if rel not in reachable:
orphans.append(rel)
orphans.sort()
return orphans
# ---------------------------------------------------------------------------
# 统计辅助
# ---------------------------------------------------------------------------
def _count_nodes_by_type(trees: list[FlowNode]) -> dict[str, int]:
"""递归统计流程树中各类型节点的数量。"""
counts: dict[str, int] = {"entry": 0, "module": 0, "class": 0, "function": 0}
def _walk(node: FlowNode) -> None:
t = node.node_type
counts[t] = counts.get(t, 0) + 1
for child in node.children:
_walk(child)
for tree in trees:
_walk(tree)
return counts
def _count_tasks_and_loaders(trees: list[FlowNode]) -> tuple[int, int]:
"""统计流程树中任务模块和加载器模块的数量。"""
tasks = 0
loaders = 0
seen: set[str] = set()
def _walk(node: FlowNode) -> None:
nonlocal tasks, loaders
if node.source_file in seen:
return
seen.add(node.source_file)
sf = node.source_file.replace("\\", "/")
if sf.startswith("tasks/") and not sf.endswith("__init__.py"):
base = sf.rsplit("/", 1)[-1]
if not base.startswith("base_"):
tasks += 1
if sf.startswith("loaders/") and not sf.endswith("__init__.py"):
base = sf.rsplit("/", 1)[-1]
if not base.startswith("base_"):
loaders += 1
for child in node.children:
_walk(child)
for tree in trees:
_walk(tree)
return tasks, loaders
# ---------------------------------------------------------------------------
# 类型标注辅助
# ---------------------------------------------------------------------------
def _get_type_annotation(source_file: str) -> str:
"""根据源文件路径返回类型标注字符串(用于报告中的节点标注)。"""
sf = source_file.replace("\\", "/")
if sf.startswith("tasks/"):
return f" [{classify_task_type(sf)}]"
if sf.startswith("loaders/"):
return f" [{classify_loader_type(sf)}]"
return ""
# ---------------------------------------------------------------------------
# Mermaid 图生成
# ---------------------------------------------------------------------------
def _render_mermaid(trees: list[FlowNode]) -> str:
"""生成 Mermaid 流程图代码。"""
lines: list[str] = ["```mermaid", "graph TD"]
seen_edges: set[tuple[str, str]] = set()
node_ids: dict[str, str] = {}
counter = [0]
def _node_id(name: str) -> str:
if name not in node_ids:
node_ids[name] = f"N{counter[0]}"
counter[0] += 1
return node_ids[name]
def _walk(node: FlowNode) -> None:
nid = _node_id(node.name)
annotation = _get_type_annotation(node.source_file)
label = f"{node.name}{annotation}"
# 声明节点
lines.append(f" {nid}[\"`{label}`\"]")
for child in node.children:
cid = _node_id(child.name)
edge = (nid, cid)
if edge not in seen_edges:
seen_edges.add(edge)
lines.append(f" {nid} --> {cid}")
_walk(child)
for tree in trees:
_walk(tree)
lines.append("```")
return "\n".join(lines)
# ---------------------------------------------------------------------------
# 缩进文本树生成
# ---------------------------------------------------------------------------
def _render_text_tree(trees: list[FlowNode]) -> str:
"""生成缩进文本形式的流程树。"""
lines: list[str] = []
seen: set[str] = set()
def _walk(node: FlowNode, depth: int) -> None:
indent = " " * depth
annotation = _get_type_annotation(node.source_file)
line = f"{indent}- `{node.name}` (`{node.source_file}`){annotation}"
lines.append(line)
key = node.source_file
if key in seen:
# 已展开过,不再递归(避免循环)
if node.children:
lines.append(f"{indent} - *(已展开)*")
return
seen.add(key)
for child in node.children:
_walk(child, depth + 1)
for tree in trees:
_walk(tree, 0)
return "\n".join(lines)
# ---------------------------------------------------------------------------
# render_flow_report — 生成 Markdown 格式的流程树报告
# ---------------------------------------------------------------------------
def render_flow_report(
trees: list[FlowNode],
orphans: list[str],
repo_root: str,
) -> str:
"""生成 Markdown 格式的流程树报告(含 Mermaid 图和缩进文本)。
报告结构:
1. 头部(时间戳、仓库路径)
2. Mermaid 流程图
3. 缩进文本树
4. 孤立模块列表
5. 统计摘要
"""
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
sections: list[str] = []
# --- 头部 ---
sections.append("# 项目流程树报告\n")
sections.append(f"- 生成时间: {timestamp}")
sections.append(f"- 仓库路径: `{repo_root}`\n")
# --- Mermaid 图 ---
sections.append("## 流程图Mermaid\n")
sections.append(_render_mermaid(trees))
sections.append("")
# --- 缩进文本树 ---
sections.append("## 流程树(缩进文本)\n")
sections.append(_render_text_tree(trees))
sections.append("")
# --- 孤立模块 ---
sections.append("## 孤立模块\n")
if orphans:
for o in orphans:
sections.append(f"- `{o}`")
else:
sections.append("未发现孤立模块。")
sections.append("")
# --- 统计摘要 ---
entry_count = sum(1 for t in trees if t.node_type == "entry")
task_count, loader_count = _count_tasks_and_loaders(trees)
orphan_count = len(orphans)
sections.append("## 统计摘要\n")
sections.append(f"| 指标 | 数量 |")
sections.append(f"|------|------|")
sections.append(f"| 入口点 | {entry_count} |")
sections.append(f"| 任务 | {task_count} |")
sections.append(f"| 加载器 | {loader_count} |")
sections.append(f"| 孤立模块 | {orphan_count} |")
sections.append("")
return "\n".join(sections)

View File

@@ -0,0 +1,449 @@
# -*- coding: utf-8 -*-
"""
文件清单分析器 — 对扫描结果进行用途分类和处置标签分配。
分类规则按优先级从高到低排列:
1. tmp/ 下所有文件 → 临时与调试 / 候选删除或候选归档
2. logs/、export/ 下的运行时产出 → 日志与输出 / 候选归档
3. *.lnk、*.rar 文件 → 其他 / 候选删除
4. 空目录 → 其他 / 候选删除
5. 核心代码目录tasks/ 等)→ 核心代码 / 保留
6. config/ → 配置 / 保留
7. database/*.sql、database/migrations/ → 数据库定义 / 保留
8. database/*.py → 核心代码 / 保留
9. tests/ → 测试 / 保留
10. docs/ → 文档 / 保留
11. scripts/ 下的 .py 文件 → 脚本工具 / 保留
12. gui/ → GUI / 保留
13. 构建与部署文件 → 构建与部署 / 保留
14. 其余 → 其他 / 待确认
"""
from __future__ import annotations
import os
from collections import Counter
from datetime import datetime, timezone
from itertools import groupby
from scripts.audit import Category, Disposition, FileEntry, InventoryItem
# ---------------------------------------------------------------------------
# 常量
# ---------------------------------------------------------------------------
# 核心代码顶层目录
_CORE_CODE_DIRS = (
"tasks/", "loaders/", "scd/", "orchestration/",
"quality/", "models/", "utils/", "api/",
)
# 构建与部署文件名(根目录级别)
_BUILD_DEPLOY_BASENAMES = {"setup.py", "build_exe.py"}
# 构建与部署扩展名
_BUILD_DEPLOY_EXTENSIONS = {".bat", ".sh", ".ps1"}
# ---------------------------------------------------------------------------
# 辅助函数
# ---------------------------------------------------------------------------
def _top_dir(rel_path: str) -> str:
"""返回相对路径的第一级目录名(含尾部斜杠),如 'tmp/foo.py''tmp/'"""
idx = rel_path.find("/")
if idx == -1:
return ""
return rel_path[: idx + 1]
def _basename(rel_path: str) -> str:
"""返回路径的最后一段文件名。"""
return rel_path.rsplit("/", 1)[-1]
def _is_init_py(rel_path: str) -> bool:
"""判断路径是否为 __init__.py。"""
return _basename(rel_path) == "__init__.py"
# ---------------------------------------------------------------------------
# classify — 核心分类函数
# ---------------------------------------------------------------------------
def classify(entry: FileEntry) -> InventoryItem:
"""根据路径、扩展名等规则对单个文件/目录进行分类和标签分配。
规则按优先级从高到低依次匹配,首个命中的规则决定分类和处置。
"""
path = entry.rel_path
top = _top_dir(path)
ext = entry.extension.lower()
base = _basename(path)
# --- 优先级 1: tmp/ 下所有文件 ---
if top == "tmp/" or path == "tmp":
return _classify_tmp(entry)
# --- 优先级 2: logs/、export/ 下的运行时产出 ---
if top in ("logs/", "export/") or path in ("logs", "export"):
return _classify_runtime_output(entry)
# --- 优先级 3: .lnk / .rar 文件 ---
if ext in (".lnk", ".rar"):
return InventoryItem(
rel_path=path,
category=Category.OTHER,
disposition=Disposition.CANDIDATE_DELETE,
description=f"快捷方式/压缩包文件(`{ext}`),建议删除",
)
# --- 优先级 4: 空目录 ---
if entry.is_empty_dir:
return InventoryItem(
rel_path=path,
category=Category.OTHER,
disposition=Disposition.CANDIDATE_DELETE,
description="空目录,建议删除",
)
# --- 优先级 5: 核心代码目录 ---
if any(path.startswith(d) or path + "/" == d for d in _CORE_CODE_DIRS):
return InventoryItem(
rel_path=path,
category=Category.CORE_CODE,
disposition=Disposition.KEEP,
description=f"核心代码(`{top.rstrip('/')}`",
)
# --- 优先级 6: config/ ---
if top == "config/" or path == "config":
return InventoryItem(
rel_path=path,
category=Category.CONFIG,
disposition=Disposition.KEEP,
description="配置文件",
)
# --- 优先级 7: database/*.sql 和 database/migrations/ ---
if top == "database/" or path == "database":
return _classify_database(entry)
# --- 优先级 8: tests/ ---
if top == "tests/" or path == "tests":
return InventoryItem(
rel_path=path,
category=Category.TEST,
disposition=Disposition.KEEP,
description="测试文件",
)
# --- 优先级 9: docs/ ---
if top == "docs/" or path == "docs":
return InventoryItem(
rel_path=path,
category=Category.DOCS,
disposition=Disposition.KEEP,
description="文档",
)
# --- 优先级 10: scripts/ 下的 .py 文件 ---
if top == "scripts/" or path == "scripts":
cat = Category.SCRIPTS
if ext == ".py" or entry.is_dir:
return InventoryItem(
rel_path=path,
category=cat,
disposition=Disposition.KEEP,
description="脚本工具",
)
return InventoryItem(
rel_path=path,
category=cat,
disposition=Disposition.NEEDS_REVIEW,
description="脚本目录下的非 Python 文件,需确认用途",
)
# --- 优先级 11: gui/ ---
if top == "gui/" or path == "gui":
return InventoryItem(
rel_path=path,
category=Category.GUI,
disposition=Disposition.KEEP,
description="GUI 模块",
)
# --- 优先级 12: 构建与部署 ---
if base in _BUILD_DEPLOY_BASENAMES or ext in _BUILD_DEPLOY_EXTENSIONS:
return InventoryItem(
rel_path=path,
category=Category.BUILD_DEPLOY,
disposition=Disposition.KEEP,
description="构建与部署文件",
)
# --- 优先级 13: cli/ ---
if top == "cli/" or path == "cli":
return InventoryItem(
rel_path=path,
category=Category.CORE_CODE,
disposition=Disposition.KEEP,
description="CLI 入口模块",
)
# --- 优先级 14: 已知根目录文件 ---
if "/" not in path:
return _classify_root_file(entry)
# --- 兜底 ---
return InventoryItem(
rel_path=path,
category=Category.OTHER,
disposition=Disposition.NEEDS_REVIEW,
description="未匹配已知规则,需人工确认用途",
)
# ---------------------------------------------------------------------------
# 子分类函数
# ---------------------------------------------------------------------------
def _classify_tmp(entry: FileEntry) -> InventoryItem:
"""tmp/ 目录下的文件分类。
默认候选删除;有意义的 .py 文件标记为候选归档。
"""
ext = entry.extension.lower()
base = _basename(entry.rel_path)
# 空目录直接候选删除
if entry.is_empty_dir:
return InventoryItem(
rel_path=entry.rel_path,
category=Category.TEMP_DEBUG,
disposition=Disposition.CANDIDATE_DELETE,
description="临时目录下的空目录",
)
# .py 文件可能有参考价值 → 候选归档
if ext == ".py" and len(base) > 4:
return InventoryItem(
rel_path=entry.rel_path,
category=Category.TEMP_DEBUG,
disposition=Disposition.CANDIDATE_ARCHIVE,
description="临时 Python 脚本,可能有参考价值",
)
return InventoryItem(
rel_path=entry.rel_path,
category=Category.TEMP_DEBUG,
disposition=Disposition.CANDIDATE_DELETE,
description="临时/调试文件,建议删除",
)
def _classify_runtime_output(entry: FileEntry) -> InventoryItem:
"""logs/、export/ 目录下的运行时产出分类。
__init__.py 保留(包标记),其余候选归档。
"""
if _is_init_py(entry.rel_path):
return InventoryItem(
rel_path=entry.rel_path,
category=Category.LOG_OUTPUT,
disposition=Disposition.KEEP,
description="包初始化文件",
)
return InventoryItem(
rel_path=entry.rel_path,
category=Category.LOG_OUTPUT,
disposition=Disposition.CANDIDATE_ARCHIVE,
description="运行时产出,建议归档",
)
def _classify_database(entry: FileEntry) -> InventoryItem:
"""database/ 目录下的文件分类。"""
path = entry.rel_path
ext = entry.extension.lower()
# migrations/ 子目录
if "migrations/" in path or path.endswith("migrations"):
return InventoryItem(
rel_path=path,
category=Category.DATABASE_DEF,
disposition=Disposition.KEEP,
description="数据库迁移脚本",
)
# .sql 文件
if ext == ".sql":
return InventoryItem(
rel_path=path,
category=Category.DATABASE_DEF,
disposition=Disposition.KEEP,
description="数据库 DDL/DML 脚本",
)
# .py 文件 → 核心代码
if ext == ".py":
return InventoryItem(
rel_path=path,
category=Category.CORE_CODE,
disposition=Disposition.KEEP,
description="数据库操作模块",
)
# 目录本身
if entry.is_dir:
if entry.is_empty_dir:
return InventoryItem(
rel_path=path,
category=Category.OTHER,
disposition=Disposition.CANDIDATE_DELETE,
description="数据库目录下的空目录",
)
return InventoryItem(
rel_path=path,
category=Category.DATABASE_DEF,
disposition=Disposition.KEEP,
description="数据库子目录",
)
# 其他文件
return InventoryItem(
rel_path=path,
category=Category.DATABASE_DEF,
disposition=Disposition.NEEDS_REVIEW,
description="数据库目录下的非标准文件,需确认",
)
def _classify_root_file(entry: FileEntry) -> InventoryItem:
"""根目录散落文件的分类。"""
ext = entry.extension.lower()
base = _basename(entry.rel_path)
# 已知构建文件
if base in _BUILD_DEPLOY_BASENAMES or ext in _BUILD_DEPLOY_EXTENSIONS:
return InventoryItem(
rel_path=entry.rel_path,
category=Category.BUILD_DEPLOY,
disposition=Disposition.KEEP,
description="构建与部署文件",
)
# 已知配置文件
if base in (
"requirements.txt", "pytest.ini", ".env", ".env.example",
".gitignore", ".flake8", "pyproject.toml",
):
return InventoryItem(
rel_path=entry.rel_path,
category=Category.CONFIG,
disposition=Disposition.KEEP,
description="项目配置文件",
)
# README
if base.lower().startswith("readme"):
return InventoryItem(
rel_path=entry.rel_path,
category=Category.DOCS,
disposition=Disposition.KEEP,
description="项目说明文档",
)
# 其他根目录文件 → 待确认
return InventoryItem(
rel_path=entry.rel_path,
category=Category.OTHER,
disposition=Disposition.NEEDS_REVIEW,
description=f"根目录散落文件(`{base}`),需确认用途",
)
# ---------------------------------------------------------------------------
# build_inventory — 批量分类
# ---------------------------------------------------------------------------
def build_inventory(entries: list[FileEntry]) -> list[InventoryItem]:
"""对所有文件条目执行分类,返回清单列表。"""
return [classify(e) for e in entries]
# ---------------------------------------------------------------------------
# render_inventory_report — Markdown 渲染
# ---------------------------------------------------------------------------
def render_inventory_report(items: list[InventoryItem], repo_root: str) -> str:
"""生成 Markdown 格式的文件清单报告。
报告结构:
- 头部:标题、生成时间、仓库路径
- 主体:按 Category 分组的表格
- 尾部:统计摘要
"""
lines: list[str] = []
# --- 头部 ---
now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
lines.append("# 文件清单报告")
lines.append("")
lines.append(f"- 生成时间:{now}")
lines.append(f"- 仓库路径:`{repo_root}`")
lines.append("")
# --- 按分类分组 ---
# 保持 Category 枚举定义顺序
cat_order = {c: i for i, c in enumerate(Category)}
sorted_items = sorted(items, key=lambda it: cat_order[it.category])
for cat, group in groupby(sorted_items, key=lambda it: it.category):
group_list = list(group)
lines.append(f"## {cat.value}")
lines.append("")
lines.append("| 相对路径 | 处置标签 | 简要说明 |")
lines.append("|---|---|---|")
for item in group_list:
lines.append(
f"| `{item.rel_path}` | {item.disposition.value} | {item.description} |"
)
lines.append("")
# --- 统计摘要 ---
lines.append("## 统计摘要")
lines.append("")
# 各分类计数
cat_counter: Counter[Category] = Counter()
disp_counter: Counter[Disposition] = Counter()
for item in items:
cat_counter[item.category] += 1
disp_counter[item.disposition] += 1
lines.append("### 按用途分类")
lines.append("")
lines.append("| 分类 | 数量 |")
lines.append("|---|---|")
for cat in Category:
count = cat_counter.get(cat, 0)
if count > 0:
lines.append(f"| {cat.value} | {count} |")
lines.append("")
lines.append("### 按处置标签")
lines.append("")
lines.append("| 标签 | 数量 |")
lines.append("|---|---|")
for disp in Disposition:
count = disp_counter.get(disp, 0)
if count > 0:
lines.append(f"| {disp.value} | {count} |")
lines.append("")
lines.append(f"**总计:{len(items)} 个条目**")
lines.append("")
return "\n".join(lines)

View File

@@ -0,0 +1,255 @@
# -*- coding: utf-8 -*-
"""
审计主入口 — 依次调用扫描器和三个分析器,生成三份报告到 docs/audit/repo/。
仅在 docs/audit/repo/ 目录下创建文件,不修改仓库中的任何现有文件。
"""
from __future__ import annotations
import logging
import re
from datetime import datetime, timezone
from pathlib import Path
from scripts.audit.scanner import scan_repo
from scripts.audit.inventory_analyzer import (
build_inventory,
render_inventory_report,
)
from scripts.audit.flow_analyzer import (
build_flow_tree,
discover_entry_points,
find_orphan_modules,
render_flow_report,
)
from scripts.audit.doc_alignment_analyzer import (
build_mappings,
check_api_samples_vs_parsers,
check_ddl_vs_dictionary,
find_undocumented_modules,
render_alignment_report,
scan_docs,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# 仓库根目录自动检测
# ---------------------------------------------------------------------------
def _detect_repo_root() -> Path:
"""从当前文件向上查找仓库根目录。
判断依据:包含 cli/ 目录或 .git/ 目录的祖先目录。
"""
current = Path(__file__).resolve().parent
for parent in (current, *current.parents):
if (parent / "cli").is_dir() or (parent / ".git").is_dir():
return parent
# 回退:假设 scripts/audit/ 在仓库根目录下
return current.parent.parent
# ---------------------------------------------------------------------------
# 报告输出目录
# ---------------------------------------------------------------------------
def _ensure_report_dir(repo_root: Path) -> Path:
"""检查并创建 docs/audit/repo/ 目录。
如果目录已存在则直接返回;不存在则创建。
创建失败时抛出 RuntimeError因为无法输出报告
"""
audit_dir = repo_root / "docs" / "audit" / "repo"
if audit_dir.is_dir():
return audit_dir
try:
audit_dir.mkdir(parents=True, exist_ok=True)
except OSError as exc:
raise RuntimeError(f"无法创建报告输出目录 {audit_dir}: {exc}") from exc
logger.info("已创建报告输出目录: %s", audit_dir)
return audit_dir
# ---------------------------------------------------------------------------
# 报告头部元信息注入
# ---------------------------------------------------------------------------
_HEADER_PATTERN = re.compile(r"生成时间[:]")
_ISO_TS_PATTERN = re.compile(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z")
# 匹配非 ISO 格式的时间戳行,用于替换
_NON_ISO_TS_LINE = re.compile(
r"([-*]\s*生成时间[:]\s*)\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2}"
)
def _inject_header(report: str, timestamp: str, repo_path: str) -> str:
"""确保报告头部包含 ISO 格式时间戳和仓库路径。
- 已有 ISO 时间戳 → 不修改
- 有非 ISO 时间戳 → 替换为 ISO 格式
- 无头部 → 在标题后注入
"""
if _HEADER_PATTERN.search(report):
# 已有头部——检查时间戳格式是否为 ISO
if _ISO_TS_PATTERN.search(report):
return report
# 非 ISO 格式 → 替换时间戳
report = _NON_ISO_TS_LINE.sub(
lambda m: m.group(1) + timestamp, report,
)
# 同时确保仓库路径使用统一值(用 lambda 避免反斜杠转义问题)
safe_path = repo_path
report = re.sub(
r"([-*]\s*仓库路径[:]\s*)`[^`]*`",
lambda m: m.group(1) + "`" + safe_path + "`",
report,
)
return report
# 无头部 → 在第一个标题行之后插入
lines = report.split("\n")
insert_idx = 1
for i, line in enumerate(lines):
if line.startswith("# "):
insert_idx = i + 1
break
header_lines = [
"",
f"- 生成时间: {timestamp}",
f"- 仓库路径: `{repo_path}`",
"",
]
lines[insert_idx:insert_idx] = header_lines
return "\n".join(lines)
# ---------------------------------------------------------------------------
# 主函数
# ---------------------------------------------------------------------------
def run_audit(repo_root: Path | None = None) -> None:
"""执行完整审计流程,生成三份报告到 docs/audit/repo/。
Parameters
----------
repo_root : Path | None
仓库根目录。为 None 时自动检测。
"""
# 1. 确定仓库根目录
if repo_root is None:
repo_root = _detect_repo_root()
repo_root = repo_root.resolve()
repo_path_str = str(repo_root)
logger.info("审计开始 — 仓库路径: %s", repo_path_str)
# 2. 检查/创建输出目录
audit_dir = _ensure_report_dir(repo_root)
# 3. 生成 UTC 时间戳(所有报告共用)
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
# 4. 扫描仓库
logger.info("正在扫描仓库文件...")
entries = scan_repo(repo_root)
logger.info("扫描完成,共 %d 个条目", len(entries))
# 5. 文件清单报告
logger.info("正在生成文件清单报告...")
try:
inventory_items = build_inventory(entries)
inventory_report = render_inventory_report(inventory_items, repo_path_str)
inventory_report = _inject_header(inventory_report, timestamp, repo_path_str)
(audit_dir / "file_inventory.md").write_text(
inventory_report, encoding="utf-8",
)
logger.info("文件清单报告已写入: file_inventory.md")
except Exception:
logger.exception("生成文件清单报告时出错")
# 6. 流程树报告
logger.info("正在生成流程树报告...")
try:
entry_points = discover_entry_points(repo_root)
trees = []
reachable: set[str] = set()
for ep in entry_points:
ep_file = ep["file"]
# 批处理文件不构建流程树
if not ep_file.endswith(".py"):
continue
tree = build_flow_tree(repo_root, ep_file)
trees.append(tree)
# 收集可达模块
_collect_reachable(tree, reachable)
orphans = find_orphan_modules(repo_root, entries, reachable)
flow_report = render_flow_report(trees, orphans, repo_path_str)
flow_report = _inject_header(flow_report, timestamp, repo_path_str)
(audit_dir / "flow_tree.md").write_text(
flow_report, encoding="utf-8",
)
logger.info("流程树报告已写入: flow_tree.md")
except Exception:
logger.exception("生成流程树报告时出错")
# 7. 文档对齐报告
logger.info("正在生成文档对齐报告...")
try:
doc_paths = scan_docs(repo_root)
mappings = build_mappings(doc_paths, repo_root)
issues = []
issues.extend(check_ddl_vs_dictionary(repo_root))
issues.extend(check_api_samples_vs_parsers(repo_root))
# 缺失文档检测
documented: set[str] = set()
for m in mappings:
documented.update(m.related_code)
undoc_modules = find_undocumented_modules(repo_root, documented)
from scripts.audit import AlignmentIssue
for mod in undoc_modules:
issues.append(AlignmentIssue(
doc_path="",
issue_type="missing",
description=f"核心代码模块 `{mod}` 缺少对应文档",
related_code=mod,
))
alignment_report = render_alignment_report(mappings, issues, repo_path_str)
alignment_report = _inject_header(alignment_report, timestamp, repo_path_str)
(audit_dir / "doc_alignment.md").write_text(
alignment_report, encoding="utf-8",
)
logger.info("文档对齐报告已写入: doc_alignment.md")
except Exception:
logger.exception("生成文档对齐报告时出错")
logger.info("审计完成 — 报告输出目录: %s", audit_dir)
# ---------------------------------------------------------------------------
# 辅助:收集可达模块
# ---------------------------------------------------------------------------
def _collect_reachable(node, reachable: set[str]) -> None:
"""递归收集流程树中所有节点的 source_file。"""
reachable.add(node.source_file)
for child in node.children:
_collect_reachable(child, reachable)
# ---------------------------------------------------------------------------
# 入口
# ---------------------------------------------------------------------------
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
run_audit()

View File

@@ -0,0 +1,150 @@
# -*- coding: utf-8 -*-
"""
仓库扫描器 — 递归遍历仓库文件系统,返回结构化的文件元信息。
仅执行只读操作:读取文件元信息(大小、类型),不修改任何文件。
遇到权限错误时跳过并记录日志,不中断扫描流程。
"""
from __future__ import annotations
import fnmatch
import logging
from pathlib import Path
from scripts.audit import FileEntry
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# 排除模式
# ---------------------------------------------------------------------------
EXCLUDED_PATTERNS: list[str] = [
".git",
"__pycache__",
".pytest_cache",
"*.pyc",
".kiro",
]
# ---------------------------------------------------------------------------
# 排除匹配逻辑
# ---------------------------------------------------------------------------
def _is_excluded(name: str, patterns: list[str]) -> bool:
"""判断文件/目录名是否匹配任一排除模式。
支持两种模式:
- 精确匹配(如 ".git""__pycache__"
- 通配符匹配(如 "*.pyc"),使用 fnmatch 语义
"""
for pat in patterns:
if fnmatch.fnmatch(name, pat):
return True
return False
# ---------------------------------------------------------------------------
# 递归遍历
# ---------------------------------------------------------------------------
def _walk(
root: Path,
base: Path,
exclude: list[str],
results: list[FileEntry],
) -> None:
"""递归遍历 *root* 下的文件和目录,将结果追加到 *results*。
Parameters
----------
root : Path
当前要遍历的目录。
base : Path
仓库根目录,用于计算相对路径。
exclude : list[str]
排除模式列表。
results : list[FileEntry]
收集结果的列表(就地修改)。
"""
try:
children = sorted(root.iterdir(), key=lambda p: p.name)
except (PermissionError, OSError) as exc:
logger.warning("无法读取目录 %s: %s", root, exc)
return
# 用于判断当前目录是否为"空目录"(排除后无可见子项)
visible_count = 0
for child in children:
if _is_excluded(child.name, exclude):
continue
visible_count += 1
rel = child.relative_to(base).as_posix()
if child.is_dir():
# 先递归子目录,再判断该目录是否为空
sub_start = len(results)
_walk(child, base, exclude, results)
sub_end = len(results)
# 该目录下递归产生的条目数为 0 → 空目录
is_empty = (sub_end == sub_start)
results.append(FileEntry(
rel_path=rel,
is_dir=True,
size_bytes=0,
extension="",
is_empty_dir=is_empty,
))
else:
# 文件
try:
size = child.stat().st_size
except (PermissionError, OSError) as exc:
logger.warning("无法获取文件信息 %s: %s", child, exc)
continue
results.append(FileEntry(
rel_path=rel,
is_dir=False,
size_bytes=size,
extension=child.suffix.lower(),
is_empty_dir=False,
))
# 如果 root 是仓库根目录自身,不需要额外处理
# (根目录不作为条目出现在结果中)
def scan_repo(
root: Path,
exclude: list[str] | None = None,
) -> list[FileEntry]:
"""递归扫描仓库,返回所有文件和目录的元信息列表。
Parameters
----------
root : Path
仓库根目录路径。
exclude : list[str] | None
排除模式列表,默认使用 EXCLUDED_PATTERNS。
Returns
-------
list[FileEntry]
按 rel_path 排序的文件/目录元信息列表。
"""
if exclude is None:
exclude = EXCLUDED_PATTERNS
results: list[FileEntry] = []
_walk(root, root, exclude, results)
# 按相对路径排序,保证输出稳定
results.sort(key=lambda e: e.rel_path)
return results

View File

@@ -0,0 +1,193 @@
# -*- coding: utf-8 -*-
"""Run data integrity checks across API -> ODS -> DWD."""
from __future__ import annotations
import argparse
import sys
from datetime import datetime
from pathlib import Path
from zoneinfo import ZoneInfo
from dateutil import parser as dtparser
from config.settings import AppConfig
from quality.integrity_service import run_history_flow, run_window_flow, write_report
from utils.logging_utils import build_log_path, configure_logging
from utils.windowing import split_window
def _parse_dt(value: str, tz: ZoneInfo) -> datetime:
dt = dtparser.parse(value)
if dt.tzinfo is None:
return dt.replace(tzinfo=tz)
return dt.astimezone(tz)
def main() -> int:
if hasattr(sys.stdout, "reconfigure"):
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
ap = argparse.ArgumentParser(description="Data integrity checks (API -> ODS -> DWD)")
ap.add_argument("--mode", choices=["history", "window"], default="history")
ap.add_argument(
"--flow",
choices=["verify", "update_and_verify"],
default="verify",
help="verify only or update+verify (auto backfill then optional recheck)",
)
ap.add_argument("--start", default="2025-07-01", help="history start date (default: 2025-07-01)")
ap.add_argument("--end", default="", help="history end datetime (default: last ETL end)")
ap.add_argument("--window-start", default="", help="window start datetime (mode=window)")
ap.add_argument("--window-end", default="", help="window end datetime (mode=window)")
ap.add_argument("--window-split-unit", default="", help="split unit (month/none), default from config")
ap.add_argument("--window-compensation-hours", type=int, default=None, help="window compensation hours, default from config")
ap.add_argument(
"--include-dimensions",
action="store_true",
default=None,
help="include dimension tables in ODS->DWD checks",
)
ap.add_argument(
"--no-include-dimensions",
action="store_true",
help="exclude dimension tables in ODS->DWD checks",
)
ap.add_argument("--ods-task-codes", default="", help="comma-separated ODS task codes for API checks")
ap.add_argument("--compare-content", action="store_true", help="compare API vs ODS content hash")
ap.add_argument("--no-compare-content", action="store_true", help="disable content comparison even if enabled in config")
ap.add_argument("--include-mismatch", action="store_true", help="backfill mismatch records as well")
ap.add_argument("--no-include-mismatch", action="store_true", help="disable mismatch backfill")
ap.add_argument("--recheck", action="store_true", help="re-run checks after backfill")
ap.add_argument("--no-recheck", action="store_true", help="skip recheck after backfill")
ap.add_argument("--content-sample-limit", type=int, default=None, help="max mismatch samples per table")
ap.add_argument("--out", default="", help="output JSON path")
ap.add_argument("--log-file", default="", help="log file path")
ap.add_argument("--log-dir", default="", help="log directory")
ap.add_argument("--log-level", default="INFO", help="log level")
ap.add_argument("--no-log-console", action="store_true", help="disable console logging")
args = ap.parse_args()
log_dir = Path(args.log_dir) if args.log_dir else (Path(__file__).resolve().parent / "logs")
log_file = Path(args.log_file) if args.log_file else build_log_path(log_dir, "data_integrity")
log_console = not args.no_log_console
with configure_logging(
"data_integrity",
log_file,
level=args.log_level,
console=log_console,
tee_std=True,
) as logger:
cfg = AppConfig.load({})
tz = ZoneInfo(cfg.get("app.timezone", "Asia/Shanghai"))
report_path = Path(args.out) if args.out else None
if args.recheck and args.no_recheck:
raise SystemExit("cannot set both --recheck and --no-recheck")
if args.include_mismatch and args.no_include_mismatch:
raise SystemExit("cannot set both --include-mismatch and --no-include-mismatch")
if args.include_dimensions and args.no_include_dimensions:
raise SystemExit("cannot set both --include-dimensions and --no-include-dimensions")
compare_content = None
if args.compare_content and args.no_compare_content:
raise SystemExit("cannot set both --compare-content and --no-compare-content")
if args.compare_content:
compare_content = True
elif args.no_compare_content:
compare_content = False
include_mismatch = cfg.get("integrity.backfill_mismatch", True)
if args.include_mismatch:
include_mismatch = True
elif args.no_include_mismatch:
include_mismatch = False
recheck_after_backfill = cfg.get("integrity.recheck_after_backfill", True)
if args.recheck:
recheck_after_backfill = True
elif args.no_recheck:
recheck_after_backfill = False
include_dimensions = cfg.get("integrity.include_dimensions", True)
if args.include_dimensions:
include_dimensions = True
elif args.no_include_dimensions:
include_dimensions = False
if args.mode == "window":
if not args.window_start or not args.window_end:
raise SystemExit("window-start and window-end are required for mode=window")
start_dt = _parse_dt(args.window_start, tz)
end_dt = _parse_dt(args.window_end, tz)
split_unit = (args.window_split_unit or cfg.get("run.window_split.unit", "month") or "month").strip()
comp_hours = args.window_compensation_hours
if comp_hours is None:
comp_hours = cfg.get("run.window_split.compensation_hours", 0)
windows = split_window(
start_dt,
end_dt,
tz=tz,
split_unit=split_unit,
compensation_hours=comp_hours,
)
if not windows:
windows = [(start_dt, end_dt)]
report, counts = run_window_flow(
cfg=cfg,
windows=windows,
include_dimensions=bool(include_dimensions),
task_codes=args.ods_task_codes,
logger=logger,
compare_content=compare_content,
content_sample_limit=args.content_sample_limit,
do_backfill=args.flow == "update_and_verify",
include_mismatch=bool(include_mismatch),
recheck_after_backfill=bool(recheck_after_backfill),
page_size=int(cfg.get("api.page_size") or 200),
chunk_size=500,
)
report_path = write_report(report, prefix="data_integrity_window", tz=tz, report_path=report_path)
report["report_path"] = report_path
logger.info("REPORT_WRITTEN path=%s", report.get("report_path"))
else:
start_dt = _parse_dt(args.start, tz)
if args.end:
end_dt = _parse_dt(args.end, tz)
else:
end_dt = None
report, counts = run_history_flow(
cfg=cfg,
start_dt=start_dt,
end_dt=end_dt,
include_dimensions=bool(include_dimensions),
task_codes=args.ods_task_codes,
logger=logger,
compare_content=compare_content,
content_sample_limit=args.content_sample_limit,
do_backfill=args.flow == "update_and_verify",
include_mismatch=bool(include_mismatch),
recheck_after_backfill=bool(recheck_after_backfill),
page_size=int(cfg.get("api.page_size") or 200),
chunk_size=500,
)
report_path = write_report(report, prefix="data_integrity_history", tz=tz, report_path=report_path)
report["report_path"] = report_path
logger.info("REPORT_WRITTEN path=%s", report.get("report_path"))
logger.info(
"SUMMARY missing=%s mismatch=%s errors=%s",
counts.get("missing"),
counts.get("mismatch"),
counts.get("errors"),
)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,82 @@
# -*- coding: utf-8 -*-
import sys
sys.path.insert(0, '.')
from config.settings import AppConfig
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations
config = AppConfig.load()
db_conn = DatabaseConnection(config.config['db']['dsn'])
db = DatabaseOperations(db_conn)
# 检查DWD层服务记录分布
print("=== DWD层服务记录分析 ===")
print()
# 1. 总体统计
sql1 = """
SELECT
COUNT(*) as total_records,
COUNT(DISTINCT tenant_member_id) as unique_members,
COUNT(DISTINCT site_assistant_id) as unique_assistants,
COUNT(DISTINCT (tenant_member_id, site_assistant_id)) as unique_pairs
FROM billiards_dwd.dwd_assistant_service_log
WHERE tenant_member_id > 0 AND is_delete = 0
"""
r = dict(db.query(sql1)[0])
print("总体统计:")
print(f" 总服务记录数: {r['total_records']}")
print(f" 唯一会员数: {r['unique_members']}")
print(f" 唯一助教数: {r['unique_assistants']}")
print(f" 唯一客户-助教对: {r['unique_pairs']}")
# 2. 助教服务会员数分布
print()
print("助教服务会员数分布 (Top 10):")
sql2 = """
SELECT site_assistant_id, COUNT(DISTINCT tenant_member_id) as member_count
FROM billiards_dwd.dwd_assistant_service_log
WHERE tenant_member_id > 0 AND is_delete = 0
GROUP BY site_assistant_id
ORDER BY member_count DESC
LIMIT 10
"""
for row in db.query(sql2):
r = dict(row)
print(f" 助教 {r['site_assistant_id']}: 服务 {r['member_count']} 个会员")
# 3. 每个客户-助教对的服务次数分布
print()
print("客户-助教对 服务次数分布 (Top 10):")
sql3 = """
SELECT tenant_member_id, site_assistant_id, COUNT(*) as service_count
FROM billiards_dwd.dwd_assistant_service_log
WHERE tenant_member_id > 0 AND is_delete = 0
GROUP BY tenant_member_id, site_assistant_id
ORDER BY service_count DESC
LIMIT 10
"""
for row in db.query(sql3):
r = dict(row)
print(f" 会员 {r['tenant_member_id']} - 助教 {r['site_assistant_id']}: {r['service_count']} 次服务")
# 4. 近60天的数据
print()
print("=== 近60天数据 ===")
sql4 = """
SELECT
COUNT(*) as total_records,
COUNT(DISTINCT tenant_member_id) as unique_members,
COUNT(DISTINCT site_assistant_id) as unique_assistants,
COUNT(DISTINCT (tenant_member_id, site_assistant_id)) as unique_pairs
FROM billiards_dwd.dwd_assistant_service_log
WHERE tenant_member_id > 0 AND is_delete = 0
AND last_use_time >= NOW() - INTERVAL '60 days'
"""
r4 = dict(db.query(sql4)[0])
print(f" 总服务记录数: {r4['total_records']}")
print(f" 唯一会员数: {r4['unique_members']}")
print(f" 唯一助教数: {r4['unique_assistants']}")
print(f" 唯一客户-助教对: {r4['unique_pairs']}")
db_conn.close()

View File

@@ -0,0 +1,248 @@
# -*- coding: utf-8 -*-
"""
Validate that ODS payload content matches stored content_hash.
Usage:
PYTHONPATH=. python -m scripts.check.check_ods_content_hash
PYTHONPATH=. python -m scripts.check.check_ods_content_hash --schema billiards_ods
PYTHONPATH=. python -m scripts.check.check_ods_content_hash --tables member_profiles,orders
"""
from __future__ import annotations
import argparse
import json
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, Iterable, Sequence
from psycopg2.extras import RealDictCursor
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from config.settings import AppConfig
from database.connection import DatabaseConnection
from tasks.ods.ods_tasks import BaseOdsTask
def _reconfigure_stdout_utf8() -> None:
if hasattr(sys.stdout, "reconfigure"):
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
def _fetch_tables(conn, schema: str) -> list[str]:
sql = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = %s AND table_type = 'BASE TABLE'
ORDER BY table_name
"""
with conn.cursor() as cur:
cur.execute(sql, (schema,))
return [r[0] for r in cur.fetchall()]
def _fetch_columns(conn, schema: str, table: str) -> list[str]:
sql = """
SELECT column_name
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, table))
cols = [r[0] for r in cur.fetchall()]
return [c for c in cols if c]
def _fetch_pk_columns(conn, schema: str, table: str) -> list[str]:
sql = """
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema = %s
AND tc.table_name = %s
ORDER BY kcu.ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, table))
cols = [r[0] for r in cur.fetchall()]
return [c for c in cols if c.lower() != "content_hash"]
def _fetch_row_count(conn, schema: str, table: str) -> int:
sql = f'SELECT COUNT(*) FROM "{schema}"."{table}"'
with conn.cursor() as cur:
cur.execute(sql)
row = cur.fetchone()
return int(row[0] if row else 0)
def _iter_rows(
conn,
schema: str,
table: str,
select_cols: Sequence[str],
batch_size: int,
) -> Iterable[dict]:
cols_sql = ", ".join(f'"{c}"' for c in select_cols)
sql = f'SELECT {cols_sql} FROM "{schema}"."{table}"'
with conn.cursor(name=f"ods_hash_{table}", cursor_factory=RealDictCursor) as cur:
cur.itersize = max(1, int(batch_size or 500))
cur.execute(sql)
for row in cur:
yield row
def _build_report_path(out_arg: str | None) -> Path:
if out_arg:
return Path(out_arg)
reports_dir = PROJECT_ROOT / "reports"
reports_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
return reports_dir / f"ods_content_hash_check_{ts}.json"
def _print_progress(
table_label: str,
processed: int,
total: int,
mismatched: int,
missing_hash: int,
invalid_payload: int,
) -> None:
if total:
msg = (
f"[{table_label}] checked {processed}/{total} "
f"mismatch={mismatched} missing_hash={missing_hash} invalid_payload={invalid_payload}"
)
else:
msg = (
f"[{table_label}] checked {processed} "
f"mismatch={mismatched} missing_hash={missing_hash} invalid_payload={invalid_payload}"
)
print(msg, flush=True)
def main() -> int:
_reconfigure_stdout_utf8()
ap = argparse.ArgumentParser(description="Validate ODS payload vs content_hash consistency")
ap.add_argument("--schema", default="billiards_ods", help="ODS schema name")
ap.add_argument("--tables", default="", help="comma-separated table names (optional)")
ap.add_argument("--batch-size", type=int, default=500, help="DB fetch batch size")
ap.add_argument("--progress-every", type=int, default=100, help="print progress every N rows")
ap.add_argument("--sample-limit", type=int, default=5, help="sample mismatch rows per table")
ap.add_argument("--out", default="", help="output report JSON path")
args = ap.parse_args()
cfg = AppConfig.load({})
db = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
conn = db.conn
tables = _fetch_tables(conn, args.schema)
if args.tables.strip():
whitelist = {t.strip() for t in args.tables.split(",") if t.strip()}
tables = [t for t in tables if t in whitelist]
report = {
"schema": args.schema,
"tables": [],
"summary": {
"total_tables": 0,
"checked_tables": 0,
"total_rows": 0,
"checked_rows": 0,
"mismatch_rows": 0,
"missing_hash_rows": 0,
"invalid_payload_rows": 0,
},
}
for table in tables:
table_label = f"{args.schema}.{table}"
cols = _fetch_columns(conn, args.schema, table)
cols_lower = {c.lower() for c in cols}
if "payload" not in cols_lower or "content_hash" not in cols_lower:
print(f"[{table_label}] skip: missing payload/content_hash", flush=True)
continue
total = _fetch_row_count(conn, args.schema, table)
pk_cols = _fetch_pk_columns(conn, args.schema, table)
select_cols = ["content_hash", "payload", *pk_cols]
processed = 0
mismatched = 0
missing_hash = 0
invalid_payload = 0
samples: list[dict[str, Any]] = []
print(f"[{table_label}] start: total_rows={total}", flush=True)
for row in _iter_rows(conn, args.schema, table, select_cols, args.batch_size):
processed += 1
content_hash = row.get("content_hash")
payload = row.get("payload")
recomputed = BaseOdsTask._compute_compare_hash_from_payload(payload)
row_mismatch = False
if not content_hash:
missing_hash += 1
mismatched += 1
row_mismatch = True
elif not recomputed:
invalid_payload += 1
mismatched += 1
row_mismatch = True
elif content_hash != recomputed:
mismatched += 1
row_mismatch = True
if row_mismatch and len(samples) < max(0, int(args.sample_limit or 0)):
sample = {k: row.get(k) for k in pk_cols}
sample["content_hash"] = content_hash
sample["recomputed_hash"] = recomputed
samples.append(sample)
if args.progress_every and processed % int(args.progress_every) == 0:
_print_progress(table_label, processed, total, mismatched, missing_hash, invalid_payload)
if processed and (not args.progress_every or processed % int(args.progress_every) != 0):
_print_progress(table_label, processed, total, mismatched, missing_hash, invalid_payload)
report["tables"].append(
{
"table": table_label,
"total_rows": total,
"checked_rows": processed,
"mismatch_rows": mismatched,
"missing_hash_rows": missing_hash,
"invalid_payload_rows": invalid_payload,
"sample_mismatches": samples,
}
)
report["summary"]["checked_tables"] += 1
report["summary"]["total_rows"] += total
report["summary"]["checked_rows"] += processed
report["summary"]["mismatch_rows"] += mismatched
report["summary"]["missing_hash_rows"] += missing_hash
report["summary"]["invalid_payload_rows"] += invalid_payload
report["summary"]["total_tables"] = len(tables)
out_path = _build_report_path(args.out)
out_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"[REPORT] {out_path}", flush=True)
return 0
if __name__ == "__main__":
raise SystemExit(main())

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,117 @@
# -*- coding: utf-8 -*-
"""
ODS JSON 字段核对脚本:对照当前数据库中的 ODS 表字段,检查示例 JSON默认目录 export/test-json-doc
是否包含同名键,并输出每表未命中的字段,便于补充映射或确认确实无源字段。
使用方法:
set PG_DSN=postgresql://... # 如 .env 中配置
python -m scripts.check.check_ods_json_vs_table
"""
from __future__ import annotations
import json
import os
import pathlib
from typing import Dict, Iterable, Set, Tuple
import psycopg2
from tasks.manual_ingest_task import ManualIngestTask
def _flatten_keys(obj, prefix: str = "") -> Set[str]:
"""递归展开 JSON 所有键路径,返回形如 data.assistantInfos.id 的集合。列表不保留索引,仅继续向下展开。"""
keys: Set[str] = set()
if isinstance(obj, dict):
for k, v in obj.items():
new_prefix = f"{prefix}.{k}" if prefix else k
keys.add(new_prefix)
keys |= _flatten_keys(v, new_prefix)
elif isinstance(obj, list):
for item in obj:
keys |= _flatten_keys(item, prefix)
return keys
def _load_json_keys(path: pathlib.Path) -> Tuple[Set[str], dict[str, Set[str]]]:
"""读取单个 JSON 文件并返回展开后的键集合以及末段->路径列表映射,若文件不存在或无法解析则返回空集合。"""
if not path.exists():
return set(), {}
data = json.loads(path.read_text(encoding="utf-8"))
paths = _flatten_keys(data)
last_map: dict[str, Set[str]] = {}
for p in paths:
last = p.split(".")[-1].lower()
last_map.setdefault(last, set()).add(p)
return paths, last_map
def _load_ods_columns(dsn: str) -> Dict[str, Set[str]]:
"""从数据库读取 billiards_ods.* 的列名集合,按表返回。"""
conn = psycopg2.connect(dsn)
cur = conn.cursor()
cur.execute(
"""
SELECT table_name, column_name
FROM information_schema.columns
WHERE table_schema='billiards_ods'
ORDER BY table_name, ordinal_position
"""
)
result: Dict[str, Set[str]] = {}
for table, col in cur.fetchall():
result.setdefault(table, set()).add(col.lower())
cur.close()
conn.close()
return result
def main() -> None:
"""主流程:遍历 FILE_MAPPING 中的 ODS 表,检查 JSON 键覆盖情况并打印报告。"""
dsn = os.environ.get("PG_DSN")
json_dir = pathlib.Path(os.environ.get("JSON_DOC_DIR", "export/test-json-doc"))
ods_cols_map = _load_ods_columns(dsn)
print(f"使用 JSON 目录: {json_dir}")
print(f"连接 DSN: {dsn}")
print("=" * 80)
for keywords, ods_table in ManualIngestTask.FILE_MAPPING:
table = ods_table.split(".")[-1]
cols = ods_cols_map.get(table, set())
file_name = f"{keywords[0]}.json"
file_path = json_dir / file_name
keys_full, path_map = _load_json_keys(file_path)
key_last_parts = set(path_map.keys())
missing: Set[str] = set()
extra_keys: Set[str] = set()
present: Set[str] = set()
for col in sorted(cols):
if col in key_last_parts:
present.add(col)
else:
missing.add(col)
for k in key_last_parts:
if k not in cols:
extra_keys.add(k)
print(f"[{table}] 文件={file_name} 列数={len(cols)} JSON键(末段)覆盖={len(present)}/{len(cols)}")
if missing:
print(" 未命中列:", ", ".join(sorted(missing)))
else:
print(" 未命中列: 无")
if extra_keys:
extras = []
for k in sorted(extra_keys):
paths = ", ".join(sorted(path_map.get(k, [])))
extras.append(f"{k} ({paths})")
print(" JSON 仅有(表无此列):", "; ".join(extras))
else:
print(" JSON 仅有(表无此列): 无")
print("-" * 80)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
"""验证DWS配置数据"""
import os
from pathlib import Path
from dotenv import load_dotenv
import psycopg2
def main():
load_dotenv(Path(__file__).parent.parent / ".env")
dsn = os.getenv("PG_DSN")
conn = psycopg2.connect(dsn)
tables = [
"cfg_performance_tier",
"cfg_assistant_level_price",
"cfg_bonus_rules",
"cfg_area_category",
"cfg_skill_type"
]
print("DWS 配置表数据统计:")
print("-" * 40)
with conn.cursor() as cur:
for t in tables:
cur.execute(f"SELECT COUNT(*) FROM billiards_dws.{t}")
cnt = cur.fetchone()[0]
print(f"{t}: {cnt}")
conn.close()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,205 @@
# -*- coding: utf-8 -*-
"""
比对 JSON 样本字段 vs API 参考文档(.md)字段。
找出 JSON 中存在但 .md 文档"四、响应字段详解"中缺失的字段。
特殊处理:
- settlement_records / recharge_settlements: 从 settleList 内层提取字段
siteProfile 子字段不提取ODS 中存为 siteprofile jsonb 列)
- stock_goods_category_tree: 从 goodsCategoryList 内层提取字段
- 嵌套对象siteProfile, tableProfile作为整体字段名
"""
import json
import os
import re
import sys
SAMPLES_DIR = os.path.join("docs", "api-reference", "samples")
DOCS_DIR = os.path.join("docs", "api-reference")
# 结构包装器字段(不应出现在比对中)
WRAPPER_FIELDS = {"settleList", "siteProfile", "tableProfile",
"goodsCategoryList", "data", "code", "msg",
"settlelist", "siteprofile", "tableprofile",
"goodscategorylist"}
# 表头关键字(跳过)— 注意 "type" 不能放这里,因为有些表有 type 业务字段
CROSS_REF_HEADERS = {"字段名", "类型", "示例值", "说明", "field", "example", "description"}
def extract_json_fields(table_name: str) -> set:
"""从 JSON 样本提取所有字段名(小写)"""
path = os.path.join(SAMPLES_DIR, f"{table_name}.json")
if not os.path.exists(path):
return set()
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
# settlement_records / recharge_settlements: settleList 内层
if table_name in ("settlement_records", "recharge_settlements"):
settle = data.get("settleList", {})
if isinstance(settle, list):
settle = settle[0] if settle else {}
fields = set()
for k in settle.keys():
kl = k.lower()
if kl in {"siteprofile"}:
fields.add(kl) # 作为整体 jsonb 列
continue
fields.add(kl)
return fields
# stock_goods_category_tree: goodsCategoryList 内层
if table_name == "stock_goods_category_tree":
cat_list = data.get("goodsCategoryList", [])
if cat_list:
return {k.lower() for k in cat_list[0].keys()
if k.lower() not in WRAPPER_FIELDS}
return set()
# role_area_association: roleAreaRelations 内层
if table_name == "role_area_association":
rel_list = data.get("roleAreaRelations", [])
if rel_list:
return {k.lower() for k in rel_list[0].keys()
if k.lower() not in WRAPPER_FIELDS}
return set()
# 通用:顶层字段
fields = set()
for k in data.keys():
kl = k.lower()
if kl in WRAPPER_FIELDS:
# 嵌套对象作为整体
if kl in ("siteprofile", "tableprofile"):
fields.add(kl)
continue
fields.add(kl)
return fields
def extract_md_fields(table_name: str) -> set:
"""从 .md 文档的"四、响应字段详解"章节提取字段名(小写)"""
md_path = os.path.join(DOCS_DIR, f"{table_name}.md")
if not os.path.exists(md_path):
return set()
with open(md_path, "r", encoding="utf-8") as f:
lines = f.readlines()
fields = set()
in_section = False
in_siteprofile = False
field_pattern = re.compile(r'^\|\s*`([^`]+)`\s*\|')
siteprofile_header = re.compile(r'^###.*siteProfile', re.IGNORECASE)
for line in lines:
s = line.strip()
if s.startswith("## 四、") and "响应字段" in s:
in_section = True
in_siteprofile = False
continue
if in_section and s.startswith("## ") and not s.startswith("## 四"):
break
if not in_section:
continue
# siteProfile 子章节处理
if table_name in ("settlement_records", "recharge_settlements"):
if siteprofile_header.search(s):
in_siteprofile = True
continue
if s.startswith("### ") and in_siteprofile:
if not siteprofile_header.search(s):
in_siteprofile = False
m = field_pattern.match(s)
if m:
raw = m.group(1).strip()
if raw.lower() in {h.lower() for h in CROSS_REF_HEADERS}:
continue
if table_name in ("settlement_records", "recharge_settlements"):
if in_siteprofile:
continue
if raw.startswith("siteProfile."):
continue
if raw.lower() in WRAPPER_FIELDS and raw.lower() not in ("siteprofile", "tableprofile"):
continue
fields.add(raw.lower())
return fields
def main():
samples = sorted([
f.replace(".json", "")
for f in os.listdir(SAMPLES_DIR)
if f.endswith(".json")
])
results = []
for table in samples:
json_fields = extract_json_fields(table)
md_fields = extract_md_fields(table)
# JSON 中有但 .md 中没有的
json_only = json_fields - md_fields
# .md 中有但 JSON 中没有的(可能是条件性字段,仅供参考)
md_only = md_fields - json_fields
results.append({
"table": table,
"json_count": len(json_fields),
"md_count": len(md_fields),
"json_only": sorted(json_only),
"md_only": sorted(md_only),
})
# 输出
print("=" * 80)
print("JSON 样本 vs .md 文档 字段比对报告")
print("=" * 80)
issues = 0
for r in results:
if r["json_only"]:
issues += 1
print(f"\n{r['table']} — JSON={r['json_count']}, MD={r['md_count']}")
print(f" JSON 中有但 .md 缺失 ({len(r['json_only'])} 个):")
for f in r["json_only"]:
print(f" - {f}")
if r["md_only"]:
print(f" .md 中有但 JSON 无 ({len(r['md_only'])} 个,可能是条件性字段):")
for f in r["md_only"]:
print(f" - {f}")
else:
status = "" if not r["md_only"] else "⚠️"
extra = ""
if r["md_only"]:
extra = f" (.md 多 {len(r['md_only'])} 个条件性字段)"
print(f"\n{status} {r['table']} — JSON={r['json_count']}, MD={r['md_count']}{extra}")
print(f"\n{'=' * 80}")
print(f"总计: {len(results)} 个表, {issues} 个有 JSON→MD 缺失")
# 输出 JSON 格式供后续处理
out_path = os.path.join("docs", "reports", "json_vs_md_gaps.json")
with open(out_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print(f"\n详细结果已写入: {out_path}")
if __name__ == "__main__":
main()
# AI_CHANGELOG:
# - 日期: 2026-02-14
# - Prompt: P20260214-044500 — "md文档和json数据不对应全面排查"
# - 直接原因: 用户要求全面排查 JSON 样本与 .md 文档的字段一致性
# - 变更摘要: 新建脚本,从 JSON 样本提取字段与 .md 文档"响应字段详解"章节比对;
# 修复 3 个 bugtype 过滤、siteProfile/tableProfile 例外、roleAreaRelations 包装器)
# - 风险与验证: 纯分析脚本,无运行时影响;运行 `python scripts/check_json_vs_md.py` 验证输出

View File

@@ -0,0 +1,381 @@
# -*- coding: utf-8 -*-
"""
比对 API 参考文档的 JSON 字段与 ODS 数据库表列,生成对比报告和 ALTER SQL。
支持 camelCase → snake_case 归一化匹配。
用法: python scripts/compare_api_ods.py
需要: psycopg2, python-dotenv
"""
import os, re, json, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from dotenv import load_dotenv
import psycopg2
load_dotenv()
PG_DSN = os.getenv("PG_DSN")
ENDPOINTS_DIR = os.path.join("docs", "api-reference", "endpoints")
REGISTRY_FILE = os.path.join("docs", "api-reference", "api_registry.json")
# ODS 元数据列ETL 框架自动添加,不属于 API 字段)
ODS_META_COLUMNS = {
"source_file", "source_endpoint", "fetched_at", "payload", "content_hash"
}
# JSON 类型 → 推荐 PG 类型映射
TYPE_MAP = {
"int": "bigint",
"float": "numeric(18,2)",
"string": "text",
"bool": "boolean",
"list": "jsonb",
"dict": "jsonb",
"object": "jsonb",
"array": "jsonb",
}
def camel_to_snake(name):
"""将 camelCase/PascalCase 转为 snake_case 小写"""
# 处理连续大写如 ABCDef → abc_def
s1 = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', name)
s2 = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', s1)
return s2.lower()
def normalize_field_name(name):
"""统一字段名camelCase → snake_case → 全小写"""
return camel_to_snake(name).replace(".", "_").strip("_")
def parse_api_fields(md_path):
"""从 API 文档 md 中解析响应字段表,返回 {原始字段名: json_type}
跳过嵌套对象的子字段(如 siteProfile.xxx"""
fields = {}
with open(md_path, "r", encoding="utf-8") as f:
content = f.read()
# 格式: | # | 字段名 | 类型 | 示例值 |
pattern = r"\|\s*\d+\s*\|\s*`([^`]+)`\s*\|\s*(\w+)\s*\|"
for m in re.finditer(pattern, content):
field_name = m.group(1).strip()
field_type = m.group(2).strip().lower()
# 跳过嵌套子字段(如 siteProfile.address
if "." in field_name:
continue
fields[field_name] = field_type
return fields
def get_ods_columns(cursor, table_name):
"""查询 ODS 表的列信息,返回 {column_name: data_type}"""
cursor.execute("""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_schema = 'billiards_ods' AND table_name = %s
ORDER BY ordinal_position
""", (table_name,))
cols = {}
for row in cursor.fetchall():
cols[row[0]] = row[1]
return cols
def suggest_pg_type(json_type):
"""根据 JSON 类型推荐 PG 类型"""
return TYPE_MAP.get(json_type, "text")
def compare_table(api_fields, ods_columns, table_name):
"""比对单张表,使用归一化名称匹配。
返回 (truly_missing, extra_in_ods, matched_pairs, case_matched)
- truly_missing: API 有但 ODS 确实没有的字段 {api_name: json_type}
- extra_in_ods: ODS 有但 API 没有的列 {col_name: pg_type}
- matched_pairs: 精确匹配的字段 [(api_name, ods_name)]
- case_matched: 通过归一化匹配的字段 [(api_name, ods_name)]
"""
# 排除 ODS 元数据列
ods_biz = {k: v for k, v in ods_columns.items() if k not in ODS_META_COLUMNS}
# 建立归一化索引
# api: normalized → (original_name, type)
api_norm = {}
for name, typ in api_fields.items():
norm = normalize_field_name(name)
api_norm[norm] = (name, typ)
# ods: normalized → (original_name, type)
ods_norm = {}
for name, typ in ods_biz.items():
norm = name.lower() # ODS 列名已经是小写
ods_norm[norm] = (name, typ)
matched_pairs = []
case_matched = []
api_matched_norms = set()
ods_matched_norms = set()
# 第一轮精确匹配API 字段名 == ODS 列名)
for api_name, api_type in api_fields.items():
if api_name in ods_biz:
matched_pairs.append((api_name, api_name))
api_matched_norms.add(normalize_field_name(api_name))
ods_matched_norms.add(api_name)
# 第二轮归一化匹配camelCase → snake_case
for norm_name, (api_name, api_type) in api_norm.items():
if norm_name in api_matched_norms:
continue
if norm_name in ods_norm:
ods_name = ods_norm[norm_name][0]
if ods_name not in ods_matched_norms:
case_matched.append((api_name, ods_name))
api_matched_norms.add(norm_name)
ods_matched_norms.add(ods_name)
# 第三轮:尝试去掉下划线的纯小写匹配
for norm_name, (api_name, api_type) in api_norm.items():
if norm_name in api_matched_norms:
continue
flat = norm_name.replace("_", "")
for ods_col, (ods_name, ods_type) in ods_norm.items():
if ods_name in ods_matched_norms:
continue
if ods_col.replace("_", "") == flat:
case_matched.append((api_name, ods_name))
api_matched_norms.add(norm_name)
ods_matched_norms.add(ods_name)
break
# 计算真正缺失和多余
truly_missing = {}
for norm_name, (api_name, api_type) in api_norm.items():
if norm_name not in api_matched_norms:
truly_missing[api_name] = api_type
extra_in_ods = {}
for ods_name, ods_type in ods_biz.items():
if ods_name not in ods_matched_norms:
extra_in_ods[ods_name] = ods_type
return truly_missing, extra_in_ods, matched_pairs, case_matched
def generate_alter_sql(table_name, missing_fields):
"""生成 ALTER TABLE ADD COLUMN SQL列名用 snake_case"""
sqls = []
for field_name, json_type in sorted(missing_fields.items()):
pg_type = suggest_pg_type(json_type)
col_name = normalize_field_name(field_name)
sqls.append(
f"ALTER TABLE billiards_ods.{table_name} ADD COLUMN IF NOT EXISTS "
f"{col_name} {pg_type}; -- API 字段: {field_name}"
)
return sqls
def main():
# 加载 API 注册表
with open(REGISTRY_FILE, "r", encoding="utf-8") as f:
registry = json.load(f)
# 建立 id → ods_table 映射
api_to_ods = {}
api_names = {}
for entry in registry:
if entry.get("ods_table") and not entry.get("skip"):
api_to_ods[entry["id"]] = entry["ods_table"]
api_names[entry["id"]] = entry.get("name_zh", entry["id"])
conn = psycopg2.connect(PG_DSN)
cursor = conn.cursor()
results = []
all_alter_sqls = []
for api_id, ods_table in sorted(api_to_ods.items()):
md_path = os.path.join(ENDPOINTS_DIR, f"{api_id}.md")
if not os.path.exists(md_path):
results.append({
"api_id": api_id, "name_zh": api_names.get(api_id, ""),
"ods_table": ods_table, "status": "NO_DOC",
"api_fields": 0, "ods_cols": 0,
})
continue
api_fields = parse_api_fields(md_path)
ods_columns = get_ods_columns(cursor, ods_table)
if not ods_columns:
results.append({
"api_id": api_id, "name_zh": api_names.get(api_id, ""),
"ods_table": ods_table, "status": "NO_TABLE",
"api_fields": len(api_fields), "ods_cols": 0,
})
continue
missing, extra, matched, case_matched = compare_table(
api_fields, ods_columns, ods_table
)
alter_sqls = generate_alter_sql(ods_table, missing)
all_alter_sqls.extend(alter_sqls)
ods_biz_count = len({k: v for k, v in ods_columns.items()
if k not in ODS_META_COLUMNS})
status = "OK" if not missing else "DRIFT"
results.append({
"api_id": api_id,
"name_zh": api_names.get(api_id, ""),
"ods_table": ods_table,
"status": status,
"api_fields": len(api_fields),
"ods_cols": ods_biz_count,
"exact_match": len(matched),
"case_match": len(case_matched),
"total_match": len(matched) + len(case_matched),
"missing_in_ods": missing,
"extra_in_ods": extra,
"case_matched_pairs": case_matched,
})
cursor.close()
conn.close()
# ── 输出 JSON 报告 ──
report_json = os.path.join("docs", "reports", "api_ods_comparison.json")
os.makedirs(os.path.dirname(report_json), exist_ok=True)
# 序列化时把 tuple 转 list
json_results = []
for r in results:
jr = dict(r)
if "case_matched_pairs" in jr:
jr["case_matched_pairs"] = [list(p) for p in jr["case_matched_pairs"]]
if "missing_in_ods" in jr:
jr["missing_in_ods"] = dict(jr["missing_in_ods"])
if "extra_in_ods" in jr:
jr["extra_in_ods"] = dict(jr["extra_in_ods"])
json_results.append(jr)
with open(report_json, "w", encoding="utf-8") as f:
json.dump(json_results, f, ensure_ascii=False, indent=2)
# ── 输出 Markdown 报告 ──
report_md = os.path.join("docs", "reports", "api_ods_comparison.md")
with open(report_md, "w", encoding="utf-8") as f:
f.write("# API JSON 字段 vs ODS 表列 对比报告\n\n")
f.write("> 自动生成于 2026-02-13 | 数据来源:数据库实际表结构 + API 参考文档\n")
f.write("> 比对逻辑camelCase → snake_case 归一化匹配 + 去下划线纯小写兜底\n\n")
# 汇总
ok_count = sum(1 for r in results if r["status"] == "OK")
drift_count = sum(1 for r in results if r["status"] == "DRIFT")
total_missing = sum(len(r.get("missing_in_ods", {})) for r in results)
total_extra = sum(len(r.get("extra_in_ods", {})) for r in results)
f.write("## 汇总\n\n")
f.write("| 指标 | 值 |\n|------|----|")
f.write(f"\n| 比对表数 | {len(results)} |")
f.write(f"\n| 完全一致(含大小写归一化) | {ok_count} |")
f.write(f"\n| 存在差异 | {drift_count} |")
f.write(f"\n| ODS 缺失字段总数 | {total_missing} |")
f.write(f"\n| ODS 多余列总数 | {total_extra} |")
f.write(f"\n| 生成 ALTER SQL 数 | {len(all_alter_sqls)} |\n\n")
# 总览表
f.write("## 逐表对比总览\n\n")
f.write("| # | API ID | 中文名 | ODS 表 | 状态 | API字段 | ODS列 | 精确匹配 | 大小写匹配 | ODS缺失 | ODS多余 |\n")
f.write("|---|--------|--------|--------|------|---------|-------|----------|-----------|---------|--------|\n")
for i, r in enumerate(results, 1):
missing_count = len(r.get("missing_in_ods", {}))
extra_count = len(r.get("extra_in_ods", {}))
exact = r.get("exact_match", 0)
case = r.get("case_match", 0)
icon = "" if r["status"] == "OK" else "⚠️" if r["status"] == "DRIFT" else ""
f.write(f"| {i} | {r['api_id']} | {r.get('name_zh','')} | {r['ods_table']} | "
f"{icon} | {r['api_fields']} | {r['ods_cols']} | {exact} | {case} | "
f"{missing_count} | {extra_count} |\n")
# 差异详情
has_drift = any(r["status"] == "DRIFT" for r in results)
if has_drift:
f.write("\n## 差异详情\n\n")
for r in results:
if r["status"] != "DRIFT":
continue
f.write(f"### {r.get('name_zh','')}`{r['ods_table']}`\n\n")
missing = r.get("missing_in_ods", {})
extra = r.get("extra_in_ods", {})
case_pairs = r.get("case_matched_pairs", [])
if case_pairs:
f.write("**大小写归一化匹配(已自动对齐,无需操作):**\n\n")
f.write("| API 字段名 (camelCase) | ODS 列名 (lowercase) |\n")
f.write("|----------------------|---------------------|\n")
for api_n, ods_n in sorted(case_pairs):
f.write(f"| `{api_n}` | `{ods_n}` |\n")
f.write("\n")
if missing:
f.write("**ODS 真正缺失的字段(需要 ADD COLUMN**\n\n")
f.write("| 字段名 | JSON 类型 | 建议 PG 列名 | 建议 PG 类型 |\n")
f.write("|--------|-----------|-------------|-------------|\n")
for fname, ftype in sorted(missing.items()):
f.write(f"| `{fname}` | {ftype} | `{normalize_field_name(fname)}` | {suggest_pg_type(ftype)} |\n")
f.write("\n")
if extra:
f.write("**ODS 多余的列API 中不存在):**\n\n")
f.write("| 列名 | PG 类型 | 可能原因 |\n")
f.write("|------|---------|--------|\n")
for cname, ctype in sorted(extra.items()):
f.write(f"| `{cname}` | {ctype} | ETL 自行添加 / 历史遗留 / API 新版已移除 |\n")
f.write("\n")
# ── 输出 ALTER SQL ──
sql_path = os.path.join("database", "migrations", "20260213_align_ods_with_api.sql")
os.makedirs(os.path.dirname(sql_path), exist_ok=True)
with open(sql_path, "w", encoding="utf-8") as f:
f.write("-- ============================================================\n")
f.write("-- ODS 表与 API JSON 字段对齐迁移\n")
f.write("-- 自动生成于 2026-02-13\n")
f.write("-- 基于: docs/api-reference/ 文档 vs billiards_ods 实际表结构\n")
f.write("-- 比对逻辑: camelCase → snake_case 归一化后再比较\n")
f.write("-- ============================================================\n\n")
if all_alter_sqls:
f.write("BEGIN;\n\n")
current_table = ""
for sql in all_alter_sqls:
# 提取表名做分组注释
tbl = sql.split("billiards_ods.")[1].split(" ")[0]
if tbl != current_table:
if current_table:
f.write("\n")
f.write(f"-- ── {tbl} ──\n")
current_table = tbl
f.write(sql + "\n")
f.write("\nCOMMIT;\n")
else:
f.write("-- 无需变更,所有 ODS 表已与 API JSON 字段对齐。\n")
print(f"[完成] 比对 {len(results)} 张表")
print(f" - 完全一致: {ok_count}")
print(f" - 存在差异: {drift_count}")
print(f" - ODS 缺失字段: {total_missing}")
print(f" - ODS 多余列: {total_extra}")
print(f" - ALTER SQL: {len(all_alter_sqls)}")
print(f" - 报告: {report_md}")
print(f" - JSON: {report_json}")
print(f" - SQL: {sql_path}")
if __name__ == "__main__":
main()
# AI_CHANGELOG:
# - 日期: 2026-02-13
# - Prompt: P20260213-210000 — "用新梳理的API返回的JSON文档比对数据库ODS层"
# - 直接原因: 用户要求比对 API 参考文档与 ODS 实际表结构,生成对比报告和 ALTER SQL
# - 变更摘要: 新建比对脚本,支持 camelCase→snake_case 归一化匹配,输出 MD/JSON 报告和迁移 SQL
# - 风险与验证: 纯分析脚本不修改数据库验证python scripts/compare_api_ods.py 检查输出

View File

@@ -0,0 +1,461 @@
# -*- coding: utf-8 -*-
"""
API 参考文档 vs ODS 实际表结构 对比脚本 (v2)
从 docs/api-reference/*.md 的 JSON 样例中提取字段,
查询 PostgreSQL billiards_ods 的实际列,
输出差异报告 JSON 和 Markdown + ALTER SQL。
用法: python scripts/compare_api_ods_v2.py
"""
import json
import os
import re
import sys
from datetime import datetime
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, ROOT)
from dotenv import load_dotenv
load_dotenv(os.path.join(ROOT, ".env"))
import psycopg2
# ODS 元列ETL 管理列,不来自 API
ODS_META_COLS = {
"source_file", "source_endpoint", "fetched_at",
"payload", "content_hash",
}
def load_registry():
"""加载 API 注册表"""
path = os.path.join(ROOT, "docs", "api-reference", "api_registry.json")
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def extract_fields_from_md(md_path, api_id):
"""
从 md 文件的 JSON 样例(五、响应样例)中提取所有字段名(小写)。
对 settlement_records / recharge_settlements 等嵌套结构,
提取 settleList 内层字段 + siteProfile 字段。
"""
with open(md_path, "r", encoding="utf-8") as f:
content = f.read()
# 提取所有 ```json ... ``` 代码块
json_blocks = re.findall(r'```json\s*\n(.*?)\n```', content, re.DOTALL)
if not json_blocks:
return None, None, "无 JSON 样例"
# 找到最大的 JSON 对象(响应样例通常是最大的)
sample_json = None
for block in json_blocks:
try:
parsed = json.loads(block)
if isinstance(parsed, dict):
if sample_json is None or len(str(parsed)) > len(str(sample_json)):
sample_json = parsed
except json.JSONDecodeError:
continue
if sample_json is None:
return None, None, "无法解析 JSON 样例"
fields = set()
has_nested = False
# settlement_records / recharge_settlements 嵌套结构:
# { "siteProfile": {...}, "settleList": {...} }
if "siteProfile" in sample_json and "settleList" in sample_json:
has_nested = True
sl = sample_json.get("settleList", {})
if isinstance(sl, dict):
for k in sl:
fields.add(k.lower())
return fields, has_nested, None
# CHANGE: stock_goods_category_tree 特殊结构处理
# intent: goodsCategoryList 是数组包装ODS 存储的是展平后的分类节点字段
# assumptions: 外层 total/goodsCategoryList 不是 ODS 列
if "goodsCategoryList" in sample_json and isinstance(sample_json["goodsCategoryList"], list):
has_nested = True
arr = sample_json["goodsCategoryList"]
if arr and isinstance(arr[0], dict):
_extract_flat(arr[0], fields)
return fields, has_nested, None
for k in sample_json:
fields.add(k.lower())
return fields, has_nested, None
def _extract_flat(obj, fields):
"""递归提取字典的标量字段名(跳过数组/嵌套对象值,但保留键名)"""
if not isinstance(obj, dict):
return
for k, v in obj.items():
fields.add(k.lower())
def get_all_ods_columns(conn):
"""查询所有 ODS 表的列信息"""
cur = conn.cursor()
cur.execute("""
SELECT table_name, column_name, data_type, ordinal_position
FROM information_schema.columns
WHERE table_schema = 'billiards_ods'
ORDER BY table_name, ordinal_position
""")
rows = cur.fetchall()
cur.close()
tables = {}
for table_name, col_name, data_type, pos in rows:
if table_name not in tables:
tables[table_name] = {}
tables[table_name][col_name] = {
"data_type": data_type,
"ordinal_position": pos,
}
return tables
def guess_pg_type(name):
"""根据字段名猜测 PostgreSQL 类型(用于 ALTER TABLE ADD COLUMN"""
n = name.lower()
if n == "id" or n.endswith("_id") or n.endswith("id"):
return "bigint"
money_kw = ["amount", "money", "price", "cost", "fee", "discount",
"deduct", "balance", "charge", "sale", "refund",
"promotion", "adjust", "rounding", "prepay", "income",
"royalty", "grade", "point", "stock", "num"]
for kw in money_kw:
if kw in n:
return "numeric(18,2)"
if "time" in n or "date" in n:
return "timestamp without time zone"
if n.startswith("is_") or (n.startswith("is") and len(n) > 2 and n[2].isupper()):
return "boolean"
if n.startswith("able_") or n.startswith("can"):
return "boolean"
int_kw = ["status", "type", "sort", "count", "seconds", "level",
"channel", "method", "way", "enabled", "switch", "delete",
"first", "single", "trash", "confirm", "clock", "cycle",
"delay", "free", "virtual", "online", "show", "audit",
"freeze", "send", "required", "scene", "range", "tag",
"on", "minutes", "number", "duration"]
for kw in int_kw:
if kw in n:
return "integer"
return "text"
def compare_one(api_entry, md_path, ods_tables):
"""比较单个 API 与其 ODS 表"""
api_id = api_entry["id"]
ods_table = api_entry.get("ods_table")
name_zh = api_entry.get("name_zh", "")
result = {
"api_id": api_id,
"name_zh": name_zh,
"ods_table": ods_table,
}
if not ods_table:
result["status"] = "skip"
result["reason"] = "无对应 ODS 表ods_table=null"
return result
if api_entry.get("skip"):
result["status"] = "skip"
result["reason"] = "接口标记为 skip暂不可用"
return result
# 提取 API JSON 样例字段
api_fields, has_nested, err = extract_fields_from_md(md_path, api_id)
if err:
result["status"] = "error"
result["reason"] = err
return result
# 获取 ODS 表列
if ods_table not in ods_tables:
result["status"] = "error"
result["reason"] = f"ODS 表 {ods_table} 不存在"
return result
ods_cols = ods_tables[ods_table]
ods_biz_cols = {c for c in ods_cols if c not in ODS_META_COLS}
# 比较
api_lower = {f.lower() for f in api_fields}
ods_lower = {c.lower() for c in ods_biz_cols}
# API 有但 ODS 没有的字段
api_only = sorted(api_lower - ods_lower)
# ODS 有但 API 没有的字段(非元列)
ods_only = sorted(ods_lower - api_lower)
# 两边都有的字段
matched = sorted(api_lower & ods_lower)
result["status"] = "ok" if not api_only else "drift"
result["has_nested_structure"] = has_nested
result["api_field_count"] = len(api_lower)
result["ods_biz_col_count"] = len(ods_biz_cols)
result["ods_total_col_count"] = len(ods_cols)
result["matched_count"] = len(matched)
result["api_only"] = api_only
result["api_only_count"] = len(api_only)
result["ods_only"] = ods_only
result["ods_only_count"] = len(ods_only)
result["matched"] = matched
return result
def generate_alter_sql(results, ods_tables):
"""生成 ALTER TABLE SQL 语句"""
sqls = []
for r in results:
if r.get("status") != "drift" or not r.get("api_only"):
continue
table = r["ods_table"]
for field in r["api_only"]:
pg_type = guess_pg_type(field)
sqls.append(
f"ALTER TABLE billiards_ods.{table} "
f"ADD COLUMN IF NOT EXISTS {field} {pg_type};"
)
return sqls
def generate_markdown_report(results, alter_sqls):
"""生成 Markdown 报告"""
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
lines = [
"# API 参考文档 vs ODS 实际表结构 对比报告 (v2)",
"",
f"> 生成时间:{now}",
"> 数据来源:`docs/api-reference/*.md` JSON 样例 vs `billiards_ods` 实际列",
"",
"---",
"",
"## 一、汇总",
"",
"| API 接口 | 中文名 | ODS 表 | 状态 | API 字段数 | ODS 业务列数 | 匹配 | API 独有 | ODS 独有 |",
"|----------|--------|--------|------|-----------|-------------|------|---------|---------|",
]
total_api_only = 0
total_ods_only = 0
ok_count = 0
drift_count = 0
skip_count = 0
error_count = 0
for r in results:
status = r.get("status", "?")
if status == "skip":
skip_count += 1
lines.append(
f"| {r['api_id']} | {r['name_zh']} | {r.get('ods_table', '-')} "
f"| ⏭️ 跳过 | - | - | - | - | - |"
)
continue
if status == "error":
error_count += 1
lines.append(
f"| {r['api_id']} | {r['name_zh']} | {r.get('ods_table', '-')} "
f"| ❌ 错误 | - | - | - | - | - |"
)
continue
api_only_n = r.get("api_only_count", 0)
ods_only_n = r.get("ods_only_count", 0)
total_api_only += api_only_n
total_ods_only += ods_only_n
if status == "ok":
ok_count += 1
badge = "✅ 对齐"
else:
drift_count += 1
badge = "⚠️ 漂移"
lines.append(
f"| {r['api_id']} | {r['name_zh']} | {r['ods_table']} "
f"| {badge} | {r['api_field_count']} | {r['ods_biz_col_count']} "
f"| {r['matched_count']} | {api_only_n} | {ods_only_n} |"
)
lines.extend([
"",
f"**统计**:对齐 {ok_count} / 漂移 {drift_count} / 跳过 {skip_count} / 错误 {error_count}",
f"**API 独有字段总计**{total_api_only}(需要 ALTER TABLE ADD COLUMN",
f"**ODS 独有列总计**{total_ods_only}API 中不存在,可能是历史遗留或 ETL 派生列)",
"",
])
# 详情:每个漂移表的字段差异
drift_results = [r for r in results if r.get("status") == "drift"]
if drift_results:
lines.extend(["---", "", "## 二、漂移详情", ""])
for r in drift_results:
lines.extend([
f"### {r['api_id']}{r['name_zh']})→ `{r['ods_table']}`",
"",
])
if r["api_only"]:
lines.append("**API 有 / ODS 缺**")
for f in r["api_only"]:
pg_type = guess_pg_type(f)
lines.append(f"- `{f}` → 建议类型 `{pg_type}`")
lines.append("")
if r["ods_only"]:
lines.append("**ODS 有 / API 无**(非元列):")
for f in r["ods_only"]:
lines.append(f"- `{f}`")
lines.append("")
# ODS 独有列详情(所有表)
ods_only_results = [r for r in results if r.get("ods_only") and r.get("status") in ("ok", "drift")]
if ods_only_results:
lines.extend(["---", "", "## 三、ODS 独有列详情API 中不存在)", ""])
for r in ods_only_results:
if not r["ods_only"]:
continue
lines.extend([
f"### `{r['ods_table']}`{r['name_zh']}",
"",
"| 列名 | 说明 |",
"|------|------|",
])
for f in r["ods_only"]:
lines.append(f"| `{f}` | ODS 独有API JSON 样例中不存在 |")
lines.append("")
# ALTER SQL
if alter_sqls:
lines.extend([
"---", "",
"## 四、ALTER SQL对齐 ODS 表结构)", "",
"```sql",
"-- 自动生成的 ALTER TABLE 语句",
f"-- 生成时间:{now}",
"-- 注意:类型为根据字段名猜测,请人工复核后执行",
"",
])
lines.extend(alter_sqls)
lines.extend(["", "```", ""])
return "\n".join(lines)
def main():
dsn = os.environ.get("PG_DSN")
if not dsn:
print("错误:未设置 PG_DSN 环境变量", file=sys.stderr)
sys.exit(1)
print("连接数据库...")
conn = psycopg2.connect(dsn)
print("查询 ODS 表结构...")
ods_tables = get_all_ods_columns(conn)
print(f"{len(ods_tables)} 张 ODS 表")
print("加载 API 注册表...")
registry = load_registry()
print(f"{len(registry)} 个 API 端点")
results = []
for entry in registry:
api_id = entry["id"]
ods_table = entry.get("ods_table")
md_path = os.path.join(ROOT, "docs", "api-reference", f"{api_id}.md")
if not os.path.exists(md_path):
results.append({
"api_id": api_id,
"name_zh": entry.get("name_zh", ""),
"ods_table": ods_table,
"status": "error",
"reason": f"文档不存在: {md_path}",
})
continue
r = compare_one(entry, md_path, ods_tables)
results.append(r)
status_icon = {"ok": "", "drift": "⚠️", "skip": "⏭️", "error": ""}.get(r["status"], "?")
extra = ""
if r.get("api_only_count"):
extra = f" (API独有: {r['api_only_count']})"
if r.get("ods_only_count"):
extra += f" (ODS独有: {r['ods_only_count']})"
print(f" {status_icon} {api_id}{ods_table or '-'}{extra}")
conn.close()
# 生成 ALTER SQL
alter_sqls = generate_alter_sql(results, ods_tables)
# 输出 JSON 报告
json_path = os.path.join(ROOT, "docs", "reports", "api_ods_comparison_v2.json")
os.makedirs(os.path.dirname(json_path), exist_ok=True)
with open(json_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print(f"\nJSON 报告: {json_path}")
# 输出 Markdown 报告
md_report = generate_markdown_report(results, alter_sqls)
md_path = os.path.join(ROOT, "docs", "reports", "api_ods_comparison_v2.md")
with open(md_path, "w", encoding="utf-8") as f:
f.write(md_report)
print(f"Markdown 报告: {md_path}")
# 输出 ALTER SQL 文件
if alter_sqls:
sql_path = os.path.join(ROOT, "database", "migrations",
"20260213_align_ods_with_api_v2.sql")
os.makedirs(os.path.dirname(sql_path), exist_ok=True)
with open(sql_path, "w", encoding="utf-8") as f:
f.write("-- API vs ODS 对齐迁移脚本 (v2)\n")
f.write(f"-- 生成时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write("-- 注意:类型为根据字段名猜测,请人工复核后执行\n\n")
f.write("BEGIN;\n\n")
for sql in alter_sqls:
f.write(sql + "\n")
f.write("\nCOMMIT;\n")
print(f"ALTER SQL: {sql_path}")
else:
print("无需 ALTER SQL所有表已对齐")
# 统计
ok_n = sum(1 for r in results if r.get("status") == "ok")
drift_n = sum(1 for r in results if r.get("status") == "drift")
skip_n = sum(1 for r in results if r.get("status") == "skip")
err_n = sum(1 for r in results if r.get("status") == "error")
print(f"\n汇总:对齐 {ok_n} / 漂移 {drift_n} / 跳过 {skip_n} / 错误 {err_n}")
print(f"ALTER SQL 语句数:{len(alter_sqls)}")
if __name__ == "__main__":
main()
# ──────────────────────────────────────────────
# AI_CHANGELOG:
# - 日期: 2026-02-13
# Prompt: P20260213-223000 — 用 API 参考文档比对数据库 ODS 实际表结构(重做,不依赖 DDL
# 直接原因: 前次比对脚本 stock_goods_category_tree 嵌套结构解析 bug需重写脚本
# 变更摘要: 完整重写脚本,从 api-reference/*.md JSON 样例提取字段,查询 PG billiards_ods 实际列,
# 处理三种特殊结构(标准/settleList 嵌套/goodsCategoryList 数组包装),输出 JSON+MD 报告
# 风险与验证: 纯分析脚本,不修改数据库;验证方式:运行脚本确认 "对齐 22 / 漂移 0"
# ──────────────────────────────────────────────

View File

@@ -0,0 +1,822 @@
#!/usr/bin/env python3
"""DDL 与数据库实际表结构对比脚本。
# AI_CHANGELOG [2026-02-13] 修复列名以 UNIQUE/CHECK 开头被误判为约束行的 bug新增 CREATE VIEW 解析支持(视图仅检查存在性)
解析 database/schema_*.sql 中的 CREATE TABLE 语句,
查询 information_schema.columns 获取数据库实际结构,
逐表逐字段对比并输出差异报告。
用法:
python scripts/compare_ddl_db.py --pg-dsn "postgresql://..." --schema billiards_ods --ddl-path database/schema_ODS_doc.sql
python scripts/compare_ddl_db.py --schema billiards_dwd --ddl-path database/schema_dwd_doc.sql # 从 .env 读取 PG_DSN
"""
from __future__ import annotations
import argparse
import os
import re
import sys
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Optional
class DiffKind(str, Enum):
"""差异分类枚举。"""
MISSING_TABLE = "MISSING_TABLE" # DDL 缺表数据库有DDL 没有)
EXTRA_TABLE = "EXTRA_TABLE" # DDL 多表DDL 有,数据库没有)
MISSING_COLUMN = "MISSING_COLUMN" # DDL 缺字段
EXTRA_COLUMN = "EXTRA_COLUMN" # DDL 多字段
TYPE_MISMATCH = "TYPE_MISMATCH" # 字段类型不一致
NULLABLE_MISMATCH = "NULLABLE_MISMATCH" # 可空约束不一致
@dataclass
class SchemaDiff:
"""单条差异记录。"""
kind: DiffKind
table: str
column: Optional[str] = None
ddl_value: Optional[str] = None
db_value: Optional[str] = None
def __str__(self) -> str:
parts = [f"[{self.kind.value}] {self.table}"]
if self.column:
parts.append(f".{self.column}")
if self.ddl_value is not None or self.db_value is not None:
parts.append(f" DDL={self.ddl_value} DB={self.db_value}")
return "".join(parts)
# ---------------------------------------------------------------------------
# DDL 列定义
# ---------------------------------------------------------------------------
@dataclass
class ColumnDef:
"""从 DDL 解析出的单个字段定义。"""
name: str
data_type: str # 标准化后的类型字符串
nullable: bool = True
is_pk: bool = False
default: Optional[str] = None
@dataclass
class TableDef:
"""从 DDL 解析出的单张表定义。"""
name: str # 不含 schema 前缀的表名(小写)
columns: dict[str, ColumnDef] = field(default_factory=dict)
pk_columns: list[str] = field(default_factory=list)
is_view: bool = False # 视图标记,跳过列级对比
# ---------------------------------------------------------------------------
# 类型标准化:将 DDL 类型和 information_schema 类型映射到统一表示
# ---------------------------------------------------------------------------
# PostgreSQL information_schema.data_type → 简写映射
_PG_TYPE_MAP: dict[str, str] = {
"bigint": "bigint",
"integer": "integer",
"smallint": "smallint",
"boolean": "boolean",
"text": "text",
"jsonb": "jsonb",
"json": "json",
"date": "date",
"bytea": "bytea",
"double precision": "double precision",
"real": "real",
"uuid": "uuid",
"timestamp without time zone": "timestamp",
"timestamp with time zone": "timestamptz",
"time without time zone": "time",
"time with time zone": "timetz",
"character varying": "varchar",
"character": "char",
"ARRAY": "array",
"USER-DEFINED": "user-defined",
}
def normalize_type(raw: str) -> str:
"""将 DDL 或 information_schema 中的类型字符串标准化为可比较的形式。
规则:
- 全部小写
- BIGINT / INT8 → bigint
- INTEGER / INT / INT4 → integer
- SMALLINT / INT2 → smallint
- BOOLEAN / BOOL → boolean
- VARCHAR(n) / CHARACTER VARYING(n) → varchar(n)
- CHAR(n) / CHARACTER(n) → char(n)
- NUMERIC(p,s) / DECIMAL(p,s) → numeric(p,s)
- SERIAL → integerserial 本质是 integer + sequence
- BIGSERIAL → bigint
- TIMESTAMP → timestamp
- TIMESTAMPTZ / TIMESTAMP WITH TIME ZONE → timestamptz
- TEXT → text
- JSONB → jsonb
"""
t = raw.strip().lower()
# 去掉多余空格
t = re.sub(r"\s+", " ", t)
# serial 家族 → 底层整数类型
if t == "bigserial":
return "bigint"
if t in ("serial", "serial4"):
return "integer"
if t == "smallserial":
return "smallint"
# 带精度的 numeric / decimal
m = re.match(r"(?:numeric|decimal)\s*\((\d+)\s*,\s*(\d+)\)", t)
if m:
return f"numeric({m.group(1)},{m.group(2)})"
m = re.match(r"(?:numeric|decimal)\s*\((\d+)\)", t)
if m:
return f"numeric({m.group(1)})"
if t in ("numeric", "decimal"):
return "numeric"
# varchar / character varying
m = re.match(r"(?:varchar|character varying)\s*\((\d+)\)", t)
if m:
return f"varchar({m.group(1)})"
if t in ("varchar", "character varying"):
return "varchar"
# char / character
m = re.match(r"(?:char|character)\s*\((\d+)\)", t)
if m:
return f"char({m.group(1)})"
if t in ("char", "character"):
return "char(1)"
# timestamp 家族
if t in ("timestamptz", "timestamp with time zone"):
return "timestamptz"
if t in ("timestamp", "timestamp without time zone"):
return "timestamp"
# 整数别名
if t in ("int8", "bigint"):
return "bigint"
if t in ("int", "int4", "integer"):
return "integer"
if t in ("int2", "smallint"):
return "smallint"
# 布尔
if t in ("bool", "boolean"):
return "boolean"
# information_schema 映射
if t in _PG_TYPE_MAP:
return _PG_TYPE_MAP[t]
return t
# ---------------------------------------------------------------------------
# DDL 解析器
# ---------------------------------------------------------------------------
# 匹配 CREATE TABLE [IF NOT EXISTS] [schema.]table_name (
_CREATE_TABLE_RE = re.compile(
r"CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?"
r"(?:(\w+)\.)?(\w+)\s*\(",
re.IGNORECASE,
)
# 匹配 DROP TABLE [IF EXISTS] [schema.]table_name [CASCADE];
_DROP_TABLE_RE = re.compile(
r"DROP\s+TABLE\s+(?:IF\s+EXISTS\s+)?(?:\w+\.)?(\w+)",
re.IGNORECASE,
)
# 匹配 CREATE [OR REPLACE] VIEW [schema.]view_name AS SELECT ...
_CREATE_VIEW_RE = re.compile(
r"CREATE\s+(?:OR\s+REPLACE\s+)?VIEW\s+"
r"(?:(\w+)\.)?(\w+)\s+AS\s+",
re.IGNORECASE,
)
def _strip_sql_comments(sql: str) -> str:
"""移除 SQL 单行注释(-- ...)和块注释(/* ... */)。"""
# 块注释
sql = re.sub(r"/\*.*?\*/", "", sql, flags=re.DOTALL)
# 单行注释
sql = re.sub(r"--[^\n]*", "", sql)
return sql
def _find_matching_paren(text: str, start: int) -> int:
"""从 start 位置(应为 '(')开始,找到匹配的 ')' 位置。
处理嵌套括号和字符串字面量中的括号。
"""
depth = 0
in_string = False
string_char = ""
i = start
while i < len(text):
ch = text[i]
if in_string:
if ch == string_char:
# 检查转义
if i + 1 < len(text) and text[i + 1] == string_char:
i += 2
continue
in_string = False
else:
if ch in ("'", '"'):
in_string = True
string_char = ch
elif ch == "(":
depth += 1
elif ch == ")":
depth -= 1
if depth == 0:
return i
i += 1
return -1
def _parse_column_line(line: str) -> Optional[ColumnDef]:
"""解析单行字段定义,返回 ColumnDef 或 None如果是约束行"""
line = line.strip().rstrip(",")
if not line:
return None
upper = line.upper()
# 跳过表级约束行
# 注意:需要区分约束行(如 "UNIQUE (...)")和以约束关键字开头的列名
# (如 "unique_customers INTEGER"、"check_status INT"
# 约束行的关键字后面紧跟空格+左括号或直接左括号,而列名后面跟下划线或字母
if re.match(
r"(?:PRIMARY\s+KEY|UNIQUE|CHECK|FOREIGN\s+KEY|EXCLUDE)"
r"(?:\s*\(|\s+(?![\w]))",
upper,
) or upper.startswith("CONSTRAINT"):
return None
# 字段名 类型 [约束...]
# 字段名可能被双引号包裹
m = re.match(r'(?:"([^"]+)"|(\w+))\s+(.+)', line)
if not m:
return None
col_name = (m.group(1) or m.group(2)).lower()
rest = m.group(3).strip()
# 提取类型:取到第一个(位置最靠前的)已知约束关键字或行尾
# 类型可能包含括号,如 NUMERIC(18,2)、VARCHAR(50)
type_end_keywords = [
"NOT NULL", "NULL", "DEFAULT", "PRIMARY KEY", "UNIQUE",
"REFERENCES", "CHECK", "CONSTRAINT", "GENERATED",
]
type_str = rest
constraint_part = ""
# 找所有关键字中位置最靠前的
best_idx = len(rest)
for kw in type_end_keywords:
idx = rest.upper().find(kw)
if idx > 0 and idx < best_idx:
candidate = rest[:idx].strip()
if candidate:
best_idx = idx
if best_idx < len(rest):
type_str = rest[:best_idx].strip()
constraint_part = rest[best_idx:]
# 去掉类型末尾的逗号
type_str = type_str.rstrip(",").strip()
nullable = True
if "NOT NULL" in constraint_part.upper():
nullable = False
is_pk = "PRIMARY KEY" in constraint_part.upper()
# 提取 DEFAULT 值
default_val = None
dm = re.search(r"DEFAULT\s+(.+?)(?:\s+(?:NOT\s+NULL|NULL|PRIMARY|UNIQUE|REFERENCES|CHECK|CONSTRAINT|,|$))",
constraint_part, re.IGNORECASE)
if dm:
default_val = dm.group(1).strip().rstrip(",")
return ColumnDef(
name=col_name,
data_type=normalize_type(type_str),
nullable=nullable,
is_pk=is_pk,
default=default_val,
)
def _extract_pk_from_body(body: str) -> list[str]:
"""从 CREATE TABLE 体中提取表级 PRIMARY KEY 约束的列名列表。"""
# PRIMARY KEY (col1, col2, ...)
# 也可能是 CONSTRAINT xxx PRIMARY KEY (col1, col2)
m = re.search(r"PRIMARY\s+KEY\s*\(([^)]+)\)", body, re.IGNORECASE)
if not m:
return []
cols_str = m.group(1)
return [c.strip().strip('"').lower() for c in cols_str.split(",")]
def parse_ddl(sql_text: str, target_schema: Optional[str] = None) -> dict[str, TableDef]:
"""解析 DDL 文本,提取所有 CREATE TABLE 定义。
Args:
sql_text: 完整的 SQL DDL 文本
target_schema: 如果指定,只保留该 schema 下的表(或无 schema 前缀的表)
Returns:
{表名(小写): TableDef} 字典
"""
# 先收集被 DROP 的表名,后续 CREATE 会覆盖
cleaned = _strip_sql_comments(sql_text)
tables: dict[str, TableDef] = {}
# 逐个匹配 CREATE TABLE
for m in _CREATE_TABLE_RE.finditer(cleaned):
schema_part = m.group(1)
table_name = m.group(2).lower()
# schema 过滤
if target_schema:
ts = target_schema.lower()
if schema_part and schema_part.lower() != ts:
continue
# 无 schema 前缀的表也接受DWD DDL 中 SET search_path 后不带前缀)
# 找到 CREATE TABLE ... ( 的左括号位置
paren_start = m.end() - 1 # m.end() 指向 '(' 后一位
paren_end = _find_matching_paren(cleaned, paren_start)
if paren_end < 0:
continue
body = cleaned[paren_start + 1: paren_end]
# 按行解析字段
table_def = TableDef(name=table_name)
# 提取表级 PRIMARY KEY
pk_cols = _extract_pk_from_body(body)
# 逐行解析
for raw_line in body.split("\n"):
col = _parse_column_line(raw_line)
if col:
table_def.columns[col.name] = col
# 合并表级 PK 信息
if pk_cols:
table_def.pk_columns = pk_cols
for pk_col in pk_cols:
if pk_col in table_def.columns:
table_def.columns[pk_col].is_pk = True
# PK 隐含 NOT NULL
table_def.columns[pk_col].nullable = False
# 合并内联 PK
inline_pk = [c.name for c in table_def.columns.values() if c.is_pk]
if inline_pk and not table_def.pk_columns:
table_def.pk_columns = inline_pk
for pk_col in inline_pk:
table_def.columns[pk_col].nullable = False
tables[table_name] = table_def
# 解析 CREATE VIEW仅标记视图存在列信息由数据库侧提供
for m in _CREATE_VIEW_RE.finditer(cleaned):
schema_part = m.group(1)
view_name = m.group(2).lower()
if target_schema:
ts = target_schema.lower()
if schema_part and schema_part.lower() != ts:
continue
if view_name not in tables:
# 视图仅标记存在,不解析列(列由底层表决定)
tables[view_name] = TableDef(name=view_name)
# 标记为视图,跳过列级对比
tables[view_name].is_view = True
return tables
# ---------------------------------------------------------------------------
# 数据库 schema 读取
# ---------------------------------------------------------------------------
@dataclass
class DbColumnInfo:
"""从 information_schema 查询到的字段信息。"""
name: str
data_type: str # 标准化后
nullable: bool
is_pk: bool = False
def fetch_db_schema(pg_dsn: str, schema_name: str) -> dict[str, TableDef]:
"""从数据库 information_schema 查询指定 schema 的所有表和字段。
Returns:
{表名(小写): TableDef} 字典
"""
import psycopg2
conn = psycopg2.connect(pg_dsn)
try:
with conn.cursor() as cur:
# 检查 schema 是否存在
cur.execute(
"SELECT 1 FROM information_schema.schemata WHERE schema_name = %s",
(schema_name,),
)
if not cur.fetchone():
print(f"⚠ schema '{schema_name}' 在数据库中不存在,跳过", file=sys.stderr)
return {}
# 查询所有列信息
cur.execute("""
SELECT
c.table_name,
c.column_name,
c.data_type,
c.is_nullable,
c.character_maximum_length,
c.numeric_precision,
c.numeric_scale,
c.udt_name
FROM information_schema.columns c
WHERE c.table_schema = %s
ORDER BY c.table_name, c.ordinal_position
""", (schema_name,))
rows = cur.fetchall()
# 查询主键信息
cur.execute("""
SELECT
tc.table_name,
kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.table_schema = %s
AND tc.constraint_type = 'PRIMARY KEY'
ORDER BY tc.table_name, kcu.ordinal_position
""", (schema_name,))
pk_rows = cur.fetchall()
finally:
conn.close()
# 构建 PK 映射: {table_name: [col1, col2, ...]}
pk_map: dict[str, list[str]] = {}
for tbl, col in pk_rows:
pk_map.setdefault(tbl.lower(), []).append(col.lower())
# 构建 TableDef
tables: dict[str, TableDef] = {}
for tbl, col_name, data_type, is_nullable, char_max_len, num_prec, num_scale, udt_name in rows:
tbl_lower = tbl.lower()
col_lower = col_name.lower()
if tbl_lower not in tables:
tables[tbl_lower] = TableDef(
name=tbl_lower,
pk_columns=pk_map.get(tbl_lower, []),
)
# 构建精确类型字符串
type_str = _build_db_type_string(data_type, char_max_len, num_prec, num_scale, udt_name)
is_pk = col_lower in pk_map.get(tbl_lower, [])
nullable = is_nullable == "YES"
tables[tbl_lower].columns[col_lower] = ColumnDef(
name=col_lower,
data_type=normalize_type(type_str),
nullable=nullable,
is_pk=is_pk,
)
return tables
def _build_db_type_string(
data_type: str,
char_max_len: Optional[int],
num_prec: Optional[int],
num_scale: Optional[int],
udt_name: str,
) -> str:
"""根据 information_schema 字段构建可比较的类型字符串。"""
dt = data_type.lower()
# character varying → varchar(n)
if dt == "character varying":
if char_max_len:
return f"varchar({char_max_len})"
return "varchar"
# character → char(n)
if dt == "character":
if char_max_len:
return f"char({char_max_len})"
return "char(1)"
# numeric → numeric(p,s)
if dt == "numeric":
if num_prec is not None and num_scale is not None:
return f"numeric({num_prec},{num_scale})"
if num_prec is not None:
return f"numeric({num_prec})"
return "numeric"
# USER-DEFINED → 使用 udt_name如 jsonb, geometry 等)
if dt == "user-defined":
return udt_name.lower()
# ARRAY → 使用 udt_name 去掉前缀 _
if dt == "array":
base = udt_name.lstrip("_").lower()
return f"{base}[]"
return dt
# ---------------------------------------------------------------------------
# 对比逻辑
# ---------------------------------------------------------------------------
def compare_tables(
ddl_tables: dict[str, TableDef],
db_tables: dict[str, TableDef],
) -> list[SchemaDiff]:
"""对比 DDL 定义与数据库实际结构,返回差异列表。
差异分类:
- MISSING_TABLE: 数据库有但 DDL 没有
- EXTRA_TABLE: DDL 有但数据库没有
- MISSING_COLUMN: 数据库有但 DDL 没有的字段
- EXTRA_COLUMN: DDL 有但数据库没有的字段
- TYPE_MISMATCH: 字段类型不一致
- NULLABLE_MISMATCH: 可空约束不一致
"""
diffs: list[SchemaDiff] = []
all_tables = sorted(set(ddl_tables.keys()) | set(db_tables.keys()))
for tbl in all_tables:
in_ddl = tbl in ddl_tables
in_db = tbl in db_tables
if in_db and not in_ddl:
diffs.append(SchemaDiff(kind=DiffKind.MISSING_TABLE, table=tbl))
continue
if in_ddl and not in_db:
diffs.append(SchemaDiff(kind=DiffKind.EXTRA_TABLE, table=tbl))
continue
# 两边都有,逐字段对比
# 视图仅检查存在性,跳过列级对比
ddl_def = ddl_tables[tbl]
if getattr(ddl_def, 'is_view', False):
continue
ddl_cols = ddl_def.columns
db_cols = db_tables[tbl].columns
all_cols = sorted(set(ddl_cols.keys()) | set(db_cols.keys()))
for col in all_cols:
col_in_ddl = col in ddl_cols
col_in_db = col in db_cols
if col_in_db and not col_in_ddl:
diffs.append(SchemaDiff(
kind=DiffKind.MISSING_COLUMN,
table=tbl,
column=col,
db_value=db_cols[col].data_type,
))
continue
if col_in_ddl and not col_in_db:
diffs.append(SchemaDiff(
kind=DiffKind.EXTRA_COLUMN,
table=tbl,
column=col,
ddl_value=ddl_cols[col].data_type,
))
continue
# 两边都有,比较类型
ddl_type = ddl_cols[col].data_type
db_type = db_cols[col].data_type
# 视图列从 DDL 解析时类型为 unknown跳过类型比较
if ddl_type != db_type and ddl_type != "unknown":
diffs.append(SchemaDiff(
kind=DiffKind.TYPE_MISMATCH,
table=tbl,
column=col,
ddl_value=ddl_type,
db_value=db_type,
))
# 比较可空性(视图列跳过)
ddl_nullable = ddl_cols[col].nullable
db_nullable = db_cols[col].nullable
if ddl_nullable != db_nullable and ddl_type != "unknown":
diffs.append(SchemaDiff(
kind=DiffKind.NULLABLE_MISMATCH,
table=tbl,
column=col,
ddl_value="NULL" if ddl_nullable else "NOT NULL",
db_value="NULL" if db_nullable else "NOT NULL",
))
return diffs
def compare_schema(ddl_path: str, schema_name: str, pg_dsn: str) -> list[SchemaDiff]:
"""对比 DDL 文件与数据库 schema 的完整流程。
Args:
ddl_path: DDL 文件路径
schema_name: 数据库 schema 名称
pg_dsn: PostgreSQL 连接字符串
Returns:
差异列表
"""
path = Path(ddl_path)
if not path.exists():
print(f"✗ DDL 文件不存在: {ddl_path}", file=sys.stderr)
return []
sql_text = path.read_text(encoding="utf-8")
ddl_tables = parse_ddl(sql_text, target_schema=schema_name)
if not ddl_tables:
print(f"⚠ DDL 文件中未解析到任何表: {ddl_path}", file=sys.stderr)
db_tables = fetch_db_schema(pg_dsn, schema_name)
return compare_tables(ddl_tables, db_tables)
# ---------------------------------------------------------------------------
# 报告输出
# ---------------------------------------------------------------------------
def print_report(diffs: list[SchemaDiff], schema_name: str, ddl_path: str) -> None:
"""按表分组输出差异报告到控制台。"""
if not diffs:
print(f"\n{schema_name} ({ddl_path}): 无差异")
return
print(f"\n{'='*60}")
print(f" 差异报告: {schema_name}{ddl_path}")
print(f"{len(diffs)} 项差异")
print(f"{'='*60}")
# 按表分组
by_table: dict[str, list[SchemaDiff]] = {}
for d in diffs:
by_table.setdefault(d.table, []).append(d)
for tbl in sorted(by_table.keys()):
items = by_table[tbl]
print(f"\n{tbl}")
for d in items:
icon = {
DiffKind.MISSING_TABLE: "🔴 DDL 缺表",
DiffKind.EXTRA_TABLE: "🟡 DDL 多表",
DiffKind.MISSING_COLUMN: "🔴 DDL 缺字段",
DiffKind.EXTRA_COLUMN: "🟡 DDL 多字段",
DiffKind.TYPE_MISMATCH: "🟠 类型不一致",
DiffKind.NULLABLE_MISMATCH: "🔵 可空不一致",
}.get(d.kind, d.kind.value)
if d.column:
detail = f" {icon}: {d.column}"
else:
detail = f" {icon}"
if d.ddl_value is not None or d.db_value is not None:
detail += f" (DDL={d.ddl_value}, DB={d.db_value})"
print(detail)
print()
# ---------------------------------------------------------------------------
# CLI 入口
# ---------------------------------------------------------------------------
# 预定义的 schema → DDL 文件映射
DEFAULT_SCHEMA_MAP: dict[str, str] = {
"billiards_ods": "database/schema_ODS_doc.sql",
"billiards_dwd": "database/schema_dwd_doc.sql",
"billiards_dws": "database/schema_dws.sql",
"etl_admin": "database/schema_etl_admin.sql",
}
def main(argv: list[str] | None = None) -> int:
parser = argparse.ArgumentParser(
description="对比 DDL 文件与数据库实际表结构",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
# 对比单个 schema
python scripts/compare_ddl_db.py --schema billiards_ods --ddl-path database/schema_ODS_doc.sql
# 对比所有预定义 schema从 .env 读取 PG_DSN
python scripts/compare_ddl_db.py --all
# 指定连接字符串
python scripts/compare_ddl_db.py --all --pg-dsn "postgresql://user:pass@host/db"
""",
)
parser.add_argument("--pg-dsn", help="PostgreSQL 连接字符串(默认从 PG_DSN 环境变量或 .env 读取)")
parser.add_argument("--schema", help="要对比的 schema 名称")
parser.add_argument("--ddl-path", help="DDL 文件路径")
parser.add_argument("--all", action="store_true", help="对比所有预定义 schema")
args = parser.parse_args(argv)
# 加载 .env
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
pg_dsn = args.pg_dsn or os.environ.get("PG_DSN")
if not pg_dsn:
print("✗ 未提供 PG_DSN请通过 --pg-dsn 参数或 PG_DSN 环境变量指定", file=sys.stderr)
return 1
# 确定要对比的 schema 列表
pairs: list[tuple[str, str]] = []
if args.all:
for schema, ddl in DEFAULT_SCHEMA_MAP.items():
pairs.append((schema, ddl))
elif args.schema and args.ddl_path:
pairs.append((args.schema, args.ddl_path))
elif args.schema:
# 尝试从预定义映射中查找
ddl = DEFAULT_SCHEMA_MAP.get(args.schema)
if ddl:
pairs.append((args.schema, ddl))
else:
print(f"✗ 未知 schema '{args.schema}',请通过 --ddl-path 指定 DDL 文件", file=sys.stderr)
return 1
else:
parser.print_help()
return 1
total_diffs = 0
for schema_name, ddl_path in pairs:
if not Path(ddl_path).exists():
print(f"⚠ DDL 文件不存在,跳过: {ddl_path}", file=sys.stderr)
continue
try:
diffs = compare_schema(ddl_path, schema_name, pg_dsn)
except Exception as e:
print(f"✗ 对比 {schema_name} 时出错: {e}", file=sys.stderr)
continue
print_report(diffs, schema_name, ddl_path)
total_diffs += len(diffs)
if total_diffs > 0:
print(f"共发现 {total_diffs} 项差异")
return 1
print("所有 schema 对比通过,无差异 ✓")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,373 @@
# -*- coding: utf-8 -*-
"""
比对 ODS 数据库实际列 vs docs/api-reference/summary/*.md 文档中的响应字段。
改进版:
1. 只提取"响应字段详解"章节的字段(排除请求参数)
2. 同时用 camelCase 原名和 snake_case 转换名做双向匹配
3. 对 ODS 连写小写列名(如 siteid也尝试匹配 camelCase如 siteId
用法: python scripts/compare_ods_vs_summary_v2.py
"""
import os, re, sys, json
from pathlib import Path
from dotenv import load_dotenv
import psycopg2
load_dotenv()
SUMMARY_DIR = Path("docs/api-reference/summary")
ODS_SCHEMA = "billiards_ods"
META_COLS = {"source_file", "source_endpoint", "fetched_at", "payload", "content_hash"}
# CHANGE P20260214-170000: 从全局黑名单移除 start_time/end_time/starttime/endtime
# intent: 这些字段在部分 API 中是请求参数,但在 assistant_accounts_master、
# group_buy_packages、member_stored_value_cards 中是真正的响应业务字段。
# 全局过滤会导致误报"ODS有/MD无"。
# assumptions: 请求参数的 startTime/endTime 不会出现在"响应字段详解"章节中
# extract_response_fields 已限定只提取该章节),因此无需在此处过滤。
# 请求参数(不应出现在 ODS 列比对中)
# 注意start_time/end_time 不在此列表中——它们在多张表中是响应业务字段,
# 而作为请求参数时已被 extract_response_fields 的章节限定逻辑排除。
REQUEST_PARAMS = {
"page", "limit",
"rangestarttime", "rangeendtime", "range_start_time", "range_end_time",
"startpaytime", "endpaytime", "start_pay_time", "end_pay_time",
"siteid_param", "settletype_param", "paymentmethod_param",
"isfirst_param", "goodssalestype", "goods_sales_type",
"issalesbind", "is_sales_bind", "existsgoodsstock", "exists_goods_stock",
"goodssecondcategoryid_param", "goodsstate_param",
"querytype", "query_type", "issalemanuser", "is_sale_man_user",
"couponusestatus", "coupon_use_status",
"total", # 分页 total 不是业务字段
}
# CHANGE P20260214-210000: 添加包装器/容器字段忽略列表
# intent: 某些 API 响应中的顶层字段是数组/对象容器(如 goodsCategoryList
# ODS 穿透存储其子元素而非容器本身MD 文档中记录了容器字段但 ODS 无对应列
# assumptions: 这些字段在 ODS 中不建列,其子元素已被展开存储
WRAPPER_FIELDS = {
"goodscategorylist", # stock_goods_category_tree: 分类树的上级数组节点
}
DSN = os.getenv("PG_DSN") or os.getenv("DATABASE_URL")
if not DSN:
print("ERROR: 需要设置 PG_DSN 或 DATABASE_URL 环境变量", file=sys.stderr)
sys.exit(1)
def get_ods_columns(conn):
cur = conn.cursor()
cur.execute("""
SELECT table_name, column_name
FROM information_schema.columns
WHERE table_schema = %s
ORDER BY table_name, ordinal_position
""", (ODS_SCHEMA,))
result = {}
for table_name, col_name in cur.fetchall():
result.setdefault(table_name, set()).add(col_name)
cur.close()
return result
def camel_to_snake(name):
"""camelCase / PascalCase → snake_case"""
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
def extract_response_fields(md_path: Path) -> set:
"""
只提取"四、响应字段详解"章节中的字段名。
排除请求参数和 siteProfile 子字段。
"""
text = md_path.read_text(encoding="utf-8")
fields = set()
# 找到"响应字段详解"章节的起始位置
response_start = None
for pattern in [
r'##\s*四、响应字段详解',
r'##\s*四、.*响应字段',
r'##\s*响应字段详解',
r'###\s*4\.',
]:
m = re.search(pattern, text)
if m:
response_start = m.start()
break
if response_start is None:
# 回退:提取所有表格字段
response_text = text
else:
# 找到下一个同级章节(## 五、或 ## 五 或文件结尾)
next_section = re.search(r'\n##\s*(五|六|七|八|九|十|5|6|7|8|9)', text[response_start + 10:])
if next_section:
response_text = text[response_start:response_start + 10 + next_section.start()]
else:
response_text = text[response_start:]
# 从响应字段章节提取表格中的字段名
# 匹配 | `fieldName` | 或 | fieldName | 格式
table_pattern = re.compile(
r'^\|\s*`?([a-zA-Z_][a-zA-Z0-9_]*)`?\s*\|',
re.MULTILINE
)
# CHANGE P20260214-200000: 用分隔行检测替代 skip_words 硬编码
# intent: skip_words 方式会误杀与表头词同名的业务字段(如 remark、type、note
# 改为利用 Markdown 表格固定结构(表头行 → 分隔行 → 数据行)来跳过表头
# assumptions: 所有 summary MD 文档的表格均遵循标准 Markdown 格式,
# 分隔行匹配 |---...| 模式,分隔行的前一行即为表头行
separator_pattern = re.compile(r'^\|[\s\-:|]+\|', re.MULTILINE)
lines = response_text.split('\n')
# 标记哪些行是表头行(分隔行的前一行)
header_lines = set()
for i, line in enumerate(lines):
if separator_pattern.match(line) and i > 0:
header_lines.add(i - 1)
# 跟踪是否在 siteProfile/tableProfile 子字段展开区域中
# CHANGE P20260214-210000: 修复 siteProfile 子节跳过逻辑
# intent: 之前的逻辑会跳过整个 siteProfile 子节(包括 siteProfile 字段本身),
# 但 siteProfile 作为 object/jsonb 字段应该被提取,只需跳过其展开的子字段
# assumptions: siteProfile/tableProfile 子节标题后紧跟的表格中,第一行是 siteProfile 字段本身
# (应保留),后续行是展开的子字段(应跳过)。
# 如果子节只有一行siteProfile 本身),则不跳过任何内容。
in_site_profile = False
site_profile_field_seen = False
for i, line in enumerate(lines):
# 检测 siteProfile/tableProfile 子节标题
if re.search(r'siteProfile|门店信息快照|tableProfile|台桌信息快照', line, re.IGNORECASE):
if '###' in line or '####' in line:
in_site_profile = True
site_profile_field_seen = False
continue
# 检测离开 siteProfile 子节(遇到下一个同级或更高级标题)
if in_site_profile and re.match(r'\s*#{2,4}\s+', line):
if not re.search(r'siteProfile|tableProfile|门店信息快照|台桌信息快照', line, re.IGNORECASE):
in_site_profile = False
site_profile_field_seen = False
# 在 siteProfile 子节中:保留 siteProfile/tableProfile 字段本身,跳过展开的子字段
if in_site_profile:
m_check = table_pattern.match(line)
if m_check:
field_name = m_check.group(1).strip().lower()
if field_name in ('siteprofile', 'tableprofile') and not site_profile_field_seen:
# 这是 siteProfile/tableProfile 字段本身,保留(不跳过)
site_profile_field_seen = True
# 不 continue让下面的提取逻辑处理
else:
# 这是展开的子字段,跳过
continue
elif i not in header_lines and not separator_pattern.match(line):
# 非表格行(空行、标题等),不跳过
pass
# 跳过表头行(分隔行的前一行)和分隔行本身
if i in header_lines or separator_pattern.match(line):
continue
m = table_pattern.match(line)
if m:
field = m.group(1).strip()
if not field.startswith('---'):
fields.add(field)
return fields
def match_fields(md_fields: set, ods_cols: set):
"""
智能匹配 MD 字段和 ODS 列。
返回 (matched, md_only, ods_only)
"""
matched = set()
md_remaining = set()
ods_remaining = set(ods_cols)
# 构建 ODS 列的查找索引
ods_lower = {c.lower(): c for c in ods_cols}
# 也构建去下划线版本 → 原名映射
ods_no_underscore = {}
for c in ods_cols:
key = c.lower().replace("_", "")
ods_no_underscore.setdefault(key, c)
for field in md_fields:
field_lower = field.lower()
field_snake = camel_to_snake(field).lower()
field_no_sep = field_lower.replace("_", "")
found = False
# 1. 精确匹配(小写)
if field_lower in ods_lower:
matched.add((field, ods_lower[field_lower]))
ods_remaining.discard(ods_lower[field_lower])
found = True
# 2. snake_case 匹配
elif field_snake in ods_lower:
matched.add((field, ods_lower[field_snake]))
ods_remaining.discard(ods_lower[field_snake])
found = True
# 3. 去下划线匹配(处理 camelCase vs 连写小写)
elif field_no_sep in ods_no_underscore:
matched.add((field, ods_no_underscore[field_no_sep]))
ods_remaining.discard(ods_no_underscore[field_no_sep])
found = True
if not found:
md_remaining.add(field)
return matched, md_remaining, ods_remaining
def is_request_param(field: str) -> bool:
"""判断字段是否为请求参数"""
f = field.lower().replace("_", "")
return f in {p.replace("_", "") for p in REQUEST_PARAMS}
def main():
conn = psycopg2.connect(DSN)
ods_tables = get_ods_columns(conn)
conn.close()
md_files = sorted(SUMMARY_DIR.glob("*.md"))
report = []
for md_path in md_files:
table_name = md_path.stem
md_fields_raw = extract_response_fields(md_path)
# 过滤请求参数和包装器字段
md_fields = {f for f in md_fields_raw
if not is_request_param(f)
and f.lower() not in WRAPPER_FIELDS}
if table_name not in ods_tables:
report.append({
"table": table_name,
"status": "NO_ODS_TABLE",
"md_fields_count": len(md_fields),
"note": "summary 文档存在但 ODS 中无对应表"
})
continue
ods_cols = ods_tables[table_name] - META_COLS
matched, md_only, ods_only = match_fields(md_fields, ods_cols)
if md_only or ods_only:
report.append({
"table": table_name,
"status": "DIFF",
"ods_count": len(ods_cols),
"md_count": len(md_fields),
"matched": len(matched),
"md_only": sorted(md_only),
"ods_only": sorted(ods_only),
})
else:
report.append({
"table": table_name,
"status": "MATCH",
"ods_count": len(ods_cols),
"md_count": len(md_fields),
"matched": len(matched),
})
# 检查 ODS 中有但 summary 中没有的表
md_table_names = {p.stem for p in md_files}
for t in sorted(ods_tables.keys()):
if t not in md_table_names:
report.append({
"table": t,
"status": "NO_MD_FILE",
"ods_count": len(ods_tables[t] - META_COLS),
"note": "ODS 表存在但无对应 summary 文档"
})
# 输出
print(f"\n{'='*70}")
print(f"ODS vs Summary 字段比对报告 (v2 — 仅响应字段,智能匹配)")
print(f"ODS 表数: {len(ods_tables)} | Summary 文档数: {len(md_files)}")
print(f"{'='*70}\n")
match_count = sum(1 for r in report if r["status"] == "MATCH")
diff_count = sum(1 for r in report if r["status"] == "DIFF")
no_ods = sum(1 for r in report if r["status"] == "NO_ODS_TABLE")
print(f"完全匹配: {match_count} | 有差异: {diff_count} | 无ODS表: {no_ods}\n")
for entry in report:
if entry["status"] == "MATCH":
print(f"{entry['table']} — 完全匹配 (匹配:{entry['matched']} ODS:{entry['ods_count']} MD:{entry['md_count']})")
elif entry["status"] == "DIFF":
print(f"\n{entry['table']} — 有差异 (匹配:{entry['matched']} ODS:{entry['ods_count']} MD:{entry['md_count']})")
if entry["md_only"]:
print(f" 📄 MD有/ODS无 ({len(entry['md_only'])}): {', '.join(entry['md_only'])}")
if entry["ods_only"]:
print(f" 🗄️ ODS有/MD无 ({len(entry['ods_only'])}): {', '.join(entry['ods_only'])}")
elif entry["status"] == "NO_ODS_TABLE":
print(f"\n ⚠️ {entry['table']}{entry['note']} (MD字段数: {entry['md_fields_count']})")
elif entry["status"] == "NO_MD_FILE":
print(f"\n ⚠️ {entry['table']}{entry['note']} (ODS字段数: {entry['ods_count']})")
# JSON 输出
json_path = Path("docs/reports/ods_vs_summary_comparison_v2.json")
json_path.parent.mkdir(parents=True, exist_ok=True)
with open(json_path, "w", encoding="utf-8") as f:
json.dump(report, f, ensure_ascii=False, indent=2)
print(f"\n📁 JSON 报告: {json_path}")
if __name__ == "__main__":
main()
# AI_CHANGELOG:
# - 日期: 2026-02-14
# Prompt: P20260214-150000 — ODS 数据库结构 vs summary MD 文档字段比对
# 直接原因: 用户要求通过查询 billiards_ods schema 与 25 个 summary MD 文档进行字段比对
# 变更摘要: 新建 v2 比对脚本,改进点:(1) 仅提取"响应字段详解"章节排除请求参数
# (2) 三重匹配(精确/camelCase→snake_case/去下划线)(3) 跳过 siteProfile 子字段
# 风险与验证: 纯分析脚本无运行时影响验证python scripts/compare_ods_vs_summary_v2.py
#
# - 日期: 2026-02-14
# Prompt: P20260214-170000 — assistant_accounts_master 的 start_time/end_time 误报修复
# 直接原因: REQUEST_PARAMS 全局黑名单包含 start_time/end_time但这些字段在 3 张表中是响应业务字段,
# 且仅对 MD 侧过滤未对 ODS 侧过滤,导致假差异
# 变更摘要: 从 REQUEST_PARAMS 移除 start_time/end_time/starttime/endtime 4 个值,
# 添加 CHANGE 标记注释说明原因
# 风险与验证: 验证python scripts/compare_ods_vs_summary_v2.py确认 assistant_accounts_master、
# member_stored_value_cards 变为完全匹配group_buy_packages 不再误报 start_time/end_time
#
# - 日期: 2026-02-14
# Prompt: P20260214-190000 — goods_stock_movements 的 remark 字段误报修复
# 直接原因: skip_words 集合包含 'remark'(本意过滤表头词),但 remark 在 goods_stock_movements、
# member_balance_changes、store_goods_master 中是真实业务字段名,导致被误过滤为表头词
# 变更摘要: 从 skip_words 移除 'remark' 和 'note',添加 CHANGE 标记注释
# 风险与验证: 验证python scripts/compare_ods_vs_summary_v2.py完全匹配从 12→14
# goods_stock_movements(19/19)、member_balance_changes(28/28) 变为完全匹配
#
# - 日期: 2026-02-14
# Prompt: P20260214-200000 — group_buy_packages 的 type 字段误报修复
# 直接原因: skip_words 硬编码方式无法区分表头词和同名业务字段type/remark/note 等),
# 根本原因是过滤策略错误——应该用 Markdown 表格结构(分隔行检测)来跳过表头行
# 变更摘要: 用分隔行检测separator_pattern + header_lines替代 skip_words 硬编码,
# 彻底消除"表头词 vs 业务字段同名"的误过滤问题
# 风险与验证: 验证python scripts/compare_ods_vs_summary_v2.py
# group_buy_packages 的 type 正确匹配(匹配 39ODS有/MD无 不再包含 type
#
# - 日期: 2026-02-14
# Prompt: P20260214-210000 — siteProfile 误跳过 + goodsCategoryList 包装器字段忽略
# 直接原因: (1) siteProfile 子节跳过逻辑会跳过 siteProfile 字段本身,但它在 table_fee_transactions、
# platform_coupon_redemption_records 等表中是 object/jsonb 字段应被提取
# (2) goodsCategoryList 是 stock_goods_category_tree 的上级数组容器节点ODS 穿透存储子元素
# 变更摘要: (1) 重写 siteProfile 子节跳过逻辑,保留 siteProfile/tableProfile 字段本身,只跳过展开的子字段
# (2) 新增 WRAPPER_FIELDS 忽略列表,过滤 goodsCategoryList
# 风险与验证: 验证python scripts/compare_ods_vs_summary_v2.py完全匹配从 14→17

View File

@@ -0,0 +1,605 @@
# -*- coding: utf-8 -*-
"""
DWS Excel导入脚本
功能说明:
支持三类Excel数据的导入
1. 支出结构dws_finance_expense_summary
2. 平台结算dws_platform_settlement
3. 充值提成dws_assistant_recharge_commission
导入规范:
- 字段定义:按照目标表字段要求
- 时间粒度:支出按月,平台结算按日,充值提成按月
- 门店维度使用配置的site_id
- 去重规则按import_batch_no去重
- 校验规则:金额字段非负,日期格式校验
使用方式:
python import_dws_excel.py --type expense --file expenses.xlsx
python import_dws_excel.py --type platform --file platform_settlement.xlsx
python import_dws_excel.py --type commission --file recharge_commission.xlsx
作者ETL团队
创建日期2026-02-01
"""
import argparse
import os
import sys
import uuid
from datetime import date, datetime
from decimal import Decimal, InvalidOperation
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
# 添加项目根目录到Python路径
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
try:
import pandas as pd
except ImportError:
print("请安装 pandas: pip install pandas openpyxl")
sys.exit(1)
from config.settings import AppConfig
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations
# =============================================================================
# 常量定义
# =============================================================================
# 支出类型枚举
EXPENSE_TYPES = {
'房租': 'RENT',
'水电费': 'UTILITY',
'物业费': 'PROPERTY',
'工资': 'SALARY',
'报销': 'REIMBURSE',
'平台服务费': 'PLATFORM_FEE',
'其他': 'OTHER',
}
# 支出大类映射
EXPENSE_CATEGORIES = {
'RENT': 'FIXED_COST',
'UTILITY': 'VARIABLE_COST',
'PROPERTY': 'FIXED_COST',
'SALARY': 'FIXED_COST',
'REIMBURSE': 'VARIABLE_COST',
'PLATFORM_FEE': 'VARIABLE_COST',
'OTHER': 'OTHER',
}
# 平台类型枚举
PLATFORM_TYPES = {
'美团': 'MEITUAN',
'抖音': 'DOUYIN',
'大众点评': 'DIANPING',
'其他': 'OTHER',
}
# =============================================================================
# 导入基类
# =============================================================================
class BaseImporter:
"""导入基类"""
def __init__(self, config: Config, db: DatabaseConnection):
self.config = config
self.db = db
self.site_id = config.get("app.store_id")
self.tenant_id = config.get("app.tenant_id", self.site_id)
self.batch_no = self._generate_batch_no()
def _generate_batch_no(self) -> str:
"""生成导入批次号"""
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
unique_id = str(uuid.uuid4())[:8]
return f"{timestamp}_{unique_id}"
def _safe_decimal(self, value: Any, default: Decimal = Decimal('0')) -> Decimal:
"""安全转换为Decimal"""
if value is None or pd.isna(value):
return default
try:
return Decimal(str(value))
except (ValueError, InvalidOperation):
return default
def _safe_date(self, value: Any) -> Optional[date]:
"""安全转换为日期"""
if value is None or pd.isna(value):
return None
if isinstance(value, datetime):
return value.date()
if isinstance(value, date):
return value
try:
return pd.to_datetime(value).date()
except:
return None
def _safe_month(self, value: Any) -> Optional[date]:
"""安全转换为月份(月第一天)"""
dt = self._safe_date(value)
if dt:
return dt.replace(day=1)
return None
def import_file(self, file_path: str) -> Dict[str, Any]:
"""导入文件"""
raise NotImplementedError
def validate_row(self, row: Dict[str, Any], row_idx: int) -> List[str]:
"""校验行数据,返回错误列表"""
return []
def transform_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
"""转换行数据"""
raise NotImplementedError
def insert_records(self, records: List[Dict[str, Any]]) -> int:
"""插入记录"""
raise NotImplementedError
# =============================================================================
# 支出导入
# =============================================================================
class ExpenseImporter(BaseImporter):
"""
支出导入
Excel格式要求
- 月份: 2026-01 或 2026/01/01 格式
- 支出类型: 房租/水电费/物业费/工资/报销/平台服务费/其他
- 金额: 数字
- 备注: 可选
"""
TARGET_TABLE = "billiards_dws.dws_finance_expense_summary"
REQUIRED_COLUMNS = ['月份', '支出类型', '金额']
OPTIONAL_COLUMNS = ['明细', '备注']
def import_file(self, file_path: str) -> Dict[str, Any]:
"""导入支出Excel"""
print(f"开始导入支出文件: {file_path}")
# 读取Excel
df = pd.read_excel(file_path)
# 校验必要列
missing_cols = [c for c in self.REQUIRED_COLUMNS if c not in df.columns]
if missing_cols:
return {"status": "ERROR", "message": f"缺少必要列: {missing_cols}"}
# 处理数据
records = []
errors = []
for idx, row in df.iterrows():
row_dict = row.to_dict()
row_errors = self.validate_row(row_dict, idx + 2) # Excel行号从2开始
if row_errors:
errors.extend(row_errors)
continue
record = self.transform_row(row_dict)
records.append(record)
if errors:
print(f"校验错误: {len(errors)}")
for err in errors[:10]:
print(f" - {err}")
# 插入数据
inserted = 0
if records:
inserted = self.insert_records(records)
return {
"status": "SUCCESS" if not errors else "PARTIAL",
"batch_no": self.batch_no,
"total_rows": len(df),
"inserted": inserted,
"errors": len(errors),
"error_messages": errors[:10]
}
def validate_row(self, row: Dict[str, Any], row_idx: int) -> List[str]:
errors = []
# 校验月份
month = self._safe_month(row.get('月份'))
if not month:
errors.append(f"{row_idx}: 月份格式错误")
# 校验支出类型
expense_type = row.get('支出类型', '').strip()
if expense_type not in EXPENSE_TYPES:
errors.append(f"{row_idx}: 支出类型无效 '{expense_type}'")
# 校验金额
amount = self._safe_decimal(row.get('金额'))
if amount < 0:
errors.append(f"{row_idx}: 金额不能为负数")
return errors
def transform_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
expense_type_name = row.get('支出类型', '').strip()
expense_type_code = EXPENSE_TYPES.get(expense_type_name, 'OTHER')
expense_category = EXPENSE_CATEGORIES.get(expense_type_code, 'OTHER')
return {
'site_id': self.site_id,
'tenant_id': self.tenant_id,
'expense_month': self._safe_month(row.get('月份')),
'expense_type_code': expense_type_code,
'expense_type_name': expense_type_name,
'expense_category': expense_category,
'expense_amount': self._safe_decimal(row.get('金额')),
'expense_detail': row.get('明细'),
'import_batch_no': self.batch_no,
'import_file_name': os.path.basename(str(row.get('_file_path', ''))),
'import_time': datetime.now(),
'import_user': os.getenv('USERNAME', 'system'),
'remark': row.get('备注'),
}
def insert_records(self, records: List[Dict[str, Any]]) -> int:
columns = [
'site_id', 'tenant_id', 'expense_month', 'expense_type_code',
'expense_type_name', 'expense_category', 'expense_amount',
'expense_detail', 'import_batch_no', 'import_file_name',
'import_time', 'import_user', 'remark'
]
cols_str = ", ".join(columns)
placeholders = ", ".join(["%s"] * len(columns))
sql = f"INSERT INTO {self.TARGET_TABLE} ({cols_str}) VALUES ({placeholders})"
inserted = 0
with self.db.conn.cursor() as cur:
for record in records:
values = [record.get(col) for col in columns]
cur.execute(sql, values)
inserted += cur.rowcount
self.db.commit()
return inserted
# =============================================================================
# 平台结算导入
# =============================================================================
class PlatformSettlementImporter(BaseImporter):
"""
平台结算导入
Excel格式要求
- 回款日期: 日期格式
- 平台类型: 美团/抖音/大众点评/其他
- 平台订单号: 字符串
- 订单原始金额: 数字
- 佣金: 数字
- 服务费: 数字
- 回款金额: 数字
- 备注: 可选
"""
TARGET_TABLE = "billiards_dws.dws_platform_settlement"
REQUIRED_COLUMNS = ['回款日期', '平台类型', '回款金额']
OPTIONAL_COLUMNS = ['平台订单号', '订单原始金额', '佣金', '服务费', '关联订单ID', '备注']
def import_file(self, file_path: str) -> Dict[str, Any]:
print(f"开始导入平台结算文件: {file_path}")
df = pd.read_excel(file_path)
missing_cols = [c for c in self.REQUIRED_COLUMNS if c not in df.columns]
if missing_cols:
return {"status": "ERROR", "message": f"缺少必要列: {missing_cols}"}
records = []
errors = []
for idx, row in df.iterrows():
row_dict = row.to_dict()
row_errors = self.validate_row(row_dict, idx + 2)
if row_errors:
errors.extend(row_errors)
continue
record = self.transform_row(row_dict)
records.append(record)
if errors:
print(f"校验错误: {len(errors)}")
for err in errors[:10]:
print(f" - {err}")
inserted = 0
if records:
inserted = self.insert_records(records)
return {
"status": "SUCCESS" if not errors else "PARTIAL",
"batch_no": self.batch_no,
"total_rows": len(df),
"inserted": inserted,
"errors": len(errors),
}
def validate_row(self, row: Dict[str, Any], row_idx: int) -> List[str]:
errors = []
settlement_date = self._safe_date(row.get('回款日期'))
if not settlement_date:
errors.append(f"{row_idx}: 回款日期格式错误")
platform_type = row.get('平台类型', '').strip()
if platform_type not in PLATFORM_TYPES:
errors.append(f"{row_idx}: 平台类型无效 '{platform_type}'")
amount = self._safe_decimal(row.get('回款金额'))
if amount < 0:
errors.append(f"{row_idx}: 回款金额不能为负数")
return errors
def transform_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
platform_name = row.get('平台类型', '').strip()
platform_type = PLATFORM_TYPES.get(platform_name, 'OTHER')
return {
'site_id': self.site_id,
'tenant_id': self.tenant_id,
'settlement_date': self._safe_date(row.get('回款日期')),
'platform_type': platform_type,
'platform_name': platform_name,
'platform_order_no': row.get('平台订单号'),
'order_settle_id': row.get('关联订单ID'),
'settlement_amount': self._safe_decimal(row.get('回款金额')),
'commission_amount': self._safe_decimal(row.get('佣金')),
'service_fee': self._safe_decimal(row.get('服务费')),
'gross_amount': self._safe_decimal(row.get('订单原始金额')),
'import_batch_no': self.batch_no,
'import_file_name': os.path.basename(str(row.get('_file_path', ''))),
'import_time': datetime.now(),
'import_user': os.getenv('USERNAME', 'system'),
'remark': row.get('备注'),
}
def insert_records(self, records: List[Dict[str, Any]]) -> int:
columns = [
'site_id', 'tenant_id', 'settlement_date', 'platform_type',
'platform_name', 'platform_order_no', 'order_settle_id',
'settlement_amount', 'commission_amount', 'service_fee',
'gross_amount', 'import_batch_no', 'import_file_name',
'import_time', 'import_user', 'remark'
]
cols_str = ", ".join(columns)
placeholders = ", ".join(["%s"] * len(columns))
sql = f"INSERT INTO {self.TARGET_TABLE} ({cols_str}) VALUES ({placeholders})"
inserted = 0
with self.db.conn.cursor() as cur:
for record in records:
values = [record.get(col) for col in columns]
cur.execute(sql, values)
inserted += cur.rowcount
self.db.commit()
return inserted
# =============================================================================
# 充值提成导入
# =============================================================================
class RechargeCommissionImporter(BaseImporter):
"""
充值提成导入
Excel格式要求
- 月份: 2026-01 格式
- 助教ID: 数字
- 助教花名: 字符串
- 充值订单金额: 数字
- 提成金额: 数字
- 充值订单号: 可选
- 备注: 可选
"""
TARGET_TABLE = "billiards_dws.dws_assistant_recharge_commission"
REQUIRED_COLUMNS = ['月份', '助教ID', '提成金额']
OPTIONAL_COLUMNS = ['助教花名', '充值订单金额', '充值订单ID', '充值订单号', '备注']
def import_file(self, file_path: str) -> Dict[str, Any]:
print(f"开始导入充值提成文件: {file_path}")
df = pd.read_excel(file_path)
missing_cols = [c for c in self.REQUIRED_COLUMNS if c not in df.columns]
if missing_cols:
return {"status": "ERROR", "message": f"缺少必要列: {missing_cols}"}
records = []
errors = []
for idx, row in df.iterrows():
row_dict = row.to_dict()
row_errors = self.validate_row(row_dict, idx + 2)
if row_errors:
errors.extend(row_errors)
continue
record = self.transform_row(row_dict)
records.append(record)
if errors:
print(f"校验错误: {len(errors)}")
for err in errors[:10]:
print(f" - {err}")
inserted = 0
if records:
inserted = self.insert_records(records)
return {
"status": "SUCCESS" if not errors else "PARTIAL",
"batch_no": self.batch_no,
"total_rows": len(df),
"inserted": inserted,
"errors": len(errors),
}
def validate_row(self, row: Dict[str, Any], row_idx: int) -> List[str]:
errors = []
month = self._safe_month(row.get('月份'))
if not month:
errors.append(f"{row_idx}: 月份格式错误")
assistant_id = row.get('助教ID')
if assistant_id is None or pd.isna(assistant_id):
errors.append(f"{row_idx}: 助教ID不能为空")
amount = self._safe_decimal(row.get('提成金额'))
if amount < 0:
errors.append(f"{row_idx}: 提成金额不能为负数")
return errors
def transform_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
recharge_amount = self._safe_decimal(row.get('充值订单金额'))
commission_amount = self._safe_decimal(row.get('提成金额'))
commission_ratio = commission_amount / recharge_amount if recharge_amount > 0 else None
return {
'site_id': self.site_id,
'tenant_id': self.tenant_id,
'assistant_id': int(row.get('助教ID')),
'assistant_nickname': row.get('助教花名'),
'commission_month': self._safe_month(row.get('月份')),
'recharge_order_id': row.get('充值订单ID'),
'recharge_order_no': row.get('充值订单号'),
'recharge_amount': recharge_amount,
'commission_amount': commission_amount,
'commission_ratio': commission_ratio,
'import_batch_no': self.batch_no,
'import_file_name': os.path.basename(str(row.get('_file_path', ''))),
'import_time': datetime.now(),
'import_user': os.getenv('USERNAME', 'system'),
'remark': row.get('备注'),
}
def insert_records(self, records: List[Dict[str, Any]]) -> int:
columns = [
'site_id', 'tenant_id', 'assistant_id', 'assistant_nickname',
'commission_month', 'recharge_order_id', 'recharge_order_no',
'recharge_amount', 'commission_amount', 'commission_ratio',
'import_batch_no', 'import_file_name', 'import_time',
'import_user', 'remark'
]
cols_str = ", ".join(columns)
placeholders = ", ".join(["%s"] * len(columns))
sql = f"INSERT INTO {self.TARGET_TABLE} ({cols_str}) VALUES ({placeholders})"
inserted = 0
with self.db.conn.cursor() as cur:
for record in records:
values = [record.get(col) for col in columns]
cur.execute(sql, values)
inserted += cur.rowcount
self.db.commit()
return inserted
# =============================================================================
# 主函数
# =============================================================================
def main():
parser = argparse.ArgumentParser(description='DWS Excel导入工具')
parser.add_argument(
'--type', '-t',
choices=['expense', 'platform', 'commission'],
required=True,
help='导入类型: expense(支出), platform(平台结算), commission(充值提成)'
)
parser.add_argument(
'--file', '-f',
required=True,
help='Excel文件路径'
)
args = parser.parse_args()
# 检查文件
if not os.path.exists(args.file):
print(f"文件不存在: {args.file}")
sys.exit(1)
# 加载配置
config = AppConfig.load()
dsn = config["db"]["dsn"]
db_conn = DatabaseConnection(dsn=dsn)
db = DatabaseOperations(db_conn)
try:
# 选择导入器
if args.type == 'expense':
importer = ExpenseImporter(config, db)
elif args.type == 'platform':
importer = PlatformSettlementImporter(config, db)
elif args.type == 'commission':
importer = RechargeCommissionImporter(config, db)
else:
print(f"未知的导入类型: {args.type}")
sys.exit(1)
# 执行导入
result = importer.import_file(args.file)
# 输出结果
print("\n" + "=" * 50)
print("导入结果:")
print(f" 状态: {result.get('status')}")
print(f" 批次号: {result.get('batch_no')}")
print(f" 总行数: {result.get('total_rows')}")
print(f" 插入行数: {result.get('inserted')}")
print(f" 错误行数: {result.get('errors')}")
if result.get('status') == 'ERROR':
print(f" 错误信息: {result.get('message')}")
sys.exit(1)
except Exception as e:
print(f"导入失败: {e}")
db_conn.rollback()
sys.exit(1)
finally:
db_conn.close()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,94 @@
# -*- coding: utf-8 -*-
"""Export cfg_index_parameters table to CSV."""
from __future__ import annotations
import argparse
import csv
import os
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if ROOT not in sys.path:
sys.path.insert(0, ROOT)
from config.settings import AppConfig
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations
FIELDS = [
"param_id",
"index_type",
"param_name",
"param_value",
"description",
"effective_from",
"effective_to",
"created_at",
"updated_at",
]
def _fetch_rows(db: DatabaseOperations, index_type: Optional[str]) -> List[Dict[str, Any]]:
base_sql = """
SELECT
param_id,
index_type,
param_name,
param_value,
description,
effective_from,
effective_to,
created_at,
updated_at
FROM billiards_dws.cfg_index_parameters
"""
args: List[Any] = []
if index_type:
base_sql += " WHERE index_type = %s"
args.append(index_type)
base_sql += " ORDER BY index_type, param_name, effective_from, param_id"
rows = db.query(base_sql, args if args else None)
return [dict(r) for r in (rows or [])]
def _write_csv(rows: List[Dict[str, Any]], out_csv: Path) -> None:
out_csv.parent.mkdir(parents=True, exist_ok=True)
with out_csv.open("w", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=FIELDS)
writer.writeheader()
for row in rows:
writer.writerow({k: row.get(k) for k in FIELDS})
def main() -> None:
parser = argparse.ArgumentParser(description="Export cfg_index_parameters to CSV.")
parser.add_argument(
"--index-type",
default=None,
help="Optional index type filter (e.g. RECALL, INTIMACY, NCI, WBI).",
)
parser.add_argument(
"--output-csv",
default=os.path.join(ROOT, "docs", "cfg_index_parameters.csv"),
help="Output CSV path.",
)
args = parser.parse_args()
config = AppConfig.load()
db_conn = DatabaseConnection(config.config["db"]["dsn"])
db = DatabaseOperations(db_conn)
try:
rows = _fetch_rows(db, args.index_type)
out_csv = Path(args.output_csv)
_write_csv(rows, out_csv)
print(f"rows={len(rows)}")
print(f"csv={out_csv}")
finally:
db_conn.close()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,423 @@
# -*- coding: utf-8 -*-
"""Export groupbuy orders that used assistant services."""
from __future__ import annotations
import argparse
import csv
import os
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if ROOT not in sys.path:
sys.path.insert(0, ROOT)
from config.settings import AppConfig
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations
def _as_int(v: Any) -> Optional[int]:
if v is None or str(v).strip() == "":
return None
return int(v)
def _resolve_site_id(config: AppConfig, db: DatabaseOperations, cli_site_id: Optional[int]) -> int:
if cli_site_id is not None:
return int(cli_site_id)
from_cfg = _as_int(config.get("app.store_id"))
if from_cfg is not None:
return from_cfg
rows = db.query(
"""
SELECT site_id
FROM billiards_dwd.dwd_settlement_head
WHERE site_id IS NOT NULL
GROUP BY site_id
ORDER BY COUNT(*) DESC
LIMIT 1
"""
)
if rows:
return int(dict(rows[0])["site_id"])
raise RuntimeError("Unable to resolve site_id; pass --site-id explicitly.")
FIELD_ORDER: List[str] = [
"site_id",
"order_settle_id",
"order_trade_no",
"pay_time",
"settle_type",
"member_id",
"member_name",
"member_phone",
"table_id",
"table_name",
"table_area_name",
"settle_consume_money",
"settle_pay_amount",
"settle_coupon_amount",
"pl_coupon_sale_amount",
"groupbuy_item_count",
"groupbuy_pay_amount",
"groupbuy_ledger_amount",
"groupbuy_coupon_money",
"coupon_codes",
"groupbuy_items",
"assistant_service_count",
"assistant_count",
"assistant_nicknames",
"assistant_skills",
"assistant_real_use_seconds",
"assistant_projected_income",
"assistant_real_service_money",
]
ZH_HEADER_MAP: Dict[str, str] = {
"site_id": "门店ID",
"order_settle_id": "结账单ID",
"order_trade_no": "订单交易号",
"pay_time": "结账时间",
"settle_type": "结账类型",
"member_id": "会员ID",
"member_name": "会员姓名",
"member_phone": "会员手机号",
"table_id": "台桌ID",
"table_name": "台桌名称",
"table_area_name": "台区名称",
"settle_consume_money": "结算消费金额",
"settle_pay_amount": "结算实付金额",
"settle_coupon_amount": "结算团购抵扣金额",
"pl_coupon_sale_amount": "平台团购实付金额",
"groupbuy_item_count": "团购核销条目数",
"groupbuy_pay_amount": "团购实付合计",
"groupbuy_ledger_amount": "团购标价合计",
"groupbuy_coupon_money": "团购券面额合计",
"coupon_codes": "团购券码列表",
"groupbuy_items": "团购项目列表",
"assistant_service_count": "助教服务条目数",
"assistant_count": "助教人数",
"assistant_nicknames": "助教昵称列表",
"assistant_skills": "助教技能列表",
"assistant_real_use_seconds": "助教实际服务秒数",
"assistant_projected_income": "助教预计收入合计",
"assistant_real_service_money": "助教实收服务费合计",
}
def _fetch_rows_current(
db: DatabaseOperations,
site_id: int,
start_date: Optional[str],
end_date: Optional[str],
) -> List[Dict[str, Any]]:
sql = """
WITH gb AS (
SELECT
site_id,
order_settle_id,
COUNT(*) AS groupbuy_item_count,
ROUND(SUM(COALESCE(ledger_unit_price, 0))::numeric, 2) AS groupbuy_pay_amount,
ROUND(SUM(COALESCE(ledger_amount, 0))::numeric, 2) AS groupbuy_ledger_amount,
ROUND(SUM(COALESCE(coupon_money, 0))::numeric, 2) AS groupbuy_coupon_money,
STRING_AGG(DISTINCT NULLIF(coupon_code, ''), '?' ORDER BY NULLIF(coupon_code, '')) AS coupon_codes,
STRING_AGG(DISTINCT NULLIF(ledger_name, ''), '?' ORDER BY NULLIF(ledger_name, '')) AS groupbuy_items
FROM billiards_dwd.dwd_groupbuy_redemption
WHERE site_id = %s
AND is_delete = 0
GROUP BY site_id, order_settle_id
),
asv AS (
SELECT
site_id,
order_settle_id,
COUNT(*) AS assistant_service_count,
COUNT(DISTINCT NULLIF(assistant_no, '')) AS assistant_count,
STRING_AGG(DISTINCT NULLIF(nickname, ''), '?' ORDER BY NULLIF(nickname, '')) AS assistant_nicknames,
STRING_AGG(DISTINCT NULLIF(skill_name, ''), '?' ORDER BY NULLIF(skill_name, '')) AS assistant_skills,
ROUND(SUM(COALESCE(real_use_seconds, 0))::numeric, 0) AS assistant_real_use_seconds,
ROUND(SUM(COALESCE(projected_income, 0))::numeric, 2) AS assistant_projected_income,
ROUND(SUM(COALESCE(real_service_money, 0))::numeric, 2) AS assistant_real_service_money
FROM billiards_dwd.dwd_assistant_service_log
WHERE site_id = %s
AND is_delete = 0
GROUP BY site_id, order_settle_id
)
SELECT
sh.site_id,
sh.order_settle_id,
sh.order_trade_no,
sh.pay_time,
sh.settle_type,
sh.member_id,
COALESCE(dm.nickname, sh.member_name) AS member_name,
COALESCE(dm.mobile, sh.member_phone) AS member_phone,
sh.table_id,
dt.table_name,
dt.site_table_area_name AS table_area_name,
ROUND(COALESCE(sh.consume_money, 0)::numeric, 2) AS settle_consume_money,
ROUND(COALESCE(sh.pay_amount, 0)::numeric, 2) AS settle_pay_amount,
ROUND(COALESCE(sh.coupon_amount, 0)::numeric, 2) AS settle_coupon_amount,
ROUND(COALESCE(sh.pl_coupon_sale_amount, 0)::numeric, 2) AS pl_coupon_sale_amount,
gb.groupbuy_item_count,
gb.groupbuy_pay_amount,
gb.groupbuy_ledger_amount,
gb.groupbuy_coupon_money,
gb.coupon_codes,
gb.groupbuy_items,
asv.assistant_service_count,
asv.assistant_count,
asv.assistant_nicknames,
asv.assistant_skills,
asv.assistant_real_use_seconds,
asv.assistant_projected_income,
asv.assistant_real_service_money
FROM gb
JOIN asv
ON asv.site_id = gb.site_id
AND asv.order_settle_id = gb.order_settle_id
LEFT JOIN billiards_dwd.dwd_settlement_head sh
ON sh.site_id = gb.site_id
AND sh.order_settle_id = gb.order_settle_id
LEFT JOIN billiards_dwd.dim_member dm
ON dm.register_site_id = sh.site_id
AND dm.member_id = sh.member_id
AND dm.scd2_is_current = 1
LEFT JOIN billiards_dwd.dim_table dt
ON dt.site_id = sh.site_id
AND dt.table_id = sh.table_id
AND dt.scd2_is_current = 1
WHERE (%s::date IS NULL OR sh.pay_time::date >= %s::date)
AND (%s::date IS NULL OR sh.pay_time::date <= %s::date)
ORDER BY sh.pay_time DESC, sh.order_settle_id DESC
"""
rows = db.query(
sql,
(
site_id,
site_id,
start_date,
start_date,
end_date,
end_date,
),
)
return [dict(r) for r in (rows or [])]
def _fetch_rows_optimized(
db: DatabaseOperations,
site_id: int,
start_date: Optional[str],
end_date: Optional[str],
) -> List[Dict[str, Any]]:
"""
Optimized export strategy:
- Deduplicate groupbuy rows by (order_settle_id, coupon_key) to handle retry noise.
- Deduplicate assistant rows by assistant_service_id.
- Keep output schema identical to current export for direct comparison.
"""
sql = """
WITH gb_raw AS (
SELECT
redemption_id,
site_id,
order_settle_id,
order_coupon_id,
coupon_code,
ledger_name,
COALESCE(ledger_unit_price, 0) AS ledger_unit_price,
COALESCE(ledger_amount, 0) AS ledger_amount,
COALESCE(coupon_money, 0) AS coupon_money,
create_time,
COALESCE(NULLIF(coupon_code, ''), CAST(order_coupon_id AS varchar), CAST(redemption_id AS varchar)) AS coupon_key,
ROW_NUMBER() OVER (
PARTITION BY site_id, order_settle_id,
COALESCE(NULLIF(coupon_code, ''), CAST(order_coupon_id AS varchar), CAST(redemption_id AS varchar))
ORDER BY create_time DESC NULLS LAST, redemption_id DESC
) AS rn
FROM billiards_dwd.dwd_groupbuy_redemption
WHERE site_id = %s
AND is_delete = 0
),
gb AS (
SELECT
site_id,
order_settle_id,
COUNT(*) AS groupbuy_item_count,
ROUND(SUM(ledger_unit_price)::numeric, 2) AS groupbuy_pay_amount,
ROUND(SUM(ledger_amount)::numeric, 2) AS groupbuy_ledger_amount,
ROUND(SUM(coupon_money)::numeric, 2) AS groupbuy_coupon_money,
STRING_AGG(DISTINCT NULLIF(coupon_code, ''), '?' ORDER BY NULLIF(coupon_code, '')) AS coupon_codes,
STRING_AGG(DISTINCT NULLIF(ledger_name, ''), '?' ORDER BY NULLIF(ledger_name, '')) AS groupbuy_items
FROM gb_raw
WHERE rn = 1
GROUP BY site_id, order_settle_id
),
asv_raw AS (
SELECT DISTINCT ON (assistant_service_id)
assistant_service_id,
site_id,
order_settle_id,
assistant_no,
nickname,
skill_name,
COALESCE(real_use_seconds, 0) AS real_use_seconds,
COALESCE(projected_income, 0) AS projected_income,
COALESCE(real_service_money, 0) AS real_service_money
FROM billiards_dwd.dwd_assistant_service_log
WHERE site_id = %s
AND is_delete = 0
ORDER BY assistant_service_id
),
asv AS (
SELECT
site_id,
order_settle_id,
COUNT(*) AS assistant_service_count,
COUNT(DISTINCT NULLIF(assistant_no, '')) AS assistant_count,
STRING_AGG(DISTINCT NULLIF(nickname, ''), '?' ORDER BY NULLIF(nickname, '')) AS assistant_nicknames,
STRING_AGG(DISTINCT NULLIF(skill_name, ''), '?' ORDER BY NULLIF(skill_name, '')) AS assistant_skills,
ROUND(SUM(real_use_seconds)::numeric, 0) AS assistant_real_use_seconds,
ROUND(SUM(projected_income)::numeric, 2) AS assistant_projected_income,
ROUND(SUM(real_service_money)::numeric, 2) AS assistant_real_service_money
FROM asv_raw
GROUP BY site_id, order_settle_id
)
SELECT
sh.site_id,
sh.order_settle_id,
sh.order_trade_no,
sh.pay_time,
sh.settle_type,
sh.member_id,
COALESCE(dm.nickname, sh.member_name) AS member_name,
COALESCE(dm.mobile, sh.member_phone) AS member_phone,
sh.table_id,
dt.table_name,
dt.site_table_area_name AS table_area_name,
ROUND(COALESCE(sh.consume_money, 0)::numeric, 2) AS settle_consume_money,
ROUND(COALESCE(sh.pay_amount, 0)::numeric, 2) AS settle_pay_amount,
ROUND(COALESCE(sh.coupon_amount, 0)::numeric, 2) AS settle_coupon_amount,
ROUND(COALESCE(sh.pl_coupon_sale_amount, 0)::numeric, 2) AS pl_coupon_sale_amount,
gb.groupbuy_item_count,
gb.groupbuy_pay_amount,
gb.groupbuy_ledger_amount,
gb.groupbuy_coupon_money,
gb.coupon_codes,
gb.groupbuy_items,
asv.assistant_service_count,
asv.assistant_count,
asv.assistant_nicknames,
asv.assistant_skills,
asv.assistant_real_use_seconds,
asv.assistant_projected_income,
asv.assistant_real_service_money
FROM gb
JOIN asv
ON asv.site_id = gb.site_id
AND asv.order_settle_id = gb.order_settle_id
LEFT JOIN billiards_dwd.dwd_settlement_head sh
ON sh.site_id = gb.site_id
AND sh.order_settle_id = gb.order_settle_id
LEFT JOIN billiards_dwd.dim_member dm
ON dm.register_site_id = sh.site_id
AND dm.member_id = sh.member_id
AND dm.scd2_is_current = 1
LEFT JOIN billiards_dwd.dim_table dt
ON dt.site_id = sh.site_id
AND dt.table_id = sh.table_id
AND dt.scd2_is_current = 1
WHERE (%s::date IS NULL OR sh.pay_time::date >= %s::date)
AND (%s::date IS NULL OR sh.pay_time::date <= %s::date)
ORDER BY sh.pay_time DESC, sh.order_settle_id DESC
"""
rows = db.query(
sql,
(
site_id,
site_id,
start_date,
start_date,
end_date,
end_date,
),
)
return [dict(r) for r in (rows or [])]
def _write_csv(
rows: List[Dict[str, Any]],
out_csv: Path,
fields: Sequence[str],
header_map: Optional[Dict[str, str]] = None,
) -> None:
out_csv.parent.mkdir(parents=True, exist_ok=True)
if header_map:
file_headers = [header_map.get(f, f) for f in fields]
else:
file_headers = list(fields)
with out_csv.open("w", newline="", encoding="utf-8-sig") as f:
writer = csv.writer(f)
writer.writerow(file_headers)
for row in rows:
writer.writerow([row.get(k) for k in fields])
def main() -> None:
parser = argparse.ArgumentParser(
description="Export groupbuy orders that used assistant services."
)
parser.add_argument("--site-id", type=int, default=None, help="Site id to export")
parser.add_argument("--start-date", default=None, help="Filter start date: YYYY-MM-DD")
parser.add_argument("--end-date", default=None, help="Filter end date: YYYY-MM-DD")
parser.add_argument(
"--scheme",
choices=["current", "optimized"],
default="current",
help="Export scheme",
)
parser.add_argument(
"--header-lang",
choices=["zh", "en"],
default="zh",
help="CSV header language",
)
parser.add_argument(
"--output-csv",
default=os.path.join(ROOT, "docs", "groupbuy_orders_with_assistant_service.csv"),
help="Output CSV path",
)
args = parser.parse_args()
config = AppConfig.load()
db_conn = DatabaseConnection(config.config["db"]["dsn"])
db = DatabaseOperations(db_conn)
try:
site_id = _resolve_site_id(config, db, args.site_id)
if args.scheme == "optimized":
rows = _fetch_rows_optimized(db, site_id, args.start_date, args.end_date)
else:
rows = _fetch_rows_current(db, site_id, args.start_date, args.end_date)
finally:
db_conn.close()
out_csv = Path(args.output_csv)
header_map = ZH_HEADER_MAP if args.header_lang == "zh" else None
_write_csv(rows, out_csv, fields=FIELD_ORDER, header_map=header_map)
print(f"site_id={site_id}")
print(f"scheme={args.scheme}")
print(f"rows={len(rows)}")
print(f"csv={out_csv}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,143 @@
# -*- coding: utf-8 -*-
"""Export index tables to markdown for quick review."""
import os
import sys
from datetime import datetime
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if ROOT not in sys.path:
sys.path.insert(0, ROOT)
from config.settings import AppConfig
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations
def _fmt(value, digits=2):
if value is None:
return "-"
if isinstance(value, (int, float)):
return f"{value:.{digits}f}"
return str(value)
def _fetch(db: DatabaseOperations, sql: str):
return [dict(r) for r in (db.query(sql) or [])]
def build_markdown(db: DatabaseOperations) -> str:
lines = []
lines.append("# Index Tables")
lines.append("")
lines.append(f"Generated at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
lines.append("")
# 老客挽回指数WBI
wbi_sql = """
SELECT
COALESCE(m.nickname, CONCAT('member_', r.member_id)) AS member_name,
r.display_score,
r.raw_score,
r.t_v,
r.visits_14d,
r.sv_balance
FROM billiards_dws.dws_member_winback_index r
LEFT JOIN billiards_dwd.dim_member m
ON r.member_id = m.member_id AND m.scd2_is_current = 1
ORDER BY r.display_score DESC NULLS LAST
"""
wbi_rows = _fetch(db, wbi_sql)
lines.append("## 1) WBI")
lines.append("")
lines.append("| member_name | wbi | raw_score | t_v | visits_14d | sv_balance |")
lines.append("|---|---:|---:|---:|---:|---:|")
for r in wbi_rows:
lines.append(
f"| {r.get('member_name') or '-'} | {_fmt(r.get('display_score'))} | {_fmt(r.get('raw_score'), 4)} | "
f"{_fmt(r.get('t_v'))} | {_fmt(r.get('visits_14d'), 0)} | {_fmt(r.get('sv_balance'))} |"
)
lines.append("")
lines.append(f"Total rows: {len(wbi_rows)}")
lines.append("")
# 新客转化指数NCI
nci_sql = """
SELECT
COALESCE(m.nickname, CONCAT('member_', r.member_id)) AS member_name,
r.display_score,
r.display_score_welcome,
r.display_score_convert,
r.raw_score,
r.raw_score_welcome,
r.raw_score_convert,
r.t_v,
r.visits_14d
FROM billiards_dws.dws_member_newconv_index r
LEFT JOIN billiards_dwd.dim_member m
ON r.member_id = m.member_id AND m.scd2_is_current = 1
ORDER BY r.display_score DESC NULLS LAST
"""
nci_rows = _fetch(db, nci_sql)
lines.append("## 2) NCI")
lines.append("")
lines.append("| member_name | nci | welcome | convert | raw_total | raw_welcome | raw_convert | t_v | visits_14d |")
lines.append("|---|---:|---:|---:|---:|---:|---:|---:|---:|")
for r in nci_rows:
lines.append(
f"| {r.get('member_name') or '-'} | {_fmt(r.get('display_score'))} | {_fmt(r.get('display_score_welcome'))} | "
f"{_fmt(r.get('display_score_convert'))} | {_fmt(r.get('raw_score'), 4)} | {_fmt(r.get('raw_score_welcome'), 4)} | "
f"{_fmt(r.get('raw_score_convert'), 4)} | {_fmt(r.get('t_v'))} | {_fmt(r.get('visits_14d'), 0)} |"
)
lines.append("")
lines.append(f"Total rows: {len(nci_rows)}")
lines.append("")
# 亲密指数
intimacy_sql = """
SELECT
COALESCE(a.nickname, CONCAT('assistant_', i.assistant_id)) AS assistant_name,
COALESCE(m.nickname, CONCAT('member_', i.member_id)) AS member_name,
i.display_score,
i.session_count,
i.attributed_recharge_amount
FROM billiards_dws.dws_member_assistant_intimacy i
LEFT JOIN billiards_dwd.dim_member m
ON i.member_id = m.member_id AND m.scd2_is_current = 1
LEFT JOIN billiards_dwd.dim_assistant a
ON i.assistant_id = a.assistant_id AND a.scd2_is_current = 1
ORDER BY i.display_score DESC NULLS LAST, i.session_count DESC
"""
intimacy_rows = _fetch(db, intimacy_sql)
lines.append("## 3) Intimacy")
lines.append("")
lines.append("| assistant | member | intimacy | sessions | recharge_amount |")
lines.append("|---|---|---:|---:|---:|")
for r in intimacy_rows:
lines.append(
f"| {r.get('assistant_name') or '-'} | {r.get('member_name') or '-'} | {_fmt(r.get('display_score'))} | "
f"{_fmt(r.get('session_count'), 0)} | {_fmt(r.get('attributed_recharge_amount'))} |"
)
lines.append("")
lines.append(f"Total rows: {len(intimacy_rows)}")
return "\n".join(lines)
def main() -> None:
config = AppConfig.load()
db_conn = DatabaseConnection(config.config["db"]["dsn"])
db = DatabaseOperations(db_conn)
try:
markdown = build_markdown(db)
finally:
db_conn.close()
output_path = os.path.join(ROOT, "docs", "index_tables.md")
with open(output_path, "w", encoding="utf-8-sig") as f:
f.write(markdown)
print(f"Exported to {output_path}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,475 @@
# -*- coding: utf-8 -*-
"""Export full intimacy JSON with member visits and card balances."""
from __future__ import annotations
import argparse
import json
import os
import sys
from datetime import date, datetime
from decimal import Decimal
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if ROOT not in sys.path:
sys.path.insert(0, ROOT)
from config.settings import AppConfig
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations
def _as_int(v: Any) -> Optional[int]:
if v is None:
return None
s = str(v).strip()
if not s:
return None
return int(s)
def _to_float(v: Any, default: float = 0.0) -> float:
if v is None:
return default
if isinstance(v, Decimal):
return float(v)
if isinstance(v, (int, float)):
return float(v)
s = str(v).strip()
if not s:
return default
return float(s)
def _fmt_dt(v: Any) -> Optional[str]:
if v is None:
return None
if isinstance(v, datetime):
return v.isoformat()
if isinstance(v, date):
return v.isoformat()
return str(v)
def _resolve_site_id(config: AppConfig, db: DatabaseOperations, cli_site_id: Optional[int]) -> int:
if cli_site_id is not None:
return int(cli_site_id)
from_cfg = _as_int(config.get("app.store_id")) or _as_int(config.get("app.default_site_id"))
if from_cfg is not None:
return from_cfg
rows = db.query(
"""
SELECT site_id
FROM billiards_dws.dws_member_assistant_intimacy
WHERE site_id IS NOT NULL
GROUP BY site_id
ORDER BY COUNT(*) DESC
LIMIT 1
"""
)
if rows:
return int(dict(rows[0])["site_id"])
raise RuntimeError("Unable to resolve site_id; pass --site-id explicitly.")
def _fetch_pairs(db: DatabaseOperations, site_id: int) -> List[Dict[str, Any]]:
sql = """
SELECT
i.site_id,
i.tenant_id,
i.member_id,
i.assistant_id,
i.session_count,
i.total_duration_minutes,
i.basic_session_count,
i.incentive_session_count,
i.days_since_last_session,
i.attributed_recharge_count,
i.attributed_recharge_amount,
i.score_frequency,
i.score_recency,
i.score_recharge,
i.score_duration,
i.burst_multiplier,
i.raw_score,
i.display_score,
i.calc_time,
COALESCE(m.nickname, CONCAT('member_', i.member_id::text)) AS member_nickname,
COALESCE(a.nickname, CONCAT('assistant_', i.assistant_id::text)) AS assistant_nickname
FROM billiards_dws.dws_member_assistant_intimacy i
LEFT JOIN billiards_dwd.dim_member m
ON i.member_id = m.member_id
AND m.scd2_is_current = 1
LEFT JOIN billiards_dwd.dim_assistant a
ON i.assistant_id = a.assistant_id
AND a.scd2_is_current = 1
WHERE i.site_id = %s
ORDER BY i.display_score DESC NULLS LAST, i.session_count DESC, i.member_id, i.assistant_id
"""
rows = db.query(sql, (site_id,))
return [dict(r) for r in (rows or [])]
def _fetch_member_cards(
db: DatabaseOperations,
site_id: int,
member_ids: List[int],
) -> Dict[int, Dict[str, Any]]:
if not member_ids:
return {}
member_ids_str = ",".join(str(int(x)) for x in sorted(set(member_ids)))
sql = f"""
SELECT
tenant_member_id AS member_id,
member_card_id,
card_type_id,
member_card_grade_code,
member_card_grade_code_name,
member_card_type_name,
member_name,
member_mobile,
balance,
principal_balance,
status,
start_time,
end_time,
last_consume_time
FROM billiards_dwd.dim_member_card_account
WHERE register_site_id = %s
AND scd2_is_current = 1
AND COALESCE(is_delete, 0) = 0
AND tenant_member_id IN ({member_ids_str})
ORDER BY tenant_member_id, balance DESC NULLS LAST, member_card_id
"""
rows = db.query(sql, (site_id,)) or []
result: Dict[int, Dict[str, Any]] = {}
for r in rows:
d = dict(r)
mid = int(d["member_id"])
balance = _to_float(d.get("balance"), 0.0)
card = {
"member_card_id": _as_int(d.get("member_card_id")),
"card_type_id": _as_int(d.get("card_type_id")),
"member_card_grade_code": _as_int(d.get("member_card_grade_code")),
"member_card_grade_code_name": d.get("member_card_grade_code_name"),
"member_card_type_name": d.get("member_card_type_name"),
"member_name": d.get("member_name"),
"member_mobile": d.get("member_mobile"),
"balance": round(balance, 2),
"principal_balance": round(_to_float(d.get("principal_balance"), 0.0), 2),
"status": _as_int(d.get("status")),
"start_time": _fmt_dt(d.get("start_time")),
"end_time": _fmt_dt(d.get("end_time")),
"last_consume_time": _fmt_dt(d.get("last_consume_time")),
}
bucket = result.setdefault(
mid,
{
"member_id": mid,
"cards_all": [],
"cards_balance_ge_10": [],
"total_card_balance_all": 0.0,
},
)
bucket["cards_all"].append(card)
bucket["total_card_balance_all"] = round(bucket["total_card_balance_all"] + balance, 2)
if balance >= 10.0:
bucket["cards_balance_ge_10"].append(card)
return result
def _fetch_visit_rows(
db: DatabaseOperations,
site_id: int,
member_ids: List[int],
) -> Dict[Tuple[int, int], Dict[str, Any]]:
if not member_ids:
return {}
member_ids_str = ",".join(str(int(x)) for x in sorted(set(member_ids)))
sql = f"""
SELECT
member_id,
order_settle_id,
visit_date,
visit_time,
table_name,
area_name,
area_category,
table_duration_min,
assistant_duration_min,
table_fee,
goods_amount,
assistant_amount,
total_consume,
total_discount,
actual_pay,
cash_pay,
cash_card_pay,
gift_card_pay,
groupbuy_pay
FROM billiards_dws.dws_member_visit_detail
WHERE site_id = %s
AND member_id IN ({member_ids_str})
ORDER BY member_id, visit_time DESC, order_settle_id DESC
"""
rows = db.query(sql, (site_id,)) or []
result: Dict[Tuple[int, int], Dict[str, Any]] = {}
for r in rows:
d = dict(r)
key = (int(d["member_id"]), int(d["order_settle_id"]))
result[key] = {
"member_id": int(d["member_id"]),
"order_settle_id": int(d["order_settle_id"]),
"visit_date": _fmt_dt(d.get("visit_date")),
"visit_time": _fmt_dt(d.get("visit_time")),
"table_name": d.get("table_name"),
"area_name": d.get("area_name"),
"area_category": d.get("area_category"),
"table_duration_min": _as_int(d.get("table_duration_min")) or 0,
"assistant_duration_min_total": _as_int(d.get("assistant_duration_min")) or 0,
"table_fee": round(_to_float(d.get("table_fee"), 0.0), 2),
"goods_amount": round(_to_float(d.get("goods_amount"), 0.0), 2),
"assistant_amount": round(_to_float(d.get("assistant_amount"), 0.0), 2),
"total_consume": round(_to_float(d.get("total_consume"), 0.0), 2),
"total_discount": round(_to_float(d.get("total_discount"), 0.0), 2),
"actual_pay": round(_to_float(d.get("actual_pay"), 0.0), 2),
"cash_pay": round(_to_float(d.get("cash_pay"), 0.0), 2),
"cash_card_pay": round(_to_float(d.get("cash_card_pay"), 0.0), 2),
"gift_card_pay": round(_to_float(d.get("gift_card_pay"), 0.0), 2),
"groupbuy_pay": round(_to_float(d.get("groupbuy_pay"), 0.0), 2),
}
return result
def _fetch_assistant_service_rows(
db: DatabaseOperations,
site_id: int,
member_ids: List[int],
) -> Dict[Tuple[int, int], List[Dict[str, Any]]]:
if not member_ids:
return {}
member_ids_str = ",".join(str(int(x)) for x in sorted(set(member_ids)))
sql = f"""
SELECT
s.tenant_member_id AS member_id,
s.order_settle_id,
d.assistant_id,
COALESCE(d.nickname, s.nickname) AS assistant_nickname,
SUM(COALESCE(s.income_seconds, 0)) / 60.0 AS duration_min,
SUM(COALESCE(s.ledger_amount, 0)) AS amount
FROM billiards_dwd.dwd_assistant_service_log s
JOIN billiards_dwd.dim_assistant d
ON s.user_id = d.user_id
AND d.scd2_is_current = 1
WHERE s.site_id = %s
AND s.is_delete = 0
AND s.tenant_member_id IN ({member_ids_str})
AND s.order_settle_id IS NOT NULL
GROUP BY
s.tenant_member_id,
s.order_settle_id,
d.assistant_id,
COALESCE(d.nickname, s.nickname)
ORDER BY s.tenant_member_id, s.order_settle_id
"""
rows = db.query(sql, (site_id,)) or []
result: Dict[Tuple[int, int], List[Dict[str, Any]]] = {}
for r in rows:
d = dict(r)
key = (int(d["member_id"]), int(d["order_settle_id"]))
rec = {
"assistant_id": int(d["assistant_id"]),
"assistant_nickname": d.get("assistant_nickname"),
"duration_min": round(_to_float(d.get("duration_min"), 0.0), 2),
"amount": round(_to_float(d.get("amount"), 0.0), 2),
}
result.setdefault(key, []).append(rec)
return result
def _pk_key(assistant_nickname: Optional[str], member_nickname: Optional[str]) -> str:
a = (assistant_nickname or "").strip() or "assistant_unknown"
m = (member_nickname or "").strip() or "member_unknown"
return f"{a}__{m}"
def build_export_payload(db: DatabaseOperations, site_id: int) -> Dict[str, Any]:
pairs = _fetch_pairs(db, site_id)
member_ids = sorted({int(p["member_id"]) for p in pairs})
cards_by_member = _fetch_member_cards(db, site_id, member_ids)
visits_by_key = _fetch_visit_rows(db, site_id, member_ids)
service_by_key = _fetch_assistant_service_rows(db, site_id, member_ids)
visits_by_member: Dict[int, List[Tuple[Tuple[int, int], Dict[str, Any]]]] = {}
for k, v in visits_by_key.items():
visits_by_member.setdefault(k[0], []).append((k, v))
data_by_pk: Dict[str, Dict[str, Any]] = {}
collisions: List[str] = []
for p in pairs:
member_id = int(p["member_id"])
assistant_id = int(p["assistant_id"])
assistant_nickname = p.get("assistant_nickname")
member_nickname = p.get("member_nickname")
visit_items: List[Dict[str, Any]] = []
for key, visit in visits_by_member.get(member_id, []):
service_list = service_by_key.get(key, [])
if not service_list:
continue
matched = [x for x in service_list if x["assistant_id"] == assistant_id]
if not matched:
continue
matched_duration = round(sum(x["duration_min"] for x in matched), 2)
matched_amount = round(sum(x["amount"] for x in matched), 2)
matched_nicknames = sorted({x.get("assistant_nickname") for x in matched if x.get("assistant_nickname")})
visit_items.append(
{
"order_settle_id": visit.get("order_settle_id"),
"visit_date": visit.get("visit_date"),
"visit_time": visit.get("visit_time"),
"table_name": visit.get("table_name"),
"area_name": visit.get("area_name"),
"area_category": visit.get("area_category"),
"table_duration_min": visit.get("table_duration_min"),
"assistant_duration_min_total": visit.get("assistant_duration_min_total"),
"table_fee": visit.get("table_fee"),
"goods_amount": visit.get("goods_amount"),
"assistant_amount": visit.get("assistant_amount"),
"total_consume": visit.get("total_consume"),
"total_discount": visit.get("total_discount"),
"actual_pay": visit.get("actual_pay"),
"cash_pay": visit.get("cash_pay"),
"cash_card_pay": visit.get("cash_card_pay"),
"gift_card_pay": visit.get("gift_card_pay"),
"groupbuy_pay": visit.get("groupbuy_pay"),
"target_assistant_nickname": ", ".join(matched_nicknames) if matched_nicknames else p.get("assistant_nickname"),
"target_assistant_duration_min": matched_duration,
"target_assistant_amount": matched_amount,
}
)
visit_items.sort(
key=lambda x: (x.get("visit_time") or "", x.get("order_settle_id") or 0),
reverse=True,
)
member_cards = cards_by_member.get(
member_id,
{
"member_id": member_id,
"cards_all": [],
"cards_balance_ge_10": [],
"total_card_balance_all": 0.0,
},
)
pk = _pk_key(assistant_nickname, member_nickname)
item = {
"primary_key": {
"assistant_nickname": assistant_nickname,
"member_nickname": member_nickname,
},
"intimacy": {
"display_score": round(_to_float(p.get("display_score"), 0.0), 2),
"raw_score": round(_to_float(p.get("raw_score"), 0.0), 6),
"session_count": _as_int(p.get("session_count")) or 0,
"total_duration_minutes": _as_int(p.get("total_duration_minutes")) or 0,
"basic_session_count": _as_int(p.get("basic_session_count")) or 0,
"incentive_session_count": _as_int(p.get("incentive_session_count")) or 0,
"days_since_last_session": _as_int(p.get("days_since_last_session")),
"attributed_recharge_count": _as_int(p.get("attributed_recharge_count")) or 0,
"attributed_recharge_amount": round(_to_float(p.get("attributed_recharge_amount"), 0.0), 2),
"score_frequency": round(_to_float(p.get("score_frequency"), 0.0), 4),
"score_recency": round(_to_float(p.get("score_recency"), 0.0), 4),
"score_recharge": round(_to_float(p.get("score_recharge"), 0.0), 4),
"score_duration": round(_to_float(p.get("score_duration"), 0.0), 4),
"burst_multiplier": round(_to_float(p.get("burst_multiplier"), 1.0), 4),
"calc_time": _fmt_dt(p.get("calc_time")),
},
"member_cards": {
"cards_balance_ge_10": member_cards.get("cards_balance_ge_10", []),
"total_card_balance_all": round(_to_float(member_cards.get("total_card_balance_all"), 0.0), 2),
},
"visit_consumptions": visit_items,
}
if pk in data_by_pk:
collisions.append(pk)
existing = data_by_pk[pk]
existing["collision_items"] = existing.get("collision_items", [])
existing["collision_items"].append(item)
else:
data_by_pk[pk] = item
payload = {
"meta": {
"site_id": site_id,
"generated_at": datetime.now().isoformat(),
"pair_count": len(pairs),
"primary_key_count": len(data_by_pk),
"member_count": len(member_ids),
"primary_key_rule": "assistant_nickname + member_nickname",
"collision_count": len(collisions),
},
"data": data_by_pk,
}
return payload
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Export full intimacy JSON")
parser.add_argument("--site-id", type=int, default=None, help="site_id, defaults to app.store_id")
parser.add_argument(
"--output",
default="tmp/intimacy_full_export.json",
help="output JSON file path",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
config = AppConfig.load()
db_conn = DatabaseConnection(config.config["db"]["dsn"])
db = DatabaseOperations(db_conn)
try:
site_id = _resolve_site_id(config, db, args.site_id)
payload = build_export_payload(db, site_id)
finally:
db_conn.close()
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(
json.dumps(payload, ensure_ascii=False, indent=2),
encoding="utf-8",
)
print(f"Exported intimacy JSON: {output_path}")
print(f"pair_count={payload['meta']['pair_count']}, member_count={payload['meta']['member_count']}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,720 @@
# -*- coding: utf-8 -*-
"""Export 60-day member visit detail with WBI/NCI scores."""
from __future__ import annotations
import argparse
import csv
import math
import os
import sys
from datetime import date, datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if ROOT not in sys.path:
sys.path.insert(0, ROOT)
from config.settings import AppConfig
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations
FIELDS = [
"site_id",
"member_id",
"member_nickname",
"visit_time",
"consume_amount",
"sv_balance",
"assistant_nicknames",
"wbi_score",
"nci_score",
]
def _as_int(v: Any) -> Optional[int]:
if v is None or str(v).strip() == "":
return None
return int(v)
def _as_float(v: Any, default: float = 0.0) -> float:
if v is None or str(v).strip() == "":
return default
return float(v)
def _resolve_site_id(config: AppConfig, db: DatabaseOperations, cli_site_id: Optional[int]) -> int:
if cli_site_id is not None:
return int(cli_site_id)
from_cfg = _as_int(config.get("app.store_id")) or _as_int(config.get("app.default_site_id"))
if from_cfg is not None:
return from_cfg
rows = db.query(
"""
SELECT site_id
FROM billiards_dwd.dwd_settlement_head
WHERE site_id IS NOT NULL
GROUP BY site_id
ORDER BY COUNT(*) DESC
LIMIT 1
"""
)
if rows:
return int(dict(rows[0])["site_id"])
raise RuntimeError("Unable to resolve site_id; pass --site-id explicitly.")
def _visit_condition_sql() -> str:
return """
(
s.settle_type = 1
OR (
s.settle_type = 3
AND EXISTS (
SELECT 1
FROM billiards_dwd.dwd_assistant_service_log asl
JOIN billiards_dws.cfg_skill_type st
ON asl.skill_id = st.skill_id
AND st.course_type_code = 'BONUS'
AND st.is_active = TRUE
WHERE asl.order_settle_id = s.order_settle_id
AND asl.site_id = s.site_id
AND asl.tenant_member_id = s.member_id
AND asl.is_delete = 0
)
)
)
"""
def _fetch_visit_rows_base(
db: DatabaseOperations,
site_id: int,
start_time: datetime,
end_time: datetime,
) -> List[Dict[str, Any]]:
sql = f"""
WITH visit_raw AS (
SELECT
s.site_id,
COALESCE(NULLIF(s.member_id, 0), mca.tenant_member_id) AS member_id,
s.order_settle_id,
s.pay_time AS visit_time,
COALESCE(s.consume_money, 0) AS consume_amount
FROM billiards_dwd.dwd_settlement_head s
LEFT JOIN billiards_dwd.dim_member_card_account mca
ON s.member_card_account_id = mca.member_card_id
AND mca.scd2_is_current = 1
AND mca.register_site_id = s.site_id
WHERE s.site_id = %s
AND s.pay_time >= %s
AND s.pay_time < %s
AND {_visit_condition_sql()}
AND COALESCE(NULLIF(s.member_id, 0), mca.tenant_member_id) > 0
),
assistant_agg AS (
SELECT
asl.order_settle_id,
STRING_AGG(DISTINCT NULLIF(asl.nickname, ''), '?' ORDER BY NULLIF(asl.nickname, '')) AS assistant_nicknames
FROM billiards_dwd.dwd_assistant_service_log asl
WHERE asl.site_id = %s
AND asl.is_delete = 0
GROUP BY asl.order_settle_id
),
member_balance AS (
SELECT
mca.register_site_id AS site_id,
mca.tenant_member_id AS member_id,
SUM(
CASE
WHEN mca.card_type_id = 2793249295533893 THEN COALESCE(mca.balance, 0)
ELSE 0
END
) AS sv_balance
FROM billiards_dwd.dim_member_card_account mca
WHERE mca.register_site_id = %s
AND mca.scd2_is_current = 1
GROUP BY mca.register_site_id, mca.tenant_member_id
),
member_name AS (
SELECT member_id, nickname
FROM billiards_dwd.dim_member
WHERE register_site_id = %s
AND scd2_is_current = 1
)
SELECT
vr.site_id,
vr.member_id,
COALESCE(mn.nickname, CONCAT('member_', vr.member_id::text)) AS member_nickname,
vr.visit_time,
ROUND(vr.consume_amount::numeric, 2) AS consume_amount,
ROUND(COALESCE(mb.sv_balance, 0)::numeric, 2) AS sv_balance,
aa.assistant_nicknames
FROM visit_raw vr
LEFT JOIN assistant_agg aa
ON aa.order_settle_id = vr.order_settle_id
LEFT JOIN member_balance mb
ON mb.site_id = vr.site_id
AND mb.member_id = vr.member_id
LEFT JOIN member_name mn
ON mn.member_id = vr.member_id
ORDER BY vr.visit_time DESC, vr.order_settle_id DESC
"""
rows = db.query(sql, (site_id, start_time, end_time, site_id, site_id, site_id))
return [dict(r) for r in (rows or [])]
def _fetch_current_score_maps(
db: DatabaseOperations,
site_id: int,
) -> Tuple[Dict[int, float], Dict[int, float]]:
wbi_rows = db.query(
"""
SELECT member_id, display_score AS wbi_score
FROM billiards_dws.dws_member_winback_index
WHERE site_id = %s
""",
(site_id,),
)
nci_rows = db.query(
"""
SELECT member_id, display_score AS nci_score
FROM billiards_dws.dws_member_newconv_index
WHERE site_id = %s
""",
(site_id,),
)
wbi_map = {
int(dict(r)["member_id"]): round(float(dict(r)["wbi_score"]), 2)
for r in (wbi_rows or [])
if dict(r).get("wbi_score") is not None
}
nci_map = {
int(dict(r)["member_id"]): round(float(dict(r)["nci_score"]), 2)
for r in (nci_rows or [])
if dict(r).get("nci_score") is not None
}
return wbi_map, nci_map
def _load_wbi_params(db: DatabaseOperations) -> Dict[str, float]:
sql = """
SELECT param_name, param_value
FROM (
SELECT
param_name,
param_value,
ROW_NUMBER() OVER (
PARTITION BY param_name
ORDER BY effective_from DESC, updated_at DESC, created_at DESC
) AS rn
FROM billiards_dws.cfg_index_parameters
WHERE index_type = 'WBI'
AND effective_from <= CURRENT_DATE
) t
WHERE rn = 1
"""
rows = db.query(sql)
params: Dict[str, float] = {}
for row in (rows or []):
d = dict(row)
params[str(d["param_name"])] = float(d["param_value"])
return params
def _fetch_wbi_member_rows(db: DatabaseOperations, site_id: int) -> Dict[int, Dict[str, Any]]:
rows = db.query(
"""
SELECT
member_id,
status,
segment,
t_v,
interval_count,
overdue_old,
drop_old,
recharge_old,
value_old,
raw_score,
display_score
FROM billiards_dws.dws_member_winback_index
WHERE site_id = %s
""",
(site_id,),
)
result: Dict[int, Dict[str, Any]] = {}
for row in (rows or []):
d = dict(row)
mid = int(d["member_id"])
result[mid] = d
return result
def _fetch_member_interval_samples(
db: DatabaseOperations,
site_id: int,
member_ids: List[int],
base_date: date,
visit_lookback_days: int,
recency_days: int,
) -> Dict[int, List[Tuple[float, int]]]:
if not member_ids:
return {}
member_ids_str = ",".join(str(m) for m in member_ids)
start_date = base_date - timedelta(days=visit_lookback_days)
sql = f"""
WITH visit_source AS (
SELECT
COALESCE(NULLIF(s.member_id, 0), mca.tenant_member_id) AS member_id,
DATE(s.pay_time) AS visit_date
FROM billiards_dwd.dwd_settlement_head s
LEFT JOIN billiards_dwd.dim_member_card_account mca
ON s.member_card_account_id = mca.member_card_id
AND mca.scd2_is_current = 1
AND mca.register_site_id = s.site_id
WHERE s.site_id = %s
AND s.pay_time >= %s
AND s.pay_time < %s + INTERVAL '1 day'
AND {_visit_condition_sql()}
AND COALESCE(NULLIF(s.member_id, 0), mca.tenant_member_id) IN ({member_ids_str})
),
visit_dedup AS (
SELECT member_id, visit_date
FROM visit_source
GROUP BY member_id, visit_date
)
SELECT member_id, visit_date
FROM visit_dedup
ORDER BY member_id, visit_date
"""
rows = db.query(sql, (site_id, start_date, base_date))
member_dates: Dict[int, List[date]] = {}
for row in (rows or []):
d = dict(row)
mid = int(d["member_id"])
vdt = d["visit_date"]
if vdt is None:
continue
member_dates.setdefault(mid, []).append(vdt)
result: Dict[int, List[Tuple[float, int]]] = {}
for mid, dates in member_dates.items():
samples: List[Tuple[float, int]] = []
for i in range(1, len(dates)):
interval = (dates[i] - dates[i - 1]).days
interval_capped = float(min(recency_days, interval))
age_days = max(0, (base_date - dates[i]).days)
samples.append((interval_capped, age_days))
result[mid] = samples
return result
def _weighted_cdf(
samples: List[Tuple[float, int]],
t_v: float,
halflife_days: float,
blend_min_samples: int = 8,
) -> float:
if not samples:
return 0.5
if halflife_days <= 0:
p_eq = sum(1.0 for x, _ in samples if x <= t_v) / len(samples)
return p_eq
ln2 = math.log(2.0)
weights: List[float] = []
indicators: List[float] = []
for interval, age_days in samples:
w = math.exp(-ln2 * float(age_days) / halflife_days)
weights.append(w)
indicators.append(1.0 if interval <= t_v else 0.0)
w_sum = sum(weights)
if w_sum <= 0:
p_w = 0.5
else:
p_w = sum(w * ind for w, ind in zip(weights, indicators)) / w_sum
p_eq = sum(indicators) / len(indicators)
m = len(samples)
lam = min(1.0, float(m) / float(max(1, blend_min_samples)))
p = lam * p_w + (1.0 - lam) * p_eq
return max(0.0, min(1.0, p))
def _calculate_percentiles(scores: List[float], lower: int, upper: int) -> Tuple[float, float]:
if not scores:
return 0.0, 0.0
sorted_scores = sorted(scores)
n = len(sorted_scores)
lower_idx = max(0, int(n * lower / 100) - 1)
upper_idx = min(n - 1, int(n * upper / 100))
return sorted_scores[lower_idx], sorted_scores[upper_idx]
def _winsorize(value: float, lower: float, upper: float) -> float:
return min(max(value, lower), upper)
def _normalize_to_display(value: float, min_val: float, max_val: float, compression_mode: str) -> float:
if compression_mode == "log1p":
value = math.log1p(value)
min_val = math.log1p(min_val)
max_val = math.log1p(max_val)
elif compression_mode == "asinh":
value = math.asinh(value)
min_val = math.asinh(min_val)
max_val = math.asinh(max_val)
eps = 1e-6
rng = max_val - min_val
if rng < eps:
return 5.0
score = 10.0 * (value - min_val) / rng
return max(0.0, min(10.0, score))
def _compression_mode_from_param(params: Dict[str, float]) -> str:
mode = int(params.get("compression_mode", 0))
if mode == 1:
return "log1p"
if mode == 2:
return "asinh"
return "none"
def _build_wbi_optimized_map(
db: DatabaseOperations,
site_id: int,
base_date: date,
half_life_days: float,
) -> Dict[int, Optional[float]]:
params = _load_wbi_params(db)
w_over = float(params.get("w_over", 2.0))
w_drop = float(params.get("w_drop", 1.0))
w_re = float(params.get("w_re", 0.4))
w_value = float(params.get("w_value", 1.2))
overdue_alpha = float(params.get("overdue_alpha", 2.0))
percentile_lower = int(params.get("percentile_lower", 5))
percentile_upper = int(params.get("percentile_upper", 95))
recency_days = int(params.get("lookback_days_recency", 60))
visit_lookback_days = int(params.get("visit_lookback_days", 180))
member_rows = _fetch_wbi_member_rows(db, site_id)
member_ids_for_calc = [
mid
for mid, row in member_rows.items()
if row.get("segment") == "OLD" and row.get("raw_score") is not None
]
interval_samples = _fetch_member_interval_samples(
db=db,
site_id=site_id,
member_ids=member_ids_for_calc,
base_date=base_date,
visit_lookback_days=visit_lookback_days,
recency_days=recency_days,
)
raw_new_map: Dict[int, float] = {}
for mid in member_ids_for_calc:
row = member_rows[mid]
t_v = _as_float(row.get("t_v"), recency_days)
overdue_old = _as_float(row.get("overdue_old"))
drop_old = _as_float(row.get("drop_old"))
recharge_old = _as_float(row.get("recharge_old"))
value_old = _as_float(row.get("value_old"))
raw_old = _as_float(row.get("raw_score"))
pre_old = (
w_over * overdue_old
+ w_drop * drop_old
+ w_re * recharge_old
+ w_value * value_old
)
if pre_old <= 1e-9:
suppression = 1.0
else:
suppression = max(0.0, min(1.0, raw_old / pre_old))
p_weighted = _weighted_cdf(
samples=interval_samples.get(mid, []),
t_v=t_v,
halflife_days=half_life_days,
)
overdue_new = math.pow(p_weighted, overdue_alpha)
pre_new = (
w_over * overdue_new
+ w_drop * drop_old
+ w_re * recharge_old
+ w_value * value_old
)
raw_new = max(0.0, pre_new * suppression)
raw_new_map[mid] = raw_new
if not raw_new_map:
return {mid: _as_float(row.get("display_score")) for mid, row in member_rows.items()}
scores = list(raw_new_map.values())
q_l, q_u = _calculate_percentiles(scores, percentile_lower, percentile_upper)
compression_mode = _compression_mode_from_param(params)
display_new_map: Dict[int, Optional[float]] = {}
for mid, raw_score in raw_new_map.items():
clipped = _winsorize(raw_score, q_l, q_u)
display = _normalize_to_display(clipped, q_l, q_u, compression_mode=compression_mode)
display_new_map[mid] = round(display, 2)
# 保留未重新计算的会员(如 STOP_HIGH_BALANCE的当前展示分数。
result: Dict[int, Optional[float]] = {}
for mid, row in member_rows.items():
if mid in display_new_map:
result[mid] = display_new_map[mid]
else:
current = row.get("display_score")
result[mid] = None if current is None else round(float(current), 2)
return result
def _attach_scores(
base_rows: List[Dict[str, Any]],
wbi_map: Dict[int, Optional[float]],
nci_map: Dict[int, float],
) -> List[Dict[str, Any]]:
result: List[Dict[str, Any]] = []
for row in base_rows:
mid = int(row["member_id"])
new_row = {
"site_id": row.get("site_id"),
"member_id": row.get("member_id"),
"member_nickname": row.get("member_nickname"),
"visit_time": row.get("visit_time"),
"consume_amount": row.get("consume_amount"),
"sv_balance": row.get("sv_balance"),
"assistant_nicknames": row.get("assistant_nicknames"),
"wbi_score": wbi_map.get(mid),
"nci_score": nci_map.get(mid),
}
result.append(new_row)
return result
def _write_csv(rows: List[Dict[str, Any]], out_csv: Path) -> None:
out_csv.parent.mkdir(parents=True, exist_ok=True)
with out_csv.open("w", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=FIELDS)
writer.writeheader()
for row in rows:
writer.writerow({k: row.get(k) for k in FIELDS})
def _write_preview_md(rows: List[Dict[str, Any]], out_md: Path, limit: int = 200) -> None:
out_md.parent.mkdir(parents=True, exist_ok=True)
lines = [
"|" + "|".join(FIELDS) + "|",
"|" + "|".join(["---"] * len(FIELDS)) + "|",
]
for row in rows[:limit]:
cells = ["" if row.get(c) is None else str(row.get(c)) for c in FIELDS]
lines.append("|" + "|".join(cells) + "|")
out_md.write_text("\n".join(lines), encoding="utf-8-sig")
def _diff_and_write_report(
current_rows: List[Dict[str, Any]],
optimized_rows: List[Dict[str, Any]],
out_md: Path,
) -> None:
def _to_map(rows: List[Dict[str, Any]]) -> Dict[Tuple[Any, Any, Any], Dict[str, Any]]:
result: Dict[Tuple[Any, Any, Any], Dict[str, Any]] = {}
for r in rows:
key = (r.get("site_id"), r.get("member_id"), r.get("visit_time"))
result[key] = r
return result
cur_map = _to_map(current_rows)
opt_map = _to_map(optimized_rows)
cur_keys = set(cur_map.keys())
opt_keys = set(opt_map.keys())
common_keys = sorted(cur_keys & opt_keys)
changed_rows = 0
changed_wbi_rows = 0
changed_nci_rows = 0
changed_member_ids = set()
member_wbi_deltas: Dict[int, List[float]] = {}
for k in common_keys:
c = cur_map[k]
o = opt_map[k]
wbi_c = c.get("wbi_score")
wbi_o = o.get("wbi_score")
nci_c = c.get("nci_score")
nci_o = o.get("nci_score")
row_changed = (wbi_c != wbi_o) or (nci_c != nci_o)
if row_changed:
changed_rows += 1
mid = int(c["member_id"])
changed_member_ids.add(mid)
if wbi_c != wbi_o:
changed_wbi_rows += 1
if wbi_c is not None and wbi_o is not None:
member_wbi_deltas.setdefault(mid, []).append(float(wbi_o) - float(wbi_c))
if nci_c != nci_o:
changed_nci_rows += 1
member_delta_summary: List[Tuple[int, float, int]] = []
for mid, ds in member_wbi_deltas.items():
if not ds:
continue
avg_delta = sum(ds) / len(ds)
member_delta_summary.append((mid, avg_delta, len(ds)))
member_delta_summary.sort(key=lambda x: abs(x[1]), reverse=True)
lines = [
"# visit_60d_member_detail_with_indices当前版 vs 优化版",
"",
"## 对比概览",
f"- 当前行数: `{len(current_rows)}`",
f"- 优化行数: `{len(optimized_rows)}`",
f"- 共同主键行数(site_id,member_id,visit_time): `{len(common_keys)}`",
f"- 仅当前有: `{len(cur_keys - opt_keys)}`",
f"- 仅优化有: `{len(opt_keys - cur_keys)}`",
f"- 分数发生变化的行: `{changed_rows}`",
f"- WBI变化行: `{changed_wbi_rows}`",
f"- NCI变化行: `{changed_nci_rows}`",
f"- 涉及会员数: `{len(changed_member_ids)}`",
"",
"## 经营解读",
"- 本次优化只改 WBI把 Overdue 从等权历史替换为时间加权CDF近期样本权重更高",
"- NCI保持不变用于避免把两类策略老客挽回/新客转化)混在一次改动里。",
"- 若变化主要出现在近期行为变化快的会员,通常更符合一线“近期状态优先”的经营直觉。",
"",
"## WBI变化最大会员(按平均分差绝对值)",
"|member_id|avg_delta(optimized-current)|visit_rows|",
"|---|---:|---:|",
]
for mid, avg_delta, cnt in member_delta_summary[:20]:
lines.append(f"|{mid}|{avg_delta:.2f}|{cnt}|")
if len(member_delta_summary) == 0:
lines.append("|(none)|0.00|0|")
out_md.parent.mkdir(parents=True, exist_ok=True)
out_md.write_text("\n".join(lines), encoding="utf-8-sig")
def main() -> None:
parser = argparse.ArgumentParser(description="Export 60-day member visit detail with WBI/NCI scores.")
parser.add_argument("--site-id", type=int, default=None, help="Site id to export")
parser.add_argument("--days", type=int, default=60, help="Lookback days (default: 60)")
parser.add_argument(
"--scheme",
choices=["current", "optimized", "both"],
default="current",
help="Export scheme",
)
parser.add_argument(
"--wbi-interval-halflife-days",
type=float,
default=30.0,
help="Half-life days for weighted CDF in optimized WBI",
)
parser.add_argument(
"--output-csv",
default=os.path.join(ROOT, "docs", "visit_60d_member_detail_with_indices.csv"),
help="Output CSV path (used by current/optimized single scheme)",
)
parser.add_argument(
"--output-preview-md",
default=os.path.join(ROOT, "docs", "visit_60d_member_detail_with_indices_preview.md"),
help="Output preview markdown path (used by current/optimized single scheme)",
)
parser.add_argument(
"--output-csv-current",
default=os.path.join(ROOT, "docs", "visit_60d_member_detail_with_indices_current.csv"),
help="Output CSV path for current scheme when --scheme both",
)
parser.add_argument(
"--output-csv-optimized",
default=os.path.join(ROOT, "docs", "visit_60d_member_detail_with_indices_optimized.csv"),
help="Output CSV path for optimized scheme when --scheme both",
)
parser.add_argument(
"--output-compare-md",
default=os.path.join(ROOT, "docs", "visit_60d_member_detail_with_indices_compare.md"),
help="Output compare markdown path when --scheme both",
)
parser.add_argument("--preview-limit", type=int, default=200, help="Preview markdown row limit")
args = parser.parse_args()
config = AppConfig.load()
db_conn = DatabaseConnection(config.config["db"]["dsn"])
db = DatabaseOperations(db_conn)
try:
site_id = _resolve_site_id(config, db, args.site_id)
now = datetime.now()
start_time = now - timedelta(days=max(1, int(args.days)))
end_time = now
base_rows = _fetch_visit_rows_base(db, site_id, start_time, end_time)
wbi_current_map, nci_current_map = _fetch_current_score_maps(db, site_id)
if args.scheme == "current":
rows = _attach_scores(base_rows, wbi_current_map, nci_current_map)
out_csv = Path(args.output_csv)
out_md = Path(args.output_preview_md)
_write_csv(rows, out_csv)
_write_preview_md(rows, out_md, limit=max(1, int(args.preview_limit)))
print(f"site_id={site_id}")
print("scheme=current")
print(f"rows={len(rows)}")
print(f"csv={out_csv}")
print(f"preview={out_md}")
return
wbi_optimized_map = _build_wbi_optimized_map(
db=db,
site_id=site_id,
base_date=end_time.date(),
half_life_days=max(1.0, float(args.wbi_interval_halflife_days)),
)
if args.scheme == "optimized":
rows = _attach_scores(base_rows, wbi_optimized_map, nci_current_map)
out_csv = Path(args.output_csv)
out_md = Path(args.output_preview_md)
_write_csv(rows, out_csv)
_write_preview_md(rows, out_md, limit=max(1, int(args.preview_limit)))
print(f"site_id={site_id}")
print("scheme=optimized")
print(f"rows={len(rows)}")
print(f"csv={out_csv}")
print(f"preview={out_md}")
return
current_rows = _attach_scores(base_rows, wbi_current_map, nci_current_map)
optimized_rows = _attach_scores(base_rows, wbi_optimized_map, nci_current_map)
out_cur = Path(args.output_csv_current)
out_opt = Path(args.output_csv_optimized)
out_cmp = Path(args.output_compare_md)
_write_csv(current_rows, out_cur)
_write_csv(optimized_rows, out_opt)
_diff_and_write_report(current_rows, optimized_rows, out_cmp)
print(f"site_id={site_id}")
print("scheme=both")
print(f"rows={len(current_rows)}")
print(f"csv_current={out_cur}")
print(f"csv_optimized={out_opt}")
print(f"compare={out_cmp}")
finally:
db_conn.close()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,634 @@
# -*- coding: utf-8 -*-
"""
全量 API JSON 刷新 + 字段分析 + MD 文档完善 + 对比报告v2
时间范围2026-01-01 00:00:00 ~ 2026-02-13 00:00:00每接口 100 条
改进点(相比 v1
- siteProfile/tableProfile 等嵌套对象MD 中已记录为 object 则不展开子字段
- 请求参数与响应字段分开对比
- 只对比顶层业务字段
- 真正缺失的新字段才补充到 MD
用法python scripts/full_api_refresh_v2.py
"""
import json
import os
import re
import sys
import time
from datetime import datetime
import requests
# ── 配置 ──────────────────────────────────────────────────────────────────
API_BASE = "https://pc.ficoo.vip/apiprod/admin/v1/"
API_TOKEN = os.environ.get("API_TOKEN", "")
if not API_TOKEN:
env_path = os.path.join(os.path.dirname(__file__), "..", ".env")
if os.path.exists(env_path):
with open(env_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("API_TOKEN="):
API_TOKEN = line.split("=", 1)[1].strip()
break
SITE_ID = 2790685415443269
START_TIME = "2026-01-01 00:00:00"
END_TIME = "2026-02-13 00:00:00"
LIMIT = 100
SAMPLES_DIR = os.path.join("docs", "api-reference", "samples")
DOCS_DIR = os.path.join("docs", "api-reference")
REPORT_DIR = os.path.join("docs", "reports")
REGISTRY_PATH = os.path.join("docs", "api-reference", "api_registry.json")
HEADERS = {
"Authorization": f"Bearer {API_TOKEN}",
"Content-Type": "application/json",
}
# 已知的嵌套对象字段名MD 中记录为 object不展开子字段
KNOWN_NESTED_OBJECTS = {
"siteProfile", "tableProfile", "settleList",
"goodsStockWarningInfo", "goodsCategoryList",
}
def load_registry():
with open(REGISTRY_PATH, "r", encoding="utf-8") as f:
return json.load(f)
def call_api(module, action, body):
url = f"{API_BASE}{module}/{action}"
try:
resp = requests.post(url, json=body, headers=HEADERS, timeout=30)
resp.raise_for_status()
return resp.json()
except Exception as e:
print(f" ❌ 请求失败: {e}")
return None
def build_body(entry):
body = dict(entry.get("body") or {})
if entry.get("time_range") and entry.get("time_keys"):
keys = entry["time_keys"]
if len(keys) >= 2:
body[keys[0]] = START_TIME
body[keys[1]] = END_TIME
if entry.get("pagination"):
body[entry["pagination"].get("page_key", "page")] = 1
body[entry["pagination"].get("limit_key", "limit")] = LIMIT
return body
def unwrap_records(raw_json, entry):
"""从原始 API 响应中提取业务记录列表"""
if raw_json is None:
return []
data = raw_json.get("data")
if data is None:
return []
table_name = entry["id"]
data_path = entry.get("data_path", "")
# tenant_member_balance_overview: data 本身就是汇总对象
if table_name == "tenant_member_balance_overview":
if isinstance(data, dict):
return [data]
return []
# 按 data_path 解析
if data_path and data_path.startswith("data."):
path_parts = data_path.split(".")[1:]
current = data
for part in path_parts:
if isinstance(current, dict):
current = current.get(part)
else:
current = None
break
if isinstance(current, list):
return current
# fallback
if isinstance(data, dict):
for k, v in data.items():
if isinstance(v, list) and k.lower() not in ("total",):
return v
if isinstance(data, list):
return data
return []
def get_top_level_fields(record):
"""只提取顶层字段名和类型(不递归展开嵌套对象)"""
fields = {}
if not isinstance(record, dict):
return fields
for k, v in record.items():
if isinstance(v, dict):
fields[k] = "object"
elif isinstance(v, list):
fields[k] = "array"
elif isinstance(v, bool):
fields[k] = "boolean"
elif isinstance(v, int):
fields[k] = "integer"
elif isinstance(v, float):
fields[k] = "number"
elif v is None:
fields[k] = "null"
else:
fields[k] = "string"
return fields
def get_nested_fields(record, parent_key):
"""提取指定嵌套对象的子字段"""
obj = record.get(parent_key)
if not isinstance(obj, dict):
return {}
fields = {}
for k, v in obj.items():
path = f"{parent_key}.{k}"
if isinstance(v, dict):
fields[path] = "object"
elif isinstance(v, list):
fields[path] = "array"
elif isinstance(v, bool):
fields[path] = "boolean"
elif isinstance(v, int):
fields[path] = "integer"
elif isinstance(v, float):
fields[path] = "number"
elif v is None:
fields[path] = "null"
else:
fields[path] = "string"
return fields
def select_top5_richest(records):
"""从所有记录中选出字段数最多的前 5 条"""
if not records:
return []
scored = []
for i, rec in enumerate(records):
if not isinstance(rec, dict):
continue
field_count = len(rec)
json_len = len(json.dumps(rec, ensure_ascii=False))
scored.append((field_count, json_len, i, rec))
scored.sort(key=lambda x: (x[0], x[1]), reverse=True)
return [item[3] for item in scored[:5]]
def collect_all_top_fields(records):
"""遍历所有记录,收集所有顶层字段(含类型、出现次数、示例值)"""
all_fields = {}
for rec in records:
if not isinstance(rec, dict):
continue
fields = get_top_level_fields(rec)
for name, typ in fields.items():
if name not in all_fields:
all_fields[name] = {"type": typ, "count": 0, "example": None}
all_fields[name]["count"] += 1
if all_fields[name]["example"] is None:
val = rec.get(name)
if val is not None and val != "" and val != 0 and not isinstance(val, (dict, list)):
ex = str(val)
if len(ex) > 80:
ex = ex[:77] + "..."
all_fields[name]["example"] = ex
return all_fields
def collect_nested_fields(records, parent_key):
"""遍历所有记录,收集指定嵌套对象的子字段"""
all_fields = {}
for rec in records:
if not isinstance(rec, dict):
continue
fields = get_nested_fields(rec, parent_key)
for path, typ in fields.items():
if path not in all_fields:
all_fields[path] = {"type": typ, "count": 0, "example": None}
all_fields[path]["count"] += 1
if all_fields[path]["example"] is None:
obj = rec.get(parent_key, {})
k = path.split(".")[-1]
val = obj.get(k) if isinstance(obj, dict) else None
if val is not None and val != "" and val != 0 and not isinstance(val, (dict, list)):
ex = str(val)
if len(ex) > 80:
ex = ex[:77] + "..."
all_fields[path]["example"] = ex
return all_fields
def extract_md_response_fields(table_name):
"""从 MD 文档的响应字段章节提取字段名(排除请求参数)"""
md_path = os.path.join(DOCS_DIR, f"{table_name}.md")
if not os.path.exists(md_path):
return set(), set(), ""
with open(md_path, "r", encoding="utf-8") as f:
content = f.read()
response_fields = set()
nested_fields = set() # siteProfile.xxx 等嵌套字段
field_pattern = re.compile(r'^\|\s*`([^`]+)`\s*\|', re.MULTILINE)
header_fields = {"字段名", "类型", "示例值", "说明", "field", "example",
"description", "type", "路径", "参数", "必填", "属性", ""}
# 找到"四、响应字段"章节的范围
in_response = False
lines = content.split("\n")
response_start = None
response_end = len(lines)
for i, line in enumerate(lines):
s = line.strip()
if ("## 四" in s or "## 4" in s) and "响应字段" in s:
in_response = True
response_start = i
continue
if in_response and s.startswith("## ") and "响应字段" not in s:
response_end = i
break
if response_start is None:
# 没有明确的响应字段章节,尝试从整个文档提取
for m in field_pattern.finditer(content):
raw = m.group(1).strip()
if raw.lower() in {h.lower() for h in header_fields}:
continue
if "." in raw:
nested_fields.add(raw)
else:
response_fields.add(raw)
return response_fields, nested_fields, content
# 只从响应字段章节提取
response_section = "\n".join(lines[response_start:response_end])
for m in field_pattern.finditer(response_section):
raw = m.group(1).strip()
if raw.lower() in {h.lower() for h in header_fields}:
continue
if "." in raw:
nested_fields.add(raw)
else:
response_fields.add(raw)
return response_fields, nested_fields, content
def compare_fields(json_fields, md_fields, md_nested_fields, table_name):
"""对比 JSON 字段与 MD 字段,返回缺失和多余"""
json_names = set(json_fields.keys())
md_names = set(md_fields) if isinstance(md_fields, set) else set(md_fields)
# JSON 有但 MD 没有的顶层字段
missing_in_md = []
for name in sorted(json_names - md_names):
# 跳过已知嵌套对象(如果 MD 中已记录为 object
if name in KNOWN_NESTED_OBJECTS and name in md_names:
continue
info = json_fields[name]
missing_in_md.append((name, info))
# MD 有但 JSON 没有的字段
extra_in_md = sorted(md_names - json_names)
return missing_in_md, extra_in_md
def save_top5_sample(table_name, top5):
"""保存前 5 条最全记录作为 JSON 样本"""
sample_path = os.path.join(SAMPLES_DIR, f"{table_name}.json")
with open(sample_path, "w", encoding="utf-8") as f:
json.dump(top5, f, ensure_ascii=False, indent=2)
return sample_path
def update_md_with_missing_fields(table_name, missing_fields, md_content):
"""将真正缺失的字段补充到 MD 文档的响应字段章节末尾"""
if not missing_fields:
return False
md_path = os.path.join(DOCS_DIR, f"{table_name}.md")
if not os.path.exists(md_path):
return False
lines = md_content.split("\n")
# 找到响应字段章节的最后一个表格行
insert_idx = None
in_response = False
last_table_row = None
for i, line in enumerate(lines):
s = line.strip()
if ("## 四" in s or "## 4" in s) and "响应字段" in s:
in_response = True
continue
if in_response and s.startswith("## ") and "响应字段" not in s:
insert_idx = last_table_row
break
if in_response and s.startswith("|") and "---" not in s:
# 检查是否是表头行
if not any(h in s for h in ["字段名", "字段", "类型", "说明"]):
last_table_row = i
elif last_table_row is None:
last_table_row = i
if insert_idx is None and last_table_row is not None:
insert_idx = last_table_row
if insert_idx is None:
return False
new_rows = []
for name, info in missing_fields:
typ = info["type"]
example = info["example"] or ""
count = info["count"]
new_rows.append(
f"| `{name}` | {typ} | {example} | "
f"(新发现字段,{count}/{LIMIT} 条记录中出现) |"
)
for row in reversed(new_rows):
lines.insert(insert_idx + 1, row)
with open(md_path, "w", encoding="utf-8") as f:
f.write("\n".join(lines))
return True
def generate_report(results):
"""生成最终的 JSON vs MD 对比报告"""
lines = []
lines.append("# API JSON 字段 vs MD 文档对比报告")
lines.append("")
lines.append(f"生成时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} (Asia/Shanghai)")
lines.append(f"数据范围:{START_TIME} ~ {END_TIME}")
lines.append(f"每接口获取:{LIMIT}")
lines.append("")
# 汇总
ok = sum(1 for r in results if r["status"] == "ok")
gap = sum(1 for r in results if r["status"] == "gap")
skip = sum(1 for r in results if r["status"] == "skipped")
err = sum(1 for r in results if r["status"] == "error")
lines.append("## 汇总")
lines.append("")
lines.append("| 状态 | 数量 |")
lines.append("|------|------|")
lines.append(f"| ✅ 完全一致 | {ok} |")
lines.append(f"| ⚠️ 有新字段(已补充) | {gap} |")
lines.append(f"| ⏭️ 跳过 | {skip} |")
lines.append(f"| 💥 错误 | {err} |")
lines.append(f"| 合计 | {len(results)} |")
lines.append("")
# 各接口详情
lines.append("## 各接口详情")
lines.append("")
for r in results:
icon = {"ok": "", "gap": "⚠️", "skipped": "⏭️", "error": "💥"}.get(r["status"], "")
lines.append(f"### {r['table']} ({r.get('name_zh', '')})")
lines.append("")
lines.append(f"| 项目 | 值 |")
lines.append(f"|------|-----|")
lines.append(f"| 状态 | {icon} {r['status']} |")
lines.append(f"| 获取记录数 | {r['record_count']} |")
lines.append(f"| JSON 顶层字段数 | {r['json_field_count']} |")
lines.append(f"| MD 响应字段数 | {r['md_field_count']} |")
lines.append(f"| 数据路径 | `{r.get('data_path', 'N/A')}` |")
if r.get("top5_field_counts"):
lines.append(f"| 前5条最全记录字段数 | {r['top5_field_counts']} |")
lines.append("")
if r.get("missing_in_md"):
lines.append("新发现字段(已补充到 MD")
lines.append("")
lines.append("| 字段名 | 类型 | 示例 | 出现次数 |")
lines.append("|--------|------|------|----------|")
for name, info in r["missing_in_md"]:
lines.append(f"| `{name}` | {info['type']} | {info.get('example', '')} | {info['count']} |")
lines.append("")
if r.get("extra_in_md"):
lines.append(f"MD 中有但本次 JSON 未出现的字段(可能为条件性字段):`{'`, `'.join(r['extra_in_md'])}`")
lines.append("")
# 嵌套对象子字段汇总
if r.get("nested_summary"):
for parent, count in r["nested_summary"].items():
lines.append(f"嵌套对象 `{parent}` 含 {count} 个子字段MD 中已记录为 object不逐字段展开")
lines.append("")
# 附录siteProfile 通用字段参考
lines.append("## 附录siteProfile 通用字段参考")
lines.append("")
lines.append("以下字段在大多数接口的 `siteProfile` 嵌套对象中出现,为门店信息快照(冗余),各接口结构一致:")
lines.append("")
lines.append("| 字段 | 类型 | 说明 |")
lines.append("|------|------|------|")
lines.append("| `id` | integer | 门店 ID |")
lines.append("| `org_id` | integer | 组织 ID |")
lines.append("| `shop_name` | string | 门店名称 |")
lines.append("| `avatar` | string | 门店头像 URL |")
lines.append("| `business_tel` | string | 门店电话 |")
lines.append("| `full_address` | string | 完整地址 |")
lines.append("| `address` | string | 简短地址 |")
lines.append("| `longitude` | number | 经度 |")
lines.append("| `latitude` | number | 纬度 |")
lines.append("| `tenant_site_region_id` | integer | 区域 ID |")
lines.append("| `tenant_id` | integer | 租户 ID |")
lines.append("| `auto_light` | integer | 自动开灯 |")
lines.append("| `attendance_distance` | integer | 考勤距离 |")
lines.append("| `attendance_enabled` | integer | 考勤启用 |")
lines.append("| `wifi_name` | string | WiFi 名称 |")
lines.append("| `wifi_password` | string | WiFi 密码 |")
lines.append("| `customer_service_qrcode` | string | 客服二维码 |")
lines.append("| `customer_service_wechat` | string | 客服微信 |")
lines.append("| `fixed_pay_qrCode` | string | 固定支付二维码 |")
lines.append("| `prod_env` | integer | 生产环境标识 |")
lines.append("| `light_status` | integer | 灯光状态 |")
lines.append("| `light_type` | integer | 灯光类型 |")
lines.append("| `light_token` | string | 灯光控制 token |")
lines.append("| `site_type` | integer | 门店类型 |")
lines.append("| `site_label` | string | 门店标签 |")
lines.append("| `shop_status` | integer | 门店状态 |")
lines.append("")
return "\n".join(lines)
def main():
registry = load_registry()
print(f"加载 API 注册表: {len(registry)} 个端点")
print(f"时间范围: {START_TIME} ~ {END_TIME}")
print(f"每接口获取: {LIMIT}")
print("=" * 80)
results = []
for entry in registry:
table_name = entry["id"]
name_zh = entry.get("name_zh", "")
module = entry["module"]
action = entry["action"]
skip = entry.get("skip", False)
print(f"\n{'' * 60}")
print(f"[{table_name}] {name_zh}{module}/{action}")
if skip:
print(" ⏭️ 跳过")
results.append({
"table": table_name, "name_zh": name_zh,
"status": "skipped", "record_count": 0,
"json_field_count": 0, "md_field_count": 0,
"data_path": entry.get("data_path"),
})
continue
# 使用已有的 raw JSON上一步已获取
raw_path = os.path.join(SAMPLES_DIR, f"{table_name}_raw.json")
if os.path.exists(raw_path):
with open(raw_path, "r", encoding="utf-8") as f:
raw = json.load(f)
print(f" 使用已缓存的原始响应")
else:
body = build_body(entry)
print(f" 请求: POST {module}/{action}")
raw = call_api(module, action, body)
if raw:
with open(raw_path, "w", encoding="utf-8") as f:
json.dump(raw, f, ensure_ascii=False, indent=2)
if raw is None:
results.append({
"table": table_name, "name_zh": name_zh,
"status": "error", "record_count": 0,
"json_field_count": 0, "md_field_count": 0,
"data_path": entry.get("data_path"),
})
continue
records = unwrap_records(raw, entry)
print(f" 记录数: {len(records)}")
if not records:
results.append({
"table": table_name, "name_zh": name_zh,
"status": "ok", "record_count": 0,
"json_field_count": 0, "md_field_count": 0,
"data_path": entry.get("data_path"),
})
continue
# 选出字段最全的前 5 条
top5 = select_top5_richest(records)
top5_counts = [len(r) for r in top5]
print(f" 前 5 条最全记录顶层字段数: {top5_counts}")
# 保存前 5 条样本
save_top5_sample(table_name, top5)
# 收集所有顶层字段
json_fields = collect_all_top_fields(records)
print(f" JSON 顶层字段数: {len(json_fields)}")
# 收集嵌套对象子字段(仅用于报告,不用于对比)
nested_summary = {}
for name, info in json_fields.items():
if info["type"] == "object" and name in KNOWN_NESTED_OBJECTS:
nested = collect_nested_fields(records, name)
nested_summary[name] = len(nested)
# 提取 MD 响应字段
md_fields, md_nested, md_content = extract_md_response_fields(table_name)
print(f" MD 响应字段数: {len(md_fields)}")
# 对比
missing_in_md, extra_in_md = compare_fields(json_fields, md_fields, md_nested, table_name)
# 过滤掉已知嵌套对象MD 中已记录为 object
real_missing = [(n, i) for n, i in missing_in_md
if n not in KNOWN_NESTED_OBJECTS or n not in md_fields]
status = "ok" if not real_missing else "gap"
if real_missing:
print(f" ⚠️ 发现 {len(real_missing)} 个新字段:")
for name, info in real_missing:
print(f" + {name} ({info['type']}, {info['count']}次)")
# 补充到 MD
updated = update_md_with_missing_fields(table_name, real_missing, md_content)
if updated:
print(f" 📝 已补充到 MD 文档")
else:
print(f" ✅ 字段完全覆盖")
if extra_in_md:
print(f" MD 多 {len(extra_in_md)} 个条件性字段")
results.append({
"table": table_name, "name_zh": name_zh,
"status": status,
"record_count": len(records),
"json_field_count": len(json_fields),
"md_field_count": len(md_fields),
"data_path": entry.get("data_path"),
"missing_in_md": real_missing,
"extra_in_md": extra_in_md,
"top5_field_counts": top5_counts,
"nested_summary": nested_summary,
})
# ── 生成报告 ──
print(f"\n{'=' * 80}")
print("生成对比报告...")
report = generate_report(results)
os.makedirs(REPORT_DIR, exist_ok=True)
report_path = os.path.join(REPORT_DIR, "api_json_vs_md_report_20260214.md")
with open(report_path, "w", encoding="utf-8") as f:
f.write(report)
print(f"报告: {report_path}")
# JSON 详细结果
json_path = os.path.join(REPORT_DIR, "api_refresh_detail_20260214.json")
serializable = []
for r in results:
sr = dict(r)
if "missing_in_md" in sr and sr["missing_in_md"]:
sr["missing_in_md"] = [(n, {"type": i["type"], "count": i["count"]})
for n, i in sr["missing_in_md"]]
serializable.append(sr)
with open(json_path, "w", encoding="utf-8") as f:
json.dump(serializable, f, ensure_ascii=False, indent=2)
# 汇总
ok = sum(1 for r in results if r["status"] == "ok")
gap = sum(1 for r in results if r["status"] == "gap")
skip = sum(1 for r in results if r["status"] == "skipped")
err = sum(1 for r in results if r["status"] == "error")
print(f"\n汇总: ✅ {ok} | ⚠️ {gap} | ⏭️ {skip} | 💥 {err}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,488 @@
#!/usr/bin/env python3
"""审计一览表生成脚本 — 解析模块
从 docs/audit/changes/ 目录扫描审计源记录 Markdown 文件,
提取结构化信息(日期、标题、修改文件、风险等级、变更类型、影响模块)。
"""
from __future__ import annotations
import os
import re
from dataclasses import dataclass, field
from pathlib import Path
# ---------------------------------------------------------------------------
# 常量
# ---------------------------------------------------------------------------
# 文件名格式YYYY-MM-DD__slug.md
_FILENAME_RE = re.compile(r"^(\d{4}-\d{2}-\d{2})__(.+)\.md$")
# 文件路径 → 功能模块映射(按最长前缀优先匹配)
MODULE_MAP: dict[str, str] = {
"api/": "API 层",
"tasks/ods": "ODS 层",
"tasks/dwd": "DWD 层",
"tasks/dws": "DWS 层",
"tasks/index": "指数算法",
"loaders/": "数据装载",
"database/": "数据库",
"orchestration/": "调度",
"config/": "配置",
"cli/": "CLI",
"models/": "模型",
"scd/": "SCD2",
"docs/": "文档",
"scripts/": "脚本工具",
"tests/": "测试",
"quality/": "质量校验",
"gui/": "GUI",
"utils/": "工具库",
}
# 按前缀长度降序排列,确保最长前缀优先匹配
_SORTED_PREFIXES: list[tuple[str, str]] = sorted(
MODULE_MAP.items(), key=lambda kv: len(kv[0]), reverse=True
)
# 所有合法模块名称(含兜底"其他"
VALID_MODULES: frozenset[str] = frozenset(MODULE_MAP.values()) | {"其他"}
# ---------------------------------------------------------------------------
# 数据类
# ---------------------------------------------------------------------------
@dataclass
class AuditEntry:
"""从单个审计源记录文件解析出的结构化数据"""
date: str # YYYY-MM-DD从文件名提取
slug: str # 文件名中 __ 后的标识符
title: str # Markdown 一级标题
filename: str # 源文件名(不含路径)
changed_files: list[str] = field(default_factory=list) # 修改的文件路径列表
modules: set[str] = field(default_factory=set) # 影响的功能模块集合
risk_level: str = "未知" # 风险等级:高/中/低/极低
change_type: str = "功能" # 变更类型bugfix/功能/文档/重构/清理
# ---------------------------------------------------------------------------
# 模块分类
# ---------------------------------------------------------------------------
def classify_module(filepath: str) -> str:
"""根据 MODULE_MAP 将文件路径映射到功能模块。
匹配规则:按前缀长度降序逐一比较,首个命中即返回。
无任何前缀命中时返回 "其他"
"""
# 统一为正斜杠,去除前导 ./ 或 /
normalized = filepath.replace("\\", "/").lstrip("./")
for prefix, module_name in _SORTED_PREFIXES:
if normalized.startswith(prefix):
return module_name
return "其他"
# ---------------------------------------------------------------------------
# 解析辅助函数
# ---------------------------------------------------------------------------
def _extract_title(content: str) -> str | None:
"""从 Markdown 内容中提取第一个一级标题(# ...)。"""
for line in content.splitlines():
stripped = line.strip()
if stripped.startswith("# "):
return stripped[2:].strip()
return None
# 匹配"修改文件清单"/"文件清单"/"Changed"/"变更范围"/"变更摘要" 等章节标题
_FILE_SECTION_RE = re.compile(
r"^##\s+.*(修改文件|文件清单|Changed|变更范围|变更摘要).*$",
re.IGNORECASE,
)
# 从表格行提取文件路径:| `path` | ... 或 | path | ...
_TABLE_FILE_RE = re.compile(
r"^\|\s*`?([^`|]+?)`?\s*\|"
)
# 从列表行提取文件路径:- path 或 - `path`(忽略纯描述行)
_LIST_FILE_RE = re.compile(
r"^[-*]\s+`?([^\s`(]+\.[a-zA-Z0-9_]+)`?"
)
# 从含 → 的行提取源路径和目标路径
_ARROW_PATH_RE = re.compile(
r"`([^`]+?)`\s*→\s*`([^`]+?)`"
)
# 子章节标题(### ...),用于在文件清单章节内继续扫描
_SUB_HEADING_RE = re.compile(r"^###\s+")
def _extract_changed_files(content: str) -> list[str]:
"""从审计文件内容中提取修改文件路径列表。
扫描策略:
1. 找到"修改文件清单"/"文件清单"/"Changed"/"变更范围"等二级章节
2. 在该章节内解析表格行和列表行中的文件路径
3. 遇到下一个同级(##)章节时停止
"""
lines = content.splitlines()
results: list[str] = []
in_section = False
for line in lines:
stripped = line.strip()
if _FILE_SECTION_RE.match(stripped):
in_section = True
continue
# 遇到下一个二级章节,退出扫描
if in_section and stripped.startswith("## ") and not _FILE_SECTION_RE.match(stripped):
break
if not in_section:
continue
# 跳过表头分隔行
if re.match(r"^\|[-\s|:]+\|$", stripped):
continue
# 跳过子章节标题(### 新增文件 等),但继续扫描
if _SUB_HEADING_RE.match(stripped):
continue
# 尝试表格行
m = _TABLE_FILE_RE.match(stripped)
if m:
path = m.group(1).strip()
# 排除表头行("文件"、"文件/对象" 等)
if path and not re.match(r"^(文件|File|路径|对象)", path, re.IGNORECASE):
results.append(path)
continue
# 尝试含 → 的移动/重命名行(提取源和目标路径)
m_arrow = _ARROW_PATH_RE.search(stripped)
if m_arrow:
src, dst = m_arrow.group(1).strip(), m_arrow.group(2).strip()
if "/" in src:
results.append(src)
if "/" in dst:
results.append(dst)
continue
# 尝试列表行
m = _LIST_FILE_RE.match(stripped)
if m:
path = m.group(1).strip()
if path and "/" in path:
results.append(path)
continue
return results
# 风险等级关键词(按优先级排列)
_RISK_KEYWORDS: list[tuple[str, str]] = [
("极低", "极低"),
("", ""),
("", ""),
("", ""),
]
# 匹配风险相关章节标题
_RISK_SECTION_RE = re.compile(
r"^##\s+.*(风险|Risk).*$", re.IGNORECASE
)
def _extract_risk_level(content: str) -> str:
"""从审计文件内容中提取风险等级。
扫描策略(按优先级):
1. 头部元数据行:`- 风险等级:低` 或 `- 风险:极低`
2. 风险相关二级章节内的关键词
3. 兜底:全文搜索含"风险"的行
"""
lines = content.splitlines()
# 策略 1头部元数据通常在前 15 行内)
_meta_risk_re = re.compile(r"^-\s*风险[等级]*[:]\s*(.+)$")
for line in lines[:15]:
m = _meta_risk_re.match(line.strip())
if m:
val = m.group(1)
if "极低" in val:
return "极低"
if "" in val:
return ""
if "" in val:
return ""
if "" in val:
return ""
# 策略 2风险相关二级章节
in_section = False
section_text = ""
for line in lines:
stripped = line.strip()
if _RISK_SECTION_RE.match(stripped):
in_section = True
continue
if in_section and stripped.startswith("## "):
break
if in_section:
section_text += stripped + " "
# 策略 3兜底全文搜索含"风险"的行
if not section_text:
for line in lines:
if "风险" in line:
section_text += line.strip() + " "
if not section_text:
return "未知"
# 按优先级匹配:先检查"极低",再检查独立的"高/中/低"
if "极低" in section_text:
return "极低"
if re.search(r"风险[:]\s*高|高风险", section_text):
return ""
if re.search(r"风险[:]\s*中|中等风险", section_text):
return ""
# "纯文档" 等描述中含"低"但不含"极低"时匹配为"低"
if re.search(r"风险[:]\s*低|低风险|风险.*低", section_text):
return ""
# 推断:描述中含"纯文档/无运行时影响/纯分析"等表述视为极低
if re.search(r"纯文档|无运行时影响|纯分析|无逻辑改动|无代码", section_text):
return "极低"
return "未知"
# 变更类型推断关键词
_CHANGE_TYPE_PATTERNS: list[tuple[str, str]] = [
("bugfix", "bugfix"),
("bug", "bugfix"),
("修复", "bugfix"),
("重构", "重构"),
("清理", "清理"),
("纯文档", "文档"),
("无逻辑改动", "文档"),
("文档", "文档"),
]
def _infer_change_type(content: str) -> str:
"""从审计文件内容推断变更类型。
按优先级扫描关键词,首个命中即返回。
默认返回 "功能"
"""
lower = content.lower()
for keyword, ctype in _CHANGE_TYPE_PATTERNS:
if keyword in lower:
return ctype
return "功能"
# ---------------------------------------------------------------------------
# 核心解析函数
# ---------------------------------------------------------------------------
def parse_audit_file(filepath: str | Path) -> AuditEntry | None:
"""解析单个审计源记录文件,返回 AuditEntry。
文件名必须符合 YYYY-MM-DD__slug.md 格式,否则返回 None 并打印警告。
"""
filepath = Path(filepath)
filename = filepath.name
# 校验文件名格式
m = _FILENAME_RE.match(filename)
if not m:
print(f"[警告] 文件名格式不符,已跳过:{filename}")
return None
date_str = m.group(1)
slug = m.group(2)
# 读取文件内容
try:
content = filepath.read_text(encoding="utf-8")
except (UnicodeDecodeError, OSError) as exc:
print(f"[警告] 无法读取文件,已跳过:{filename}{exc}")
return None
# 提取标题(缺失时用 slug 兜底)
title = _extract_title(content) or slug
# 提取修改文件列表
changed_files = _extract_changed_files(content)
# 推导影响模块
if changed_files:
modules = {classify_module(f) for f in changed_files}
else:
modules = {"其他"}
# 提取风险等级
risk_level = _extract_risk_level(content)
# 推断变更类型
change_type = _infer_change_type(content)
return AuditEntry(
date=date_str,
slug=slug,
title=title,
filename=filename,
changed_files=changed_files,
modules=modules,
risk_level=risk_level,
change_type=change_type,
)
def scan_audit_dir(dirpath: str | Path) -> list[AuditEntry]:
"""扫描审计目录,返回按日期倒序排列的 AuditEntry 列表。
跳过非 .md 文件和格式不合规的文件。
目录为空或不存在时返回空列表。
"""
dirpath = Path(dirpath)
if not dirpath.is_dir():
return []
entries: list[AuditEntry] = []
for child in sorted(dirpath.iterdir()):
if not child.is_file() or child.suffix != ".md":
continue
entry = parse_audit_file(child)
if entry is not None:
entries.append(entry)
# 按日期倒序
entries.sort(key=lambda e: e.date, reverse=True)
return entries
# ---------------------------------------------------------------------------
# 渲染函数
# ---------------------------------------------------------------------------
def render_timeline_table(entries: list[AuditEntry]) -> str:
"""按时间倒序生成 Markdown 表格。
输入的 entries 应已按日期倒序排列(由 scan_audit_dir 保证)。
空列表时返回"暂无审计记录"提示。
"""
if not entries:
return "> 暂无审计记录\n"
lines: list[str] = [
"| 日期 | 需求摘要 | 变更类型 | 影响模块 | 风险 | 详情 |",
"|------|----------|----------|----------|------|------|",
]
for e in entries:
modules_str = ", ".join(sorted(e.modules))
link = f"[链接](changes/{e.filename})"
lines.append(
f"| {e.date} | {e.title} | {e.change_type} | {modules_str} | {e.risk_level} | {link} |"
)
return "\n".join(lines) + "\n"
def render_module_index(entries: list[AuditEntry]) -> str:
"""按模块分组生成 Markdown 章节。
每个模块一个三级标题 + 表格,模块按字母序排列。
空列表时返回"暂无审计记录"提示。
"""
if not entries:
return "> 暂无审计记录\n"
# 按模块分组
module_entries: dict[str, list[AuditEntry]] = {}
for e in entries:
for mod in e.modules:
module_entries.setdefault(mod, []).append(e)
sections: list[str] = []
for mod in sorted(module_entries.keys()):
mod_list = module_entries[mod]
section_lines: list[str] = [
f"### {mod}",
"",
"| 日期 | 需求摘要 | 变更类型 | 风险 | 详情 |",
"|------|----------|----------|------|------|",
]
for e in mod_list:
link = f"[链接](changes/{e.filename})"
section_lines.append(
f"| {e.date} | {e.title} | {e.change_type} | {e.risk_level} | {link} |"
)
sections.append("\n".join(section_lines) + "\n")
return "\n".join(sections)
def render_dashboard(entries: list[AuditEntry]) -> str:
"""组合时间线和模块索引生成完整 dashboard Markdown 文档。
包含:标题、生成时间戳、时间线视图、模块索引视图。
"""
from datetime import datetime
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
parts: list[str] = [
"# 审计一览表",
"",
f"> 自动生成于 {timestamp},请勿手动编辑。",
"",
"## 时间线视图",
"",
render_timeline_table(entries),
"## 模块索引",
"",
render_module_index(entries),
]
return "\n".join(parts)
# ---------------------------------------------------------------------------
# 主入口
# ---------------------------------------------------------------------------
def main() -> None:
"""扫描审计源记录 → 解析 → 渲染 → 写入 audit_dashboard.md。"""
audit_dir = Path("docs/audit/changes")
output_path = Path("docs/audit/audit_dashboard.md")
# 扫描并解析
entries = scan_audit_dir(audit_dir)
# 渲染完整 dashboard
content = render_dashboard(entries)
# 确保输出目录存在
output_path.parent.mkdir(parents=True, exist_ok=True)
# 写入文件
output_path.write_text(content, encoding="utf-8")
# 输出摘要
print(f"已解析 {len(entries)} 条审计记录")
print(f"输出文件:{output_path}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,983 @@
{
"assistant_accounts_master": [
"id",
"tenant_id",
"site_id",
"assistant_no",
"nickname",
"real_name",
"mobile",
"team_id",
"team_name",
"user_id",
"level",
"assistant_status",
"work_status",
"leave_status",
"entry_time",
"resign_time",
"start_time",
"end_time",
"create_time",
"update_time",
"order_trade_no",
"staff_id",
"staff_profile_id",
"system_role_id",
"avatar",
"birth_date",
"gender",
"height",
"weight",
"job_num",
"show_status",
"show_sort",
"sum_grade",
"assistant_grade",
"get_grade_times",
"introduce",
"video_introduction_url",
"group_id",
"group_name",
"shop_name",
"charge_way",
"entry_type",
"allow_cx",
"is_guaranteed",
"salary_grant_enabled",
"light_status",
"online_status",
"is_delete",
"cx_unit_price",
"pd_unit_price",
"last_table_id",
"last_table_name",
"person_org_id",
"serial_number",
"is_team_leader",
"criticism_status",
"last_update_name",
"ding_talk_synced",
"site_light_cfg_id",
"light_equipment_id",
"entry_sign_status",
"resign_sign_status",
"source_file",
"source_endpoint",
"fetched_at",
"payload",
"content_hash"
],
"assistant_cancellation_records": [
"id",
"siteid",
"siteprofile",
"assistantname",
"assistantabolishamount",
"assistanton",
"pdchargeminutes",
"tableareaid",
"tablearea",
"tableid",
"tablename",
"trashreason",
"createtime",
"source_file",
"source_endpoint",
"fetched_at",
"payload",
"content_hash",
"tenant_id"
],
"assistant_service_records": [
"id",
"tenant_id",
"site_id",
"siteprofile",
"site_table_id",
"order_settle_id",
"order_trade_no",
"order_pay_id",
"order_assistant_id",
"order_assistant_type",
"assistantname",
"assistantno",
"assistant_level",
"levelname",
"site_assistant_id",
"skill_id",
"skillname",
"system_member_id",
"tablename",
"tenant_member_id",
"user_id",
"assistant_team_id",
"nickname",
"ledger_name",
"ledger_group_name",
"ledger_amount",
"ledger_count",
"ledger_unit_price",
"ledger_status",
"ledger_start_time",
"ledger_end_time",
"manual_discount_amount",
"member_discount_amount",
"coupon_deduct_money",
"service_money",
"projected_income",
"real_use_seconds",
"income_seconds",
"start_use_time",
"last_use_time",
"create_time",
"is_single_order",
"is_delete",
"is_trash",
"trash_reason",
"trash_applicant_id",
"trash_applicant_name",
"operator_id",
"operator_name",
"salesman_name",
"salesman_org_id",
"salesman_user_id",
"person_org_id",
"add_clock",
"returns_clock",
"composite_grade",
"composite_grade_time",
"skill_grade",
"service_grade",
"sum_grade",
"grade_status",
"get_grade_times",
"is_not_responding",
"is_confirm",
"payload",
"source_file",
"source_endpoint",
"fetched_at",
"content_hash",
"assistantteamname",
"real_service_money"
],
"goods_stock_movements": [
"sitegoodsstockid",
"tenantid",
"siteid",
"sitegoodsid",
"goodsname",
"goodscategoryid",
"goodssecondcategoryid",
"unit",
"price",
"stocktype",
"changenum",
"startnum",
"endnum",
"changenuma",
"startnuma",
"endnuma",
"remark",
"operatorname",
"createtime",
"source_file",
"source_endpoint",
"fetched_at",
"payload",
"content_hash"
],
"goods_stock_summary": [
"sitegoodsid",
"goodsname",
"goodsunit",
"goodscategoryid",
"goodscategorysecondid",
"categoryname",
"rangestartstock",
"rangeendstock",
"rangein",
"rangeout",
"rangesale",
"rangesalemoney",
"rangeinventory",
"currentstock",
"source_file",
"source_endpoint",
"fetched_at",
"payload",
"content_hash"
],
"group_buy_packages": [
"id",
"package_id",
"package_name",
"selling_price",
"coupon_money",
"date_type",
"date_info",
"start_time",
"end_time",
"start_clock",
"end_clock",
"add_start_clock",
"add_end_clock",
"duration",
"usable_count",
"usable_range",
"table_area_id",
"table_area_name",
"table_area_id_list",
"tenant_table_area_id",
"tenant_table_area_id_list",
"site_id",
"site_name",
"tenant_id",
"card_type_ids",
"group_type",
"system_group_type",
"type",
"effective_status",
"is_enabled",
"is_delete",
"max_selectable_categories",
"area_tag_type",
"creator_name",
"create_time",
"source_file",
"source_endpoint",
"fetched_at",
"payload",
"content_hash",
"is_first_limit",
"sort",
"tenantcouponsaleorderitemid"
],
"group_buy_redemption_records": [
"id",
"tenant_id",
"site_id",
"sitename",
"table_id",
"tablename",
"tableareaname",
"tenant_table_area_id",
"order_trade_no",
"order_settle_id",
"order_pay_id",
"order_coupon_id",
"order_coupon_channel",
"coupon_code",
"coupon_money",
"coupon_origin_id",
"ledger_name",
"ledger_group_name",
"ledger_amount",
"ledger_count",
"ledger_unit_price",
"ledger_status",
"table_charge_seconds",
"promotion_activity_id",
"promotion_coupon_id",
"promotion_seconds",
"offer_type",
"assistant_promotion_money",
"assistant_service_promotion_money",
"table_service_promotion_money",
"goods_promotion_money",
"recharge_promotion_money",
"reward_promotion_money",
"goodsoptionprice",
"salesman_name",
"sales_man_org_id",
"salesman_role_id",
"salesman_user_id",
"operator_id",
"operator_name",
"is_single_order",
"is_delete",
"create_time",
"payload",
"source_file",
"source_endpoint",
"fetched_at",
"content_hash",
"assistant_service_share_money",
"assistant_share_money",
"coupon_sale_id",
"good_service_share_money",
"goods_share_money",
"member_discount_money",
"recharge_share_money",
"table_service_share_money",
"table_share_money"
],
"member_balance_changes": [
"tenant_id",
"site_id",
"register_site_id",
"registersitename",
"paysitename",
"id",
"tenant_member_id",
"tenant_member_card_id",
"system_member_id",
"membername",
"membermobile",
"card_type_id",
"membercardtypename",
"account_data",
"before",
"after",
"refund_amount",
"from_type",
"payment_method",
"relate_id",
"remark",
"operator_id",
"operator_name",
"is_delete",
"create_time",
"source_file",
"source_endpoint",
"fetched_at",
"payload",
"content_hash",
"principal_after",
"principal_before",
"principal_data"
],
"member_profiles": [
"tenant_id",
"register_site_id",
"site_name",
"id",
"system_member_id",
"member_card_grade_code",
"member_card_grade_name",
"mobile",
"nickname",
"point",
"growth_value",
"referrer_member_id",
"status",
"user_status",
"create_time",
"source_file",
"source_endpoint",
"fetched_at",
"payload",
"content_hash",
"pay_money_sum",
"person_tenant_org_id",
"person_tenant_org_name",
"recharge_money_sum",
"register_source"
],
"member_stored_value_cards": [
"tenant_id",
"tenant_member_id",
"system_member_id",
"register_site_id",
"site_name",
"id",
"member_card_grade_code",
"member_card_grade_code_name",
"member_card_type_name",
"member_name",
"member_mobile",
"card_type_id",
"card_no",
"card_physics_type",
"balance",
"denomination",
"table_discount",
"goods_discount",
"assistant_discount",
"assistant_reward_discount",
"table_service_discount",
"assistant_service_discount",
"coupon_discount",
"goods_service_discount",
"assistant_discount_sub_switch",
"table_discount_sub_switch",
"goods_discount_sub_switch",
"assistant_reward_discount_sub_switch",
"table_service_deduct_radio",
"assistant_service_deduct_radio",
"goods_service_deduct_radio",
"assistant_deduct_radio",
"table_deduct_radio",
"goods_deduct_radio",
"coupon_deduct_radio",
"assistant_reward_deduct_radio",
"tablecarddeduct",
"tableservicecarddeduct",
"goodscardeduct",
"goodsservicecarddeduct",
"assistantcarddeduct",
"assistantservicecarddeduct",
"assistantrewardcarddeduct",
"cardsettlededuct",
"couponcarddeduct",
"deliveryfeededuct",
"use_scene",
"able_cross_site",
"is_allow_give",
"is_allow_order_deduct",
"is_delete",
"bind_password",
"goods_discount_range_type",
"goodscategoryid",
"tableareaid",
"effect_site_id",
"start_time",
"end_time",
"disable_start_time",
"disable_end_time",
"last_consume_time",
"create_time",
"status",
"sort",
"tenantavatar",
"tenantname",
"pdassisnatlevel",
"cxassisnatlevel",
"source_file",
"source_endpoint",
"fetched_at",
"payload",
"content_hash",
"able_share_member_discount",
"electricity_deduct_radio",
"electricity_discount",
"electricitycarddeduct",
"member_grade",
"principal_balance",
"rechargefreezebalance"
],
"payment_transactions": [
"id",
"site_id",
"siteprofile",
"relate_type",
"relate_id",
"pay_amount",
"pay_status",
"pay_time",
"create_time",
"payment_method",
"online_pay_channel",
"source_file",
"source_endpoint",
"fetched_at",
"payload",
"content_hash",
"tenant_id"
],
"platform_coupon_redemption_records": [
"id",
"verify_id",
"certificate_id",
"coupon_code",
"coupon_name",
"coupon_channel",
"groupon_type",
"group_package_id",
"sale_price",
"coupon_money",
"coupon_free_time",
"coupon_cover",
"coupon_remark",
"use_status",
"consume_time",
"create_time",
"deal_id",
"channel_deal_id",
"site_id",
"site_order_id",
"table_id",
"tenant_id",
"operator_id",
"operator_name",
"is_delete",
"siteprofile",
"source_file",
"source_endpoint",
"fetched_at",
"payload",
"content_hash"
],
"recharge_settlements": [
"id",
"tenantid",
"siteid",
"sitename",
"balanceamount",
"cardamount",
"cashamount",
"couponamount",
"createtime",
"memberid",
"membername",
"tenantmembercardid",
"membercardtypename",
"memberphone",
"tableid",
"consumemoney",
"onlineamount",
"operatorid",
"operatorname",
"revokeorderid",
"revokeordername",
"revoketime",
"payamount",
"pointamount",
"refundamount",
"settlename",
"settlerelateid",
"settlestatus",
"settletype",
"paytime",
"roundingamount",
"paymentmethod",
"adjustamount",
"assistantcxmoney",
"assistantpdmoney",
"couponsaleamount",
"memberdiscountamount",
"tablechargemoney",
"goodsmoney",
"realgoodsmoney",
"servicemoney",
"prepaymoney",
"salesmanname",
"orderremark",
"salesmanuserid",
"canberevoked",
"pointdiscountprice",
"pointdiscountcost",
"activitydiscount",
"serialnumber",
"assistantmanualdiscount",
"allcoupondiscount",
"goodspromotionmoney",
"assistantpromotionmoney",
"isusecoupon",
"isusediscount",
"isactivity",
"isbindmember",
"isfirst",
"rechargecardamount",
"giftcardamount",
"source_file",
"source_endpoint",
"fetched_at",
"payload",
"content_hash",
"electricityadjustmoney",
"electricitymoney",
"mervousalesamount",
"plcouponsaleamount",
"realelectricitymoney"
],
"refund_transactions": [
"id",
"tenant_id",
"tenantname",
"site_id",
"siteprofile",
"relate_type",
"relate_id",
"pay_sn",
"pay_amount",
"refund_amount",
"round_amount",
"pay_status",
"pay_time",
"create_time",
"payment_method",
"pay_terminal",
"pay_config_id",
"online_pay_channel",
"online_pay_type",
"channel_fee",
"channel_payer_id",
"channel_pay_no",
"member_id",
"member_card_id",
"cashier_point_id",
"operator_id",
"action_type",
"check_status",
"is_revoke",
"is_delete",
"balance_frozen_amount",
"card_frozen_amount",
"source_file",
"source_endpoint",
"fetched_at",
"payload",
"content_hash"
],
"settlement_records": [
"id",
"tenantid",
"siteid",
"sitename",
"balanceamount",
"cardamount",
"cashamount",
"couponamount",
"createtime",
"memberid",
"membername",
"tenantmembercardid",
"membercardtypename",
"memberphone",
"tableid",
"consumemoney",
"onlineamount",
"operatorid",
"operatorname",
"revokeorderid",
"revokeordername",
"revoketime",
"payamount",
"pointamount",
"refundamount",
"settlename",
"settlerelateid",
"settlestatus",
"settletype",
"paytime",
"roundingamount",
"paymentmethod",
"adjustamount",
"assistantcxmoney",
"assistantpdmoney",
"couponsaleamount",
"memberdiscountamount",
"tablechargemoney",
"goodsmoney",
"realgoodsmoney",
"servicemoney",
"prepaymoney",
"salesmanname",
"orderremark",
"salesmanuserid",
"canberevoked",
"pointdiscountprice",
"pointdiscountcost",
"activitydiscount",
"serialnumber",
"assistantmanualdiscount",
"allcoupondiscount",
"goodspromotionmoney",
"assistantpromotionmoney",
"isusecoupon",
"isusediscount",
"isactivity",
"isbindmember",
"isfirst",
"rechargecardamount",
"giftcardamount",
"source_file",
"source_endpoint",
"fetched_at",
"payload",
"content_hash",
"electricityadjustmoney",
"electricitymoney",
"mervousalesamount",
"plcouponsaleamount",
"realelectricitymoney"
],
"site_tables_master": [
"id",
"site_id",
"sitename",
"appletQrCodeUrl",
"areaname",
"audit_status",
"charge_free",
"create_time",
"delay_lights_time",
"is_online_reservation",
"is_rest_area",
"light_status",
"only_allow_groupon",
"order_delay_time",
"self_table",
"show_status",
"site_table_area_id",
"tablestatusname",
"table_cloth_use_cycle",
"table_cloth_use_time",
"table_name",
"table_price",
"table_status",
"temporary_light_second",
"virtual_table",
"source_file",
"source_endpoint",
"fetched_at",
"payload",
"content_hash",
"order_id"
],
"stock_goods_category_tree": [
"id",
"tenant_id",
"category_name",
"alias_name",
"pid",
"business_name",
"tenant_goods_business_id",
"open_salesman",
"categoryboxes",
"sort",
"is_warehousing",
"source_file",
"source_endpoint",
"fetched_at",
"payload",
"content_hash"
],
"store_goods_master": [
"id",
"tenant_id",
"site_id",
"sitename",
"tenant_goods_id",
"goods_name",
"goods_bar_code",
"goods_category_id",
"goods_second_category_id",
"onecategoryname",
"twocategoryname",
"unit",
"sale_price",
"cost_price",
"cost_price_type",
"min_discount_price",
"safe_stock",
"stock",
"stock_a",
"sale_num",
"total_purchase_cost",
"total_sales",
"average_monthly_sales",
"batch_stock_quantity",
"days_available",
"provisional_total_cost",
"enable_status",
"audit_status",
"goods_state",
"is_delete",
"is_warehousing",
"able_discount",
"able_site_transfer",
"forbid_sell_status",
"freeze",
"send_state",
"custom_label_type",
"option_required",
"sale_channel",
"sort",
"remark",
"pinyin_initial",
"goods_cover",
"create_time",
"update_time",
"payload",
"source_file",
"source_endpoint",
"fetched_at",
"content_hash",
"commodity_code",
"not_sale"
],
"store_goods_sales_records": [
"id",
"tenant_id",
"site_id",
"siteid",
"sitename",
"site_goods_id",
"tenant_goods_id",
"order_settle_id",
"order_trade_no",
"order_goods_id",
"ordergoodsid",
"order_pay_id",
"order_coupon_id",
"ledger_name",
"ledger_group_name",
"ledger_amount",
"ledger_count",
"ledger_unit_price",
"ledger_status",
"discount_money",
"discount_price",
"coupon_deduct_money",
"member_discount_amount",
"option_coupon_deduct_money",
"option_member_discount_money",
"point_discount_money",
"point_discount_money_cost",
"real_goods_money",
"cost_money",
"push_money",
"sales_type",
"is_single_order",
"is_delete",
"goods_remark",
"option_price",
"option_value_name",
"member_coupon_id",
"package_coupon_id",
"sales_man_org_id",
"salesman_name",
"salesman_role_id",
"salesman_user_id",
"operator_id",
"operator_name",
"opensalesman",
"returns_number",
"site_table_id",
"tenant_goods_business_id",
"tenant_goods_category_id",
"create_time",
"payload",
"source_file",
"source_endpoint",
"fetched_at",
"content_hash",
"coupon_share_money"
],
"table_fee_discount_records": [
"id",
"tenant_id",
"site_id",
"siteprofile",
"site_table_id",
"tableprofile",
"tenant_table_area_id",
"adjust_type",
"ledger_amount",
"ledger_count",
"ledger_name",
"ledger_status",
"applicant_id",
"applicant_name",
"operator_id",
"operator_name",
"order_settle_id",
"order_trade_no",
"is_delete",
"create_time",
"source_file",
"source_endpoint",
"fetched_at",
"payload",
"content_hash",
"area_type_id",
"charge_free",
"site_table_area_id",
"site_table_area_name",
"sitename",
"table_name",
"table_price",
"tenant_name"
],
"table_fee_transactions": [
"id",
"tenant_id",
"site_id",
"siteprofile",
"site_table_id",
"site_table_area_id",
"site_table_area_name",
"tenant_table_area_id",
"order_trade_no",
"order_pay_id",
"order_settle_id",
"ledger_name",
"ledger_amount",
"ledger_count",
"ledger_unit_price",
"ledger_status",
"ledger_start_time",
"ledger_end_time",
"start_use_time",
"last_use_time",
"real_table_use_seconds",
"real_table_charge_money",
"add_clock_seconds",
"adjust_amount",
"coupon_promotion_amount",
"member_discount_amount",
"used_card_amount",
"mgmt_fee",
"service_money",
"fee_total",
"is_single_order",
"is_delete",
"member_id",
"operator_id",
"operator_name",
"salesman_name",
"salesman_org_id",
"salesman_user_id",
"create_time",
"payload",
"source_file",
"source_endpoint",
"fetched_at",
"content_hash",
"activity_discount_amount",
"order_consumption_type",
"real_service_money"
],
"tenant_goods_master": [
"id",
"tenant_id",
"goods_name",
"goods_bar_code",
"goods_category_id",
"goods_second_category_id",
"categoryname",
"unit",
"goods_number",
"out_goods_id",
"goods_state",
"sale_channel",
"able_discount",
"able_site_transfer",
"is_delete",
"is_warehousing",
"isinsite",
"cost_price",
"cost_price_type",
"market_price",
"min_discount_price",
"common_sale_royalty",
"point_sale_royalty",
"pinyin_initial",
"commoditycode",
"commodity_code",
"goods_cover",
"supplier_id",
"remark_name",
"create_time",
"update_time",
"payload",
"source_file",
"source_endpoint",
"fetched_at",
"content_hash",
"not_sale"
]
}

View File

@@ -0,0 +1,404 @@
# -*- coding: utf-8 -*-
"""
一键重建 ETL 相关 Schema并执行 ODS → DWD。
本脚本面向“离线示例 JSON 回放”的开发/运维场景,使用当前项目内的任务实现:
1) 可选DROP 并重建 schema`etl_admin` / `billiards_ods` / `billiards_dwd`
2) 执行 `INIT_ODS_SCHEMA`:创建 `etl_admin` 元数据表 + 执行 `schema_ODS_doc.sql`(内部会做轻量清洗)
3) 执行 `INIT_DWD_SCHEMA`:执行 `schema_dwd_doc.sql`
4) 执行 `MANUAL_INGEST`:从本地 JSON 目录灌入 ODS
5) 执行 `DWD_LOAD_FROM_ODS`:从 ODS 装载到 DWD
用法(推荐):
python -m scripts.rebuild.rebuild_db_and_run_ods_to_dwd ^
--dsn "postgresql://user:pwd@host:5432/db" ^
--store-id 1 ^
--json-dir "export/test-json-doc" ^
--drop-schemas
环境变量(可选):
PG_DSN、STORE_ID、INGEST_SOURCE_DIR
日志:
默认同时输出到控制台与文件;文件路径为 `io.log_root/rebuild_db_<时间戳>.log`。
"""
from __future__ import annotations
import argparse
import logging
import os
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import psycopg2
from config.settings import AppConfig
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations
from tasks.dwd.dwd_load_task import DwdLoadTask
from tasks.utility.init_dwd_schema_task import InitDwdSchemaTask
from tasks.utility.init_schema_task import InitOdsSchemaTask
from tasks.utility.manual_ingest_task import ManualIngestTask
DEFAULT_JSON_DIR = "export/test-json-doc"
@dataclass(frozen=True)
class RunArgs:
"""脚本参数对象(用于减少散落的参数传递)。"""
dsn: str
store_id: int
json_dir: str
drop_schemas: bool
terminate_own_sessions: bool
demo: bool
only_files: list[str]
only_dwd_tables: list[str]
stop_after: str | None
def _attach_file_logger(log_root: str | Path, filename: str, logger: logging.Logger) -> logging.Handler | None:
"""
给 root logger 附加文件日志处理器UTF-8
说明:
- 使用 root logger 是为了覆盖项目中不同命名的 logger包含第三方/子模块)。
- 若创建失败仅记录 warning不中断主流程。
返回值:
创建成功返回 handler调用方负责 removeHandler/close失败返回 None。
"""
log_dir = Path(log_root)
try:
log_dir.mkdir(parents=True, exist_ok=True)
except Exception as exc: # noqa: BLE001
logger.warning("创建日志目录失败:%s%s", log_dir, exc)
return None
log_path = log_dir / filename
try:
handler: logging.Handler = logging.FileHandler(log_path, encoding="utf-8")
except Exception as exc: # noqa: BLE001
logger.warning("创建文件日志失败:%s%s", log_path, exc)
return None
handler.setLevel(logging.INFO)
handler.setFormatter(
logging.Formatter(
fmt="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
)
logging.getLogger().addHandler(handler)
logger.info("文件日志已启用:%s", log_path)
return handler
def _parse_args() -> RunArgs:
"""解析命令行/环境变量参数。"""
parser = argparse.ArgumentParser(description="重建 Schema 并执行 ODS→DWD离线 JSON 回放)")
parser.add_argument("--dsn", default=os.environ.get("PG_DSN"), help="PostgreSQL DSN默认读取 PG_DSN")
parser.add_argument(
"--store-id",
type=int,
default=int(os.environ.get("STORE_ID") or 1),
help="门店/租户 store_id默认读取 STORE_ID否则为 1",
)
parser.add_argument(
"--json-dir",
default=os.environ.get("INGEST_SOURCE_DIR") or DEFAULT_JSON_DIR,
help=f"示例 JSON 目录(默认 {DEFAULT_JSON_DIR},也可读 INGEST_SOURCE_DIR",
)
parser.add_argument(
"--drop-schemas",
action=argparse.BooleanOptionalAction,
default=True,
help="是否先 DROP 并重建 etl_admin/billiards_ods/billiards_dwd默认",
)
parser.add_argument(
"--terminate-own-sessions",
action=argparse.BooleanOptionalAction,
default=True,
help="执行 DROP 前是否终止当前用户的 idle-in-transaction 会话(默认:是)",
)
parser.add_argument(
"--demo",
action=argparse.BooleanOptionalAction,
default=False,
help="运行最小 Demo仅导入 member_profiles 并生成 dim_member/dim_member_ex",
)
parser.add_argument(
"--only-files",
default="",
help="仅处理指定 JSON 文件(逗号分隔,不含 .json例如member_profiles,settlement_records",
)
parser.add_argument(
"--only-dwd-tables",
default="",
help="仅处理指定 DWD 表逗号分隔支持完整名或表名例如billiards_dwd.dim_member,dim_member_ex",
)
parser.add_argument(
"--stop-after",
default="",
help="在指定阶段后停止可选DROP_SCHEMAS/INIT_ODS_SCHEMA/INIT_DWD_SCHEMA/MANUAL_INGEST/DWD_LOAD_FROM_ODS/BASIC_VALIDATE",
)
args = parser.parse_args()
if not args.dsn:
raise SystemExit("缺少 DSN请传入 --dsn 或设置环境变量 PG_DSN")
only_files = [x.strip().lower() for x in str(args.only_files or "").split(",") if x.strip()]
only_dwd_tables = [x.strip().lower() for x in str(args.only_dwd_tables or "").split(",") if x.strip()]
stop_after = str(args.stop_after or "").strip().upper() or None
return RunArgs(
dsn=args.dsn,
store_id=args.store_id,
json_dir=str(args.json_dir),
drop_schemas=bool(args.drop_schemas),
terminate_own_sessions=bool(args.terminate_own_sessions),
demo=bool(args.demo),
only_files=only_files,
only_dwd_tables=only_dwd_tables,
stop_after=stop_after,
)
def _build_config(args: RunArgs) -> AppConfig:
"""构建本次执行所需的最小配置覆盖。"""
manual_cfg: dict[str, Any] = {}
dwd_cfg: dict[str, Any] = {}
if args.demo:
manual_cfg["include_files"] = ["member_profiles"]
dwd_cfg["only_tables"] = ["billiards_dwd.dim_member", "billiards_dwd.dim_member_ex"]
if args.only_files:
manual_cfg["include_files"] = args.only_files
if args.only_dwd_tables:
dwd_cfg["only_tables"] = args.only_dwd_tables
overrides: dict[str, Any] = {
"app": {"store_id": args.store_id},
"pipeline": {"flow": "INGEST_ONLY", "ingest_source_dir": args.json_dir},
"manual": manual_cfg,
"dwd": dwd_cfg,
# 离线回放/建仓可能耗时较长,关闭 statement_timeout避免被默认 30s 中断。
# 同时关闭 lock_timeout避免 DROP/DDL 因锁等待稍久就直接失败。
"db": {"dsn": args.dsn, "session": {"statement_timeout_ms": 0, "lock_timeout_ms": 0}},
}
return AppConfig.load(overrides)
def _drop_schemas(db: DatabaseOperations, logger: logging.Logger) -> None:
"""删除并重建 ETL 相关 schema具备破坏性请谨慎"""
with db.conn.cursor() as cur:
# 避免因为其他会话持锁而无限等待;若确实被占用,提示用户先释放/终止阻塞会话。
cur.execute("SET lock_timeout TO '5s'")
for schema in ("billiards_dwd", "billiards_ods", "etl_admin"):
logger.info("DROP SCHEMA IF EXISTS %s CASCADE ...", schema)
cur.execute(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE;')
def _terminate_own_idle_in_tx(db: DatabaseOperations, logger: logging.Logger) -> int:
"""终止当前用户在本库中处于 idle-in-transaction 的会话,避免阻塞 DROP/DDL。"""
with db.conn.cursor() as cur:
cur.execute(
"""
SELECT pid
FROM pg_stat_activity
WHERE datname = current_database()
AND usename = current_user
AND pid <> pg_backend_pid()
AND state = 'idle in transaction'
"""
)
pids = [r[0] for r in cur.fetchall()]
killed = 0
for pid in pids:
cur.execute("SELECT pg_terminate_backend(%s)", (pid,))
ok = bool(cur.fetchone()[0])
logger.info("终止会话 pid=%s ok=%s", pid, ok)
killed += 1 if ok else 0
return killed
def _run_task(task, logger: logging.Logger) -> dict:
"""统一运行任务并打印关键结果。"""
result = task.execute(None)
logger.info("%s: status=%s counts=%s", task.get_task_code(), result.get("status"), result.get("counts"))
return result
def _basic_validate(db: DatabaseOperations, logger: logging.Logger) -> None:
"""做最基础的可用性校验schema 存在、关键表行数可查询。"""
checks = [
("billiards_ods", "member_profiles"),
("billiards_ods", "settlement_records"),
("billiards_dwd", "dim_member"),
("billiards_dwd", "dwd_settlement_head"),
]
for schema, table in checks:
try:
rows = db.query(f'SELECT COUNT(1) AS cnt FROM "{schema}"."{table}"')
logger.info("校验行数:%s.%s = %s", schema, table, (rows[0] or {}).get("cnt") if rows else None)
except Exception as exc: # noqa: BLE001
logger.warning("校验失败:%s.%s%s", schema, table, exc)
def _connect_db_with_retry(cfg: AppConfig, logger: logging.Logger) -> DatabaseConnection:
"""创建数据库连接(带重试),避免短暂网络抖动导致脚本直接失败。"""
dsn = cfg["db"]["dsn"]
session = cfg["db"].get("session")
connect_timeout = cfg["db"].get("connect_timeout_sec")
backoffs = [1, 2, 4, 8, 16]
last_exc: Exception | None = None
for attempt, wait_sec in enumerate([0] + backoffs, start=1):
if wait_sec:
time.sleep(wait_sec)
try:
return DatabaseConnection(dsn=dsn, session=session, connect_timeout=connect_timeout)
except Exception as exc: # noqa: BLE001
last_exc = exc
logger.warning("数据库连接失败(第 %s 次):%s", attempt, exc)
raise last_exc or RuntimeError("数据库连接失败")
def _is_connection_error(exc: Exception) -> bool:
"""判断是否为连接断开/服务端异常导致的可重试错误。"""
return isinstance(exc, (psycopg2.OperationalError, psycopg2.InterfaceError))
def _run_stage_with_reconnect(
cfg: AppConfig,
logger: logging.Logger,
stage_name: str,
fn,
max_attempts: int = 3,
) -> dict | None:
"""
运行单个阶段:失败(尤其是连接断开)时自动重连并重试。
fn: (db_ops) -> dict | None
"""
last_exc: Exception | None = None
for attempt in range(1, max_attempts + 1):
db_conn = _connect_db_with_retry(cfg, logger)
db_ops = DatabaseOperations(db_conn)
try:
logger.info("阶段开始:%s(第 %s/%s 次)", stage_name, attempt, max_attempts)
result = fn(db_ops)
logger.info("阶段完成:%s", stage_name)
return result
except Exception as exc: # noqa: BLE001
last_exc = exc
logger.exception("阶段失败:%s(第 %s/%s 次):%s", stage_name, attempt, max_attempts, exc)
# 连接类错误允许重试;非连接错误直接抛出,避免掩盖逻辑问题。
if not _is_connection_error(exc):
raise
time.sleep(min(2**attempt, 10))
finally:
try:
db_ops.close() # type: ignore[attr-defined]
except Exception:
pass
try:
db_conn.close()
except Exception:
pass
raise last_exc or RuntimeError(f"阶段失败:{stage_name}")
def main() -> int:
"""脚本主入口:按顺序重建并跑通 ODS→DWD。"""
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("fq_etl.rebuild_db")
args = _parse_args()
cfg = _build_config(args)
# 默认启用文件日志,便于事后追溯(即便运行失败也应尽早落盘)。
file_handler = _attach_file_logger(
log_root=cfg["io"]["log_root"],
filename=time.strftime("rebuild_db_%Y%m%d-%H%M%S.log"),
logger=logger,
)
try:
json_dir = Path(args.json_dir)
if not json_dir.exists():
logger.error("示例 JSON 目录不存在:%s", json_dir)
return 2
def stage_drop(db_ops: DatabaseOperations):
if not args.drop_schemas:
return None
if args.terminate_own_sessions:
killed = _terminate_own_idle_in_tx(db_ops, logger)
if killed:
db_ops.commit()
_drop_schemas(db_ops, logger)
db_ops.commit()
return None
def stage_init_ods(db_ops: DatabaseOperations):
return _run_task(InitOdsSchemaTask(cfg, db_ops, None, logger), logger)
def stage_init_dwd(db_ops: DatabaseOperations):
return _run_task(InitDwdSchemaTask(cfg, db_ops, None, logger), logger)
def stage_manual_ingest(db_ops: DatabaseOperations):
logger.info("开始执行MANUAL_INGESTjson_dir=%s", json_dir)
return _run_task(ManualIngestTask(cfg, db_ops, None, logger), logger)
def stage_dwd_load(db_ops: DatabaseOperations):
logger.info("开始执行DWD_LOAD_FROM_ODS")
return _run_task(DwdLoadTask(cfg, db_ops, None, logger), logger)
_run_stage_with_reconnect(cfg, logger, "DROP_SCHEMAS", stage_drop, max_attempts=3)
if args.stop_after == "DROP_SCHEMAS":
return 0
_run_stage_with_reconnect(cfg, logger, "INIT_ODS_SCHEMA", stage_init_ods, max_attempts=3)
if args.stop_after == "INIT_ODS_SCHEMA":
return 0
_run_stage_with_reconnect(cfg, logger, "INIT_DWD_SCHEMA", stage_init_dwd, max_attempts=3)
if args.stop_after == "INIT_DWD_SCHEMA":
return 0
_run_stage_with_reconnect(cfg, logger, "MANUAL_INGEST", stage_manual_ingest, max_attempts=5)
if args.stop_after == "MANUAL_INGEST":
return 0
_run_stage_with_reconnect(cfg, logger, "DWD_LOAD_FROM_ODS", stage_dwd_load, max_attempts=5)
if args.stop_after == "DWD_LOAD_FROM_ODS":
return 0
# 校验阶段复用一条新连接即可
_run_stage_with_reconnect(
cfg,
logger,
"BASIC_VALIDATE",
lambda db_ops: _basic_validate(db_ops, logger),
max_attempts=3,
)
if args.stop_after == "BASIC_VALIDATE":
return 0
return 0
finally:
if file_handler is not None:
try:
logging.getLogger().removeHandler(file_handler)
except Exception:
pass
try:
file_handler.close()
except Exception:
pass
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,523 @@
# -*- coding: utf-8 -*-
"""
重新获取全部 API 接口的 JSON 数据(最多 100 条),
遍历所有记录提取最全字段集合,
与 .md 文档比对并输出差异报告。
时间范围2026-01-01 00:00:00 ~ 2026-02-13 00:00:00
用法python scripts/refresh_json_and_audit.py
"""
import json
import os
import re
import sys
import time
import requests
# ── 配置 ──────────────────────────────────────────────────────────────────
API_BASE = "https://pc.ficoo.vip/apiprod/admin/v1/"
API_TOKEN = os.environ.get("API_TOKEN", "")
if not API_TOKEN:
env_path = os.path.join(os.path.dirname(__file__), "..", ".env")
if os.path.exists(env_path):
with open(env_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("API_TOKEN="):
API_TOKEN = line.split("=", 1)[1].strip()
break
SITE_ID = 2790685415443269
START_TIME = "2026-01-01 00:00:00"
END_TIME = "2026-02-13 00:00:00"
LIMIT = 100
SAMPLES_DIR = os.path.join("docs", "api-reference", "samples")
DOCS_DIR = os.path.join("docs", "api-reference")
REPORT_DIR = os.path.join("docs", "reports")
HEADERS = {
"Authorization": f"Bearer {API_TOKEN}",
"Content-Type": "application/json",
}
REGISTRY_PATH = os.path.join("docs", "api-reference", "api_registry.json")
WRAPPER_FIELDS = {"settleList", "siteProfile", "tableProfile",
"goodsCategoryList", "data", "code", "msg",
"settlelist", "siteprofile", "tableprofile",
"goodscategorylist"}
CROSS_REF_HEADERS = {"字段名", "类型", "示例值", "说明", "field", "example",
"description"}
# 每个接口实际返回的列表字段名(从调试中获得)
ACTUAL_LIST_KEY = {
"assistant_accounts_master": "assistantInfos",
"assistant_service_records": "orderAssistantDetails",
"assistant_cancellation_records": "abolitionAssistants",
"table_fee_transactions": "siteTableUseDetailsList",
"table_fee_discount_records": "taiFeeAdjustInfos",
"tenant_goods_master": "tenantGoodsList",
"store_goods_sales_records": "orderGoodsLedgers",
"store_goods_master": "orderGoodsList",
"goods_stock_movements": "queryDeliveryRecordsList",
"member_profiles": "tenantMemberInfos",
"member_stored_value_cards": "tenantMemberCards",
"member_balance_changes": "tenantMemberCardLogs",
"group_buy_packages": "packageCouponList",
"group_buy_redemption_records": "siteTableUseDetailsList",
"site_tables_master": "siteTables",
# 以下使用 "list" 或特殊路径
"payment_transactions": "list",
"refund_transactions": "list",
"platform_coupon_redemption_records": "list",
"goods_stock_summary": "list",
"settlement_records": "settleList",
"recharge_settlements": "settleList",
}
def load_registry():
with open(REGISTRY_PATH, "r", encoding="utf-8") as f:
return json.load(f)
def call_api(module, action, body):
url = f"{API_BASE}{module}/{action}"
try:
resp = requests.post(url, json=body, headers=HEADERS, timeout=30)
resp.raise_for_status()
return resp.json()
except Exception as e:
print(f" ❌ 请求失败: {e}")
return None
def unwrap_records(raw_json, table_name):
"""从原始 API 响应中提取业务记录列表"""
if raw_json is None:
return []
data = raw_json.get("data")
if data is None:
return []
# ── 特殊表stock_goods_category_tree ──
if table_name == "stock_goods_category_tree":
if isinstance(data, dict):
cats = data.get("goodsCategoryList", [])
return cats if isinstance(cats, list) else []
return []
# ── 特殊表role_area_association ──
if table_name == "role_area_association":
if isinstance(data, dict):
rels = data.get("roleAreaRelations", [])
return rels if isinstance(rels, list) else []
return []
# ── 特殊表tenant_member_balance_overview ──
# 返回的是汇总对象 + rechargeCardList/giveCardList
if table_name == "tenant_member_balance_overview":
if isinstance(data, dict):
# 合并顶层标量字段 + 列表中的字段
records = [data] # 顶层作为一条记录
for list_key in ("rechargeCardList", "giveCardList"):
items = data.get(list_key, [])
if isinstance(items, list):
records.extend(items)
return records
return []
# ── settlement_records / recharge_settlements ──
# data.settleList 是列表,每个元素内部有 settleList 子对象
if table_name in ("settlement_records", "recharge_settlements"):
if isinstance(data, dict):
settle_list = data.get("settleList", [])
if isinstance(settle_list, list):
return settle_list
return []
# ── 通用data 是 dict从中找列表字段 ──
if isinstance(data, dict):
list_key = ACTUAL_LIST_KEY.get(table_name, "list")
items = data.get(list_key, [])
if isinstance(items, list):
return items
# fallback: 找第一个列表字段
for k, v in data.items():
if isinstance(v, list) and k != "total":
return v
return []
if isinstance(data, list):
return data
return []
def extract_all_fields(records, table_name):
"""从多条记录中提取所有唯一字段名(小写)"""
all_fields = set()
for record in records:
if not isinstance(record, dict):
continue
# settlement_records / recharge_settlements: 内层 settleList 展开
if table_name in ("settlement_records", "recharge_settlements"):
settle = record.get("settleList", record)
if isinstance(settle, list):
settle = settle[0] if settle else {}
if isinstance(settle, dict):
for k in settle.keys():
kl = k.lower()
if kl == "siteprofile":
all_fields.add("siteprofile")
elif kl in WRAPPER_FIELDS:
continue
else:
all_fields.add(kl)
continue
# tenant_member_balance_overview: 特殊处理
if table_name == "tenant_member_balance_overview":
for k in record.keys():
kl = k.lower()
# 跳过嵌套列表键名本身
if kl in ("rechargecardlist", "givecardlist"):
continue
all_fields.add(kl)
continue
# 通用
for k in record.keys():
kl = k.lower()
if kl in WRAPPER_FIELDS:
if kl in ("siteprofile", "tableprofile"):
all_fields.add(kl)
continue
all_fields.add(kl)
return all_fields
def extract_md_fields(table_name):
"""从 .md 文档的"四、响应字段详解"章节提取字段名(小写)"""
md_path = os.path.join(DOCS_DIR, f"{table_name}.md")
if not os.path.exists(md_path):
return set()
with open(md_path, "r", encoding="utf-8") as f:
lines = f.readlines()
fields = set()
in_section = False
in_siteprofile = False
field_pattern = re.compile(r'^\|\s*`([^`]+)`\s*\|')
siteprofile_header = re.compile(r'^###.*siteProfile', re.IGNORECASE)
for line in lines:
s = line.strip()
if s.startswith("## 四、") and "响应字段" in s:
in_section = True
in_siteprofile = False
continue
if in_section and s.startswith("## ") and not s.startswith("## 四"):
break
if not in_section:
continue
if table_name in ("settlement_records", "recharge_settlements"):
if siteprofile_header.search(s):
in_siteprofile = True
continue
if s.startswith("### ") and in_siteprofile:
if not siteprofile_header.search(s):
in_siteprofile = False
m = field_pattern.match(s)
if m:
raw = m.group(1).strip()
if raw.lower() in {h.lower() for h in CROSS_REF_HEADERS}:
continue
if table_name in ("settlement_records", "recharge_settlements"):
if in_siteprofile:
continue
if raw.startswith("siteProfile."):
continue
if raw.lower() in WRAPPER_FIELDS and raw.lower() not in (
"siteprofile", "tableprofile"):
continue
fields.add(raw.lower())
return fields
def build_body(entry):
body = dict(entry.get("body") or {})
if entry.get("time_range") and entry.get("time_keys"):
keys = entry["time_keys"]
if len(keys) >= 2:
body[keys[0]] = START_TIME
body[keys[1]] = END_TIME
if entry.get("pagination"):
body[entry["pagination"].get("page_key", "page")] = 1
body[entry["pagination"].get("limit_key", "limit")] = LIMIT
return body
def save_sample(table_name, records):
"""保存第一条记录作为 JSON 样本"""
sample_path = os.path.join(SAMPLES_DIR, f"{table_name}.json")
if records and isinstance(records[0], dict):
with open(sample_path, "w", encoding="utf-8") as f:
json.dump(records[0], f, ensure_ascii=False, indent=2)
return sample_path
def discover_actual_data_path(raw_json, table_name):
"""发现 API 实际返回的数据路径"""
data = raw_json.get("data") if raw_json else None
if data is None:
return None
# 特殊表
if table_name == "stock_goods_category_tree":
return "data.goodsCategoryList"
if table_name == "role_area_association":
return "data.roleAreaRelations"
if table_name == "tenant_member_balance_overview":
return "data" # 顶层汇总对象
if table_name in ("settlement_records", "recharge_settlements"):
return "data.settleList"
if isinstance(data, dict):
list_key = ACTUAL_LIST_KEY.get(table_name)
if list_key and list_key in data:
return f"data.{list_key}"
# fallback
for k, v in data.items():
if isinstance(v, list) and k.lower() != "total":
return f"data.{k}"
return None
def update_md_data_path(table_name, actual_path):
"""在 .md 文档的接口概述表格中更新/添加实际数据路径"""
md_path = os.path.join(DOCS_DIR, f"{table_name}.md")
if not os.path.exists(md_path):
return False
with open(md_path, "r", encoding="utf-8") as f:
content = f.read()
# 检查是否已有"数据路径"或"响应数据路径"行
if "数据路径" in content or "data_path" in content.lower():
# 尝试更新已有行
pattern = re.compile(
r'(\|\s*(?:数据路径|响应数据路径|data_path)\s*\|\s*)`[^`]*`(\s*\|)',
re.IGNORECASE
)
if pattern.search(content):
new_content = pattern.sub(
rf'\g<1>`{actual_path}`\g<2>', content
)
if new_content != content:
with open(md_path, "w", encoding="utf-8") as f:
f.write(new_content)
return True
return False # 已经是最新值
# 没有数据路径行,在接口概述表格末尾添加
# 找到"## 一、接口概述"后的表格最后一行(以 | 开头)
lines = content.split("\n")
insert_idx = None
in_overview = False
last_table_row = None
for i, line in enumerate(lines):
s = line.strip()
if "## 一、" in s and "接口概述" in s:
in_overview = True
continue
if in_overview and s.startswith("## "):
break
if in_overview and s.startswith("|") and "---" not in s:
last_table_row = i
if last_table_row is not None:
new_line = f"| 响应数据路径 | `{actual_path}` |"
lines.insert(last_table_row + 1, new_line)
with open(md_path, "w", encoding="utf-8") as f:
f.write("\n".join(lines))
return True
return False
def main():
registry = load_registry()
print(f"加载 API 注册表: {len(registry)} 个端点")
print(f"时间范围: {START_TIME} ~ {END_TIME}")
print(f"每接口获取: {LIMIT}")
print("=" * 80)
results = []
all_gaps = []
registry_updates = {} # table_name -> actual_data_path
for entry in registry:
table_name = entry["id"]
name_zh = entry.get("name_zh", "")
module = entry["module"]
action = entry["action"]
skip = entry.get("skip", False)
print(f"\n{'' * 60}")
print(f"[{table_name}] {name_zh}{module}/{action}")
if skip:
print(" ⏭️ 跳过(标记为 skip")
results.append({
"table": table_name,
"status": "skipped",
"record_count": 0,
"json_field_count": 0,
"md_field_count": 0,
"json_fields": [],
"md_fields": [],
"json_only": [],
"md_only": [],
"actual_data_path": None,
})
continue
body = build_body(entry)
print(f" 请求: POST {module}/{action}")
raw = call_api(module, action, body)
if raw is None:
results.append({
"table": table_name,
"status": "error",
"record_count": 0,
"json_field_count": 0,
"md_field_count": 0,
"json_fields": [],
"md_fields": [],
"json_only": [],
"md_only": [],
"actual_data_path": None,
})
continue
# 发现实际数据路径
actual_path = discover_actual_data_path(raw, table_name)
old_path = entry.get("data_path", "")
if actual_path and actual_path != old_path:
print(f" 📍 数据路径: {old_path}{actual_path}")
registry_updates[table_name] = actual_path
else:
print(f" 📍 数据路径: {actual_path or old_path}")
records = unwrap_records(raw, table_name)
print(f" 获取记录数: {len(records)}")
# 保存样本(第一条)
save_sample(table_name, records)
# 遍历所有记录提取全字段
json_fields = extract_all_fields(records, table_name)
md_fields = extract_md_fields(table_name)
json_only = json_fields - md_fields
md_only = md_fields - json_fields
status = "ok"
if json_only:
status = "gap"
print(f" ❌ JSON 有但 .md 缺失 ({len(json_only)} 个): {sorted(json_only)}")
all_gaps.append((table_name, name_zh, sorted(json_only)))
else:
if md_only:
print(f" ⚠️ .md 多 {len(md_only)} 个条件性字段")
else:
print(f" ✅ 完全一致 ({len(json_fields)} 个字段)")
# 更新 .md 文档中的数据路径
if actual_path:
updated = update_md_data_path(table_name, actual_path)
if updated:
print(f" 📝 已更新 .md 文档数据路径")
results.append({
"table": table_name,
"status": status,
"record_count": len(records),
"json_field_count": len(json_fields),
"md_field_count": len(md_fields),
"json_fields": sorted(json_fields),
"md_fields": sorted(md_fields),
"json_only": sorted(json_only),
"md_only": sorted(md_only),
"actual_data_path": actual_path,
})
time.sleep(0.3)
# ── 更新 api_registry.json 中的 data_path ──
if registry_updates:
print(f"\n{'' * 60}")
print(f"更新 api_registry.json 中 {len(registry_updates)} 个 data_path...")
for entry in registry:
tid = entry["id"]
if tid in registry_updates:
entry["data_path"] = registry_updates[tid]
with open(REGISTRY_PATH, "w", encoding="utf-8") as f:
json.dump(registry, f, ensure_ascii=False, indent=2)
print(" ✅ api_registry.json 已更新")
# ── 汇总 ──
print(f"\n{'=' * 80}")
print("汇总报告")
print(f"{'=' * 80}")
gap_count = sum(1 for r in results if r["status"] == "gap")
ok_count = sum(1 for r in results if r["status"] == "ok")
skip_count = sum(1 for r in results if r["status"] == "skipped")
err_count = sum(1 for r in results if r["status"] == "error")
print(f" 完全一致: {ok_count}")
print(f" 有缺失: {gap_count}")
print(f" 跳过: {skip_count}")
print(f" 错误: {err_count}")
if all_gaps:
print(f"\n需要补充到 .md 文档的字段:")
for table, name_zh, fields in all_gaps:
print(f" {table} ({name_zh}): {fields}")
# 保存详细结果
out_path = os.path.join(REPORT_DIR, "json_refresh_audit.json")
os.makedirs(REPORT_DIR, exist_ok=True)
with open(out_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print(f"\n详细结果已写入: {out_path}")
if __name__ == "__main__":
main()
# AI_CHANGELOG:
# - 日期: 2026-02-14
# - Prompt: P20260214-060000 — 全量 JSON 刷新 + MD 文档补全 + 数据路径修正
# - 直接原因: 旧 JSON 样本仅含单条记录,缺少条件性字段;需重新获取 100 条数据并遍历提取最全字段
# - 变更摘要: 新建脚本,实现:(1) 调用全部 24 个 API 端点获取 100 条数据 (2) 遍历所有记录提取字段并集
# (3) 与 .md 文档比对找出缺失字段 (4) 更新 JSON 样本和 api_registry.json data_path (5) 更新 .md 文档响应数据路径行
# - 风险与验证: 脚本需要有效的 API_TOKEN 和网络连接;验证:运行后检查 json_refresh_audit.json 中 24/24 通过

View File

@@ -0,0 +1,717 @@
# -*- coding: utf-8 -*-
"""
补全丢失的 ODS 数据
通过运行数据完整性检查,找出 API 与 ODS 之间的差异,
然后重新从 API 获取丢失的数据并插入 ODS。
用法:
python -m scripts.backfill_missing_data --start 2025-07-01 --end 2026-01-19
python -m scripts.backfill_missing_data --from-report reports/ods_gap_check_xxx.json
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
import time as time_mod
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
from zoneinfo import ZoneInfo
from dateutil import parser as dtparser
from psycopg2.extras import Json, execute_values
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from api.recording_client import build_recording_client
from config.settings import AppConfig
from database.connection import DatabaseConnection
from models.parsers import TypeParser
from tasks.ods.ods_tasks import BaseOdsTask, ENABLED_ODS_CODES, ODS_TASK_SPECS, OdsTaskSpec
from scripts.check.check_ods_gaps import run_gap_check
from utils.logging_utils import build_log_path, configure_logging
from utils.ods_record_utils import (
get_value_case_insensitive,
merge_record_layers,
normalize_pk_value,
pk_tuple_from_record,
)
def _reconfigure_stdout_utf8() -> None:
if hasattr(sys.stdout, "reconfigure"):
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
def _parse_dt(value: str, tz: ZoneInfo, *, is_end: bool = False) -> datetime:
raw = (value or "").strip()
if not raw:
raise ValueError("empty datetime")
has_time = any(ch in raw for ch in (":", "T"))
dt = dtparser.parse(raw)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=tz)
else:
dt = dt.astimezone(tz)
if not has_time:
dt = dt.replace(
hour=23 if is_end else 0,
minute=59 if is_end else 0,
second=59 if is_end else 0,
microsecond=0
)
return dt
def _get_spec(code: str) -> Optional[OdsTaskSpec]:
"""根据任务代码获取 ODS 任务规格"""
for spec in ODS_TASK_SPECS:
if spec.code == code:
return spec
return None
def _merge_record_layers(record: dict) -> dict:
"""Flatten nested data layers into a single dict."""
return merge_record_layers(record)
def _get_value_case_insensitive(record: dict | None, col: str | None):
"""Fetch value without case sensitivity."""
return get_value_case_insensitive(record, col)
def _normalize_pk_value(value):
"""Normalize PK value."""
return normalize_pk_value(value)
def _pk_tuple_from_record(record: dict, pk_cols: List[str]) -> Optional[Tuple]:
"""Extract PK tuple from record."""
return pk_tuple_from_record(record, pk_cols)
def _get_table_pk_columns(conn, table: str, *, include_content_hash: bool = False) -> List[str]:
"""获取表的主键列"""
if "." in table:
schema, name = table.split(".", 1)
else:
schema, name = "public", table
sql = """
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema = %s
AND tc.table_name = %s
ORDER BY kcu.ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, name))
cols = [r[0] for r in cur.fetchall()]
if include_content_hash:
return cols
return [c for c in cols if c.lower() != "content_hash"]
def _get_table_columns(conn, table: str) -> List[Tuple[str, str, str]]:
"""获取表的所有列信息"""
if "." in table:
schema, name = table.split(".", 1)
else:
schema, name = "public", table
sql = """
SELECT column_name, data_type, udt_name
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, name))
return [(r[0], (r[1] or "").lower(), (r[2] or "").lower()) for r in cur.fetchall()]
def _fetch_existing_pk_set(
conn, table: str, pk_cols: List[str], pk_values: List[Tuple], chunk_size: int
) -> Set[Tuple]:
"""获取已存在的 PK 集合"""
if not pk_values:
return set()
select_cols = ", ".join(f't."{c}"' for c in pk_cols)
value_cols = ", ".join(f'"{c}"' for c in pk_cols)
join_cond = " AND ".join(f't."{c}" = v."{c}"' for c in pk_cols)
sql = (
f"SELECT {select_cols} FROM {table} t "
f"JOIN (VALUES %s) AS v({value_cols}) ON {join_cond}"
)
existing: Set[Tuple] = set()
with conn.cursor() as cur:
for i in range(0, len(pk_values), chunk_size):
chunk = pk_values[i:i + chunk_size]
execute_values(cur, sql, chunk, page_size=len(chunk))
for row in cur.fetchall():
existing.add(tuple(row))
return existing
def _cast_value(value, data_type: str):
"""类型转换"""
if value is None:
return None
dt = (data_type or "").lower()
if dt in ("integer", "bigint", "smallint"):
if isinstance(value, bool):
return int(value)
try:
return int(value)
except Exception:
return None
if dt in ("numeric", "double precision", "real", "decimal"):
if isinstance(value, bool):
return int(value)
try:
return float(value)
except Exception:
return None
if dt.startswith("timestamp") or dt in ("date", "time", "interval"):
return value if isinstance(value, (str, datetime)) else None
return value
def _normalize_scalar(value):
"""规范化标量值"""
if value == "" or value == "{}" or value == "[]":
return None
return value
class MissingDataBackfiller:
"""丢失数据补全器"""
def __init__(
self,
cfg: AppConfig,
logger: logging.Logger,
dry_run: bool = False,
):
self.cfg = cfg
self.logger = logger
self.dry_run = dry_run
self.tz = ZoneInfo(cfg.get("app.timezone", "Asia/Shanghai"))
self.store_id = int(cfg.get("app.store_id") or 0)
# API 客户端
self.api = build_recording_client(cfg, task_code="BACKFILL_MISSING_DATA")
# 数据库连接DatabaseConnection 构造时已设置 autocommit=False
self.db = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
def close(self):
"""关闭连接"""
if self.db:
self.db.close()
def _ensure_db(self):
"""确保数据库连接可用"""
if self.db and getattr(self.db, "conn", None) is not None:
if getattr(self.db.conn, "closed", 0) == 0:
return
self.db = DatabaseConnection(dsn=self.cfg["db"]["dsn"], session=self.cfg["db"].get("session"))
def backfill_from_gap_check(
self,
*,
start: datetime,
end: datetime,
task_codes: Optional[str] = None,
include_mismatch: bool = False,
page_size: int = 200,
chunk_size: int = 500,
content_sample_limit: int | None = None,
) -> Dict[str, Any]:
"""
运行 gap check 并补全丢失数据
Returns:
补全结果统计
"""
self.logger.info("数据补全开始 起始=%s 结束=%s", start.isoformat(), end.isoformat())
# 计算窗口大小
total_seconds = max(0, int((end - start).total_seconds()))
if total_seconds >= 86400:
window_days = max(1, total_seconds // 86400)
window_hours = 0
else:
window_days = 0
window_hours = max(1, total_seconds // 3600 or 1)
# 运行 gap check
self.logger.info("正在执行缺失检查...")
gap_result = run_gap_check(
cfg=self.cfg,
start=start,
end=end,
window_days=window_days,
window_hours=window_hours,
page_size=page_size,
chunk_size=chunk_size,
sample_limit=10000, # 获取所有丢失样本
sleep_per_window=0,
sleep_per_page=0,
task_codes=task_codes or "",
from_cutoff=False,
cutoff_overlap_hours=24,
allow_small_window=True,
logger=self.logger,
compare_content=include_mismatch,
content_sample_limit=content_sample_limit or 10000,
)
total_missing = gap_result.get("total_missing", 0)
total_mismatch = gap_result.get("total_mismatch", 0)
if total_missing == 0 and (not include_mismatch or total_mismatch == 0):
self.logger.info("Data complete: no missing/mismatch records")
return {"backfilled": 0, "errors": 0, "details": []}
if include_mismatch:
self.logger.info("Missing/mismatch check done missing=%s mismatch=%s", total_missing, total_mismatch)
else:
self.logger.info("Missing check done missing=%s", total_missing)
results = []
total_backfilled = 0
total_errors = 0
for task_result in gap_result.get("results", []):
task_code = task_result.get("task_code")
missing = task_result.get("missing", 0)
missing_samples = task_result.get("missing_samples", [])
mismatch = task_result.get("mismatch", 0) if include_mismatch else 0
mismatch_samples = task_result.get("mismatch_samples", []) if include_mismatch else []
target_samples = list(missing_samples) + list(mismatch_samples)
if missing == 0 and mismatch == 0:
continue
self.logger.info(
"Start backfill task task=%s missing=%s mismatch=%s samples=%s",
task_code, missing, mismatch, len(target_samples)
)
try:
backfilled = self._backfill_task(
task_code=task_code,
table=task_result.get("table"),
pk_columns=task_result.get("pk_columns", []),
pk_samples=target_samples,
start=start,
end=end,
page_size=page_size,
chunk_size=chunk_size,
)
results.append({
"task_code": task_code,
"missing": missing,
"mismatch": mismatch,
"backfilled": backfilled,
"error": None,
})
total_backfilled += backfilled
except Exception as exc:
self.logger.exception("补全失败 任务=%s", task_code)
results.append({
"task_code": task_code,
"missing": missing,
"mismatch": mismatch,
"backfilled": 0,
"error": str(exc),
})
total_errors += 1
self.logger.info(
"数据补全完成 总缺失=%s 已补全=%s 错误数=%s",
total_missing, total_backfilled, total_errors
)
return {
"total_missing": total_missing,
"total_mismatch": total_mismatch,
"backfilled": total_backfilled,
"errors": total_errors,
"details": results,
}
def _backfill_task(
self,
*,
task_code: str,
table: str,
pk_columns: List[str],
pk_samples: List[Dict],
start: datetime,
end: datetime,
page_size: int,
chunk_size: int,
) -> int:
"""补全单个任务的丢失数据"""
self._ensure_db()
spec = _get_spec(task_code)
if not spec:
self.logger.warning("未找到任务规格 任务=%s", task_code)
return 0
if not pk_columns:
pk_columns = _get_table_pk_columns(self.db.conn, table, include_content_hash=False)
conflict_columns = _get_table_pk_columns(self.db.conn, table, include_content_hash=True)
if not conflict_columns:
conflict_columns = pk_columns
if not pk_columns:
self.logger.warning("未找到主键列 任务=%s 表=%s", task_code, table)
return 0
# 提取丢失的 PK 值
missing_pks: Set[Tuple] = set()
for sample in pk_samples:
pk_tuple = tuple(sample.get(col) for col in pk_columns)
if all(v is not None for v in pk_tuple):
missing_pks.add(pk_tuple)
if not missing_pks:
self.logger.info("无缺失主键 任务=%s", task_code)
return 0
self.logger.info(
"开始获取数据 任务=%s 缺失主键数=%s",
task_code, len(missing_pks)
)
# 从 API 获取数据并过滤出丢失的记录
params = self._build_params(spec, start, end)
backfilled = 0
cols_info = _get_table_columns(self.db.conn, table)
db_json_cols_lower = {
c[0].lower() for c in cols_info
if c[1] in ("json", "jsonb") or c[2] in ("json", "jsonb")
}
col_names = [c[0] for c in cols_info]
# 结束只读事务,避免长时间 API 拉取导致 idle_in_tx 超时
try:
self.db.conn.commit()
except Exception:
self.db.conn.rollback()
try:
for page_no, records, _, response_payload in self.api.iter_paginated(
endpoint=spec.endpoint,
params=params,
page_size=page_size,
data_path=spec.data_path,
list_key=spec.list_key,
):
# 过滤出丢失的记录
records_to_insert = []
for rec in records:
if not isinstance(rec, dict):
continue
pk_tuple = _pk_tuple_from_record(rec, pk_columns)
if pk_tuple and pk_tuple in missing_pks:
records_to_insert.append(rec)
if not records_to_insert:
continue
# 插入丢失的记录
if self.dry_run:
backfilled += len(records_to_insert)
self.logger.info(
"模拟运行 任务=%s 页=%s 将插入=%s",
task_code, page_no, len(records_to_insert)
)
else:
inserted = self._insert_records(
table=table,
records=records_to_insert,
cols_info=cols_info,
pk_columns=pk_columns,
conflict_columns=conflict_columns,
db_json_cols_lower=db_json_cols_lower,
)
backfilled += inserted
# 避免长事务阻塞与 idle_in_tx 超时
self.db.conn.commit()
self.logger.info(
"已插入 任务=%s 页=%s 数量=%s",
task_code, page_no, inserted
)
if not self.dry_run:
self.db.conn.commit()
self.logger.info("任务补全完成 任务=%s 已补全=%s", task_code, backfilled)
return backfilled
except Exception:
self.db.conn.rollback()
raise
def _build_params(
self,
spec: OdsTaskSpec,
start: datetime,
end: datetime,
) -> Dict:
"""构建 API 请求参数"""
base: Dict[str, Any] = {}
if spec.include_site_id:
if spec.endpoint == "/TenantGoods/GetGoodsInventoryList":
base["siteId"] = [self.store_id]
else:
base["siteId"] = self.store_id
if spec.requires_window and spec.time_fields:
start_key, end_key = spec.time_fields
base[start_key] = TypeParser.format_timestamp(start, self.tz)
base[end_key] = TypeParser.format_timestamp(end, self.tz)
# 合并公共参数
common = self.cfg.get("api.params", {}) or {}
if isinstance(common, dict):
merged = {**common, **base}
else:
merged = base
merged.update(spec.extra_params or {})
return merged
def _insert_records(
self,
*,
table: str,
records: List[Dict],
cols_info: List[Tuple[str, str, str]],
pk_columns: List[str],
conflict_columns: List[str],
db_json_cols_lower: Set[str],
) -> int:
"""插入记录到数据库"""
if not records:
return 0
col_names = [c[0] for c in cols_info]
needs_content_hash = any(c[0].lower() == "content_hash" for c in cols_info)
quoted_cols = ", ".join(f'"{c}"' for c in col_names)
sql = f"INSERT INTO {table} ({quoted_cols}) VALUES %s"
conflict_cols = conflict_columns or pk_columns
if conflict_cols:
pk_clause = ", ".join(f'"{c}"' for c in conflict_cols)
sql += f" ON CONFLICT ({pk_clause}) DO NOTHING"
now = datetime.now(self.tz)
json_dump = lambda v: json.dumps(v, ensure_ascii=False)
params: List[Tuple] = []
for rec in records:
merged_rec = _merge_record_layers(rec)
# 检查 PK
if pk_columns:
missing_pk = False
for pk in pk_columns:
if str(pk).lower() == "content_hash":
continue
pk_val = _get_value_case_insensitive(merged_rec, pk)
if pk_val is None or pk_val == "":
missing_pk = True
break
if missing_pk:
continue
content_hash = None
if needs_content_hash:
content_hash = BaseOdsTask._compute_content_hash(
merged_rec, include_fetched_at=False
)
row_vals: List[Any] = []
for (col_name, data_type, _udt) in cols_info:
col_lower = col_name.lower()
if col_lower == "payload":
row_vals.append(Json(rec, dumps=json_dump))
continue
if col_lower == "source_file":
row_vals.append("backfill")
continue
if col_lower == "source_endpoint":
row_vals.append("backfill")
continue
if col_lower == "fetched_at":
row_vals.append(now)
continue
if col_lower == "content_hash":
row_vals.append(content_hash)
continue
value = _normalize_scalar(_get_value_case_insensitive(merged_rec, col_name))
if col_lower in db_json_cols_lower:
row_vals.append(Json(value, dumps=json_dump) if value is not None else None)
continue
row_vals.append(_cast_value(value, data_type))
params.append(tuple(row_vals))
if not params:
return 0
inserted = 0
with self.db.conn.cursor() as cur:
for i in range(0, len(params), 200):
chunk = params[i:i + 200]
execute_values(cur, sql, chunk, page_size=len(chunk))
if cur.rowcount is not None and cur.rowcount > 0:
inserted += int(cur.rowcount)
return inserted
def run_backfill(
*,
cfg: AppConfig,
start: datetime,
end: datetime,
task_codes: Optional[str] = None,
include_mismatch: bool = False,
dry_run: bool = False,
page_size: int = 200,
chunk_size: int = 500,
content_sample_limit: int | None = None,
logger: logging.Logger,
) -> Dict[str, Any]:
"""
运行数据补全
Args:
cfg: 应用配置
start: 开始时间
end: 结束时间
task_codes: 指定任务代码(逗号分隔)
dry_run: 是否仅预览
page_size: API 分页大小
chunk_size: 数据库批量大小
logger: 日志记录器
Returns:
补全结果
"""
backfiller = MissingDataBackfiller(cfg, logger, dry_run)
try:
return backfiller.backfill_from_gap_check(
start=start,
end=end,
task_codes=task_codes,
include_mismatch=include_mismatch,
page_size=page_size,
chunk_size=chunk_size,
content_sample_limit=content_sample_limit,
)
finally:
backfiller.close()
def main() -> int:
_reconfigure_stdout_utf8()
ap = argparse.ArgumentParser(description="补全丢失的 ODS 数据")
ap.add_argument("--start", default="2025-07-01", help="开始日期 (默认: 2025-07-01)")
ap.add_argument("--end", default="", help="结束日期 (默认: 当前时间)")
ap.add_argument("--task-codes", default="", help="指定任务代码(逗号分隔,留空=全部)")
ap.add_argument("--include-mismatch", action="store_true", help="同时补全内容不一致的记录")
ap.add_argument("--content-sample-limit", type=int, default=None, help="不一致样本上限 (默认: 10000)")
ap.add_argument("--dry-run", action="store_true", help="仅预览,不实际写入")
ap.add_argument("--page-size", type=int, default=200, help="API 分页大小 (默认: 200)")
ap.add_argument("--chunk-size", type=int, default=500, help="数据库批量大小 (默认: 500)")
ap.add_argument("--log-file", default="", help="日志文件路径")
ap.add_argument("--log-dir", default="", help="日志目录")
ap.add_argument("--log-level", default="INFO", help="日志级别 (默认: INFO)")
ap.add_argument("--no-log-console", action="store_true", help="禁用控制台日志")
args = ap.parse_args()
log_dir = Path(args.log_dir) if args.log_dir else (PROJECT_ROOT / "logs")
log_file = Path(args.log_file) if args.log_file else build_log_path(log_dir, "backfill_missing")
log_console = not args.no_log_console
with configure_logging(
"backfill_missing",
log_file,
level=args.log_level,
console=log_console,
tee_std=True,
) as logger:
cfg = AppConfig.load({})
tz = ZoneInfo(cfg.get("app.timezone", "Asia/Shanghai"))
start = _parse_dt(args.start, tz)
end = _parse_dt(args.end, tz, is_end=True) if args.end else datetime.now(tz)
result = run_backfill(
cfg=cfg,
start=start,
end=end,
task_codes=args.task_codes or None,
include_mismatch=args.include_mismatch,
dry_run=args.dry_run,
page_size=args.page_size,
chunk_size=args.chunk_size,
content_sample_limit=args.content_sample_limit,
logger=logger,
)
logger.info("=" * 60)
logger.info("补全完成!")
logger.info(" 总丢失: %s", result.get("total_missing", 0))
if args.include_mismatch:
logger.info(" 总不一致: %s", result.get("total_mismatch", 0))
logger.info(" 已补全: %s", result.get("backfilled", 0))
logger.info(" 错误数: %s", result.get("errors", 0))
logger.info("=" * 60)
# 输出详细结果
for detail in result.get("details", []):
if detail.get("error"):
logger.error(
" %s: 丢失=%s 不一致=%s 补全=%s 错误=%s",
detail.get("task_code"),
detail.get("missing"),
detail.get("mismatch", 0),
detail.get("backfilled"),
detail.get("error"),
)
elif detail.get("backfilled", 0) > 0:
logger.info(
" %s: 丢失=%s 不一致=%s 补全=%s",
detail.get("task_code"),
detail.get("missing"),
detail.get("mismatch", 0),
detail.get("backfilled"),
)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,261 @@
# -*- coding: utf-8 -*-
"""
Deduplicate ODS snapshots by (business PK, content_hash).
Keep the latest row by fetched_at (tie-breaker: ctid desc).
Usage:
PYTHONPATH=. python -m scripts.repair.dedupe_ods_snapshots
PYTHONPATH=. python -m scripts.repair.dedupe_ods_snapshots --schema billiards_ods
PYTHONPATH=. python -m scripts.repair.dedupe_ods_snapshots --tables member_profiles,orders
"""
from __future__ import annotations
import argparse
import json
import sys
from datetime import datetime
from pathlib import Path
from typing import Iterable, Sequence
import psycopg2
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from config.settings import AppConfig
from database.connection import DatabaseConnection
def _reconfigure_stdout_utf8() -> None:
if hasattr(sys.stdout, "reconfigure"):
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
def _quote_ident(name: str) -> str:
return '"' + str(name).replace('"', '""') + '"'
def _fetch_tables(conn, schema: str) -> list[str]:
sql = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = %s AND table_type = 'BASE TABLE'
ORDER BY table_name
"""
with conn.cursor() as cur:
cur.execute(sql, (schema,))
return [r[0] for r in cur.fetchall()]
def _fetch_columns(conn, schema: str, table: str) -> list[str]:
sql = """
SELECT column_name
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, table))
return [r[0] for r in cur.fetchall()]
def _fetch_pk_columns(conn, schema: str, table: str) -> list[str]:
sql = """
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema = %s
AND tc.table_name = %s
ORDER BY kcu.ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, table))
cols = [r[0] for r in cur.fetchall()]
return [c for c in cols if c.lower() != "content_hash"]
def _build_report_path(out_arg: str | None) -> Path:
if out_arg:
return Path(out_arg)
reports_dir = PROJECT_ROOT / "reports"
reports_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
return reports_dir / f"ods_snapshot_dedupe_{ts}.json"
def _print_progress(
table_label: str,
deleted: int,
total: int,
errors: int,
) -> None:
if total:
msg = f"[{table_label}] deleted {deleted}/{total} errors={errors}"
else:
msg = f"[{table_label}] deleted {deleted} errors={errors}"
print(msg, flush=True)
def _count_duplicates(conn, schema: str, table: str, key_cols: Sequence[str]) -> int:
keys_sql = ", ".join(_quote_ident(c) for c in [*key_cols, "content_hash"])
table_sql = f"{_quote_ident(schema)}.{_quote_ident(table)}"
sql = f"""
SELECT COUNT(*) FROM (
SELECT 1
FROM (
SELECT ROW_NUMBER() OVER (
PARTITION BY {keys_sql}
ORDER BY fetched_at DESC NULLS LAST, ctid DESC
) AS rn
FROM {table_sql}
) t
WHERE rn > 1
) s
"""
with conn.cursor() as cur:
cur.execute(sql)
row = cur.fetchone()
return int(row[0] if row else 0)
def _delete_duplicate_batch(
conn,
schema: str,
table: str,
key_cols: Sequence[str],
batch_size: int,
) -> int:
keys_sql = ", ".join(_quote_ident(c) for c in [*key_cols, "content_hash"])
table_sql = f"{_quote_ident(schema)}.{_quote_ident(table)}"
sql = f"""
WITH dupes AS (
SELECT ctid
FROM (
SELECT ctid,
ROW_NUMBER() OVER (
PARTITION BY {keys_sql}
ORDER BY fetched_at DESC NULLS LAST, ctid DESC
) AS rn
FROM {table_sql}
) s
WHERE rn > 1
LIMIT %s
)
DELETE FROM {table_sql} t
USING dupes d
WHERE t.ctid = d.ctid
RETURNING 1
"""
with conn.cursor() as cur:
cur.execute(sql, (int(batch_size),))
rows = cur.fetchall()
return len(rows or [])
def main() -> int:
_reconfigure_stdout_utf8()
ap = argparse.ArgumentParser(description="Deduplicate ODS snapshot rows by PK+content_hash")
ap.add_argument("--schema", default="billiards_ods", help="ODS schema name")
ap.add_argument("--tables", default="", help="comma-separated table names (optional)")
ap.add_argument("--batch-size", type=int, default=1000, help="delete batch size")
ap.add_argument("--progress-every", type=int, default=100, help="print progress every N deletions")
ap.add_argument("--out", default="", help="output report JSON path")
ap.add_argument("--dry-run", action="store_true", help="only compute duplicate counts")
args = ap.parse_args()
cfg = AppConfig.load({})
db = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
try:
db.conn.rollback()
except Exception:
pass
db.conn.autocommit = True
tables = _fetch_tables(db.conn, args.schema)
if args.tables.strip():
whitelist = {t.strip() for t in args.tables.split(",") if t.strip()}
tables = [t for t in tables if t in whitelist]
report = {
"schema": args.schema,
"tables": [],
"summary": {
"total_tables": len(tables),
"checked_tables": 0,
"total_duplicates": 0,
"deleted_rows": 0,
"error_rows": 0,
"skipped_tables": 0,
},
}
for table in tables:
table_label = f"{args.schema}.{table}"
cols = _fetch_columns(db.conn, args.schema, table)
cols_lower = {c.lower() for c in cols}
if "content_hash" not in cols_lower or "fetched_at" not in cols_lower:
print(f"[{table_label}] skip: missing content_hash/fetched_at", flush=True)
report["summary"]["skipped_tables"] += 1
continue
key_cols = _fetch_pk_columns(db.conn, args.schema, table)
if not key_cols:
print(f"[{table_label}] skip: missing primary key", flush=True)
report["summary"]["skipped_tables"] += 1
continue
total_dupes = _count_duplicates(db.conn, args.schema, table, key_cols)
print(f"[{table_label}] duplicates={total_dupes}", flush=True)
deleted = 0
errors = 0
if not args.dry_run and total_dupes:
while True:
try:
batch_deleted = _delete_duplicate_batch(
db.conn,
args.schema,
table,
key_cols,
args.batch_size,
)
except psycopg2.Error:
errors += 1
break
if batch_deleted <= 0:
break
deleted += batch_deleted
if args.progress_every and deleted % int(args.progress_every) == 0:
_print_progress(table_label, deleted, total_dupes, errors)
if deleted and (not args.progress_every or deleted % int(args.progress_every) != 0):
_print_progress(table_label, deleted, total_dupes, errors)
report["tables"].append(
{
"table": table_label,
"duplicate_rows": total_dupes,
"deleted_rows": deleted,
"error_rows": errors,
}
)
report["summary"]["checked_tables"] += 1
report["summary"]["total_duplicates"] += total_dupes
report["summary"]["deleted_rows"] += deleted
report["summary"]["error_rows"] += errors
out_path = _build_report_path(args.out)
out_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"[REPORT] {out_path}", flush=True)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,86 @@
# -*- coding: utf-8 -*-
"""修复 dim_assistant 表中的 user_id 字段"""
import sys
sys.path.insert(0, '.')
from config.settings import AppConfig
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations
config = AppConfig.load()
db_conn = DatabaseConnection(config.config['db']['dsn'])
db = DatabaseOperations(db_conn)
print("=== 修复 dim_assistant.user_id ===")
# 方案:从 ODS 表更新 DWD 表的 user_id
# 通过 id (ODS) = assistant_id (DWD) 关联
# 1. 先检查当前状态
print("\n修复前:")
sql_before = """
SELECT
COUNT(*) as total,
COUNT(CASE WHEN user_id > 0 THEN 1 END) as has_user_id
FROM billiards_dwd.dim_assistant
WHERE scd2_is_current = 1
"""
r = dict(db.query(sql_before)[0])
print(f" 总记录: {r['total']}, 有user_id: {r['has_user_id']}")
# 2. 执行更新
print("\n执行更新...")
update_sql = """
UPDATE billiards_dwd.dim_assistant d
SET user_id = o.user_id
FROM (
SELECT DISTINCT ON (id) id, user_id
FROM billiards_ods.assistant_accounts_master
WHERE user_id > 0
ORDER BY id, fetched_at DESC
) o
WHERE d.assistant_id = o.id
AND (d.user_id IS NULL OR d.user_id = 0)
"""
with db_conn.conn.cursor() as cur:
cur.execute(update_sql)
updated = cur.rowcount
print(f" 更新了 {updated} 条记录")
db_conn.conn.commit()
# 3. 检查修复后状态
print("\n修复后:")
r2 = dict(db.query(sql_before)[0])
print(f" 总记录: {r2['total']}, 有user_id: {r2['has_user_id']}")
# 4. 显示样本数据
print("\n样本数据:")
sql_sample = """
SELECT assistant_id, user_id, assistant_no, nickname
FROM billiards_dwd.dim_assistant
WHERE scd2_is_current = 1
ORDER BY assistant_no::int
LIMIT 10
"""
for row in db.query(sql_sample):
r = dict(row)
print(f" assistant_id={r['assistant_id']}, user_id={r['user_id']}, no={r['assistant_no']}, nickname={r['nickname']}")
# 5. 验证与服务日志的关联
print("\n验证与服务日志的关联:")
sql_verify = """
SELECT
COUNT(DISTINCT s.user_id) as service_unique_users,
COUNT(DISTINCT CASE WHEN d.assistant_id IS NOT NULL THEN s.user_id END) as matched_users
FROM billiards_dwd.dwd_assistant_service_log s
LEFT JOIN billiards_dwd.dim_assistant d
ON s.user_id = d.user_id AND d.scd2_is_current = 1
WHERE s.is_delete = 0 AND s.user_id > 0
"""
r3 = dict(db.query(sql_verify)[0])
print(f" 服务日志唯一user_id: {r3['service_unique_users']}")
print(f" 能匹配到dim_assistant: {r3['matched_users']}")
match_rate = r3['matched_users'] / r3['service_unique_users'] * 100 if r3['service_unique_users'] > 0 else 0
print(f" 匹配率: {match_rate:.1f}%")
db_conn.close()
print("\n完成!")

View File

@@ -0,0 +1,302 @@
# -*- coding: utf-8 -*-
"""
Repair ODS content_hash values by recomputing from payload.
Usage:
PYTHONPATH=. python -m scripts.repair.repair_ods_content_hash
PYTHONPATH=. python -m scripts.repair.repair_ods_content_hash --schema billiards_ods
PYTHONPATH=. python -m scripts.repair.repair_ods_content_hash --tables member_profiles,orders
"""
from __future__ import annotations
import argparse
import json
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, Iterable, Sequence
import psycopg2
from psycopg2.extras import RealDictCursor
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from config.settings import AppConfig
from database.connection import DatabaseConnection
from tasks.ods.ods_tasks import BaseOdsTask
def _reconfigure_stdout_utf8() -> None:
if hasattr(sys.stdout, "reconfigure"):
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
def _fetch_tables(conn, schema: str) -> list[str]:
sql = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = %s AND table_type = 'BASE TABLE'
ORDER BY table_name
"""
with conn.cursor() as cur:
cur.execute(sql, (schema,))
return [r[0] for r in cur.fetchall()]
def _fetch_columns(conn, schema: str, table: str) -> list[str]:
sql = """
SELECT column_name
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, table))
cols = [r[0] for r in cur.fetchall()]
return [c for c in cols if c]
def _fetch_pk_columns(conn, schema: str, table: str) -> list[str]:
sql = """
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema = %s
AND tc.table_name = %s
ORDER BY kcu.ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, table))
cols = [r[0] for r in cur.fetchall()]
return [c for c in cols if c.lower() != "content_hash"]
def _fetch_row_count(conn, schema: str, table: str) -> int:
sql = f'SELECT COUNT(*) FROM "{schema}"."{table}"'
with conn.cursor() as cur:
cur.execute(sql)
row = cur.fetchone()
return int(row[0] if row else 0)
def _iter_rows(
conn,
schema: str,
table: str,
select_cols: Sequence[str],
batch_size: int,
) -> Iterable[dict]:
cols_sql = ", ".join("ctid" if c == "ctid" else f'"{c}"' for c in select_cols)
sql = f'SELECT {cols_sql} FROM "{schema}"."{table}"'
with conn.cursor(name=f"ods_hash_fix_{table}", cursor_factory=RealDictCursor) as cur:
cur.itersize = max(1, int(batch_size or 500))
cur.execute(sql)
for row in cur:
yield row
def _build_report_path(out_arg: str | None) -> Path:
if out_arg:
return Path(out_arg)
reports_dir = PROJECT_ROOT / "reports"
reports_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
return reports_dir / f"ods_content_hash_repair_{ts}.json"
def _print_progress(
table_label: str,
processed: int,
total: int,
updated: int,
skipped: int,
conflicts: int,
errors: int,
missing_hash: int,
invalid_payload: int,
) -> None:
if total:
msg = (
f"[{table_label}] checked {processed}/{total} "
f"updated={updated} skipped={skipped} conflicts={conflicts} errors={errors} "
f"missing_hash={missing_hash} invalid_payload={invalid_payload}"
)
else:
msg = (
f"[{table_label}] checked {processed} "
f"updated={updated} skipped={skipped} conflicts={conflicts} errors={errors} "
f"missing_hash={missing_hash} invalid_payload={invalid_payload}"
)
print(msg, flush=True)
def main() -> int:
_reconfigure_stdout_utf8()
ap = argparse.ArgumentParser(description="Repair ODS content_hash using payload")
ap.add_argument("--schema", default="billiards_ods", help="ODS schema name")
ap.add_argument("--tables", default="", help="comma-separated table names (optional)")
ap.add_argument("--batch-size", type=int, default=500, help="DB fetch batch size")
ap.add_argument("--progress-every", type=int, default=100, help="print progress every N rows")
ap.add_argument("--sample-limit", type=int, default=10, help="sample conflicts per table")
ap.add_argument("--out", default="", help="output report JSON path")
ap.add_argument("--dry-run", action="store_true", help="only compute stats, do not update")
args = ap.parse_args()
cfg = AppConfig.load({})
db_read = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
db_write = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
try:
db_write.conn.rollback()
except Exception:
pass
db_write.conn.autocommit = True
tables = _fetch_tables(db_read.conn, args.schema)
if args.tables.strip():
whitelist = {t.strip() for t in args.tables.split(",") if t.strip()}
tables = [t for t in tables if t in whitelist]
report = {
"schema": args.schema,
"tables": [],
"summary": {
"total_tables": len(tables),
"checked_tables": 0,
"total_rows": 0,
"checked_rows": 0,
"updated_rows": 0,
"skipped_rows": 0,
"conflict_rows": 0,
"error_rows": 0,
"missing_hash_rows": 0,
"invalid_payload_rows": 0,
},
}
for table in tables:
table_label = f"{args.schema}.{table}"
cols = _fetch_columns(db_read.conn, args.schema, table)
cols_lower = {c.lower() for c in cols}
if "payload" not in cols_lower or "content_hash" not in cols_lower:
print(f"[{table_label}] skip: missing payload/content_hash", flush=True)
continue
total = _fetch_row_count(db_read.conn, args.schema, table)
pk_cols = _fetch_pk_columns(db_read.conn, args.schema, table)
select_cols = ["ctid", "content_hash", "payload", *pk_cols]
processed = 0
updated = 0
skipped = 0
conflicts = 0
errors = 0
missing_hash = 0
invalid_payload = 0
samples: list[dict[str, Any]] = []
print(f"[{table_label}] start: total_rows={total}", flush=True)
for row in _iter_rows(db_read.conn, args.schema, table, select_cols, args.batch_size):
processed += 1
content_hash = row.get("content_hash")
payload = row.get("payload")
recomputed = BaseOdsTask._compute_compare_hash_from_payload(payload)
row_ctid = row.get("ctid")
if not content_hash:
missing_hash += 1
if not recomputed:
invalid_payload += 1
if not recomputed:
skipped += 1
elif content_hash == recomputed:
skipped += 1
else:
if args.dry_run:
updated += 1
else:
try:
with db_write.conn.cursor() as cur:
cur.execute(
f'UPDATE "{args.schema}"."{table}" SET content_hash = %s WHERE ctid = %s',
(recomputed, row_ctid),
)
updated += 1
except psycopg2.errors.UniqueViolation:
conflicts += 1
if len(samples) < max(0, int(args.sample_limit or 0)):
sample = {k: row.get(k) for k in pk_cols}
sample["content_hash"] = content_hash
sample["recomputed_hash"] = recomputed
samples.append(sample)
except psycopg2.Error:
errors += 1
if args.progress_every and processed % int(args.progress_every) == 0:
_print_progress(
table_label,
processed,
total,
updated,
skipped,
conflicts,
errors,
missing_hash,
invalid_payload,
)
if processed and (not args.progress_every or processed % int(args.progress_every) != 0):
_print_progress(
table_label,
processed,
total,
updated,
skipped,
conflicts,
errors,
missing_hash,
invalid_payload,
)
report["tables"].append(
{
"table": table_label,
"total_rows": total,
"checked_rows": processed,
"updated_rows": updated,
"skipped_rows": skipped,
"conflict_rows": conflicts,
"error_rows": errors,
"missing_hash_rows": missing_hash,
"invalid_payload_rows": invalid_payload,
"conflict_samples": samples,
}
)
report["summary"]["checked_tables"] += 1
report["summary"]["total_rows"] += total
report["summary"]["checked_rows"] += processed
report["summary"]["updated_rows"] += updated
report["summary"]["skipped_rows"] += skipped
report["summary"]["conflict_rows"] += conflicts
report["summary"]["error_rows"] += errors
report["summary"]["missing_hash_rows"] += missing_hash
report["summary"]["invalid_payload_rows"] += invalid_payload
out_path = _build_report_path(args.out)
out_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"[REPORT] {out_path}", flush=True)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,231 @@
# -*- coding: utf-8 -*-
"""Create performance indexes for integrity verification and run ANALYZE.
Usage:
python -m scripts.tune_integrity_indexes
python -m scripts.tune_integrity_indexes --dry-run
"""
from __future__ import annotations
import argparse
import hashlib
from dataclasses import dataclass
from typing import Dict, List, Sequence, Set, Tuple
import psycopg2
from psycopg2 import sql
from config.settings import AppConfig
TIME_CANDIDATES = (
"pay_time",
"create_time",
"start_use_time",
"scd2_start_time",
"calc_time",
"order_date",
"fetched_at",
)
@dataclass(frozen=True)
class IndexPlan:
schema: str
table: str
index_name: str
columns: Tuple[str, ...]
def _short_index_name(table: str, tag: str, columns: Sequence[str]) -> str:
raw = f"idx_{table}_{tag}_{'_'.join(columns)}"
if len(raw) <= 63:
return raw
digest = hashlib.md5(raw.encode("utf-8")).hexdigest()[:8]
shortened = f"idx_{table}_{tag}_{digest}"
return shortened[:63]
def _load_table_columns(cur, schema: str, table: str) -> Set[str]:
cur.execute(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
""",
(schema, table),
)
return {r[0] for r in cur.fetchall()}
def _load_pk_columns(cur, schema: str, table: str) -> List[str]:
cur.execute(
"""
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
AND tc.table_name = kcu.table_name
WHERE tc.table_schema = %s
AND tc.table_name = %s
AND tc.constraint_type = 'PRIMARY KEY'
ORDER BY kcu.ordinal_position
""",
(schema, table),
)
return [r[0] for r in cur.fetchall()]
def _load_tables(cur, schema: str) -> List[str]:
cur.execute(
"""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = %s
AND table_type = 'BASE TABLE'
ORDER BY table_name
""",
(schema,),
)
return [r[0] for r in cur.fetchall()]
def _plan_indexes(cur, schema: str, table: str) -> List[IndexPlan]:
plans: List[IndexPlan] = []
cols = _load_table_columns(cur, schema, table)
pk_cols = _load_pk_columns(cur, schema, table)
if schema == "billiards_ods":
if "fetched_at" in cols:
plans.append(
IndexPlan(
schema=schema,
table=table,
index_name=_short_index_name(table, "fetched_at", ("fetched_at",)),
columns=("fetched_at",),
)
)
if pk_cols and len(pk_cols) <= 3 and all(c in cols for c in pk_cols):
comp_cols = ("fetched_at", *pk_cols)
plans.append(
IndexPlan(
schema=schema,
table=table,
index_name=_short_index_name(table, "fetched_pk", comp_cols),
columns=comp_cols,
)
)
if schema == "billiards_dwd":
if pk_cols and "scd2_is_current" in cols and len(pk_cols) <= 4:
comp_cols = (*pk_cols, "scd2_is_current")
plans.append(
IndexPlan(
schema=schema,
table=table,
index_name=_short_index_name(table, "pk_current", comp_cols),
columns=comp_cols,
)
)
for tcol in TIME_CANDIDATES:
if tcol in cols:
plans.append(
IndexPlan(
schema=schema,
table=table,
index_name=_short_index_name(table, "time", (tcol,)),
columns=(tcol,),
)
)
if pk_cols and len(pk_cols) <= 3 and all(c in cols for c in pk_cols):
comp_cols = (tcol, *pk_cols)
plans.append(
IndexPlan(
schema=schema,
table=table,
index_name=_short_index_name(table, "time_pk", comp_cols),
columns=comp_cols,
)
)
# 按索引名去重
dedup: Dict[str, IndexPlan] = {}
for p in plans:
dedup[p.index_name] = p
return list(dedup.values())
def _create_index(cur, plan: IndexPlan) -> None:
stmt = sql.SQL("CREATE INDEX IF NOT EXISTS {idx} ON {sch}.{tbl} ({cols})").format(
idx=sql.Identifier(plan.index_name),
sch=sql.Identifier(plan.schema),
tbl=sql.Identifier(plan.table),
cols=sql.SQL(", ").join(sql.Identifier(c) for c in plan.columns),
)
cur.execute(stmt)
def _analyze_table(cur, schema: str, table: str) -> None:
stmt = sql.SQL("ANALYZE {sch}.{tbl}").format(
sch=sql.Identifier(schema),
tbl=sql.Identifier(table),
)
cur.execute(stmt)
def main() -> int:
ap = argparse.ArgumentParser(description="Tune indexes for integrity verification.")
ap.add_argument("--dry-run", action="store_true", help="Print planned SQL only.")
ap.add_argument(
"--skip-analyze",
action="store_true",
help="Create indexes but skip ANALYZE.",
)
args = ap.parse_args()
cfg = AppConfig.load({})
dsn = cfg.get("db.dsn")
timeout_sec = int(cfg.get("db.connect_timeout_sec", 10) or 10)
with psycopg2.connect(dsn, connect_timeout=timeout_sec) as conn:
conn.autocommit = False
with conn.cursor() as cur:
all_plans: List[IndexPlan] = []
for schema in ("billiards_ods", "billiards_dwd"):
for table in _load_tables(cur, schema):
all_plans.extend(_plan_indexes(cur, schema, table))
touched_tables: Set[Tuple[str, str]] = set()
print(f"planned indexes: {len(all_plans)}")
for plan in all_plans:
cols = ", ".join(plan.columns)
print(f"[INDEX] {plan.schema}.{plan.table} ({cols}) -> {plan.index_name}")
if not args.dry_run:
_create_index(cur, plan)
touched_tables.add((plan.schema, plan.table))
if not args.skip_analyze:
if args.dry_run:
for schema, table in sorted({(p.schema, p.table) for p in all_plans}):
print(f"[ANALYZE] {schema}.{table}")
else:
for schema, table in sorted(touched_tables):
_analyze_table(cur, schema, table)
print(f"[ANALYZE] {schema}.{table}")
if args.dry_run:
conn.rollback()
print("dry-run complete; transaction rolled back")
else:
conn.commit()
print("index tuning complete")
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,113 @@
# -*- coding: utf-8 -*-
"""
v3 比对脚本 — 直接从 JSON 样本提取字段,与硬编码的 ODS 列比对。
ODS 列数据来自 information_schema.columns WHERE table_schema = 'billiards_ods'
"""
import json
import os
SAMPLES_DIR = os.path.join(os.path.dirname(__file__), "..", "docs", "api-reference", "samples")
REPORT_DIR = os.path.join(os.path.dirname(__file__), "..", "docs", "reports")
ODS_META = {"source_file", "source_endpoint", "fetched_at", "payload", "content_hash"}
NESTED_OBJECTS = {"siteprofile", "tableprofile"}
# 22 张需要比对的表
TABLES = [
"assistant_accounts_master", "settlement_records", "assistant_service_records",
"assistant_cancellation_records", "table_fee_transactions", "table_fee_discount_records",
"payment_transactions", "refund_transactions", "platform_coupon_redemption_records",
"tenant_goods_master", "store_goods_sales_records", "store_goods_master",
"stock_goods_category_tree", "goods_stock_movements", "member_profiles",
"member_stored_value_cards", "recharge_settlements", "member_balance_changes",
"group_buy_packages", "group_buy_redemption_records", "goods_stock_summary",
"site_tables_master",
]
def load_json(table):
path = os.path.join(SAMPLES_DIR, f"{table}.json")
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def extract_fields(table):
data = load_json(table)
# settlement_records / recharge_settlements: 取 settleList 内层
if table in ("settlement_records", "recharge_settlements"):
record = data.get("settleList", {})
if isinstance(record, list):
record = record[0] if record else {}
fields = {k.lower() for k in record.keys()}
# 加上 siteProfile顶层嵌套对象
if "siteProfile" in data:
fields.add("siteprofile")
return fields
# stock_goods_category_tree: 取 goodsCategoryList 数组元素
if table == "stock_goods_category_tree":
cat_list = data.get("goodsCategoryList", [])
if cat_list:
return {k.lower() for k in cat_list[0].keys()}
return set()
# 通用:顶层 keys
fields = set()
for k, v in data.items():
kl = k.lower()
if kl in NESTED_OBJECTS:
fields.add(kl) # 嵌套对象作为单列
else:
fields.add(kl)
return fields
def main():
# 从数据库查询结果构建 ODS 列映射(硬编码,来自 information_schema
# 这里我们直接读取 JSON 样本并用 psycopg2 查询
# 但为了独立运行,我们从环境变量或文件读取
# 实际上我们直接用 extract_fields + 从文件读取 ODS 列
# ODS 列从单独的 JSON 文件读取
ods_cols_path = os.path.join(os.path.dirname(__file__), "ods_columns.json")
with open(ods_cols_path, "r", encoding="utf-8") as f:
ods_all = json.load(f)
results = []
for table in TABLES:
api_fields = extract_fields(table)
ods_cols = set(ods_all.get(table, [])) - ODS_META
matched = sorted(api_fields & ods_cols)
api_only = sorted(api_fields - ods_cols)
ods_only = sorted(ods_cols - api_fields)
results.append({
"table": table,
"api_count": len(api_fields),
"ods_count": len(ods_cols),
"matched": len(matched),
"api_only": api_only,
"ods_only": ods_only,
})
status = "✓ 完全对齐" if not api_only and not ods_only else ""
print(f"{table}: API={len(api_fields)} ODS={len(ods_cols)} 匹配={len(matched)} API独有={len(api_only)} ODS独有={len(ods_only)} {status}")
if api_only:
print(f" API独有: {api_only}")
if ods_only:
print(f" ODS独有: {ods_only}")
# 写 JSON 报告
os.makedirs(REPORT_DIR, exist_ok=True)
out = os.path.join(REPORT_DIR, "api_ods_comparison_v3.json")
with open(out, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print(f"\nJSON 报告: {out}")
if __name__ == "__main__":
main()
# ──────────────────────────────────────────────────────────────────
# AI_CHANGELOG:
# - 日期: 2026-02-14
# Prompt: P20260214-000000 — "还是不准。现在拆解任务,所有表,每个表当作一个任务进行比对。"
# 直接原因: v2 比对脚本结果不准确,需从 JSON 样本直接提取字段与数据库实际列精确比对
# 变更摘要: 新建脚本,读取 samples/*.json 提取 API 字段,读取 ods_columns.json 获取 ODS 列,
# 处理 settleList 嵌套/goodsCategoryList 数组/siteProfile 嵌套对象等特殊结构,逐表输出比对结果
# 风险与验证: 纯分析脚本,不修改数据库;验证方式:运行脚本确认输出与 v3 报告一致
# ──────────────────────────────────────────────────────────────────

View File

@@ -0,0 +1,465 @@
# -*- coding: utf-8 -*-
"""
v3-fixed: API 参考文档 (.md) 响应字段详解 vs ODS 实际列 — 精确比对
核心改进(相对 v3
1. 仅从"四、响应字段详解"章节提取字段(排除请求参数、跨表关联等章节)
2. 对 settlement_records / recharge_settlements 特殊处理:
- settleList 内层字段 → 直接比对 ODS 列
- siteProfile → ODS 中存为 siteprofile jsonb 单列(不展开子字段)
3. 对 table_fee_discount_records / payment_transactions 等含 siteProfile/tableProfile 的表:
- siteProfile/tableProfile 作为嵌套对象 → ODS 中存为 jsonb 单列
4. 对 stock_goods_category_treegoodsCategoryList/categoryBoxes 是结构包装器,不是业务字段
5. JSON 样本作为补充来源union
CHANGE P20260214-003000: 完全重写字段提取逻辑
intent: 精确限定提取范围到"响应字段详解"章节,避免误提取请求参数和跨表关联字段
assumptions: 所有 .md 文档均以"## 四、响应字段详解"开始响应字段章节,以"## 五、"结束
edge cases: settlement_records/recharge_settlements 的 siteProfile 子字段不应与 ODS 列比对
"""
import json
import os
import re
from datetime import datetime
DOCS_DIR = os.path.join(os.path.dirname(__file__), "..", "docs", "api-reference")
SAMPLES_DIR = os.path.join(DOCS_DIR, "samples")
REPORT_DIR = os.path.join(os.path.dirname(__file__), "..", "docs", "reports")
ODS_META = {"source_file", "source_endpoint", "fetched_at", "payload", "content_hash"}
TABLES = [
"assistant_accounts_master", "settlement_records", "assistant_service_records",
"assistant_cancellation_records", "table_fee_transactions", "table_fee_discount_records",
"payment_transactions", "refund_transactions", "platform_coupon_redemption_records",
"tenant_goods_master", "store_goods_sales_records", "store_goods_master",
"stock_goods_category_tree", "goods_stock_movements", "member_profiles",
"member_stored_value_cards", "recharge_settlements", "member_balance_changes",
"group_buy_packages", "group_buy_redemption_records", "goods_stock_summary",
"site_tables_master",
]
# 这些字段在 API JSON 中是嵌套对象ODS 中存为 jsonb 单列
NESTED_OBJECTS = {"siteprofile", "tableprofile"}
# 这些字段是结构包装器,不是业务字段
# 注意categoryboxes 虽然是嵌套数组,但 ODS 中确实有 categoryboxes 列jsonb所以不排除
WRAPPER_FIELDS = {"goodscategorylist", "total"}
# 跨表关联章节中常见的"本表字段"列标题
CROSS_REF_HEADERS = {"本表字段", "关联表字段", "关联表", "参数", "字段"}
def extract_response_fields_from_md(table_name: str) -> tuple[set[str], list[str]]:
"""
从 API 参考文档中精确提取"响应字段详解"章节的字段名。
返回: (fields_set_lowercase, debug_messages)
提取策略:
- 找到"## 四、响应字段详解"章节
- 在该章节内提取所有 Markdown 表格第一列的反引号字段名
- 遇到"## 五、"或更高级别标题时停止
- 对 settlement_records / recharge_settlements
- siteProfile 子字段(带 siteProfile. 前缀的)→ 不提取ODS 中存为 siteprofile jsonb
- settleList 内层字段 → 正常提取
- 对含 siteProfile/tableProfile 的表这些作为顶层字段名提取ODS 中是 jsonb 列)
"""
md_path = os.path.join(DOCS_DIR, f"{table_name}.md")
debug = []
if not os.path.exists(md_path):
debug.append(f"[WARN] 文档不存在: {md_path}")
return set(), debug
with open(md_path, "r", encoding="utf-8") as f:
lines = f.readlines()
fields = set()
in_response_section = False
in_siteprofile_subsection = False
field_pattern = re.compile(r'^\|\s*`([^`]+)`\s*\|')
# 用于检测 siteProfile 子章节(如 "### A. siteProfile" 或 "### 4.1 门店信息快照siteProfile"
siteprofile_header = re.compile(r'^###.*siteProfile', re.IGNORECASE)
for line in lines:
stripped = line.strip()
# 检测进入"响应字段详解"章节
if stripped.startswith("## 四、") and "响应字段" in stripped:
in_response_section = True
in_siteprofile_subsection = False
continue
# 检测离开(遇到下一个 ## 级别标题)
if in_response_section and stripped.startswith("## ") and not stripped.startswith("## 四"):
break
if not in_response_section:
continue
# 检测 siteProfile 子章节(仅对 settlement_records / recharge_settlements
if table_name in ("settlement_records", "recharge_settlements"):
if siteprofile_header.search(stripped):
in_siteprofile_subsection = True
continue
# 遇到下一个 ### 标题,退出 siteProfile 子章节
if stripped.startswith("### ") and in_siteprofile_subsection:
if not siteprofile_header.search(stripped):
in_siteprofile_subsection = False
# 提取字段名
m = field_pattern.match(stripped)
if m:
raw_field = m.group(1).strip()
# 跳过表头行
if raw_field in CROSS_REF_HEADERS:
continue
# 对 settlement_records / recharge_settlements跳过 siteProfile 子字段
if table_name in ("settlement_records", "recharge_settlements"):
if in_siteprofile_subsection:
# siteProfile 子字段不提取ODS 中存为 siteprofile jsonb
continue
# 带 siteProfile. 前缀的也跳过
if raw_field.startswith("siteProfile."):
continue
# 跳过结构包装器字段
if raw_field.lower() in WRAPPER_FIELDS:
continue
fields.add(raw_field.lower())
debug.append(f"从 .md 提取 {len(fields)} 个响应字段")
return fields, debug
def extract_fields_from_json(table_name: str) -> tuple[set[str], list[str]]:
"""从 JSON 样本提取字段(作为补充)"""
path = os.path.join(SAMPLES_DIR, f"{table_name}.json")
debug = []
if not os.path.exists(path):
debug.append("[INFO] 无 JSON 样本")
return set(), debug
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
# settlement_records / recharge_settlements: 提取 settleList 内层字段
if table_name in ("settlement_records", "recharge_settlements"):
settle = data.get("settleList", {})
if isinstance(settle, list):
settle = settle[0] if settle else {}
fields = {k.lower() for k in settle.keys()}
# siteProfile 作为整体ODS 中不存 siteProfile 的子字段,但可能有 siteprofile jsonb 列)
# 不添加 siteProfile 的子字段
debug.append(f"从 JSON settleList 提取 {len(fields)} 个字段")
return fields, debug
# stock_goods_category_tree: 提取 goodsCategoryList 内层字段
if table_name == "stock_goods_category_tree":
cat_list = data.get("goodsCategoryList", [])
if cat_list:
fields = set()
for k in cat_list[0].keys():
kl = k.lower()
if kl not in WRAPPER_FIELDS:
fields.add(kl)
debug.append(f"从 JSON goodsCategoryList 提取 {len(fields)} 个字段")
return fields, debug
return set(), debug
# 通用:提取顶层字段
fields = set()
for k in data.keys():
kl = k.lower()
# siteProfile/tableProfile 作为整体保留ODS 中是 jsonb 列)
if kl in NESTED_OBJECTS:
fields.add(kl)
elif kl not in WRAPPER_FIELDS:
fields.add(kl)
debug.append(f"从 JSON 提取 {len(fields)} 个字段")
return fields, debug
def classify_ods_only(table_name: str, field: str) -> str:
"""对 ODS 独有字段进行分类说明"""
# table_fee_discount_records 的展开字段
if table_name == "table_fee_discount_records" and field in (
"area_type_id", "charge_free", "site_table_area_id", "site_table_area_name",
"sitename", "table_name", "table_price", "tenant_name"
):
return "从 tableProfile/siteProfile 嵌套对象展开的字段"
# site_tables_master 的 order_id
if table_name == "site_tables_master" and field == "order_id":
return "ODS 后续版本新增字段(当前使用中的台桌关联订单 ID"
# tenant_id 在某些表中是 ODS 额外添加的
if field == "tenant_id" and table_name in (
"assistant_cancellation_records", "payment_transactions"
):
return "ODS 额外添加的租户 ID 字段API 响应中不含ETL 入库时补充)"
# API 后续版本新增字段(文档快照未覆盖)
api_version_fields = {
"assistant_service_records": {
"assistantteamname": "API 后续版本新增(助教团队名称)",
"real_service_money": "API 后续版本新增(实际服务金额)",
},
"table_fee_transactions": {
"activity_discount_amount": "API 后续版本新增(活动折扣金额)",
"order_consumption_type": "API 后续版本新增(订单消费类型)",
"real_service_money": "API 后续版本新增(实际服务金额)",
},
"tenant_goods_master": {
"not_sale": "API 后续版本新增(是否禁售标记)",
},
"store_goods_sales_records": {
"coupon_share_money": "API 后续版本新增(优惠券分摊金额)",
},
"store_goods_master": {
"commodity_code": "API 后续版本新增(商品编码)",
"not_sale": "API 后续版本新增(是否禁售标记)",
},
"member_profiles": {
"pay_money_sum": "API 后续版本新增(累计消费金额)",
"person_tenant_org_id": "API 后续版本新增(人事组织 ID",
"person_tenant_org_name": "API 后续版本新增(人事组织名称)",
"recharge_money_sum": "API 后续版本新增(累计充值金额)",
"register_source": "API 后续版本新增(注册来源)",
},
"member_stored_value_cards": {
"able_share_member_discount": "API 后续版本新增(是否共享会员折扣)",
"electricity_deduct_radio": "API 后续版本新增(电费抵扣比例)",
"electricity_discount": "API 后续版本新增(电费折扣)",
"electricitycarddeduct": "API 后续版本新增(电费卡扣金额)",
"member_grade": "API 后续版本新增(会员等级)",
"principal_balance": "API 后续版本新增(本金余额)",
"rechargefreezebalance": "API 后续版本新增(充值冻结余额)",
},
"member_balance_changes": {
"principal_after": "API 后续版本新增(变动后本金)",
"principal_before": "API 后续版本新增(变动前本金)",
"principal_data": "API 后续版本新增(本金明细数据)",
},
"group_buy_packages": {
"is_first_limit": "API 后续版本新增(是否限首单)",
"sort": "API 后续版本新增(排序序号)",
"tenantcouponsaleorderitemid": "API 后续版本新增(租户券销售订单项 ID",
},
"group_buy_redemption_records": {
"assistant_service_share_money": "API 后续版本新增(助教服务分摊金额)",
"assistant_share_money": "API 后续版本新增(助教分摊金额)",
"coupon_sale_id": "API 后续版本新增(券销售 ID",
"good_service_share_money": "API 后续版本新增(商品服务分摊金额)",
"goods_share_money": "API 后续版本新增(商品分摊金额)",
"member_discount_money": "API 后续版本新增(会员折扣金额)",
"recharge_share_money": "API 后续版本新增(充值分摊金额)",
"table_service_share_money": "API 后续版本新增(台费服务分摊金额)",
"table_share_money": "API 后续版本新增(台费分摊金额)",
},
}
table_fields = api_version_fields.get(table_name, {})
if field in table_fields:
return table_fields[field]
return "ODS 独有(待确认来源)"
def main():
ods_cols_path = os.path.join(os.path.dirname(__file__), "ods_columns.json")
with open(ods_cols_path, "r", encoding="utf-8") as f:
ods_all = json.load(f)
results = []
total_api_only = 0
total_ods_only = 0
all_debug = {}
for table in TABLES:
debug_lines = [f"\n{'='*60}", f"表: {table}", f"{'='*60}"]
# 从文档提取字段(主要来源)
md_fields, md_debug = extract_response_fields_from_md(table)
debug_lines.extend(md_debug)
# 从 JSON 样本提取字段(补充)
json_fields, json_debug = extract_fields_from_json(table)
debug_lines.extend(json_debug)
# 合并:文档字段 JSON 样本字段
api_fields = md_fields | json_fields
# 特殊处理settlement_records / recharge_settlements
# ODS 中有 siteprofile 列但不展开子字段;也有 settlelist jsonb 列
# API 文档中 siteProfile 子字段已被排除,但需要确保 siteprofile 作为整体列被考虑
if table in ("settlement_records", "recharge_settlements"):
# 不把 siteprofile 加入 api_fields因为 ODS 中 siteprofile 不是从 API 直接映射的列名)
# settlelist 也是 ODS 的 jsonb 列,不在 API 字段中
pass
# 特殊处理:含 siteProfile/tableProfile 的表
# 这些在 API 中是嵌套对象ODS 中存为 jsonb 列
# 确保 api_fields 中包含 siteprofile/tableprofile如果 ODS 有这些列)
ods_cols = set(ods_all.get(table, [])) - ODS_META
ods_cols_lower = set()
ods_case_map = {}
for c in ods_cols:
cl = c.lower()
ods_cols_lower.add(cl)
ods_case_map[cl] = c
# 如果 ODS 有 siteprofile/tableprofile 列,且 API 文档中有 siteProfile/tableProfile 字段
for nested in NESTED_OBJECTS:
if nested in ods_cols_lower and nested not in api_fields:
# 检查 API 文档/JSON 中是否有这个嵌套对象
# 对于 settlement_records/recharge_settlementssiteProfile 确实存在于 API 响应中
# 对于 payment_transactions 等siteProfile 也存在
api_fields.add(nested)
debug_lines.append(f" 补充嵌套对象字段: {nested}")
matched = sorted(api_fields & ods_cols_lower)
api_only = sorted(api_fields - ods_cols_lower)
ods_only = sorted(ods_cols_lower - api_fields)
# 对 ODS 独有字段分类
ods_only_classified = []
for f in ods_only:
reason = classify_ods_only(table, f)
ods_only_classified.append({"field": f, "ods_original": ods_case_map.get(f, f), "reason": reason})
total_api_only += len(api_only)
total_ods_only += len(ods_only)
result = {
"table": table,
"api_count": len(api_fields),
"ods_count": len(ods_cols_lower),
"matched": len(matched),
"matched_fields": matched,
"api_only": api_only,
"ods_only": ods_only_classified,
"api_only_count": len(api_only),
"ods_only_count": len(ods_only),
"md_fields_count": len(md_fields),
"json_fields_count": len(json_fields),
}
results.append(result)
status = "✓ 完全对齐" if not api_only and not ods_only else ""
print(f"{table}: API={len(api_fields)}(md={len(md_fields)},json={len(json_fields)}) "
f"ODS={len(ods_cols_lower)} 匹配={len(matched)} "
f"API独有={len(api_only)} ODS独有={len(ods_only)} {status}")
if api_only:
print(f" API独有: {api_only}")
if ods_only:
for item in ods_only_classified:
print(f" ODS独有: {item['ods_original']}{item['reason']}")
all_debug[table] = debug_lines
print(f"\n{'='*60}")
print(f"总计: API独有={total_api_only}, ODS独有={total_ods_only}")
print(f"{'='*60}")
# 写 JSON 报告
os.makedirs(REPORT_DIR, exist_ok=True)
json_out = os.path.join(REPORT_DIR, "api_ods_comparison_v3_fixed.json")
with open(json_out, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print(f"\nJSON 报告: {json_out}")
# 写 Markdown 报告
md_out = os.path.join(REPORT_DIR, "api_ods_comparison_v3_fixed.md")
write_md_report(results, md_out, total_api_only, total_ods_only)
print(f"MD 报告: {md_out}")
def write_md_report(results, path, total_api_only, total_ods_only):
now = datetime.now().strftime("%Y-%m-%d %H:%M")
lines = [
f"# API 响应字段 vs ODS 表结构比对报告v3-fixed",
f"",
f"> 生成时间:{now}Asia/Shanghai",
f"> 数据来源API 参考文档docs/api-reference/*.md+ JSON 样本 + PostgreSQL information_schema",
f'> 比对方法:从文档"响应字段详解"章节精确提取字段,与 ODS 实际列比对(排除 meta 列)',
f"",
f"## 汇总",
f"",
f"| 指标 | 值 |",
f"|------|-----|",
f"| 比对表数 | {len(results)} |",
f"| API 独有字段总数 | {total_api_only} |",
f"| ODS 独有字段总数 | {total_ods_only} |",
f"| 完全对齐表数 | {sum(1 for r in results if r['api_only_count'] == 0 and r['ods_only_count'] == 0)} |",
f"",
f"## 逐表比对",
f"",
]
for r in results:
status = "✅ 完全对齐" if r["api_only_count"] == 0 and r["ods_only_count"] == 0 else "⚠️ 有差异"
lines.append(f"### {r['table']}{status}")
lines.append(f"")
lines.append(f"| 指标 | 值 |")
lines.append(f"|------|-----|")
lines.append(f"| API 字段数 | {r['api_count']}(文档={r['md_fields_count']}JSON={r['json_fields_count']} |")
lines.append(f"| ODS 列数(排除 meta | {r['ods_count']} |")
lines.append(f"| 匹配 | {r['matched']} |")
lines.append(f"| API 独有 | {r['api_only_count']} |")
lines.append(f"| ODS 独有 | {r['ods_only_count']} |")
lines.append(f"")
if r["api_only"]:
lines.append(f"**API 独有字段ODS 中缺失):**")
lines.append(f"")
for f in r["api_only"]:
lines.append(f"- `{f}`")
lines.append(f"")
if r["ods_only"]:
lines.append(f"**ODS 独有字段API 文档中未出现):**")
lines.append(f"")
lines.append(f"| ODS 列名 | 分类说明 |")
lines.append(f"|----------|----------|")
for item in r["ods_only"]:
lines.append(f"| `{item['ods_original']}` | {item['reason']} |")
lines.append(f"")
lines.append(f"---")
lines.append(f"")
# AI_CHANGELOG
lines.extend([
f"<!--",
f"AI_CHANGELOG:",
f"- 日期: 2026-02-14",
f"- Prompt: P20260214-003000 — v3 比对不准确,重写为 v3-fixed",
f"- 直接原因: v3 仅从 JSON 样本提取字段导致遗漏v3-fixed 从 .md 文档响应字段详解章节精确提取",
f"- 变更摘要: 新建 v3-fixed 报告,精确限定提取范围,排除请求参数和跨表关联字段",
f"- 风险与验证: 纯分析报告,无运行时影响;验证方式:抽查 assistant_accounts_master 的 last_update_name 是否正确识别为匹配",
f"-->",
])
with open(path, "w", encoding="utf-8") as f:
f.write("\n".join(lines))
if __name__ == "__main__":
main()
# AI_CHANGELOG:
# - 日期: 2026-02-14
# - Prompt: P20260214-003000 — "还是不准比如assistant_accounts_master的last_update_name命名Json里就有再仔细比对下"
# - 直接原因: v3 仅从 JSON 样本提取字段导致遗漏条件性字段;需改用 .md 文档响应字段详解章节作为主要来源
# - 变更摘要: 完全重写脚本,精确限定提取范围到"四、响应字段详解"章节,排除请求参数和跨表关联;
# 对 settlement_records/recharge_settlements 的 siteProfile 子字段不提取;对所有 ODS 独有字段分类说明
# - 风险与验证: 纯分析脚本,无运行时影响;验证:确认 assistant_accounts_master 62:62 完全对齐last_update_name 正确匹配
#
# - 日期: 2026-02-14
# - Prompt: P20260214-030000 — 上下文传递续接,执行 settlelist 删除后的收尾工作
# - 直接原因: settlelist 列已从 ODS 删除classify_ods_only 中的 settlelist 特殊分类不再需要
# - 变更摘要: 移除 classify_ods_only 函数中 settlelist 的特殊分类逻辑
# - 风险与验证: 纯分析脚本;验证:重新运行脚本确认 ODS 独有=47settlement_records 和 recharge_settlements 完全对齐
#
# - 日期: 2026-02-14
# - Prompt: P20260214-070000 — ODS 清理与文档标注5 项任务)
# - 直接原因: option_namestore_goods_sales_records和 able_site_transfermember_stored_value_cards已从 ODS 删除
# - 变更摘要: 从 classify_ods_only 的 api_version_fields 字典中移除 option_name 和 able_site_transfer 条目
# - 风险与验证: 纯分析脚本;验证:重新运行脚本确认两表 ODS 独有数减少

View File

@@ -0,0 +1,26 @@
@echo off
REM -*- coding: utf-8 -*-
REM 说明:一键重建 ODS执行 INIT_ODS_SCHEMA并灌入示例 JSON执行 MANUAL_INGEST
setlocal
cd /d "%~dp0\.."
REM 如果需要覆盖示例目录,可修改下面的 INGEST_DIR
set "INGEST_DIR=export\\test-json-doc"
echo [INIT_ODS_SCHEMA] 准备执行,源目录=%INGEST_DIR%
python -m cli.main --tasks INIT_ODS_SCHEMA --pipeline-flow INGEST_ONLY --ingest-source "%INGEST_DIR%"
if errorlevel 1 (
echo INIT_ODS_SCHEMA 失败,退出
exit /b 1
)
echo [MANUAL_INGEST] 准备执行,源目录=%INGEST_DIR%
python -m cli.main --tasks MANUAL_INGEST --pipeline-flow INGEST_ONLY --ingest-source "%INGEST_DIR%"
if errorlevel 1 (
echo MANUAL_INGEST 失败,退出
exit /b 1
)
echo 全部完成。
endlocal

View File

@@ -0,0 +1,516 @@
# -*- coding: utf-8 -*-
"""
一键增量更新脚本ODS -> DWD -> DWS
用法:
python scripts/run_update.py
"""
from __future__ import annotations
import argparse
import logging
import multiprocessing as mp
import subprocess
import sys
import time as time_mod
from datetime import date, datetime, time, timedelta
from pathlib import Path
from zoneinfo import ZoneInfo
from api.client import APIClient
from config.settings import AppConfig
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations
from orchestration.scheduler import ETLScheduler
from tasks.utility.check_cutoff_task import CheckCutoffTask
from tasks.dwd.dwd_load_task import DwdLoadTask
from tasks.ods.ods_tasks import ENABLED_ODS_CODES
from utils.logging_utils import build_log_path, configure_logging
STEP_TIMEOUT_SEC = 120
def _coerce_date(s: str) -> date:
s = (s or "").strip()
if not s:
raise ValueError("empty date")
if len(s) >= 10:
s = s[:10]
return date.fromisoformat(s)
def _compute_dws_window(
*,
cfg: AppConfig,
tz: ZoneInfo,
rebuild_days: int,
bootstrap_days: int,
dws_start: date | None,
dws_end: date | None,
) -> tuple[datetime, datetime]:
if dws_start and dws_end and dws_end < dws_start:
raise ValueError("dws_end must be >= dws_start")
store_id = int(cfg.get("app.store_id"))
dsn = cfg["db"]["dsn"]
session = cfg["db"].get("session")
conn = DatabaseConnection(dsn=dsn, session=session)
try:
if dws_start is None:
row = conn.query(
"SELECT MAX(order_date) AS mx FROM billiards_dws.dws_order_summary WHERE site_id=%s",
(store_id,),
)
mx = (row[0] or {}).get("mx") if row else None
if isinstance(mx, date):
dws_start = mx - timedelta(days=max(0, int(rebuild_days)))
else:
dws_start = (datetime.now(tz).date()) - timedelta(days=max(1, int(bootstrap_days)))
if dws_end is None:
dws_end = datetime.now(tz).date()
finally:
conn.close()
start_dt = datetime.combine(dws_start, time.min).replace(tzinfo=tz)
# end_dt 取到当天 23:59:59避免只跑到“当前时刻”的 date() 导致少一天
end_dt = datetime.combine(dws_end, time.max).replace(tzinfo=tz)
return start_dt, end_dt
def _run_check_cutoff(cfg: AppConfig, logger: logging.Logger):
dsn = cfg["db"]["dsn"]
session = cfg["db"].get("session")
db_conn = DatabaseConnection(dsn=dsn, session=session)
db_ops = DatabaseOperations(db_conn)
api = APIClient(
base_url=cfg["api"]["base_url"],
token=cfg["api"]["token"],
timeout=cfg["api"]["timeout_sec"],
retry_max=cfg["api"]["retries"]["max_attempts"],
headers_extra=cfg["api"].get("headers_extra"),
)
try:
CheckCutoffTask(cfg, db_ops, api, logger).execute(None)
finally:
db_conn.close()
def _iter_daily_windows(window_start: datetime, window_end: datetime) -> list[tuple[datetime, datetime]]:
if window_start > window_end:
return []
tz = window_start.tzinfo
windows: list[tuple[datetime, datetime]] = []
cur = window_start
while cur <= window_end:
day_start = datetime.combine(cur.date(), time.min).replace(tzinfo=tz)
day_end = datetime.combine(cur.date(), time.max).replace(tzinfo=tz)
if day_start < window_start:
day_start = window_start
if day_end > window_end:
day_end = window_end
windows.append((day_start, day_end))
next_day = cur.date() + timedelta(days=1)
cur = datetime.combine(next_day, time.min).replace(tzinfo=tz)
return windows
def _run_step_worker(result_queue: "mp.Queue[dict[str, str]]", step: dict[str, str]) -> None:
if hasattr(sys.stdout, "reconfigure"):
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
log_file = step.get("log_file") or ""
log_level = step.get("log_level") or "INFO"
log_console = bool(step.get("log_console", True))
log_path = Path(log_file) if log_file else None
with configure_logging(
"etl_update",
log_path,
level=log_level,
console=log_console,
tee_std=True,
) as logger:
cfg_base = AppConfig.load({})
step_type = step.get("type", "")
try:
if step_type == "check_cutoff":
_run_check_cutoff(cfg_base, logger)
elif step_type == "ods_task":
task_code = step["task_code"]
overlap_seconds = int(step.get("overlap_seconds", 0))
cfg_ods = AppConfig.load(
{
"pipeline": {"flow": "FULL"},
"run": {"tasks": [task_code], "overlap_seconds": overlap_seconds},
}
)
scheduler = ETLScheduler(cfg_ods, logger)
try:
scheduler.run_tasks([task_code])
finally:
scheduler.close()
elif step_type == "init_dws_schema":
overlap_seconds = int(step.get("overlap_seconds", 0))
cfg_dwd = AppConfig.load(
{
"pipeline": {"flow": "INGEST_ONLY"},
"run": {"tasks": ["INIT_DWS_SCHEMA"], "overlap_seconds": overlap_seconds},
}
)
scheduler = ETLScheduler(cfg_dwd, logger)
try:
scheduler.run_tasks(["INIT_DWS_SCHEMA"])
finally:
scheduler.close()
elif step_type == "dwd_table":
dwd_table = step["dwd_table"]
overlap_seconds = int(step.get("overlap_seconds", 0))
cfg_dwd = AppConfig.load(
{
"pipeline": {"flow": "INGEST_ONLY"},
"run": {"tasks": ["DWD_LOAD_FROM_ODS"], "overlap_seconds": overlap_seconds},
"dwd": {"only_tables": [dwd_table]},
}
)
scheduler = ETLScheduler(cfg_dwd, logger)
try:
scheduler.run_tasks(["DWD_LOAD_FROM_ODS"])
finally:
scheduler.close()
elif step_type == "dws_window":
overlap_seconds = int(step.get("overlap_seconds", 0))
window_start = step["window_start"]
window_end = step["window_end"]
cfg_dws = AppConfig.load(
{
"pipeline": {"flow": "INGEST_ONLY"},
"run": {
"tasks": ["DWS_BUILD_ORDER_SUMMARY"],
"overlap_seconds": overlap_seconds,
"window_override": {"start": window_start, "end": window_end},
},
}
)
scheduler = ETLScheduler(cfg_dws, logger)
try:
scheduler.run_tasks(["DWS_BUILD_ORDER_SUMMARY"])
finally:
scheduler.close()
elif step_type == "ods_gap_check":
overlap_hours = int(step.get("overlap_hours", 24))
window_days = int(step.get("window_days", 1))
window_hours = int(step.get("window_hours", 0))
page_size = int(step.get("page_size", 0) or 0)
sleep_per_window = float(step.get("sleep_per_window", 0) or 0)
sleep_per_page = float(step.get("sleep_per_page", 0) or 0)
tag = step.get("tag", "run_update")
task_codes = (step.get("task_codes") or "").strip()
script_dir = Path(__file__).resolve().parent.parent
script_path = script_dir / "scripts" / "check" / "check_ods_gaps.py"
cmd = [
sys.executable,
str(script_path),
"--from-cutoff",
"--cutoff-overlap-hours",
str(overlap_hours),
"--window-days",
str(window_days),
"--tag",
str(tag),
]
if window_hours > 0:
cmd += ["--window-hours", str(window_hours)]
if page_size > 0:
cmd += ["--page-size", str(page_size)]
if sleep_per_window > 0:
cmd += ["--sleep-per-window-seconds", str(sleep_per_window)]
if sleep_per_page > 0:
cmd += ["--sleep-per-page-seconds", str(sleep_per_page)]
if task_codes:
cmd += ["--task-codes", task_codes]
subprocess.run(cmd, check=True, cwd=str(script_dir))
else:
raise ValueError(f"Unknown step type: {step_type}")
result_queue.put({"status": "ok"})
except Exception as exc:
result_queue.put({"status": "error", "error": str(exc)})
def _run_step_with_timeout(
step: dict[str, str], logger: logging.Logger, timeout_sec: int
) -> dict[str, object]:
start = time_mod.monotonic()
step_timeout = timeout_sec
if step.get("timeout_sec"):
try:
step_timeout = int(step.get("timeout_sec"))
except Exception:
step_timeout = timeout_sec
ctx = mp.get_context("spawn")
result_queue: mp.Queue = ctx.Queue()
proc = ctx.Process(target=_run_step_worker, args=(result_queue, step))
proc.start()
proc.join(timeout=step_timeout)
elapsed = time_mod.monotonic() - start
if proc.is_alive():
logger.error(
"STEP_TIMEOUT name=%s elapsed=%.2fs limit=%ss", step["name"], elapsed, step_timeout
)
proc.terminate()
proc.join(10)
return {"name": step["name"], "status": "timeout", "elapsed": elapsed}
result: dict[str, object] = {"name": step["name"], "status": "error", "elapsed": elapsed}
try:
payload = result_queue.get_nowait()
except Exception:
payload = {}
if payload:
result.update(payload)
if result.get("status") == "ok":
logger.info("STEP_OK name=%s elapsed=%.2fs", step["name"], elapsed)
else:
logger.error(
"STEP_FAIL name=%s elapsed=%.2fs error=%s",
step["name"],
elapsed,
result.get("error"),
)
return result
def main() -> int:
if hasattr(sys.stdout, "reconfigure"):
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
parser = argparse.ArgumentParser(description="One-click ETL update (ODS -> DWD -> DWS)")
parser.add_argument("--overlap-seconds", type=int, default=3600, help="overlap seconds (default: 3600)")
parser.add_argument(
"--dws-rebuild-days",
type=int,
default=1,
help="DWS 回算冗余天数default: 1",
)
parser.add_argument(
"--dws-bootstrap-days",
type=int,
default=30,
help="DWS 首次/空表时回算天数default: 30",
)
parser.add_argument("--dws-start", type=str, default="", help="DWS 回算开始日期 YYYY-MM-DD可选")
parser.add_argument("--dws-end", type=str, default="", help="DWS 回算结束日期 YYYY-MM-DD可选")
parser.add_argument(
"--skip-cutoff",
action="store_true",
help="跳过 CHECK_CUTOFF默认会在开始/结束各跑一次)",
)
parser.add_argument(
"--skip-ods",
action="store_true",
help="跳过 ODS 在线抓取(仅跑 DWD/DWS",
)
parser.add_argument(
"--ods-tasks",
type=str,
default="",
help="指定要跑的 ODS 任务(逗号分隔),默认跑全部 ENABLED_ODS_CODES",
)
parser.add_argument(
"--check-ods-gaps",
action="store_true",
help="run ODS gap check after ODS load (default: off)",
)
parser.add_argument(
"--check-ods-overlap-hours",
type=int,
default=24,
help="gap check overlap hours from cutoff (default: 24)",
)
parser.add_argument(
"--check-ods-window-days",
type=int,
default=1,
help="gap check window days (default: 1)",
)
parser.add_argument(
"--check-ods-window-hours",
type=int,
default=0,
help="gap check window hours (default: 0)",
)
parser.add_argument(
"--check-ods-page-size",
type=int,
default=200,
help="gap check API page size (default: 200)",
)
parser.add_argument(
"--check-ods-timeout-sec",
type=int,
default=1800,
help="gap check timeout seconds (default: 1800)",
)
parser.add_argument(
"--check-ods-task-codes",
type=str,
default="",
help="gap check task codes (comma-separated, optional)",
)
parser.add_argument(
"--check-ods-sleep-per-window-seconds",
type=float,
default=0,
help="gap check sleep seconds after each window (default: 0)",
)
parser.add_argument(
"--check-ods-sleep-per-page-seconds",
type=float,
default=0,
help="gap check sleep seconds after each page (default: 0)",
)
parser.add_argument("--log-file", type=str, default="", help="log file path (default: logs/run_update_YYYYMMDD_HHMMSS.log)")
parser.add_argument("--log-dir", type=str, default="", help="log directory (default: logs)")
parser.add_argument("--log-level", type=str, default="INFO", help="log level (default: INFO)")
parser.add_argument("--no-log-console", action="store_true", help="disable console logging")
args = parser.parse_args()
log_dir = Path(args.log_dir) if args.log_dir else (Path(__file__).resolve().parent.parent / "logs")
log_file = Path(args.log_file) if args.log_file else build_log_path(log_dir, "run_update")
log_console = not args.no_log_console
with configure_logging(
"etl_update",
log_file,
level=args.log_level,
console=log_console,
tee_std=True,
) as logger:
cfg_base = AppConfig.load({})
tz = ZoneInfo(cfg_base.get("app.timezone", "Asia/Shanghai"))
dws_start = _coerce_date(args.dws_start) if args.dws_start else None
dws_end = _coerce_date(args.dws_end) if args.dws_end else None
steps: list[dict[str, str]] = []
if not args.skip_cutoff:
steps.append({"name": "CHECK_CUTOFF:before", "type": "check_cutoff"})
# ------------------------------------------------------------------ ODS在线抓取 + 写入)
if not args.skip_ods:
if args.ods_tasks:
ods_tasks = [t.strip().upper() for t in args.ods_tasks.split(",") if t.strip()]
else:
ods_tasks = sorted(ENABLED_ODS_CODES)
for task_code in ods_tasks:
steps.append(
{
"name": f"ODS:{task_code}",
"type": "ods_task",
"task_code": task_code,
"overlap_seconds": str(args.overlap_seconds),
}
)
if args.check_ods_gaps:
steps.append(
{
"name": "ODS_GAP_CHECK",
"type": "ods_gap_check",
"overlap_hours": str(args.check_ods_overlap_hours),
"window_days": str(args.check_ods_window_days),
"window_hours": str(args.check_ods_window_hours),
"page_size": str(args.check_ods_page_size),
"sleep_per_window": str(args.check_ods_sleep_per_window_seconds),
"sleep_per_page": str(args.check_ods_sleep_per_page_seconds),
"timeout_sec": str(args.check_ods_timeout_sec),
"task_codes": str(args.check_ods_task_codes or ""),
"tag": "run_update",
}
)
# ------------------------------------------------------------------ DWD从 ODS 表装载)
steps.append(
{
"name": "INIT_DWS_SCHEMA",
"type": "init_dws_schema",
"overlap_seconds": str(args.overlap_seconds),
}
)
for dwd_table in DwdLoadTask.TABLE_MAP.keys():
steps.append(
{
"name": f"DWD:{dwd_table}",
"type": "dwd_table",
"dwd_table": dwd_table,
"overlap_seconds": str(args.overlap_seconds),
}
)
# ------------------------------------------------------------------ DWS按日期窗口重建
window_start, window_end = _compute_dws_window(
cfg=cfg_base,
tz=tz,
rebuild_days=int(args.dws_rebuild_days),
bootstrap_days=int(args.dws_bootstrap_days),
dws_start=dws_start,
dws_end=dws_end,
)
for start_dt, end_dt in _iter_daily_windows(window_start, window_end):
steps.append(
{
"name": f"DWS:{start_dt.date().isoformat()}",
"type": "dws_window",
"window_start": start_dt.strftime("%Y-%m-%d %H:%M:%S"),
"window_end": end_dt.strftime("%Y-%m-%d %H:%M:%S"),
"overlap_seconds": str(args.overlap_seconds),
}
)
if not args.skip_cutoff:
steps.append({"name": "CHECK_CUTOFF:after", "type": "check_cutoff"})
for step in steps:
step["log_file"] = str(log_file)
step["log_level"] = args.log_level
step["log_console"] = log_console
step_results: list[dict[str, object]] = []
for step in steps:
logger.info("STEP_START name=%s timeout=%ss", step["name"], STEP_TIMEOUT_SEC)
result = _run_step_with_timeout(step, logger, STEP_TIMEOUT_SEC)
step_results.append(result)
total = len(step_results)
ok_count = sum(1 for r in step_results if r.get("status") == "ok")
timeout_count = sum(1 for r in step_results if r.get("status") == "timeout")
fail_count = total - ok_count - timeout_count
logger.info(
"STEP_SUMMARY total=%s ok=%s failed=%s timeout=%s",
total,
ok_count,
fail_count,
timeout_count,
)
for item in sorted(step_results, key=lambda r: float(r.get("elapsed", 0.0)), reverse=True):
logger.info(
"STEP_RESULT name=%s status=%s elapsed=%.2fs",
item.get("name"),
item.get("status"),
item.get("elapsed", 0.0),
)
logger.info("Update done.")
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,488 @@
#!/usr/bin/env python3
"""
BD_Manual 文档体系验证脚本。
# AI_CHANGELOG [2026-02-13] 新增:验证 Property 1/4/5/6/7/8/9/10支持 --pg-dsn 参数
验证 docs/database/ 下的目录结构、文档覆盖率、格式完整性和命名规范。
需要连接 PostgreSQL 获取 billiards_ods schema 的表清单作为基准。
用法:
python scripts/validate_bd_manual.py --pg-dsn "postgresql://user:pass@host/db"
python scripts/validate_bd_manual.py # 从 PG_DSN 环境变量或 .env 读取
"""
from __future__ import annotations
import argparse
import os
import re
import sys
from pathlib import Path
from dataclasses import dataclass, field
# ---------------------------------------------------------------------------
# 常量
# ---------------------------------------------------------------------------
BD_MANUAL_ROOT = Path("docs/database")
ODS_MAIN_DIR = BD_MANUAL_ROOT / "ODS" / "main"
ODS_MAPPINGS_DIR = BD_MANUAL_ROOT / "ODS" / "mappings"
ODS_DICT_PATH = Path("docs/database/overview/ods_tables_dictionary.md")
# 四个数据层,每层都应有 main/ 和 changes/
DATA_LAYERS = ["ODS", "DWD", "DWS", "ETL_Admin"]
# ODS 文档必须包含的章节标题Property 5
ODS_DOC_REQUIRED_SECTIONS = [
"表信息",
"字段说明",
"使用说明",
"可回溯性",
]
# ODS 文档"表信息"表格中必须出现的属性关键词
ODS_DOC_TABLE_INFO_KEYS = ["Schema", "表名", "主键", "数据来源", "说明"]
# ODS 文档必须提及的 ETL 元数据字段
ODS_DOC_ETL_META_FIELDS = [
"content_hash",
"source_file",
"source_endpoint",
"fetched_at",
"payload",
]
# 映射文档必须包含的章节/关键内容Property 8
MAPPING_DOC_REQUIRED_SECTIONS = [
"端点信息",
"字段映射",
"ETL 补充字段",
]
# 映射文档"端点信息"表格中必须出现的属性关键词
MAPPING_DOC_ENDPOINT_KEYS = ["接口路径", "ODS 对应表", "JSON 数据路径"]
# ---------------------------------------------------------------------------
# 数据结构
# ---------------------------------------------------------------------------
@dataclass
class CheckResult:
"""单条验证结果。"""
property_id: str # 如 "Property 1"
description: str
passed: bool
details: list[str] = field(default_factory=list) # 失败时的具体说明
# ---------------------------------------------------------------------------
# 数据库查询:获取 ODS 表清单
# ---------------------------------------------------------------------------
def fetch_ods_tables(pg_dsn: str) -> list[str]:
"""从 billiards_ods schema 获取所有用户表名(排除系统表)。"""
import psycopg2
sql = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'billiards_ods'
AND table_type = 'BASE TABLE'
ORDER BY table_name;
"""
with psycopg2.connect(pg_dsn) as conn:
with conn.cursor() as cur:
cur.execute(sql)
return [row[0] for row in cur.fetchall()]
# ---------------------------------------------------------------------------
# Property 1: 数据层目录结构一致性
# ---------------------------------------------------------------------------
def check_directory_structure() -> CheckResult:
"""ODS/DWD/DWS/ETL_Admin 各层都应有 main/ 和 changes/ 子目录。"""
missing: list[str] = []
for layer in DATA_LAYERS:
for sub in ("main", "changes"):
p = BD_MANUAL_ROOT / layer / sub
if not p.is_dir():
missing.append(str(p))
return CheckResult(
property_id="Property 1",
description="数据层目录结构一致性main/ + changes/",
passed=len(missing) == 0,
details=[f"缺失目录: {d}" for d in missing],
)
# ---------------------------------------------------------------------------
# Property 4: ODS 表级文档覆盖率
# ---------------------------------------------------------------------------
def check_ods_doc_coverage(ods_tables: list[str]) -> CheckResult:
"""billiards_ods 中每张表都应有 BD_manual_{表名}.md。"""
missing: list[str] = []
for tbl in ods_tables:
expected = ODS_MAIN_DIR / f"BD_manual_{tbl}.md"
if not expected.is_file():
missing.append(tbl)
return CheckResult(
property_id="Property 4",
description="ODS 表级文档覆盖率",
passed=len(missing) == 0,
details=[f"缺失文档: BD_manual_{t}.md" for t in missing],
)
# ---------------------------------------------------------------------------
# Property 5: ODS 表级文档格式完整性
# ---------------------------------------------------------------------------
def _check_single_ods_doc(filepath: Path) -> list[str]:
"""检查单份 ODS 文档是否包含必要章节和内容,返回问题列表。"""
issues: list[str] = []
name = filepath.name
try:
content = filepath.read_text(encoding="utf-8")
except Exception as e:
return [f"{name}: 无法读取 ({e})"]
# 检查必要章节
for section in ODS_DOC_REQUIRED_SECTIONS:
# 匹配 ## 章节标题(允许前后有空格)
pattern = rf"^##\s+.*{re.escape(section)}"
if not re.search(pattern, content, re.MULTILINE):
issues.append(f"{name}: 缺少「{section}」章节")
# 检查"表信息"表格中的关键属性
for key in ODS_DOC_TABLE_INFO_KEYS:
if key not in content:
issues.append(f"{name}: 表信息缺少「{key}」属性")
# 检查 ETL 元数据字段是否被提及
meta_missing = [f for f in ODS_DOC_ETL_META_FIELDS if f not in content]
if meta_missing:
issues.append(f"{name}: 未提及 ETL 元数据字段: {', '.join(meta_missing)}")
return issues
def check_ods_doc_format() -> CheckResult:
"""每份 ODS 文档应包含表信息、字段说明、使用说明、可回溯性、ETL 元数据字段。"""
all_issues: list[str] = []
if not ODS_MAIN_DIR.is_dir():
return CheckResult(
property_id="Property 5",
description="ODS 表级文档格式完整性",
passed=False,
details=["ODS/main/ 目录不存在"],
)
for f in sorted(ODS_MAIN_DIR.glob("BD_manual_*.md")):
all_issues.extend(_check_single_ods_doc(f))
return CheckResult(
property_id="Property 5",
description="ODS 表级文档格式完整性",
passed=len(all_issues) == 0,
details=all_issues,
)
# ---------------------------------------------------------------------------
# Property 6: ODS 表级文档命名规范
# ---------------------------------------------------------------------------
def check_ods_doc_naming() -> CheckResult:
"""ODS/main/ 下的文件名应匹配 BD_manual_{表名}.md。"""
bad: list[str] = []
if not ODS_MAIN_DIR.is_dir():
return CheckResult(
property_id="Property 6",
description="ODS 表级文档命名规范",
passed=False,
details=["ODS/main/ 目录不存在"],
)
pattern = re.compile(r"^BD_manual_[a-z][a-z0-9_]*\.md$")
for f in sorted(ODS_MAIN_DIR.iterdir()):
if f.suffix == ".md" and not pattern.match(f.name):
bad.append(f.name)
return CheckResult(
property_id="Property 6",
description="ODS 表级文档命名规范BD_manual_{表名}.md",
passed=len(bad) == 0,
details=[f"命名不规范: {n}" for n in bad],
)
# ---------------------------------------------------------------------------
# Property 7: 映射文档覆盖率
# ---------------------------------------------------------------------------
def check_mapping_doc_coverage(ods_tables: list[str]) -> CheckResult:
"""每个有 ODS 表的 API 端点都应有映射文档。
策略:遍历 ODS 表,检查 mappings/ 下是否存在至少一个
mapping_*_{表名}.md 文件。
"""
missing: list[str] = []
if not ODS_MAPPINGS_DIR.is_dir():
return CheckResult(
property_id="Property 7",
description="映射文档覆盖率",
passed=False,
details=["ODS/mappings/ 目录不存在"],
)
existing_mappings = {f.name for f in ODS_MAPPINGS_DIR.glob("mapping_*.md")}
for tbl in ods_tables:
# 查找 mapping_*_{表名}.md
found = any(
name.endswith(f"_{tbl}.md") and name.startswith("mapping_")
for name in existing_mappings
)
if not found:
missing.append(tbl)
return CheckResult(
property_id="Property 7",
description="映射文档覆盖率(每张 ODS 表至少一份映射文档)",
passed=len(missing) == 0,
details=[f"缺失映射文档: mapping_*_{t}.md" for t in missing],
)
# ---------------------------------------------------------------------------
# Property 8: 映射文档内容完整性
# ---------------------------------------------------------------------------
def _check_single_mapping_doc(filepath: Path) -> list[str]:
"""检查单份映射文档是否包含必要章节和内容。"""
issues: list[str] = []
name = filepath.name
try:
content = filepath.read_text(encoding="utf-8")
except Exception as e:
return [f"{name}: 无法读取 ({e})"]
# 检查必要章节
for section in MAPPING_DOC_REQUIRED_SECTIONS:
pattern = rf"^##\s+.*{re.escape(section)}"
if not re.search(pattern, content, re.MULTILINE):
issues.append(f"{name}: 缺少「{section}」章节")
# 检查端点信息表格中的关键属性
for key in MAPPING_DOC_ENDPOINT_KEYS:
if key not in content:
issues.append(f"{name}: 端点信息缺少「{key}」属性")
# 检查 ETL 补充字段是否被提及
etl_missing = [f for f in ODS_DOC_ETL_META_FIELDS if f not in content]
if etl_missing:
issues.append(f"{name}: 未提及 ETL 补充字段: {', '.join(etl_missing)}")
return issues
def check_mapping_doc_content() -> CheckResult:
"""每份映射文档应包含端点路径、ODS 表名、JSON 数据路径、字段映射表、ETL 补充字段。"""
all_issues: list[str] = []
if not ODS_MAPPINGS_DIR.is_dir():
return CheckResult(
property_id="Property 8",
description="映射文档内容完整性",
passed=False,
details=["ODS/mappings/ 目录不存在"],
)
for f in sorted(ODS_MAPPINGS_DIR.glob("mapping_*.md")):
all_issues.extend(_check_single_mapping_doc(f))
return CheckResult(
property_id="Property 8",
description="映射文档内容完整性",
passed=len(all_issues) == 0,
details=all_issues,
)
# ---------------------------------------------------------------------------
# Property 9: 映射文档命名规范
# ---------------------------------------------------------------------------
def check_mapping_doc_naming() -> CheckResult:
"""映射文档文件名应匹配 mapping_{API端点名}_{ODS表名}.md。"""
bad: list[str] = []
if not ODS_MAPPINGS_DIR.is_dir():
return CheckResult(
property_id="Property 9",
description="映射文档命名规范",
passed=False,
details=["ODS/mappings/ 目录不存在"],
)
# mapping_{EndpointName}_{table_name}.md
# 端点名PascalCase字母数字表名snake_case
pattern = re.compile(r"^mapping_[A-Z][A-Za-z0-9]+_[a-z][a-z0-9_]*\.md$")
for f in sorted(ODS_MAPPINGS_DIR.iterdir()):
if f.suffix == ".md" and f.name.startswith("mapping_"):
if not pattern.match(f.name):
bad.append(f.name)
return CheckResult(
property_id="Property 9",
description="映射文档命名规范mapping_{API端点名}_{ODS表名}.md",
passed=len(bad) == 0,
details=[f"命名不规范: {n}" for n in bad],
)
# ---------------------------------------------------------------------------
# Property 10: ODS 数据字典覆盖率
# ---------------------------------------------------------------------------
def check_ods_dictionary_coverage(ods_tables: list[str]) -> CheckResult:
"""数据字典中应包含所有 ODS 表条目。"""
if not ODS_DICT_PATH.is_file():
return CheckResult(
property_id="Property 10",
description="ODS 数据字典覆盖率",
passed=False,
details=[f"数据字典文件不存在: {ODS_DICT_PATH}"],
)
try:
content = ODS_DICT_PATH.read_text(encoding="utf-8")
except Exception as e:
return CheckResult(
property_id="Property 10",
description="ODS 数据字典覆盖率",
passed=False,
details=[f"无法读取数据字典: {e}"],
)
missing: list[str] = []
for tbl in ods_tables:
# 在字典内容中查找表名(反引号包裹或直接出现)
if tbl not in content:
missing.append(tbl)
return CheckResult(
property_id="Property 10",
description="ODS 数据字典覆盖率",
passed=len(missing) == 0,
details=[f"数据字典缺失条目: {t}" for t in missing],
)
# ---------------------------------------------------------------------------
# 报告输出
# ---------------------------------------------------------------------------
def print_report(results: list[CheckResult]) -> None:
"""打印验证报告。"""
print("=" * 60)
print("BD_Manual 文档体系验证报告")
print("=" * 60)
passed_count = sum(1 for r in results if r.passed)
total = len(results)
for r in results:
status = "✓ PASS" if r.passed else "✗ FAIL"
print(f"\n[{status}] {r.property_id}: {r.description}")
if not r.passed:
for d in r.details[:20]: # 最多显示 20 条
print(f" - {d}")
if len(r.details) > 20:
print(f" ... 还有 {len(r.details) - 20} 条问题")
print("\n" + "-" * 60)
print(f"结果: {passed_count}/{total} 项通过")
if passed_count < total:
print("存在未通过的验证项,请检查上述详情。")
else:
print("所有验证项均通过 ✓")
print("=" * 60)
# ---------------------------------------------------------------------------
# 主入口
# ---------------------------------------------------------------------------
def main(argv: list[str] | None = None) -> int:
parser = argparse.ArgumentParser(
description="验证 BD_Manual 文档体系的覆盖率、格式和命名规范",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
# 从 .env 或 PG_DSN 环境变量读取连接字符串
python scripts/validate_bd_manual.py
# 指定连接字符串
python scripts/validate_bd_manual.py --pg-dsn "postgresql://user:pass@host/db"
""",
)
parser.add_argument(
"--pg-dsn",
help="PostgreSQL 连接字符串(默认从 PG_DSN 环境变量或 .env 读取)",
)
args = parser.parse_args(argv)
# 加载 .env
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
pg_dsn = args.pg_dsn or os.environ.get("PG_DSN")
if not pg_dsn:
print(
"✗ 未提供 PG_DSN请通过 --pg-dsn 参数或 PG_DSN 环境变量指定",
file=sys.stderr,
)
return 1
# 获取 ODS 表清单
try:
ods_tables = fetch_ods_tables(pg_dsn)
except Exception as e:
print(f"✗ 连接数据库失败: {e}", file=sys.stderr)
return 1
if not ods_tables:
print("⚠ billiards_ods schema 中未找到任何表", file=sys.stderr)
return 1
print(f"从数据库获取到 {len(ods_tables)} 张 ODS 表\n")
# 运行所有验证
results: list[CheckResult] = [
check_directory_structure(), # Property 1
check_ods_doc_coverage(ods_tables), # Property 4
check_ods_doc_format(), # Property 5
check_ods_doc_naming(), # Property 6
check_mapping_doc_coverage(ods_tables),# Property 7
check_mapping_doc_content(), # Property 8
check_mapping_doc_naming(), # Property 9
check_ods_dictionary_coverage(ods_tables), # Property 10
]
print_report(results)
# 任一验证失败则返回非零退出码
if any(not r.passed for r in results):
return 1
return 0
if __name__ == "__main__":
sys.exit(main())