Files
Neo-ZQYY/apps/backend/app/ai/bailian_client.py

274 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""百炼 API 统一封装层。
使用 openai Python SDK百炼兼容 OpenAI 协议),提供流式和非流式两种调用模式。
所有 AI 应用通过此客户端统一调用阿里云通义千问。
"""
from __future__ import annotations
import asyncio
import copy
import json
import logging
from datetime import datetime
from typing import Any, AsyncGenerator
import openai
logger = logging.getLogger(__name__)
# ── 异常类 ──────────────────────────────────────────────────────────
class BailianApiError(Exception):
"""百炼 API 调用失败(重试耗尽后)。"""
def __init__(self, message: str, status_code: int | None = None):
super().__init__(message)
self.status_code = status_code
class BailianJsonParseError(Exception):
"""百炼 API 返回的 JSON 解析失败。"""
def __init__(self, message: str, raw_content: str = ""):
super().__init__(message)
self.raw_content = raw_content
class BailianAuthError(BailianApiError):
"""百炼 API Key 无效HTTP 401"""
def __init__(self, message: str = "API Key 无效或已过期"):
super().__init__(message, status_code=401)
# ── 客户端 ──────────────────────────────────────────────────────────
class BailianClient:
"""百炼 API 统一封装层。
使用 openai.AsyncOpenAI 客户端base_url 指向百炼端点。
提供流式chat_stream和非流式chat_json两种调用模式。
"""
# 重试配置
MAX_RETRIES = 3
BASE_INTERVAL = 1 # 秒
def __init__(self, api_key: str, base_url: str, model: str):
"""初始化百炼客户端。
Args:
api_key: 百炼 API Key环境变量 BAILIAN_API_KEY
base_url: 百炼 API 端点(环境变量 BAILIAN_BASE_URL
model: 模型标识,如 qwen-plus环境变量 BAILIAN_MODEL
"""
self.model = model
self._client = openai.AsyncOpenAI(
api_key=api_key,
base_url=base_url,
)
async def chat_stream(
self,
messages: list[dict],
*,
temperature: float = 0.7,
max_tokens: int = 2000,
) -> AsyncGenerator[str, None]:
"""流式调用,逐 chunk yield 文本。用于应用 1 SSE。
Args:
messages: 消息列表
temperature: 温度参数,默认 0.7
max_tokens: 最大 token 数,默认 2000
Yields:
文本 chunk
"""
messages = self._inject_current_time(messages)
response = await self._call_with_retry(
model=self.model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=True,
)
async for chunk in response:
if chunk.choices and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
async def chat_json(
self,
messages: list[dict],
*,
temperature: float = 0.3,
max_tokens: int = 4000,
) -> tuple[dict, int]:
"""非流式调用,返回解析后的 JSON dict 和 tokens_used。
用于应用 2-8 的结构化输出。使用 response_format={"type": "json_object"}
确保返回合法 JSON。
Args:
messages: 消息列表
temperature: 温度参数,默认 0.3(结构化输出用低温度)
max_tokens: 最大 token 数,默认 4000
Returns:
(parsed_json_dict, tokens_used) 元组
Raises:
BailianJsonParseError: 响应内容无法解析为 JSON
BailianApiError: API 调用失败(重试耗尽后)
"""
messages = self._inject_current_time(messages)
response = await self._call_with_retry(
model=self.model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=False,
response_format={"type": "json_object"},
)
raw_content = response.choices[0].message.content or ""
tokens_used = response.usage.total_tokens if response.usage else 0
try:
parsed = json.loads(raw_content)
except (json.JSONDecodeError, TypeError) as e:
logger.error("百炼 API 返回非法 JSON: %s", raw_content[:500])
raise BailianJsonParseError(
f"JSON 解析失败: {e}",
raw_content=raw_content,
) from e
return parsed, tokens_used
def _inject_current_time(self, messages: list[dict]) -> list[dict]:
"""纯函数:在首条消息的 contentJSON 字符串)中注入 current_time 字段。
- 深拷贝输入,不修改原始 messages
- 首条消息 content 尝试解析为 JSON注入 current_time
- 如果首条消息 content 不是 JSON则包装为 JSON
- 其余消息不变
- current_time 格式ISO 8601 精确到秒,如 2026-03-08T14:30:00
Args:
messages: 原始消息列表
Returns:
注入 current_time 后的新消息列表
"""
if not messages:
return []
result = copy.deepcopy(messages)
first = result[0]
content = first.get("content", "")
now_str = datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
try:
parsed = json.loads(content)
if isinstance(parsed, dict):
parsed["current_time"] = now_str
else:
# content 是合法 JSON 但不是 dict如数组、字符串包装为 dict
parsed = {"original_content": parsed, "current_time": now_str}
except (json.JSONDecodeError, TypeError):
# content 不是 JSON包装为 dict
parsed = {"content": content, "current_time": now_str}
first["content"] = json.dumps(parsed, ensure_ascii=False)
return result
async def _call_with_retry(self, **kwargs: Any) -> Any:
"""带指数退避的重试封装。
重试策略:
- 最多重试 MAX_RETRIES 次(默认 3 次)
- 间隔BASE_INTERVAL × 2^(n-1),即 1s → 2s → 4s
- HTTP 4xx不重试直接抛出401 → BailianAuthError
- HTTP 5xx / 超时:重试
Args:
**kwargs: 传递给 openai client 的参数
Returns:
API 响应对象
Raises:
BailianAuthError: API Key 无效HTTP 401
BailianApiError: API 调用失败(重试耗尽后)
"""
is_stream = kwargs.get("stream", False)
last_error: Exception | None = None
for attempt in range(self.MAX_RETRIES):
try:
if is_stream:
# 流式调用:返回 async iterator
return await self._client.chat.completions.create(**kwargs)
else:
return await self._client.chat.completions.create(**kwargs)
except openai.AuthenticationError as e:
# 401API Key 无效,不重试
logger.error("百炼 API 认证失败: %s", e)
raise BailianAuthError(str(e)) from e
except openai.BadRequestError as e:
# 400请求参数错误不重试
logger.error("百炼 API 请求参数错误: %s", e)
raise BailianApiError(str(e), status_code=400) from e
except openai.RateLimitError as e:
# 429限流不重试属于 4xx
logger.error("百炼 API 限流: %s", e)
raise BailianApiError(str(e), status_code=429) from e
except openai.PermissionDeniedError as e:
# 403权限不足不重试
logger.error("百炼 API 权限不足: %s", e)
raise BailianApiError(str(e), status_code=403) from e
except openai.NotFoundError as e:
# 404资源不存在不重试
logger.error("百炼 API 资源不存在: %s", e)
raise BailianApiError(str(e), status_code=404) from e
except openai.UnprocessableEntityError as e:
# 422不可处理不重试
logger.error("百炼 API 不可处理的请求: %s", e)
raise BailianApiError(str(e), status_code=422) from e
except (openai.InternalServerError, openai.APIConnectionError, openai.APITimeoutError) as e:
# 5xx / 超时 / 连接错误:重试
last_error = e
if attempt < self.MAX_RETRIES - 1:
wait_time = self.BASE_INTERVAL * (2 ** attempt)
logger.warning(
"百炼 API 调用失败(第 %d/%d 次),%ds 后重试: %s",
attempt + 1,
self.MAX_RETRIES,
wait_time,
e,
)
await asyncio.sleep(wait_time)
else:
logger.error(
"百炼 API 调用失败,已达最大重试次数 %d: %s",
self.MAX_RETRIES,
e,
)
# 重试耗尽
status_code = getattr(last_error, "status_code", None)
raise BailianApiError(
f"百炼 API 调用失败(重试 {self.MAX_RETRIES} 次后): {last_error}",
status_code=status_code,
) from last_error