合并
This commit is contained in:
267
etl_billiards/scripts/test_db_performance.py
Normal file
267
etl_billiards/scripts/test_db_performance.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""PostgreSQL connection performance test (ASCII-only output)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import statistics
|
||||
import sys
|
||||
import time
|
||||
from typing import Dict, Iterable, List
|
||||
|
||||
from psycopg2.extensions import make_dsn, parse_dsn
|
||||
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, PROJECT_ROOT)
|
||||
|
||||
from database.connection import DatabaseConnection
|
||||
|
||||
|
||||
def _load_env() -> Dict[str, str]:
|
||||
env: Dict[str, str] = {}
|
||||
try:
|
||||
from config.env_parser import _load_dotenv_values
|
||||
except Exception:
|
||||
_load_dotenv_values = None
|
||||
if _load_dotenv_values:
|
||||
try:
|
||||
env.update(_load_dotenv_values())
|
||||
except Exception:
|
||||
pass
|
||||
env.update(os.environ)
|
||||
return env
|
||||
|
||||
|
||||
def _apply_dsn_overrides(dsn: str, host: str | None, port: int | None) -> str:
|
||||
overrides = {}
|
||||
if host:
|
||||
overrides["host"] = host
|
||||
if port:
|
||||
overrides["port"] = str(port)
|
||||
if not overrides:
|
||||
return dsn
|
||||
return make_dsn(dsn, **overrides)
|
||||
|
||||
|
||||
def _build_dsn_from_env(
|
||||
host: str,
|
||||
port: int,
|
||||
user: str | None,
|
||||
password: str | None,
|
||||
dbname: str | None,
|
||||
) -> str | None:
|
||||
if not user or not dbname:
|
||||
return None
|
||||
params = {
|
||||
"host": host,
|
||||
"port": str(port),
|
||||
"user": user,
|
||||
"dbname": dbname,
|
||||
}
|
||||
if password:
|
||||
params["password"] = password
|
||||
return make_dsn("", **params)
|
||||
|
||||
|
||||
def _safe_dsn_summary(dsn: str, host: str | None, port: int | None) -> str:
|
||||
try:
|
||||
info = parse_dsn(dsn)
|
||||
except Exception:
|
||||
info = {}
|
||||
if host:
|
||||
info["host"] = host
|
||||
if port:
|
||||
info["port"] = str(port)
|
||||
info.pop("password", None)
|
||||
if not info:
|
||||
return "dsn=(hidden)"
|
||||
items = " ".join(f"{k}={info[k]}" for k in sorted(info.keys()))
|
||||
return items
|
||||
|
||||
|
||||
def _percentile(values: List[float], pct: float) -> float:
|
||||
if not values:
|
||||
return 0.0
|
||||
ordered = sorted(values)
|
||||
if len(ordered) == 1:
|
||||
return ordered[0]
|
||||
rank = (len(ordered) - 1) * (pct / 100.0)
|
||||
low = int(math.floor(rank))
|
||||
high = int(math.ceil(rank))
|
||||
if low == high:
|
||||
return ordered[low]
|
||||
return ordered[low] + (ordered[high] - ordered[low]) * (rank - low)
|
||||
|
||||
|
||||
def _format_stats(label: str, values: Iterable[float]) -> str:
|
||||
data = list(values)
|
||||
if not data:
|
||||
return f"{label}: no samples"
|
||||
avg = statistics.mean(data)
|
||||
stdev = statistics.stdev(data) if len(data) > 1 else 0.0
|
||||
return (
|
||||
f"{label}: count={len(data)} "
|
||||
f"min={min(data):.2f}ms avg={avg:.2f}ms "
|
||||
f"p50={_percentile(data, 50):.2f}ms "
|
||||
f"p95={_percentile(data, 95):.2f}ms "
|
||||
f"max={max(data):.2f}ms stdev={stdev:.2f}ms"
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="PostgreSQL connection performance test")
|
||||
parser.add_argument("--dsn", help="Override PG_DSN/TEST_DB_DSN/.env value")
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
default="100.64.0.4",
|
||||
help="Override host in DSN (default: 100.64.0.4)",
|
||||
)
|
||||
parser.add_argument("--port", type=int, help="Override port in DSN")
|
||||
parser.add_argument("--user", help="User when building DSN from PG_* env")
|
||||
parser.add_argument("--password", help="Password when building DSN from PG_* env")
|
||||
parser.add_argument("--dbname", help="Database name when building DSN from PG_* env")
|
||||
parser.add_argument("--rounds", type=int, default=20, help="Measured connection rounds")
|
||||
parser.add_argument("--warmup", type=int, default=2, help="Warmup rounds (not recorded)")
|
||||
parser.add_argument("--query", default="SELECT 1", help="SQL to run after connect")
|
||||
parser.add_argument(
|
||||
"--query-repeat",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Query repetitions per connection (0 to skip)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--connect-timeout",
|
||||
type=int,
|
||||
default=10,
|
||||
help="connect_timeout seconds (capped at 20, default: 10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--statement-timeout-ms",
|
||||
type=int,
|
||||
help="Optional statement_timeout applied per connection",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sleep-ms",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Sleep between rounds in milliseconds",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--continue-on-error",
|
||||
action="store_true",
|
||||
help="Continue even if a round fails",
|
||||
)
|
||||
parser.add_argument("--verbose", action="store_true", help="Print per-round timings")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _run_round(
|
||||
dsn: str,
|
||||
timeout: int,
|
||||
query: str,
|
||||
query_repeat: int,
|
||||
session: Dict[str, int] | None,
|
||||
) -> tuple[float, List[float]]:
|
||||
start = time.perf_counter()
|
||||
conn = DatabaseConnection(dsn, connect_timeout=timeout, session=session)
|
||||
connect_ms = (time.perf_counter() - start) * 1000.0
|
||||
query_times: List[float] = []
|
||||
try:
|
||||
for _ in range(query_repeat):
|
||||
q_start = time.perf_counter()
|
||||
conn.query(query)
|
||||
query_times.append((time.perf_counter() - q_start) * 1000.0)
|
||||
return connect_ms, query_times
|
||||
finally:
|
||||
try:
|
||||
conn.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
conn.close()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
if args.rounds < 0 or args.warmup < 0 or args.query_repeat < 0:
|
||||
print("rounds/warmup/query-repeat must be >= 0", file=sys.stderr)
|
||||
return 2
|
||||
env = _load_env()
|
||||
|
||||
dsn = args.dsn or env.get("PG_DSN") or env.get("TEST_DB_DSN")
|
||||
host = args.host
|
||||
port = args.port
|
||||
|
||||
if not dsn:
|
||||
user = args.user or env.get("PG_USER")
|
||||
password = args.password if args.password is not None else env.get("PG_PASSWORD")
|
||||
dbname = args.dbname or env.get("PG_NAME")
|
||||
try:
|
||||
resolved_port = port or int(env.get("PG_PORT", "5432"))
|
||||
except ValueError:
|
||||
resolved_port = port or 5432
|
||||
dsn = _build_dsn_from_env(host, resolved_port, user, password, dbname)
|
||||
if not dsn:
|
||||
print(
|
||||
"Missing DSN. Provide --dsn or set PG_DSN/TEST_DB_DSN, or PG_USER + PG_NAME.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 2
|
||||
dsn = _apply_dsn_overrides(dsn, host, port)
|
||||
|
||||
timeout = max(1, min(int(args.connect_timeout), 20))
|
||||
session = None
|
||||
if args.statement_timeout_ms is not None:
|
||||
session = {"statement_timeout_ms": int(args.statement_timeout_ms)}
|
||||
|
||||
print("Target:", _safe_dsn_summary(dsn, host, port))
|
||||
print(
|
||||
f"Rounds: {args.rounds} (warmup {args.warmup}), "
|
||||
f"query_repeat={args.query_repeat}, timeout={timeout}s"
|
||||
)
|
||||
if args.query_repeat > 0:
|
||||
print("Query:", args.query)
|
||||
|
||||
connect_times: List[float] = []
|
||||
query_times: List[float] = []
|
||||
failures: List[str] = []
|
||||
|
||||
total = args.warmup + args.rounds
|
||||
for idx in range(total):
|
||||
is_warmup = idx < args.warmup
|
||||
try:
|
||||
c_ms, q_times = _run_round(
|
||||
dsn, timeout, args.query, args.query_repeat, session
|
||||
)
|
||||
if not is_warmup:
|
||||
connect_times.append(c_ms)
|
||||
query_times.extend(q_times)
|
||||
if args.verbose:
|
||||
tag = "warmup" if is_warmup else "sample"
|
||||
q_msg = ""
|
||||
if args.query_repeat > 0:
|
||||
q_avg = statistics.mean(q_times) if q_times else 0.0
|
||||
q_msg = f", query_avg={q_avg:.2f}ms"
|
||||
print(f"[{tag} {idx + 1}/{total}] connect={c_ms:.2f}ms{q_msg}")
|
||||
except Exception as exc:
|
||||
msg = f"round {idx + 1}: {exc}"
|
||||
failures.append(msg)
|
||||
print("Failure:", msg, file=sys.stderr)
|
||||
if not args.continue_on_error:
|
||||
break
|
||||
if args.sleep_ms > 0:
|
||||
time.sleep(args.sleep_ms / 1000.0)
|
||||
|
||||
if connect_times:
|
||||
print(_format_stats("Connect", connect_times))
|
||||
if args.query_repeat > 0:
|
||||
print(_format_stats("Query", query_times))
|
||||
if failures:
|
||||
print(f"Failures: {len(failures)}", file=sys.stderr)
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user