包含多个会话的累积代码变更: - backend: AI 聊天服务、触发器调度、认证增强、WebSocket、调度器最小间隔 - admin-web: ETL 状态页、任务管理、调度配置、登录优化 - miniprogram: 看板页面、聊天集成、UI 组件、导航更新 - etl: DWS 新任务(finance_area_daily/board_cache)、连接器增强 - tenant-admin: 项目初始化 - db: 19 个迁移脚本(etl_feiqiu 11 + zqyy_app 8) - packages/shared: 枚举和工具函数更新 - tools: 数据库工具、报表生成、健康检查 - docs: PRD/架构/部署/合约文档更新 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
319 lines
12 KiB
Python
319 lines
12 KiB
Python
"""DashScope Application API 统一封装层。
|
||
|
||
使用 dashscope.Application.call() 调用百炼智能体应用,
|
||
替代原 openai SDK 的通用模型 API。
|
||
|
||
- call_app_stream(): App1 流式调用,asyncio.Queue 桥接 async generator
|
||
- call_app(): App2~8 单轮调用,asyncio.to_thread() 包装
|
||
- _call_with_retry(): 指数退避重试(1s→2s→4s)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
from typing import Any, AsyncGenerator, Callable
|
||
|
||
import dashscope
|
||
from dashscope import Application
|
||
|
||
from app.ai.exceptions import (
|
||
DashScopeApiError,
|
||
DashScopeAuthError,
|
||
DashScopeJsonParseError,
|
||
DashScopeTimeoutError,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class DashScopeClient:
|
||
"""DashScope Application API 统一封装层。
|
||
|
||
通过 app_id 调用百炼控制台配置的智能体应用,
|
||
充分利用云端 System Prompt 和 MCP 工具。
|
||
"""
|
||
|
||
MAX_RETRIES = 3
|
||
BASE_INTERVAL = 1 # 秒
|
||
|
||
def __init__(self, api_key: str, workspace_id: str | None = None):
|
||
"""初始化。dashscope 通过全局变量设置密钥。
|
||
|
||
Args:
|
||
api_key: DashScope API Key
|
||
workspace_id: 百炼工作空间 ID(可选)
|
||
"""
|
||
dashscope.api_key = api_key
|
||
self._workspace_id = workspace_id
|
||
|
||
async def call_app_stream(
|
||
self,
|
||
app_id: str,
|
||
prompt: str,
|
||
session_id: str | None = None,
|
||
biz_params: dict | None = None,
|
||
) -> AsyncGenerator[str, None]:
|
||
"""App1 流式调用。
|
||
|
||
在线程中消费同步迭代器,通过 asyncio.Queue 桥接到 async generator。
|
||
错误通过 queue 传递给调用方。
|
||
|
||
Args:
|
||
app_id: 百炼应用 ID
|
||
prompt: 用户输入
|
||
session_id: 百炼 session_id(多轮对话)
|
||
biz_params: 业务参数(如 user_prompt_params)
|
||
|
||
Yields:
|
||
文本 chunk
|
||
"""
|
||
queue: asyncio.Queue[str | BaseException | None] = asyncio.Queue()
|
||
loop = asyncio.get_running_loop()
|
||
|
||
def _consume_in_thread() -> None:
|
||
"""在线程中消费同步迭代器,逐 chunk 放入 queue。"""
|
||
try:
|
||
call_kwargs: dict[str, Any] = {
|
||
"app_id": app_id,
|
||
"prompt": prompt,
|
||
"stream": True,
|
||
"incremental_output": True,
|
||
}
|
||
if session_id is not None:
|
||
call_kwargs["session_id"] = session_id
|
||
if biz_params is not None:
|
||
call_kwargs["biz_params"] = biz_params
|
||
if self._workspace_id is not None:
|
||
call_kwargs["workspace"] = self._workspace_id
|
||
|
||
response = Application.call(**call_kwargs)
|
||
for chunk in response:
|
||
if chunk.status_code == 200:
|
||
text = chunk.output.get("text", "")
|
||
if text:
|
||
asyncio.run_coroutine_threadsafe(
|
||
queue.put(text), loop
|
||
)
|
||
else:
|
||
# 非 200 状态码,构造异常传递给调用方
|
||
status = chunk.status_code
|
||
msg = getattr(chunk, "message", "") or f"状态码 {status}"
|
||
if status == 401:
|
||
err = DashScopeAuthError(msg)
|
||
else:
|
||
err = DashScopeApiError(msg, status_code=status)
|
||
asyncio.run_coroutine_threadsafe(
|
||
queue.put(err), loop
|
||
)
|
||
return
|
||
# 正常结束信号
|
||
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
|
||
except Exception as exc:
|
||
# 线程内未预期异常,传递给调用方
|
||
asyncio.run_coroutine_threadsafe(
|
||
queue.put(exc), loop
|
||
)
|
||
|
||
loop.run_in_executor(None, _consume_in_thread)
|
||
|
||
while True:
|
||
item = await queue.get()
|
||
if item is None:
|
||
break
|
||
if isinstance(item, BaseException):
|
||
raise item
|
||
yield item
|
||
|
||
async def call_app(
|
||
self,
|
||
app_id: str,
|
||
prompt: str,
|
||
session_id: str | None = None,
|
||
biz_params: dict | None = None,
|
||
) -> tuple[dict, int, str | None]:
|
||
"""App2~8 单轮调用。
|
||
|
||
通过 asyncio.to_thread() 包装同步 Application.call(),
|
||
解析 response.output.text 获取 JSON 内容。
|
||
非合法 JSON 触发重试(最多 3 次),不做本地修复。
|
||
|
||
Args:
|
||
app_id: 百炼应用 ID
|
||
prompt: 后端拼好的完整数据 JSON 字符串
|
||
session_id: 百炼 session_id(可选)
|
||
biz_params: 业务参数(可选)
|
||
|
||
Returns:
|
||
(parsed_json, tokens_used, new_session_id) 元组
|
||
|
||
Raises:
|
||
DashScopeApiError: API 调用失败(重试耗尽)
|
||
DashScopeJsonParseError: JSON 解析失败(重试耗尽)
|
||
"""
|
||
call_kwargs: dict[str, Any] = {
|
||
"app_id": app_id,
|
||
"prompt": prompt,
|
||
}
|
||
if session_id is not None:
|
||
call_kwargs["session_id"] = session_id
|
||
if biz_params is not None:
|
||
call_kwargs["biz_params"] = biz_params
|
||
if self._workspace_id is not None:
|
||
call_kwargs["workspace"] = self._workspace_id
|
||
|
||
# 非合法 JSON 纯重试,最多 MAX_RETRIES 次
|
||
last_json_error: DashScopeJsonParseError | None = None
|
||
for json_attempt in range(self.MAX_RETRIES):
|
||
response = await self._call_with_retry(
|
||
Application.call, **call_kwargs
|
||
)
|
||
|
||
# 提取 output.text
|
||
raw_text: str = ""
|
||
if hasattr(response, "output"):
|
||
output = response.output
|
||
if isinstance(output, dict):
|
||
raw_text = output.get("text", "")
|
||
elif hasattr(output, "text"):
|
||
raw_text = output.text or ""
|
||
|
||
# 提取 tokens_used
|
||
tokens_used = 0
|
||
if hasattr(response, "usage") and response.usage:
|
||
usage = response.usage
|
||
if isinstance(usage, dict):
|
||
# input_tokens + output_tokens
|
||
tokens_used = usage.get("input_tokens", 0) + usage.get(
|
||
"output_tokens", 0
|
||
)
|
||
elif hasattr(usage, "total_tokens"):
|
||
tokens_used = usage.total_tokens or 0
|
||
|
||
# 提取 new_session_id
|
||
new_session_id: str | None = None
|
||
if hasattr(response, "output") and isinstance(response.output, dict):
|
||
new_session_id = response.output.get("session_id")
|
||
|
||
# 解析 JSON
|
||
try:
|
||
parsed = json.loads(raw_text)
|
||
if isinstance(parsed, list):
|
||
# CHANGE 2026-03-23 | Prompt: App2 LLM 返回 list 而非 dict
|
||
# 百炼 LLM 有时直接返回 insights 数组而非包裹 dict,
|
||
# 自动包装为 {"insights": list} 避免无意义重试
|
||
logger.info(
|
||
"LLM 返回 list(长度 %d),自动包装为 {\"insights\": [...]}",
|
||
len(parsed),
|
||
)
|
||
parsed = {"insights": parsed}
|
||
if not isinstance(parsed, dict):
|
||
raise TypeError(f"期望 dict,实际 {type(parsed).__name__}")
|
||
return parsed, tokens_used, new_session_id
|
||
except (json.JSONDecodeError, TypeError) as e:
|
||
last_json_error = DashScopeJsonParseError(
|
||
f"JSON 解析失败(第 {json_attempt + 1}/{self.MAX_RETRIES} 次): {e}",
|
||
raw_content=raw_text,
|
||
)
|
||
logger.warning(
|
||
"Application API 返回非法 JSON(第 %d/%d 次): %s",
|
||
json_attempt + 1,
|
||
self.MAX_RETRIES,
|
||
raw_text[:500],
|
||
)
|
||
# 非合法 JSON 纯重试,不做本地修复
|
||
continue
|
||
|
||
# JSON 重试耗尽
|
||
raise last_json_error # type: ignore[misc]
|
||
|
||
async def _call_with_retry(self, func: Callable, **kwargs: Any) -> Any:
|
||
"""指数退避重试封装。
|
||
|
||
重试策略:
|
||
- 最多重试 MAX_RETRIES 次(默认 3 次)
|
||
- 间隔:BASE_INTERVAL × 2^(n-1),即 1s → 2s → 4s
|
||
- HTTP 4xx → 不重试,立即抛出(401 → DashScopeAuthError)
|
||
- HTTP 5xx / 超时 / 连接错误 → 重试
|
||
|
||
Args:
|
||
func: 同步调用函数(如 Application.call)
|
||
**kwargs: 传递给 func 的参数
|
||
|
||
Returns:
|
||
API 响应对象(status_code == 200)
|
||
|
||
Raises:
|
||
DashScopeAuthError: API Key 无效(HTTP 401)
|
||
DashScopeTimeoutError: 调用超时(重试耗尽)
|
||
DashScopeApiError: API 调用失败(重试耗尽)
|
||
"""
|
||
last_error: Exception | None = None
|
||
|
||
for attempt in range(self.MAX_RETRIES):
|
||
try:
|
||
response = await asyncio.to_thread(func, **kwargs)
|
||
except Exception as exc:
|
||
# 网络/连接/超时等底层异常 → 可重试
|
||
last_error = exc
|
||
if attempt < self.MAX_RETRIES - 1:
|
||
wait_time = self.BASE_INTERVAL * (2**attempt)
|
||
logger.warning(
|
||
"DashScope API 底层异常(第 %d/%d 次),%ds 后重试: %s",
|
||
attempt + 1,
|
||
self.MAX_RETRIES,
|
||
wait_time,
|
||
exc,
|
||
)
|
||
await asyncio.sleep(wait_time)
|
||
continue
|
||
else:
|
||
logger.error(
|
||
"DashScope API 底层异常,已达最大重试次数 %d: %s",
|
||
self.MAX_RETRIES,
|
||
exc,
|
||
)
|
||
raise DashScopeApiError(
|
||
f"DashScope API 调用失败(重试 {self.MAX_RETRIES} 次后): {exc}",
|
||
) from exc
|
||
|
||
# Application.call() 返回 response 对象,通过 status_code 判断成功/失败
|
||
status_code = getattr(response, "status_code", None)
|
||
|
||
if status_code == 200:
|
||
return response
|
||
|
||
# 非 200:根据状态码分类处理
|
||
message = getattr(response, "message", "") or f"状态码 {status_code}"
|
||
|
||
if status_code is not None and 400 <= status_code < 500:
|
||
# 4xx:不重试,立即抛出
|
||
if status_code == 401:
|
||
raise DashScopeAuthError(message)
|
||
raise DashScopeApiError(message, status_code=status_code)
|
||
|
||
# 5xx 或其他未知状态码 → 可重试
|
||
last_error = DashScopeApiError(message, status_code=status_code)
|
||
if attempt < self.MAX_RETRIES - 1:
|
||
wait_time = self.BASE_INTERVAL * (2**attempt)
|
||
logger.warning(
|
||
"DashScope API 调用失败(第 %d/%d 次,状态码 %s),%ds 后重试: %s",
|
||
attempt + 1,
|
||
self.MAX_RETRIES,
|
||
status_code,
|
||
wait_time,
|
||
message,
|
||
)
|
||
await asyncio.sleep(wait_time)
|
||
else:
|
||
logger.error(
|
||
"DashScope API 调用失败,已达最大重试次数 %d(状态码 %s): %s",
|
||
self.MAX_RETRIES,
|
||
status_code,
|
||
message,
|
||
)
|
||
|
||
# 重试耗尽
|
||
raise last_error # type: ignore[misc]
|