# -*- coding: utf-8 -*- """Generic ODS loader that keeps raw payload + primary keys.""" from __future__ import annotations import json from datetime import datetime, timezone from typing import Iterable, Sequence from ..base_loader import BaseLoader class GenericODSLoader(BaseLoader): """Insert/update helper for ODS tables that share the same pattern.""" def __init__( self, db_ops, table_name: str, columns: Sequence[str], conflict_columns: Sequence[str], ): super().__init__(db_ops) if not conflict_columns: raise ValueError("conflict_columns must not be empty for ODS loader") self.table_name = table_name self.columns = list(columns) self.conflict_columns = list(conflict_columns) self._sql = self._build_sql() def upsert_rows(self, rows: Iterable[dict]) -> tuple[int, int, int]: """Insert/update the provided iterable of dictionaries.""" rows = list(rows) if not rows: return (0, 0, 0) normalized = [self._normalize_row(row) for row in rows] inserted, updated = self.db.batch_upsert_with_returning( self._sql, normalized, page_size=self._batch_size() ) return inserted, updated, 0 def _build_sql(self) -> str: col_list = ", ".join(self.columns) placeholders = ", ".join(f"%({col})s" for col in self.columns) conflict_clause = ", ".join(self.conflict_columns) update_columns = [c for c in self.columns if c not in self.conflict_columns] set_clause = ", ".join(f"{col} = EXCLUDED.{col}" for col in update_columns) return ( f"INSERT INTO {self.table_name} ({col_list}) " f"VALUES ({placeholders}) " f"ON CONFLICT ({conflict_clause}) DO UPDATE SET {set_clause} " f"RETURNING (xmax = 0) AS inserted" ) def _normalize_row(self, row: dict) -> dict: normalized = {} for col in self.columns: value = row.get(col) if col == "payload" and value is not None and not isinstance(value, str): normalized[col] = json.dumps(value, ensure_ascii=False) else: normalized[col] = value if "fetched_at" in normalized and normalized["fetched_at"] is None: normalized["fetched_at"] = datetime.now(timezone.utc) return normalized