257 lines
8.1 KiB
Python
257 lines
8.1 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""API客户端:统一封装 POST/重试/分页与列表提取逻辑。"""
|
||
from __future__ import annotations
|
||
|
||
from typing import Iterable, Sequence, Tuple
|
||
|
||
import requests
|
||
from requests.adapters import HTTPAdapter
|
||
from urllib3.util.retry import Retry
|
||
|
||
DEFAULT_BROWSER_HEADERS = {
|
||
"Accept": "application/json, text/plain, */*",
|
||
"Content-Type": "application/json",
|
||
"Origin": "https://pc.ficoo.vip",
|
||
"Referer": "https://pc.ficoo.vip/",
|
||
"User-Agent": (
|
||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
|
||
"(KHTML, like Gecko) Chrome/141.0.0.0 Safari/537.36"
|
||
),
|
||
"Accept-Language": "zh-CN,zh;q=0.9",
|
||
"sec-ch-ua": '"Google Chrome";v="141", "Not?A_Brand";v="8", "Chromium";v="141"',
|
||
"sec-ch-ua-platform": '"Windows"',
|
||
"sec-ch-ua-mobile": "?0",
|
||
"sec-fetch-site": "same-origin",
|
||
"sec-fetch-mode": "cors",
|
||
"sec-fetch-dest": "empty",
|
||
"priority": "u=1, i",
|
||
"X-Requested-With": "XMLHttpRequest",
|
||
"DNT": "1",
|
||
}
|
||
|
||
DEFAULT_LIST_KEYS: Tuple[str, ...] = (
|
||
"list",
|
||
"rows",
|
||
"records",
|
||
"items",
|
||
"dataList",
|
||
"data_list",
|
||
"tenantMemberInfos",
|
||
"tenantMemberCardLogs",
|
||
"tenantMemberCards",
|
||
"settleList",
|
||
"orderAssistantDetails",
|
||
"assistantInfos",
|
||
"siteTables",
|
||
"taiFeeAdjustInfos",
|
||
"siteTableUseDetailsList",
|
||
"tenantGoodsList",
|
||
"packageCouponList",
|
||
"queryDeliveryRecordsList",
|
||
"goodsCategoryList",
|
||
"orderGoodsList",
|
||
"orderGoodsLedgers",
|
||
)
|
||
|
||
|
||
class APIClient:
|
||
"""HTTP API 客户端(默认使用 POST + JSON 请求体)"""
|
||
|
||
def __init__(
|
||
self,
|
||
base_url: str,
|
||
token: str | None = None,
|
||
timeout: int = 20,
|
||
retry_max: int = 3,
|
||
headers_extra: dict | None = None,
|
||
):
|
||
self.base_url = (base_url or "").rstrip("/")
|
||
self.token = self._normalize_token(token)
|
||
self.timeout = timeout
|
||
self.retry_max = retry_max
|
||
self.headers_extra = headers_extra or {}
|
||
self._session: requests.Session | None = None
|
||
|
||
# ------------------------------------------------------------------ HTTP 基础
|
||
def _get_session(self) -> requests.Session:
|
||
"""获取或创建带重试的 Session。"""
|
||
if self._session is None:
|
||
self._session = requests.Session()
|
||
|
||
retries = max(0, int(self.retry_max) - 1)
|
||
retry = Retry(
|
||
total=None,
|
||
connect=retries,
|
||
read=retries,
|
||
status=retries,
|
||
allowed_methods=frozenset(["GET", "POST"]),
|
||
status_forcelist=(429, 500, 502, 503, 504),
|
||
backoff_factor=0.5,
|
||
respect_retry_after_header=True,
|
||
raise_on_status=False,
|
||
)
|
||
|
||
adapter = HTTPAdapter(max_retries=retry)
|
||
self._session.mount("http://", adapter)
|
||
self._session.mount("https://", adapter)
|
||
self._session.headers.update(self._build_headers())
|
||
|
||
return self._session
|
||
|
||
def get(self, endpoint: str, params: dict | None = None) -> dict:
|
||
"""
|
||
兼容旧名的请求入口(实际以 POST JSON 方式请求)。
|
||
"""
|
||
return self._post_json(endpoint, params)
|
||
|
||
def _post_json(self, endpoint: str, payload: dict | None = None) -> dict:
|
||
if not self.base_url:
|
||
raise ValueError("API base_url 未配置")
|
||
|
||
url = f"{self.base_url}/{endpoint.lstrip('/')}"
|
||
sess = self._get_session()
|
||
resp = sess.post(url, json=payload or {}, timeout=self.timeout)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
self._ensure_success(data)
|
||
return data
|
||
|
||
def _build_headers(self) -> dict:
|
||
headers = dict(DEFAULT_BROWSER_HEADERS)
|
||
headers.update(self.headers_extra)
|
||
if self.token:
|
||
headers["Authorization"] = self.token
|
||
return headers
|
||
|
||
@staticmethod
|
||
def _normalize_token(token: str | None) -> str | None:
|
||
if not token:
|
||
return None
|
||
t = str(token).strip()
|
||
if not t.lower().startswith("bearer "):
|
||
t = f"Bearer {t}"
|
||
return t
|
||
|
||
@staticmethod
|
||
def _ensure_success(payload: dict):
|
||
"""API 返回 code 非 0 时主动抛错,便于上层重试/记录。"""
|
||
if isinstance(payload, dict) and "code" in payload:
|
||
code = payload.get("code")
|
||
if code not in (0, "0", None):
|
||
msg = payload.get("msg") or payload.get("message") or ""
|
||
raise ValueError(f"API 返回错误 code={code} msg={msg}")
|
||
|
||
# ------------------------------------------------------------------ 分页
|
||
def iter_paginated(
|
||
self,
|
||
endpoint: str,
|
||
params: dict | None,
|
||
page_size: int | None = 200,
|
||
page_field: str = "page",
|
||
size_field: str = "limit",
|
||
data_path: tuple = ("data",),
|
||
list_key: str | Sequence[str] | None = None,
|
||
page_start: int = 1,
|
||
page_end: int | None = None,
|
||
) -> Iterable[tuple[int, list, dict, dict]]:
|
||
"""
|
||
分页迭代器:逐页拉取数据并产出 (page_no, records, request_params, raw_response)。
|
||
page_size=None 时不附带分页参数,仅拉取一次。
|
||
"""
|
||
base_params = dict(params or {})
|
||
page = page_start
|
||
|
||
while True:
|
||
page_params = dict(base_params)
|
||
if page_size is not None:
|
||
page_params[page_field] = page
|
||
page_params[size_field] = page_size
|
||
|
||
payload = self._post_json(endpoint, page_params)
|
||
records = self._extract_list(payload, data_path, list_key)
|
||
|
||
yield page, records, page_params, payload
|
||
|
||
if page_size is None:
|
||
break
|
||
if page_end is not None and page >= page_end:
|
||
break
|
||
if len(records) < (page_size or 0):
|
||
break
|
||
if len(records) == 0:
|
||
break
|
||
|
||
page += 1
|
||
|
||
def get_paginated(
|
||
self,
|
||
endpoint: str,
|
||
params: dict,
|
||
page_size: int | None = 200,
|
||
page_field: str = "page",
|
||
size_field: str = "limit",
|
||
data_path: tuple = ("data",),
|
||
list_key: str | Sequence[str] | None = None,
|
||
page_start: int = 1,
|
||
page_end: int | None = None,
|
||
) -> tuple[list, list]:
|
||
"""分页获取数据并将所有记录汇总在一个列表中。"""
|
||
records, pages_meta = [], []
|
||
|
||
for page_no, page_records, request_params, response in self.iter_paginated(
|
||
endpoint=endpoint,
|
||
params=params,
|
||
page_size=page_size,
|
||
page_field=page_field,
|
||
size_field=size_field,
|
||
data_path=data_path,
|
||
list_key=list_key,
|
||
page_start=page_start,
|
||
page_end=page_end,
|
||
):
|
||
records.extend(page_records)
|
||
pages_meta.append(
|
||
{"page": page_no, "request": request_params, "response": response}
|
||
)
|
||
|
||
return records, pages_meta
|
||
|
||
# ------------------------------------------------------------------ 响应解析
|
||
@classmethod
|
||
def _extract_list(
|
||
cls, payload: dict | list, data_path: tuple, list_key: str | Sequence[str] | None
|
||
) -> list:
|
||
"""根据 data_path/list_key 提取列表结构,兼容常见字段名。"""
|
||
cur: object = payload
|
||
|
||
if isinstance(cur, list):
|
||
return cur
|
||
|
||
for key in data_path:
|
||
if isinstance(cur, dict):
|
||
cur = cur.get(key)
|
||
else:
|
||
cur = None
|
||
if cur is None:
|
||
break
|
||
|
||
if isinstance(cur, list):
|
||
return cur
|
||
|
||
if isinstance(cur, dict):
|
||
if list_key:
|
||
keys = (list_key,) if isinstance(list_key, str) else tuple(list_key)
|
||
for k in keys:
|
||
if isinstance(cur.get(k), list):
|
||
return cur[k]
|
||
|
||
for k in DEFAULT_LIST_KEYS:
|
||
if isinstance(cur.get(k), list):
|
||
return cur[k]
|
||
|
||
for v in cur.values():
|
||
if isinstance(v, list):
|
||
return v
|
||
|
||
return []
|