fix: band-level windowed refine_layout + programmatic map_fields to prevent 91.5% content loss
Root cause: LLM receiving full 34k-char JRXML would regenerate from scratch
instead of modifying coordinates in-place, shrinking output to ~3k chars.
Solution (programmatic node control, not prompt engineering):
- New agent/jrxml_windower.py: decompose JRXML into header (never sent to
LLM) + individual bands. Split bands >4000 chars at element boundaries.
Reassemble with element count validation (>10% change = rollback).
- Rewrite refine_layout: per-band windowed LLM processing (~2-4k chars
each). LLM cannot "reimagine" the entire report.
- Rewrite map_fields: 100% programmatic regex $F{field_N} -> real name
replacement. Zero LLM calls, zero content loss.
- _sanitize_field_name: non-ASCII chars escaped to _uXXXX_ format for
valid JRXML identifiers.
- Tests: 48 new unit tests (windower 28 + map_fields 20). All passing.
Full suite 385 tests, zero regressions.
This commit is contained in:
@@ -265,3 +265,249 @@ class TestBoundaries:
|
||||
json={"text": large_text, "file_ids": []},
|
||||
) as resp:
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ── 用户 CRUD API ───────────────────────────────────────────────
|
||||
|
||||
class TestUserAPI:
|
||||
@pytest.fixture(autouse=True)
|
||||
def temp_kb_data(self, monkeypatch, tmp_path):
|
||||
kb_data = tmp_path / "kb_data"
|
||||
monkeypatch.setattr("backend.kb_manager.KB_DATA_DIR", kb_data)
|
||||
monkeypatch.setattr("backend.kb_manager._USERS_FILE", kb_data / "users.json")
|
||||
yield kb_data
|
||||
|
||||
def test_create_user(self, client):
|
||||
resp = client.post("/api/users", json={"name": "测试用户"})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["name"] == "测试用户"
|
||||
assert len(data["user_id"]) >= 12
|
||||
|
||||
def test_create_user_empty_name_rejected(self, client):
|
||||
resp = client.post("/api/users", json={"name": ""})
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_list_users(self, client):
|
||||
client.post("/api/users", json={"name": "A"})
|
||||
client.post("/api/users", json={"name": "B"})
|
||||
resp = client.get("/api/users")
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()["users"]) == 2
|
||||
|
||||
def test_get_user(self, client):
|
||||
uid = client.post("/api/users", json={"name": "张三"}).json()["user_id"]
|
||||
resp = client.get(f"/api/users/{uid}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "张三"
|
||||
|
||||
def test_get_user_not_found(self, client):
|
||||
resp = client.get("/api/users/deadbeef1234567890abcd")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_delete_user(self, client):
|
||||
uid = client.post("/api/users", json={"name": "待删除"}).json()["user_id"]
|
||||
resp = client.delete(f"/api/users/{uid}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "deleted"
|
||||
assert client.get(f"/api/users/{uid}").status_code == 404
|
||||
|
||||
def test_delete_nonexistent_user(self, client):
|
||||
resp = client.delete("/api/users/deadbeef1234567890abcd")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ── 知识库 CRUD API ─────────────────────────────────────────────
|
||||
|
||||
class TestKbAPI:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_kb(self, monkeypatch, tmp_path):
|
||||
kb_data = tmp_path / "kb_data"
|
||||
monkeypatch.setattr("backend.kb_manager.KB_DATA_DIR", kb_data)
|
||||
monkeypatch.setattr("backend.kb_manager._USERS_FILE", kb_data / "users.json")
|
||||
# 使用 raw TestClient 来创建前置用户
|
||||
from fastapi.testclient import TestClient as TC
|
||||
tc = TC(app)
|
||||
resp = tc.post("/api/users", json={"name": "KB测试用户"})
|
||||
self.uid = resp.json()["user_id"]
|
||||
|
||||
def test_create_kb(self, client):
|
||||
resp = client.post(f"/api/users/{self.uid}/kbs", json={"name": "测试库", "description": "描述"})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["name"] == "测试库"
|
||||
assert data["parse_status"] == "empty"
|
||||
|
||||
def test_create_kb_empty_name_rejected(self, client):
|
||||
resp = client.post(f"/api/users/{self.uid}/kbs", json={"name": ""})
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_list_kbs(self, client):
|
||||
client.post(f"/api/users/{self.uid}/kbs", json={"name": "KB1"})
|
||||
client.post(f"/api/users/{self.uid}/kbs", json={"name": "KB2"})
|
||||
resp = client.get(f"/api/users/{self.uid}/kbs")
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()["kbs"]) == 2
|
||||
|
||||
def test_get_kb(self, client):
|
||||
kid = client.post(f"/api/users/{self.uid}/kbs", json={"name": "查询库"}).json()["kb_id"]
|
||||
resp = client.get(f"/api/kbs/{kid}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "查询库"
|
||||
|
||||
def test_get_kb_not_found(self, client):
|
||||
resp = client.get("/api/kbs/deadbeef1234567890abcd")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_delete_kb(self, client):
|
||||
kid = client.post(f"/api/users/{self.uid}/kbs", json={"name": "待删库"}).json()["kb_id"]
|
||||
resp = client.delete(f"/api/kbs/{kid}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "deleted"
|
||||
|
||||
def test_kb_status(self, client):
|
||||
kid = client.post(f"/api/users/{self.uid}/kbs", json={"name": "状态库"}).json()["kb_id"]
|
||||
resp = client.get(f"/api/kbs/{kid}/status")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["parse_status"] == "empty"
|
||||
assert resp.json()["file_count"] == 0
|
||||
|
||||
def test_kb_fields(self, client):
|
||||
kid = client.post(f"/api/users/{self.uid}/kbs", json={"name": "字段库"}).json()["kb_id"]
|
||||
resp = client.get(f"/api/kbs/{kid}/fields")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["fields"] == []
|
||||
assert resp.json()["templates"] == []
|
||||
|
||||
|
||||
# ── KB 文件上传 & 构建 API ──────────────────────────────────────
|
||||
|
||||
class TestKbUploadBuild:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_up(self, monkeypatch, tmp_path):
|
||||
kb_data = tmp_path / "kb_data"
|
||||
kb_data.mkdir(parents=True, exist_ok=True)
|
||||
monkeypatch.setattr("backend.kb_manager.KB_DATA_DIR", kb_data)
|
||||
monkeypatch.setattr("backend.kb_manager._USERS_FILE", kb_data / "users.json")
|
||||
# Mock process_file_for_kb to avoid SameFileError (API already writes to raw_dir)
|
||||
monkeypatch.setattr(
|
||||
"backend.kb_parser.process_file_for_kb",
|
||||
lambda kb_id, file_path, source_name="": {
|
||||
"filename": source_name, "type": "txt", "error": None})
|
||||
from fastapi.testclient import TestClient as TC
|
||||
tc = TC(app)
|
||||
resp = tc.post("/api/users", json={"name": "上传测试用户"})
|
||||
self.uid = resp.json()["user_id"]
|
||||
|
||||
def test_upload_to_kb(self, client):
|
||||
kid = client.post(f"/api/users/{self.uid}/kbs", json={"name": "上传库"}).json()["kb_id"]
|
||||
resp = client.post(
|
||||
f"/api/kbs/{kid}/upload",
|
||||
files={"file": ("readme.md", io.BytesIO(b"# test"), "text/markdown")},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["filename"] == "readme.md"
|
||||
|
||||
def test_upload_to_nonexistent_kb(self, client):
|
||||
resp = client.post(
|
||||
"/api/kbs/deadbeef1234567890abcd/upload",
|
||||
files={"file": ("x.txt", io.BytesIO(b"x"), "text/plain")},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_build_empty_kb_fails(self, client):
|
||||
kid = client.post(f"/api/users/{self.uid}/kbs", json={"name": "空库"}).json()["kb_id"]
|
||||
resp = client.post(f"/api/kbs/{kid}/build")
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_search_kb_empty_query_rejected(self, client):
|
||||
kid = client.post(f"/api/users/{self.uid}/kbs", json={"name": "搜索库"}).json()["kb_id"]
|
||||
resp = client.get(f"/api/kbs/{kid}/search")
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# ── 会话-KB 绑定 API ────────────────────────────────────────────
|
||||
|
||||
class TestSessionKbBinding:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_bind(self, monkeypatch, tmp_path):
|
||||
kb_data = tmp_path / "kb_data"
|
||||
kb_data.mkdir(parents=True, exist_ok=True)
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
monkeypatch.setattr("backend.kb_manager.KB_DATA_DIR", kb_data)
|
||||
monkeypatch.setattr("backend.kb_manager._USERS_FILE", kb_data / "users.json")
|
||||
monkeypatch.setattr("backend.session.SESSIONS_DIR", sessions_dir)
|
||||
monkeypatch.setattr("api_server.UPLOADS_DIR", tmp_path / "uploads")
|
||||
|
||||
def test_bind_kb_to_session(self, client):
|
||||
uid = client.post("/api/users", json={"name": "绑定用户"}).json()["user_id"]
|
||||
kid = client.post(f"/api/users/{uid}/kbs", json={"name": "绑定库"}).json()["kb_id"]
|
||||
sid = client.post("/api/sessions").json()["session_id"]
|
||||
resp = client.put(f"/api/sessions/{sid}/kb", json={"kb_id": kid})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["kb_id"] == kid
|
||||
|
||||
def test_get_session_kb(self, client):
|
||||
uid = client.post("/api/users", json={"name": "查询用户"}).json()["user_id"]
|
||||
kid = client.post(f"/api/users/{uid}/kbs", json={"name": "查询KB"}).json()["kb_id"]
|
||||
sid = client.post("/api/sessions").json()["session_id"]
|
||||
client.put(f"/api/sessions/{sid}/kb", json={"kb_id": kid})
|
||||
resp = client.get(f"/api/sessions/{sid}/kb")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["kb_id"] == kid
|
||||
assert resp.json()["kb_name"] == "查询KB"
|
||||
|
||||
def test_unbind_kb(self, client):
|
||||
sid = client.post("/api/sessions").json()["session_id"]
|
||||
resp = client.put(f"/api/sessions/{sid}/kb", json={"kb_id": ""})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["kb_id"] is None
|
||||
|
||||
def test_bind_nonexistent_kb(self, client):
|
||||
sid = client.post("/api/sessions").json()["session_id"]
|
||||
resp = client.put(f"/api/sessions/{sid}/kb", json={"kb_id": "deadbeef1234567890abcd"})
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_bind_to_nonexistent_session(self, client):
|
||||
resp = client.put("/api/sessions/deadbeef1234567890abcd/kb", json={"kb_id": ""})
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ── 用户-KB 端到端流程 ──────────────────────────────────────────
|
||||
|
||||
class TestUserKbE2E:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_e2e(self, monkeypatch, tmp_path):
|
||||
kb_data = tmp_path / "kb_data"
|
||||
kb_data.mkdir(parents=True, exist_ok=True)
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
monkeypatch.setattr("backend.kb_manager.KB_DATA_DIR", kb_data)
|
||||
monkeypatch.setattr("backend.kb_manager._USERS_FILE", kb_data / "users.json")
|
||||
monkeypatch.setattr("backend.session.SESSIONS_DIR", sessions_dir)
|
||||
monkeypatch.setattr("api_server.UPLOADS_DIR", tmp_path / "uploads")
|
||||
# Mock process_file_for_kb to avoid SameFileError
|
||||
monkeypatch.setattr(
|
||||
"backend.kb_parser.process_file_for_kb",
|
||||
lambda kb_id, file_path, source_name="": {
|
||||
"filename": source_name, "type": "txt", "error": None})
|
||||
|
||||
def test_full_flow(self, client):
|
||||
# 1. 创建用户
|
||||
uid = client.post("/api/users", json={"name": "E2E用户"}).json()["user_id"]
|
||||
# 2. 创建 KB
|
||||
kid = client.post(f"/api/users/{uid}/kbs", json={"name": "E2E库"}).json()["kb_id"]
|
||||
# 3. 上传文件
|
||||
resp = client.post(
|
||||
f"/api/kbs/{kid}/upload",
|
||||
files={"file": ("readme.md", io.BytesIO(b"# E2E test"), "text/markdown")},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
# 4. 创建会话
|
||||
sid = client.post("/api/sessions").json()["session_id"]
|
||||
# 5. 绑定 KB 到会话
|
||||
bind = client.put(f"/api/sessions/{sid}/kb", json={"kb_id": kid})
|
||||
assert bind.status_code == 200
|
||||
assert bind.json()["kb_id"] == kid
|
||||
# 6. 查询会话 KB
|
||||
info = client.get(f"/api/sessions/{sid}/kb")
|
||||
assert info.json()["kb_name"] == "E2E库"
|
||||
|
||||
@@ -0,0 +1,161 @@
|
||||
"""datasource.py 测试 — 数据源模式解析, JDBC 检测, 上下文构建。"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from agent.datasource import (
|
||||
resolve_datasource_mode, _detect_jdbc_intent,
|
||||
build_datasource_context, configure_jdbc, ask_db_config,
|
||||
)
|
||||
|
||||
|
||||
def _make_state(**overrides):
|
||||
s = {
|
||||
"user_input": "",
|
||||
"conversation_history": [],
|
||||
"current_jrxml": "",
|
||||
"status": "",
|
||||
"error_msg": "",
|
||||
"natural_explanation": "",
|
||||
"retry_count": 0,
|
||||
"user_modification_request": "",
|
||||
"final_jrxml": "",
|
||||
"stage": "",
|
||||
"retrieved_context": "",
|
||||
**overrides,
|
||||
}
|
||||
return s
|
||||
|
||||
|
||||
# ── JDBC 意图检测 ───────────────────────────────────────────────
|
||||
|
||||
class TestDetectJdbcIntent:
|
||||
def test_direct_connect_keywords(self):
|
||||
assert _detect_jdbc_intent("我想从数据库直连查询") is True
|
||||
assert _detect_jdbc_intent("直连数据库获取数据") is True
|
||||
|
||||
def test_db_name_mentions(self):
|
||||
assert _detect_jdbc_intent("从MySQL数据库查询用户表") is True
|
||||
assert _detect_jdbc_intent("在PostgreSQL中执行查询") is True
|
||||
assert _detect_jdbc_intent("从Oracle读取数据") is True
|
||||
|
||||
def test_jdbc_explicit_mention(self):
|
||||
assert _detect_jdbc_intent("使用JDBC连接") is True
|
||||
|
||||
def test_sql_keywords(self):
|
||||
assert _detect_jdbc_intent("SELECT * FROM users") is True
|
||||
assert _detect_jdbc_intent("从数据库查询用户表") is True
|
||||
assert _detect_jdbc_intent("先查询 数据库") is True
|
||||
|
||||
def test_normal_request_is_not_jdbc(self):
|
||||
assert _detect_jdbc_intent("生成一个员工报表") is False
|
||||
assert _detect_jdbc_intent("修改标题为XX公司") is False
|
||||
|
||||
def test_empty_input(self):
|
||||
assert _detect_jdbc_intent("") is False
|
||||
|
||||
|
||||
# ── 模式解析 ────────────────────────────────────────────────────
|
||||
|
||||
class TestResolveDatasourceMode:
|
||||
def test_defaults_to_parameter_mode(self):
|
||||
state = _make_state(user_input="生成报表")
|
||||
assert resolve_datasource_mode(state) == "parameter"
|
||||
|
||||
def test_detects_jdbc_from_input(self):
|
||||
state = _make_state(user_input="从数据库直连查询")
|
||||
assert resolve_datasource_mode(state) == "jdbc"
|
||||
|
||||
def test_respects_existing_mode_in_state(self):
|
||||
state = _make_state(datasource_mode="jdbc", user_input="生成报表")
|
||||
assert resolve_datasource_mode(state) == "jdbc"
|
||||
|
||||
def test_existing_parameter_overrides_jdbc_input(self):
|
||||
state = _make_state(datasource_mode="parameter", user_input="从数据库直连")
|
||||
assert resolve_datasource_mode(state) == "parameter"
|
||||
|
||||
def test_ignores_invalid_mode_in_state(self):
|
||||
state = _make_state(datasource_mode="unknown", user_input="从数据库直连")
|
||||
assert resolve_datasource_mode(state) == "jdbc"
|
||||
|
||||
|
||||
# ── 上下文构建 ──────────────────────────────────────────────────
|
||||
|
||||
class TestBuildDatasourceContext:
|
||||
def test_parameter_mode_with_fields(self):
|
||||
fields = [
|
||||
{"name": "billNo", "description": "工单号", "type": "java.lang.String"},
|
||||
{"name": "amount", "description": "金额", "type": "java.math.BigDecimal"},
|
||||
]
|
||||
ctx = build_datasource_context("parameter", fields)
|
||||
assert "[数据源模式: 参数]" in ctx
|
||||
assert "$P{xxx}" in ctx
|
||||
assert "billNo" in ctx
|
||||
assert "amount" in ctx
|
||||
|
||||
def test_parameter_mode_without_fields(self):
|
||||
ctx = build_datasource_context("parameter", [])
|
||||
assert "[数据源模式: 参数]" in ctx
|
||||
assert "$P{xxx}" in ctx
|
||||
|
||||
def test_jdbc_mode_with_config(self):
|
||||
db_config = {"url": "jdbc:mysql://localhost:3306/mydb",
|
||||
"driver": "com.mysql.cj.jdbc.Driver"}
|
||||
ctx = build_datasource_context("jdbc", [], db_config)
|
||||
assert "[数据源模式: JDBC]" in ctx
|
||||
assert "jdbc:mysql://" in ctx
|
||||
assert "CDATA" in ctx
|
||||
|
||||
def test_jdbc_mode_without_config_shows_warning(self):
|
||||
ctx = build_datasource_context("jdbc", [])
|
||||
assert "尚未配置数据库连接" in ctx
|
||||
assert "P{xxx}" in ctx
|
||||
|
||||
|
||||
# ── JDBC 配置 ───────────────────────────────────────────────────
|
||||
|
||||
class TestConfigureJdbc:
|
||||
def test_configure_returns_update_dict(self):
|
||||
state = _make_state()
|
||||
update = configure_jdbc(
|
||||
state, url="jdbc:mysql://localhost/db",
|
||||
driver="com.mysql.cj.jdbc.Driver",
|
||||
username="root", password="pass")
|
||||
assert update["datasource_mode"] == "jdbc"
|
||||
assert update["db_config"]["url"] == "jdbc:mysql://localhost/db"
|
||||
assert update["db_config"]["username"] == "root"
|
||||
|
||||
def test_default_driver_is_mysql(self):
|
||||
update = configure_jdbc(_make_state(), url="jdbc:postgresql://localhost/db")
|
||||
assert "mysql" in update["db_config"]["driver"]
|
||||
|
||||
|
||||
# ── ask_db_config ───────────────────────────────────────────────
|
||||
|
||||
class TestAskDbConfig:
|
||||
def test_returns_none_for_parameter_mode(self):
|
||||
state = _make_state(datasource_mode="parameter")
|
||||
assert ask_db_config(state) is None
|
||||
|
||||
def test_returns_none_when_jdbc_configured(self):
|
||||
state = _make_state(datasource_mode="jdbc",
|
||||
db_config={"url": "jdbc:mysql://localhost/db"})
|
||||
assert ask_db_config(state) is None
|
||||
|
||||
def test_returns_prompt_when_jdbc_missing_config(self):
|
||||
state = _make_state(datasource_mode="jdbc")
|
||||
msg = ask_db_config(state)
|
||||
assert msg is not None
|
||||
assert "JDBC URL" in msg
|
||||
assert "用户名" in msg
|
||||
assert "密码" in msg
|
||||
|
||||
def test_returns_none_when_db_config_empty(self):
|
||||
state = _make_state(datasource_mode="jdbc", db_config={})
|
||||
msg = ask_db_config(state)
|
||||
assert msg is not None
|
||||
@@ -0,0 +1,157 @@
|
||||
"""field_matcher.py 测试 — OCR 字段 → KB 字段匹配, embedding + LLM。"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from backend.field_matcher import (
|
||||
_cosine_similarity, match_ocr_to_kb, _match_via_llm,
|
||||
format_field_mapping_context,
|
||||
)
|
||||
|
||||
|
||||
# ── 余弦相似度 ──────────────────────────────────────────────────
|
||||
|
||||
class TestCosineSimilarity:
|
||||
def test_identical_vectors(self):
|
||||
assert _cosine_similarity([1, 0, 0], [1, 0, 0]) == 1.0
|
||||
|
||||
def test_orthogonal_vectors(self):
|
||||
assert _cosine_similarity([1, 0, 0], [0, 1, 0]) == 0.0
|
||||
|
||||
def test_opposite_vectors(self):
|
||||
assert _cosine_similarity([1, 0], [-1, 0]) == -1.0
|
||||
|
||||
def test_normalized_vectors_range(self):
|
||||
sim = _cosine_similarity([0.6, 0.8], [0.8, 0.6])
|
||||
assert -1.0 <= sim <= 1.0
|
||||
|
||||
|
||||
# ── LLM 匹配 ────────────────────────────────────────────────────
|
||||
|
||||
class TestMatchViaLlm:
|
||||
def test_returns_json_mapping(self):
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = '{"工单号": "billNo", "客户": "customerName"}'
|
||||
mock_llm.invoke.return_value = mock_response
|
||||
|
||||
kb_fields = [
|
||||
{"name": "billNo", "description": "工单号", "type": "String"},
|
||||
{"name": "customerName", "description": "客户名称", "type": "String"},
|
||||
]
|
||||
result = _match_via_llm(["工单号", "客户"], kb_fields, mock_llm)
|
||||
assert result == {"工单号": "billNo", "客户": "customerName"}
|
||||
|
||||
def test_includes_candidates_hint_when_provided(self):
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = '{"工单号": "billNo"}'
|
||||
mock_llm.invoke.return_value = mock_response
|
||||
candidates = {"工单号": [("billNo", 0.85), ("orderId", 0.62)]}
|
||||
result = _match_via_llm(
|
||||
["工单号"],
|
||||
[{"name": "billNo", "description": "工单号", "type": "String"}],
|
||||
mock_llm, candidates=candidates)
|
||||
call_args = mock_llm.invoke.call_args[0][0]
|
||||
assert "候选" in call_args
|
||||
assert "billNo" in call_args
|
||||
|
||||
def test_llm_error_returns_empty_dict(self):
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke.side_effect = RuntimeError("LLM crash")
|
||||
result = _match_via_llm(["x"], [{"name": "y", "description": "", "type": "String"}], mock_llm)
|
||||
assert result == {}
|
||||
|
||||
def test_llm_returns_invalid_json_returns_empty(self):
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = "not json at all"
|
||||
mock_llm.invoke.return_value = mock_response
|
||||
result = _match_via_llm(["x"], [{"name": "y", "description": "", "type": "String"}], mock_llm)
|
||||
assert result == {}
|
||||
|
||||
|
||||
# ── 完整匹配流程 ────────────────────────────────────────────────
|
||||
|
||||
class TestMatchOcrToKb:
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_embed(self):
|
||||
with patch("backend.field_matcher._embed") as mock_embed:
|
||||
def _fake_embed(text):
|
||||
if "billNo" in text or "工单" in text:
|
||||
return [1.0, 0.0, 0.0]
|
||||
if "customerName" in text or "客户" in text:
|
||||
return [0.0, 1.0, 0.0]
|
||||
if "amount" in text or "金额" in text:
|
||||
return [0.0, 0.0, 1.0]
|
||||
return [0.0, 0.0, 0.0]
|
||||
mock_embed.side_effect = _fake_embed
|
||||
yield mock_embed
|
||||
|
||||
def test_matches_without_llm(self):
|
||||
kb_fields = [
|
||||
{"name": "billNo", "description": "工单号", "type": "String"},
|
||||
{"name": "customerName", "description": "客户名称", "type": "String"},
|
||||
{"name": "amount", "description": "金额", "type": "BigDecimal"},
|
||||
]
|
||||
mapping = match_ocr_to_kb(
|
||||
["工单号", "客户名称", "金额"], kb_fields, llm=None)
|
||||
assert mapping["工单号"] == "billNo"
|
||||
assert mapping["客户名称"] == "customerName"
|
||||
assert mapping["金额"] == "amount"
|
||||
|
||||
def test_empty_inputs_return_empty(self):
|
||||
assert match_ocr_to_kb([], [], llm=None) == {}
|
||||
assert match_ocr_to_kb(["x"], [], llm=None) == {}
|
||||
assert match_ocr_to_kb([], [{"name": "y", "description": "", "type": "String"}], llm=None) == {}
|
||||
|
||||
def test_low_similarity_not_matched(self):
|
||||
kb_fields = [{"name": "far", "description": "不相关字段", "type": "String"}]
|
||||
mapping = match_ocr_to_kb(["无关"], kb_fields, llm=None)
|
||||
assert "无关" not in mapping or mapping == {}
|
||||
|
||||
def test_uses_llm_when_provided(self):
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = '{"工单号": "billNo", "客户名称": "customerName"}'
|
||||
mock_llm.invoke.return_value = mock_response
|
||||
kb_fields = [
|
||||
{"name": "billNo", "description": "工单号", "type": "String"},
|
||||
{"name": "customerName", "description": "客户", "type": "String"},
|
||||
]
|
||||
mapping = match_ocr_to_kb(["工单号", "客户名称"], kb_fields, llm=mock_llm)
|
||||
assert mapping["工单号"] == "billNo"
|
||||
|
||||
def test_embedding_failure_falls_back_to_llm(self):
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = '{"工单号": "billNo"}'
|
||||
mock_llm.invoke.return_value = mock_response
|
||||
with patch("backend.field_matcher._embed", side_effect=RuntimeError("model error")):
|
||||
kb_fields = [{"name": "billNo", "description": "工单号", "type": "String"}]
|
||||
mapping = match_ocr_to_kb(["工单号"], kb_fields, llm=mock_llm)
|
||||
assert mapping["工单号"] == "billNo"
|
||||
|
||||
|
||||
# ── 格式化上下文 ────────────────────────────────────────────────
|
||||
|
||||
class TestFormatFieldMappingContext:
|
||||
def test_formats_mapping_as_table(self):
|
||||
ctx = format_field_mapping_context({"工单号": "billNo", "客户": "customerName"})
|
||||
assert "[字段映射" in ctx
|
||||
assert "$P{billNo}" in ctx
|
||||
assert "$P{customerName}" in ctx
|
||||
assert "工单号" in ctx
|
||||
assert "客户" in ctx
|
||||
|
||||
def test_empty_mapping_returns_empty_string(self):
|
||||
assert format_field_mapping_context({}) == ""
|
||||
|
||||
def test_single_entry(self):
|
||||
ctx = format_field_mapping_context({"发票号码": "invoiceNo"})
|
||||
assert "$P{invoiceNo}" in ctx
|
||||
@@ -0,0 +1,325 @@
|
||||
"""JRXML 窗口化模块单元测试。
|
||||
|
||||
测试 decompose → split → reassemble 往返链路,
|
||||
以及元素计数和校验逻辑。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from agent.jrxml_windower import (
|
||||
decompose_jrxml, reassemble_jrxml,
|
||||
split_band_into_windows, reassemble_band_windows,
|
||||
count_elements, validate_element_count,
|
||||
BandInfo,
|
||||
)
|
||||
|
||||
# ── 最小 JRXML 测试夹具 ──────────────────────────────────────────────
|
||||
|
||||
MINIMAL_JRXML = """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<jasperReport name="test" pageWidth="595" pageHeight="842" columnCount="3">
|
||||
<property name="test.prop" value="1"/>
|
||||
<field name="name" class="java.lang.String"/>
|
||||
<field name="amount" class="java.math.BigDecimal"/>
|
||||
<queryString><![CDATA[SELECT * FROM t]]></queryString>
|
||||
<title>
|
||||
<band height="50">
|
||||
<staticText>
|
||||
<reportElement x="0" y="0" width="100" height="20"/>
|
||||
<text><![CDATA[Title]]></text>
|
||||
</staticText>
|
||||
<textField>
|
||||
<reportElement x="200" y="0" width="80" height="20"/>
|
||||
<textFieldExpression><![CDATA[$F{name}]]></textFieldExpression>
|
||||
</textField>
|
||||
</band>
|
||||
</title>
|
||||
<columnHeader>
|
||||
<band height="30">
|
||||
<staticText>
|
||||
<reportElement x="0" y="0" width="100" height="30"/>
|
||||
<text><![CDATA[Header]]></text>
|
||||
</staticText>
|
||||
</band>
|
||||
</columnHeader>
|
||||
<detail>
|
||||
<band height="40">
|
||||
<textField>
|
||||
<reportElement x="0" y="0" width="100" height="20"/>
|
||||
<textFieldExpression><![CDATA[$F{name}]]></textFieldExpression>
|
||||
</textField>
|
||||
<textField>
|
||||
<reportElement x="200" y="0" width="80" height="20"/>
|
||||
<textFieldExpression><![CDATA[$F{amount}]]></textFieldExpression>
|
||||
</textField>
|
||||
</band>
|
||||
</detail>
|
||||
<pageFooter>
|
||||
<band height="30">
|
||||
<textField>
|
||||
<reportElement x="0" y="0" width="100" height="20"/>
|
||||
<textFieldExpression><![CDATA["Page " + $V{PAGE_NUMBER}]]></textFieldExpression>
|
||||
</textField>
|
||||
</band>
|
||||
</pageFooter>
|
||||
</jasperReport>"""
|
||||
|
||||
|
||||
# ── Decompose 测试 ────────────────────────────────────────────────────
|
||||
|
||||
class TestDecompose:
|
||||
def test_parses_minimal_jrxml(self):
|
||||
parts = decompose_jrxml(MINIMAL_JRXML)
|
||||
assert parts is not None
|
||||
assert parts.band_count == 4 # title, columnHeader, detail, pageFooter
|
||||
assert parts.total_elements == 6 # 2 + 1 + 2 + 1
|
||||
|
||||
def test_declaration_preserved(self):
|
||||
parts = decompose_jrxml(MINIMAL_JRXML)
|
||||
assert '<?xml' in parts.declaration
|
||||
|
||||
def test_root_open_has_jasperreport(self):
|
||||
parts = decompose_jrxml(MINIMAL_JRXML)
|
||||
assert 'jasperReport' in parts.root_open
|
||||
|
||||
def test_header_children_separated(self):
|
||||
parts = decompose_jrxml(MINIMAL_JRXML)
|
||||
assert 'field name="name"' in parts.header_xml
|
||||
assert 'field name="amount"' in parts.header_xml
|
||||
assert 'queryString' in parts.header_xml
|
||||
assert 'property name' in parts.header_xml
|
||||
|
||||
def test_band_labels(self):
|
||||
parts = decompose_jrxml(MINIMAL_JRXML)
|
||||
labels = [b.label for b in parts.bands]
|
||||
assert labels == ["title", "columnHeader", "detail", "pageFooter"]
|
||||
|
||||
def test_footer_closes_jasperreport(self):
|
||||
parts = decompose_jrxml(MINIMAL_JRXML)
|
||||
assert 'jasperReport' in parts.footer
|
||||
assert parts.footer.strip().endswith('>')
|
||||
|
||||
def test_returns_none_for_non_jrxml(self):
|
||||
parts = decompose_jrxml("<html><body></body></html>")
|
||||
assert parts is None
|
||||
|
||||
def test_returns_none_for_malformed_xml(self):
|
||||
parts = decompose_jrxml("not xml at all <<<")
|
||||
assert parts is None
|
||||
|
||||
|
||||
# ── Roundtrip 测试 ────────────────────────────────────────────────────
|
||||
|
||||
class TestRoundtrip:
|
||||
def test_decompose_reassemble_element_count_unchanged(self):
|
||||
parts = decompose_jrxml(MINIMAL_JRXML)
|
||||
band_map = {b.label: b.band_xml for b in parts.bands}
|
||||
result = reassemble_jrxml(parts, band_map)
|
||||
|
||||
orig = count_elements(MINIMAL_JRXML)
|
||||
reassembled = count_elements(result)
|
||||
assert orig == reassembled, f"Elements: {orig} -> {reassembled}"
|
||||
|
||||
def test_roundtrip_preserves_text_content(self):
|
||||
parts = decompose_jrxml(MINIMAL_JRXML)
|
||||
band_map = {b.label: b.band_xml for b in parts.bands}
|
||||
result = reassemble_jrxml(parts, band_map)
|
||||
|
||||
assert 'Title' in result
|
||||
assert 'Header' in result
|
||||
assert '$F{name}' in result
|
||||
assert '$F{amount}' in result
|
||||
|
||||
def test_empty_bands_preserved(self):
|
||||
"""空 band(无元素)在 roundtrip 中不丢失。"""
|
||||
jrxml = """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<jasperReport name="t" pageWidth="595" pageHeight="842">
|
||||
<queryString><![CDATA[]]></queryString>
|
||||
<background>
|
||||
<band height="10"/>
|
||||
</background>
|
||||
<title>
|
||||
<band height="50">
|
||||
<staticText>
|
||||
<reportElement x="0" y="0" width="100" height="20"/>
|
||||
<text><![CDATA[T]]></text>
|
||||
</staticText>
|
||||
</band>
|
||||
</title>
|
||||
</jasperReport>"""
|
||||
parts = decompose_jrxml(jrxml)
|
||||
assert parts.band_count == 2
|
||||
band_map = {b.label: b.band_xml for b in parts.bands}
|
||||
result = reassemble_jrxml(parts, band_map)
|
||||
assert count_elements(jrxml) == count_elements(result)
|
||||
|
||||
|
||||
# ── Window Split 测试 ─────────────────────────────────────────────────
|
||||
|
||||
class TestWindowSplit:
|
||||
def test_small_band_not_split(self):
|
||||
"""小 band 不会被切分。"""
|
||||
band = BandInfo(
|
||||
section_name="title", band_index=0,
|
||||
band_xml='<band height="50"><staticText><reportElement x="0" y="0" width="1" height="1"/><text><![CDATA[X]]></text></staticText></band>',
|
||||
element_count=1, char_length=150,
|
||||
)
|
||||
windows = split_band_into_windows(band, max_chars=4000)
|
||||
assert len(windows) == 1
|
||||
|
||||
def test_large_band_split_at_element_boundaries(self):
|
||||
"""超过字符阈值的 band 在元素边界切分。"""
|
||||
inner = "<staticText><reportElement x=\"0\" y=\"0\" width=\"100\" height=\"20\"/><text><![CDATA[A]]></text></staticText>\n" * 80
|
||||
band_xml = f'<band height="50">{inner}</band>'
|
||||
band = BandInfo(
|
||||
section_name="detail", band_index=0,
|
||||
band_xml=band_xml,
|
||||
element_count=80, char_length=len(band_xml),
|
||||
)
|
||||
windows = split_band_into_windows(band, max_chars=4000)
|
||||
assert len(windows) > 1, f"Expected multiple windows, got {len(windows)}"
|
||||
|
||||
def test_split_preserves_element_count(self):
|
||||
"""切分后重组元素数不变。"""
|
||||
inner = "<staticText><reportElement x=\"0\" y=\"0\" width=\"100\" height=\"20\"/><text><![CDATA[A]]></text></staticText>\n" * 80
|
||||
band_xml = f'<band height="50">{inner}</band>'
|
||||
band = BandInfo(
|
||||
section_name="detail", band_index=0,
|
||||
band_xml=band_xml,
|
||||
element_count=80, char_length=len(band_xml),
|
||||
)
|
||||
windows = split_band_into_windows(band, max_chars=4000)
|
||||
reassembled = reassemble_band_windows(windows)
|
||||
assert count_elements(band_xml) == count_elements(reassembled)
|
||||
|
||||
def test_no_empty_windows(self):
|
||||
"""所有窗口非空。"""
|
||||
inner = "<staticText><reportElement x=\"0\" y=\"0\" width=\"100\" height=\"20\"/><text><![CDATA[A]]></text></staticText>\n" * 80
|
||||
band_xml = f'<band height="50">{inner}</band>'
|
||||
band = BandInfo(
|
||||
section_name="detail", band_index=0,
|
||||
band_xml=band_xml,
|
||||
element_count=80, char_length=len(band_xml),
|
||||
)
|
||||
windows = split_band_into_windows(band, max_chars=4000)
|
||||
for i, w in enumerate(windows):
|
||||
assert len(w.strip()) > 0, f"Window {i} is empty"
|
||||
assert '<band' in w, f"Window {i} missing <band>"
|
||||
|
||||
def test_namespaced_band_split(self):
|
||||
"""命名空间前缀的 band 也能正确切分。"""
|
||||
inner = "<ns0:staticText><ns0:reportElement x=\"0\" y=\"0\" width=\"100\" height=\"20\"/><ns0:text><![CDATA[A]]></ns0:text></ns0:staticText>\n" * 80
|
||||
band_xml = f'<ns0:band xmlns:ns0="http://jasperreports.sourceforge.net/jasperreports" height="50">{inner}</ns0:band>'
|
||||
band = BandInfo(
|
||||
section_name="detail", band_index=0,
|
||||
band_xml=band_xml,
|
||||
element_count=80, char_length=len(band_xml),
|
||||
)
|
||||
windows = split_band_into_windows(band, max_chars=4000)
|
||||
assert len(windows) > 1, f"Expected multiple, got {len(windows)}"
|
||||
for w in windows:
|
||||
assert '</ns0:band>' in w or w.startswith('<ns0:band')
|
||||
|
||||
|
||||
# ── Element Count 测试 ────────────────────────────────────────────────
|
||||
|
||||
class TestElementCount:
|
||||
def test_counts_textfield_statictext(self):
|
||||
xml = '<textField/><staticText/>'
|
||||
assert count_elements(xml) == 2
|
||||
|
||||
def test_counts_field_declarations(self):
|
||||
xml = '<field name="a" class="java.lang.String"/>'
|
||||
assert count_elements(xml) == 1
|
||||
|
||||
def test_counts_namespaced_elements(self):
|
||||
xml = '<ns0:textField/><ns0:staticText/><ns0:field name="x"/>'
|
||||
assert count_elements(xml) == 3
|
||||
|
||||
def test_minimal_jrxml_count(self):
|
||||
assert count_elements(MINIMAL_JRXML) == 8
|
||||
|
||||
def test_empty_string_zero(self):
|
||||
assert count_elements("") == 0
|
||||
|
||||
|
||||
# ── Validate 测试 ─────────────────────────────────────────────────────
|
||||
|
||||
class TestValidateElementCount:
|
||||
def test_no_change_ok(self):
|
||||
r = validate_element_count(MINIMAL_JRXML, MINIMAL_JRXML, "test")
|
||||
assert r["ok"] is True
|
||||
assert r["change_pct"] == 0
|
||||
|
||||
def test_small_change_ok(self):
|
||||
"""< 5% 变化静默通过。"""
|
||||
xml2 = MINIMAL_JRXML.replace('<staticText>', '<staticText><!-- comment -->')
|
||||
r = validate_element_count(MINIMAL_JRXML, xml2, "test")
|
||||
# 0% change since comments don't count as elements
|
||||
assert r["ok"] is True
|
||||
|
||||
def test_large_change_not_ok(self):
|
||||
"""> 10% 变化返回 ok=False。"""
|
||||
short = MINIMAL_JRXML[:500] # 大幅截断
|
||||
r = validate_element_count(MINIMAL_JRXML, short, "test")
|
||||
if r["original"] > 0 and r["change_pct"] > 0.10:
|
||||
assert r["ok"] is False
|
||||
|
||||
def test_zero_original_always_ok(self):
|
||||
r = validate_element_count("", MINIMAL_JRXML, "test")
|
||||
assert r["ok"] is True
|
||||
|
||||
|
||||
# ── 多 section 多 band 测试 ──────────────────────────────────────────
|
||||
|
||||
MULTI_BAND_JRXML = """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<jasperReport name="multi" pageWidth="595" pageHeight="842">
|
||||
<field name="f1" class="java.lang.String"/>
|
||||
<queryString><![CDATA[SELECT 1]]></queryString>
|
||||
<detail>
|
||||
<band height="30">
|
||||
<textField>
|
||||
<reportElement x="0" y="0" width="100" height="20"/>
|
||||
<textFieldExpression><![CDATA[$F{f1}]]></textFieldExpression>
|
||||
</textField>
|
||||
</band>
|
||||
<band height="20">
|
||||
<staticText>
|
||||
<reportElement x="0" y="0" width="100" height="15"/>
|
||||
<text><![CDATA[Sub]]></text>
|
||||
</staticText>
|
||||
</band>
|
||||
</detail>
|
||||
<summary>
|
||||
<band height="40">
|
||||
<textField>
|
||||
<reportElement x="0" y="0" width="200" height="30"/>
|
||||
<textFieldExpression><![CDATA["Total"]]></textFieldExpression>
|
||||
</textField>
|
||||
</band>
|
||||
</summary>
|
||||
</jasperReport>"""
|
||||
|
||||
|
||||
class TestMultiBand:
|
||||
def test_multiple_bands_same_section(self):
|
||||
"""同一 section 内的多个 band 分别处理。"""
|
||||
parts = decompose_jrxml(MULTI_BAND_JRXML)
|
||||
assert parts.band_count == 3 # detail_band0, detail_band1, summary
|
||||
labels = [b.label for b in parts.bands]
|
||||
assert labels == ["detail", "detail_band1", "summary"]
|
||||
|
||||
def test_multi_band_roundtrip(self):
|
||||
parts = decompose_jrxml(MULTI_BAND_JRXML)
|
||||
band_map = {b.label: b.band_xml for b in parts.bands}
|
||||
result = reassemble_jrxml(parts, band_map)
|
||||
assert count_elements(MULTI_BAND_JRXML) == count_elements(result)
|
||||
|
||||
def test_reassemble_opens_closes_sections(self):
|
||||
parts = decompose_jrxml(MULTI_BAND_JRXML)
|
||||
band_map = {b.label: b.band_xml for b in parts.bands}
|
||||
result = reassemble_jrxml(parts, band_map)
|
||||
assert result.count('<detail>') == 1
|
||||
assert result.count('</detail>') == 1
|
||||
assert result.count('<summary>') == 1
|
||||
assert result.count('</summary>') == 1
|
||||
@@ -0,0 +1,265 @@
|
||||
"""kb_manager.py 测试 — 用户 + KB CRUD, 原子写入, ID 验证。"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from backend.kb_manager import (
|
||||
_validate_id, _now_iso, _ensure_dir, _read_json, _write_json_atomic,
|
||||
_load_users, _save_users,
|
||||
create_user, list_users, get_user, delete_user,
|
||||
create_kb, list_kbs, get_kb, update_kb_meta, delete_kb,
|
||||
get_kb_raw_dir, get_kb_chunks_path, get_kb_chroma_path,
|
||||
KB_DATA_DIR, _USERS_FILE,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_kb_data(monkeypatch):
|
||||
with tempfile.TemporaryDirectory(prefix="test_kb_") as tmpdir:
|
||||
monkeypatch.setattr("backend.kb_manager.KB_DATA_DIR", Path(tmpdir))
|
||||
monkeypatch.setattr("backend.kb_manager._USERS_FILE", Path(tmpdir) / "users.json")
|
||||
yield Path(tmpdir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user(temp_kb_data):
|
||||
return create_user("测试用户")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kb(temp_kb_data, user):
|
||||
return create_kb(user["user_id"], "测试知识库", "测试描述")
|
||||
|
||||
|
||||
# ── ID 验证 ─────────────────────────────────────────────────────
|
||||
|
||||
class TestIDValidation:
|
||||
def test_valid_hex_id_passes(self):
|
||||
_validate_id("aabbccddeeff0011223344", "test_id")
|
||||
|
||||
def test_short_id_raises(self):
|
||||
with pytest.raises(ValueError, match="Invalid"):
|
||||
_validate_id("abc", "test_id")
|
||||
|
||||
def test_non_hex_id_raises(self):
|
||||
with pytest.raises(ValueError, match="Invalid"):
|
||||
_validate_id("not_valid!!!", "test_id")
|
||||
|
||||
def test_empty_id_raises(self):
|
||||
with pytest.raises(ValueError, match="Invalid"):
|
||||
_validate_id("", "test_id")
|
||||
|
||||
|
||||
# ── 原子写入 ────────────────────────────────────────────────────
|
||||
|
||||
class TestAtomicWrite:
|
||||
def test_write_json_atomic_creates_file(self, temp_kb_data):
|
||||
fp = temp_kb_data / "test.json"
|
||||
_write_json_atomic(fp, {"key": "value"})
|
||||
assert fp.exists()
|
||||
assert json.loads(fp.read_text(encoding="utf-8")) == {"key": "value"}
|
||||
|
||||
def test_write_json_atomic_overwrites(self, temp_kb_data):
|
||||
fp = temp_kb_data / "test.json"
|
||||
_write_json_atomic(fp, {"a": 1})
|
||||
_write_json_atomic(fp, {"b": 2})
|
||||
assert json.loads(fp.read_text(encoding="utf-8")) == {"b": 2}
|
||||
|
||||
def test_write_json_atomic_creates_parent_dir(self, temp_kb_data):
|
||||
fp = temp_kb_data / "deep" / "nested" / "test.json"
|
||||
_write_json_atomic(fp, {"ok": True})
|
||||
assert fp.exists()
|
||||
|
||||
def test_write_json_atomic_no_partial_file_on_error(self, temp_kb_data):
|
||||
fp = temp_kb_data / "fail.json"
|
||||
with patch("json.dump", side_effect=RuntimeError("boom")):
|
||||
with pytest.raises(RuntimeError):
|
||||
_write_json_atomic(fp, {"x": 1})
|
||||
assert not fp.exists()
|
||||
tmps = list(temp_kb_data.glob("*.json*"))
|
||||
assert len(tmps) == 0 or all(not f.name.endswith(".tmp") for f in tmps)
|
||||
|
||||
|
||||
# ── 用户 CRUD ───────────────────────────────────────────────────
|
||||
|
||||
class TestUserCRUD:
|
||||
def test_create_user_returns_dict(self, temp_kb_data):
|
||||
u = create_user("张三")
|
||||
assert u["name"] == "张三"
|
||||
assert len(u["user_id"]) >= 12
|
||||
assert "created_at" in u
|
||||
|
||||
def test_create_user_persists_to_disk(self, temp_kb_data):
|
||||
u = create_user("李四")
|
||||
loaded = list_users()
|
||||
assert any(x["user_id"] == u["user_id"] for x in loaded)
|
||||
|
||||
def test_create_user_with_custom_id(self, temp_kb_data):
|
||||
uid = "abcdef1234567890abcdef"
|
||||
u = create_user("王五", user_id=uid)
|
||||
assert u["user_id"] == uid
|
||||
|
||||
def test_create_duplicate_user_id_raises(self, temp_kb_data):
|
||||
uid = "deadbeef1234567890abcd"
|
||||
create_user("用户1", user_id=uid)
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
create_user("用户2", user_id=uid)
|
||||
|
||||
def test_list_users_empty(self, temp_kb_data):
|
||||
assert list_users() == []
|
||||
|
||||
def test_list_users_returns_all(self, temp_kb_data):
|
||||
create_user("A")
|
||||
create_user("B")
|
||||
assert len(list_users()) == 2
|
||||
|
||||
def test_get_user_found(self, user):
|
||||
u = get_user(user["user_id"])
|
||||
assert u is not None
|
||||
assert u["name"] == user["name"]
|
||||
|
||||
def test_get_user_not_found(self, temp_kb_data):
|
||||
assert get_user("deadbeef1234567890abcd") is None
|
||||
|
||||
def test_get_user_invalid_id_raises(self, temp_kb_data):
|
||||
with pytest.raises(ValueError):
|
||||
get_user("invalid")
|
||||
|
||||
def test_delete_user_returns_true(self, user):
|
||||
assert delete_user(user["user_id"]) is True
|
||||
|
||||
def test_delete_user_removes_from_list(self, user):
|
||||
delete_user(user["user_id"])
|
||||
assert get_user(user["user_id"]) is None
|
||||
|
||||
def test_delete_user_removes_dir(self, temp_kb_data, user):
|
||||
user_dir = temp_kb_data / user["user_id"]
|
||||
assert user_dir.exists()
|
||||
delete_user(user["user_id"])
|
||||
assert not user_dir.exists()
|
||||
|
||||
def test_delete_user_not_found_returns_false(self, temp_kb_data):
|
||||
assert delete_user("deadbeef1234567890abcd") is False
|
||||
|
||||
def test_delete_user_invalid_id_raises(self, temp_kb_data):
|
||||
with pytest.raises(ValueError):
|
||||
delete_user("bad_id")
|
||||
|
||||
|
||||
# ── KB CRUD ─────────────────────────────────────────────────────
|
||||
|
||||
class TestKbCRUD:
|
||||
def test_create_kb_returns_meta(self, kb):
|
||||
assert kb["name"] == "测试知识库"
|
||||
assert len(kb["kb_id"]) >= 12
|
||||
assert kb["parse_status"] == "empty"
|
||||
assert kb["file_count"] == 0
|
||||
|
||||
def test_create_kb_creates_dir_structure(self, temp_kb_data, user, kb):
|
||||
kb_dir = temp_kb_data / user["user_id"] / kb["kb_id"]
|
||||
assert kb_dir.is_dir()
|
||||
assert (kb_dir / "raw").is_dir()
|
||||
assert (kb_dir / "meta.json").exists()
|
||||
|
||||
def test_create_kb_with_custom_id(self, user):
|
||||
kid = "cafebabe1234567890feed"
|
||||
kb = create_kb(user["user_id"], "自定义ID库", kb_id=kid)
|
||||
assert kb["kb_id"] == kid
|
||||
|
||||
def test_list_kbs_empty(self, user):
|
||||
assert list_kbs(user["user_id"]) == []
|
||||
|
||||
def test_list_kbs_returns_all(self, user):
|
||||
create_kb(user["user_id"], "B库")
|
||||
create_kb(user["user_id"], "A库")
|
||||
assert len(list_kbs(user["user_id"])) == 2
|
||||
|
||||
def test_list_kbs_summary_format(self, user, kb):
|
||||
kbs = list_kbs(user["user_id"])
|
||||
s = kbs[0]
|
||||
for key in ("kb_id", "name", "field_count", "template_count", "parse_status"):
|
||||
assert key in s
|
||||
|
||||
def test_get_kb_found(self, kb):
|
||||
k = get_kb(kb["kb_id"])
|
||||
assert k is not None
|
||||
assert k["name"] == kb["name"]
|
||||
|
||||
def test_get_kb_not_found(self, temp_kb_data):
|
||||
assert get_kb("deadbeef1234567890abcd") is None
|
||||
|
||||
def test_get_kb_invalid_id_raises(self, temp_kb_data):
|
||||
with pytest.raises(ValueError):
|
||||
get_kb("bad")
|
||||
|
||||
def test_update_kb_meta_changes_fields(self, kb):
|
||||
updated = update_kb_meta(kb["kb_id"], {"parse_status": "ready", "file_count": 5})
|
||||
assert updated is not None
|
||||
assert updated["parse_status"] == "ready"
|
||||
assert updated["file_count"] == 5
|
||||
assert "updated_at" in updated
|
||||
|
||||
def test_update_kb_meta_not_found(self, temp_kb_data):
|
||||
assert update_kb_meta("deadbeef1234567890abcd", {"x": 1}) is None
|
||||
|
||||
def test_delete_kb_returns_true(self, kb):
|
||||
assert delete_kb(kb["kb_id"]) is True
|
||||
|
||||
def test_delete_kb_removes_dir(self, temp_kb_data, user, kb):
|
||||
kb_dir = temp_kb_data / user["user_id"] / kb["kb_id"]
|
||||
assert kb_dir.exists()
|
||||
delete_kb(kb["kb_id"])
|
||||
assert not kb_dir.exists()
|
||||
|
||||
def test_delete_kb_not_found_returns_false(self, temp_kb_data):
|
||||
assert delete_kb("deadbeef1234567890abcd") is False
|
||||
|
||||
|
||||
# ── 工具函数 ────────────────────────────────────────────────────
|
||||
|
||||
class TestHelpers:
|
||||
def test_get_kb_raw_dir(self, kb):
|
||||
d = get_kb_raw_dir(kb["kb_id"])
|
||||
assert d is not None
|
||||
assert d.name == "raw"
|
||||
|
||||
def test_get_kb_raw_dir_not_found(self, temp_kb_data):
|
||||
assert get_kb_raw_dir("deadbeef1234567890abcd") is None
|
||||
|
||||
def test_get_kb_chunks_path(self, kb):
|
||||
p = get_kb_chunks_path(kb["kb_id"])
|
||||
assert p is not None
|
||||
assert p.name == "chunks.json"
|
||||
|
||||
def test_get_kb_chroma_path_creates_dir(self, kb):
|
||||
p = get_kb_chroma_path(kb["kb_id"])
|
||||
assert p is not None
|
||||
assert p.name == "chroma"
|
||||
assert p.exists()
|
||||
|
||||
def test_user_can_own_multiple_kbs(self, user):
|
||||
create_kb(user["user_id"], "KB1")
|
||||
create_kb(user["user_id"], "KB2")
|
||||
create_kb(user["user_id"], "KB3")
|
||||
assert len(list_kbs(user["user_id"])) == 3
|
||||
|
||||
def test_different_users_have_isolated_kbs(self, temp_kb_data):
|
||||
u1 = create_user("用户A")
|
||||
u2 = create_user("用户B")
|
||||
create_kb(u1["user_id"], "A的库")
|
||||
create_kb(u2["user_id"], "B的库")
|
||||
assert len(list_kbs(u1["user_id"])) == 1
|
||||
assert len(list_kbs(u2["user_id"])) == 1
|
||||
|
||||
def test_delete_user_cascades_to_kbs(self, temp_kb_data, user):
|
||||
create_kb(user["user_id"], "要被删除的库")
|
||||
delete_user(user["user_id"])
|
||||
assert not (temp_kb_data / user["user_id"]).exists()
|
||||
@@ -0,0 +1,311 @@
|
||||
"""kb_parser.py 测试 — JRXML 解析, 文件处理, 分块, 字段提取。"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from backend.kb_parser import (
|
||||
parse_jrxml_fields, _extract_archive, process_file_for_kb,
|
||||
chunk_file_results, extract_fields_with_llm, _extract_fields_from_table,
|
||||
_find_tag, _find_all_tags, _collect_from_result, build_kb_from_files,
|
||||
)
|
||||
|
||||
SAMPLE_JRXML = """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<jasperReport name="TestReport" pageWidth="595" pageHeight="842"
|
||||
xmlns="http://jasperreports.sourceforge.net/jasperreports">
|
||||
<parameter name="billNo" class="java.lang.String">
|
||||
<parameterDescription>工单号</parameterDescription>
|
||||
</parameter>
|
||||
<parameter name="customerName" class="java.lang.String"/>
|
||||
<field name="amount" class="java.math.BigDecimal"/>
|
||||
<field name="createDate" class="java.sql.Date"/>
|
||||
<queryString><![CDATA[SELECT * FROM orders WHERE bill_no = $P{billNo}]]></queryString>
|
||||
</jasperReport>"""
|
||||
|
||||
SAMPLE_JRXML_NO_NS = """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<jasperReport name="SimpleReport" pageWidth="800" pageHeight="600">
|
||||
<parameter name="title" class="java.lang.String"/>
|
||||
<field name="name" class="java.lang.String"/>
|
||||
</jasperReport>"""
|
||||
|
||||
INVALID_XML = """<?xml version="1.0"?>
|
||||
<jasperReport>
|
||||
<parameter name="broken"
|
||||
</jasperReport>"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jrxml_file():
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jrxml",
|
||||
delete=False, encoding="utf-8") as f:
|
||||
f.write(SAMPLE_JRXML)
|
||||
path = f.name
|
||||
yield path
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_kb(monkeypatch):
|
||||
with tempfile.TemporaryDirectory(prefix="test_kb_parser_") as tmpdir:
|
||||
kb_data = Path(tmpdir)
|
||||
user_dir = kb_data / "default"
|
||||
kb_dir = user_dir / "abcdef1234567890abcd"
|
||||
raw_dir = kb_dir / "raw"
|
||||
raw_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(
|
||||
"backend.kb_manager.get_kb_raw_dir",
|
||||
lambda kb_id: raw_dir if kb_id == "abcdef1234567890abcd" else None)
|
||||
monkeypatch.setattr(
|
||||
"backend.kb_manager.get_kb_chunks_path",
|
||||
lambda kb_id: kb_dir / "chunks.json" if kb_id == "abcdef1234567890abcd" else None)
|
||||
monkeypatch.setattr(
|
||||
"backend.kb_manager.update_kb_meta",
|
||||
lambda kb_id, updates: None)
|
||||
yield {"kb_id": "abcdef1234567890abcd", "kb_dir": kb_dir, "raw_dir": raw_dir,
|
||||
"data_dir": kb_data}
|
||||
|
||||
|
||||
# ── JRXML 解析 ──────────────────────────────────────────────────
|
||||
|
||||
class TestParseJrxmlFields:
|
||||
def test_parses_parameters(self, jrxml_file):
|
||||
result = parse_jrxml_fields(jrxml_file)
|
||||
assert result["error"] is None
|
||||
assert len(result["parameters"]) == 2
|
||||
assert result["parameters"][0]["name"] == "billNo"
|
||||
assert result["parameters"][0]["type"] == "java.lang.String"
|
||||
assert result["parameters"][0]["description"] == "工单号"
|
||||
|
||||
def test_parses_fields(self, jrxml_file):
|
||||
result = parse_jrxml_fields(jrxml_file)
|
||||
assert len(result["fields"]) == 2
|
||||
field_names = [f["name"] for f in result["fields"]]
|
||||
assert "amount" in field_names
|
||||
assert "createDate" in field_names
|
||||
|
||||
def test_parses_query(self, jrxml_file):
|
||||
result = parse_jrxml_fields(jrxml_file)
|
||||
assert "SELECT * FROM orders" in result["query"]
|
||||
|
||||
def test_parses_report_metadata(self, jrxml_file):
|
||||
result = parse_jrxml_fields(jrxml_file)
|
||||
assert result["report_name"] == "TestReport"
|
||||
assert result["page_width"] == "595"
|
||||
assert result["page_height"] == "842"
|
||||
|
||||
def test_parses_jrxml_without_namespace(self, tmp_path):
|
||||
fp = tmp_path / "simple.jrxml"
|
||||
fp.write_text(SAMPLE_JRXML_NO_NS, encoding="utf-8")
|
||||
result = parse_jrxml_fields(str(fp))
|
||||
assert result["report_name"] == "SimpleReport"
|
||||
assert len(result["parameters"]) == 1
|
||||
|
||||
def test_invalid_xml_returns_error(self, tmp_path):
|
||||
fp = tmp_path / "bad.jrxml"
|
||||
fp.write_text(INVALID_XML, encoding="utf-8")
|
||||
result = parse_jrxml_fields(str(fp))
|
||||
assert result["error"] is not None
|
||||
assert "解析失败" in result["error"]
|
||||
|
||||
def test_empty_jrxml_has_no_fields(self, tmp_path):
|
||||
fp = tmp_path / "empty.jrxml"
|
||||
fp.write_text(
|
||||
'<?xml version="1.0"?>'
|
||||
'<jasperReport name="Empty"/>',
|
||||
encoding="utf-8")
|
||||
result = parse_jrxml_fields(str(fp))
|
||||
assert result["parameters"] == []
|
||||
assert result["fields"] == []
|
||||
|
||||
def test_nonexistent_file_raises(self):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
parse_jrxml_fields("/nonexistent/path.jrxml")
|
||||
|
||||
|
||||
# ── 表格字段提取 ────────────────────────────────────────────────
|
||||
|
||||
class TestExtractFieldsFromTable:
|
||||
def test_extracts_from_markdown_table(self):
|
||||
text = """| 字段名 | 含义 | 必填 | 类型 |
|
||||
|--------|------|------|------|
|
||||
| billNo | 工单号 | 是 | String |
|
||||
| amount | 金额 | 否 | BigDecimal |"""
|
||||
fields = _extract_fields_from_table(text)
|
||||
assert len(fields) == 2
|
||||
assert fields[0]["name"] == "billNo"
|
||||
assert fields[0]["description"] == "工单号"
|
||||
assert fields[0]["required"] is True
|
||||
assert fields[1]["name"] == "amount"
|
||||
|
||||
def test_skips_separator_rows(self):
|
||||
text = """| 字段 | 说明 |
|
||||
|------|------|
|
||||
|------|------|
|
||||
| name | 名称 |"""
|
||||
fields = _extract_fields_from_table(text)
|
||||
assert len(fields) == 1
|
||||
assert fields[0]["name"] == "name"
|
||||
|
||||
def test_returns_empty_for_plain_text(self):
|
||||
fields = _extract_fields_from_table("这是一段普通文本,没有表格。")
|
||||
assert fields == []
|
||||
|
||||
def test_cells_with_bold_markers_stripped(self):
|
||||
text = """| 名称 | 含义 |
|
||||
|------|------|
|
||||
| **billNo** | 单号 |"""
|
||||
fields = _extract_fields_from_table(text)
|
||||
assert fields[0]["name"] == "billNo"
|
||||
|
||||
|
||||
# ── LLM 字段提取 ────────────────────────────────────────────────
|
||||
|
||||
class TestExtractFieldsWithLlm:
|
||||
def test_falls_back_to_table_when_no_llm(self):
|
||||
text = "| 字段 | 说明 |\n|------|------|\n| code | 编码 |"
|
||||
fields = extract_fields_with_llm(text, llm=None)
|
||||
assert len(fields) >= 1
|
||||
assert any(f["name"] == "code" for f in fields)
|
||||
|
||||
def test_uses_llm_when_provided(self):
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = '[{"name": "id", "description": "ID", "type": "Long", "required": true}]'
|
||||
mock_llm.invoke.return_value = mock_response
|
||||
fields = extract_fields_with_llm("some text", llm=mock_llm)
|
||||
assert len(fields) == 1
|
||||
assert fields[0]["name"] == "id"
|
||||
|
||||
def test_llm_failure_falls_back_to_table(self):
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke.side_effect = RuntimeError("LLM down")
|
||||
text = "| 字段 | 说明 |\n|------|------|\n| code | 编码 |"
|
||||
fields = extract_fields_with_llm(text, llm=mock_llm)
|
||||
assert any(f["name"] == "code" for f in fields)
|
||||
|
||||
|
||||
# ── 文件处理 ────────────────────────────────────────────────────
|
||||
|
||||
class TestProcessFileForKb:
|
||||
def test_process_jrxml_copies_and_parses(self, jrxml_file, temp_kb):
|
||||
result = process_file_for_kb(temp_kb["kb_id"], jrxml_file)
|
||||
assert result["type"] == "jrxml"
|
||||
assert result["jrxml_info"]["report_name"] == "TestReport"
|
||||
assert result["error"] is None
|
||||
copied = list(temp_kb["raw_dir"].glob("*.jrxml"))
|
||||
assert len(copied) == 1
|
||||
|
||||
def test_process_nonexistent_kb_returns_error(self, jrxml_file):
|
||||
result = process_file_for_kb("deadbeef1234567890abcd", jrxml_file)
|
||||
assert result["error"] is not None
|
||||
|
||||
def test_process_text_file(self, tmp_path, temp_kb):
|
||||
fp = tmp_path / "readme.md"
|
||||
fp.write_text("# 标题\n\n这是一段内容。\n\n另一段内容。", encoding="utf-8")
|
||||
with patch("backend.kb_parser.parse_file") as mock_parse:
|
||||
mock_parse.return_value = {"text": "parsed content", "error": None}
|
||||
result = process_file_for_kb(temp_kb["kb_id"], str(fp))
|
||||
assert result["filename"] is not None
|
||||
assert result["error"] is None
|
||||
|
||||
|
||||
# ── 分块 ────────────────────────────────────────────────────────
|
||||
|
||||
class TestChunkFileResults:
|
||||
def test_jrxml_result_produces_template_chunk(self, jrxml_file):
|
||||
info = parse_jrxml_fields(jrxml_file)
|
||||
raw = Path(jrxml_file).read_text(encoding="utf-8")
|
||||
results = [{
|
||||
"filename": "test.jrxml", "type": "jrxml",
|
||||
"text": "text content", "raw_xml": raw,
|
||||
"jrxml_info": info, "error": None,
|
||||
}]
|
||||
chunks = chunk_file_results(results, kb_name="测试库")
|
||||
assert len(chunks) >= 1
|
||||
tmpl = [c for c in chunks if c["metadata"]["chunk_type"] == "jrxml_template"]
|
||||
assert len(tmpl) == 1
|
||||
assert tmpl[0]["metadata"]["report_name"] == "TestReport"
|
||||
assert "TestReport" in tmpl[0]["content"]
|
||||
|
||||
def test_archive_result_recurses(self):
|
||||
results = [{
|
||||
"filename": "bundle.zip", "type": "archive", "text": "",
|
||||
"archive_contents": [
|
||||
{"filename": "inner.jrxml", "type": "jrxml",
|
||||
"text": "inner text", "raw_xml": "<xml/>",
|
||||
"jrxml_info": {"report_name": "Inner", "parameters": [], "fields": []},
|
||||
"error": None},
|
||||
], "error": None,
|
||||
}]
|
||||
chunks = chunk_file_results(results)
|
||||
assert any(c["metadata"]["report_name"] == "Inner" for c in chunks)
|
||||
|
||||
def test_empty_text_skipped(self):
|
||||
results = [{"filename": "empty.md", "type": "md", "text": "", "error": None}]
|
||||
assert chunk_file_results(results) == []
|
||||
|
||||
def test_short_paragraphs_skipped(self):
|
||||
results = [{"filename": "short.txt", "type": "txt", "text": "hi", "error": None}]
|
||||
assert chunk_file_results(results) == []
|
||||
|
||||
def test_text_split_into_paragraphs(self):
|
||||
long_para = "A" * 50
|
||||
results = [
|
||||
{"filename": "doc.txt", "type": "txt",
|
||||
"text": f"{long_para}\n\n{long_para}\n\n{long_para}", "error": None},
|
||||
]
|
||||
chunks = chunk_file_results(results)
|
||||
assert len(chunks) == 3
|
||||
|
||||
|
||||
# ── _collect_from_result ────────────────────────────────────────
|
||||
|
||||
class TestCollectFromResult:
|
||||
def test_collects_jrxml_parameters_as_fields(self):
|
||||
fields = []
|
||||
templates = []
|
||||
_collect_from_result({
|
||||
"jrxml_info": {
|
||||
"report_name": "R1",
|
||||
"parameters": [{"name": "p1", "type": "String", "description": "参数1"}],
|
||||
"fields": [],
|
||||
},
|
||||
"filename": "r1.jrxml",
|
||||
}, fields, templates)
|
||||
assert len(templates) == 1
|
||||
assert any(f["name"] == "p1" for f in fields)
|
||||
|
||||
def test_collects_jrxml_fields(self):
|
||||
fields = []
|
||||
templates = []
|
||||
_collect_from_result({
|
||||
"jrxml_info": {
|
||||
"report_name": "R2",
|
||||
"parameters": [],
|
||||
"fields": [{"name": "f1", "type": "Double", "description": ""}],
|
||||
},
|
||||
"filename": "r2.jrxml",
|
||||
}, fields, templates)
|
||||
assert any(f["name"] == "f1" for f in fields)
|
||||
|
||||
def test_skips_non_jrxml(self):
|
||||
fields = []
|
||||
templates = []
|
||||
_collect_from_result({"type": "csv", "filename": "data.csv"}, fields, templates)
|
||||
assert templates == []
|
||||
assert fields == []
|
||||
|
||||
def test_deduplicates_fields(self):
|
||||
fields = []
|
||||
templates = []
|
||||
info = {"report_name": "R", "parameters": [{"name": "dup", "type": "String", "description": ""}], "fields": []}
|
||||
_collect_from_result({"jrxml_info": info, "filename": "a.jrxml"}, fields, templates)
|
||||
_collect_from_result({"jrxml_info": info, "filename": "b.jrxml"}, fields, templates)
|
||||
assert sum(1 for f in fields if f["name"] == "dup") == 1
|
||||
@@ -0,0 +1,214 @@
|
||||
"""kb_searcher.py 测试 — KBChromaSearcher 创建, 搜索, 模板检索。"""
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from backend.kb_searcher import (
|
||||
KBChromaSearcher, get_kb_searcher, search_kb, search_templates_in_kb,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chromadb(monkeypatch):
|
||||
mock_client = MagicMock()
|
||||
mock_collection = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_collection
|
||||
mock_client.create_collection.return_value = mock_collection
|
||||
monkeypatch.setattr(
|
||||
"chromadb.PersistentClient",
|
||||
lambda path: mock_client)
|
||||
mock_st = MagicMock()
|
||||
mock_st_model = MagicMock()
|
||||
mock_st_model.encode.return_value = MagicMock()
|
||||
mock_st_model.encode.return_value.tolist.return_value = [0.1, 0.2, 0.3]
|
||||
mock_st.return_value = mock_st_model
|
||||
monkeypatch.setattr("sentence_transformers.SentenceTransformer", mock_st)
|
||||
yield {"client": mock_client, "collection": mock_collection,
|
||||
"st_model": mock_st_model, "st": mock_st}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def searcher(mock_chromadb):
|
||||
with tempfile.TemporaryDirectory(prefix="test_chroma_") as tmpdir:
|
||||
s = KBChromaSearcher(chroma_path=tmpdir, collection_name="test_kb")
|
||||
s._model = mock_chromadb["st_model"]
|
||||
s._client = mock_chromadb["client"]
|
||||
s._collection = mock_chromadb["collection"]
|
||||
yield s
|
||||
|
||||
|
||||
# ── 创建 & 就绪检查 ─────────────────────────────────────────────
|
||||
|
||||
class TestKBChromaSearcherInit:
|
||||
def test_creates_with_defaults(self, mock_chromadb):
|
||||
with tempfile.TemporaryDirectory(prefix="test_chroma_") as tmpdir:
|
||||
s = KBChromaSearcher(chroma_path=tmpdir)
|
||||
assert s.collection_name == "kb_chunks"
|
||||
assert s.chroma_path == str(tmpdir)
|
||||
|
||||
def test_custom_collection_name(self, mock_chromadb):
|
||||
with tempfile.TemporaryDirectory(prefix="test_chroma_") as tmpdir:
|
||||
s = KBChromaSearcher(chroma_path=tmpdir, collection_name="custom")
|
||||
assert s.collection_name == "custom"
|
||||
|
||||
def test_model_lazy_loaded(self, mock_chromadb):
|
||||
with tempfile.TemporaryDirectory(prefix="test_chroma_") as tmpdir:
|
||||
s = KBChromaSearcher(chroma_path=tmpdir)
|
||||
assert s._model is None
|
||||
|
||||
def test_is_ready_true(self, searcher):
|
||||
assert searcher.is_ready() is True
|
||||
|
||||
def test_is_ready_false(self, searcher, mock_chromadb):
|
||||
mock_chromadb["client"].get_collection.side_effect = Exception("no collection")
|
||||
assert searcher.is_ready() is False
|
||||
|
||||
|
||||
# ── 搜索 ────────────────────────────────────────────────────────
|
||||
|
||||
class TestSearch:
|
||||
def test_search_returns_empty_when_not_ready(self, searcher, mock_chromadb):
|
||||
mock_chromadb["client"].get_collection.side_effect = Exception("no collection")
|
||||
results = searcher.search("test query")
|
||||
assert results == []
|
||||
|
||||
def test_search_calls_collection_query(self, searcher, mock_chromadb):
|
||||
mock_chromadb["collection"].query.return_value = {
|
||||
"ids": [["chunk_0", "chunk_1"]],
|
||||
"documents": [["doc1", "doc2"]],
|
||||
"metadatas": [[{"chunk_type": "md"}, {"chunk_type": "txt"}]],
|
||||
"distances": [[0.1, 0.3]],
|
||||
}
|
||||
results = searcher.search("query", k=5)
|
||||
assert len(results) == 2
|
||||
assert results[0]["id"] == "chunk_0"
|
||||
assert results[0]["content"] == "doc1"
|
||||
assert results[0]["metadata"]["chunk_type"] == "md"
|
||||
assert results[0]["distance"] == 0.1
|
||||
|
||||
def test_search_respects_threshold(self, searcher, mock_chromadb):
|
||||
mock_chromadb["collection"].query.return_value = {
|
||||
"ids": [["chunk_0", "chunk_1"]],
|
||||
"documents": [["doc1", "doc2"]],
|
||||
"metadatas": [[{}, {}]],
|
||||
"distances": [[0.2, 0.8]],
|
||||
}
|
||||
results = searcher.search("query", threshold=0.5)
|
||||
assert len(results) == 1
|
||||
assert results[0]["id"] == "chunk_0"
|
||||
|
||||
def test_search_empty_results(self, searcher, mock_chromadb):
|
||||
mock_chromadb["collection"].query.return_value = {
|
||||
"ids": [[]], "documents": [[]], "metadatas": [[]], "distances": [[]],
|
||||
}
|
||||
assert searcher.search("query") == []
|
||||
|
||||
|
||||
# ── 模板搜索 ────────────────────────────────────────────────────
|
||||
|
||||
class TestSearchTemplates:
|
||||
def test_filters_jrxml_chunks(self, searcher, mock_chromadb):
|
||||
mock_chromadb["collection"].query.return_value = {
|
||||
"ids": [["c0", "c1", "c2"]],
|
||||
"documents": [["t1", "t2", "t3"]],
|
||||
"metadatas": [[
|
||||
{"chunk_type": "jrxml_template", "report_name": "R1"},
|
||||
{"chunk_type": "md_section"},
|
||||
{"chunk_type": "jrxml_template", "report_name": "R2"},
|
||||
]],
|
||||
"distances": [[0.1, 0.2, 0.3]],
|
||||
}
|
||||
templates = searcher.search_templates("query", k=3)
|
||||
assert len(templates) >= 1
|
||||
for t in templates:
|
||||
meta = t["metadata"]
|
||||
assert "jrxml" in meta.get("chunk_type", "").lower() or meta.get("report_name")
|
||||
|
||||
def test_no_jrxml_chunks_returns_empty(self, searcher, mock_chromadb):
|
||||
mock_chromadb["collection"].query.return_value = {
|
||||
"ids": [["c0"]],
|
||||
"documents": [["text"]],
|
||||
"metadatas": [[{"chunk_type": "md_section"}]],
|
||||
"distances": [[0.1]],
|
||||
}
|
||||
templates = searcher.search_templates("query")
|
||||
assert templates == []
|
||||
|
||||
|
||||
# ── search_as_context ───────────────────────────────────────────
|
||||
|
||||
class TestSearchAsContext:
|
||||
def test_returns_formatted_context(self, searcher, mock_chromadb):
|
||||
mock_chromadb["collection"].query.return_value = {
|
||||
"ids": [["c0", "c1"]],
|
||||
"documents": [["内容1", "内容2"]],
|
||||
"metadatas": [[
|
||||
{"chunk_type": "md", "report_name": "R1"},
|
||||
{"chunk_type": "txt"},
|
||||
]],
|
||||
"distances": [[0.1, 0.2]],
|
||||
}
|
||||
ctx = searcher.search_as_context("q", k=2)
|
||||
assert "内容1" in ctx
|
||||
assert "内容2" in ctx
|
||||
assert "类型" in ctx
|
||||
assert "报表" in ctx
|
||||
assert "---" in ctx
|
||||
|
||||
def test_empty_returns_empty_string(self, searcher, mock_chromadb):
|
||||
mock_chromadb["collection"].query.return_value = {
|
||||
"ids": [[]], "documents": [[]], "metadatas": [[]], "distances": [[]],
|
||||
}
|
||||
assert searcher.search_as_context("q") == ""
|
||||
|
||||
|
||||
# ── add_chunks ──────────────────────────────────────────────────
|
||||
|
||||
class TestAddChunks:
|
||||
def test_add_chunks_calls_upsert(self, searcher, mock_chromadb):
|
||||
chunks = [{"id": "c0", "content": "test content", "metadata": {"chunk_type": "md"}}]
|
||||
searcher.add_chunks(chunks)
|
||||
mock_chromadb["collection"].upsert.assert_called_once()
|
||||
kwargs = mock_chromadb["collection"].upsert.call_args[1]
|
||||
assert kwargs["ids"] == ["c0"]
|
||||
assert kwargs["documents"] == ["test content"]
|
||||
|
||||
def test_empty_chunks_noop(self, searcher, mock_chromadb):
|
||||
searcher.add_chunks([])
|
||||
mock_chromadb["collection"].upsert.assert_not_called()
|
||||
|
||||
|
||||
# ── 工厂函数 ────────────────────────────────────────────────────
|
||||
|
||||
class TestGetKbSearcher:
|
||||
def test_returns_cached_instance(self, monkeypatch, mock_chromadb):
|
||||
with tempfile.TemporaryDirectory(prefix="test_chroma_") as tmpdir:
|
||||
monkeypatch.setattr(
|
||||
"backend.kb_manager.get_kb_chroma_path",
|
||||
lambda kb_id: Path(tmpdir) if kb_id == "abcdef1234567890abcd" else None)
|
||||
s1 = get_kb_searcher("abcdef1234567890abcd")
|
||||
s2 = get_kb_searcher("abcdef1234567890abcd")
|
||||
assert s1 is s2
|
||||
|
||||
def test_returns_none_for_invalid_kb(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.kb_manager.get_kb_chroma_path", lambda kb_id: None)
|
||||
assert get_kb_searcher("deadbeef1234567890abcd") is None
|
||||
|
||||
|
||||
class TestSearchKbFunction:
|
||||
def test_returns_empty_for_invalid_kb(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.kb_manager.get_kb_chroma_path", lambda kb_id: None)
|
||||
assert search_kb("deadbeef1234567890abcd", "query") == ""
|
||||
|
||||
def test_returns_empty_for_invalid_template_search(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.kb_manager.get_kb_chroma_path", lambda kb_id: None)
|
||||
assert search_templates_in_kb("deadbeef1234567890abcd", "query") == []
|
||||
@@ -0,0 +1,200 @@
|
||||
"""程序化字段映射单元测试。
|
||||
|
||||
测试 _programmatic_map_fields 和 _sanitize_field_name
|
||||
的确定性替换行为,以及 validate_element_count 校验。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from agent.nodes import _programmatic_map_fields, _sanitize_field_name
|
||||
from agent.jrxml_windower import count_elements, validate_element_count
|
||||
|
||||
# ── 最小 JRXML 模板(含占位字段)────────────────────────────────────
|
||||
|
||||
JRXML_WITH_PLACEHOLDERS = """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<jasperReport name="test" pageWidth="595" pageHeight="842">
|
||||
<field name="field_1" class="java.lang.String"/>
|
||||
<field name="field_2" class="java.math.BigDecimal"/>
|
||||
<field name="field_3" class="java.lang.String"/>
|
||||
<queryString><![CDATA[SELECT * FROM t]]></queryString>
|
||||
<title>
|
||||
<band height="50">
|
||||
<staticText>
|
||||
<reportElement x="0" y="0" width="100" height="20"/>
|
||||
<text><![CDATA[$F{field_1}]]></text>
|
||||
</staticText>
|
||||
<textField>
|
||||
<reportElement x="100" y="0" width="80" height="20"/>
|
||||
<textFieldExpression><![CDATA[$F{field_2}]]></textFieldExpression>
|
||||
</textField>
|
||||
<textField>
|
||||
<reportElement x="200" y="0" width="80" height="20"/>
|
||||
<textFieldExpression><![CDATA[$F{field_3}]]></textFieldExpression>
|
||||
</textField>
|
||||
</band>
|
||||
</title>
|
||||
<detail>
|
||||
<band height="30">
|
||||
<textField>
|
||||
<reportElement x="0" y="0" width="100" height="20"/>
|
||||
<textFieldExpression><![CDATA[$F{field_1} + " " + $F{field_2}]]></textFieldExpression>
|
||||
</textField>
|
||||
</band>
|
||||
</detail>
|
||||
</jasperReport>"""
|
||||
|
||||
|
||||
# ── _sanitize_field_name 测试 ────────────────────────────────────────
|
||||
|
||||
class TestSanitizeFieldName:
|
||||
def test_ascii_name_passes_through(self):
|
||||
assert _sanitize_field_name("customer_name") == "customer_name"
|
||||
|
||||
def test_uppercase_lowered(self):
|
||||
assert _sanitize_field_name("CustomerName") == "customername"
|
||||
|
||||
def test_spaces_replaced(self):
|
||||
assert _sanitize_field_name("customer name") == "customer_name"
|
||||
|
||||
def test_chinese_characters_escaped(self):
|
||||
result = _sanitize_field_name("发票代码")
|
||||
assert "发票" not in result
|
||||
assert "u53d1_" in result
|
||||
assert "u7968_" in result
|
||||
|
||||
def test_mixed_ascii_chinese(self):
|
||||
result = _sanitize_field_name("发票_code")
|
||||
assert "_code" in result
|
||||
assert "u53d1_" in result
|
||||
|
||||
def test_empty_returns_unnamed(self):
|
||||
assert _sanitize_field_name("") == "unnamed_field"
|
||||
|
||||
def test_all_special_chars_returns_unnamed(self):
|
||||
assert _sanitize_field_name("!!!") == "unnamed_field"
|
||||
|
||||
def test_leading_digit_prefixed(self):
|
||||
result = _sanitize_field_name("123abc")
|
||||
assert result == "f_123abc"
|
||||
|
||||
def test_consecutive_underscores_collapsed(self):
|
||||
result = _sanitize_field_name("a__b___c")
|
||||
assert result == "a_b_c"
|
||||
|
||||
def test_japanese_characters_escaped(self):
|
||||
result = _sanitize_field_name("請求書")
|
||||
assert "請求" not in result
|
||||
|
||||
|
||||
# ── _programmatic_map_fields 测试 ────────────────────────────────────
|
||||
|
||||
class TestProgrammaticMapFields:
|
||||
def test_replaces_field_declarations(self):
|
||||
ocr = [
|
||||
{"field_name": "customer_name"},
|
||||
{"field_name": "total_amount"},
|
||||
{"field_name": "invoice_date"},
|
||||
]
|
||||
result = _programmatic_map_fields(JRXML_WITH_PLACEHOLDERS, ocr)
|
||||
assert 'field name="customer_name"' in result
|
||||
assert 'field name="total_amount"' in result
|
||||
assert 'field name="invoice_date"' in result
|
||||
assert 'field name="field_1"' not in result
|
||||
|
||||
def test_replaces_field_references(self):
|
||||
ocr = [
|
||||
{"field_name": "customer_name"},
|
||||
{"field_name": "total_amount"},
|
||||
{"field_name": "invoice_date"},
|
||||
]
|
||||
result = _programmatic_map_fields(JRXML_WITH_PLACEHOLDERS, ocr)
|
||||
assert "$F{field_1}" not in result
|
||||
assert "$F{customer_name}" in result
|
||||
assert "$F{total_amount}" in result
|
||||
assert "$F{invoice_date}" in result
|
||||
|
||||
def test_preserves_element_count(self):
|
||||
ocr = [
|
||||
{"field_name": "customer_name"},
|
||||
{"field_name": "total_amount"},
|
||||
{"field_name": "invoice_date"},
|
||||
]
|
||||
result = _programmatic_map_fields(JRXML_WITH_PLACEHOLDERS, ocr)
|
||||
orig = count_elements(JRXML_WITH_PLACEHOLDERS)
|
||||
mod = count_elements(result)
|
||||
assert orig == mod, f"Elements: {orig} -> {mod}"
|
||||
|
||||
def test_preserves_coordinates(self):
|
||||
ocr = [
|
||||
{"field_name": "customer_name"},
|
||||
{"field_name": "total_amount"},
|
||||
{"field_name": "invoice_date"},
|
||||
]
|
||||
result = _programmatic_map_fields(JRXML_WITH_PLACEHOLDERS, ocr)
|
||||
assert 'x="0"' in result
|
||||
assert 'x="100"' in result
|
||||
assert 'x="200"' in result
|
||||
assert 'y="0"' in result
|
||||
assert 'width="100"' in result
|
||||
assert 'height="20"' in result
|
||||
|
||||
def test_partial_fields_preserved(self):
|
||||
"""当 OCR 字段少于占位字段时,多余占位字段保留。"""
|
||||
ocr = [
|
||||
{"field_name": "customer_name"},
|
||||
{"field_name": "total_amount"},
|
||||
]
|
||||
result = _programmatic_map_fields(JRXML_WITH_PLACEHOLDERS, ocr)
|
||||
assert 'field name="field_3"' in result
|
||||
assert "$F{field_3}" in result
|
||||
|
||||
def test_empty_field_name_skipped(self):
|
||||
"""空 field_name 的 OCR 字段不触发替换。"""
|
||||
ocr = [
|
||||
{"field_name": ""},
|
||||
{"field_name": "total_amount"},
|
||||
{"field_name": ""},
|
||||
]
|
||||
result = _programmatic_map_fields(JRXML_WITH_PLACEHOLDERS, ocr)
|
||||
assert '$F{field_1}' in result
|
||||
assert '$F{total_amount}' in result
|
||||
assert '$F{field_3}' in result
|
||||
|
||||
def test_no_ocr_fields_no_change(self):
|
||||
result = _programmatic_map_fields(JRXML_WITH_PLACEHOLDERS, [])
|
||||
assert result == JRXML_WITH_PLACEHOLDERS
|
||||
|
||||
def test_chinese_field_names_sanitized(self):
|
||||
ocr = [
|
||||
{"field_name": "发票代码"},
|
||||
{"field_name": "发票号码"},
|
||||
{"field_name": "金额"},
|
||||
]
|
||||
result = _programmatic_map_fields(JRXML_WITH_PLACEHOLDERS, ocr)
|
||||
assert "发票代码" not in result
|
||||
|
||||
def test_validate_element_count_passes(self):
|
||||
ocr = [
|
||||
{"field_name": "customer_name"},
|
||||
{"field_name": "total_amount"},
|
||||
{"field_name": "invoice_date"},
|
||||
]
|
||||
result = _programmatic_map_fields(JRXML_WITH_PLACEHOLDERS, ocr)
|
||||
validation = validate_element_count(
|
||||
JRXML_WITH_PLACEHOLDERS, result, "map_fields"
|
||||
)
|
||||
assert validation["ok"] is True
|
||||
assert validation["modified"] == validation["original"]
|
||||
|
||||
def test_expression_with_multiple_fields(self):
|
||||
"""包含多个 $F{} 的表达式正确替换。"""
|
||||
ocr = [
|
||||
{"field_name": "unit_price"},
|
||||
{"field_name": "quantity"},
|
||||
]
|
||||
result = _programmatic_map_fields(JRXML_WITH_PLACEHOLDERS, ocr)
|
||||
assert '$F{unit_price}' in result
|
||||
assert '$F{quantity}' in result
|
||||
assert '$F{field_1}' not in result
|
||||
assert '$F{field_2}' not in result
|
||||
Reference in New Issue
Block a user