Files
Neo-ZQYY/scripts/ops/test_chat_e2e.py

226 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
RNS1.4 CHAT 模块端到端测试脚本
用法: python scripts/ops/test_chat_e2e.py
前置条件:
- 后端服务已启动 (uvicorn app.main:app)
- .env 中配置了 TEST_USER_TOKEN 和 APP_DB_DSN
- test_zqyy_app 数据库可访问
环境变量:
BACKEND_URL — 后端地址,默认 http://localhost:8000
TEST_USER_TOKEN — 测试用户 JWT token
APP_DB_DSN — 业务数据库连接串(指向 test_zqyy_app
"""
# CHANGE 2026-03-20 | RNS1.4 T13.1: CHAT 端到端测试脚本
import json
import os
import sys
import time
from pathlib import Path
# 加载根 .env
from dotenv import load_dotenv
load_dotenv(Path(__file__).resolve().parents[2] / ".env")
BACKEND_URL = os.environ.get("BACKEND_URL", "http://localhost:8000").rstrip("/")
TOKEN = os.environ.get("TEST_USER_TOKEN", "")
DB_DSN = os.environ.get("APP_DB_DSN", "")
if not TOKEN:
print("❌ 缺少 TEST_USER_TOKEN 环境变量")
sys.exit(1)
if not DB_DSN:
print("❌ 缺少 APP_DB_DSN 环境变量")
sys.exit(1)
import httpx
import psycopg2
HEADERS = {"Authorization": f"Bearer {TOKEN}", "Content-Type": "application/json"}
results: list[dict] = []
def record(name: str, passed: bool, detail: str = ""):
status = "✅ PASS" if passed else "❌ FAIL"
print(f" {status}{name}" + (f" ({detail})" if detail else ""))
results.append({"test": name, "passed": passed, "detail": detail})
def main():
print(f"\n🔗 后端: {BACKEND_URL}")
print(f"🗄️ 数据库: {DB_DSN[:40]}...\n")
chat_id = None
# ── CHAT-1: 对话历史列表 ──
print("── CHAT-1: GET /api/xcx/chat/history ──")
try:
r = httpx.get(f"{BACKEND_URL}/api/xcx/chat/history", headers=HEADERS, timeout=10)
record("CHAT-1 状态码 200", r.status_code == 200, f"got {r.status_code}")
if r.status_code == 200:
data = r.json()
record("CHAT-1 返回 items 数组", isinstance(data.get("items"), list))
except Exception as e:
record("CHAT-1 请求", False, str(e))
# ── 创建/获取对话(通过 general 入口)──
print("\n── CHAT-2b: GET /api/xcx/chat/messages?contextType=general ──")
try:
r = httpx.get(
f"{BACKEND_URL}/api/xcx/chat/messages",
params={"contextType": "general", "contextId": ""},
headers=HEADERS, timeout=10,
)
record("CHAT-2b 状态码 200", r.status_code == 200, f"got {r.status_code}")
if r.status_code == 200:
data = r.json()
chat_id = data.get("chatId")
record("CHAT-2b 返回 chatId", chat_id is not None, f"chatId={chat_id}")
except Exception as e:
record("CHAT-2b 请求", False, str(e))
if not chat_id:
print("\n⚠️ 无法获取 chatId跳过后续测试")
_print_summary()
return
# ── CHAT-2a: 通过 chatId 查询消息 ──
print(f"\n── CHAT-2a: GET /api/xcx/chat/{chat_id}/messages ──")
try:
r = httpx.get(
f"{BACKEND_URL}/api/xcx/chat/{chat_id}/messages",
headers=HEADERS, timeout=10,
)
record("CHAT-2a 状态码 200", r.status_code == 200, f"got {r.status_code}")
except Exception as e:
record("CHAT-2a 请求", False, str(e))
# ── CHAT-3: 发送消息(同步) ──
print(f"\n── CHAT-3: POST /api/xcx/chat/{chat_id}/messages ──")
test_content = "你好,这是一条端到端测试消息,请简短回复。"
ai_reply_id = None
try:
r = httpx.post(
f"{BACKEND_URL}/api/xcx/chat/{chat_id}/messages",
json={"content": test_content},
headers=HEADERS, timeout=30,
)
record("CHAT-3 状态码 200", r.status_code == 200, f"got {r.status_code}")
if r.status_code == 200:
data = r.json()
user_msg = data.get("userMessage", {})
ai_msg = data.get("aiReply", {})
record("CHAT-3 用户消息已返回", bool(user_msg.get("id")))
record("CHAT-3 AI 回复已返回", bool(ai_msg.get("id")))
record("CHAT-3 AI 回复非空", bool(ai_msg.get("content")))
ai_reply_id = ai_msg.get("id")
if ai_msg.get("content"):
print(f" AI 回复: {ai_msg['content'][:80]}...")
except Exception as e:
record("CHAT-3 请求", False, str(e))
# ── CHAT-4: SSE 流式 ──
print(f"\n── CHAT-4: POST /api/xcx/chat/stream (SSE) ──")
sse_content = "请用一句话介绍台球运动。"
sse_tokens: list[str] = []
sse_done = False
sse_message_id = None
try:
with httpx.stream(
"POST",
f"{BACKEND_URL}/api/xcx/chat/stream",
json={"chatId": int(chat_id), "content": sse_content},
headers=HEADERS,
timeout=60,
) as resp:
record("CHAT-4 状态码 200", resp.status_code == 200, f"got {resp.status_code}")
current_event = ""
for line in resp.iter_lines():
if line.startswith("event:"):
current_event = line[6:].strip()
elif line.startswith("data:"):
raw = line[5:].strip()
try:
d = json.loads(raw)
except json.JSONDecodeError:
continue
if current_event == "message" and "token" in d:
sse_tokens.append(d["token"])
elif current_event == "done":
sse_done = True
sse_message_id = d.get("messageId")
elif current_event == "error":
record("CHAT-4 收到 error 事件", False, d.get("message", ""))
full_reply = "".join(sse_tokens)
record("CHAT-4 收到 token 事件", len(sse_tokens) > 0, f"{len(sse_tokens)} tokens")
record("CHAT-4 收到 done 事件", sse_done)
record("CHAT-4 回复语义通顺", len(full_reply) > 5, f"len={len(full_reply)}")
if full_reply:
print(f" SSE 回复: {full_reply[:80]}...")
except Exception as e:
record("CHAT-4 请求", False, str(e))
# ── 数据库验证 ──
print("\n── 数据库持久化验证 ──")
try:
conn = psycopg2.connect(DB_DSN)
cur = conn.cursor()
# 验证 ai_messages 包含用户消息和 AI 回复
cur.execute(
"SELECT id, role, tokens_used FROM biz.ai_messages "
"WHERE conversation_id = %s ORDER BY created_at DESC LIMIT 4",
(int(chat_id),),
)
rows = cur.fetchall()
roles = [r[1] for r in rows]
record("DB: ai_messages 有 user 消息", "user" in roles)
record("DB: ai_messages 有 assistant 回复", "assistant" in roles)
# 验证 tokens_used
assistant_rows = [r for r in rows if r[1] == "assistant"]
if assistant_rows:
tokens = assistant_rows[0][2]
record("DB: tokens_used 已记录", tokens is not None and tokens > 0,
f"tokens_used={tokens}")
else:
record("DB: tokens_used 已记录", False, "无 assistant 行")
# 验证 ai_conversations 元数据更新
cur.execute(
"SELECT last_message, last_message_at FROM biz.ai_conversations WHERE id = %s",
(int(chat_id),),
)
conv = cur.fetchone()
if conv:
record("DB: last_message 已更新", bool(conv[0]))
record("DB: last_message_at 已更新", conv[1] is not None)
else:
record("DB: ai_conversations 记录存在", False)
cur.close()
conn.close()
except Exception as e:
record("DB 验证", False, str(e))
_print_summary()
def _print_summary():
print("\n" + "=" * 50)
passed = sum(1 for r in results if r["passed"])
failed = sum(1 for r in results if not r["passed"])
print(f"总计: {len(results)} 项 | ✅ {passed} 通过 | ❌ {failed} 失败")
# 输出 JSON 报告
report_path = Path(__file__).parent / "chat_e2e_report.json"
report_path.write_text(json.dumps(results, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"📄 报告已保存: {report_path}")
if __name__ == "__main__":
main()