266 lines
9.4 KiB
Python
266 lines
9.4 KiB
Python
#!/usr/bin/env python3
|
||
"""batch_generate_summaries — 批量为 session log 生成 LLM 摘要。
|
||
|
||
从双索引中找出所有缺少 description 的主对话 entry,
|
||
并发调用百炼千问 API 生成摘要,写回索引。
|
||
|
||
用法:
|
||
python -B scripts/ops/batch_generate_summaries.py # 全量
|
||
python -B scripts/ops/batch_generate_summaries.py --limit 10 # 只处理 10 条
|
||
python -B scripts/ops/batch_generate_summaries.py --concurrency 5 # 并发 5
|
||
python -B scripts/ops/batch_generate_summaries.py --dry-run # 预览
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import os
|
||
import re
|
||
import sys
|
||
import time
|
||
|
||
from dotenv import load_dotenv
|
||
|
||
load_dotenv()
|
||
|
||
BAILIAN_API_KEY = os.environ.get("BAILIAN_API_KEY", "")
|
||
if not BAILIAN_API_KEY:
|
||
raise RuntimeError("BAILIAN_API_KEY 未设置,请检查 .env 文件")
|
||
MODEL_NAME = os.environ.get("BAILIAN_MODEL", "qwen-plus")
|
||
BASE_URL = os.environ.get("BAILIAN_BASE_URL",
|
||
"https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||
|
||
# 导入索引管理函数
|
||
sys.path.insert(0, os.path.join(os.path.dirname(__file__)))
|
||
from extract_kiro_session import (
|
||
load_index, save_index, load_full_index, save_full_index,
|
||
)
|
||
|
||
SYSTEM_PROMPT = """你是一个专业的技术对话分析师。你的任务是为 AI 编程助手的一轮执行(execution)生成简洁的中文摘要。
|
||
|
||
背景:一个对话(chatSession)包含多轮执行(execution)。每轮执行 = 用户发一条消息 → AI 完成响应。你收到的是单轮执行的完整记录。
|
||
|
||
摘要规则:
|
||
1. 只描述本轮执行实际完成的工作,不要描述历史背景
|
||
2. 列出完成的功能点/任务(一轮可能完成多个)
|
||
3. 包含关键技术细节:文件路径、模块名、数据库表、API 端点等
|
||
4. bug 修复要说明原因和方案
|
||
5. 不写过程性描述("用户说..."),只写结果
|
||
6. 内容太短或无实质内容的,写"无实质内容"
|
||
7. 不限字数,信息完整优先,避免截断失真
|
||
|
||
重要:
|
||
- "执行摘要"(📋)是最可靠的信息源,优先基于它判断本轮做了什么
|
||
- 如果"用户输入"包含 CONTEXT TRANSFER,那是之前多轮的历史摘要,不是本轮工作
|
||
- 对话记录中的实际工具调用和文件变更才是本轮的真实操作
|
||
|
||
请直接输出摘要,不要添加任何前缀或解释。"""
|
||
|
||
|
||
def _extract_summary_content(md_content: str) -> str:
|
||
"""检测 CONTEXT TRANSFER,替换用户输入为简短标注。"""
|
||
ct_pattern = re.compile(
|
||
r"## 2\. 用户输入\s*\n```\s*\n.*?CONTEXT TRANSFER", re.DOTALL
|
||
)
|
||
if ct_pattern.search(md_content):
|
||
md_content = re.sub(
|
||
r"(## 2\. 用户输入)\s*\n```[\s\S]*?```\s*\n(?=## 3\.)",
|
||
r"\1\n\n[本轮为 Context Transfer 续接,已省略。]\n\n",
|
||
md_content,
|
||
)
|
||
return md_content
|
||
|
||
|
||
PLACEHOLDER_PREFIX = "[待生成摘要]"
|
||
|
||
|
||
def collect_targets(index: dict, include_placeholder: bool = False) -> list[tuple[str, dict]]:
|
||
"""收集所有缺少 description 的主对话 entry。
|
||
|
||
include_placeholder=True 时,也包含占位标记的 entry(用于覆盖生成真实摘要)。
|
||
"""
|
||
targets = []
|
||
for eid, ent in index.get("entries", {}).items():
|
||
if ent.get("is_sub"):
|
||
continue
|
||
# 跳过已替代或无日志的占位条目
|
||
if ent.get("superseded_by") or ent.get("no_log"):
|
||
continue
|
||
desc = ent.get("description", "")
|
||
if not desc:
|
||
targets.append((eid, ent))
|
||
elif include_placeholder and desc.startswith(PLACEHOLDER_PREFIX):
|
||
targets.append((eid, ent))
|
||
# 按 startTime 排序(旧的先处理)
|
||
targets.sort(key=lambda t: t[1].get("startTime", ""))
|
||
return targets
|
||
|
||
|
||
def load_md_content(eid: str, entry: dict) -> str | None:
|
||
"""加载 entry 对应的 main_*.md 内容。"""
|
||
out_dir = entry.get("output_dir", "")
|
||
if not out_dir or not os.path.isdir(out_dir):
|
||
return None
|
||
eid_short = eid[:8]
|
||
main_files = sorted(
|
||
f for f in os.listdir(out_dir)
|
||
if f.startswith("main_") and f.endswith(".md") and eid_short in f
|
||
)
|
||
if not main_files:
|
||
main_files = sorted(
|
||
f for f in os.listdir(out_dir)
|
||
if f.startswith("main_") and f.endswith(".md")
|
||
)
|
||
if not main_files:
|
||
return None
|
||
parts = []
|
||
for mf in main_files:
|
||
try:
|
||
with open(os.path.join(out_dir, mf), "r", encoding="utf-8") as fh:
|
||
parts.append(fh.read())
|
||
except Exception:
|
||
continue
|
||
if not parts:
|
||
return None
|
||
content = "\n\n---\n\n".join(parts)
|
||
content = _extract_summary_content(content)
|
||
if len(content) > 60000:
|
||
content = content[:60000] + "\n\n[TRUNCATED]"
|
||
return content
|
||
|
||
|
||
async def generate_one(
|
||
client,
|
||
content: str,
|
||
semaphore: asyncio.Semaphore,
|
||
max_retries: int = 3,
|
||
) -> str:
|
||
"""调用百炼 API 生成摘要,带限流和指数退避。"""
|
||
async with semaphore:
|
||
for attempt in range(max_retries):
|
||
try:
|
||
resp = await client.chat.completions.create(
|
||
model=MODEL_NAME,
|
||
messages=[
|
||
{"role": "system", "content": SYSTEM_PROMPT},
|
||
{"role": "user",
|
||
"content": f"请为以下单轮执行记录生成摘要:\n\n{content}"},
|
||
],
|
||
max_tokens=4096,
|
||
)
|
||
return resp.choices[0].message.content.strip()
|
||
except Exception as e:
|
||
if attempt < max_retries - 1:
|
||
wait = 2 ** attempt
|
||
print(f" ⏳ 重试 {attempt+1}: {e}", file=sys.stderr)
|
||
await asyncio.sleep(wait)
|
||
else:
|
||
return ""
|
||
return ""
|
||
|
||
|
||
async def process_target(
|
||
client,
|
||
eid: str,
|
||
entry: dict,
|
||
semaphore: asyncio.Semaphore,
|
||
) -> tuple[str, str]:
|
||
"""处理单个 target,返回 (eid, description)。"""
|
||
content = load_md_content(eid, entry)
|
||
if not content:
|
||
return (eid, "")
|
||
desc = await generate_one(client, content, semaphore)
|
||
return (eid, desc)
|
||
|
||
|
||
|
||
async def main():
|
||
import argparse
|
||
from openai import AsyncOpenAI
|
||
|
||
parser = argparse.ArgumentParser(description="批量生成 session 摘要")
|
||
parser.add_argument("--limit", type=int, default=0,
|
||
help="只处理前 N 条(0=全量)")
|
||
parser.add_argument("--concurrency", type=int, default=10,
|
||
help="并发数(默认 10)")
|
||
parser.add_argument("--batch-size", type=int, default=20,
|
||
help="每批处理条数,每批完成后保存 checkpoint(默认 20)")
|
||
parser.add_argument("--replace-placeholder", action="store_true",
|
||
help="也处理占位标记 '[待生成摘要]' 的 entry")
|
||
parser.add_argument("--dry-run", action="store_true",
|
||
help="预览待处理列表")
|
||
args = parser.parse_args()
|
||
|
||
index = load_index()
|
||
full_index = load_full_index()
|
||
targets = collect_targets(index, include_placeholder=args.replace_placeholder)
|
||
|
||
if args.limit > 0:
|
||
targets = targets[:args.limit]
|
||
|
||
print(f"待处理: {len(targets)} 条缺少 description 的主对话 entry")
|
||
|
||
if not targets:
|
||
print("全部已有摘要,无需处理")
|
||
return
|
||
|
||
if args.dry_run:
|
||
for eid, ent in targets[:20]:
|
||
st = ent.get("startTime", "?")
|
||
od = ent.get("output_dir", "?")
|
||
print(f" {eid[:8]} | {st} | {od}")
|
||
if len(targets) > 20:
|
||
print(f" ... 还有 {len(targets) - 20} 条")
|
||
return
|
||
|
||
client = AsyncOpenAI(api_key=BAILIAN_API_KEY, base_url=BASE_URL)
|
||
semaphore = asyncio.Semaphore(args.concurrency)
|
||
|
||
t0 = time.time()
|
||
total_generated = 0
|
||
total_failed = 0
|
||
batch_size = args.batch_size
|
||
|
||
# 分批处理 + 每批 checkpoint 保存
|
||
for batch_start in range(0, len(targets), batch_size):
|
||
batch = targets[batch_start:batch_start + batch_size]
|
||
batch_num = batch_start // batch_size + 1
|
||
total_batches = (len(targets) + batch_size - 1) // batch_size
|
||
print(f"\n📦 批次 {batch_num}/{total_batches}({len(batch)} 条)...")
|
||
|
||
tasks = [process_target(client, eid, ent, semaphore)
|
||
for eid, ent in batch]
|
||
results = await asyncio.gather(*tasks)
|
||
|
||
# 写回索引
|
||
batch_generated = 0
|
||
full_entries = full_index.get("entries", {})
|
||
idx_entries = index.get("entries", {})
|
||
for eid, desc in results:
|
||
if not desc:
|
||
total_failed += 1
|
||
continue
|
||
if eid in idx_entries:
|
||
idx_entries[eid]["description"] = desc
|
||
if eid in full_entries:
|
||
full_entries[eid]["description"] = desc
|
||
batch_generated += 1
|
||
|
||
total_generated += batch_generated
|
||
|
||
# checkpoint:每批完成后立即保存
|
||
if batch_generated > 0:
|
||
save_index(index)
|
||
save_full_index(full_index)
|
||
elapsed = time.time() - t0
|
||
print(f" ✅ 本批 {batch_generated} 条已保存(累计 {total_generated},耗时 {elapsed:.1f}s)")
|
||
|
||
elapsed = time.time() - t0
|
||
print(f"\n完成: {total_generated}/{len(targets)} 生成成功,"
|
||
f"{total_failed} 失败,耗时 {elapsed:.1f}s")
|
||
|
||
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|