feat: layered precise generation for A4 report images
3-phase pipeline to solve LLM prompt overflow from too many OCR elements:
Phase 1 (generate_skeleton): compressed layout schema → skeleton JRXML
Phase 2 (refine_layout): sampled coordinates → pixel-level position tuning
Phase 3 (map_fields): OCR field names → replace $F{field_N} placeholders
Only triggered when layout_schema.total_rows > 0 on initial_generation intent.
Text requests and all other intents are unaffected (zero behavior change).
This commit is contained in:
@@ -0,0 +1,267 @@
|
||||
"""测试分层精确生成:extract_layout_schema, _format_row_coordinates, 路由逻辑。"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# 确保项目根在 path 中
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from backend.layout_analyzer import extract_layout_schema
|
||||
from agent.nodes import _format_row_coordinates
|
||||
from agent.graph import route_after_retrieve
|
||||
from agent.state import AgentState
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 测试夹具
|
||||
# ============================================================
|
||||
|
||||
def _make_row(elements: list[dict]) -> dict:
|
||||
y_center = sum(e["y"] + e["h"] / 2 for e in elements) / len(elements) if elements else 0
|
||||
return {"y_center": round(y_center, 1), "elements": elements}
|
||||
|
||||
|
||||
def _make_elem(x: float, y: float, w: float, h: float, text: str) -> dict:
|
||||
return {"x": x, "y": y, "w": w, "h": h, "font_size": h, "text": text}
|
||||
|
||||
|
||||
def _make_5col_10row_layout() -> dict:
|
||||
"""构造一个标准的 5列 x 10行 报表布局。"""
|
||||
rows = []
|
||||
for ri in range(10):
|
||||
y = 100 + ri * 30
|
||||
if ri == 0:
|
||||
texts = ["员工名册"]
|
||||
xs = [300]
|
||||
ws = [100]
|
||||
elif ri == 1:
|
||||
texts = ["序号", "姓名", "部门", "职位", "入职日期"]
|
||||
xs = [50, 150, 300, 450, 550]
|
||||
ws = [40, 80, 100, 80, 80]
|
||||
else:
|
||||
texts = [str(ri - 1), f"员工{ri-1}", "技术部", "工程师", "2024-01-01"]
|
||||
xs = [50, 150, 300, 450, 550]
|
||||
ws = [40, 80, 100, 80, 80]
|
||||
elements = [_make_elem(xs[i], y, ws[i], 24, texts[i]) for i in range(len(texts))]
|
||||
rows.append(_make_row(elements))
|
||||
return {"rows": rows, "image_size": (595, 842), "total_rows": 10, "total_elements": 46}
|
||||
|
||||
|
||||
# ============================================================
|
||||
# TestExtractLayoutSchema
|
||||
# ============================================================
|
||||
|
||||
class TestExtractLayoutSchema:
|
||||
"""extract_layout_schema() 单元测试。"""
|
||||
|
||||
def test_basic_table(self):
|
||||
"""5列x10行 布局 → 正确列数和区域分类。"""
|
||||
layout = _make_5col_10row_layout()
|
||||
schema = extract_layout_schema(layout)
|
||||
assert schema["total_columns"] == 5
|
||||
assert schema["total_rows"] == 10
|
||||
assert len(schema["columns"]) == 5
|
||||
assert len(schema["regions"]) >= 3 # title + header + data at minimum
|
||||
|
||||
def test_title_detection(self):
|
||||
"""第 0 行只有 1 个元素 → 标题区域。"""
|
||||
layout = _make_5col_10row_layout()
|
||||
schema = extract_layout_schema(layout)
|
||||
region_types = [r["type"] for r in schema["regions"]]
|
||||
assert "title" in region_types
|
||||
|
||||
def test_footer_detection(self):
|
||||
"""尾部行元素更少 → 表尾区域。"""
|
||||
rows = []
|
||||
for ri in range(8):
|
||||
y = 100 + ri * 30
|
||||
elements = [_make_elem(50 + i * 100, y, 80, 24, f"col{i}") for i in range(5)]
|
||||
rows.append(_make_row(elements))
|
||||
# 最后一行只有 1 个元素(如合计)
|
||||
rows.append(_make_row([_make_elem(300, 340, 100, 24, "合计: 100")]))
|
||||
layout = {"rows": rows, "image_size": (595, 842), "total_rows": 9, "total_elements": 41}
|
||||
schema = extract_layout_schema(layout)
|
||||
region_types = [r["type"] for r in schema["regions"]]
|
||||
assert "footer" in region_types
|
||||
|
||||
def test_empty_layout(self):
|
||||
"""空行 → 返回零值 schema。"""
|
||||
schema = extract_layout_schema({"rows": [], "image_size": (0, 0)})
|
||||
assert schema["total_rows"] == 0
|
||||
assert schema["total_columns"] == 0
|
||||
assert schema["columns"] == []
|
||||
assert "无法解析" in schema["schema_text"]
|
||||
|
||||
def test_single_row(self):
|
||||
"""单行 → 全部为 data。"""
|
||||
layout = {
|
||||
"rows": [_make_row([_make_elem(10, 10, 50, 20, "test")])],
|
||||
"image_size": (200, 200),
|
||||
}
|
||||
schema = extract_layout_schema(layout)
|
||||
assert schema["total_rows"] == 1
|
||||
assert schema["total_columns"] == 1
|
||||
|
||||
def test_two_rows(self):
|
||||
"""两行 → header + data。"""
|
||||
rows = [
|
||||
_make_row([_make_elem(10, 10, 50, 20, "表头")]),
|
||||
_make_row([_make_elem(10, 40, 50, 20, "数据")]),
|
||||
]
|
||||
schema = extract_layout_schema({"rows": rows, "image_size": (200, 200)})
|
||||
assert schema["total_rows"] == 2
|
||||
|
||||
def test_schema_text_format(self):
|
||||
"""schema_text 包含中文标签(列定义、区域)。"""
|
||||
layout = _make_5col_10row_layout()
|
||||
schema = extract_layout_schema(layout)
|
||||
assert "列定义" in schema["schema_text"]
|
||||
assert "区域" in schema["schema_text"]
|
||||
assert "A4纵向" in schema["schema_text"]
|
||||
|
||||
def test_column_width_categories(self):
|
||||
"""宽度分类:窄/中/宽 标签存在。"""
|
||||
layout = _make_5col_10row_layout()
|
||||
schema = extract_layout_schema(layout)
|
||||
text = schema["schema_text"]
|
||||
assert any(w in text for w in ("窄", "中", "宽"))
|
||||
|
||||
def test_a4_dimensions(self):
|
||||
"""A4 尺寸固定为 595x842。"""
|
||||
layout = _make_5col_10row_layout()
|
||||
schema = extract_layout_schema(layout)
|
||||
assert schema["a4_dimensions"] == {"width": 595, "height": 842}
|
||||
|
||||
|
||||
# ============================================================
|
||||
# TestFormatRowCoordinates
|
||||
# ============================================================
|
||||
|
||||
class TestFormatRowCoordinates:
|
||||
"""_format_row_coordinates() 单元测试。"""
|
||||
|
||||
def test_formats_single_row(self):
|
||||
"""单行 → columns 列表包含正确字段。"""
|
||||
row = _make_row([
|
||||
_make_elem(10, 100, 50, 20, "序号"),
|
||||
_make_elem(60, 100, 80, 20, "姓名"),
|
||||
])
|
||||
result = _format_row_coordinates(row)
|
||||
assert len(result["columns"]) == 2
|
||||
assert result["y_center"] > 0
|
||||
assert "col" in result["columns"][0]
|
||||
assert "text" in result["columns"][0]
|
||||
|
||||
def test_sorts_by_x(self):
|
||||
"""元素按 x 坐标从左到右排序。"""
|
||||
row = _make_row([
|
||||
_make_elem(200, 100, 80, 20, "right"),
|
||||
_make_elem(10, 100, 50, 20, "left"),
|
||||
])
|
||||
result = _format_row_coordinates(row)
|
||||
assert result["columns"][0]["text"] == "left"
|
||||
assert result["columns"][1]["text"] == "right"
|
||||
|
||||
def test_empty_row(self):
|
||||
"""空元素列表 → columns 为空。"""
|
||||
result = _format_row_coordinates({"y_center": 100, "elements": []})
|
||||
assert result["columns"] == []
|
||||
|
||||
def test_non_dict_input(self):
|
||||
"""非 dict 输入 → 返回空 dict。"""
|
||||
assert _format_row_coordinates(None) == {}
|
||||
assert _format_row_coordinates("not a dict") == {}
|
||||
|
||||
|
||||
# ============================================================
|
||||
# TestRouting
|
||||
# ============================================================
|
||||
|
||||
class TestRouting:
|
||||
"""route_after_retrieve() 路由逻辑测试。"""
|
||||
|
||||
def _state(self, **kwargs) -> AgentState:
|
||||
s = {"current_jrxml": "", "user_input": "test", "conversation_history": []}
|
||||
s.update(kwargs)
|
||||
return s
|
||||
|
||||
def test_with_schema(self):
|
||||
"""layout_schema 有行 → generate_skeleton。"""
|
||||
state = self._state(layout_schema={"total_rows": 5, "total_columns": 3})
|
||||
assert route_after_retrieve(state) == "generate_skeleton"
|
||||
|
||||
def test_without_schema(self):
|
||||
"""无 layout_schema → generate。"""
|
||||
state = self._state()
|
||||
assert route_after_retrieve(state) == "generate"
|
||||
|
||||
def test_empty_schema(self):
|
||||
"""空 dict → generate。"""
|
||||
state = self._state(layout_schema={})
|
||||
assert route_after_retrieve(state) == "generate"
|
||||
|
||||
def test_zero_rows(self):
|
||||
"""total_rows = 0 → generate。"""
|
||||
state = self._state(layout_schema={"total_rows": 0, "total_columns": 0})
|
||||
assert route_after_retrieve(state) == "generate"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# TestIntegration
|
||||
# ============================================================
|
||||
|
||||
class TestIntegration:
|
||||
"""集成测试:验证图执行路径。"""
|
||||
|
||||
def test_three_phase_nodes_exist(self):
|
||||
"""三个新节点可被导入。"""
|
||||
from agent.nodes import generate_skeleton, refine_layout, map_fields
|
||||
assert callable(generate_skeleton)
|
||||
assert callable(refine_layout)
|
||||
assert callable(map_fields)
|
||||
|
||||
def test_graph_builds_with_new_nodes(self):
|
||||
"""图构建成功包含新节点。"""
|
||||
from agent.graph import build_graph
|
||||
graph = build_graph()
|
||||
nodes = graph.get_graph().nodes
|
||||
node_names = {n for n in nodes}
|
||||
assert "generate_skeleton" in node_names
|
||||
assert "refine_layout" in node_names
|
||||
assert "map_fields" in node_names
|
||||
|
||||
def test_one_shot_path_unchanged(self):
|
||||
"""无 schema 时 route_after_retrieve → generate。"""
|
||||
from agent.graph import route_after_retrieve as rar
|
||||
state = {"layout_schema": {}, "current_jrxml": "", "user_input": "test"}
|
||||
assert rar(state) == "generate"
|
||||
|
||||
def test_modify_report_unaffected(self):
|
||||
"""modify_report 意图不受影响 — route_by_intent 不变。"""
|
||||
from agent.graph import route_by_intent
|
||||
state = {
|
||||
"intent": "modify_report",
|
||||
"current_jrxml": "<jasperReport>...</jasperReport>",
|
||||
"layout_schema": {"total_rows": 5},
|
||||
}
|
||||
assert route_by_intent(state) == "modify_jrxml"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# TestLayoutSchemaJSONRoundTrip
|
||||
# ============================================================
|
||||
|
||||
class TestLayoutSchemaJSONRoundTrip:
|
||||
"""验证 schema 可被 JSON 序列化/反序列化(用于会话持久化)。"""
|
||||
|
||||
def test_serializable(self):
|
||||
"""extract_layout_schema 输出可 JSON 序列化。"""
|
||||
layout = _make_5col_10row_layout()
|
||||
schema = extract_layout_schema(layout)
|
||||
dumped = json.dumps(schema, ensure_ascii=False)
|
||||
loaded = json.loads(dumped)
|
||||
assert loaded["total_rows"] == schema["total_rows"]
|
||||
assert loaded["total_columns"] == schema["total_columns"]
|
||||
Reference in New Issue
Block a user