187 lines
6.1 KiB
Python
187 lines
6.1 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""包装 APIClient,将分页响应落盘便于后续本地清洗。"""
|
||
from __future__ import annotations
|
||
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
import time
|
||
from typing import Any, Iterable, Tuple
|
||
from zoneinfo import ZoneInfo
|
||
|
||
from api.client import APIClient
|
||
from api.endpoint_routing import plan_calls
|
||
from utils.json_store import dump_json, endpoint_to_filename
|
||
|
||
|
||
class RecordingAPIClient:
|
||
"""
|
||
代理 APIClient,在调用 iter_paginated/get_paginated 时同时把响应写入 JSON 文件。
|
||
文件名根据 endpoint 生成,写入到指定 output_dir。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
base_client: APIClient,
|
||
output_dir: Path | str,
|
||
task_code: str,
|
||
run_id: int,
|
||
write_pretty: bool = False,
|
||
):
|
||
self.base = base_client
|
||
self.output_dir = Path(output_dir)
|
||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||
self.task_code = task_code
|
||
self.run_id = run_id
|
||
self.write_pretty = write_pretty
|
||
self.last_dump: dict[str, Any] | None = None
|
||
|
||
# ------------------------------------------------------------------ 公共 API
|
||
def get_source_hint(self, endpoint: str) -> str:
|
||
"""Return the JSON dump path for this endpoint (for source_file lineage)."""
|
||
return str(self.output_dir / endpoint_to_filename(endpoint))
|
||
|
||
def iter_paginated(
|
||
self,
|
||
endpoint: str,
|
||
params: dict | None,
|
||
page_size: int = 200,
|
||
page_field: str = "page",
|
||
size_field: str = "limit",
|
||
data_path: tuple = ("data",),
|
||
list_key: str | None = None,
|
||
) -> Iterable[Tuple[int, list, dict, dict]]:
|
||
pages: list[dict[str, Any]] = []
|
||
total_records = 0
|
||
|
||
for page_no, records, request_params, response in self.base.iter_paginated(
|
||
endpoint=endpoint,
|
||
params=params,
|
||
page_size=page_size,
|
||
page_field=page_field,
|
||
size_field=size_field,
|
||
data_path=data_path,
|
||
list_key=list_key,
|
||
):
|
||
pages.append({"page": page_no, "request": request_params, "response": response})
|
||
total_records += len(records)
|
||
yield page_no, records, request_params, response
|
||
|
||
self._dump(endpoint, params, page_size, pages, total_records)
|
||
|
||
def get_paginated(
|
||
self,
|
||
endpoint: str,
|
||
params: dict,
|
||
page_size: int = 200,
|
||
page_field: str = "page",
|
||
size_field: str = "limit",
|
||
data_path: tuple = ("data",),
|
||
list_key: str | None = None,
|
||
) -> tuple[list, list]:
|
||
records: list = []
|
||
pages_meta: list = []
|
||
|
||
for page_no, page_records, request_params, response in self.iter_paginated(
|
||
endpoint=endpoint,
|
||
params=params,
|
||
page_size=page_size,
|
||
page_field=page_field,
|
||
size_field=size_field,
|
||
data_path=data_path,
|
||
list_key=list_key,
|
||
):
|
||
records.extend(page_records)
|
||
pages_meta.append({"page": page_no, "request": request_params, "response": response})
|
||
|
||
return records, pages_meta
|
||
|
||
# ------------------------------------------------------------------ 内部方法
|
||
def _dump(
|
||
self,
|
||
endpoint: str,
|
||
params: dict | None,
|
||
page_size: int,
|
||
pages: list[dict[str, Any]],
|
||
total_records: int,
|
||
):
|
||
filename = endpoint_to_filename(endpoint)
|
||
path = self.output_dir / filename
|
||
routing_calls = []
|
||
try:
|
||
for call in plan_calls(endpoint, params):
|
||
routing_calls.append({"endpoint": call.endpoint, "params": call.params})
|
||
except Exception:
|
||
routing_calls = []
|
||
payload = {
|
||
"task_code": self.task_code,
|
||
"run_id": self.run_id,
|
||
"endpoint": endpoint,
|
||
"params": params or {},
|
||
"endpoint_routing": {"calls": routing_calls} if routing_calls else None,
|
||
"page_size": page_size,
|
||
"pages": pages,
|
||
"total_records": total_records,
|
||
"dumped_at": datetime.utcnow().isoformat() + "Z",
|
||
}
|
||
dump_json(path, payload, pretty=self.write_pretty)
|
||
self.last_dump = {
|
||
"file": str(path),
|
||
"endpoint": endpoint,
|
||
"pages": len(pages),
|
||
"records": total_records,
|
||
}
|
||
|
||
|
||
def _cfg_get(cfg, key: str, default=None):
|
||
if isinstance(cfg, dict):
|
||
cur = cfg
|
||
for part in key.split("."):
|
||
if not isinstance(cur, dict) or part not in cur:
|
||
return default
|
||
cur = cur[part]
|
||
return cur
|
||
getter = getattr(cfg, "get", None)
|
||
if callable(getter):
|
||
return getter(key, default)
|
||
return default
|
||
|
||
|
||
def build_recording_client(
|
||
cfg,
|
||
*,
|
||
task_code: str,
|
||
output_dir: Path | str | None = None,
|
||
run_id: int | None = None,
|
||
write_pretty: bool | None = None,
|
||
):
|
||
"""Build RecordingAPIClient from AppConfig or dict config."""
|
||
base_client = APIClient(
|
||
base_url=_cfg_get(cfg, "api.base_url") or "",
|
||
token=_cfg_get(cfg, "api.token"),
|
||
timeout=int(_cfg_get(cfg, "api.timeout_sec", 20) or 20),
|
||
retry_max=int(_cfg_get(cfg, "api.retries.max_attempts", 3) or 3),
|
||
headers_extra=_cfg_get(cfg, "api.headers_extra") or {},
|
||
)
|
||
|
||
if write_pretty is None:
|
||
write_pretty = bool(_cfg_get(cfg, "io.write_pretty_json", False))
|
||
|
||
if run_id is None:
|
||
run_id = int(time.time())
|
||
|
||
if output_dir is None:
|
||
tz_name = _cfg_get(cfg, "app.timezone", "Asia/Taipei") or "Asia/Taipei"
|
||
tz = ZoneInfo(tz_name)
|
||
ts = datetime.now(tz).strftime("%Y%m%d-%H%M%S")
|
||
fetch_root = _cfg_get(cfg, "pipeline.fetch_root") or _cfg_get(cfg, "io.export_root") or "export/JSON"
|
||
task_upper = str(task_code).upper()
|
||
output_dir = Path(fetch_root) / task_upper / f"{task_upper}-{run_id}-{ts}"
|
||
|
||
return RecordingAPIClient(
|
||
base_client=base_client,
|
||
output_dir=output_dir,
|
||
task_code=str(task_code),
|
||
run_id=int(run_id),
|
||
write_pretty=bool(write_pretty),
|
||
)
|