Files
feiqiu-ETL/etl_billiards/scripts/test_db_performance.py
2026-01-27 22:47:05 +08:00

268 lines
8.1 KiB
Python

# -*- coding: utf-8 -*-
"""PostgreSQL connection performance test (ASCII-only output)."""
from __future__ import annotations
import argparse
import math
import os
import statistics
import sys
import time
from typing import Dict, Iterable, List
from psycopg2.extensions import make_dsn, parse_dsn
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
from database.connection import DatabaseConnection
def _load_env() -> Dict[str, str]:
env: Dict[str, str] = {}
try:
from config.env_parser import _load_dotenv_values
except Exception:
_load_dotenv_values = None
if _load_dotenv_values:
try:
env.update(_load_dotenv_values())
except Exception:
pass
env.update(os.environ)
return env
def _apply_dsn_overrides(dsn: str, host: str | None, port: int | None) -> str:
overrides = {}
if host:
overrides["host"] = host
if port:
overrides["port"] = str(port)
if not overrides:
return dsn
return make_dsn(dsn, **overrides)
def _build_dsn_from_env(
host: str,
port: int,
user: str | None,
password: str | None,
dbname: str | None,
) -> str | None:
if not user or not dbname:
return None
params = {
"host": host,
"port": str(port),
"user": user,
"dbname": dbname,
}
if password:
params["password"] = password
return make_dsn("", **params)
def _safe_dsn_summary(dsn: str, host: str | None, port: int | None) -> str:
try:
info = parse_dsn(dsn)
except Exception:
info = {}
if host:
info["host"] = host
if port:
info["port"] = str(port)
info.pop("password", None)
if not info:
return "dsn=(hidden)"
items = " ".join(f"{k}={info[k]}" for k in sorted(info.keys()))
return items
def _percentile(values: List[float], pct: float) -> float:
if not values:
return 0.0
ordered = sorted(values)
if len(ordered) == 1:
return ordered[0]
rank = (len(ordered) - 1) * (pct / 100.0)
low = int(math.floor(rank))
high = int(math.ceil(rank))
if low == high:
return ordered[low]
return ordered[low] + (ordered[high] - ordered[low]) * (rank - low)
def _format_stats(label: str, values: Iterable[float]) -> str:
data = list(values)
if not data:
return f"{label}: no samples"
avg = statistics.mean(data)
stdev = statistics.stdev(data) if len(data) > 1 else 0.0
return (
f"{label}: count={len(data)} "
f"min={min(data):.2f}ms avg={avg:.2f}ms "
f"p50={_percentile(data, 50):.2f}ms "
f"p95={_percentile(data, 95):.2f}ms "
f"max={max(data):.2f}ms stdev={stdev:.2f}ms"
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="PostgreSQL connection performance test")
parser.add_argument("--dsn", help="Override PG_DSN/TEST_DB_DSN/.env value")
parser.add_argument(
"--host",
default="100.64.0.4",
help="Override host in DSN (default: 100.64.0.4)",
)
parser.add_argument("--port", type=int, help="Override port in DSN")
parser.add_argument("--user", help="User when building DSN from PG_* env")
parser.add_argument("--password", help="Password when building DSN from PG_* env")
parser.add_argument("--dbname", help="Database name when building DSN from PG_* env")
parser.add_argument("--rounds", type=int, default=20, help="Measured connection rounds")
parser.add_argument("--warmup", type=int, default=2, help="Warmup rounds (not recorded)")
parser.add_argument("--query", default="SELECT 1", help="SQL to run after connect")
parser.add_argument(
"--query-repeat",
type=int,
default=1,
help="Query repetitions per connection (0 to skip)",
)
parser.add_argument(
"--connect-timeout",
type=int,
default=10,
help="connect_timeout seconds (capped at 20, default: 10)",
)
parser.add_argument(
"--statement-timeout-ms",
type=int,
help="Optional statement_timeout applied per connection",
)
parser.add_argument(
"--sleep-ms",
type=int,
default=0,
help="Sleep between rounds in milliseconds",
)
parser.add_argument(
"--continue-on-error",
action="store_true",
help="Continue even if a round fails",
)
parser.add_argument("--verbose", action="store_true", help="Print per-round timings")
return parser.parse_args()
def _run_round(
dsn: str,
timeout: int,
query: str,
query_repeat: int,
session: Dict[str, int] | None,
) -> tuple[float, List[float]]:
start = time.perf_counter()
conn = DatabaseConnection(dsn, connect_timeout=timeout, session=session)
connect_ms = (time.perf_counter() - start) * 1000.0
query_times: List[float] = []
try:
for _ in range(query_repeat):
q_start = time.perf_counter()
conn.query(query)
query_times.append((time.perf_counter() - q_start) * 1000.0)
return connect_ms, query_times
finally:
try:
conn.rollback()
except Exception:
pass
conn.close()
def main() -> int:
args = parse_args()
if args.rounds < 0 or args.warmup < 0 or args.query_repeat < 0:
print("rounds/warmup/query-repeat must be >= 0", file=sys.stderr)
return 2
env = _load_env()
dsn = args.dsn or env.get("PG_DSN") or env.get("TEST_DB_DSN")
host = args.host
port = args.port
if not dsn:
user = args.user or env.get("PG_USER")
password = args.password if args.password is not None else env.get("PG_PASSWORD")
dbname = args.dbname or env.get("PG_NAME")
try:
resolved_port = port or int(env.get("PG_PORT", "5432"))
except ValueError:
resolved_port = port or 5432
dsn = _build_dsn_from_env(host, resolved_port, user, password, dbname)
if not dsn:
print(
"Missing DSN. Provide --dsn or set PG_DSN/TEST_DB_DSN, or PG_USER + PG_NAME.",
file=sys.stderr,
)
return 2
dsn = _apply_dsn_overrides(dsn, host, port)
timeout = max(1, min(int(args.connect_timeout), 20))
session = None
if args.statement_timeout_ms is not None:
session = {"statement_timeout_ms": int(args.statement_timeout_ms)}
print("Target:", _safe_dsn_summary(dsn, host, port))
print(
f"Rounds: {args.rounds} (warmup {args.warmup}), "
f"query_repeat={args.query_repeat}, timeout={timeout}s"
)
if args.query_repeat > 0:
print("Query:", args.query)
connect_times: List[float] = []
query_times: List[float] = []
failures: List[str] = []
total = args.warmup + args.rounds
for idx in range(total):
is_warmup = idx < args.warmup
try:
c_ms, q_times = _run_round(
dsn, timeout, args.query, args.query_repeat, session
)
if not is_warmup:
connect_times.append(c_ms)
query_times.extend(q_times)
if args.verbose:
tag = "warmup" if is_warmup else "sample"
q_msg = ""
if args.query_repeat > 0:
q_avg = statistics.mean(q_times) if q_times else 0.0
q_msg = f", query_avg={q_avg:.2f}ms"
print(f"[{tag} {idx + 1}/{total}] connect={c_ms:.2f}ms{q_msg}")
except Exception as exc:
msg = f"round {idx + 1}: {exc}"
failures.append(msg)
print("Failure:", msg, file=sys.stderr)
if not args.continue_on_error:
break
if args.sleep_ms > 0:
time.sleep(args.sleep_ms / 1000.0)
if connect_times:
print(_format_stats("Connect", connect_times))
if args.query_repeat > 0:
print(_format_stats("Query", query_times))
if failures:
print(f"Failures: {len(failures)}", file=sys.stderr)
return 1
return 0
if __name__ == "__main__":
raise SystemExit(main())