""" AI 缓存读写服务。 负责 biz.ai_cache 表的 CRUD 和保留策略管理。 所有查询和写入操作强制 site_id 隔离。 P14 改造: - 新增 status 字段处理(valid/expired/invalidated/generating) - 查询仅返回 status='valid' 且未过期的记录 - 按 App 类型设置过期时间 - 每 App 保留最新 20,000 条 """ from __future__ import annotations import json import logging from datetime import datetime, timedelta, timezone from app.database import get_connection logger = logging.getLogger(__name__) # 缓存过期策略:cache_type → 过期天数(0 表示当日 23:59:59) CACHE_EXPIRY_DAYS: dict[str, int] = { "app2_finance": 0, # 当日 23:59:59 "app3_clue": 7, "app4_analysis": 7, "app5_tactics": 7, "app6_note_analysis": 30, "app7_customer_analysis": 7, "app8_clue_consolidated": 7, } # 每 App 保留上限 CACHE_MAX_PER_APP = 20_000 class AICacheService: """AI 缓存读写服务。""" def get_latest( self, cache_type: str, site_id: int, target_id: str, ) -> dict | None: """查询最新有效缓存记录。 仅返回 status='valid' 且未过期的记录。 无记录时返回 None。 """ conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ SELECT id, cache_type, site_id, target_id, result_json, score, triggered_by, created_at, expires_at, status FROM biz.ai_cache WHERE cache_type = %s AND site_id = %s AND target_id = %s AND (status = 'valid' OR status IS NULL) AND (expires_at IS NULL OR expires_at > now()) ORDER BY created_at DESC LIMIT 1 """, (cache_type, site_id, target_id), ) columns = [desc[0] for desc in cur.description] row = cur.fetchone() if row is None: return None return _row_to_dict(columns, row) finally: conn.close() def get_history( self, cache_type: str, site_id: int, target_id: str, limit: int = 2, ) -> list[dict]: """查询历史缓存记录(按 created_at DESC),用于 Prompt reference。 无记录时返回空列表。 """ conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ SELECT id, cache_type, site_id, target_id, result_json, score, triggered_by, created_at, expires_at FROM biz.ai_cache WHERE cache_type = %s AND site_id = %s AND target_id = %s ORDER BY created_at DESC LIMIT %s """, (cache_type, site_id, target_id, limit), ) columns = [desc[0] for desc in cur.description] rows = cur.fetchall() return [_row_to_dict(columns, row) for row in rows] finally: conn.close() def write_cache( self, cache_type: str, site_id: int, target_id: str, result_json: dict, triggered_by: str | None = None, score: int | None = None, expires_at: datetime | None = None, ) -> int: """写入缓存记录,返回 id。 自动设置 status='valid' 和按 App 类型计算 expires_at。 写入后清理超限记录(每 App 保留 20,000 条)。 """ # 自动计算过期时间(如果未显式指定) if expires_at is None: expires_at = self._calc_expires_at(cache_type) conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ INSERT INTO biz.ai_cache (cache_type, site_id, target_id, result_json, triggered_by, score, expires_at, status) VALUES (%s, %s, %s, %s, %s, %s, %s, 'valid') RETURNING id """, ( cache_type, site_id, target_id, json.dumps(result_json, ensure_ascii=False), triggered_by, score, expires_at, ), ) row = cur.fetchone() conn.commit() cache_id: int = row[0] except Exception: conn.rollback() raise finally: conn.close() # 写入成功后清理超限记录 try: deleted = self._cleanup_excess(cache_type, site_id, target_id) if deleted > 0: logger.info( "清理超限缓存: cache_type=%s site_id=%s target_id=%s 删除=%d", cache_type, site_id, target_id, deleted, ) except Exception: logger.warning( "清理超限缓存失败: cache_type=%s site_id=%s target_id=%s", cache_type, site_id, target_id, exc_info=True, ) return cache_id def set_generating( self, cache_type: str, site_id: int, target_id: str, triggered_by: str | None = None, ) -> int: """写入 generating 状态占位记录,返回 id。完成后调用 finalize_cache 更新。""" conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ INSERT INTO biz.ai_cache (cache_type, site_id, target_id, result_json, status, triggered_by) VALUES (%s, %s, %s, '{}', 'generating', %s) RETURNING id """, (cache_type, site_id, target_id, triggered_by), ) row = cur.fetchone() conn.commit() return row[0] except Exception: conn.rollback() raise finally: conn.close() def finalize_cache( self, cache_id: int, result_json: dict, score: int | None = None, cache_type: str | None = None, ) -> None: """将 generating 记录更新为 valid,填充结果和过期时间。""" expires_at = self._calc_expires_at(cache_type) if cache_type else None conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ UPDATE biz.ai_cache SET result_json = %s, score = %s, status = 'valid', expires_at = %s WHERE id = %s AND status = 'generating' """, ( json.dumps(result_json, ensure_ascii=False), score, expires_at, cache_id, ), ) conn.commit() except Exception: conn.rollback() raise finally: conn.close() @staticmethod def _calc_expires_at(cache_type: str | None) -> datetime | None: """根据 cache_type 计算过期时间。未知类型返回 None。""" if cache_type is None: return None days = CACHE_EXPIRY_DAYS.get(cache_type) if days is None: return None now = datetime.now(timezone.utc) if days == 0: # 当日 23:59:59(UTC+8) local_now = now + timedelta(hours=8) end_of_day = local_now.replace(hour=23, minute=59, second=59, microsecond=0) return end_of_day - timedelta(hours=8) # 转回 UTC return now + timedelta(days=days) def _cleanup_excess( self, cache_type: str, site_id: int, target_id: str, max_count: int = CACHE_MAX_PER_APP, ) -> int: """清理超限记录,保留最近 max_count 条,返回删除数量。""" conn = get_connection() try: with conn.cursor() as cur: # 删除超出保留上限的最旧记录 cur.execute( """ DELETE FROM biz.ai_cache WHERE id IN ( SELECT id FROM biz.ai_cache WHERE cache_type = %s AND site_id = %s AND target_id = %s ORDER BY created_at DESC OFFSET %s ) """, (cache_type, site_id, target_id, max_count), ) deleted = cur.rowcount conn.commit() return deleted except Exception: conn.rollback() raise finally: conn.close() def _row_to_dict(columns: list[str], row: tuple) -> dict: """将数据库行转换为 dict,处理特殊类型序列化。""" result = {} for col, val in zip(columns, row): if isinstance(val, datetime): result[col] = val.isoformat() else: result[col] = val return result