Files
agent_jrxml/tests/test_kb_parser.py
T
panda bd5bfbac2d fix: band-level windowed refine_layout + programmatic map_fields to prevent 91.5% content loss
Root cause: LLM receiving full 34k-char JRXML would regenerate from scratch
instead of modifying coordinates in-place, shrinking output to ~3k chars.

Solution (programmatic node control, not prompt engineering):

- New agent/jrxml_windower.py: decompose JRXML into header (never sent to
  LLM) + individual bands. Split bands >4000 chars at element boundaries.
  Reassemble with element count validation (>10% change = rollback).

- Rewrite refine_layout: per-band windowed LLM processing (~2-4k chars
  each). LLM cannot "reimagine" the entire report.

- Rewrite map_fields: 100% programmatic regex $F{field_N} -> real name
  replacement. Zero LLM calls, zero content loss.

- _sanitize_field_name: non-ASCII chars escaped to _uXXXX_ format for
  valid JRXML identifiers.

- Tests: 48 new unit tests (windower 28 + map_fields 20). All passing.
  Full suite 385 tests, zero regressions.
2026-05-24 08:55:38 +08:00

312 lines
12 KiB
Python

"""kb_parser.py 测试 — JRXML 解析, 文件处理, 分块, 字段提取。"""
import json
import os
import sys
import tempfile
from pathlib import Path
from unittest.mock import patch, MagicMock
import pytest
sys.path.insert(0, str(Path(__file__).parent.parent))
from backend.kb_parser import (
parse_jrxml_fields, _extract_archive, process_file_for_kb,
chunk_file_results, extract_fields_with_llm, _extract_fields_from_table,
_find_tag, _find_all_tags, _collect_from_result, build_kb_from_files,
)
SAMPLE_JRXML = """<?xml version="1.0" encoding="UTF-8"?>
<jasperReport name="TestReport" pageWidth="595" pageHeight="842"
xmlns="http://jasperreports.sourceforge.net/jasperreports">
<parameter name="billNo" class="java.lang.String">
<parameterDescription>工单号</parameterDescription>
</parameter>
<parameter name="customerName" class="java.lang.String"/>
<field name="amount" class="java.math.BigDecimal"/>
<field name="createDate" class="java.sql.Date"/>
<queryString><![CDATA[SELECT * FROM orders WHERE bill_no = $P{billNo}]]></queryString>
</jasperReport>"""
SAMPLE_JRXML_NO_NS = """<?xml version="1.0" encoding="UTF-8"?>
<jasperReport name="SimpleReport" pageWidth="800" pageHeight="600">
<parameter name="title" class="java.lang.String"/>
<field name="name" class="java.lang.String"/>
</jasperReport>"""
INVALID_XML = """<?xml version="1.0"?>
<jasperReport>
<parameter name="broken"
</jasperReport>"""
@pytest.fixture
def jrxml_file():
with tempfile.NamedTemporaryFile(mode="w", suffix=".jrxml",
delete=False, encoding="utf-8") as f:
f.write(SAMPLE_JRXML)
path = f.name
yield path
os.unlink(path)
@pytest.fixture
def temp_kb(monkeypatch):
with tempfile.TemporaryDirectory(prefix="test_kb_parser_") as tmpdir:
kb_data = Path(tmpdir)
user_dir = kb_data / "default"
kb_dir = user_dir / "abcdef1234567890abcd"
raw_dir = kb_dir / "raw"
raw_dir.mkdir(parents=True)
monkeypatch.setattr(
"backend.kb_manager.get_kb_raw_dir",
lambda kb_id: raw_dir if kb_id == "abcdef1234567890abcd" else None)
monkeypatch.setattr(
"backend.kb_manager.get_kb_chunks_path",
lambda kb_id: kb_dir / "chunks.json" if kb_id == "abcdef1234567890abcd" else None)
monkeypatch.setattr(
"backend.kb_manager.update_kb_meta",
lambda kb_id, updates: None)
yield {"kb_id": "abcdef1234567890abcd", "kb_dir": kb_dir, "raw_dir": raw_dir,
"data_dir": kb_data}
# ── JRXML 解析 ──────────────────────────────────────────────────
class TestParseJrxmlFields:
def test_parses_parameters(self, jrxml_file):
result = parse_jrxml_fields(jrxml_file)
assert result["error"] is None
assert len(result["parameters"]) == 2
assert result["parameters"][0]["name"] == "billNo"
assert result["parameters"][0]["type"] == "java.lang.String"
assert result["parameters"][0]["description"] == "工单号"
def test_parses_fields(self, jrxml_file):
result = parse_jrxml_fields(jrxml_file)
assert len(result["fields"]) == 2
field_names = [f["name"] for f in result["fields"]]
assert "amount" in field_names
assert "createDate" in field_names
def test_parses_query(self, jrxml_file):
result = parse_jrxml_fields(jrxml_file)
assert "SELECT * FROM orders" in result["query"]
def test_parses_report_metadata(self, jrxml_file):
result = parse_jrxml_fields(jrxml_file)
assert result["report_name"] == "TestReport"
assert result["page_width"] == "595"
assert result["page_height"] == "842"
def test_parses_jrxml_without_namespace(self, tmp_path):
fp = tmp_path / "simple.jrxml"
fp.write_text(SAMPLE_JRXML_NO_NS, encoding="utf-8")
result = parse_jrxml_fields(str(fp))
assert result["report_name"] == "SimpleReport"
assert len(result["parameters"]) == 1
def test_invalid_xml_returns_error(self, tmp_path):
fp = tmp_path / "bad.jrxml"
fp.write_text(INVALID_XML, encoding="utf-8")
result = parse_jrxml_fields(str(fp))
assert result["error"] is not None
assert "解析失败" in result["error"]
def test_empty_jrxml_has_no_fields(self, tmp_path):
fp = tmp_path / "empty.jrxml"
fp.write_text(
'<?xml version="1.0"?>'
'<jasperReport name="Empty"/>',
encoding="utf-8")
result = parse_jrxml_fields(str(fp))
assert result["parameters"] == []
assert result["fields"] == []
def test_nonexistent_file_raises(self):
with pytest.raises(FileNotFoundError):
parse_jrxml_fields("/nonexistent/path.jrxml")
# ── 表格字段提取 ────────────────────────────────────────────────
class TestExtractFieldsFromTable:
def test_extracts_from_markdown_table(self):
text = """| 字段名 | 含义 | 必填 | 类型 |
|--------|------|------|------|
| billNo | 工单号 | 是 | String |
| amount | 金额 | 否 | BigDecimal |"""
fields = _extract_fields_from_table(text)
assert len(fields) == 2
assert fields[0]["name"] == "billNo"
assert fields[0]["description"] == "工单号"
assert fields[0]["required"] is True
assert fields[1]["name"] == "amount"
def test_skips_separator_rows(self):
text = """| 字段 | 说明 |
|------|------|
|------|------|
| name | 名称 |"""
fields = _extract_fields_from_table(text)
assert len(fields) == 1
assert fields[0]["name"] == "name"
def test_returns_empty_for_plain_text(self):
fields = _extract_fields_from_table("这是一段普通文本,没有表格。")
assert fields == []
def test_cells_with_bold_markers_stripped(self):
text = """| 名称 | 含义 |
|------|------|
| **billNo** | 单号 |"""
fields = _extract_fields_from_table(text)
assert fields[0]["name"] == "billNo"
# ── LLM 字段提取 ────────────────────────────────────────────────
class TestExtractFieldsWithLlm:
def test_falls_back_to_table_when_no_llm(self):
text = "| 字段 | 说明 |\n|------|------|\n| code | 编码 |"
fields = extract_fields_with_llm(text, llm=None)
assert len(fields) >= 1
assert any(f["name"] == "code" for f in fields)
def test_uses_llm_when_provided(self):
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = '[{"name": "id", "description": "ID", "type": "Long", "required": true}]'
mock_llm.invoke.return_value = mock_response
fields = extract_fields_with_llm("some text", llm=mock_llm)
assert len(fields) == 1
assert fields[0]["name"] == "id"
def test_llm_failure_falls_back_to_table(self):
mock_llm = MagicMock()
mock_llm.invoke.side_effect = RuntimeError("LLM down")
text = "| 字段 | 说明 |\n|------|------|\n| code | 编码 |"
fields = extract_fields_with_llm(text, llm=mock_llm)
assert any(f["name"] == "code" for f in fields)
# ── 文件处理 ────────────────────────────────────────────────────
class TestProcessFileForKb:
def test_process_jrxml_copies_and_parses(self, jrxml_file, temp_kb):
result = process_file_for_kb(temp_kb["kb_id"], jrxml_file)
assert result["type"] == "jrxml"
assert result["jrxml_info"]["report_name"] == "TestReport"
assert result["error"] is None
copied = list(temp_kb["raw_dir"].glob("*.jrxml"))
assert len(copied) == 1
def test_process_nonexistent_kb_returns_error(self, jrxml_file):
result = process_file_for_kb("deadbeef1234567890abcd", jrxml_file)
assert result["error"] is not None
def test_process_text_file(self, tmp_path, temp_kb):
fp = tmp_path / "readme.md"
fp.write_text("# 标题\n\n这是一段内容。\n\n另一段内容。", encoding="utf-8")
with patch("backend.kb_parser.parse_file") as mock_parse:
mock_parse.return_value = {"text": "parsed content", "error": None}
result = process_file_for_kb(temp_kb["kb_id"], str(fp))
assert result["filename"] is not None
assert result["error"] is None
# ── 分块 ────────────────────────────────────────────────────────
class TestChunkFileResults:
def test_jrxml_result_produces_template_chunk(self, jrxml_file):
info = parse_jrxml_fields(jrxml_file)
raw = Path(jrxml_file).read_text(encoding="utf-8")
results = [{
"filename": "test.jrxml", "type": "jrxml",
"text": "text content", "raw_xml": raw,
"jrxml_info": info, "error": None,
}]
chunks = chunk_file_results(results, kb_name="测试库")
assert len(chunks) >= 1
tmpl = [c for c in chunks if c["metadata"]["chunk_type"] == "jrxml_template"]
assert len(tmpl) == 1
assert tmpl[0]["metadata"]["report_name"] == "TestReport"
assert "TestReport" in tmpl[0]["content"]
def test_archive_result_recurses(self):
results = [{
"filename": "bundle.zip", "type": "archive", "text": "",
"archive_contents": [
{"filename": "inner.jrxml", "type": "jrxml",
"text": "inner text", "raw_xml": "<xml/>",
"jrxml_info": {"report_name": "Inner", "parameters": [], "fields": []},
"error": None},
], "error": None,
}]
chunks = chunk_file_results(results)
assert any(c["metadata"]["report_name"] == "Inner" for c in chunks)
def test_empty_text_skipped(self):
results = [{"filename": "empty.md", "type": "md", "text": "", "error": None}]
assert chunk_file_results(results) == []
def test_short_paragraphs_skipped(self):
results = [{"filename": "short.txt", "type": "txt", "text": "hi", "error": None}]
assert chunk_file_results(results) == []
def test_text_split_into_paragraphs(self):
long_para = "A" * 50
results = [
{"filename": "doc.txt", "type": "txt",
"text": f"{long_para}\n\n{long_para}\n\n{long_para}", "error": None},
]
chunks = chunk_file_results(results)
assert len(chunks) == 3
# ── _collect_from_result ────────────────────────────────────────
class TestCollectFromResult:
def test_collects_jrxml_parameters_as_fields(self):
fields = []
templates = []
_collect_from_result({
"jrxml_info": {
"report_name": "R1",
"parameters": [{"name": "p1", "type": "String", "description": "参数1"}],
"fields": [],
},
"filename": "r1.jrxml",
}, fields, templates)
assert len(templates) == 1
assert any(f["name"] == "p1" for f in fields)
def test_collects_jrxml_fields(self):
fields = []
templates = []
_collect_from_result({
"jrxml_info": {
"report_name": "R2",
"parameters": [],
"fields": [{"name": "f1", "type": "Double", "description": ""}],
},
"filename": "r2.jrxml",
}, fields, templates)
assert any(f["name"] == "f1" for f in fields)
def test_skips_non_jrxml(self):
fields = []
templates = []
_collect_from_result({"type": "csv", "filename": "data.csv"}, fields, templates)
assert templates == []
assert fields == []
def test_deduplicates_fields(self):
fields = []
templates = []
info = {"report_name": "R", "parameters": [{"name": "dup", "type": "String", "description": ""}], "fields": []}
_collect_from_result({"jrxml_info": info, "filename": "a.jrxml"}, fields, templates)
_collect_from_result({"jrxml_info": info, "filename": "b.jrxml"}, fields, templates)
assert sum(1 for f in fields if f["name"] == "dup") == 1