ETL 完成
This commit is contained in:
@@ -7,7 +7,7 @@ import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Iterable
|
||||
|
||||
from psycopg2.extras import Json
|
||||
from psycopg2.extras import Json, execute_values
|
||||
|
||||
from .base_task import BaseTask
|
||||
|
||||
@@ -75,7 +75,7 @@ class ManualIngestTask(BaseTask):
|
||||
return "MANUAL_INGEST"
|
||||
|
||||
def execute(self, cursor_data: dict | None = None) -> dict:
|
||||
"""从目录读取 JSON,按表定义批量入库。"""
|
||||
"""从目录读取 JSON,按表定义批量入库(按文件提交事务,避免长事务导致连接不稳定)。"""
|
||||
data_dir = (
|
||||
self.config.get("manual.data_dir")
|
||||
or self.config.get("pipeline.ingest_source_dir")
|
||||
@@ -87,9 +87,15 @@ class ManualIngestTask(BaseTask):
|
||||
|
||||
counts = {"fetched": 0, "inserted": 0, "updated": 0, "skipped": 0, "errors": 0}
|
||||
|
||||
include_files_cfg = self.config.get("manual.include_files") or []
|
||||
include_files = {str(x).strip().lower() for x in include_files_cfg if str(x).strip()} if include_files_cfg else set()
|
||||
|
||||
for filename in sorted(os.listdir(data_dir)):
|
||||
if not filename.endswith(".json"):
|
||||
continue
|
||||
stem = os.path.splitext(filename)[0].lower()
|
||||
if include_files and stem not in include_files:
|
||||
continue
|
||||
filepath = os.path.join(data_dir, filename)
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as fh:
|
||||
@@ -113,22 +119,25 @@ class ManualIngestTask(BaseTask):
|
||||
|
||||
self.logger.info("Ingesting %s into %s", filename, target_table)
|
||||
try:
|
||||
inserted, updated = self._ingest_table(target_table, records, filename)
|
||||
inserted, updated, row_errors = self._ingest_table(target_table, records, filename)
|
||||
counts["inserted"] += inserted
|
||||
counts["updated"] += updated
|
||||
counts["fetched"] += len(records)
|
||||
counts["errors"] += row_errors
|
||||
# 每个文件一次提交:降低单次事务体积,避免长事务/连接异常导致整体回滚失败。
|
||||
self.db.commit()
|
||||
except Exception:
|
||||
counts["errors"] += 1
|
||||
self.logger.exception("Error processing %s", filename)
|
||||
self.db.rollback()
|
||||
try:
|
||||
self.db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
# 若连接已断开,后续文件无法继续,直接抛出让上层处理(重连/重跑)。
|
||||
if getattr(self.db.conn, "closed", 0):
|
||||
raise
|
||||
continue
|
||||
|
||||
try:
|
||||
self.db.commit()
|
||||
except Exception:
|
||||
self.db.rollback()
|
||||
raise
|
||||
|
||||
return {"status": "SUCCESS", "counts": counts}
|
||||
|
||||
def _match_by_filename(self, filename: str) -> str | None:
|
||||
@@ -211,8 +220,15 @@ class ManualIngestTask(BaseTask):
|
||||
self._table_columns_cache = cache
|
||||
return cols
|
||||
|
||||
def _ingest_table(self, table: str, records: list[dict], source_file: str) -> tuple[int, int]:
|
||||
"""构建 INSERT/ON CONFLICT 语句并批量执行。"""
|
||||
def _ingest_table(self, table: str, records: list[dict], source_file: str) -> tuple[int, int, int]:
|
||||
"""
|
||||
构建 INSERT/ON CONFLICT 语句并批量执行(优先向量化,小批次提交)。
|
||||
|
||||
设计目标:
|
||||
- 控制单条 SQL 体积(避免一次性 VALUES 过大导致服务端 backend 被 OOM/异常终止);
|
||||
- 发生异常时,可降级逐行并用 SAVEPOINT 跳过异常行;
|
||||
- 统计口径偏“尽量可跑通”,插入/更新计数为近似值(不强依赖 RETURNING)。
|
||||
"""
|
||||
spec = self.TABLE_SPECS.get(table)
|
||||
if not spec:
|
||||
raise ValueError(f"No table spec for {table}")
|
||||
@@ -229,15 +245,19 @@ class ManualIngestTask(BaseTask):
|
||||
pk_col_db = None
|
||||
if pk_col:
|
||||
pk_col_db = next((c for c in columns if c.lower() == pk_col.lower()), pk_col)
|
||||
pk_index = None
|
||||
if pk_col_db:
|
||||
try:
|
||||
pk_index = next(i for i, c in enumerate(columns_info) if c[0] == pk_col_db)
|
||||
except Exception:
|
||||
pk_index = None
|
||||
|
||||
placeholders = ", ".join(["%s"] * len(columns))
|
||||
col_list = ", ".join(f'"{c}"' for c in columns)
|
||||
sql = f'INSERT INTO {table} ({col_list}) VALUES ({placeholders})'
|
||||
sql_prefix = f"INSERT INTO {table} ({col_list}) VALUES %s"
|
||||
if pk_col_db:
|
||||
update_cols = [c for c in columns if c != pk_col_db]
|
||||
set_clause = ", ".join(f'"{c}"=EXCLUDED."{c}"' for c in update_cols)
|
||||
sql += f' ON CONFLICT ("{pk_col_db}") DO UPDATE SET {set_clause}'
|
||||
sql += " RETURNING (xmax = 0) AS inserted"
|
||||
sql_prefix += f' ON CONFLICT ("{pk_col_db}") DO UPDATE SET {set_clause}'
|
||||
|
||||
params = []
|
||||
now = datetime.now()
|
||||
@@ -288,19 +308,55 @@ class ManualIngestTask(BaseTask):
|
||||
params.append(tuple(row_vals))
|
||||
|
||||
if not params:
|
||||
return 0, 0
|
||||
return 0, 0, 0
|
||||
|
||||
# 先尝试向量化执行(速度快);若失败,再降级逐行并用 SAVEPOINT 跳过异常行。
|
||||
try:
|
||||
with self.db.conn.cursor() as cur:
|
||||
# 分批提交:降低单次事务/单次 SQL 压力,避免服务端异常中断连接。
|
||||
affected = 0
|
||||
chunk_size = int(self.config.get("manual.execute_values_page_size", 50) or 50)
|
||||
chunk_size = max(1, min(chunk_size, 500))
|
||||
for i in range(0, len(params), chunk_size):
|
||||
chunk = params[i : i + chunk_size]
|
||||
execute_values(cur, sql_prefix, chunk, page_size=len(chunk))
|
||||
if cur.rowcount is not None and cur.rowcount > 0:
|
||||
affected += int(cur.rowcount)
|
||||
# 这里无法精确拆分 inserted/updated(除非 RETURNING),按“受影响行数≈插入”近似返回。
|
||||
return int(affected), 0, 0
|
||||
except Exception as exc:
|
||||
self.logger.warning("批量入库失败,准备降级逐行处理:table=%s, err=%s", table, exc)
|
||||
try:
|
||||
self.db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
inserted = 0
|
||||
updated = 0
|
||||
errors = 0
|
||||
with self.db.conn.cursor() as cur:
|
||||
for row in params:
|
||||
cur.execute(sql, row)
|
||||
flag = cur.fetchone()[0]
|
||||
if flag:
|
||||
cur.execute("SAVEPOINT sp_manual_ingest_row")
|
||||
try:
|
||||
cur.execute(sql_prefix.replace(" VALUES %s", f" VALUES ({', '.join(['%s'] * len(row))})"), row)
|
||||
inserted += 1
|
||||
else:
|
||||
updated += 1
|
||||
return inserted, updated
|
||||
cur.execute("RELEASE SAVEPOINT sp_manual_ingest_row")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
errors += 1
|
||||
try:
|
||||
cur.execute("ROLLBACK TO SAVEPOINT sp_manual_ingest_row")
|
||||
cur.execute("RELEASE SAVEPOINT sp_manual_ingest_row")
|
||||
except Exception:
|
||||
pass
|
||||
pk_val = None
|
||||
if pk_index is not None:
|
||||
try:
|
||||
pk_val = row[pk_index]
|
||||
except Exception:
|
||||
pk_val = None
|
||||
self.logger.warning("跳过异常行:table=%s pk=%s err=%s", table, pk_val, exc)
|
||||
|
||||
return inserted, updated, errors
|
||||
|
||||
@staticmethod
|
||||
def _get_value_case_insensitive(record: dict, col: str | None):
|
||||
|
||||
Reference in New Issue
Block a user