数据库 数据校验写入等逻辑更新。
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
"""ODS ingestion tasks."""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
@@ -317,6 +318,7 @@ class BaseOdsTask(BaseTask):
|
||||
db_json_cols_lower = {
|
||||
c[0].lower() for c in cols_info if c[1] in ("json", "jsonb") or c[2] in ("json", "jsonb")
|
||||
}
|
||||
needs_content_hash = any(c[0].lower() == "content_hash" for c in cols_info)
|
||||
|
||||
col_names = [c[0] for c in cols_info]
|
||||
quoted_cols = ", ".join(f'\"{c}\"' for c in col_names)
|
||||
@@ -330,6 +332,7 @@ class BaseOdsTask(BaseTask):
|
||||
|
||||
params: list[tuple] = []
|
||||
skipped = 0
|
||||
merged_records: list[dict] = []
|
||||
|
||||
root_site_profile = None
|
||||
if isinstance(response_payload, dict):
|
||||
@@ -345,6 +348,7 @@ class BaseOdsTask(BaseTask):
|
||||
continue
|
||||
|
||||
merged_rec = self._merge_record_layers(rec)
|
||||
merged_records.append({"raw": rec, "merged": merged_rec})
|
||||
if table in {"billiards_ods.recharge_settlements", "billiards_ods.settlement_records"}:
|
||||
site_profile = merged_rec.get("siteProfile") or merged_rec.get("site_profile") or root_site_profile
|
||||
if isinstance(site_profile, dict):
|
||||
@@ -363,9 +367,42 @@ class BaseOdsTask(BaseTask):
|
||||
_fill_missing("siteid", [site_profile.get("siteId"), site_profile.get("id")])
|
||||
_fill_missing("sitename", [site_profile.get("shop_name"), site_profile.get("siteName")])
|
||||
|
||||
has_fetched_at = any(c[0].lower() == "fetched_at" for c in cols_info)
|
||||
business_keys = [c for c in pk_cols if str(c).lower() != "content_hash"]
|
||||
compare_latest = bool(needs_content_hash and has_fetched_at and business_keys)
|
||||
latest_compare_hash: dict[tuple[Any, ...], str | None] = {}
|
||||
if compare_latest:
|
||||
key_values: list[tuple[Any, ...]] = []
|
||||
for item in merged_records:
|
||||
merged_rec = item["merged"]
|
||||
key = tuple(self._get_value_case_insensitive(merged_rec, k) for k in business_keys)
|
||||
if any(v is None or v == "" for v in key):
|
||||
continue
|
||||
key_values.append(key)
|
||||
|
||||
if key_values:
|
||||
with self.db.conn.cursor() as cur:
|
||||
latest_payloads = self._fetch_latest_payloads(cur, table, business_keys, key_values)
|
||||
for key, payload in latest_payloads.items():
|
||||
latest_compare_hash[key] = self._compute_compare_hash_from_payload(payload)
|
||||
|
||||
for item in merged_records:
|
||||
rec = item["raw"]
|
||||
merged_rec = item["merged"]
|
||||
|
||||
content_hash = None
|
||||
compare_hash = None
|
||||
if needs_content_hash:
|
||||
compare_hash = self._compute_content_hash(merged_rec, include_fetched_at=False)
|
||||
hash_record = dict(merged_rec)
|
||||
hash_record["fetched_at"] = now
|
||||
content_hash = self._compute_content_hash(hash_record, include_fetched_at=True)
|
||||
|
||||
if pk_cols:
|
||||
missing_pk = False
|
||||
for pk in pk_cols:
|
||||
if str(pk).lower() == "content_hash":
|
||||
continue
|
||||
pk_val = self._get_value_case_insensitive(merged_rec, pk)
|
||||
if pk_val is None or pk_val == "":
|
||||
missing_pk = True
|
||||
@@ -374,6 +411,16 @@ class BaseOdsTask(BaseTask):
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
if compare_latest and compare_hash is not None:
|
||||
key = tuple(self._get_value_case_insensitive(merged_rec, k) for k in business_keys)
|
||||
if any(v is None or v == "" for v in key):
|
||||
skipped += 1
|
||||
continue
|
||||
last_hash = latest_compare_hash.get(key)
|
||||
if last_hash is not None and last_hash == compare_hash:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
row_vals: list[Any] = []
|
||||
for (col_name, data_type, _udt) in cols_info:
|
||||
col_lower = col_name.lower()
|
||||
@@ -389,6 +436,9 @@ class BaseOdsTask(BaseTask):
|
||||
if col_lower == "fetched_at":
|
||||
row_vals.append(now)
|
||||
continue
|
||||
if col_lower == "content_hash":
|
||||
row_vals.append(content_hash)
|
||||
continue
|
||||
|
||||
value = self._normalize_scalar(self._get_value_case_insensitive(merged_rec, col_name))
|
||||
if col_lower in db_json_cols_lower:
|
||||
@@ -472,6 +522,93 @@ class BaseOdsTask(BaseTask):
|
||||
return resolver(spec.endpoint)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _hash_default(value):
|
||||
if isinstance(value, datetime):
|
||||
return value.isoformat()
|
||||
return str(value)
|
||||
|
||||
@classmethod
|
||||
def _sanitize_record_for_hash(cls, record: dict, *, include_fetched_at: bool) -> dict:
|
||||
exclude = {
|
||||
"data",
|
||||
"payload",
|
||||
"source_file",
|
||||
"source_endpoint",
|
||||
"content_hash",
|
||||
"record_index",
|
||||
}
|
||||
if not include_fetched_at:
|
||||
exclude.add("fetched_at")
|
||||
|
||||
def _strip(value):
|
||||
if isinstance(value, dict):
|
||||
cleaned = {}
|
||||
for k, v in value.items():
|
||||
if isinstance(k, str) and k.lower() in exclude:
|
||||
continue
|
||||
cleaned[k] = _strip(v)
|
||||
return cleaned
|
||||
if isinstance(value, list):
|
||||
return [_strip(v) for v in value]
|
||||
return value
|
||||
|
||||
return _strip(record or {})
|
||||
|
||||
@classmethod
|
||||
def _compute_content_hash(cls, record: dict, *, include_fetched_at: bool) -> str:
|
||||
cleaned = cls._sanitize_record_for_hash(record, include_fetched_at=include_fetched_at)
|
||||
payload = json.dumps(
|
||||
cleaned,
|
||||
ensure_ascii=False,
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
default=cls._hash_default,
|
||||
)
|
||||
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _compute_compare_hash_from_payload(payload: Any) -> str | None:
|
||||
if payload is None:
|
||||
return None
|
||||
if isinstance(payload, str):
|
||||
try:
|
||||
payload = json.loads(payload)
|
||||
except Exception:
|
||||
return None
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
merged = BaseOdsTask._merge_record_layers(payload)
|
||||
return BaseOdsTask._compute_content_hash(merged, include_fetched_at=False)
|
||||
|
||||
@staticmethod
|
||||
def _fetch_latest_payloads(cur, table: str, business_keys: Sequence[str], key_values: Sequence[tuple]) -> dict:
|
||||
if not business_keys or not key_values:
|
||||
return {}
|
||||
keys_sql = ", ".join(f'"{k}"' for k in business_keys)
|
||||
sql = (
|
||||
f"WITH keys({keys_sql}) AS (VALUES %s) "
|
||||
f"SELECT DISTINCT ON ({keys_sql}) {keys_sql}, payload "
|
||||
f"FROM {table} t JOIN keys k USING ({keys_sql}) "
|
||||
f"ORDER BY {keys_sql}, fetched_at DESC NULLS LAST"
|
||||
)
|
||||
unique_keys = list({tuple(k) for k in key_values})
|
||||
execute_values(cur, sql, unique_keys, page_size=500)
|
||||
rows = cur.fetchall() or []
|
||||
result = {}
|
||||
if rows and isinstance(rows[0], dict):
|
||||
for r in rows:
|
||||
key = tuple(r[k] for k in business_keys)
|
||||
result[key] = r.get("payload")
|
||||
return result
|
||||
|
||||
key_len = len(business_keys)
|
||||
for r in rows:
|
||||
key = tuple(r[:key_len])
|
||||
payload = r[key_len] if len(r) > key_len else None
|
||||
result[key] = payload
|
||||
return result
|
||||
|
||||
|
||||
def _int_col(name: str, *sources: str, required: bool = False) -> ColumnSpec:
|
||||
return ColumnSpec(
|
||||
|
||||
Reference in New Issue
Block a user