274 lines
9.7 KiB
Python
274 lines
9.7 KiB
Python
"""百炼 API 统一封装层。
|
||
|
||
使用 openai Python SDK(百炼兼容 OpenAI 协议),提供流式和非流式两种调用模式。
|
||
所有 AI 应用通过此客户端统一调用阿里云通义千问。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import copy
|
||
import json
|
||
import logging
|
||
from datetime import datetime
|
||
from typing import Any, AsyncGenerator
|
||
|
||
import openai
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ── 异常类 ──────────────────────────────────────────────────────────
|
||
|
||
|
||
class BailianApiError(Exception):
|
||
"""百炼 API 调用失败(重试耗尽后)。"""
|
||
|
||
def __init__(self, message: str, status_code: int | None = None):
|
||
super().__init__(message)
|
||
self.status_code = status_code
|
||
|
||
|
||
class BailianJsonParseError(Exception):
|
||
"""百炼 API 返回的 JSON 解析失败。"""
|
||
|
||
def __init__(self, message: str, raw_content: str = ""):
|
||
super().__init__(message)
|
||
self.raw_content = raw_content
|
||
|
||
|
||
class BailianAuthError(BailianApiError):
|
||
"""百炼 API Key 无效(HTTP 401)。"""
|
||
|
||
def __init__(self, message: str = "API Key 无效或已过期"):
|
||
super().__init__(message, status_code=401)
|
||
|
||
|
||
# ── 客户端 ──────────────────────────────────────────────────────────
|
||
|
||
|
||
class BailianClient:
|
||
"""百炼 API 统一封装层。
|
||
|
||
使用 openai.AsyncOpenAI 客户端,base_url 指向百炼端点。
|
||
提供流式(chat_stream)和非流式(chat_json)两种调用模式。
|
||
"""
|
||
|
||
# 重试配置
|
||
MAX_RETRIES = 3
|
||
BASE_INTERVAL = 1 # 秒
|
||
|
||
def __init__(self, api_key: str, base_url: str, model: str):
|
||
"""初始化百炼客户端。
|
||
|
||
Args:
|
||
api_key: 百炼 API Key(环境变量 BAILIAN_API_KEY)
|
||
base_url: 百炼 API 端点(环境变量 BAILIAN_BASE_URL)
|
||
model: 模型标识,如 qwen-plus(环境变量 BAILIAN_MODEL)
|
||
"""
|
||
self.model = model
|
||
self._client = openai.AsyncOpenAI(
|
||
api_key=api_key,
|
||
base_url=base_url,
|
||
)
|
||
|
||
async def chat_stream(
|
||
self,
|
||
messages: list[dict],
|
||
*,
|
||
temperature: float = 0.7,
|
||
max_tokens: int = 2000,
|
||
) -> AsyncGenerator[str, None]:
|
||
"""流式调用,逐 chunk yield 文本。用于应用 1 SSE。
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
temperature: 温度参数,默认 0.7
|
||
max_tokens: 最大 token 数,默认 2000
|
||
|
||
Yields:
|
||
文本 chunk
|
||
"""
|
||
messages = self._inject_current_time(messages)
|
||
response = await self._call_with_retry(
|
||
model=self.model,
|
||
messages=messages,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
stream=True,
|
||
)
|
||
async for chunk in response:
|
||
if chunk.choices and chunk.choices[0].delta.content:
|
||
yield chunk.choices[0].delta.content
|
||
|
||
async def chat_json(
|
||
self,
|
||
messages: list[dict],
|
||
*,
|
||
temperature: float = 0.3,
|
||
max_tokens: int = 4000,
|
||
) -> tuple[dict, int]:
|
||
"""非流式调用,返回解析后的 JSON dict 和 tokens_used。
|
||
|
||
用于应用 2-8 的结构化输出。使用 response_format={"type": "json_object"}
|
||
确保返回合法 JSON。
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
temperature: 温度参数,默认 0.3(结构化输出用低温度)
|
||
max_tokens: 最大 token 数,默认 4000
|
||
|
||
Returns:
|
||
(parsed_json_dict, tokens_used) 元组
|
||
|
||
Raises:
|
||
BailianJsonParseError: 响应内容无法解析为 JSON
|
||
BailianApiError: API 调用失败(重试耗尽后)
|
||
"""
|
||
messages = self._inject_current_time(messages)
|
||
response = await self._call_with_retry(
|
||
model=self.model,
|
||
messages=messages,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
stream=False,
|
||
response_format={"type": "json_object"},
|
||
)
|
||
raw_content = response.choices[0].message.content or ""
|
||
tokens_used = response.usage.total_tokens if response.usage else 0
|
||
|
||
try:
|
||
parsed = json.loads(raw_content)
|
||
except (json.JSONDecodeError, TypeError) as e:
|
||
logger.error("百炼 API 返回非法 JSON: %s", raw_content[:500])
|
||
raise BailianJsonParseError(
|
||
f"JSON 解析失败: {e}",
|
||
raw_content=raw_content,
|
||
) from e
|
||
|
||
return parsed, tokens_used
|
||
|
||
def _inject_current_time(self, messages: list[dict]) -> list[dict]:
|
||
"""纯函数:在首条消息的 content(JSON 字符串)中注入 current_time 字段。
|
||
|
||
- 深拷贝输入,不修改原始 messages
|
||
- 首条消息 content 尝试解析为 JSON,注入 current_time
|
||
- 如果首条消息 content 不是 JSON,则包装为 JSON
|
||
- 其余消息不变
|
||
- current_time 格式:ISO 8601 精确到秒,如 2026-03-08T14:30:00
|
||
|
||
Args:
|
||
messages: 原始消息列表
|
||
|
||
Returns:
|
||
注入 current_time 后的新消息列表
|
||
"""
|
||
if not messages:
|
||
return []
|
||
|
||
result = copy.deepcopy(messages)
|
||
first = result[0]
|
||
content = first.get("content", "")
|
||
now_str = datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
|
||
|
||
try:
|
||
parsed = json.loads(content)
|
||
if isinstance(parsed, dict):
|
||
parsed["current_time"] = now_str
|
||
else:
|
||
# content 是合法 JSON 但不是 dict(如数组、字符串),包装为 dict
|
||
parsed = {"original_content": parsed, "current_time": now_str}
|
||
except (json.JSONDecodeError, TypeError):
|
||
# content 不是 JSON,包装为 dict
|
||
parsed = {"content": content, "current_time": now_str}
|
||
|
||
first["content"] = json.dumps(parsed, ensure_ascii=False)
|
||
return result
|
||
|
||
async def _call_with_retry(self, **kwargs: Any) -> Any:
|
||
"""带指数退避的重试封装。
|
||
|
||
重试策略:
|
||
- 最多重试 MAX_RETRIES 次(默认 3 次)
|
||
- 间隔:BASE_INTERVAL × 2^(n-1),即 1s → 2s → 4s
|
||
- HTTP 4xx:不重试,直接抛出(401 → BailianAuthError)
|
||
- HTTP 5xx / 超时:重试
|
||
|
||
Args:
|
||
**kwargs: 传递给 openai client 的参数
|
||
|
||
Returns:
|
||
API 响应对象
|
||
|
||
Raises:
|
||
BailianAuthError: API Key 无效(HTTP 401)
|
||
BailianApiError: API 调用失败(重试耗尽后)
|
||
"""
|
||
is_stream = kwargs.get("stream", False)
|
||
last_error: Exception | None = None
|
||
|
||
for attempt in range(self.MAX_RETRIES):
|
||
try:
|
||
if is_stream:
|
||
# 流式调用:返回 async iterator
|
||
return await self._client.chat.completions.create(**kwargs)
|
||
else:
|
||
return await self._client.chat.completions.create(**kwargs)
|
||
|
||
except openai.AuthenticationError as e:
|
||
# 401:API Key 无效,不重试
|
||
logger.error("百炼 API 认证失败: %s", e)
|
||
raise BailianAuthError(str(e)) from e
|
||
|
||
except openai.BadRequestError as e:
|
||
# 400:请求参数错误,不重试
|
||
logger.error("百炼 API 请求参数错误: %s", e)
|
||
raise BailianApiError(str(e), status_code=400) from e
|
||
|
||
except openai.RateLimitError as e:
|
||
# 429:限流,不重试(属于 4xx)
|
||
logger.error("百炼 API 限流: %s", e)
|
||
raise BailianApiError(str(e), status_code=429) from e
|
||
|
||
except openai.PermissionDeniedError as e:
|
||
# 403:权限不足,不重试
|
||
logger.error("百炼 API 权限不足: %s", e)
|
||
raise BailianApiError(str(e), status_code=403) from e
|
||
|
||
except openai.NotFoundError as e:
|
||
# 404:资源不存在,不重试
|
||
logger.error("百炼 API 资源不存在: %s", e)
|
||
raise BailianApiError(str(e), status_code=404) from e
|
||
|
||
except openai.UnprocessableEntityError as e:
|
||
# 422:不可处理,不重试
|
||
logger.error("百炼 API 不可处理的请求: %s", e)
|
||
raise BailianApiError(str(e), status_code=422) from e
|
||
|
||
except (openai.InternalServerError, openai.APIConnectionError, openai.APITimeoutError) as e:
|
||
# 5xx / 超时 / 连接错误:重试
|
||
last_error = e
|
||
if attempt < self.MAX_RETRIES - 1:
|
||
wait_time = self.BASE_INTERVAL * (2 ** attempt)
|
||
logger.warning(
|
||
"百炼 API 调用失败(第 %d/%d 次),%ds 后重试: %s",
|
||
attempt + 1,
|
||
self.MAX_RETRIES,
|
||
wait_time,
|
||
e,
|
||
)
|
||
await asyncio.sleep(wait_time)
|
||
else:
|
||
logger.error(
|
||
"百炼 API 调用失败,已达最大重试次数 %d: %s",
|
||
self.MAX_RETRIES,
|
||
e,
|
||
)
|
||
|
||
# 重试耗尽
|
||
status_code = getattr(last_error, "status_code", None)
|
||
raise BailianApiError(
|
||
f"百炼 API 调用失败(重试 {self.MAX_RETRIES} 次后): {last_error}",
|
||
status_code=status_code,
|
||
) from last_error
|