ETL 完成

This commit is contained in:
Neo
2026-01-18 22:37:38 +08:00
parent 8da6cb6563
commit 7ca19a4a2c
159 changed files with 31225 additions and 467 deletions

View File

@@ -7,7 +7,7 @@ import os
from datetime import datetime
from typing import Any, Iterable
from psycopg2.extras import Json
from psycopg2.extras import Json, execute_values
from .base_task import BaseTask
@@ -75,7 +75,7 @@ class ManualIngestTask(BaseTask):
return "MANUAL_INGEST"
def execute(self, cursor_data: dict | None = None) -> dict:
"""从目录读取 JSON按表定义批量入库。"""
"""从目录读取 JSON按表定义批量入库(按文件提交事务,避免长事务导致连接不稳定)"""
data_dir = (
self.config.get("manual.data_dir")
or self.config.get("pipeline.ingest_source_dir")
@@ -87,9 +87,15 @@ class ManualIngestTask(BaseTask):
counts = {"fetched": 0, "inserted": 0, "updated": 0, "skipped": 0, "errors": 0}
include_files_cfg = self.config.get("manual.include_files") or []
include_files = {str(x).strip().lower() for x in include_files_cfg if str(x).strip()} if include_files_cfg else set()
for filename in sorted(os.listdir(data_dir)):
if not filename.endswith(".json"):
continue
stem = os.path.splitext(filename)[0].lower()
if include_files and stem not in include_files:
continue
filepath = os.path.join(data_dir, filename)
try:
with open(filepath, "r", encoding="utf-8") as fh:
@@ -113,22 +119,25 @@ class ManualIngestTask(BaseTask):
self.logger.info("Ingesting %s into %s", filename, target_table)
try:
inserted, updated = self._ingest_table(target_table, records, filename)
inserted, updated, row_errors = self._ingest_table(target_table, records, filename)
counts["inserted"] += inserted
counts["updated"] += updated
counts["fetched"] += len(records)
counts["errors"] += row_errors
# 每个文件一次提交:降低单次事务体积,避免长事务/连接异常导致整体回滚失败。
self.db.commit()
except Exception:
counts["errors"] += 1
self.logger.exception("Error processing %s", filename)
self.db.rollback()
try:
self.db.rollback()
except Exception:
pass
# 若连接已断开,后续文件无法继续,直接抛出让上层处理(重连/重跑)。
if getattr(self.db.conn, "closed", 0):
raise
continue
try:
self.db.commit()
except Exception:
self.db.rollback()
raise
return {"status": "SUCCESS", "counts": counts}
def _match_by_filename(self, filename: str) -> str | None:
@@ -211,8 +220,15 @@ class ManualIngestTask(BaseTask):
self._table_columns_cache = cache
return cols
def _ingest_table(self, table: str, records: list[dict], source_file: str) -> tuple[int, int]:
"""构建 INSERT/ON CONFLICT 语句并批量执行。"""
def _ingest_table(self, table: str, records: list[dict], source_file: str) -> tuple[int, int, int]:
"""
构建 INSERT/ON CONFLICT 语句并批量执行(优先向量化,小批次提交)。
设计目标:
- 控制单条 SQL 体积(避免一次性 VALUES 过大导致服务端 backend 被 OOM/异常终止);
- 发生异常时,可降级逐行并用 SAVEPOINT 跳过异常行;
- 统计口径偏“尽量可跑通”,插入/更新计数为近似值(不强依赖 RETURNING
"""
spec = self.TABLE_SPECS.get(table)
if not spec:
raise ValueError(f"No table spec for {table}")
@@ -229,15 +245,19 @@ class ManualIngestTask(BaseTask):
pk_col_db = None
if pk_col:
pk_col_db = next((c for c in columns if c.lower() == pk_col.lower()), pk_col)
pk_index = None
if pk_col_db:
try:
pk_index = next(i for i, c in enumerate(columns_info) if c[0] == pk_col_db)
except Exception:
pk_index = None
placeholders = ", ".join(["%s"] * len(columns))
col_list = ", ".join(f'"{c}"' for c in columns)
sql = f'INSERT INTO {table} ({col_list}) VALUES ({placeholders})'
sql_prefix = f"INSERT INTO {table} ({col_list}) VALUES %s"
if pk_col_db:
update_cols = [c for c in columns if c != pk_col_db]
set_clause = ", ".join(f'"{c}"=EXCLUDED."{c}"' for c in update_cols)
sql += f' ON CONFLICT ("{pk_col_db}") DO UPDATE SET {set_clause}'
sql += " RETURNING (xmax = 0) AS inserted"
sql_prefix += f' ON CONFLICT ("{pk_col_db}") DO UPDATE SET {set_clause}'
params = []
now = datetime.now()
@@ -288,19 +308,55 @@ class ManualIngestTask(BaseTask):
params.append(tuple(row_vals))
if not params:
return 0, 0
return 0, 0, 0
# 先尝试向量化执行(速度快);若失败,再降级逐行并用 SAVEPOINT 跳过异常行。
try:
with self.db.conn.cursor() as cur:
# 分批提交:降低单次事务/单次 SQL 压力,避免服务端异常中断连接。
affected = 0
chunk_size = int(self.config.get("manual.execute_values_page_size", 50) or 50)
chunk_size = max(1, min(chunk_size, 500))
for i in range(0, len(params), chunk_size):
chunk = params[i : i + chunk_size]
execute_values(cur, sql_prefix, chunk, page_size=len(chunk))
if cur.rowcount is not None and cur.rowcount > 0:
affected += int(cur.rowcount)
# 这里无法精确拆分 inserted/updated除非 RETURNING按“受影响行数≈插入”近似返回。
return int(affected), 0, 0
except Exception as exc:
self.logger.warning("批量入库失败准备降级逐行处理table=%s, err=%s", table, exc)
try:
self.db.rollback()
except Exception:
pass
inserted = 0
updated = 0
errors = 0
with self.db.conn.cursor() as cur:
for row in params:
cur.execute(sql, row)
flag = cur.fetchone()[0]
if flag:
cur.execute("SAVEPOINT sp_manual_ingest_row")
try:
cur.execute(sql_prefix.replace(" VALUES %s", f" VALUES ({', '.join(['%s'] * len(row))})"), row)
inserted += 1
else:
updated += 1
return inserted, updated
cur.execute("RELEASE SAVEPOINT sp_manual_ingest_row")
except Exception as exc: # noqa: BLE001
errors += 1
try:
cur.execute("ROLLBACK TO SAVEPOINT sp_manual_ingest_row")
cur.execute("RELEASE SAVEPOINT sp_manual_ingest_row")
except Exception:
pass
pk_val = None
if pk_index is not None:
try:
pk_val = row[pk_index]
except Exception:
pk_val = None
self.logger.warning("跳过异常行table=%s pk=%s err=%s", table, pk_val, exc)
return inserted, updated, errors
@staticmethod
def _get_value_case_insensitive(record: dict, col: str | None):