Files
Neo-ZQYY/apps/backend/app/services/chat_service.py

686 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
CHAT 模块业务逻辑层。
封装对话管理、消息持久化、referenceCard 组装、标题生成等核心逻辑。
路由层xcx_chat.py调用本服务完成 CHAT-1/2/3/4 端点的业务处理。
表依赖:
- biz.ai_conversations — 对话会话(含 context_type/context_id/title/last_message 扩展字段)
- biz.ai_messages — 消息记录(含 reference_card 扩展字段)
- fdw_etl.v_dim_member — 会员信息(通过 ETL 直连)
- fdw_etl.v_dws_member_consumption_summary / v_dwd_assistant_service_log — 消费指标
⚠️ P5 PRD 合规:
- app_id 固定为 'app1_chat'
- 用户消息发送时即写入 ai_messagesrole=user
- 流式完成后完整 assistant 回复写入 ai_messagesrole=assistant含 tokens_used
"""
from __future__ import annotations
import json
import logging
import os
from datetime import datetime
from decimal import Decimal
from typing import Any
from fastapi import HTTPException, status
from app.ai.bailian_client import BailianClient
from app.database import get_connection
from app.services import fdw_queries
logger = logging.getLogger(__name__)
APP_ID = "app1_chat"
# 对话复用时限(天)
_REUSE_DAYS = 3
class ChatService:
"""CHAT 模块业务逻辑。"""
# ------------------------------------------------------------------
# CHAT-1: 对话历史列表
# ------------------------------------------------------------------
def get_chat_history(
self,
user_id: int,
site_id: int,
page: int,
page_size: int,
) -> tuple[list[dict], int]:
"""查询对话历史列表,返回 (items, total)。
按 last_message_at 倒序JOIN v_dim_member 获取 customerName。
仅返回 app_id='app1_chat' 的对话。
"""
offset = (page - 1) * page_size
conn = get_connection()
try:
with conn.cursor() as cur:
# 总数
cur.execute(
"""
SELECT COUNT(*)
FROM biz.ai_conversations
WHERE user_id = %s AND site_id = %s AND app_id = %s
""",
(str(user_id), site_id, APP_ID),
)
total = cur.fetchone()[0]
# 分页列表
cur.execute(
"""
SELECT id, title, context_type, context_id,
last_message, last_message_at, created_at
FROM biz.ai_conversations
WHERE user_id = %s AND site_id = %s AND app_id = %s
ORDER BY COALESCE(last_message_at, created_at) DESC
LIMIT %s OFFSET %s
""",
(str(user_id), site_id, APP_ID, page_size, offset),
)
columns = [desc[0] for desc in cur.description]
rows = cur.fetchall()
finally:
conn.close()
# 组装结果,尝试获取 customerName
items: list[dict] = []
# 收集需要查询姓名的 customer context_id
customer_ids: list[int] = []
raw_items: list[dict] = []
for row in rows:
item = dict(zip(columns, row))
raw_items.append(item)
if item.get("context_type") == "customer" and item.get("context_id"):
try:
customer_ids.append(int(item["context_id"]))
except (ValueError, TypeError):
pass
# 批量查询客户姓名FDW 降级:查询失败返回空映射)
name_map: dict[int, str] = {}
if customer_ids:
try:
biz_conn = get_connection()
try:
info_map = fdw_queries.get_member_info(biz_conn, site_id, customer_ids)
for mid, info in info_map.items():
name_map[mid] = info.get("nickname") or ""
finally:
biz_conn.close()
except Exception:
logger.warning("查询客户姓名失败,降级为空", exc_info=True)
for item in raw_items:
customer_name: str | None = None
if item.get("context_type") == "customer" and item.get("context_id"):
try:
customer_name = name_map.get(int(item["context_id"]))
except (ValueError, TypeError):
pass
# 生成标题
title = self.generate_title(
title=item.get("title"),
customer_name=customer_name,
conversation_id=item["id"],
)
ts = item.get("last_message_at") or item.get("created_at")
items.append({
"id": item["id"],
"title": title,
"customer_name": customer_name,
"last_message": item.get("last_message"),
"timestamp": ts.isoformat() if isinstance(ts, datetime) else str(ts) if ts else "",
"unread_count": 0,
})
return items, total
# ------------------------------------------------------------------
# 对话复用 / 创建
# ------------------------------------------------------------------
def get_or_create_session(
self,
user_id: int,
site_id: int,
context_type: str,
context_id: str | None,
) -> int:
"""按入口上下文查找或创建对话,返回 chat_id。
复用规则:
- context_type='task': 同一 taskId 始终复用(无时限)
- context_type='customer'/'coach': 最后消息 ≤ 3 天复用,> 3 天新建
- context_type='general': 始终新建
"""
# general 入口始终新建
if context_type == "general":
return self._create_session(user_id, site_id, context_type, context_id)
conn = get_connection()
try:
with conn.cursor() as cur:
if context_type == "task":
# task 入口:始终复用(无时限)
cur.execute(
"""
SELECT id FROM biz.ai_conversations
WHERE user_id = %s AND site_id = %s
AND context_type = 'task' AND context_id = %s
ORDER BY created_at DESC LIMIT 1
""",
(str(user_id), site_id, context_id),
)
elif context_type in ("customer", "coach"):
# customer/coach 入口3 天时限复用
cur.execute(
"""
SELECT id FROM biz.ai_conversations
WHERE user_id = %s AND site_id = %s
AND context_type = %s AND context_id = %s
AND last_message_at > NOW() - INTERVAL '3 days'
ORDER BY last_message_at DESC LIMIT 1
""",
(str(user_id), site_id, context_type, context_id),
)
else:
# 未知类型,新建
return self._create_session(user_id, site_id, context_type, context_id)
row = cur.fetchone()
if row:
return row[0]
finally:
conn.close()
# 未找到可复用对话,新建
return self._create_session(user_id, site_id, context_type, context_id)
def _create_session(
self,
user_id: int,
site_id: int,
context_type: str,
context_id: str | None,
) -> int:
"""创建新对话记录,返回 conversation_id。"""
conn = get_connection()
try:
with conn.cursor() as cur:
# 查询用户昵称
cur.execute(
"SELECT nickname FROM auth.users WHERE id = %s",
(user_id,),
)
row = cur.fetchone()
nickname = row[0] if row and row[0] else ""
cur.execute(
"""
INSERT INTO biz.ai_conversations
(user_id, nickname, app_id, site_id, context_type, context_id)
VALUES (%s, %s, %s, %s, %s, %s)
RETURNING id
""",
(str(user_id), nickname, APP_ID, site_id, context_type, context_id),
)
new_id = cur.fetchone()[0]
conn.commit()
return new_id
except Exception:
conn.rollback()
raise
finally:
conn.close()
# ------------------------------------------------------------------
# CHAT-2: 消息列表
# ------------------------------------------------------------------
def get_messages(
self,
chat_id: int,
user_id: int,
site_id: int,
page: int,
page_size: int,
) -> tuple[list[dict], int, int]:
"""查询消息列表,返回 (messages, total, chat_id)。
验证 chat_id 归属当前用户,按 created_at 正序。
"""
self._verify_ownership(chat_id, user_id, site_id)
offset = (page - 1) * page_size
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"SELECT COUNT(*) FROM biz.ai_messages WHERE conversation_id = %s",
(chat_id,),
)
total = cur.fetchone()[0]
cur.execute(
"""
SELECT id, role, content, created_at, reference_card
FROM biz.ai_messages
WHERE conversation_id = %s
ORDER BY created_at ASC
LIMIT %s OFFSET %s
""",
(chat_id, page_size, offset),
)
columns = [desc[0] for desc in cur.description]
rows = cur.fetchall()
finally:
conn.close()
messages = []
for row in rows:
item = dict(zip(columns, row))
ref_card = item.get("reference_card")
# reference_card 可能是 dictpsycopg2 自动解析 jsonb或 str
if isinstance(ref_card, str):
try:
ref_card = json.loads(ref_card)
except (json.JSONDecodeError, TypeError):
ref_card = None
created_at = item["created_at"]
messages.append({
"id": item["id"],
"role": item["role"],
"content": item["content"],
"created_at": created_at.isoformat() if isinstance(created_at, datetime) else str(created_at),
"reference_card": ref_card,
})
return messages, total, chat_id
# ------------------------------------------------------------------
# CHAT-3: 发送消息(同步回复)
# ------------------------------------------------------------------
async def send_message_sync(
self,
chat_id: int,
content: str,
user_id: int,
site_id: int,
) -> dict:
"""发送消息并获取同步 AI 回复。
流程:
1. 验证 chatId 归属
2. 存入用户消息(立即写入)
3. 调用 AI 获取回复
4. 存入 AI 回复
5. 更新 session 的 last_message / last_message_at
6. AI 失败时返回错误提示消息HTTP 200
"""
self._verify_ownership(chat_id, user_id, site_id)
# 1. 立即存入用户消息P5 PRD 合规:发送时即写入)
user_msg_id, user_created_at = self._save_message(chat_id, "user", content)
# 2. 调用 AI
ai_reply_text: str
tokens_used: int | None = None
try:
ai_reply_text, tokens_used = await self._call_ai(chat_id, content, user_id, site_id)
except Exception as e:
logger.error("AI 服务调用失败: %s", e, exc_info=True)
ai_reply_text = "抱歉AI 助手暂时无法回复,请稍后重试"
# 3. 存入 AI 回复
ai_msg_id, ai_created_at = self._save_message(
chat_id, "assistant", ai_reply_text, tokens_used=tokens_used,
)
# 4. 更新 session 元数据
self._update_session_metadata(chat_id, ai_reply_text)
return {
"user_message": {
"id": user_msg_id,
"content": content,
"created_at": user_created_at,
},
"ai_reply": {
"id": ai_msg_id,
"content": ai_reply_text,
"created_at": ai_created_at,
},
}
# ------------------------------------------------------------------
# referenceCard 组装
# ------------------------------------------------------------------
def build_reference_card(
self,
customer_id: int,
site_id: int,
) -> dict | None:
"""从 FDW 查询客户关键指标,组装 referenceCard。
⚠️ DWD-DOC 规则:金额用 items_sum 口径ledger_amount
会员信息通过 member_id JOIN dim_memberscd2_is_current=1
FDW 查询失败时静默降级返回 None不影响消息本身
"""
try:
biz_conn = get_connection()
try:
# 客户姓名
info_map = fdw_queries.get_member_info(biz_conn, site_id, [customer_id])
if customer_id not in info_map:
return None
member_name = info_map[customer_id].get("nickname") or "未知客户"
# 余额
balance: Decimal | None = None
try:
balance_map = fdw_queries.get_member_balance(biz_conn, site_id, [customer_id])
balance = balance_map.get(customer_id)
except Exception:
logger.warning("referenceCard: 查询余额失败", exc_info=True)
# 近 30 天消费items_sum 口径)
consume_30d: Decimal | None = None
try:
consume_30d = self._get_consumption_30d(biz_conn, site_id, customer_id)
except Exception:
logger.warning("referenceCard: 查询近30天消费失败", exc_info=True)
# 近 30 天到店次数
visit_count: int | None = None
try:
visit_count = self._get_visit_count_30d(biz_conn, site_id, customer_id)
except Exception:
logger.warning("referenceCard: 查询到店次数失败", exc_info=True)
finally:
biz_conn.close()
# 格式化
balance_str = f"¥{balance:,.2f}" if balance is not None else ""
consume_str = f"¥{consume_30d:,.2f}" if consume_30d is not None else ""
visit_str = f"{visit_count}" if visit_count is not None else ""
return {
"type": "customer",
"title": f"{member_name} — 消费概览",
"summary": f"余额 {balance_str}近30天消费 {consume_str}",
"data": {
"余额": balance_str,
"近30天消费": consume_str,
"到店次数": visit_str,
},
}
except Exception:
logger.warning("referenceCard 组装失败,降级为 null", exc_info=True)
return None
# ------------------------------------------------------------------
# 标题生成
# ------------------------------------------------------------------
def generate_title(
self,
title: str | None = None,
customer_name: str | None = None,
conversation_id: int | None = None,
first_message: str | None = None,
) -> str:
"""生成对话标题:自定义标题 > 客户姓名 > 首条消息前 20 字。
结果始终非空。
"""
# 优先级 1自定义标题
if title and title.strip():
return title.strip()
# 优先级 2客户姓名
if customer_name and customer_name.strip():
return customer_name.strip()
# 优先级 3首条消息前 20 字
if first_message is None and conversation_id is not None:
first_message = self._get_first_message(conversation_id)
if first_message and first_message.strip():
text = first_message.strip()
return text[:20] if len(text) > 20 else text
return "新对话"
# ------------------------------------------------------------------
# 内部辅助方法
# ------------------------------------------------------------------
def _verify_ownership(self, chat_id: int, user_id: int, site_id: int) -> None:
"""验证对话归属当前用户,不属于时抛出 HTTP 403/404。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT user_id FROM biz.ai_conversations
WHERE id = %s AND site_id = %s
""",
(chat_id, site_id),
)
row = cur.fetchone()
if not row:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="对话不存在",
)
if str(row[0]) != str(user_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="无权访问此对话",
)
finally:
conn.close()
def _save_message(
self,
conversation_id: int,
role: str,
content: str,
tokens_used: int | None = None,
reference_card: dict | None = None,
) -> tuple[int, str]:
"""写入消息记录,返回 (message_id, created_at ISO 字符串)。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO biz.ai_messages
(conversation_id, role, content, tokens_used, reference_card)
VALUES (%s, %s, %s, %s, %s)
RETURNING id, created_at
""",
(
conversation_id,
role,
content,
tokens_used,
json.dumps(reference_card, ensure_ascii=False) if reference_card else None,
),
)
row = cur.fetchone()
conn.commit()
msg_id = row[0]
created_at = row[1]
return msg_id, created_at.isoformat() if isinstance(created_at, datetime) else str(created_at)
except Exception:
conn.rollback()
raise
finally:
conn.close()
def _update_session_metadata(self, chat_id: int, last_message: str) -> None:
"""更新对话的 last_message 和 last_message_at。"""
# 截断至 100 字
truncated = last_message[:100] if len(last_message) > 100 else last_message
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE biz.ai_conversations
SET last_message = %s, last_message_at = NOW()
WHERE id = %s
""",
(truncated, chat_id),
)
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
def _get_first_message(self, conversation_id: int) -> str | None:
"""查询对话的首条 user 消息内容。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT content FROM biz.ai_messages
WHERE conversation_id = %s AND role = 'user'
ORDER BY created_at ASC LIMIT 1
""",
(conversation_id,),
)
row = cur.fetchone()
return row[0] if row else None
finally:
conn.close()
async def _call_ai(
self,
chat_id: int,
content: str,
user_id: int,
site_id: int,
) -> tuple[str, int | None]:
"""调用百炼 API 获取非流式回复,返回 (reply_text, tokens_used)。
构建历史消息上下文发送给 AI。
"""
bailian = _get_bailian_client()
# 获取历史消息作为上下文(最近 20 条)
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT role, content FROM biz.ai_messages
WHERE conversation_id = %s
ORDER BY created_at ASC
""",
(chat_id,),
)
history = cur.fetchall()
finally:
conn.close()
# 构建消息列表
messages: list[dict] = []
# 取最近 20 条(含刚写入的 user 消息)
recent = history[-20:] if len(history) > 20 else history
for role, msg_content in recent:
messages.append({"role": role, "content": msg_content})
# 如果没有 system 消息,添加默认 system prompt
if not messages or messages[0]["role"] != "system":
system_prompt = {
"role": "system",
"content": json.dumps(
{"task": "你是台球门店的 AI 助手,根据用户的问题和当前页面上下文提供帮助。"},
ensure_ascii=False,
),
}
messages.insert(0, system_prompt)
# 非流式调用chat_stream 用于 SSE这里用 chat_stream 收集完整回复)
full_parts: list[str] = []
async for chunk in bailian.chat_stream(messages):
full_parts.append(chunk)
reply = "".join(full_parts)
# 流式模式不返回 tokens_used按字符数估算
estimated_tokens = len(reply)
return reply, estimated_tokens
@staticmethod
def _get_consumption_30d(conn: Any, site_id: int, member_id: int) -> Decimal | None:
"""查询客户近 30 天消费金额items_sum 口径)。
⚠️ DWD-DOC 规则 1: 使用 ledger_amountitems_sum 口径),禁用 consume_money。
"""
with fdw_queries._fdw_context(conn, site_id) as cur:
cur.execute(
"""
SELECT COALESCE(SUM(ledger_amount), 0)
FROM app.v_dwd_assistant_service_log
WHERE tenant_member_id = %s
AND is_delete = 0
AND create_time >= (CURRENT_DATE - INTERVAL '30 days')::timestamptz
""",
(member_id,),
)
row = cur.fetchone()
return Decimal(str(row[0])) if row and row[0] is not None else None
@staticmethod
def _get_visit_count_30d(conn: Any, site_id: int, member_id: int) -> int | None:
"""查询客户近 30 天到店次数。"""
with fdw_queries._fdw_context(conn, site_id) as cur:
cur.execute(
"""
SELECT COUNT(DISTINCT create_time::date)
FROM app.v_dwd_assistant_service_log
WHERE tenant_member_id = %s
AND is_delete = 0
AND create_time >= (CURRENT_DATE - INTERVAL '30 days')::timestamptz
""",
(member_id,),
)
row = cur.fetchone()
return int(row[0]) if row and row[0] is not None else None
# ── 模块级辅助函数 ──────────────────────────────────────────────
def _get_bailian_client() -> BailianClient:
"""从环境变量构建 BailianClient缺失时报错。"""
api_key = os.environ.get("BAILIAN_API_KEY")
base_url = os.environ.get("BAILIAN_BASE_URL")
model = os.environ.get("BAILIAN_MODEL")
if not api_key or not base_url or not model:
raise RuntimeError(
"百炼 API 环境变量缺失,需要 BAILIAN_API_KEY、BAILIAN_BASE_URL、BAILIAN_MODEL"
)
return BailianClient(api_key=api_key, base_url=base_url, model=model)