# -*- coding: utf-8 -*- """Database connection manager with capped connect_timeout.""" import psycopg2 import psycopg2.extras class DatabaseConnection: """Wrap psycopg2 connection with session parameters and timeout guard.""" def __init__(self, dsn: str, session: dict = None, connect_timeout: int = None): timeout_val = connect_timeout if connect_timeout is not None else 5 # PRD: database connect_timeout must not exceed 20 seconds. timeout_val = max(1, min(int(timeout_val), 20)) self.conn = psycopg2.connect(dsn, connect_timeout=timeout_val) self.conn.autocommit = False # Session parameters (timezone, statement timeout, etc.) if session: with self.conn.cursor() as c: if session.get("timezone"): c.execute("SET TIME ZONE %s", (session["timezone"],)) if session.get("statement_timeout_ms") is not None: c.execute( "SET statement_timeout = %s", (int(session["statement_timeout_ms"]),), ) if session.get("lock_timeout_ms") is not None: c.execute( "SET lock_timeout = %s", (int(session["lock_timeout_ms"]),) ) if session.get("idle_in_tx_timeout_ms") is not None: c.execute( "SET idle_in_transaction_session_timeout = %s", (int(session["idle_in_tx_timeout_ms"]),), ) def query(self, sql: str, args=None): """Execute a query and fetch all rows.""" with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as c: c.execute(sql, args) return c.fetchall() def execute(self, sql: str, args=None): """Execute a SQL statement without returning rows.""" with self.conn.cursor() as c: c.execute(sql, args) def commit(self): """Commit current transaction.""" self.conn.commit() def rollback(self): """Rollback current transaction.""" self.conn.rollback() def close(self): """Safely close the connection.""" try: self.conn.close() except Exception: pass