# -*- coding: utf-8 -*- """Quick utility for validating PostgreSQL connectivity.""" from __future__ import annotations import argparse import os import sys PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) if PROJECT_ROOT not in sys.path: sys.path.insert(0, PROJECT_ROOT) from database.connection import DatabaseConnection def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="PostgreSQL connectivity smoke test") parser.add_argument("--dsn", help="Override TEST_DB_DSN / env value") parser.add_argument( "--query", default="SELECT 1 AS ok", help="Custom SQL to run after connection (default: SELECT 1 AS ok)", ) parser.add_argument( "--timeout", type=int, default=5, help="connect_timeout seconds passed to psycopg2 (default: 5)", ) return parser.parse_args() def main() -> int: args = parse_args() dsn = args.dsn or os.environ.get("TEST_DB_DSN") if not dsn: print("❌ 未提供 DSN,请通过 --dsn 或 TEST_DB_DSN 指定连接串", file=sys.stderr) return 2 print(f"尝试连接: {dsn}") try: conn = DatabaseConnection(dsn, connect_timeout=args.timeout) except Exception as exc: # pragma: no cover - diagnostic output print("❌ 连接失败:", exc, file=sys.stderr) return 1 try: result = conn.query(args.query) print("✅ 连接成功,查询结果:") for row in result: print(row) conn.close() return 0 except Exception as exc: # pragma: no cover - diagnostic output print("⚠️ 连接成功但执行查询失败:", exc, file=sys.stderr) try: conn.close() finally: return 3 if __name__ == "__main__": raise SystemExit(main())