64 lines
1.8 KiB
Python
64 lines
1.8 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""Quick utility for validating PostgreSQL connectivity."""
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import os
|
||
import sys
|
||
|
||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||
if PROJECT_ROOT not in sys.path:
|
||
sys.path.insert(0, PROJECT_ROOT)
|
||
|
||
from database.connection import DatabaseConnection
|
||
|
||
|
||
def parse_args() -> argparse.Namespace:
|
||
parser = argparse.ArgumentParser(description="PostgreSQL connectivity smoke test")
|
||
parser.add_argument("--dsn", help="Override TEST_DB_DSN / env value")
|
||
parser.add_argument(
|
||
"--query",
|
||
default="SELECT 1 AS ok",
|
||
help="Custom SQL to run after connection (default: SELECT 1 AS ok)",
|
||
)
|
||
parser.add_argument(
|
||
"--timeout",
|
||
type=int,
|
||
default=5,
|
||
help="connect_timeout seconds passed to psycopg2 (default: 5)",
|
||
)
|
||
return parser.parse_args()
|
||
|
||
|
||
def main() -> int:
|
||
args = parse_args()
|
||
dsn = args.dsn or os.environ.get("TEST_DB_DSN")
|
||
if not dsn:
|
||
print("❌ 未提供 DSN,请通过 --dsn 或 TEST_DB_DSN 指定连接串", file=sys.stderr)
|
||
return 2
|
||
|
||
print(f"尝试连接: {dsn}")
|
||
try:
|
||
conn = DatabaseConnection(dsn, connect_timeout=args.timeout)
|
||
except Exception as exc: # pragma: no cover - diagnostic output
|
||
print("❌ 连接失败:", exc, file=sys.stderr)
|
||
return 1
|
||
|
||
try:
|
||
result = conn.query(args.query)
|
||
print("✅ 连接成功,查询结果:")
|
||
for row in result:
|
||
print(row)
|
||
conn.close()
|
||
return 0
|
||
except Exception as exc: # pragma: no cover - diagnostic output
|
||
print("⚠️ 连接成功但执行查询失败:", exc, file=sys.stderr)
|
||
try:
|
||
conn.close()
|
||
finally:
|
||
return 3
|
||
|
||
|
||
if __name__ == "__main__":
|
||
raise SystemExit(main())
|