130 lines
4.2 KiB
Python
130 lines
4.2 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""API客户端"""
|
|
import requests
|
|
from urllib3.util.retry import Retry
|
|
from requests.adapters import HTTPAdapter
|
|
|
|
class APIClient:
|
|
"""HTTP API客户端"""
|
|
|
|
def __init__(self, base_url: str, token: str = None, timeout: int = 20,
|
|
retry_max: int = 3, headers_extra: dict = None):
|
|
self.base_url = base_url.rstrip("/")
|
|
self.token = token
|
|
self.timeout = timeout
|
|
self.retry_max = retry_max
|
|
self.headers_extra = headers_extra or {}
|
|
self._session = None
|
|
|
|
def _get_session(self):
|
|
"""获取或创建会话"""
|
|
if self._session is None:
|
|
self._session = requests.Session()
|
|
|
|
retries = max(0, int(self.retry_max) - 1)
|
|
retry = Retry(
|
|
total=None,
|
|
connect=retries,
|
|
read=retries,
|
|
status=retries,
|
|
allowed_methods=frozenset(["GET"]),
|
|
status_forcelist=(429, 500, 502, 503, 504),
|
|
backoff_factor=1.0,
|
|
respect_retry_after_header=True,
|
|
raise_on_status=False,
|
|
)
|
|
|
|
adapter = HTTPAdapter(max_retries=retry)
|
|
self._session.mount("http://", adapter)
|
|
self._session.mount("https://", adapter)
|
|
|
|
if self.headers_extra:
|
|
self._session.headers.update(self.headers_extra)
|
|
|
|
return self._session
|
|
|
|
def get(self, endpoint: str, params: dict = None) -> dict:
|
|
"""执行GET请求"""
|
|
url = f"{self.base_url}/{endpoint.lstrip('/')}"
|
|
headers = {"Authorization": self.token} if self.token else {}
|
|
headers.update(self.headers_extra)
|
|
|
|
sess = self._get_session()
|
|
resp = sess.get(url, headers=headers, params=params, timeout=self.timeout)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
def iter_paginated(
|
|
self,
|
|
endpoint: str,
|
|
params: dict | None,
|
|
page_size: int = 200,
|
|
page_field: str = "pageIndex",
|
|
size_field: str = "pageSize",
|
|
data_path: tuple = ("data",),
|
|
list_key: str | None = None,
|
|
):
|
|
"""分页迭代器:逐页拉取数据并产出 (page_no, records, request_params, raw_response)。"""
|
|
base_params = dict(params or {})
|
|
page = 1
|
|
|
|
while True:
|
|
page_params = dict(base_params)
|
|
page_params[page_field] = page
|
|
page_params[size_field] = page_size
|
|
|
|
payload = self.get(endpoint, page_params)
|
|
records = self._extract_list(payload, data_path, list_key)
|
|
|
|
yield page, records, page_params, payload
|
|
|
|
if len(records) < page_size:
|
|
break
|
|
|
|
if len(records) == 0:
|
|
break
|
|
|
|
page += 1
|
|
|
|
def get_paginated(self, endpoint: str, params: dict, page_size: int = 200,
|
|
page_field: str = "pageIndex", size_field: str = "pageSize",
|
|
data_path: tuple = ("data",), list_key: str = None) -> tuple:
|
|
"""分页获取数据并将所有记录汇总在一个列表中。"""
|
|
records, pages_meta = [], []
|
|
|
|
for page_no, page_records, request_params, response in self.iter_paginated(
|
|
endpoint=endpoint,
|
|
params=params,
|
|
page_size=page_size,
|
|
page_field=page_field,
|
|
size_field=size_field,
|
|
data_path=data_path,
|
|
list_key=list_key,
|
|
):
|
|
records.extend(page_records)
|
|
pages_meta.append(
|
|
{"page": page_no, "request": request_params, "response": response}
|
|
)
|
|
|
|
return records, pages_meta
|
|
|
|
@staticmethod
|
|
def _extract_list(payload: dict, data_path: tuple, list_key: str | None):
|
|
"""辅助函数:根据 data_path/list_key 提取列表结构。"""
|
|
cur = payload
|
|
for key in data_path:
|
|
if isinstance(cur, dict):
|
|
cur = cur.get(key)
|
|
else:
|
|
cur = None
|
|
if cur is None:
|
|
break
|
|
|
|
if list_key and isinstance(cur, dict):
|
|
cur = cur.get(list_key)
|
|
|
|
if not isinstance(cur, list):
|
|
cur = []
|
|
|
|
return cur
|