""" AI 缓存读写服务。 负责 biz.ai_cache 表的 CRUD 和保留策略管理。 所有查询和写入操作强制 site_id 隔离。 P14 改造: - 新增 status 字段处理(valid/expired/invalidated/generating) - 查询仅返回 status='valid' 且未过期的记录 - 按 App 类型设置过期时间 - 每 App 保留最新 20,000 条 """ from __future__ import annotations import json import logging from datetime import datetime, timedelta, timezone from app.database import get_connection from app.services.runtime_context import ( LIVE_INSTANCE_ID, MODE_LIVE, MODE_SANDBOX, get_runtime_context, ) logger = logging.getLogger(__name__) # 缓存过期策略:cache_type → 过期天数(0 表示当日 23:59:59) # 命名与 prompt 文件名一致(W1-AI-CLOSURE 组 1 数据库迁移已统一)。 # P0-11 修正:补 app2a_finance_area,与 app2_finance 同当日过期策略, # 之前漏配导致 64 个区域组合缓存 expires_at=NULL 永不过期。 CACHE_EXPIRY_DAYS: dict[str, int] = { "app2_finance": 0, "app2a_finance_area": 0, "app3_clue": 7, "app4_analysis": 7, "app5_tactics": 7, "app6_note": 30, "app7_customer": 7, "app8_consolidation": 7, } # 每 App 保留上限 CACHE_MAX_PER_APP = 20_000 class AICacheService: """AI 缓存读写服务。""" @staticmethod def _runtime_scope(site_id: int, target_id: str, conn) -> tuple[str, str, str]: """返回运行模式、实例 ID 和实际 cache target_id。""" ctx = get_runtime_context(site_id, conn=conn) if ctx.is_sandbox and ctx.sandbox_instance_id: return MODE_SANDBOX, ctx.sandbox_instance_id, f"{ctx.sandbox_instance_id}:{target_id}" return MODE_LIVE, LIVE_INSTANCE_ID, target_id def get_latest( self, cache_type: str, site_id: int, target_id: str, ) -> dict | None: """查询最新有效缓存记录。 仅返回 status='valid' 且未过期的记录。 无记录时返回 None。 """ conn = get_connection() try: runtime_mode, sandbox_instance_id, scoped_target_id = self._runtime_scope( site_id, target_id, conn ) with conn.cursor() as cur: cur.execute( """ SELECT id, cache_type, site_id, target_id, result_json, score, triggered_by, created_at, expires_at, status FROM biz.ai_cache WHERE cache_type = %s AND site_id = %s AND target_id = %s AND COALESCE(runtime_mode, 'live') = %s AND COALESCE(sandbox_instance_id, 'live') = %s AND (status = 'valid' OR status IS NULL) AND (expires_at IS NULL OR expires_at > now()) ORDER BY created_at DESC LIMIT 1 """, (cache_type, site_id, scoped_target_id, runtime_mode, sandbox_instance_id), ) columns = [desc[0] for desc in cur.description] row = cur.fetchone() if row is None: return None return _row_to_dict(columns, row) finally: conn.close() def get_history( self, cache_type: str, site_id: int, target_id: str, limit: int = 2, ) -> list[dict]: """查询历史缓存记录(按 created_at DESC),用于 Prompt reference。 无记录时返回空列表。 """ conn = get_connection() try: runtime_mode, sandbox_instance_id, scoped_target_id = self._runtime_scope( site_id, target_id, conn ) with conn.cursor() as cur: cur.execute( """ SELECT id, cache_type, site_id, target_id, result_json, score, triggered_by, created_at, expires_at FROM biz.ai_cache WHERE cache_type = %s AND site_id = %s AND target_id = %s AND COALESCE(runtime_mode, 'live') = %s AND COALESCE(sandbox_instance_id, 'live') = %s ORDER BY created_at DESC LIMIT %s """, (cache_type, site_id, scoped_target_id, runtime_mode, sandbox_instance_id, limit), ) columns = [desc[0] for desc in cur.description] rows = cur.fetchall() return [_row_to_dict(columns, row) for row in rows] finally: conn.close() def write_cache( self, cache_type: str, site_id: int, target_id: str, result_json: dict, triggered_by: str | None = None, score: int | None = None, expires_at: datetime | None = None, ) -> int: """写入缓存记录,返回 id。 自动设置 status='valid' 和按 App 类型计算 expires_at。 写入后清理超限记录(每 App 保留 20,000 条)。 """ # 自动计算过期时间(如果未显式指定) if expires_at is None: expires_at = self._calc_expires_at(cache_type) conn = get_connection() try: runtime_mode, sandbox_instance_id, scoped_target_id = self._runtime_scope( site_id, target_id, conn ) with conn.cursor() as cur: cur.execute( """ INSERT INTO biz.ai_cache (cache_type, site_id, target_id, result_json, triggered_by, score, expires_at, status, runtime_mode, sandbox_instance_id) VALUES (%s, %s, %s, %s, %s, %s, %s, 'valid', %s, %s) RETURNING id """, ( cache_type, site_id, scoped_target_id, json.dumps(result_json, ensure_ascii=False), triggered_by, score, expires_at, runtime_mode, sandbox_instance_id, ), ) row = cur.fetchone() conn.commit() cache_id: int = row[0] except Exception: conn.rollback() raise finally: conn.close() # 写入成功后清理超限记录 try: deleted = self._cleanup_excess(cache_type, site_id, scoped_target_id) if deleted > 0: logger.info( "清理超限缓存: cache_type=%s site_id=%s target_id=%s 删除=%d", cache_type, site_id, target_id, deleted, ) except Exception: logger.warning( "清理超限缓存失败: cache_type=%s site_id=%s target_id=%s", cache_type, site_id, target_id, exc_info=True, ) return cache_id def set_generating( self, cache_type: str, site_id: int, target_id: str, triggered_by: str | None = None, ) -> int: """写入 generating 状态占位记录,返回 id。完成后调用 finalize_cache 更新。""" conn = get_connection() try: runtime_mode, sandbox_instance_id, scoped_target_id = self._runtime_scope( site_id, target_id, conn ) with conn.cursor() as cur: cur.execute( """ INSERT INTO biz.ai_cache (cache_type, site_id, target_id, result_json, status, triggered_by, runtime_mode, sandbox_instance_id) VALUES (%s, %s, %s, '{}', 'generating', %s, %s, %s) RETURNING id """, (cache_type, site_id, scoped_target_id, triggered_by, runtime_mode, sandbox_instance_id), ) row = cur.fetchone() conn.commit() return row[0] except Exception: conn.rollback() raise finally: conn.close() def finalize_cache( self, cache_id: int, result_json: dict, score: int | None = None, cache_type: str | None = None, ) -> None: """将 generating 记录更新为 valid,填充结果和过期时间。""" expires_at = self._calc_expires_at(cache_type) if cache_type else None conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ UPDATE biz.ai_cache SET result_json = %s, score = %s, status = 'valid', expires_at = %s WHERE id = %s AND status = 'generating' """, ( json.dumps(result_json, ensure_ascii=False), score, expires_at, cache_id, ), ) conn.commit() except Exception: conn.rollback() raise finally: conn.close() @staticmethod def _calc_expires_at(cache_type: str | None) -> datetime | None: """根据 cache_type 计算过期时间。未知类型返回 None。""" if cache_type is None: return None days = CACHE_EXPIRY_DAYS.get(cache_type) if days is None: return None now = datetime.now(timezone.utc) if days == 0: # 当日 23:59:59(UTC+8) local_now = now + timedelta(hours=8) end_of_day = local_now.replace(hour=23, minute=59, second=59, microsecond=0) return end_of_day - timedelta(hours=8) # 转回 UTC return now + timedelta(days=days) def _cleanup_excess( self, cache_type: str, site_id: int, target_id: str, max_count: int = CACHE_MAX_PER_APP, ) -> int: """清理超限记录,保留最近 max_count 条,返回删除数量。""" conn = get_connection() try: with conn.cursor() as cur: # 删除超出保留上限的最旧记录 cur.execute( """ DELETE FROM biz.ai_cache WHERE id IN ( SELECT id FROM biz.ai_cache WHERE cache_type = %s AND site_id = %s AND target_id = %s ORDER BY created_at DESC OFFSET %s ) """, (cache_type, site_id, target_id, max_count), ) deleted = cur.rowcount conn.commit() return deleted except Exception: conn.rollback() raise finally: conn.close() def _row_to_dict(columns: list[str], row: tuple) -> dict: """将数据库行转换为 dict,处理特殊类型序列化。""" result = {} for col, val in zip(columns, row): if isinstance(val, datetime): result[col] = val.isoformat() else: result[col] = val return result