"""AI 运行日志服务 — biz.ai_run_logs 表的 CRUD 操作。 每次 Application API 调用前创建 pending 记录,调用过程中更新状态, 调用结束后记录结果。同时提供日/月 token 聚合查询,实现 UsageProvider 协议 以便注入 BudgetTracker。 request_prompt 写入前截断为前 2000 字符,避免大 prompt 占用过多存储。 """ from __future__ import annotations from datetime import datetime, timezone from typing import Callable import psycopg2.extensions # prompt 最大存储长度 _MAX_PROMPT_LENGTH = 2000 def _truncate_prompt(prompt: str | None) -> str | None: """截断 prompt 为前 2000 字符。None 原样返回。""" if prompt is None: return None return prompt[:_MAX_PROMPT_LENGTH] class AIRunLogService: """AI 运行日志 CRUD,实现 UsageProvider 协议。 构造函数接受 get_conn callable,每次操作时获取数据库连接, 避免长期持有连接导致超时或连接池耗尽。 """ def __init__(self, get_conn: Callable[[], psycopg2.extensions.connection]) -> None: self._get_conn = get_conn # ── 创建 ────────────────────────────────────────────── def create_log( self, site_id: int, app_type: str, trigger_type: str, *, member_id: int | None = None, request_prompt: str | None = None, session_id: str | None = None, ) -> int: """创建日志记录(status: pending),返回 log_id。 request_prompt 自动截断为前 2000 字符。 """ truncated = _truncate_prompt(request_prompt) conn = self._get_conn() try: with conn.cursor() as cur: cur.execute( """ INSERT INTO biz.ai_run_logs (site_id, app_type, trigger_type, member_id, request_prompt, session_id, status) VALUES (%s, %s, %s, %s, %s, %s, 'pending') RETURNING id """, (site_id, app_type, trigger_type, member_id, truncated, session_id), ) row = cur.fetchone() assert row is not None, "INSERT RETURNING 应返回 id" log_id: int = row[0] conn.commit() return log_id except Exception: conn.rollback() raise # ── 状态转换 ────────────────────────────────────────── def update_running(self, log_id: int) -> None: """更新为 running。""" conn = self._get_conn() try: with conn.cursor() as cur: cur.execute( """ UPDATE biz.ai_run_logs SET status = 'running' WHERE id = %s """, (log_id,), ) conn.commit() except Exception: conn.rollback() raise def update_success( self, log_id: int, response_text: str, tokens_used: int, latency_ms: int, ) -> None: """更新为 success,记录响应、token 消耗和耗时。""" now = datetime.now(timezone.utc) conn = self._get_conn() try: with conn.cursor() as cur: cur.execute( """ UPDATE biz.ai_run_logs SET status = 'success', response_text = %s, tokens_used = %s, latency_ms = %s, finished_at = %s WHERE id = %s """, (response_text, tokens_used, latency_ms, now, log_id), ) conn.commit() except Exception: conn.rollback() raise def update_failed( self, log_id: int, error_message: str, latency_ms: int, ) -> None: """更新为 failed,记录错误信息和耗时。""" now = datetime.now(timezone.utc) conn = self._get_conn() try: with conn.cursor() as cur: cur.execute( """ UPDATE biz.ai_run_logs SET status = 'failed', error_message = %s, latency_ms = %s, finished_at = %s WHERE id = %s """, (error_message, latency_ms, now, log_id), ) conn.commit() except Exception: conn.rollback() raise def update_timeout(self, log_id: int, latency_ms: int) -> None: """更新为 timeout。""" now = datetime.now(timezone.utc) conn = self._get_conn() try: with conn.cursor() as cur: cur.execute( """ UPDATE biz.ai_run_logs SET status = 'timeout', latency_ms = %s, finished_at = %s WHERE id = %s """, (latency_ms, now, log_id), ) conn.commit() except Exception: conn.rollback() raise # ── UsageProvider 协议实现 ──────────────────────────── def get_daily_usage(self) -> int: """聚合今日 token 消耗(status='success',created_at 为今日)。""" conn = self._get_conn() with conn.cursor() as cur: cur.execute( """ SELECT COALESCE(SUM(tokens_used), 0) FROM biz.ai_run_logs WHERE status = 'success' AND created_at >= CURRENT_DATE AND created_at < CURRENT_DATE + INTERVAL '1 day' """ ) row = cur.fetchone() return int(row[0]) if row else 0 def get_monthly_usage(self) -> int: """聚合本月 token 消耗(status='success',created_at 为本月)。""" conn = self._get_conn() with conn.cursor() as cur: cur.execute( """ SELECT COALESCE(SUM(tokens_used), 0) FROM biz.ai_run_logs WHERE status = 'success' AND created_at >= date_trunc('month', CURRENT_DATE) AND created_at < date_trunc('month', CURRENT_DATE) + INTERVAL '1 month' """ ) row = cur.fetchone() return int(row[0]) if row else 0