数据库 数据校验写入等逻辑更新。

This commit is contained in:
Neo
2026-02-01 03:46:16 +08:00
parent 9948000b71
commit 076f5755ca
128 changed files with 494310 additions and 2819 deletions

View File

@@ -22,6 +22,7 @@ from typing import Iterable, Sequence
from zoneinfo import ZoneInfo
from dateutil import parser as dtparser
from psycopg2 import InterfaceError, OperationalError
from psycopg2.extras import execute_values
PROJECT_ROOT = Path(__file__).resolve().parents[1]
@@ -32,8 +33,14 @@ from api.client import APIClient
from config.settings import AppConfig
from database.connection import DatabaseConnection
from models.parsers import TypeParser
from tasks.ods_tasks import ENABLED_ODS_CODES, ODS_TASK_SPECS
from tasks.ods_tasks import BaseOdsTask, ENABLED_ODS_CODES, ODS_TASK_SPECS
from utils.logging_utils import build_log_path, configure_logging
from utils.ods_record_utils import (
get_value_case_insensitive,
merge_record_layers,
normalize_pk_value,
pk_tuple_from_record,
)
from utils.windowing import split_window
DEFAULT_START = "2025-07-01"
@@ -74,38 +81,7 @@ def _iter_windows(start: datetime, end: datetime, window_size: timedelta) -> Ite
def _merge_record_layers(record: dict) -> dict:
merged = record
data_part = merged.get("data")
while isinstance(data_part, dict):
merged = {**data_part, **merged}
data_part = data_part.get("data")
settle_inner = merged.get("settleList")
if isinstance(settle_inner, dict):
merged = {**settle_inner, **merged}
return merged
def _get_value_case_insensitive(record: dict | None, col: str | None):
if record is None or col is None:
return None
if col in record:
return record.get(col)
col_lower = col.lower()
for k, v in record.items():
if isinstance(k, str) and k.lower() == col_lower:
return v
return None
def _normalize_pk_value(value):
if value is None:
return None
if isinstance(value, str) and value.isdigit():
try:
return int(value)
except Exception:
return value
return value
return merge_record_layers(record)
def _chunked(seq: Sequence, size: int) -> Iterable[Sequence]:
@@ -133,7 +109,24 @@ def _get_table_pk_columns(conn, table: str) -> list[str]:
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, name))
return [r[0] for r in cur.fetchall()]
cols = [r[0] for r in cur.fetchall()]
return [c for c in cols if c.lower() != "content_hash"]
def _table_has_column(conn, table: str, column: str) -> bool:
if "." in table:
schema, name = table.split(".", 1)
else:
schema, name = "public", table
sql = """
SELECT 1
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s AND column_name = %s
LIMIT 1
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, name, column))
return cur.fetchone() is not None
def _fetch_existing_pk_set(conn, table: str, pk_cols: Sequence[str], pk_values: list[tuple], chunk_size: int) -> set[tuple]:
@@ -155,6 +148,54 @@ def _fetch_existing_pk_set(conn, table: str, pk_cols: Sequence[str], pk_values:
return existing
def _fetch_existing_pk_hash_set(
conn, table: str, pk_cols: Sequence[str], pk_hash_values: list[tuple], chunk_size: int
) -> set[tuple]:
if not pk_hash_values:
return set()
select_cols = ", ".join([*(f't.\"{c}\"' for c in pk_cols), 't.\"content_hash\"'])
value_cols = ", ".join([*(f'\"{c}\"' for c in pk_cols), '\"content_hash\"'])
join_cond = " AND ".join([*(f't.\"{c}\" = v.\"{c}\"' for c in pk_cols), 't.\"content_hash\" = v.\"content_hash\"'])
sql = (
f"SELECT {select_cols} FROM {table} t "
f"JOIN (VALUES %s) AS v({value_cols}) ON {join_cond}"
)
existing: set[tuple] = set()
with conn.cursor() as cur:
for chunk in _chunked(pk_hash_values, chunk_size):
execute_values(cur, sql, chunk, page_size=len(chunk))
for row in cur.fetchall():
existing.add(tuple(row))
return existing
def _init_db_state(cfg: AppConfig) -> dict:
db_conn = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
try:
db_conn.conn.rollback()
except Exception:
pass
db_conn.conn.autocommit = True
return {"db": db_conn, "conn": db_conn.conn}
def _reconnect_db(db_state: dict, cfg: AppConfig, logger: logging.Logger):
try:
db_state.get("db").close()
except Exception:
pass
db_state.update(_init_db_state(cfg))
logger.warning("DB connection reset/reconnected")
return db_state["conn"]
def _ensure_db_conn(db_state: dict, cfg: AppConfig, logger: logging.Logger):
conn = db_state.get("conn")
if conn is None or getattr(conn, "closed", 0):
return _reconnect_db(db_state, cfg, logger)
return conn
def _merge_common_params(cfg: AppConfig, task_code: str, base: dict) -> dict:
merged: dict = {}
common = cfg.get("api.params", {}) or {}
@@ -182,19 +223,22 @@ def _build_params(cfg: AppConfig, spec, store_id: int, window_start: datetime |
return _merge_common_params(cfg, spec.code, base)
def _pk_tuple_from_record(record: dict, pk_cols: Sequence[str]) -> tuple | None:
merged = _merge_record_layers(record)
def _pk_tuple_from_merged(merged: dict, pk_cols: Sequence[str]) -> tuple | None:
values = []
for col in pk_cols:
val = _normalize_pk_value(_get_value_case_insensitive(merged, col))
val = normalize_pk_value(get_value_case_insensitive(merged, col))
if val is None or val == "":
return None
values.append(val)
return tuple(values)
def _pk_tuple_from_record(record: dict, pk_cols: Sequence[str]) -> tuple | None:
return pk_tuple_from_record(record, pk_cols)
def _pk_tuple_from_ticket_candidate(value) -> tuple | None:
val = _normalize_pk_value(value)
val = normalize_pk_value(value)
if val is None or val == "":
return None
return (val,)
@@ -204,10 +248,17 @@ def _format_missing_sample(pk_cols: Sequence[str], pk_tuple: tuple) -> dict:
return {col: pk_tuple[idx] for idx, col in enumerate(pk_cols)}
def _format_mismatch_sample(pk_cols: Sequence[str], pk_tuple: tuple, content_hash: str | None) -> dict:
sample = _format_missing_sample(pk_cols, pk_tuple)
if content_hash:
sample["content_hash"] = content_hash
return sample
def _check_spec(
*,
client: APIClient,
db_conn,
db_state: dict,
cfg: AppConfig,
tz: ZoneInfo,
logger: logging.Logger,
@@ -219,6 +270,8 @@ def _check_spec(
page_size: int,
chunk_size: int,
sample_limit: int,
compare_content: bool,
content_sample_limit: int,
sleep_per_window: float,
sleep_per_page: float,
) -> dict:
@@ -231,19 +284,34 @@ def _check_spec(
"records_with_pk": 0,
"missing": 0,
"missing_samples": [],
"mismatch": 0,
"mismatch_samples": [],
"pages": 0,
"skipped_missing_pk": 0,
"errors": 0,
"error_detail": None,
}
pk_cols = _get_table_pk_columns(db_conn, spec.table_name)
db_conn = _ensure_db_conn(db_state, cfg, logger)
try:
pk_cols = _get_table_pk_columns(db_conn, spec.table_name)
except (OperationalError, InterfaceError):
db_conn = _reconnect_db(db_state, cfg, logger)
pk_cols = _get_table_pk_columns(db_conn, spec.table_name)
result["pk_columns"] = pk_cols
if not pk_cols:
result["errors"] = 1
result["error_detail"] = "no primary key columns found"
return result
try:
has_content_hash = bool(compare_content and _table_has_column(db_conn, spec.table_name, "content_hash"))
except (OperationalError, InterfaceError):
db_conn = _reconnect_db(db_state, cfg, logger)
has_content_hash = bool(compare_content and _table_has_column(db_conn, spec.table_name, "content_hash"))
result["compare_content"] = bool(compare_content)
result["content_hash_supported"] = has_content_hash
if spec.requires_window and spec.time_fields:
if not start or not end:
result["errors"] = 1
@@ -293,24 +361,33 @@ def _check_spec(
result["pages"] += 1
result["records"] += len(records)
pk_tuples: list[tuple] = []
pk_hash_tuples: list[tuple] = []
for rec in records:
if not isinstance(rec, dict):
result["skipped_missing_pk"] += 1
window_skipped += 1
continue
pk_tuple = _pk_tuple_from_record(rec, pk_cols)
merged = _merge_record_layers(rec)
pk_tuple = _pk_tuple_from_merged(merged, pk_cols)
if not pk_tuple:
result["skipped_missing_pk"] += 1
window_skipped += 1
continue
pk_tuples.append(pk_tuple)
if has_content_hash:
content_hash = BaseOdsTask._compute_content_hash(merged, include_fetched_at=False)
pk_hash_tuples.append((*pk_tuple, content_hash))
if not pk_tuples:
continue
result["records_with_pk"] += len(pk_tuples)
pk_unique = list(dict.fromkeys(pk_tuples))
existing = _fetch_existing_pk_set(db_conn, spec.table_name, pk_cols, pk_unique, chunk_size)
try:
existing = _fetch_existing_pk_set(db_conn, spec.table_name, pk_cols, pk_unique, chunk_size)
except (OperationalError, InterfaceError):
db_conn = _reconnect_db(db_state, cfg, logger)
existing = _fetch_existing_pk_set(db_conn, spec.table_name, pk_cols, pk_unique, chunk_size)
for pk_tuple in pk_unique:
if pk_tuple in existing:
continue
@@ -321,6 +398,29 @@ def _check_spec(
window_missing += 1
if len(result["missing_samples"]) < sample_limit:
result["missing_samples"].append(_format_missing_sample(pk_cols, pk_tuple))
if has_content_hash and pk_hash_tuples:
pk_hash_unique = list(dict.fromkeys(pk_hash_tuples))
try:
existing_hash = _fetch_existing_pk_hash_set(
db_conn, spec.table_name, pk_cols, pk_hash_unique, chunk_size
)
except (OperationalError, InterfaceError):
db_conn = _reconnect_db(db_state, cfg, logger)
existing_hash = _fetch_existing_pk_hash_set(
db_conn, spec.table_name, pk_cols, pk_hash_unique, chunk_size
)
for pk_hash_tuple in pk_hash_unique:
pk_tuple = pk_hash_tuple[:-1]
if pk_tuple not in existing:
continue
if pk_hash_tuple in existing_hash:
continue
result["mismatch"] += 1
if len(result["mismatch_samples"]) < content_sample_limit:
result["mismatch_samples"].append(
_format_mismatch_sample(pk_cols, pk_tuple, pk_hash_tuple[-1])
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"PAGE task=%s idx=%s page=%s records=%s missing=%s skipped=%s",
@@ -369,7 +469,7 @@ def _check_spec(
def _check_settlement_tickets(
*,
client: APIClient,
db_conn,
db_state: dict,
cfg: AppConfig,
tz: ZoneInfo,
logger: logging.Logger,
@@ -380,11 +480,18 @@ def _check_settlement_tickets(
page_size: int,
chunk_size: int,
sample_limit: int,
compare_content: bool,
content_sample_limit: int,
sleep_per_window: float,
sleep_per_page: float,
) -> dict:
table_name = "billiards_ods.settlement_ticket_details"
pk_cols = _get_table_pk_columns(db_conn, table_name)
db_conn = _ensure_db_conn(db_state, cfg, logger)
try:
pk_cols = _get_table_pk_columns(db_conn, table_name)
except (OperationalError, InterfaceError):
db_conn = _reconnect_db(db_state, cfg, logger)
pk_cols = _get_table_pk_columns(db_conn, table_name)
result = {
"task_code": "ODS_SETTLEMENT_TICKET",
"table": table_name,
@@ -394,6 +501,8 @@ def _check_settlement_tickets(
"records_with_pk": 0,
"missing": 0,
"missing_samples": [],
"mismatch": 0,
"mismatch_samples": [],
"pages": 0,
"skipped_missing_pk": 0,
"errors": 0,
@@ -476,7 +585,11 @@ def _check_settlement_tickets(
result["records_with_pk"] += len(pk_tuples)
pk_unique = list(dict.fromkeys(pk_tuples))
existing = _fetch_existing_pk_set(db_conn, table_name, pk_cols, pk_unique, chunk_size)
try:
existing = _fetch_existing_pk_set(db_conn, table_name, pk_cols, pk_unique, chunk_size)
except (OperationalError, InterfaceError):
db_conn = _reconnect_db(db_state, cfg, logger)
existing = _fetch_existing_pk_set(db_conn, table_name, pk_cols, pk_unique, chunk_size)
for pk_tuple in pk_unique:
if pk_tuple in existing:
continue
@@ -585,6 +698,8 @@ def run_gap_check(
cutoff_overlap_hours: int,
allow_small_window: bool,
logger: logging.Logger,
compare_content: bool = False,
content_sample_limit: int | None = None,
window_split_unit: str | None = None,
window_compensation_hours: int | None = None,
) -> dict:
@@ -668,6 +783,9 @@ def run_gap_check(
if windows:
start, end = windows[0][0], windows[-1][1]
if content_sample_limit is None:
content_sample_limit = sample_limit
logger.info(
"START range=%s~%s window_days=%s window_hours=%s split_unit=%s comp_hours=%s page_size=%s chunk_size=%s",
start.isoformat() if isinstance(start, datetime) else None,
@@ -690,12 +808,7 @@ def run_gap_check(
headers_extra=cfg["api"].get("headers_extra") or {},
)
db_conn = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
try:
db_conn.conn.rollback()
except Exception:
pass
db_conn.conn.autocommit = True
db_state = _init_db_state(cfg)
try:
task_filter = {t.strip().upper() for t in (task_codes or "").split(",") if t.strip()}
specs = [s for s in ODS_TASK_SPECS if s.code in ENABLED_ODS_CODES]
@@ -708,7 +821,7 @@ def run_gap_check(
continue
result = _check_spec(
client=client,
db_conn=db_conn.conn,
db_state=db_state,
cfg=cfg,
tz=tz,
logger=logger,
@@ -720,6 +833,8 @@ def run_gap_check(
page_size=page_size,
chunk_size=chunk_size,
sample_limit=sample_limit,
compare_content=compare_content,
content_sample_limit=content_sample_limit,
sleep_per_window=sleep_per_window,
sleep_per_page=sleep_per_page,
)
@@ -735,7 +850,7 @@ def run_gap_check(
if (not task_filter) or ("ODS_SETTLEMENT_TICKET" in task_filter):
ticket_result = _check_settlement_tickets(
client=client,
db_conn=db_conn.conn,
db_state=db_state,
cfg=cfg,
tz=tz,
logger=logger,
@@ -746,6 +861,8 @@ def run_gap_check(
page_size=page_size,
chunk_size=chunk_size,
sample_limit=sample_limit,
compare_content=compare_content,
content_sample_limit=content_sample_limit,
sleep_per_window=sleep_per_window,
sleep_per_page=sleep_per_page,
)
@@ -759,6 +876,7 @@ def run_gap_check(
)
total_missing = sum(int(r.get("missing") or 0) for r in results)
total_mismatch = sum(int(r.get("mismatch") or 0) for r in results)
total_errors = sum(int(r.get("errors") or 0) for r in results)
payload = {
@@ -772,16 +890,22 @@ def run_gap_check(
"page_size": page_size,
"chunk_size": chunk_size,
"sample_limit": sample_limit,
"compare_content": compare_content,
"content_sample_limit": content_sample_limit,
"store_id": store_id,
"base_url": cfg.get("api.base_url"),
"results": results,
"total_missing": total_missing,
"total_mismatch": total_mismatch,
"total_errors": total_errors,
"generated_at": datetime.now(tz).isoformat(),
}
return payload
finally:
db_conn.close()
try:
db_state.get("db").close()
except Exception:
pass
def main() -> int:
@@ -796,6 +920,13 @@ def main() -> int:
ap.add_argument("--page-size", type=int, default=200, help="API page size (default: 200)")
ap.add_argument("--chunk-size", type=int, default=500, help="DB query chunk size (default: 500)")
ap.add_argument("--sample-limit", type=int, default=50, help="max missing PK samples per table")
ap.add_argument("--compare-content", action="store_true", help="compare record content hash (mismatch detection)")
ap.add_argument(
"--content-sample-limit",
type=int,
default=None,
help="max mismatch samples per table (default: same as --sample-limit)",
)
ap.add_argument("--sleep-per-window-seconds", type=float, default=0, help="sleep seconds after each window")
ap.add_argument("--sleep-per-page-seconds", type=float, default=0, help="sleep seconds after each page")
ap.add_argument("--task-codes", default="", help="comma-separated task codes to check (optional)")
@@ -847,6 +978,8 @@ def main() -> int:
cutoff_overlap_hours=args.cutoff_overlap_hours,
allow_small_window=args.allow_small_window,
logger=logger,
compare_content=args.compare_content,
content_sample_limit=args.content_sample_limit,
window_split_unit=args.window_split_unit or None,
window_compensation_hours=args.window_compensation_hours,
)
@@ -862,8 +995,9 @@ def main() -> int:
out_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
logger.info("REPORT_WRITTEN path=%s", out_path)
logger.info(
"SUMMARY missing=%s errors=%s",
"SUMMARY missing=%s mismatch=%s errors=%s",
payload.get("total_missing"),
payload.get("total_mismatch"),
payload.get("total_errors"),
)