Files
2025-11-18 02:28:47 +08:00

113 lines
3.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
数据库操作批量、RETURNING支持
"""
import re
from typing import List, Dict, Tuple
import psycopg2.extras
from .connection import DatabaseConnection
class DatabaseOperations(DatabaseConnection):
"""扩展数据库操作包含批量upsert和returning支持"""
def batch_execute(self, sql: str, rows: List[Dict], page_size: int = 1000):
"""批量执行SQL不带RETURNING"""
if not rows:
return
with self.conn.cursor() as c:
psycopg2.extras.execute_batch(c, sql, rows, page_size=page_size)
def batch_upsert_with_returning(self, sql: str, rows: List[Dict], page_size: int = 1000) -> Tuple[int, int]:
"""
批量 UPSERT 并统计插入/更新数
Args:
sql: 包含RETURNING子句的SQL
rows: 数据行列表
page_size: 批次大小
Returns:
(inserted_count, updated_count) 元组
"""
if not rows:
return (0, 0)
use_returning = "RETURNING" in sql.upper()
with self.conn.cursor() as c:
if not use_returning:
psycopg2.extras.execute_batch(c, sql, rows, page_size=page_size)
return (0, 0)
# 优先尝试向量化执行
try:
inserted, updated = self._execute_with_returning_vectorized(c, sql, rows, page_size)
return (inserted, updated)
except Exception:
# 回退到逐行执行
return self._execute_with_returning_row_by_row(c, sql, rows)
def _execute_with_returning_vectorized(self, cursor, sql: str, rows: List[Dict], page_size: int) -> Tuple[int, int]:
"""向量化执行使用execute_values"""
# 解析VALUES子句
m = re.search(r"VALUES\s*\((.*?)\)", sql, flags=re.IGNORECASE | re.DOTALL)
if not m:
raise ValueError("Cannot parse VALUES clause")
tpl = "(" + m.group(1) + ")"
base_sql = sql[:m.start()] + "VALUES %s" + sql[m.end():]
ret = psycopg2.extras.execute_values(
cursor, base_sql, rows, template=tpl, page_size=page_size, fetch=True
)
if not ret:
return (0, 0)
inserted = 0
for rec in ret:
flag = self._extract_inserted_flag(rec)
if flag:
inserted += 1
return (inserted, len(ret) - inserted)
def _execute_with_returning_row_by_row(self, cursor, sql: str, rows: List[Dict]) -> Tuple[int, int]:
"""逐行执行(回退方案)"""
inserted = 0
updated = 0
for r in rows:
cursor.execute(sql, r)
try:
rec = cursor.fetchone()
except Exception:
rec = None
flag = self._extract_inserted_flag(rec) if rec else None
if flag:
inserted += 1
else:
updated += 1
return (inserted, updated)
@staticmethod
def _extract_inserted_flag(rec) -> bool:
"""从返回记录中提取inserted标志"""
if isinstance(rec, tuple):
return bool(rec[0])
elif isinstance(rec, dict):
return bool(rec.get("inserted"))
else:
try:
return bool(rec["inserted"])
except Exception:
return False
# 为了向后兼容提供Pg别名
Pg = DatabaseOperations