fix: band-level windowed refine_layout + programmatic map_fields to prevent 91.5% content loss
Root cause: LLM receiving full 34k-char JRXML would regenerate from scratch
instead of modifying coordinates in-place, shrinking output to ~3k chars.
Solution (programmatic node control, not prompt engineering):
- New agent/jrxml_windower.py: decompose JRXML into header (never sent to
LLM) + individual bands. Split bands >4000 chars at element boundaries.
Reassemble with element count validation (>10% change = rollback).
- Rewrite refine_layout: per-band windowed LLM processing (~2-4k chars
each). LLM cannot "reimagine" the entire report.
- Rewrite map_fields: 100% programmatic regex $F{field_N} -> real name
replacement. Zero LLM calls, zero content loss.
- _sanitize_field_name: non-ASCII chars escaped to _uXXXX_ format for
valid JRXML identifiers.
- Tests: 48 new unit tests (windower 28 + map_fields 20). All passing.
Full suite 385 tests, zero regressions.
This commit is contained in:
@@ -0,0 +1,311 @@
|
||||
"""kb_parser.py 测试 — JRXML 解析, 文件处理, 分块, 字段提取。"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from backend.kb_parser import (
|
||||
parse_jrxml_fields, _extract_archive, process_file_for_kb,
|
||||
chunk_file_results, extract_fields_with_llm, _extract_fields_from_table,
|
||||
_find_tag, _find_all_tags, _collect_from_result, build_kb_from_files,
|
||||
)
|
||||
|
||||
SAMPLE_JRXML = """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<jasperReport name="TestReport" pageWidth="595" pageHeight="842"
|
||||
xmlns="http://jasperreports.sourceforge.net/jasperreports">
|
||||
<parameter name="billNo" class="java.lang.String">
|
||||
<parameterDescription>工单号</parameterDescription>
|
||||
</parameter>
|
||||
<parameter name="customerName" class="java.lang.String"/>
|
||||
<field name="amount" class="java.math.BigDecimal"/>
|
||||
<field name="createDate" class="java.sql.Date"/>
|
||||
<queryString><![CDATA[SELECT * FROM orders WHERE bill_no = $P{billNo}]]></queryString>
|
||||
</jasperReport>"""
|
||||
|
||||
SAMPLE_JRXML_NO_NS = """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<jasperReport name="SimpleReport" pageWidth="800" pageHeight="600">
|
||||
<parameter name="title" class="java.lang.String"/>
|
||||
<field name="name" class="java.lang.String"/>
|
||||
</jasperReport>"""
|
||||
|
||||
INVALID_XML = """<?xml version="1.0"?>
|
||||
<jasperReport>
|
||||
<parameter name="broken"
|
||||
</jasperReport>"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jrxml_file():
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jrxml",
|
||||
delete=False, encoding="utf-8") as f:
|
||||
f.write(SAMPLE_JRXML)
|
||||
path = f.name
|
||||
yield path
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_kb(monkeypatch):
|
||||
with tempfile.TemporaryDirectory(prefix="test_kb_parser_") as tmpdir:
|
||||
kb_data = Path(tmpdir)
|
||||
user_dir = kb_data / "default"
|
||||
kb_dir = user_dir / "abcdef1234567890abcd"
|
||||
raw_dir = kb_dir / "raw"
|
||||
raw_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(
|
||||
"backend.kb_manager.get_kb_raw_dir",
|
||||
lambda kb_id: raw_dir if kb_id == "abcdef1234567890abcd" else None)
|
||||
monkeypatch.setattr(
|
||||
"backend.kb_manager.get_kb_chunks_path",
|
||||
lambda kb_id: kb_dir / "chunks.json" if kb_id == "abcdef1234567890abcd" else None)
|
||||
monkeypatch.setattr(
|
||||
"backend.kb_manager.update_kb_meta",
|
||||
lambda kb_id, updates: None)
|
||||
yield {"kb_id": "abcdef1234567890abcd", "kb_dir": kb_dir, "raw_dir": raw_dir,
|
||||
"data_dir": kb_data}
|
||||
|
||||
|
||||
# ── JRXML 解析 ──────────────────────────────────────────────────
|
||||
|
||||
class TestParseJrxmlFields:
|
||||
def test_parses_parameters(self, jrxml_file):
|
||||
result = parse_jrxml_fields(jrxml_file)
|
||||
assert result["error"] is None
|
||||
assert len(result["parameters"]) == 2
|
||||
assert result["parameters"][0]["name"] == "billNo"
|
||||
assert result["parameters"][0]["type"] == "java.lang.String"
|
||||
assert result["parameters"][0]["description"] == "工单号"
|
||||
|
||||
def test_parses_fields(self, jrxml_file):
|
||||
result = parse_jrxml_fields(jrxml_file)
|
||||
assert len(result["fields"]) == 2
|
||||
field_names = [f["name"] for f in result["fields"]]
|
||||
assert "amount" in field_names
|
||||
assert "createDate" in field_names
|
||||
|
||||
def test_parses_query(self, jrxml_file):
|
||||
result = parse_jrxml_fields(jrxml_file)
|
||||
assert "SELECT * FROM orders" in result["query"]
|
||||
|
||||
def test_parses_report_metadata(self, jrxml_file):
|
||||
result = parse_jrxml_fields(jrxml_file)
|
||||
assert result["report_name"] == "TestReport"
|
||||
assert result["page_width"] == "595"
|
||||
assert result["page_height"] == "842"
|
||||
|
||||
def test_parses_jrxml_without_namespace(self, tmp_path):
|
||||
fp = tmp_path / "simple.jrxml"
|
||||
fp.write_text(SAMPLE_JRXML_NO_NS, encoding="utf-8")
|
||||
result = parse_jrxml_fields(str(fp))
|
||||
assert result["report_name"] == "SimpleReport"
|
||||
assert len(result["parameters"]) == 1
|
||||
|
||||
def test_invalid_xml_returns_error(self, tmp_path):
|
||||
fp = tmp_path / "bad.jrxml"
|
||||
fp.write_text(INVALID_XML, encoding="utf-8")
|
||||
result = parse_jrxml_fields(str(fp))
|
||||
assert result["error"] is not None
|
||||
assert "解析失败" in result["error"]
|
||||
|
||||
def test_empty_jrxml_has_no_fields(self, tmp_path):
|
||||
fp = tmp_path / "empty.jrxml"
|
||||
fp.write_text(
|
||||
'<?xml version="1.0"?>'
|
||||
'<jasperReport name="Empty"/>',
|
||||
encoding="utf-8")
|
||||
result = parse_jrxml_fields(str(fp))
|
||||
assert result["parameters"] == []
|
||||
assert result["fields"] == []
|
||||
|
||||
def test_nonexistent_file_raises(self):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
parse_jrxml_fields("/nonexistent/path.jrxml")
|
||||
|
||||
|
||||
# ── 表格字段提取 ────────────────────────────────────────────────
|
||||
|
||||
class TestExtractFieldsFromTable:
|
||||
def test_extracts_from_markdown_table(self):
|
||||
text = """| 字段名 | 含义 | 必填 | 类型 |
|
||||
|--------|------|------|------|
|
||||
| billNo | 工单号 | 是 | String |
|
||||
| amount | 金额 | 否 | BigDecimal |"""
|
||||
fields = _extract_fields_from_table(text)
|
||||
assert len(fields) == 2
|
||||
assert fields[0]["name"] == "billNo"
|
||||
assert fields[0]["description"] == "工单号"
|
||||
assert fields[0]["required"] is True
|
||||
assert fields[1]["name"] == "amount"
|
||||
|
||||
def test_skips_separator_rows(self):
|
||||
text = """| 字段 | 说明 |
|
||||
|------|------|
|
||||
|------|------|
|
||||
| name | 名称 |"""
|
||||
fields = _extract_fields_from_table(text)
|
||||
assert len(fields) == 1
|
||||
assert fields[0]["name"] == "name"
|
||||
|
||||
def test_returns_empty_for_plain_text(self):
|
||||
fields = _extract_fields_from_table("这是一段普通文本,没有表格。")
|
||||
assert fields == []
|
||||
|
||||
def test_cells_with_bold_markers_stripped(self):
|
||||
text = """| 名称 | 含义 |
|
||||
|------|------|
|
||||
| **billNo** | 单号 |"""
|
||||
fields = _extract_fields_from_table(text)
|
||||
assert fields[0]["name"] == "billNo"
|
||||
|
||||
|
||||
# ── LLM 字段提取 ────────────────────────────────────────────────
|
||||
|
||||
class TestExtractFieldsWithLlm:
|
||||
def test_falls_back_to_table_when_no_llm(self):
|
||||
text = "| 字段 | 说明 |\n|------|------|\n| code | 编码 |"
|
||||
fields = extract_fields_with_llm(text, llm=None)
|
||||
assert len(fields) >= 1
|
||||
assert any(f["name"] == "code" for f in fields)
|
||||
|
||||
def test_uses_llm_when_provided(self):
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = '[{"name": "id", "description": "ID", "type": "Long", "required": true}]'
|
||||
mock_llm.invoke.return_value = mock_response
|
||||
fields = extract_fields_with_llm("some text", llm=mock_llm)
|
||||
assert len(fields) == 1
|
||||
assert fields[0]["name"] == "id"
|
||||
|
||||
def test_llm_failure_falls_back_to_table(self):
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke.side_effect = RuntimeError("LLM down")
|
||||
text = "| 字段 | 说明 |\n|------|------|\n| code | 编码 |"
|
||||
fields = extract_fields_with_llm(text, llm=mock_llm)
|
||||
assert any(f["name"] == "code" for f in fields)
|
||||
|
||||
|
||||
# ── 文件处理 ────────────────────────────────────────────────────
|
||||
|
||||
class TestProcessFileForKb:
|
||||
def test_process_jrxml_copies_and_parses(self, jrxml_file, temp_kb):
|
||||
result = process_file_for_kb(temp_kb["kb_id"], jrxml_file)
|
||||
assert result["type"] == "jrxml"
|
||||
assert result["jrxml_info"]["report_name"] == "TestReport"
|
||||
assert result["error"] is None
|
||||
copied = list(temp_kb["raw_dir"].glob("*.jrxml"))
|
||||
assert len(copied) == 1
|
||||
|
||||
def test_process_nonexistent_kb_returns_error(self, jrxml_file):
|
||||
result = process_file_for_kb("deadbeef1234567890abcd", jrxml_file)
|
||||
assert result["error"] is not None
|
||||
|
||||
def test_process_text_file(self, tmp_path, temp_kb):
|
||||
fp = tmp_path / "readme.md"
|
||||
fp.write_text("# 标题\n\n这是一段内容。\n\n另一段内容。", encoding="utf-8")
|
||||
with patch("backend.kb_parser.parse_file") as mock_parse:
|
||||
mock_parse.return_value = {"text": "parsed content", "error": None}
|
||||
result = process_file_for_kb(temp_kb["kb_id"], str(fp))
|
||||
assert result["filename"] is not None
|
||||
assert result["error"] is None
|
||||
|
||||
|
||||
# ── 分块 ────────────────────────────────────────────────────────
|
||||
|
||||
class TestChunkFileResults:
|
||||
def test_jrxml_result_produces_template_chunk(self, jrxml_file):
|
||||
info = parse_jrxml_fields(jrxml_file)
|
||||
raw = Path(jrxml_file).read_text(encoding="utf-8")
|
||||
results = [{
|
||||
"filename": "test.jrxml", "type": "jrxml",
|
||||
"text": "text content", "raw_xml": raw,
|
||||
"jrxml_info": info, "error": None,
|
||||
}]
|
||||
chunks = chunk_file_results(results, kb_name="测试库")
|
||||
assert len(chunks) >= 1
|
||||
tmpl = [c for c in chunks if c["metadata"]["chunk_type"] == "jrxml_template"]
|
||||
assert len(tmpl) == 1
|
||||
assert tmpl[0]["metadata"]["report_name"] == "TestReport"
|
||||
assert "TestReport" in tmpl[0]["content"]
|
||||
|
||||
def test_archive_result_recurses(self):
|
||||
results = [{
|
||||
"filename": "bundle.zip", "type": "archive", "text": "",
|
||||
"archive_contents": [
|
||||
{"filename": "inner.jrxml", "type": "jrxml",
|
||||
"text": "inner text", "raw_xml": "<xml/>",
|
||||
"jrxml_info": {"report_name": "Inner", "parameters": [], "fields": []},
|
||||
"error": None},
|
||||
], "error": None,
|
||||
}]
|
||||
chunks = chunk_file_results(results)
|
||||
assert any(c["metadata"]["report_name"] == "Inner" for c in chunks)
|
||||
|
||||
def test_empty_text_skipped(self):
|
||||
results = [{"filename": "empty.md", "type": "md", "text": "", "error": None}]
|
||||
assert chunk_file_results(results) == []
|
||||
|
||||
def test_short_paragraphs_skipped(self):
|
||||
results = [{"filename": "short.txt", "type": "txt", "text": "hi", "error": None}]
|
||||
assert chunk_file_results(results) == []
|
||||
|
||||
def test_text_split_into_paragraphs(self):
|
||||
long_para = "A" * 50
|
||||
results = [
|
||||
{"filename": "doc.txt", "type": "txt",
|
||||
"text": f"{long_para}\n\n{long_para}\n\n{long_para}", "error": None},
|
||||
]
|
||||
chunks = chunk_file_results(results)
|
||||
assert len(chunks) == 3
|
||||
|
||||
|
||||
# ── _collect_from_result ────────────────────────────────────────
|
||||
|
||||
class TestCollectFromResult:
|
||||
def test_collects_jrxml_parameters_as_fields(self):
|
||||
fields = []
|
||||
templates = []
|
||||
_collect_from_result({
|
||||
"jrxml_info": {
|
||||
"report_name": "R1",
|
||||
"parameters": [{"name": "p1", "type": "String", "description": "参数1"}],
|
||||
"fields": [],
|
||||
},
|
||||
"filename": "r1.jrxml",
|
||||
}, fields, templates)
|
||||
assert len(templates) == 1
|
||||
assert any(f["name"] == "p1" for f in fields)
|
||||
|
||||
def test_collects_jrxml_fields(self):
|
||||
fields = []
|
||||
templates = []
|
||||
_collect_from_result({
|
||||
"jrxml_info": {
|
||||
"report_name": "R2",
|
||||
"parameters": [],
|
||||
"fields": [{"name": "f1", "type": "Double", "description": ""}],
|
||||
},
|
||||
"filename": "r2.jrxml",
|
||||
}, fields, templates)
|
||||
assert any(f["name"] == "f1" for f in fields)
|
||||
|
||||
def test_skips_non_jrxml(self):
|
||||
fields = []
|
||||
templates = []
|
||||
_collect_from_result({"type": "csv", "filename": "data.csv"}, fields, templates)
|
||||
assert templates == []
|
||||
assert fields == []
|
||||
|
||||
def test_deduplicates_fields(self):
|
||||
fields = []
|
||||
templates = []
|
||||
info = {"report_name": "R", "parameters": [{"name": "dup", "type": "String", "description": ""}], "fields": []}
|
||||
_collect_from_result({"jrxml_info": info, "filename": "a.jrxml"}, fields, templates)
|
||||
_collect_from_result({"jrxml_info": info, "filename": "b.jrxml"}, fields, templates)
|
||||
assert sum(1 for f in fields if f["name"] == "dup") == 1
|
||||
Reference in New Issue
Block a user