Files
Neo-ZQYY/apps/etl/connectors/feiqiu/pipeline/unified_pipeline.py

474 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""统一管道引擎:串行请求 + 异步处理 + 单线程写库。
核心执行流程:
主线程_request_loop串行发送 API 请求 → processing_queue
→ N 个 worker 线程_process_worker并行处理 → write_queue
→ 1 个 writer 线程_write_worker批量写入数据库
线程安全保证:
- PipelineResult 的计数更新通过 threading.Lock 保护
- 队列通信使用 queue.Queue内置线程安全
- SENTINELNone用于通知线程退出
"""
from __future__ import annotations
import logging
import queue
import threading
import time
from typing import Any, Callable, Iterable
from api.rate_limiter import RateLimiter
from config.pipeline_config import PipelineConfig
from utils.cancellation import CancellationToken
from pipeline.models import PipelineRequest, PipelineResult, WriteResult
# 运行时指标日志间隔(每 N 个请求记录一次队列深度等指标)
_METRICS_LOG_INTERVAL = 10
class UnifiedPipeline:
"""统一管道引擎:串行请求 + 异步处理 + 单线程写库。
Args:
api_client: API 客户端duck typing需有 post 方法)
db_connection: 数据库连接duck typing
logger: 日志记录器
config: 管道配置
cancel_token: 取消令牌None 时自动创建一个不会取消的令牌
etl_timer: 可选的 EtlTimer 实例,用于在 FlowRunner 计时报告中记录阶段耗时
task_code: 任务代码,与 etl_timer 配合使用作为步骤名前缀
"""
def __init__(
self,
api_client, # duck typing: 有 post(endpoint, params) 方法
db_connection, # duck typing
logger: logging.Logger,
config: PipelineConfig,
cancel_token: CancellationToken | None = None,
etl_timer=None, # 可选 EtlTimerduck typing
task_code: str | None = None,
) -> None:
self.api = api_client
self.db = db_connection
self.logger = logger
self.config = config
self.cancel_token = cancel_token or CancellationToken()
self._rate_limiter = RateLimiter(config.rate_min, config.rate_max)
self._etl_timer = etl_timer
self._task_code = task_code
# 结果计数锁,保护 PipelineResult 的并发更新
self._lock = threading.Lock()
# 处理线程引用,用于运行时指标日志中统计活跃线程数
self._workers: list[threading.Thread] = []
def run(
self,
requests: Iterable[PipelineRequest],
process_fn: Callable[[Any], list[dict]],
write_fn: Callable[[list[dict]], WriteResult],
) -> PipelineResult:
"""执行管道。
Args:
requests: 请求迭代器(由 BaseOdsTask 生成)
process_fn: 处理函数,将 API 响应转换为待写入记录列表
write_fn: 写入函数,将记录批量写入数据库
Returns:
PipelineResult 包含各阶段统计和最终状态
"""
# 预取消检查cancel_token 已取消则立即返回空结果
if self.cancel_token.is_cancelled:
return PipelineResult(status="CANCELLED", cancelled=True)
processing_queue: queue.Queue = queue.Queue(
maxsize=self.config.queue_size,
)
write_queue: queue.Queue = queue.Queue(
maxsize=self.config.queue_size * 2,
)
result = PipelineResult()
# 保存队列引用,供 _request_loop 运行时指标日志使用
self._processing_queue = processing_queue
self._write_queue = write_queue
start_time = time.monotonic()
# EtlTimer 集成:记录请求阶段子步骤
timer = self._etl_timer
step_name = self._task_code
# 启动 N 个处理线程
self._workers = []
for i in range(self.config.workers):
t = threading.Thread(
target=self._process_worker,
args=(processing_queue, write_queue, process_fn, result),
name=f"pipeline-worker-{i}",
daemon=True,
)
t.start()
self._workers.append(t)
# 启动 1 个写入线程
writer = threading.Thread(
target=self._write_worker,
args=(write_queue, write_fn, result),
name="pipeline-writer",
daemon=True,
)
writer.start()
# 主线程:串行请求
if timer and step_name:
try:
timer.start_sub_step(step_name, "request")
except KeyError:
pass # 父步骤不存在时静默跳过
request_start = time.monotonic()
self._request_loop(requests, processing_queue, result)
request_elapsed = time.monotonic() - request_start
if timer and step_name:
try:
timer.stop_sub_step(step_name, "request")
except KeyError:
pass
# 发送 SENTINEL 到处理队列,通知所有 worker 退出
if timer and step_name:
try:
timer.start_sub_step(step_name, "process")
except KeyError:
pass
process_start = time.monotonic()
for _ in self._workers:
processing_queue.put(None)
for w in self._workers:
w.join()
process_elapsed = time.monotonic() - process_start
if timer and step_name:
try:
timer.stop_sub_step(step_name, "process")
except KeyError:
pass
# 发送 SENTINEL 到写入队列,通知 writer 退出
if timer and step_name:
try:
timer.start_sub_step(step_name, "write")
except KeyError:
pass
write_start = time.monotonic()
write_queue.put(None)
writer.join()
write_elapsed = time.monotonic() - write_start
if timer and step_name:
try:
timer.stop_sub_step(step_name, "write")
except KeyError:
pass
total_elapsed = time.monotonic() - start_time
result.timing["total"] = round(total_elapsed, 3)
result.timing["request"] = round(request_elapsed, 3)
result.timing["process"] = round(process_elapsed, 3)
result.timing["write"] = round(write_elapsed, 3)
# 确定最终状态
if result.cancelled:
result.status = "CANCELLED"
elif result.status == "FAILED":
pass # 连续失败已设置 FAILED保持不变
elif (
result.request_failures
+ result.processing_failures
+ result.write_failures
> 0
):
result.status = "PARTIAL"
else:
result.status = "SUCCESS"
# 执行摘要日志(需求 8.2
self.logger.info(
"管道执行摘要: status=%s, 总耗时=%.1fs "
"[请求=%.1fs, 处理=%.1fs, 写入=%.1fs], "
"请求=%d/%d, 获取=%d, "
"写入(inserted=%d, updated=%d, skipped=%d), "
"失败(request=%d, process=%d, write=%d)",
result.status,
total_elapsed,
request_elapsed,
process_elapsed,
write_elapsed,
result.completed_requests,
result.total_requests,
result.total_fetched,
result.total_inserted,
result.total_updated,
result.total_skipped,
result.request_failures,
result.processing_failures,
result.write_failures,
)
# 清理队列引用
self._processing_queue = None
self._write_queue = None
self._workers = []
return result
def _request_loop(
self,
requests: Iterable[PipelineRequest],
processing_queue: queue.Queue,
result: PipelineResult,
) -> None:
"""主线程:串行发送 API 请求,限流等待,背压阻塞。
流程:
1. 遍历 requests 迭代器
2. 检查取消信号
3. 调用 api.post() 发送请求
4. 将响应 put 到 processing_queue满时阻塞 = 背压)
5. 调用 rate_limiter.wait(),被取消则 break
6. 连续失败超过阈值则中断status=FAILED
"""
consecutive_failures = 0
for req in requests:
# 取消检查
if self.cancel_token.is_cancelled:
with self._lock:
result.cancelled = True
self.logger.info("收到取消信号,停止发送新请求")
break
with self._lock:
result.total_requests += 1
req_start = time.monotonic()
try:
# 预取模式iter_paginated 已获取数据,直接使用
if req._prefetched_response is not None:
response = req._prefetched_response
else:
response = self.api.post(req.endpoint, req.params)
elapsed = time.monotonic() - req_start
self.logger.debug(
"请求完成: endpoint=%s, 耗时=%.2fs",
req.endpoint,
elapsed,
)
# 将响应放入处理队列(满时阻塞 = 背压机制)
processing_queue.put((req, response))
with self._lock:
result.completed_requests += 1
completed = result.completed_requests
total = result.total_requests
# 成功则重置连续失败计数
consecutive_failures = 0
# 运行时指标日志(需求 8.1):每 N 个请求记录一次队列深度和进度
if completed % _METRICS_LOG_INTERVAL == 0:
self._log_runtime_metrics(result, completed, total)
except Exception as exc:
elapsed = time.monotonic() - req_start
consecutive_failures += 1
self.logger.error(
"请求失败: endpoint=%s, 耗时=%.2fs, 错误=%s",
req.endpoint,
elapsed,
exc,
)
with self._lock:
result.request_failures += 1
result.errors.append({
"phase": "request",
"endpoint": req.endpoint,
"error": str(exc),
})
# 连续失败超过阈值则中断
if consecutive_failures >= self.config.max_consecutive_failures:
self.logger.error(
"连续失败 %d 次,超过阈值 %d,中断管道",
consecutive_failures,
self.config.max_consecutive_failures,
)
with self._lock:
result.status = "FAILED"
break
# 限流等待(最后一个请求后也等待,保持与上游的间隔一致性)
if not self._rate_limiter.wait(self.cancel_token.event):
with self._lock:
result.cancelled = True
self.logger.info("限流等待期间收到取消信号,停止发送新请求")
break
def _process_worker(
self,
processing_queue: queue.Queue,
write_queue: queue.Queue,
process_fn: Callable[[Any], list[dict]],
result: PipelineResult,
) -> None:
"""处理线程:从 processing_queue 消费数据,调用 process_fn结果放入 write_queue。
收到 SENTINELNone时退出。
单条记录处理异常时捕获、记录错误、继续处理。
"""
while True:
item = processing_queue.get()
# SENTINEL退出信号
if item is None:
processing_queue.task_done()
break
req, response = item
try:
records = process_fn(response)
if records:
# 将处理结果放入写入队列
write_queue.put(records)
with self._lock:
result.total_fetched += len(records)
except Exception as exc:
self.logger.error(
"处理失败: endpoint=%s, 错误=%s",
req.endpoint,
exc,
)
with self._lock:
result.processing_failures += 1
result.errors.append({
"phase": "processing",
"endpoint": req.endpoint,
"error": str(exc),
})
processing_queue.task_done()
def _write_worker(
self,
write_queue: queue.Queue,
write_fn: Callable[[list[dict]], WriteResult],
result: PipelineResult,
) -> None:
"""写入线程:从 write_queue 消费数据,累积到 batch_size 或超时后批量写入。
- 累积到 batch_size 条记录时立即写入
- 等待 batch_timeout 秒后将已累积的记录写入(即使不足 batch_size
- 写入失败时记录错误、继续处理后续批次
- 收到 SENTINELNone时将剩余数据 flush 后退出
"""
batch: list[dict] = []
batch_size = self.config.batch_size
batch_timeout = self.config.batch_timeout
while True:
try:
item = write_queue.get(timeout=batch_timeout)
except queue.Empty:
# 超时:将已累积的记录写入
if batch:
self._flush_batch(batch, write_fn, result)
batch = []
continue
# SENTINEL退出信号
if item is None:
write_queue.task_done()
break
# item 是 list[dict](一次 process_fn 的输出)
batch.extend(item)
write_queue.task_done()
# 队列积压警告
qsize = write_queue.qsize()
if qsize >= self.config.queue_size * 2:
self.logger.warning(
"写入队列积压: qsize=%d, 阈值=%d",
qsize,
self.config.queue_size * 2,
)
# 累积到 batch_size 时写入
while len(batch) >= batch_size:
chunk = batch[:batch_size]
batch = batch[batch_size:]
self._flush_batch(chunk, write_fn, result)
# 退出前 flush 剩余数据
if batch:
self._flush_batch(batch, write_fn, result)
def _flush_batch(
self,
batch: list[dict],
write_fn: Callable[[list[dict]], WriteResult],
result: PipelineResult,
) -> None:
"""执行一次批量写入,更新结果计数。"""
if not batch:
return
try:
wr = write_fn(batch)
with self._lock:
result.total_inserted += wr.inserted
result.total_updated += wr.updated
result.total_skipped += wr.skipped
except Exception as exc:
self.logger.error(
"批量写入失败: batch_size=%d, 错误=%s",
len(batch),
exc,
)
with self._lock:
result.write_failures += 1
result.errors.append({
"phase": "write",
"batch_size": len(batch),
"error": str(exc),
})
def _log_runtime_metrics(
self,
result: PipelineResult,
completed: int,
total: int,
) -> None:
"""记录运行时指标:队列深度、活跃线程数、进度(需求 8.1)。"""
pq_depth = self._processing_queue.qsize() if self._processing_queue else 0
wq_depth = self._write_queue.qsize() if self._write_queue else 0
active_workers = sum(1 for w in self._workers if w.is_alive())
self.logger.debug(
"运行时指标: 进度=%d/%d, 处理队列=%d, 活跃线程=%d, 写入队列=%d",
completed,
total,
pq_depth,
active_workers,
wq_depth,
)