"""AI 监控后台聚合服务层。 提供 Dashboard 总览、调度任务管理、调用记录查询、缓存失效、 Token 预算、批量执行(含成本二次确认)、告警管理等功能。 所有数据库操作使用 psycopg2 同步连接,方法签名为 async(FastAPI 兼容)。 查询强制 site_id 隔离(当 site_id 参数不为 None 时)。 """ from __future__ import annotations import asyncio import logging import uuid from datetime import datetime, timezone, timedelta from app.ai.budget_tracker import BudgetTracker from app.database import get_connection logger = logging.getLogger(__name__) # 批量执行预估:每次调用平均 Token 消耗 AVG_TOKENS_PER_CALL = 2000 # 批量执行内存存储 TTL(秒) _BATCH_TTL_SECONDS = 600 # 10 分钟 class AdminAIService: """AI 监控后台聚合服务。""" def __init__(self, budget_tracker: BudgetTracker | None = None) -> None: self._budget = budget_tracker self._batch_store: dict[str, dict] = {} # batch_id → {params, expires_at} # ── Dashboard ───────────────────────────────────────── async def get_dashboard(self, site_id: int | None = None) -> dict: """聚合所有 Dashboard 数据。""" today_stats = await self._get_today_stats(site_id) trend_7d = await self._get_7d_trend(site_id) app_dist = await self._get_app_distribution(site_id) app_health = await self._get_app_health(site_id) budget = await self.get_budget() recent_alerts = await self._get_recent_alerts(site_id) return { **today_stats, "trend_7d": trend_7d, "app_distribution": app_dist, "budget": budget, "recent_alerts": recent_alerts, "app_health": app_health, } async def _get_today_stats(self, site_id: int | None) -> dict: """今日调用次数、成功率、Token 消耗、平均延迟。""" site_clause, params = _site_filter(site_id) conn = get_connection() try: with conn.cursor() as cur: cur.execute( f""" SELECT COUNT(*) AS total_calls, COUNT(*) FILTER (WHERE status = 'success') AS success_count, COALESCE(SUM(tokens_used), 0) AS total_tokens, COALESCE(AVG(latency_ms) FILTER (WHERE latency_ms IS NOT NULL), 0) AS avg_latency FROM biz.ai_run_logs WHERE created_at >= CURRENT_DATE AND created_at < CURRENT_DATE + INTERVAL '1 day' {site_clause} """, params, ) row = cur.fetchone() conn.commit() finally: conn.close() total, success, tokens, avg_lat = row if row else (0, 0, 0, 0) rate = round(success / total, 4) if total > 0 else 0.0 return { "today_calls": total, "today_success_rate": rate, "today_tokens": int(tokens), "today_avg_latency_ms": round(float(avg_lat), 2), } async def _get_7d_trend(self, site_id: int | None) -> list[dict]: """近 7 天按日聚合。""" site_clause, params = _site_filter(site_id) conn = get_connection() try: with conn.cursor() as cur: cur.execute( f""" SELECT created_at::date AS day, COUNT(*) AS calls, COUNT(*) FILTER (WHERE status = 'success') AS success_count FROM biz.ai_run_logs WHERE created_at >= CURRENT_DATE - INTERVAL '6 days' {site_clause} GROUP BY day ORDER BY day """, params, ) rows = cur.fetchall() conn.commit() finally: conn.close() return [ { "date": row[0].isoformat(), "calls": row[1], "success_rate": round(row[2] / row[1], 4) if row[1] > 0 else 0.0, } for row in rows ] async def _get_app_distribution(self, site_id: int | None) -> list[dict]: """各 App 调用占比。""" site_clause, params = _site_filter(site_id) conn = get_connection() try: with conn.cursor() as cur: cur.execute( f""" SELECT app_type, COUNT(*) AS cnt FROM biz.ai_run_logs WHERE created_at >= CURRENT_DATE - INTERVAL '6 days' {site_clause} GROUP BY app_type ORDER BY cnt DESC """, params, ) rows = cur.fetchall() conn.commit() finally: conn.close() total = sum(r[1] for r in rows) or 1 return [ { "app_type": row[0], "count": row[1], "percentage": round(row[1] / total, 4), } for row in rows ] async def _get_app_health(self, site_id: int | None) -> list[dict]: """各 App 最近一次调用状态。""" site_clause, params = _site_filter(site_id) conn = get_connection() try: with conn.cursor() as cur: cur.execute( f""" SELECT DISTINCT ON (app_type) app_type, status AS last_status, created_at AS last_call_at FROM biz.ai_run_logs WHERE TRUE {site_clause} ORDER BY app_type, created_at DESC """, params, ) rows = cur.fetchall() conn.commit() finally: conn.close() return [ { "app_type": row[0], "last_status": row[1], "last_call_at": row[2].isoformat() if row[2] else None, } for row in rows ] async def _get_recent_alerts(self, site_id: int | None, limit: int = 10) -> list[dict]: """最近告警事件(Dashboard 用)。""" site_clause, params = _site_filter(site_id) params = (*params, limit) conn = get_connection() try: with conn.cursor() as cur: cur.execute( f""" SELECT id, app_type, status, alert_status, error_message, created_at FROM biz.ai_run_logs WHERE status IN ('failed', 'timeout', 'circuit_open') {site_clause} ORDER BY created_at DESC LIMIT %s """, params, ) cols = [d[0] for d in cur.description] rows = cur.fetchall() conn.commit() finally: conn.close() return [_row_to_dict(cols, r) for r in rows] # ── 调度任务 ────────────────────────────────────────── async def list_trigger_jobs( self, filters: dict, page: int = 1, page_size: int = 20, ) -> dict: """分页查询 ai_trigger_jobs + 今日去重统计。""" where_parts: list[str] = [] params: list = [] for key in ("event_type", "status", "site_id"): if filters.get(key) is not None: where_parts.append(f"{key} = %s") params.append(filters[key]) if filters.get("date_from"): where_parts.append("created_at >= %s") params.append(filters["date_from"]) if filters.get("date_to"): where_parts.append("created_at <= %s") params.append(filters["date_to"]) where_sql = ("WHERE " + " AND ".join(where_parts)) if where_parts else "" offset = (page - 1) * page_size conn = get_connection() try: with conn.cursor() as cur: # 总数 cur.execute( f"SELECT COUNT(*) FROM biz.ai_trigger_jobs {where_sql}", params, ) total = cur.fetchone()[0] # 分页数据 cur.execute( f""" SELECT id, event_type, member_id, status, app_chain, is_forced, site_id, started_at, finished_at, created_at FROM biz.ai_trigger_jobs {where_sql} ORDER BY created_at DESC LIMIT %s OFFSET %s """, (*params, page_size, offset), ) cols = [d[0] for d in cur.description] rows = cur.fetchall() # 今日去重跳过数 cur.execute( """ SELECT COUNT(*) FROM biz.ai_trigger_jobs WHERE status = 'skipped_duplicate' AND created_at >= CURRENT_DATE AND created_at < CURRENT_DATE + INTERVAL '1 day' """, ) today_skipped = cur.fetchone()[0] conn.commit() finally: conn.close() return { "items": [_row_to_dict(cols, r) for r in rows], "total": total, "page": page, "page_size": page_size, "today_skipped_duplicates": today_skipped, } async def get_trigger_job(self, job_id: int) -> dict | None: """单条调度任务详情。""" conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ SELECT id, event_type, member_id, status, app_chain, is_forced, site_id, started_at, finished_at, created_at, payload, error_message, connector_type FROM biz.ai_trigger_jobs WHERE id = %s """, (job_id,), ) cols = [d[0] for d in cur.description] row = cur.fetchone() conn.commit() finally: conn.close() if row is None: return None return _row_to_dict(cols, row) async def retry_trigger_job(self, job_id: int) -> int: """创建新 trigger_job(is_forced=true),返回新 job_id。""" original = await self.get_trigger_job(job_id) if original is None: raise ValueError(f"trigger_job {job_id} 不存在") conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ INSERT INTO biz.ai_trigger_jobs (event_type, member_id, site_id, connector_type, payload, app_chain, is_forced, status) VALUES (%s, %s, %s, %s, %s, %s, true, 'pending') RETURNING id """, ( original["event_type"], original.get("member_id"), original["site_id"], original.get("connector_type", "feiqiu"), original.get("payload"), original.get("app_chain"), ), ) new_id = cur.fetchone()[0] conn.commit() except Exception: conn.rollback() raise finally: conn.close() return new_id # ── 调用记录 ────────────────────────────────────────── async def list_run_logs( self, filters: dict, page: int = 1, page_size: int = 20, ) -> dict: """分页查询 ai_run_logs。""" where_parts: list[str] = [] params: list = [] for key in ("app_type", "status", "trigger_type", "site_id"): if filters.get(key) is not None: where_parts.append(f"{key} = %s") params.append(filters[key]) if filters.get("date_from"): where_parts.append("created_at >= %s") params.append(filters["date_from"]) if filters.get("date_to"): where_parts.append("created_at <= %s") params.append(filters["date_to"]) where_sql = ("WHERE " + " AND ".join(where_parts)) if where_parts else "" offset = (page - 1) * page_size conn = get_connection() try: with conn.cursor() as cur: cur.execute( f"SELECT COUNT(*) FROM biz.ai_run_logs {where_sql}", params, ) total = cur.fetchone()[0] cur.execute( f""" SELECT id, app_type, trigger_type, member_id, tokens_used, latency_ms, status, site_id, created_at FROM biz.ai_run_logs {where_sql} ORDER BY created_at DESC LIMIT %s OFFSET %s """, (*params, page_size, offset), ) cols = [d[0] for d in cur.description] rows = cur.fetchall() conn.commit() finally: conn.close() return { "items": [_row_to_dict(cols, r) for r in rows], "total": total, "page": page, "page_size": page_size, } async def get_run_log(self, log_id: int) -> dict | None: """单条调用记录详情(含完整 prompt/response,不脱敏)。""" conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ SELECT id, app_type, trigger_type, member_id, tokens_used, latency_ms, status, site_id, created_at, request_prompt, response_text, error_message, session_id, finished_at FROM biz.ai_run_logs WHERE id = %s """, (log_id,), ) cols = [d[0] for d in cur.description] row = cur.fetchone() conn.commit() finally: conn.close() if row is None: return None return _row_to_dict(cols, row) # ── 缓存管理 ────────────────────────────────────────── async def invalidate_cache( self, site_id: int, app_type: str | None = None, member_id: int | None = None, ) -> int: """批量缓存失效,返回受影响记录数。""" where_parts = ["site_id = %s"] params: list = [site_id] if app_type is not None: where_parts.append("cache_type = %s") params.append(app_type) if member_id is not None: where_parts.append("target_id = %s") params.append(str(member_id)) where_sql = " AND ".join(where_parts) conn = get_connection() try: with conn.cursor() as cur: cur.execute( f""" UPDATE biz.ai_cache SET status = 'invalidated' WHERE {where_sql} AND status != 'invalidated' """, params, ) affected = cur.rowcount conn.commit() except Exception: conn.rollback() raise finally: conn.close() return affected # ── Token 预算 ──────────────────────────────────────── async def get_budget(self) -> dict: """Token 预算使用情况。""" if self._budget is not None: status = self._budget.check_budget() daily_limit = self._budget.daily_limit monthly_limit = self._budget.monthly_limit return { "daily_used": status.daily_used, "daily_limit": daily_limit, "daily_pct": round(status.daily_used / daily_limit, 4) if daily_limit > 0 else 0.0, "monthly_used": status.monthly_used, "monthly_limit": monthly_limit, "monthly_pct": round(status.monthly_used / monthly_limit, 4) if monthly_limit > 0 else 0.0, } # 无 BudgetTracker 时直接查询 conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ SELECT COALESCE(SUM(tokens_used) FILTER ( WHERE created_at >= CURRENT_DATE AND created_at < CURRENT_DATE + INTERVAL '1 day' ), 0) AS daily_used, COALESCE(SUM(tokens_used) FILTER ( WHERE created_at >= date_trunc('month', CURRENT_DATE) AND created_at < date_trunc('month', CURRENT_DATE) + INTERVAL '1 month' ), 0) AS monthly_used FROM biz.ai_run_logs WHERE status = 'success' """, ) row = cur.fetchone() conn.commit() finally: conn.close() daily_used, monthly_used = (int(row[0]), int(row[1])) if row else (0, 0) daily_limit = 100_000 monthly_limit = 2_000_000 return { "daily_used": daily_used, "daily_limit": daily_limit, "daily_pct": round(daily_used / daily_limit, 4) if daily_limit > 0 else 0.0, "monthly_used": monthly_used, "monthly_limit": monthly_limit, "monthly_pct": round(monthly_used / monthly_limit, 4) if monthly_limit > 0 else 0.0, } # ── 批量执行 ────────────────────────────────────────── async def estimate_batch( self, app_types: list[str], member_ids: list[int], site_id: int, ) -> dict: """生成 batch_id,存入内存(TTL 10min),返回预估。""" self._cleanup_expired_batches() batch_id = uuid.uuid4().hex estimated_calls = len(app_types) * len(member_ids) estimated_tokens = estimated_calls * AVG_TOKENS_PER_CALL self._batch_store[batch_id] = { "params": { "app_types": app_types, "member_ids": member_ids, "site_id": site_id, }, "expires_at": datetime.now(timezone.utc) + timedelta(seconds=_BATCH_TTL_SECONDS), } return { "batch_id": batch_id, "estimated_calls": estimated_calls, "estimated_tokens": estimated_tokens, } async def confirm_batch(self, batch_id: str) -> None: """取出参数,异步执行批量调用。""" self._cleanup_expired_batches() entry = self._batch_store.pop(batch_id, None) if entry is None: raise ValueError(f"batch_id 无效或已过期: {batch_id}") params = entry["params"] logger.info( "批量执行确认: batch_id=%s apps=%s members=%d site_id=%s", batch_id, params["app_types"], len(params["member_ids"]), params["site_id"], ) # 后台异步执行(具体调用链由路由层注入 dispatcher 处理) asyncio.create_task( self._run_batch(params["app_types"], params["member_ids"], params["site_id"]) ) async def _run_batch( self, app_types: list[str], member_ids: list[int], site_id: int, ) -> None: """后台批量执行(占位实现,实际由 dispatcher 驱动)。""" logger.info( "批量执行开始: apps=%s members=%d site_id=%s", app_types, len(member_ids), site_id, ) # 实际执行逻辑在路由层通过 dispatcher.handle_trigger 驱动 # 此处仅记录日志,避免服务层直接依赖 dispatcher 实例 def _cleanup_expired_batches(self) -> None: """清理过期 batch。""" now = datetime.now(timezone.utc) expired = [ bid for bid, entry in self._batch_store.items() if entry["expires_at"] <= now ] for bid in expired: del self._batch_store[bid] if expired: logger.debug("清理过期 batch: %d 个", len(expired)) # ── 告警管理 ────────────────────────────────────────── async def list_alerts( self, alert_status: str | None = None, site_id: int | None = None, page: int = 1, page_size: int = 20, ) -> dict: """告警列表:ai_run_logs WHERE status IN ('failed','timeout','circuit_open')。""" where_parts = ["status IN ('failed', 'timeout', 'circuit_open')"] params: list = [] if alert_status is not None: if alert_status == "pending": # pending 包含 NULL 和 'pending' where_parts.append("(alert_status IS NULL OR alert_status = 'pending')") else: where_parts.append("alert_status = %s") params.append(alert_status) if site_id is not None: where_parts.append("site_id = %s") params.append(site_id) where_sql = "WHERE " + " AND ".join(where_parts) offset = (page - 1) * page_size conn = get_connection() try: with conn.cursor() as cur: cur.execute( f"SELECT COUNT(*) FROM biz.ai_run_logs {where_sql}", params, ) total = cur.fetchone()[0] cur.execute( f""" SELECT id, app_type, status, alert_status, error_message, created_at FROM biz.ai_run_logs {where_sql} ORDER BY created_at DESC LIMIT %s OFFSET %s """, (*params, page_size, offset), ) cols = [d[0] for d in cur.description] rows = cur.fetchall() conn.commit() finally: conn.close() return { "items": [_row_to_dict(cols, r) for r in rows], "total": total, "page": page, "page_size": page_size, } async def ack_alert(self, log_id: int) -> str: """确认告警:alert_status → acknowledged。""" conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ UPDATE biz.ai_run_logs SET alert_status = 'acknowledged' WHERE id = %s AND status IN ('failed', 'timeout', 'circuit_open') """, (log_id,), ) conn.commit() except Exception: conn.rollback() raise finally: conn.close() return "acknowledged" async def ignore_alert(self, log_id: int) -> str: """忽略告警:alert_status → ignored。""" conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ UPDATE biz.ai_run_logs SET alert_status = 'ignored' WHERE id = %s AND status IN ('failed', 'timeout', 'circuit_open') """, (log_id,), ) conn.commit() except Exception: conn.rollback() raise finally: conn.close() return "ignored" # ── 工具函数 ────────────────────────────────────────────── def _site_filter(site_id: int | None) -> tuple[str, tuple]: """生成 site_id 过滤子句和参数。""" if site_id is None: return "", () return "AND site_id = %s", (site_id,) def _row_to_dict(columns: list[str], row: tuple) -> dict: """将数据库行转换为 dict,处理 datetime 序列化。""" result = {} for col, val in zip(columns, row): if isinstance(val, datetime): result[col] = val.isoformat() else: result[col] = val return result