Files
agent_jrxml/tests/test_layered_generation.py
T
panda 43a0542a11 feat: layered precise generation for A4 report images
3-phase pipeline to solve LLM prompt overflow from too many OCR elements:
Phase 1 (generate_skeleton): compressed layout schema → skeleton JRXML
Phase 2 (refine_layout): sampled coordinates → pixel-level position tuning
Phase 3 (map_fields): OCR field names → replace $F{field_N} placeholders

Only triggered when layout_schema.total_rows > 0 on initial_generation intent.
Text requests and all other intents are unaffected (zero behavior change).
2026-05-21 08:34:32 +08:00

268 lines
10 KiB
Python

"""测试分层精确生成:extract_layout_schema, _format_row_coordinates, 路由逻辑。"""
import json
import sys
from pathlib import Path
import pytest
# 确保项目根在 path 中
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from backend.layout_analyzer import extract_layout_schema
from agent.nodes import _format_row_coordinates
from agent.graph import route_after_retrieve
from agent.state import AgentState
# ============================================================
# 测试夹具
# ============================================================
def _make_row(elements: list[dict]) -> dict:
y_center = sum(e["y"] + e["h"] / 2 for e in elements) / len(elements) if elements else 0
return {"y_center": round(y_center, 1), "elements": elements}
def _make_elem(x: float, y: float, w: float, h: float, text: str) -> dict:
return {"x": x, "y": y, "w": w, "h": h, "font_size": h, "text": text}
def _make_5col_10row_layout() -> dict:
"""构造一个标准的 5列 x 10行 报表布局。"""
rows = []
for ri in range(10):
y = 100 + ri * 30
if ri == 0:
texts = ["员工名册"]
xs = [300]
ws = [100]
elif ri == 1:
texts = ["序号", "姓名", "部门", "职位", "入职日期"]
xs = [50, 150, 300, 450, 550]
ws = [40, 80, 100, 80, 80]
else:
texts = [str(ri - 1), f"员工{ri-1}", "技术部", "工程师", "2024-01-01"]
xs = [50, 150, 300, 450, 550]
ws = [40, 80, 100, 80, 80]
elements = [_make_elem(xs[i], y, ws[i], 24, texts[i]) for i in range(len(texts))]
rows.append(_make_row(elements))
return {"rows": rows, "image_size": (595, 842), "total_rows": 10, "total_elements": 46}
# ============================================================
# TestExtractLayoutSchema
# ============================================================
class TestExtractLayoutSchema:
"""extract_layout_schema() 单元测试。"""
def test_basic_table(self):
"""5列x10行 布局 → 正确列数和区域分类。"""
layout = _make_5col_10row_layout()
schema = extract_layout_schema(layout)
assert schema["total_columns"] == 5
assert schema["total_rows"] == 10
assert len(schema["columns"]) == 5
assert len(schema["regions"]) >= 3 # title + header + data at minimum
def test_title_detection(self):
"""第 0 行只有 1 个元素 → 标题区域。"""
layout = _make_5col_10row_layout()
schema = extract_layout_schema(layout)
region_types = [r["type"] for r in schema["regions"]]
assert "title" in region_types
def test_footer_detection(self):
"""尾部行元素更少 → 表尾区域。"""
rows = []
for ri in range(8):
y = 100 + ri * 30
elements = [_make_elem(50 + i * 100, y, 80, 24, f"col{i}") for i in range(5)]
rows.append(_make_row(elements))
# 最后一行只有 1 个元素(如合计)
rows.append(_make_row([_make_elem(300, 340, 100, 24, "合计: 100")]))
layout = {"rows": rows, "image_size": (595, 842), "total_rows": 9, "total_elements": 41}
schema = extract_layout_schema(layout)
region_types = [r["type"] for r in schema["regions"]]
assert "footer" in region_types
def test_empty_layout(self):
"""空行 → 返回零值 schema。"""
schema = extract_layout_schema({"rows": [], "image_size": (0, 0)})
assert schema["total_rows"] == 0
assert schema["total_columns"] == 0
assert schema["columns"] == []
assert "无法解析" in schema["schema_text"]
def test_single_row(self):
"""单行 → 全部为 data。"""
layout = {
"rows": [_make_row([_make_elem(10, 10, 50, 20, "test")])],
"image_size": (200, 200),
}
schema = extract_layout_schema(layout)
assert schema["total_rows"] == 1
assert schema["total_columns"] == 1
def test_two_rows(self):
"""两行 → header + data。"""
rows = [
_make_row([_make_elem(10, 10, 50, 20, "表头")]),
_make_row([_make_elem(10, 40, 50, 20, "数据")]),
]
schema = extract_layout_schema({"rows": rows, "image_size": (200, 200)})
assert schema["total_rows"] == 2
def test_schema_text_format(self):
"""schema_text 包含中文标签(列定义、区域)。"""
layout = _make_5col_10row_layout()
schema = extract_layout_schema(layout)
assert "列定义" in schema["schema_text"]
assert "区域" in schema["schema_text"]
assert "A4纵向" in schema["schema_text"]
def test_column_width_categories(self):
"""宽度分类:窄/中/宽 标签存在。"""
layout = _make_5col_10row_layout()
schema = extract_layout_schema(layout)
text = schema["schema_text"]
assert any(w in text for w in ("", "", ""))
def test_a4_dimensions(self):
"""A4 尺寸固定为 595x842。"""
layout = _make_5col_10row_layout()
schema = extract_layout_schema(layout)
assert schema["a4_dimensions"] == {"width": 595, "height": 842}
# ============================================================
# TestFormatRowCoordinates
# ============================================================
class TestFormatRowCoordinates:
"""_format_row_coordinates() 单元测试。"""
def test_formats_single_row(self):
"""单行 → columns 列表包含正确字段。"""
row = _make_row([
_make_elem(10, 100, 50, 20, "序号"),
_make_elem(60, 100, 80, 20, "姓名"),
])
result = _format_row_coordinates(row)
assert len(result["columns"]) == 2
assert result["y_center"] > 0
assert "col" in result["columns"][0]
assert "text" in result["columns"][0]
def test_sorts_by_x(self):
"""元素按 x 坐标从左到右排序。"""
row = _make_row([
_make_elem(200, 100, 80, 20, "right"),
_make_elem(10, 100, 50, 20, "left"),
])
result = _format_row_coordinates(row)
assert result["columns"][0]["text"] == "left"
assert result["columns"][1]["text"] == "right"
def test_empty_row(self):
"""空元素列表 → columns 为空。"""
result = _format_row_coordinates({"y_center": 100, "elements": []})
assert result["columns"] == []
def test_non_dict_input(self):
"""非 dict 输入 → 返回空 dict。"""
assert _format_row_coordinates(None) == {}
assert _format_row_coordinates("not a dict") == {}
# ============================================================
# TestRouting
# ============================================================
class TestRouting:
"""route_after_retrieve() 路由逻辑测试。"""
def _state(self, **kwargs) -> AgentState:
s = {"current_jrxml": "", "user_input": "test", "conversation_history": []}
s.update(kwargs)
return s
def test_with_schema(self):
"""layout_schema 有行 → generate_skeleton。"""
state = self._state(layout_schema={"total_rows": 5, "total_columns": 3})
assert route_after_retrieve(state) == "generate_skeleton"
def test_without_schema(self):
"""无 layout_schema → generate。"""
state = self._state()
assert route_after_retrieve(state) == "generate"
def test_empty_schema(self):
"""空 dict → generate。"""
state = self._state(layout_schema={})
assert route_after_retrieve(state) == "generate"
def test_zero_rows(self):
"""total_rows = 0 → generate。"""
state = self._state(layout_schema={"total_rows": 0, "total_columns": 0})
assert route_after_retrieve(state) == "generate"
# ============================================================
# TestIntegration
# ============================================================
class TestIntegration:
"""集成测试:验证图执行路径。"""
def test_three_phase_nodes_exist(self):
"""三个新节点可被导入。"""
from agent.nodes import generate_skeleton, refine_layout, map_fields
assert callable(generate_skeleton)
assert callable(refine_layout)
assert callable(map_fields)
def test_graph_builds_with_new_nodes(self):
"""图构建成功包含新节点。"""
from agent.graph import build_graph
graph = build_graph()
nodes = graph.get_graph().nodes
node_names = {n for n in nodes}
assert "generate_skeleton" in node_names
assert "refine_layout" in node_names
assert "map_fields" in node_names
def test_one_shot_path_unchanged(self):
"""无 schema 时 route_after_retrieve → generate。"""
from agent.graph import route_after_retrieve as rar
state = {"layout_schema": {}, "current_jrxml": "", "user_input": "test"}
assert rar(state) == "generate"
def test_modify_report_unaffected(self):
"""modify_report 意图不受影响 — route_by_intent 不变。"""
from agent.graph import route_by_intent
state = {
"intent": "modify_report",
"current_jrxml": "<jasperReport>...</jasperReport>",
"layout_schema": {"total_rows": 5},
}
assert route_by_intent(state) == "modify_jrxml"
# ============================================================
# TestLayoutSchemaJSONRoundTrip
# ============================================================
class TestLayoutSchemaJSONRoundTrip:
"""验证 schema 可被 JSON 序列化/反序列化(用于会话持久化)。"""
def test_serializable(self):
"""extract_layout_schema 输出可 JSON 序列化。"""
layout = _make_5col_10row_layout()
schema = extract_layout_schema(layout)
dumped = json.dumps(schema, ensure_ascii=False)
loaded = json.loads(dumped)
assert loaded["total_rows"] == schema["total_rows"]
assert loaded["total_columns"] == schema["total_columns"]