Files
feiqiu-ETL/tmp/rewrite_schema_dwd_doc_comments.py
2025-12-13 08:26:09 +08:00

635 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
import ast
import json
import re
from collections import deque
from pathlib import Path
ROOT = Path(r"C:\dev\LLTQ\ETL\feiqiu-ETL")
SQL_PATH = ROOT / "etl_billiards" / "database" / "schema_dwd_doc.sql"
DOC_DIR = Path(r"C:\dev\LLTQ\export\test-json-doc")
DWD_TASK_PATH = ROOT / "etl_billiards" / "tasks" / "dwd_load_task.py"
SCD_COLS = {"scd2_start_time", "scd2_end_time", "scd2_is_current", "scd2_version"}
SITEPROFILE_FIELD_PURPOSE = {
"id": "门店 ID用于门店维度关联。",
"org_id": "组织/机构 ID用于组织维度归属。",
"shop_name": "门店名称,用于展示与查询。",
"site_label": "门店标签(如 A/B 店),用于展示与分组。",
"full_address": "门店详细地址,用于展示与地理信息。",
"address": "门店地址简称/快照,用于展示。",
"longitude": "经度,用于定位与地图展示。",
"latitude": "纬度,用于定位与地图展示。",
"tenant_site_region_id": "租户下门店区域 ID用于区域维度分析。",
"business_tel": "门店电话,用于联系信息展示。",
"site_type": "门店类型枚举,用于门店分类。",
"shop_status": "门店状态枚举,用于营业状态标识。",
"tenant_id": "租户/品牌 ID用于商户维度过滤与关联。",
"auto_light": "是否启用自动灯控配置,用于门店设备策略。",
"attendance_enabled": "是否启用考勤功能,用于门店考勤配置。",
"attendance_distance": "考勤允许距离(米),用于考勤打卡限制。",
"prod_env": "环境标识(生产/测试),用于区分配置环境。",
"light_status": "灯控状态/开关,用于灯控设备管理。",
"light_type": "灯控类型,用于设备类型区分。",
"light_token": "灯控控制令牌,用于对接灯控服务。",
"avatar": "门店头像/图片 URL用于展示。",
"wifi_name": "门店 WiFi 名称,用于展示与引导。",
"wifi_password": "门店 WiFi 密码,用于展示与引导。",
"customer_service_qrcode": "客服二维码 URL用于引导联系。",
"customer_service_wechat": "客服微信号,用于引导联系。",
"fixed_pay_qrCode": "固定收款码二维码URL用于收款引导。",
"create_time": "门店创建时间(快照字段)。",
"update_time": "门店更新时间(快照字段)。",
}
def _escape_sql(s: str) -> str:
return (s or "").replace("'", "''")
def _first_sentence(text: str, max_len: int = 140) -> str:
s = re.sub(r"\s+", " ", (text or "").strip())
if not s:
return ""
parts = re.split(r"[。;;]\s*", s)
s = parts[0].strip() if parts else s
if len(s) > max_len:
s = s[: max_len - 1] + ""
return s
def normalize_key(s: str) -> str:
return re.sub(r"[_\-\s]", "", (s or "").lower())
def snake_to_lower_camel(s: str) -> str:
parts = re.split(r"[_\-\s]+", s)
if not parts:
return s
first = parts[0].lower()
rest = "".join(p[:1].upper() + p[1:] for p in parts[1:] if p)
return first + rest
def snake_to_upper_camel(s: str) -> str:
parts = re.split(r"[_\-\s]+", s)
return "".join(p[:1].upper() + p[1:] for p in parts if p)
def find_key_in_record(record: dict, token: str) -> str | None:
if not isinstance(record, dict):
return None
if token in record:
return token
norm_to_key = {normalize_key(k): k for k in record.keys()}
candidates = [
token,
token.lower(),
token.upper(),
snake_to_lower_camel(token),
snake_to_upper_camel(token),
]
# 常见变体siteProfile/siteprofile
if normalize_key(token) == "siteprofile":
candidates.extend(["siteProfile", "siteprofile"])
for c in candidates:
nk = normalize_key(c)
if nk in norm_to_key:
return norm_to_key[nk]
return None
def parse_dwd_task_mappings(path: Path):
mod = ast.parse(path.read_text(encoding="utf-8"))
table_map = None
fact_mappings = None
for node in mod.body:
if isinstance(node, ast.ClassDef) and node.name == "DwdLoadTask":
for stmt in node.body:
if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
name = stmt.targets[0].id
if name == "TABLE_MAP":
table_map = ast.literal_eval(stmt.value)
elif name == "FACT_MAPPINGS":
fact_mappings = ast.literal_eval(stmt.value)
if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name):
name = stmt.target.id
if name == "TABLE_MAP":
table_map = ast.literal_eval(stmt.value)
elif name == "FACT_MAPPINGS":
fact_mappings = ast.literal_eval(stmt.value)
if not isinstance(table_map, dict) or not isinstance(fact_mappings, dict):
raise RuntimeError("Failed to parse TABLE_MAP/FACT_MAPPINGS from dwd_load_task.py")
return table_map, fact_mappings
def parse_columns_from_ddl(create_sql: str):
start = create_sql.find("(")
end = create_sql.rfind(")")
body = create_sql[start + 1 : end]
cols = []
for line in body.splitlines():
s = line.strip().rstrip(",")
if not s:
continue
if s.upper().startswith("PRIMARY KEY"):
continue
if s.upper().startswith("CONSTRAINT "):
continue
m = re.match(r"^([A-Za-z_][A-Za-z0-9_]*)\s+", s)
if not m:
continue
name = m.group(1)
if name.upper() in {"PRIMARY", "UNIQUE", "FOREIGN", "CHECK"}:
continue
cols.append(name.lower())
return cols
def _find_best_record_list(data, required_norm_keys: set[str]):
best = None
best_score = -1.0
best_path: list[str] = []
q = deque([(data, 0, [])])
visited = 0
while q and visited < 25000:
node, depth, path = q.popleft()
visited += 1
if depth > 10:
continue
if isinstance(node, list):
if node and all(isinstance(x, dict) for x in node[:3]):
scores = []
for x in node[:5]:
keys_norm = {normalize_key(k) for k in x.keys()}
scores.append(len(keys_norm & required_norm_keys))
score = sum(scores) / max(1, len(scores))
if score > best_score:
best_score = score
best = node
best_path = path
for x in node[:10]:
q.append((x, depth + 1, path))
else:
for x in node[:120]:
q.append((x, depth + 1, path))
elif isinstance(node, dict):
for k, v in list(node.items())[:160]:
q.append((v, depth + 1, path + [str(k)]))
node_str = ".".join(best_path) if best_path else "$"
return best or [], node_str
def _format_example(value, max_len: int = 120) -> str:
if value is None:
return "NULL"
if isinstance(value, bool):
return "true" if value else "false"
if isinstance(value, (int, float)):
return str(value)
if isinstance(value, str):
s = value.strip()
if len(s) > max_len:
s = s[: max_len - 1] + ""
return s
if isinstance(value, dict):
keys = list(value)[:6]
mini = {k: value.get(k) for k in keys}
rendered = json.dumps(mini, ensure_ascii=False)
if len(value) > len(keys):
rendered = rendered[:-1] + ", …}"
if len(rendered) > max_len:
rendered = rendered[: max_len - 1] + ""
return rendered
if isinstance(value, list):
if not value:
return "[]"
rendered = json.dumps(value[0], ensure_ascii=False)
if len(value) > 1:
rendered = f"[{rendered}, …] (len={len(value)})"
else:
rendered = f"[{rendered}]"
if len(rendered) > max_len:
rendered = rendered[: max_len - 1] + ""
return rendered
s = str(value)
if len(s) > max_len:
s = s[: max_len - 1] + ""
return s
def _infer_purpose(table: str, col: str, json_path: str | None) -> str:
lcol = col.lower()
if lcol in SCD_COLS:
if lcol == "scd2_start_time":
return "SCD2 开始时间(版本生效起点),用于维度慢变追踪。"
if lcol == "scd2_end_time":
return "SCD2 结束时间(默认 9999-12-31 表示当前版本),用于维度慢变追踪。"
if lcol == "scd2_is_current":
return "SCD2 当前版本标记1=当前0=历史),用于筛选最新维度记录。"
if lcol == "scd2_version":
return "SCD2 版本号(自增),用于与时间段一起避免版本重叠。"
if json_path and json_path.startswith("siteProfile."):
sf = json_path.split(".", 1)[1]
return SITEPROFILE_FIELD_PURPOSE.get(sf, "门店快照字段,用于门店维度补充信息。")
if lcol.endswith("_id"):
return "标识类 ID 字段,用于关联/定位相关实体。"
if lcol.endswith("_time") or lcol.endswith("time") or lcol.endswith("_date"):
return "时间/日期字段,用于记录业务时间与统计口径对齐。"
if any(k in lcol for k in ["amount", "money", "fee", "price", "deduct", "cost", "balance"]):
return "金额字段,用于计费/结算/核算等金额计算。"
if any(k in lcol for k in ["count", "num", "number", "seconds", "qty", "quantity"]):
return "数量/时长字段,用于统计与计量。"
if lcol.endswith("_name") or lcol.endswith("name"):
return "名称字段,用于展示与辅助识别。"
if lcol.endswith("_status") or lcol == "status":
return "状态枚举字段,用于标识业务状态。"
if lcol.startswith("is_") or lcol.startswith("can_"):
return "布尔/开关字段,用于表示是否/可用性等业务开关。"
# 表级兜底
if table.startswith("dim_"):
return "维度字段,用于补充维度属性。"
return "明细字段,用于记录事实取值。"
def _parse_json_extract(expr: str):
# e.g. siteprofile->>'org_id'
m = re.match(r"^([A-Za-z_][A-Za-z0-9_]*)\s*->>\s*'([^']+)'\s*$", expr)
if not m:
return None
base = m.group(1)
field = m.group(2)
if normalize_key(base) == "siteprofile":
base = "siteProfile"
return base, field
def build_table_comment(table: str, source_ods: str | None, source_json_base: str | None) -> str:
table_l = table.lower()
if table_l.startswith("dim_"):
kind = "DWD 维度表"
else:
kind = "DWD 明细事实表"
extra = "扩展字段表" if table_l.endswith("_ex") else ""
if source_ods and source_json_base:
src = (
f"ODS 来源表:{source_ods}(对应 JSON{source_json_base}.json分析{source_json_base}-Analysis.md"
f"装载/清洗逻辑参考etl_billiards/tasks/dwd_load_task.pyDwdLoadTask"
)
else:
src = "来源:由 ODS 清洗装载生成(详见 DWD 装载任务)。"
return f"{kind}{('' + extra + '') if extra else ''}{table_l}{src}"
def get_source_info(table_l: str, table_map: dict) -> tuple[str | None, str | None]:
key = f"billiards_dwd.{table_l}"
source_ods = table_map.get(key)
if not source_ods:
return None, None
json_base = source_ods.split(".")[-1]
return source_ods, json_base
def build_column_mappings(table_l: str, cols: list[str], fact_mappings: dict) -> dict[str, tuple[str | None, str | None]]:
# return col -> (json_path, src_expr)
mapping_list = fact_mappings.get(f"billiards_dwd.{table_l}") or []
explicit = {dwd_col.lower(): src_expr for dwd_col, src_expr, _cast in mapping_list}
casts = {dwd_col.lower(): cast for dwd_col, _src_expr, cast in mapping_list}
out: dict[str, tuple[str | None, str | None]] = {}
for c in cols:
if c in SCD_COLS:
out[c] = (None, None)
continue
src_expr = explicit.get(c, c)
cast = casts.get(c)
json_path = None
parsed = _parse_json_extract(src_expr)
if parsed:
base, field = parsed
json_path = f"{base}.{field}"
else:
# derived: pay_date uses pay_time + cast date
if cast == "date":
json_path = src_expr
else:
json_path = src_expr
out[c] = (json_path, src_expr)
return out
def load_json_records(json_base: str, required_norm_keys: set[str]):
json_path = DOC_DIR / f"{json_base}.json"
data = json.loads(json_path.read_text(encoding="utf-8"))
return _find_best_record_list(data, required_norm_keys)
def pick_example_from_record(record: dict, json_path: str | None):
if not json_path:
return None
if json_path.startswith("siteProfile."):
base_key = find_key_in_record(record, "siteProfile")
base = record.get(base_key) if base_key else None
if isinstance(base, dict):
field = json_path.split(".", 1)[1]
return base.get(field)
return None
# plain key
key = find_key_in_record(record, json_path)
if key:
return record.get(key)
# fallback: try match by normalized name
nk = normalize_key(json_path)
for k in record.keys():
if normalize_key(k) == nk:
return record.get(k)
return None
def resolve_json_field_display(records: list, json_path: str | None, cast: str | None = None) -> str:
if not json_path:
return ""
if json_path.startswith("siteProfile."):
return json_path
actual_key = None
for r in records[:80]:
if not isinstance(r, dict):
continue
k = find_key_in_record(r, json_path)
if k:
actual_key = k
break
base = actual_key or json_path
if cast == "date":
return f"{base}派生DATE({base})"
if cast == "boolean":
return f"{base}派生BOOLEAN({base})"
if cast in {"numeric", "timestamptz"}:
return f"{base}派生CAST({base} AS {cast})"
return base
def resolve_ods_source_field(records: list, src_expr: str | None, cast: str | None = None) -> str:
if not src_expr:
return ""
parsed = _parse_json_extract(src_expr)
if parsed:
base, field = parsed
# 统一大小写展示
if normalize_key(base) == "siteprofile":
base = "siteProfile"
return f"{base}.{field}"
# 直接字段:尽量输出 JSON 实际键名(大小写/驼峰)
actual = None
for r in records[:80]:
if not isinstance(r, dict):
continue
k = find_key_in_record(r, src_expr)
if k:
actual = k
break
base = actual or src_expr
if cast == "date":
return f"{base}派生DATE({base})"
if cast == "boolean":
return f"{base}派生BOOLEAN({base})"
if cast in {"numeric", "timestamptz"}:
return f"{base}派生CAST({base} AS {cast})"
return base
def resolve_json_field_triplet(
json_file: str | None,
record_node: str | None,
records: list,
json_path: str | None,
cast: str | None = None,
) -> str:
if not json_file:
json_file = ""
node = record_node or "$"
if not json_path:
return f"{json_file} - 无 - 无"
if json_path.startswith("siteProfile."):
base_key = None
field_key = None
for r in records[:80]:
if not isinstance(r, dict):
continue
base_key = find_key_in_record(r, "siteProfile")
if base_key:
base = r.get(base_key)
if isinstance(base, dict):
raw_field = json_path.split(".", 1)[1]
# 尽量匹配子字段大小写
if raw_field in base:
field_key = raw_field
else:
nk = normalize_key(raw_field)
for k in base.keys():
if normalize_key(k) == nk:
field_key = k
break
break
base_key = base_key or "siteProfile"
field_key = field_key or json_path.split(".", 1)[1]
node = f"{node}.{base_key}" if node else base_key
field = field_key
else:
actual = None
for r in records[:80]:
if isinstance(r, dict):
actual = find_key_in_record(r, json_path)
if actual:
break
field = actual or json_path
if cast == "date":
field = f"{field}派生DATE({field})"
elif cast == "boolean":
field = f"{field}派生BOOLEAN({field})"
elif cast in {"numeric", "timestamptz"}:
field = f"{field}派生CAST({field} AS {cast})"
return f"{json_file} - {node} - {field}"
def main():
table_map, fact_mappings = parse_dwd_task_mappings(DWD_TASK_PATH)
raw = SQL_PATH.read_text(encoding="utf-8", errors="replace")
newline = "\r\n" if "\r\n" in raw else "\n"
# strip all sql comments and existing COMMENT ON statements, incl. DO-block comment exec lines
kept_lines = []
for line in raw.splitlines(True):
if line.lstrip().startswith("--"):
continue
if re.match(r"^\s*COMMENT ON\s+(TABLE|COLUMN)\s+", line, re.I):
continue
if "COMMENT ON COLUMN" in line or "COMMENT ON TABLE" in line:
# remove legacy execute format lines too
continue
kept_lines.append(line)
clean = "".join(kept_lines)
create_re = re.compile(
r"(^\s*CREATE TABLE IF NOT EXISTS\s+(?P<table>[A-Za-z0-9_]+)\s*\([\s\S]*?\)\s*;)",
re.M,
)
out_parts = []
last = 0
count_tables = 0
for m in create_re.finditer(clean):
stmt = m.group(1)
table = m.group("table").lower()
out_parts.append(clean[last : m.end()])
cols = parse_columns_from_ddl(stmt)
source_ods, json_base = get_source_info(table, table_map)
# derive required keys
required_norm = set()
col_map = build_column_mappings(table, cols, fact_mappings)
# cast map for json field display
cast_map = {
dwd_col.lower(): cast
for dwd_col, _src_expr, cast in (fact_mappings.get(f"billiards_dwd.{table}") or [])
}
src_expr_map = {
dwd_col.lower(): src_expr
for dwd_col, src_expr, _cast in (fact_mappings.get(f"billiards_dwd.{table}") or [])
}
for c, (jp, _src) in col_map.items():
if not jp:
continue
if jp.startswith("siteProfile."):
required_norm.add(normalize_key("siteProfile"))
else:
required_norm.add(normalize_key(jp))
records = []
record_node = "$"
if json_base and (DOC_DIR / f"{json_base}.json").exists():
try:
records, record_node = load_json_records(json_base, required_norm)
except Exception:
records = []
record_node = "$"
table_comment = build_table_comment(table, source_ods, json_base)
comment_lines = [f"COMMENT ON TABLE billiards_dwd.{table} IS '{_escape_sql(table_comment)}';"]
for c in cols:
jp, _src = col_map.get(c, (None, None))
if c in SCD_COLS:
if c == "scd2_start_time":
ex = "2025-11-10T00:00:00+08:00"
elif c == "scd2_end_time":
ex = "9999-12-31T00:00:00+00:00"
elif c == "scd2_is_current":
ex = "1"
else:
ex = "1"
json_field = "无 - DWD慢变元数据 - 无"
ods_src = "DWD慢变元数据"
else:
# pick example from first records
ex_val = None
for r in records[:80]:
v = pick_example_from_record(r, jp)
if v not in (None, ""):
ex_val = v
break
ex = _format_example(ex_val)
json_field = resolve_json_field_triplet(
f"{json_base}.json" if json_base else None,
record_node,
records,
jp,
cast_map.get(c),
)
src_expr = src_expr_map.get(c, jp)
ods_src = resolve_ods_source_field(records, src_expr, cast_map.get(c))
purpose = _first_sentence(_infer_purpose(table, c, jp), 140)
func = purpose
if "用于" not in func:
func = "用于" + func.rstrip("")
if source_ods:
ods_table_only = source_ods.split(".")[-1]
ods_src_display = f"{ods_table_only} - {ods_src}"
else:
ods_src_display = f"无 - {ods_src}"
comment = (
f"【说明】{purpose}"
f" 【示例】{ex}{func})。"
f" 【ODS来源】{ods_src_display}"
f" 【JSON字段】{json_field}"
)
comment_lines.append(
f"COMMENT ON COLUMN billiards_dwd.{table}.{c} IS '{_escape_sql(comment)}';"
)
out_parts.append(newline + newline + (newline.join(comment_lines)) + newline + newline)
last = m.end()
count_tables += 1
out_parts.append(clean[last:])
result = "".join(out_parts)
# collapse extra blank lines
result = re.sub(r"(?:\r?\n){4,}", newline * 3, result)
backup = SQL_PATH.with_suffix(SQL_PATH.suffix + ".bak")
if not backup.exists():
backup.write_text(raw, encoding="utf-8")
SQL_PATH.write_text(result, encoding="utf-8")
print(f"Rewrote comments for {count_tables} tables: {SQL_PATH}")
if __name__ == "__main__":
main()