# -*- coding: utf-8 -*- """API客户端""" import requests from urllib3.util.retry import Retry from requests.adapters import HTTPAdapter class APIClient: """HTTP API客户端""" def __init__(self, base_url: str, token: str = None, timeout: int = 20, retry_max: int = 3, headers_extra: dict = None): self.base_url = base_url.rstrip("/") self.token = token self.timeout = timeout self.retry_max = retry_max self.headers_extra = headers_extra or {} self._session = None def _get_session(self): """获取或创建会话""" if self._session is None: self._session = requests.Session() retries = max(0, int(self.retry_max) - 1) retry = Retry( total=None, connect=retries, read=retries, status=retries, allowed_methods=frozenset(["GET"]), status_forcelist=(429, 500, 502, 503, 504), backoff_factor=1.0, respect_retry_after_header=True, raise_on_status=False, ) adapter = HTTPAdapter(max_retries=retry) self._session.mount("http://", adapter) self._session.mount("https://", adapter) if self.headers_extra: self._session.headers.update(self.headers_extra) return self._session def get(self, endpoint: str, params: dict = None) -> dict: """执行GET请求""" url = f"{self.base_url}/{endpoint.lstrip('/')}" headers = {"Authorization": self.token} if self.token else {} headers.update(self.headers_extra) sess = self._get_session() resp = sess.get(url, headers=headers, params=params, timeout=self.timeout) resp.raise_for_status() return resp.json() def iter_paginated( self, endpoint: str, params: dict | None, page_size: int = 200, page_field: str = "pageIndex", size_field: str = "pageSize", data_path: tuple = ("data",), list_key: str | None = None, ): """分页迭代器:逐页拉取数据并产出 (page_no, records, request_params, raw_response)。""" base_params = dict(params or {}) page = 1 while True: page_params = dict(base_params) page_params[page_field] = page page_params[size_field] = page_size payload = self.get(endpoint, page_params) records = self._extract_list(payload, data_path, list_key) yield page, records, page_params, payload if len(records) < page_size: break if len(records) == 0: break page += 1 def get_paginated(self, endpoint: str, params: dict, page_size: int = 200, page_field: str = "pageIndex", size_field: str = "pageSize", data_path: tuple = ("data",), list_key: str = None) -> tuple: """分页获取数据并将所有记录汇总在一个列表中。""" records, pages_meta = [], [] for page_no, page_records, request_params, response in self.iter_paginated( endpoint=endpoint, params=params, page_size=page_size, page_field=page_field, size_field=size_field, data_path=data_path, list_key=list_key, ): records.extend(page_records) pages_meta.append( {"page": page_no, "request": request_params, "response": response} ) return records, pages_meta @staticmethod def _extract_list(payload: dict, data_path: tuple, list_key: str | None): """辅助函数:根据 data_path/list_key 提取列表结构。""" cur = payload for key in data_path: if isinstance(cur, dict): cur = cur.get(key) else: cur = None if cur is None: break if list_key and isinstance(cur, dict): cur = cur.get(list_key) if not isinstance(cur, list): cur = [] return cur