#!/usr/bin/env python3 """batch_generate_summaries — 批量为 session log 生成 LLM 摘要。 从双索引中找出所有缺少 description 的主对话 entry, 并发调用百炼千问 API 生成摘要,写回索引。 用法: python -B scripts/ops/batch_generate_summaries.py # 全量 python -B scripts/ops/batch_generate_summaries.py --limit 10 # 只处理 10 条 python -B scripts/ops/batch_generate_summaries.py --concurrency 5 # 并发 5 python -B scripts/ops/batch_generate_summaries.py --dry-run # 预览 """ import asyncio import json import os import re import sys import time from dotenv import load_dotenv load_dotenv() BAILIAN_API_KEY = os.environ.get("BAILIAN_API_KEY", "") if not BAILIAN_API_KEY: raise RuntimeError("BAILIAN_API_KEY 未设置,请检查 .env 文件") MODEL_NAME = os.environ.get("BAILIAN_MODEL", "qwen-plus") BASE_URL = os.environ.get("BAILIAN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1") # 导入索引管理函数 sys.path.insert(0, os.path.join(os.path.dirname(__file__))) from extract_kiro_session import ( load_index, save_index, load_full_index, save_full_index, ) SYSTEM_PROMPT = """你是一个专业的技术对话分析师。你的任务是为 AI 编程助手的一轮执行(execution)生成简洁的中文摘要。 背景:一个对话(chatSession)包含多轮执行(execution)。每轮执行 = 用户发一条消息 → AI 完成响应。你收到的是单轮执行的完整记录。 摘要规则: 1. 只描述本轮执行实际完成的工作,不要描述历史背景 2. 列出完成的功能点/任务(一轮可能完成多个) 3. 包含关键技术细节:文件路径、模块名、数据库表、API 端点等 4. bug 修复要说明原因和方案 5. 不写过程性描述("用户说..."),只写结果 6. 内容太短或无实质内容的,写"无实质内容" 7. 不限字数,信息完整优先,避免截断失真 重要: - "执行摘要"(📋)是最可靠的信息源,优先基于它判断本轮做了什么 - 如果"用户输入"包含 CONTEXT TRANSFER,那是之前多轮的历史摘要,不是本轮工作 - 对话记录中的实际工具调用和文件变更才是本轮的真实操作 请直接输出摘要,不要添加任何前缀或解释。""" def _extract_summary_content(md_content: str) -> str: """检测 CONTEXT TRANSFER,替换用户输入为简短标注。""" ct_pattern = re.compile( r"## 2\. 用户输入\s*\n```\s*\n.*?CONTEXT TRANSFER", re.DOTALL ) if ct_pattern.search(md_content): md_content = re.sub( r"(## 2\. 用户输入)\s*\n```[\s\S]*?```\s*\n(?=## 3\.)", r"\1\n\n[本轮为 Context Transfer 续接,已省略。]\n\n", md_content, ) return md_content PLACEHOLDER_PREFIX = "[待生成摘要]" def collect_targets(index: dict, include_placeholder: bool = False) -> list[tuple[str, dict]]: """收集所有缺少 description 的主对话 entry。 include_placeholder=True 时,也包含占位标记的 entry(用于覆盖生成真实摘要)。 """ targets = [] for eid, ent in index.get("entries", {}).items(): if ent.get("is_sub"): continue # 跳过已替代或无日志的占位条目 if ent.get("superseded_by") or ent.get("no_log"): continue desc = ent.get("description", "") if not desc: targets.append((eid, ent)) elif include_placeholder and desc.startswith(PLACEHOLDER_PREFIX): targets.append((eid, ent)) # 按 startTime 排序(旧的先处理) targets.sort(key=lambda t: t[1].get("startTime", "")) return targets def load_md_content(eid: str, entry: dict) -> str | None: """加载 entry 对应的 main_*.md 内容。""" out_dir = entry.get("output_dir", "") if not out_dir or not os.path.isdir(out_dir): return None eid_short = eid[:8] main_files = sorted( f for f in os.listdir(out_dir) if f.startswith("main_") and f.endswith(".md") and eid_short in f ) if not main_files: main_files = sorted( f for f in os.listdir(out_dir) if f.startswith("main_") and f.endswith(".md") ) if not main_files: return None parts = [] for mf in main_files: try: with open(os.path.join(out_dir, mf), "r", encoding="utf-8") as fh: parts.append(fh.read()) except Exception: continue if not parts: return None content = "\n\n---\n\n".join(parts) content = _extract_summary_content(content) if len(content) > 60000: content = content[:60000] + "\n\n[TRUNCATED]" return content async def generate_one( client, content: str, semaphore: asyncio.Semaphore, max_retries: int = 3, ) -> str: """调用百炼 API 生成摘要,带限流和指数退避。""" async with semaphore: for attempt in range(max_retries): try: resp = await client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": f"请为以下单轮执行记录生成摘要:\n\n{content}"}, ], max_tokens=4096, ) return resp.choices[0].message.content.strip() except Exception as e: if attempt < max_retries - 1: wait = 2 ** attempt print(f" ⏳ 重试 {attempt+1}: {e}", file=sys.stderr) await asyncio.sleep(wait) else: return "" return "" async def process_target( client, eid: str, entry: dict, semaphore: asyncio.Semaphore, ) -> tuple[str, str]: """处理单个 target,返回 (eid, description)。""" content = load_md_content(eid, entry) if not content: return (eid, "") desc = await generate_one(client, content, semaphore) return (eid, desc) async def main(): import argparse from openai import AsyncOpenAI parser = argparse.ArgumentParser(description="批量生成 session 摘要") parser.add_argument("--limit", type=int, default=0, help="只处理前 N 条(0=全量)") parser.add_argument("--concurrency", type=int, default=10, help="并发数(默认 10)") parser.add_argument("--batch-size", type=int, default=20, help="每批处理条数,每批完成后保存 checkpoint(默认 20)") parser.add_argument("--replace-placeholder", action="store_true", help="也处理占位标记 '[待生成摘要]' 的 entry") parser.add_argument("--dry-run", action="store_true", help="预览待处理列表") args = parser.parse_args() index = load_index() full_index = load_full_index() targets = collect_targets(index, include_placeholder=args.replace_placeholder) if args.limit > 0: targets = targets[:args.limit] print(f"待处理: {len(targets)} 条缺少 description 的主对话 entry") if not targets: print("全部已有摘要,无需处理") return if args.dry_run: for eid, ent in targets[:20]: st = ent.get("startTime", "?") od = ent.get("output_dir", "?") print(f" {eid[:8]} | {st} | {od}") if len(targets) > 20: print(f" ... 还有 {len(targets) - 20} 条") return client = AsyncOpenAI(api_key=BAILIAN_API_KEY, base_url=BASE_URL) semaphore = asyncio.Semaphore(args.concurrency) t0 = time.time() total_generated = 0 total_failed = 0 batch_size = args.batch_size # 分批处理 + 每批 checkpoint 保存 for batch_start in range(0, len(targets), batch_size): batch = targets[batch_start:batch_start + batch_size] batch_num = batch_start // batch_size + 1 total_batches = (len(targets) + batch_size - 1) // batch_size print(f"\n📦 批次 {batch_num}/{total_batches}({len(batch)} 条)...") tasks = [process_target(client, eid, ent, semaphore) for eid, ent in batch] results = await asyncio.gather(*tasks) # 写回索引 batch_generated = 0 full_entries = full_index.get("entries", {}) idx_entries = index.get("entries", {}) for eid, desc in results: if not desc: total_failed += 1 continue if eid in idx_entries: idx_entries[eid]["description"] = desc if eid in full_entries: full_entries[eid]["description"] = desc batch_generated += 1 total_generated += batch_generated # checkpoint:每批完成后立即保存 if batch_generated > 0: save_index(index) save_full_index(full_index) elapsed = time.time() - t0 print(f" ✅ 本批 {batch_generated} 条已保存(累计 {total_generated},耗时 {elapsed:.1f}s)") elapsed = time.time() - t0 print(f"\n完成: {total_generated}/{len(targets)} 生成成功," f"{total_failed} 失败,耗时 {elapsed:.1f}s") if __name__ == "__main__": asyncio.run(main())