Files
Neo-ZQYY/apps/backend/app/ai/cache_service.py
Neo 6f8f12314f feat: 累积功能变更 — 聊天集成、租户管理、小程序更新、ETL 增强、迁移脚本
包含多个会话的累积代码变更:
- backend: AI 聊天服务、触发器调度、认证增强、WebSocket、调度器最小间隔
- admin-web: ETL 状态页、任务管理、调度配置、登录优化
- miniprogram: 看板页面、聊天集成、UI 组件、导航更新
- etl: DWS 新任务(finance_area_daily/board_cache)、连接器增强
- tenant-admin: 项目初始化
- db: 19 个迁移脚本(etl_feiqiu 11 + zqyy_app 8)
- packages/shared: 枚举和工具函数更新
- tools: 数据库工具、报表生成、健康检查
- docs: PRD/架构/部署/合约文档更新

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-06 00:03:48 +08:00

296 lines
9.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
AI 缓存读写服务。
负责 biz.ai_cache 表的 CRUD 和保留策略管理。
所有查询和写入操作强制 site_id 隔离。
P14 改造:
- 新增 status 字段处理valid/expired/invalidated/generating
- 查询仅返回 status='valid' 且未过期的记录
- 按 App 类型设置过期时间
- 每 App 保留最新 20,000 条
"""
from __future__ import annotations
import json
import logging
from datetime import datetime, timedelta, timezone
from app.database import get_connection
logger = logging.getLogger(__name__)
# 缓存过期策略cache_type → 过期天数0 表示当日 23:59:59
CACHE_EXPIRY_DAYS: dict[str, int] = {
"app2_finance": 0, # 当日 23:59:59
"app3_clue": 7,
"app4_analysis": 7,
"app5_tactics": 7,
"app6_note_analysis": 30,
"app7_customer_analysis": 7,
"app8_clue_consolidated": 7,
}
# 每 App 保留上限
CACHE_MAX_PER_APP = 20_000
class AICacheService:
"""AI 缓存读写服务。"""
def get_latest(
self,
cache_type: str,
site_id: int,
target_id: str,
) -> dict | None:
"""查询最新有效缓存记录。
仅返回 status='valid' 且未过期的记录。
无记录时返回 None。
"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, cache_type, site_id, target_id,
result_json, score, triggered_by,
created_at, expires_at, status
FROM biz.ai_cache
WHERE cache_type = %s AND site_id = %s AND target_id = %s
AND (status = 'valid' OR status IS NULL)
AND (expires_at IS NULL OR expires_at > now())
ORDER BY created_at DESC
LIMIT 1
""",
(cache_type, site_id, target_id),
)
columns = [desc[0] for desc in cur.description]
row = cur.fetchone()
if row is None:
return None
return _row_to_dict(columns, row)
finally:
conn.close()
def get_history(
self,
cache_type: str,
site_id: int,
target_id: str,
limit: int = 2,
) -> list[dict]:
"""查询历史缓存记录(按 created_at DESC用于 Prompt reference。
无记录时返回空列表。
"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, cache_type, site_id, target_id,
result_json, score, triggered_by,
created_at, expires_at
FROM biz.ai_cache
WHERE cache_type = %s AND site_id = %s AND target_id = %s
ORDER BY created_at DESC
LIMIT %s
""",
(cache_type, site_id, target_id, limit),
)
columns = [desc[0] for desc in cur.description]
rows = cur.fetchall()
return [_row_to_dict(columns, row) for row in rows]
finally:
conn.close()
def write_cache(
self,
cache_type: str,
site_id: int,
target_id: str,
result_json: dict,
triggered_by: str | None = None,
score: int | None = None,
expires_at: datetime | None = None,
) -> int:
"""写入缓存记录,返回 id。
自动设置 status='valid' 和按 App 类型计算 expires_at。
写入后清理超限记录(每 App 保留 20,000 条)。
"""
# 自动计算过期时间(如果未显式指定)
if expires_at is None:
expires_at = self._calc_expires_at(cache_type)
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO biz.ai_cache
(cache_type, site_id, target_id, result_json,
triggered_by, score, expires_at, status)
VALUES (%s, %s, %s, %s, %s, %s, %s, 'valid')
RETURNING id
""",
(
cache_type,
site_id,
target_id,
json.dumps(result_json, ensure_ascii=False),
triggered_by,
score,
expires_at,
),
)
row = cur.fetchone()
conn.commit()
cache_id: int = row[0]
except Exception:
conn.rollback()
raise
finally:
conn.close()
# 写入成功后清理超限记录
try:
deleted = self._cleanup_excess(cache_type, site_id, target_id)
if deleted > 0:
logger.info(
"清理超限缓存: cache_type=%s site_id=%s target_id=%s 删除=%d",
cache_type, site_id, target_id, deleted,
)
except Exception:
logger.warning(
"清理超限缓存失败: cache_type=%s site_id=%s target_id=%s",
cache_type, site_id, target_id,
exc_info=True,
)
return cache_id
def set_generating(
self,
cache_type: str,
site_id: int,
target_id: str,
triggered_by: str | None = None,
) -> int:
"""写入 generating 状态占位记录,返回 id。完成后调用 finalize_cache 更新。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO biz.ai_cache
(cache_type, site_id, target_id, result_json, status, triggered_by)
VALUES (%s, %s, %s, '{}', 'generating', %s)
RETURNING id
""",
(cache_type, site_id, target_id, triggered_by),
)
row = cur.fetchone()
conn.commit()
return row[0]
except Exception:
conn.rollback()
raise
finally:
conn.close()
def finalize_cache(
self,
cache_id: int,
result_json: dict,
score: int | None = None,
cache_type: str | None = None,
) -> None:
"""将 generating 记录更新为 valid填充结果和过期时间。"""
expires_at = self._calc_expires_at(cache_type) if cache_type else None
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE biz.ai_cache
SET result_json = %s, score = %s, status = 'valid', expires_at = %s
WHERE id = %s AND status = 'generating'
""",
(
json.dumps(result_json, ensure_ascii=False),
score,
expires_at,
cache_id,
),
)
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
@staticmethod
def _calc_expires_at(cache_type: str | None) -> datetime | None:
"""根据 cache_type 计算过期时间。未知类型返回 None。"""
if cache_type is None:
return None
days = CACHE_EXPIRY_DAYS.get(cache_type)
if days is None:
return None
now = datetime.now(timezone.utc)
if days == 0:
# 当日 23:59:59UTC+8
local_now = now + timedelta(hours=8)
end_of_day = local_now.replace(hour=23, minute=59, second=59, microsecond=0)
return end_of_day - timedelta(hours=8) # 转回 UTC
return now + timedelta(days=days)
def _cleanup_excess(
self,
cache_type: str,
site_id: int,
target_id: str,
max_count: int = CACHE_MAX_PER_APP,
) -> int:
"""清理超限记录,保留最近 max_count 条,返回删除数量。"""
conn = get_connection()
try:
with conn.cursor() as cur:
# 删除超出保留上限的最旧记录
cur.execute(
"""
DELETE FROM biz.ai_cache
WHERE id IN (
SELECT id FROM biz.ai_cache
WHERE cache_type = %s AND site_id = %s AND target_id = %s
ORDER BY created_at DESC
OFFSET %s
)
""",
(cache_type, site_id, target_id, max_count),
)
deleted = cur.rowcount
conn.commit()
return deleted
except Exception:
conn.rollback()
raise
finally:
conn.close()
def _row_to_dict(columns: list[str], row: tuple) -> dict:
"""将数据库行转换为 dict处理特殊类型序列化。"""
result = {}
for col, val in zip(columns, row):
if isinstance(val, datetime):
result[col] = val.isoformat()
else:
result[col] = val
return result