Files
feiqiu-ETL/etl_billiards/database/operations.py
2025-11-20 01:27:33 +08:00

100 lines
3.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""数据库批量操作"""
import psycopg2.extras
import re
class DatabaseOperations:
"""数据库批量操作封装"""
def __init__(self, connection):
self._connection = connection
self.conn = connection.conn
def batch_execute(self, sql: str, rows: list, page_size: int = 1000):
"""批量执行SQL"""
if not rows:
return
with self.conn.cursor() as c:
psycopg2.extras.execute_batch(c, sql, rows, page_size=page_size)
def batch_upsert_with_returning(self, sql: str, rows: list,
page_size: int = 1000) -> tuple:
"""批量UPSERT并返回插入/更新计数"""
if not rows:
return (0, 0)
use_returning = "RETURNING" in sql.upper()
with self.conn.cursor() as c:
if not use_returning:
psycopg2.extras.execute_batch(c, sql, rows, page_size=page_size)
return (0, 0)
# 尝试向量化执行
try:
m = re.search(r"VALUES\s*\((.*?)\)", sql, flags=re.IGNORECASE | re.DOTALL)
if m:
tpl = "(" + m.group(1) + ")"
base_sql = sql[:m.start()] + "VALUES %s" + sql[m.end():]
ret = psycopg2.extras.execute_values(
c, base_sql, rows, template=tpl, page_size=page_size, fetch=True
)
if not ret:
return (0, 0)
inserted = sum(1 for rec in ret if self._is_inserted(rec))
return (inserted, len(ret) - inserted)
except Exception:
pass
# 回退:逐行执行
inserted = 0
updated = 0
for r in rows:
c.execute(sql, r)
try:
rec = c.fetchone()
except Exception:
rec = None
if self._is_inserted(rec):
inserted += 1
else:
updated += 1
return (inserted, updated)
@staticmethod
def _is_inserted(rec) -> bool:
"""判断是否为插入操作"""
if rec is None:
return False
if isinstance(rec, tuple):
return bool(rec[0])
if isinstance(rec, dict):
return bool(rec.get("inserted"))
return False
# --- pass-through helpers -------------------------------------------------
def commit(self):
"""提交事务(委托给底层连接)"""
self._connection.commit()
def rollback(self):
"""回滚事务(委托给底层连接)"""
self._connection.rollback()
def query(self, sql: str, args=None):
"""执行查询并返回结果"""
return self._connection.query(sql, args)
def execute(self, sql: str, args=None):
"""执行任意 SQL"""
self._connection.execute(sql, args)
def cursor(self):
"""暴露原生 cursor供特殊操作使用"""
return self.conn.cursor()