268 lines
8.1 KiB
Python
268 lines
8.1 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""PostgreSQL connection performance test (ASCII-only output)."""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import math
|
|
import os
|
|
import statistics
|
|
import sys
|
|
import time
|
|
from typing import Dict, Iterable, List
|
|
|
|
from psycopg2.extensions import make_dsn, parse_dsn
|
|
|
|
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
if PROJECT_ROOT not in sys.path:
|
|
sys.path.insert(0, PROJECT_ROOT)
|
|
|
|
from database.connection import DatabaseConnection
|
|
|
|
|
|
def _load_env() -> Dict[str, str]:
|
|
env: Dict[str, str] = {}
|
|
try:
|
|
from config.env_parser import _load_dotenv_values
|
|
except Exception:
|
|
_load_dotenv_values = None
|
|
if _load_dotenv_values:
|
|
try:
|
|
env.update(_load_dotenv_values())
|
|
except Exception:
|
|
pass
|
|
env.update(os.environ)
|
|
return env
|
|
|
|
|
|
def _apply_dsn_overrides(dsn: str, host: str | None, port: int | None) -> str:
|
|
overrides = {}
|
|
if host:
|
|
overrides["host"] = host
|
|
if port:
|
|
overrides["port"] = str(port)
|
|
if not overrides:
|
|
return dsn
|
|
return make_dsn(dsn, **overrides)
|
|
|
|
|
|
def _build_dsn_from_env(
|
|
host: str,
|
|
port: int,
|
|
user: str | None,
|
|
password: str | None,
|
|
dbname: str | None,
|
|
) -> str | None:
|
|
if not user or not dbname:
|
|
return None
|
|
params = {
|
|
"host": host,
|
|
"port": str(port),
|
|
"user": user,
|
|
"dbname": dbname,
|
|
}
|
|
if password:
|
|
params["password"] = password
|
|
return make_dsn("", **params)
|
|
|
|
|
|
def _safe_dsn_summary(dsn: str, host: str | None, port: int | None) -> str:
|
|
try:
|
|
info = parse_dsn(dsn)
|
|
except Exception:
|
|
info = {}
|
|
if host:
|
|
info["host"] = host
|
|
if port:
|
|
info["port"] = str(port)
|
|
info.pop("password", None)
|
|
if not info:
|
|
return "dsn=(hidden)"
|
|
items = " ".join(f"{k}={info[k]}" for k in sorted(info.keys()))
|
|
return items
|
|
|
|
|
|
def _percentile(values: List[float], pct: float) -> float:
|
|
if not values:
|
|
return 0.0
|
|
ordered = sorted(values)
|
|
if len(ordered) == 1:
|
|
return ordered[0]
|
|
rank = (len(ordered) - 1) * (pct / 100.0)
|
|
low = int(math.floor(rank))
|
|
high = int(math.ceil(rank))
|
|
if low == high:
|
|
return ordered[low]
|
|
return ordered[low] + (ordered[high] - ordered[low]) * (rank - low)
|
|
|
|
|
|
def _format_stats(label: str, values: Iterable[float]) -> str:
|
|
data = list(values)
|
|
if not data:
|
|
return f"{label}: no samples"
|
|
avg = statistics.mean(data)
|
|
stdev = statistics.stdev(data) if len(data) > 1 else 0.0
|
|
return (
|
|
f"{label}: count={len(data)} "
|
|
f"min={min(data):.2f}ms avg={avg:.2f}ms "
|
|
f"p50={_percentile(data, 50):.2f}ms "
|
|
f"p95={_percentile(data, 95):.2f}ms "
|
|
f"max={max(data):.2f}ms stdev={stdev:.2f}ms"
|
|
)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="PostgreSQL connection performance test")
|
|
parser.add_argument("--dsn", help="Override PG_DSN/TEST_DB_DSN/.env value")
|
|
parser.add_argument(
|
|
"--host",
|
|
default="100.64.0.4",
|
|
help="Override host in DSN (default: 100.64.0.4)",
|
|
)
|
|
parser.add_argument("--port", type=int, help="Override port in DSN")
|
|
parser.add_argument("--user", help="User when building DSN from PG_* env")
|
|
parser.add_argument("--password", help="Password when building DSN from PG_* env")
|
|
parser.add_argument("--dbname", help="Database name when building DSN from PG_* env")
|
|
parser.add_argument("--rounds", type=int, default=20, help="Measured connection rounds")
|
|
parser.add_argument("--warmup", type=int, default=2, help="Warmup rounds (not recorded)")
|
|
parser.add_argument("--query", default="SELECT 1", help="SQL to run after connect")
|
|
parser.add_argument(
|
|
"--query-repeat",
|
|
type=int,
|
|
default=1,
|
|
help="Query repetitions per connection (0 to skip)",
|
|
)
|
|
parser.add_argument(
|
|
"--connect-timeout",
|
|
type=int,
|
|
default=10,
|
|
help="connect_timeout seconds (capped at 20, default: 10)",
|
|
)
|
|
parser.add_argument(
|
|
"--statement-timeout-ms",
|
|
type=int,
|
|
help="Optional statement_timeout applied per connection",
|
|
)
|
|
parser.add_argument(
|
|
"--sleep-ms",
|
|
type=int,
|
|
default=0,
|
|
help="Sleep between rounds in milliseconds",
|
|
)
|
|
parser.add_argument(
|
|
"--continue-on-error",
|
|
action="store_true",
|
|
help="Continue even if a round fails",
|
|
)
|
|
parser.add_argument("--verbose", action="store_true", help="Print per-round timings")
|
|
return parser.parse_args()
|
|
|
|
|
|
def _run_round(
|
|
dsn: str,
|
|
timeout: int,
|
|
query: str,
|
|
query_repeat: int,
|
|
session: Dict[str, int] | None,
|
|
) -> tuple[float, List[float]]:
|
|
start = time.perf_counter()
|
|
conn = DatabaseConnection(dsn, connect_timeout=timeout, session=session)
|
|
connect_ms = (time.perf_counter() - start) * 1000.0
|
|
query_times: List[float] = []
|
|
try:
|
|
for _ in range(query_repeat):
|
|
q_start = time.perf_counter()
|
|
conn.query(query)
|
|
query_times.append((time.perf_counter() - q_start) * 1000.0)
|
|
return connect_ms, query_times
|
|
finally:
|
|
try:
|
|
conn.rollback()
|
|
except Exception:
|
|
pass
|
|
conn.close()
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
if args.rounds < 0 or args.warmup < 0 or args.query_repeat < 0:
|
|
print("rounds/warmup/query-repeat must be >= 0", file=sys.stderr)
|
|
return 2
|
|
env = _load_env()
|
|
|
|
dsn = args.dsn or env.get("PG_DSN") or env.get("TEST_DB_DSN")
|
|
host = args.host
|
|
port = args.port
|
|
|
|
if not dsn:
|
|
user = args.user or env.get("PG_USER")
|
|
password = args.password if args.password is not None else env.get("PG_PASSWORD")
|
|
dbname = args.dbname or env.get("PG_NAME")
|
|
try:
|
|
resolved_port = port or int(env.get("PG_PORT", "5432"))
|
|
except ValueError:
|
|
resolved_port = port or 5432
|
|
dsn = _build_dsn_from_env(host, resolved_port, user, password, dbname)
|
|
if not dsn:
|
|
print(
|
|
"Missing DSN. Provide --dsn or set PG_DSN/TEST_DB_DSN, or PG_USER + PG_NAME.",
|
|
file=sys.stderr,
|
|
)
|
|
return 2
|
|
dsn = _apply_dsn_overrides(dsn, host, port)
|
|
|
|
timeout = max(1, min(int(args.connect_timeout), 20))
|
|
session = None
|
|
if args.statement_timeout_ms is not None:
|
|
session = {"statement_timeout_ms": int(args.statement_timeout_ms)}
|
|
|
|
print("Target:", _safe_dsn_summary(dsn, host, port))
|
|
print(
|
|
f"Rounds: {args.rounds} (warmup {args.warmup}), "
|
|
f"query_repeat={args.query_repeat}, timeout={timeout}s"
|
|
)
|
|
if args.query_repeat > 0:
|
|
print("Query:", args.query)
|
|
|
|
connect_times: List[float] = []
|
|
query_times: List[float] = []
|
|
failures: List[str] = []
|
|
|
|
total = args.warmup + args.rounds
|
|
for idx in range(total):
|
|
is_warmup = idx < args.warmup
|
|
try:
|
|
c_ms, q_times = _run_round(
|
|
dsn, timeout, args.query, args.query_repeat, session
|
|
)
|
|
if not is_warmup:
|
|
connect_times.append(c_ms)
|
|
query_times.extend(q_times)
|
|
if args.verbose:
|
|
tag = "warmup" if is_warmup else "sample"
|
|
q_msg = ""
|
|
if args.query_repeat > 0:
|
|
q_avg = statistics.mean(q_times) if q_times else 0.0
|
|
q_msg = f", query_avg={q_avg:.2f}ms"
|
|
print(f"[{tag} {idx + 1}/{total}] connect={c_ms:.2f}ms{q_msg}")
|
|
except Exception as exc:
|
|
msg = f"round {idx + 1}: {exc}"
|
|
failures.append(msg)
|
|
print("Failure:", msg, file=sys.stderr)
|
|
if not args.continue_on_error:
|
|
break
|
|
if args.sleep_ms > 0:
|
|
time.sleep(args.sleep_ms / 1000.0)
|
|
|
|
if connect_times:
|
|
print(_format_stats("Connect", connect_times))
|
|
if args.query_repeat > 0:
|
|
print(_format_stats("Query", query_times))
|
|
if failures:
|
|
print(f"Failures: {len(failures)}", file=sys.stderr)
|
|
return 1
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|