改 相对路径 完成客户端
This commit is contained in:
@@ -4,6 +4,8 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from types import SimpleNamespace
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@@ -99,16 +101,87 @@ def dump_offline_payload(spec: TaskSpec, archive_dir: Path) -> Path:
|
||||
class FakeCursor:
|
||||
"""极简游标桩对象,记录 SQL/参数并支持上下文管理,供 FakeDBOperations 与 SCD2Handler 使用。"""
|
||||
|
||||
def __init__(self, recorder: List[Dict]):
|
||||
def __init__(self, recorder: List[Dict], db_ops=None):
|
||||
self.recorder = recorder
|
||||
self._db_ops = db_ops
|
||||
self._pending_rows: List[Tuple] = []
|
||||
self._fetchall_rows: List[Tuple] = []
|
||||
self.rowcount = 0
|
||||
self.connection = SimpleNamespace(encoding="UTF8")
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def execute(self, sql: str, params=None):
|
||||
self.recorder.append({"sql": sql.strip(), "params": params})
|
||||
sql_text = sql.decode("utf-8", errors="ignore") if isinstance(sql, (bytes, bytearray)) else str(sql)
|
||||
self.recorder.append({"sql": sql_text.strip(), "params": params})
|
||||
self._fetchall_rows = []
|
||||
|
||||
# Handle information_schema queries for schema-aware inserts.
|
||||
lowered = sql_text.lower()
|
||||
if "from information_schema.columns" in lowered:
|
||||
table_name = None
|
||||
if params and len(params) >= 2:
|
||||
table_name = params[1]
|
||||
self._fetchall_rows = self._fake_columns(table_name)
|
||||
return
|
||||
if "from information_schema.table_constraints" in lowered:
|
||||
self._fetchall_rows = []
|
||||
return
|
||||
|
||||
if self._pending_rows:
|
||||
self.rowcount = len(self._pending_rows)
|
||||
self._record_upserts(sql_text)
|
||||
self._pending_rows = []
|
||||
else:
|
||||
self.rowcount = 0
|
||||
|
||||
def fetchone(self):
|
||||
return None
|
||||
|
||||
def fetchall(self):
|
||||
return list(self._fetchall_rows)
|
||||
|
||||
def mogrify(self, template, args):
|
||||
self._pending_rows.append(tuple(args))
|
||||
return b"(?)"
|
||||
|
||||
def _record_upserts(self, sql_text: str):
|
||||
if not self._db_ops:
|
||||
return
|
||||
match = re.search(r"insert\s+into\s+[^\(]+\(([^)]*)\)\s+values", sql_text, re.I)
|
||||
if not match:
|
||||
return
|
||||
columns = [c.strip().strip('"') for c in match.group(1).split(",")]
|
||||
rows = []
|
||||
for idx, row in enumerate(self._pending_rows):
|
||||
if len(row) != len(columns):
|
||||
continue
|
||||
row_dict = {}
|
||||
for col, val in zip(columns, row):
|
||||
if col == "record_index" and val in (None, ""):
|
||||
row_dict[col] = idx
|
||||
continue
|
||||
if hasattr(val, "adapted"):
|
||||
row_dict[col] = json.dumps(val.adapted, ensure_ascii=False)
|
||||
else:
|
||||
row_dict[col] = val
|
||||
rows.append(row_dict)
|
||||
if rows:
|
||||
self._db_ops.upserts.append(
|
||||
{"sql": sql_text.strip(), "count": len(rows), "page_size": len(rows), "rows": rows}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _fake_columns(_table_name: str | None) -> List[Tuple[str, str, str]]:
|
||||
return [
|
||||
("id", "bigint", "int8"),
|
||||
("sitegoodsstockid", "bigint", "int8"),
|
||||
("record_index", "integer", "int4"),
|
||||
("source_file", "text", "text"),
|
||||
("source_endpoint", "text", "text"),
|
||||
("fetched_at", "timestamp with time zone", "timestamptz"),
|
||||
("payload", "jsonb", "jsonb"),
|
||||
]
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
@@ -119,11 +192,12 @@ class FakeCursor:
|
||||
class FakeConnection:
|
||||
"""仿 psycopg 连接对象,仅满足 SCD2Handler 对 cursor 的最小需求,并缓存执行过的语句。"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, db_ops):
|
||||
self.statements: List[Dict] = []
|
||||
self._db_ops = db_ops
|
||||
|
||||
def cursor(self):
|
||||
return FakeCursor(self.statements)
|
||||
return FakeCursor(self.statements, self._db_ops)
|
||||
|
||||
|
||||
class FakeDBOperations:
|
||||
@@ -134,7 +208,7 @@ class FakeDBOperations:
|
||||
self.executes: List[Dict] = []
|
||||
self.commits = 0
|
||||
self.rollbacks = 0
|
||||
self.conn = FakeConnection()
|
||||
self.conn = FakeConnection(self)
|
||||
# Pre-seeded query results (FIFO) to let tests control DB-returned rows
|
||||
self.query_results: List[List[Dict]] = []
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from config.defaults import DEFAULTS
|
||||
|
||||
def test_config_load():
|
||||
"""测试配置加载"""
|
||||
config = AppConfig.load()
|
||||
config = AppConfig.load({"app": {"store_id": 1}})
|
||||
assert config.get("app.timezone") == DEFAULTS["app"]["timezone"]
|
||||
|
||||
def test_config_override():
|
||||
@@ -19,6 +19,6 @@ def test_config_override():
|
||||
|
||||
def test_config_get_nested():
|
||||
"""测试嵌套配置获取"""
|
||||
config = AppConfig.load()
|
||||
config = AppConfig.load({"app": {"store_id": 1}})
|
||||
assert config.get("db.batch_size") == 1000
|
||||
assert config.get("nonexistent.key", "default") == "default"
|
||||
|
||||
Reference in New Issue
Block a user