# -*- coding: utf-8 -*- """Apply the PRD-aligned warehouse schema (ODS/DWD/DWS) to PostgreSQL.""" from __future__ import annotations import argparse import os import sys from pathlib import Path PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from database.connection import DatabaseConnection # noqa: E402 def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Create/upgrade warehouse schemas using schema_v2.sql" ) parser.add_argument( "--dsn", help="PostgreSQL DSN (fallback to PG_DSN env)", default=os.environ.get("PG_DSN"), ) parser.add_argument( "--file", help="Path to schema SQL", default=str(PROJECT_ROOT / "database" / "schema_v2.sql"), ) parser.add_argument( "--timeout", type=int, default=int(os.environ.get("PG_CONNECT_TIMEOUT", 10) or 10), help="connect_timeout seconds (capped at 20, default 10)", ) return parser.parse_args() def apply_schema(dsn: str, sql_path: Path, timeout: int) -> None: if not sql_path.exists(): raise FileNotFoundError(f"Schema file not found: {sql_path}") sql_text = sql_path.read_text(encoding="utf-8") timeout_val = max(1, min(timeout, 20)) conn = DatabaseConnection(dsn, connect_timeout=timeout_val) try: with conn.conn.cursor() as cur: cur.execute(sql_text) conn.commit() except Exception: conn.rollback() raise finally: conn.close() def main() -> int: args = parse_args() if not args.dsn: print("Missing DSN. Set PG_DSN or pass --dsn.", file=sys.stderr) return 2 try: apply_schema(args.dsn, Path(args.file), args.timeout) except Exception as exc: # pragma: no cover - utility script print(f"Schema apply failed: {exc}", file=sys.stderr) return 1 print("Schema applied successfully.") return 0 if __name__ == "__main__": raise SystemExit(main())