# -*- coding: utf-8 -*- """数据库连接管理器(限制最大连接超时时间)。""" import psycopg2 import psycopg2.extras class DatabaseConnection: """封装 psycopg2 连接,支持会话参数和超时保护。""" def __init__(self, dsn: str, session: dict = None, connect_timeout: int = None): self._dsn = dsn self._session = session or {} self._connect_timeout = connect_timeout self.conn = self._open_connection() def _open_connection(self): """创建并初始化连接(包含会话参数)。""" timeout_val = self._connect_timeout if self._connect_timeout is not None else 5 # 生产环境要求:数据库连接超时不得超过 20 秒。 timeout_val = max(1, min(int(timeout_val), 20)) conn = psycopg2.connect(self._dsn, connect_timeout=timeout_val) conn.autocommit = False # 会话参数(时区、语句超时等) if self._session: with conn.cursor() as c: if self._session.get("timezone"): c.execute("SET TIME ZONE %s", (self._session["timezone"],)) if self._session.get("statement_timeout_ms") is not None: c.execute( "SET statement_timeout = %s", (int(self._session["statement_timeout_ms"]),), ) if self._session.get("lock_timeout_ms") is not None: c.execute( "SET lock_timeout = %s", (int(self._session["lock_timeout_ms"]),) ) if self._session.get("idle_in_tx_timeout_ms") is not None: c.execute( "SET idle_in_transaction_session_timeout = %s", (int(self._session["idle_in_tx_timeout_ms"]),), ) return conn def query(self, sql: str, args=None): """Execute a query and fetch all rows.""" with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as c: c.execute(sql, args) return c.fetchall() def execute(self, sql: str, args=None): """Execute a SQL statement without returning rows.""" with self.conn.cursor() as c: c.execute(sql, args) def commit(self): """Commit current transaction.""" self.conn.commit() def rollback(self): """Rollback current transaction.""" self.conn.rollback() def close(self): """Safely close the connection.""" try: self.conn.close() except Exception: pass def ensure_open(self) -> bool: """确保连接可用,若已关闭则尝试重连。""" try: if getattr(self.conn, "closed", 0): self.conn = self._open_connection() return True except Exception: return False