68 lines
2.4 KiB
Python
68 lines
2.4 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Generic ODS loader that keeps raw payload + primary keys."""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from datetime import datetime, timezone
|
|
from typing import Iterable, Sequence
|
|
|
|
from ..base_loader import BaseLoader
|
|
|
|
|
|
class GenericODSLoader(BaseLoader):
|
|
"""Insert/update helper for ODS tables that share the same pattern."""
|
|
|
|
def __init__(
|
|
self,
|
|
db_ops,
|
|
table_name: str,
|
|
columns: Sequence[str],
|
|
conflict_columns: Sequence[str],
|
|
):
|
|
super().__init__(db_ops)
|
|
if not conflict_columns:
|
|
raise ValueError("conflict_columns must not be empty for ODS loader")
|
|
self.table_name = table_name
|
|
self.columns = list(columns)
|
|
self.conflict_columns = list(conflict_columns)
|
|
self._sql = self._build_sql()
|
|
|
|
def upsert_rows(self, rows: Iterable[dict]) -> tuple[int, int, int]:
|
|
"""Insert/update the provided iterable of dictionaries."""
|
|
rows = list(rows)
|
|
if not rows:
|
|
return (0, 0, 0)
|
|
|
|
normalized = [self._normalize_row(row) for row in rows]
|
|
inserted, updated = self.db.batch_upsert_with_returning(
|
|
self._sql, normalized, page_size=self._batch_size()
|
|
)
|
|
return inserted, updated, 0
|
|
|
|
def _build_sql(self) -> str:
|
|
col_list = ", ".join(self.columns)
|
|
placeholders = ", ".join(f"%({col})s" for col in self.columns)
|
|
conflict_clause = ", ".join(self.conflict_columns)
|
|
update_columns = [c for c in self.columns if c not in self.conflict_columns]
|
|
set_clause = ", ".join(f"{col} = EXCLUDED.{col}" for col in update_columns)
|
|
return (
|
|
f"INSERT INTO {self.table_name} ({col_list}) "
|
|
f"VALUES ({placeholders}) "
|
|
f"ON CONFLICT ({conflict_clause}) DO UPDATE SET {set_clause} "
|
|
f"RETURNING (xmax = 0) AS inserted"
|
|
)
|
|
|
|
def _normalize_row(self, row: dict) -> dict:
|
|
normalized = {}
|
|
for col in self.columns:
|
|
value = row.get(col)
|
|
if col == "payload" and value is not None and not isinstance(value, str):
|
|
normalized[col] = json.dumps(value, ensure_ascii=False)
|
|
else:
|
|
normalized[col] = value
|
|
|
|
if "fetched_at" in normalized and normalized["fetched_at"] is None:
|
|
normalized["fetched_at"] = datetime.now(timezone.utc)
|
|
|
|
return normalized
|