# -*- coding: utf-8 -*- """PostgreSQL connection performance test (ASCII-only output).""" from __future__ import annotations import argparse import math import os import statistics import sys import time from typing import Dict, Iterable, List from psycopg2.extensions import make_dsn, parse_dsn PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) if PROJECT_ROOT not in sys.path: sys.path.insert(0, PROJECT_ROOT) from database.connection import DatabaseConnection def _load_env() -> Dict[str, str]: env: Dict[str, str] = {} try: from config.env_parser import _load_dotenv_values except Exception: _load_dotenv_values = None if _load_dotenv_values: try: env.update(_load_dotenv_values()) except Exception: pass env.update(os.environ) return env def _apply_dsn_overrides(dsn: str, host: str | None, port: int | None) -> str: overrides = {} if host: overrides["host"] = host if port: overrides["port"] = str(port) if not overrides: return dsn return make_dsn(dsn, **overrides) def _build_dsn_from_env( host: str, port: int, user: str | None, password: str | None, dbname: str | None, ) -> str | None: if not user or not dbname: return None params = { "host": host, "port": str(port), "user": user, "dbname": dbname, } if password: params["password"] = password return make_dsn("", **params) def _safe_dsn_summary(dsn: str, host: str | None, port: int | None) -> str: try: info = parse_dsn(dsn) except Exception: info = {} if host: info["host"] = host if port: info["port"] = str(port) info.pop("password", None) if not info: return "dsn=(hidden)" items = " ".join(f"{k}={info[k]}" for k in sorted(info.keys())) return items def _percentile(values: List[float], pct: float) -> float: if not values: return 0.0 ordered = sorted(values) if len(ordered) == 1: return ordered[0] rank = (len(ordered) - 1) * (pct / 100.0) low = int(math.floor(rank)) high = int(math.ceil(rank)) if low == high: return ordered[low] return ordered[low] + (ordered[high] - ordered[low]) * (rank - low) def _format_stats(label: str, values: Iterable[float]) -> str: data = list(values) if not data: return f"{label}: no samples" avg = statistics.mean(data) stdev = statistics.stdev(data) if len(data) > 1 else 0.0 return ( f"{label}: count={len(data)} " f"min={min(data):.2f}ms avg={avg:.2f}ms " f"p50={_percentile(data, 50):.2f}ms " f"p95={_percentile(data, 95):.2f}ms " f"max={max(data):.2f}ms stdev={stdev:.2f}ms" ) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="PostgreSQL connection performance test") parser.add_argument("--dsn", help="Override PG_DSN/TEST_DB_DSN/.env value") parser.add_argument( "--host", default="100.64.0.4", help="Override host in DSN (default: 100.64.0.4)", ) parser.add_argument("--port", type=int, help="Override port in DSN") parser.add_argument("--user", help="User when building DSN from PG_* env") parser.add_argument("--password", help="Password when building DSN from PG_* env") parser.add_argument("--dbname", help="Database name when building DSN from PG_* env") parser.add_argument("--rounds", type=int, default=20, help="Measured connection rounds") parser.add_argument("--warmup", type=int, default=2, help="Warmup rounds (not recorded)") parser.add_argument("--query", default="SELECT 1", help="SQL to run after connect") parser.add_argument( "--query-repeat", type=int, default=1, help="Query repetitions per connection (0 to skip)", ) parser.add_argument( "--connect-timeout", type=int, default=10, help="connect_timeout seconds (capped at 20, default: 10)", ) parser.add_argument( "--statement-timeout-ms", type=int, help="Optional statement_timeout applied per connection", ) parser.add_argument( "--sleep-ms", type=int, default=0, help="Sleep between rounds in milliseconds", ) parser.add_argument( "--continue-on-error", action="store_true", help="Continue even if a round fails", ) parser.add_argument("--verbose", action="store_true", help="Print per-round timings") return parser.parse_args() def _run_round( dsn: str, timeout: int, query: str, query_repeat: int, session: Dict[str, int] | None, ) -> tuple[float, List[float]]: start = time.perf_counter() conn = DatabaseConnection(dsn, connect_timeout=timeout, session=session) connect_ms = (time.perf_counter() - start) * 1000.0 query_times: List[float] = [] try: for _ in range(query_repeat): q_start = time.perf_counter() conn.query(query) query_times.append((time.perf_counter() - q_start) * 1000.0) return connect_ms, query_times finally: try: conn.rollback() except Exception: pass conn.close() def main() -> int: args = parse_args() if args.rounds < 0 or args.warmup < 0 or args.query_repeat < 0: print("rounds/warmup/query-repeat must be >= 0", file=sys.stderr) return 2 env = _load_env() dsn = args.dsn or env.get("PG_DSN") or env.get("TEST_DB_DSN") host = args.host port = args.port if not dsn: user = args.user or env.get("PG_USER") password = args.password if args.password is not None else env.get("PG_PASSWORD") dbname = args.dbname or env.get("PG_NAME") try: resolved_port = port or int(env.get("PG_PORT", "5432")) except ValueError: resolved_port = port or 5432 dsn = _build_dsn_from_env(host, resolved_port, user, password, dbname) if not dsn: print( "Missing DSN. Provide --dsn or set PG_DSN/TEST_DB_DSN, or PG_USER + PG_NAME.", file=sys.stderr, ) return 2 dsn = _apply_dsn_overrides(dsn, host, port) timeout = max(1, min(int(args.connect_timeout), 20)) session = None if args.statement_timeout_ms is not None: session = {"statement_timeout_ms": int(args.statement_timeout_ms)} print("Target:", _safe_dsn_summary(dsn, host, port)) print( f"Rounds: {args.rounds} (warmup {args.warmup}), " f"query_repeat={args.query_repeat}, timeout={timeout}s" ) if args.query_repeat > 0: print("Query:", args.query) connect_times: List[float] = [] query_times: List[float] = [] failures: List[str] = [] total = args.warmup + args.rounds for idx in range(total): is_warmup = idx < args.warmup try: c_ms, q_times = _run_round( dsn, timeout, args.query, args.query_repeat, session ) if not is_warmup: connect_times.append(c_ms) query_times.extend(q_times) if args.verbose: tag = "warmup" if is_warmup else "sample" q_msg = "" if args.query_repeat > 0: q_avg = statistics.mean(q_times) if q_times else 0.0 q_msg = f", query_avg={q_avg:.2f}ms" print(f"[{tag} {idx + 1}/{total}] connect={c_ms:.2f}ms{q_msg}") except Exception as exc: msg = f"round {idx + 1}: {exc}" failures.append(msg) print("Failure:", msg, file=sys.stderr) if not args.continue_on_error: break if args.sleep_ms > 0: time.sleep(args.sleep_ms / 1000.0) if connect_times: print(_format_stats("Connect", connect_times)) if args.query_repeat > 0: print(_format_stats("Query", query_times)) if failures: print(f"Failures: {len(failures)}", file=sys.stderr) return 1 return 0 if __name__ == "__main__": raise SystemExit(main())