Files
Neo-ZQYY/scripts/ops/batch_generate_summaries.py

266 lines
9.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""batch_generate_summaries — 批量为 session log 生成 LLM 摘要。
从双索引中找出所有缺少 description 的主对话 entry
并发调用百炼千问 API 生成摘要,写回索引。
用法:
python -B scripts/ops/batch_generate_summaries.py # 全量
python -B scripts/ops/batch_generate_summaries.py --limit 10 # 只处理 10 条
python -B scripts/ops/batch_generate_summaries.py --concurrency 5 # 并发 5
python -B scripts/ops/batch_generate_summaries.py --dry-run # 预览
"""
import asyncio
import json
import os
import re
import sys
import time
from dotenv import load_dotenv
load_dotenv()
BAILIAN_API_KEY = os.environ.get("BAILIAN_API_KEY", "")
if not BAILIAN_API_KEY:
raise RuntimeError("BAILIAN_API_KEY 未设置,请检查 .env 文件")
MODEL_NAME = os.environ.get("BAILIAN_MODEL", "qwen-plus")
BASE_URL = os.environ.get("BAILIAN_BASE_URL",
"https://dashscope.aliyuncs.com/compatible-mode/v1")
# 导入索引管理函数
sys.path.insert(0, os.path.join(os.path.dirname(__file__)))
from extract_kiro_session import (
load_index, save_index, load_full_index, save_full_index,
)
SYSTEM_PROMPT = """你是一个专业的技术对话分析师。你的任务是为 AI 编程助手的一轮执行execution生成简洁的中文摘要。
背景一个对话chatSession包含多轮执行execution。每轮执行 = 用户发一条消息 → AI 完成响应。你收到的是单轮执行的完整记录。
摘要规则:
1. 只描述本轮执行实际完成的工作,不要描述历史背景
2. 列出完成的功能点/任务(一轮可能完成多个)
3. 包含关键技术细节文件路径、模块名、数据库表、API 端点等
4. bug 修复要说明原因和方案
5. 不写过程性描述("用户说..."),只写结果
6. 内容太短或无实质内容的,写"无实质内容"
7. 不限字数,信息完整优先,避免截断失真
重要:
- "执行摘要"(📋)是最可靠的信息源,优先基于它判断本轮做了什么
- 如果"用户输入"包含 CONTEXT TRANSFER那是之前多轮的历史摘要不是本轮工作
- 对话记录中的实际工具调用和文件变更才是本轮的真实操作
请直接输出摘要,不要添加任何前缀或解释。"""
def _extract_summary_content(md_content: str) -> str:
"""检测 CONTEXT TRANSFER替换用户输入为简短标注。"""
ct_pattern = re.compile(
r"## 2\. 用户输入\s*\n```\s*\n.*?CONTEXT TRANSFER", re.DOTALL
)
if ct_pattern.search(md_content):
md_content = re.sub(
r"(## 2\. 用户输入)\s*\n```[\s\S]*?```\s*\n(?=## 3\.)",
r"\1\n\n[本轮为 Context Transfer 续接,已省略。]\n\n",
md_content,
)
return md_content
PLACEHOLDER_PREFIX = "[待生成摘要]"
def collect_targets(index: dict, include_placeholder: bool = False) -> list[tuple[str, dict]]:
"""收集所有缺少 description 的主对话 entry。
include_placeholder=True 时,也包含占位标记的 entry用于覆盖生成真实摘要
"""
targets = []
for eid, ent in index.get("entries", {}).items():
if ent.get("is_sub"):
continue
# 跳过已替代或无日志的占位条目
if ent.get("superseded_by") or ent.get("no_log"):
continue
desc = ent.get("description", "")
if not desc:
targets.append((eid, ent))
elif include_placeholder and desc.startswith(PLACEHOLDER_PREFIX):
targets.append((eid, ent))
# 按 startTime 排序(旧的先处理)
targets.sort(key=lambda t: t[1].get("startTime", ""))
return targets
def load_md_content(eid: str, entry: dict) -> str | None:
"""加载 entry 对应的 main_*.md 内容。"""
out_dir = entry.get("output_dir", "")
if not out_dir or not os.path.isdir(out_dir):
return None
eid_short = eid[:8]
main_files = sorted(
f for f in os.listdir(out_dir)
if f.startswith("main_") and f.endswith(".md") and eid_short in f
)
if not main_files:
main_files = sorted(
f for f in os.listdir(out_dir)
if f.startswith("main_") and f.endswith(".md")
)
if not main_files:
return None
parts = []
for mf in main_files:
try:
with open(os.path.join(out_dir, mf), "r", encoding="utf-8") as fh:
parts.append(fh.read())
except Exception:
continue
if not parts:
return None
content = "\n\n---\n\n".join(parts)
content = _extract_summary_content(content)
if len(content) > 60000:
content = content[:60000] + "\n\n[TRUNCATED]"
return content
async def generate_one(
client,
content: str,
semaphore: asyncio.Semaphore,
max_retries: int = 3,
) -> str:
"""调用百炼 API 生成摘要,带限流和指数退避。"""
async with semaphore:
for attempt in range(max_retries):
try:
resp = await client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user",
"content": f"请为以下单轮执行记录生成摘要:\n\n{content}"},
],
max_tokens=4096,
)
return resp.choices[0].message.content.strip()
except Exception as e:
if attempt < max_retries - 1:
wait = 2 ** attempt
print(f" ⏳ 重试 {attempt+1}: {e}", file=sys.stderr)
await asyncio.sleep(wait)
else:
return ""
return ""
async def process_target(
client,
eid: str,
entry: dict,
semaphore: asyncio.Semaphore,
) -> tuple[str, str]:
"""处理单个 target返回 (eid, description)。"""
content = load_md_content(eid, entry)
if not content:
return (eid, "")
desc = await generate_one(client, content, semaphore)
return (eid, desc)
async def main():
import argparse
from openai import AsyncOpenAI
parser = argparse.ArgumentParser(description="批量生成 session 摘要")
parser.add_argument("--limit", type=int, default=0,
help="只处理前 N 条0=全量)")
parser.add_argument("--concurrency", type=int, default=10,
help="并发数(默认 10")
parser.add_argument("--batch-size", type=int, default=20,
help="每批处理条数,每批完成后保存 checkpoint默认 20")
parser.add_argument("--replace-placeholder", action="store_true",
help="也处理占位标记 '[待生成摘要]' 的 entry")
parser.add_argument("--dry-run", action="store_true",
help="预览待处理列表")
args = parser.parse_args()
index = load_index()
full_index = load_full_index()
targets = collect_targets(index, include_placeholder=args.replace_placeholder)
if args.limit > 0:
targets = targets[:args.limit]
print(f"待处理: {len(targets)} 条缺少 description 的主对话 entry")
if not targets:
print("全部已有摘要,无需处理")
return
if args.dry_run:
for eid, ent in targets[:20]:
st = ent.get("startTime", "?")
od = ent.get("output_dir", "?")
print(f" {eid[:8]} | {st} | {od}")
if len(targets) > 20:
print(f" ... 还有 {len(targets) - 20}")
return
client = AsyncOpenAI(api_key=BAILIAN_API_KEY, base_url=BASE_URL)
semaphore = asyncio.Semaphore(args.concurrency)
t0 = time.time()
total_generated = 0
total_failed = 0
batch_size = args.batch_size
# 分批处理 + 每批 checkpoint 保存
for batch_start in range(0, len(targets), batch_size):
batch = targets[batch_start:batch_start + batch_size]
batch_num = batch_start // batch_size + 1
total_batches = (len(targets) + batch_size - 1) // batch_size
print(f"\n📦 批次 {batch_num}/{total_batches}{len(batch)} 条)...")
tasks = [process_target(client, eid, ent, semaphore)
for eid, ent in batch]
results = await asyncio.gather(*tasks)
# 写回索引
batch_generated = 0
full_entries = full_index.get("entries", {})
idx_entries = index.get("entries", {})
for eid, desc in results:
if not desc:
total_failed += 1
continue
if eid in idx_entries:
idx_entries[eid]["description"] = desc
if eid in full_entries:
full_entries[eid]["description"] = desc
batch_generated += 1
total_generated += batch_generated
# checkpoint每批完成后立即保存
if batch_generated > 0:
save_index(index)
save_full_index(full_index)
elapsed = time.time() - t0
print(f" ✅ 本批 {batch_generated} 条已保存(累计 {total_generated},耗时 {elapsed:.1f}s")
elapsed = time.time() - t0
print(f"\n完成: {total_generated}/{len(targets)} 生成成功,"
f"{total_failed} 失败,耗时 {elapsed:.1f}s")
if __name__ == "__main__":
asyncio.run(main())