43a0542a11
3-phase pipeline to solve LLM prompt overflow from too many OCR elements:
Phase 1 (generate_skeleton): compressed layout schema → skeleton JRXML
Phase 2 (refine_layout): sampled coordinates → pixel-level position tuning
Phase 3 (map_fields): OCR field names → replace $F{field_N} placeholders
Only triggered when layout_schema.total_rows > 0 on initial_generation intent.
Text requests and all other intents are unaffected (zero behavior change).
268 lines
10 KiB
Python
268 lines
10 KiB
Python
"""测试分层精确生成:extract_layout_schema, _format_row_coordinates, 路由逻辑。"""
|
|
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
# 确保项目根在 path 中
|
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
|
|
from backend.layout_analyzer import extract_layout_schema
|
|
from agent.nodes import _format_row_coordinates
|
|
from agent.graph import route_after_retrieve
|
|
from agent.state import AgentState
|
|
|
|
|
|
# ============================================================
|
|
# 测试夹具
|
|
# ============================================================
|
|
|
|
def _make_row(elements: list[dict]) -> dict:
|
|
y_center = sum(e["y"] + e["h"] / 2 for e in elements) / len(elements) if elements else 0
|
|
return {"y_center": round(y_center, 1), "elements": elements}
|
|
|
|
|
|
def _make_elem(x: float, y: float, w: float, h: float, text: str) -> dict:
|
|
return {"x": x, "y": y, "w": w, "h": h, "font_size": h, "text": text}
|
|
|
|
|
|
def _make_5col_10row_layout() -> dict:
|
|
"""构造一个标准的 5列 x 10行 报表布局。"""
|
|
rows = []
|
|
for ri in range(10):
|
|
y = 100 + ri * 30
|
|
if ri == 0:
|
|
texts = ["员工名册"]
|
|
xs = [300]
|
|
ws = [100]
|
|
elif ri == 1:
|
|
texts = ["序号", "姓名", "部门", "职位", "入职日期"]
|
|
xs = [50, 150, 300, 450, 550]
|
|
ws = [40, 80, 100, 80, 80]
|
|
else:
|
|
texts = [str(ri - 1), f"员工{ri-1}", "技术部", "工程师", "2024-01-01"]
|
|
xs = [50, 150, 300, 450, 550]
|
|
ws = [40, 80, 100, 80, 80]
|
|
elements = [_make_elem(xs[i], y, ws[i], 24, texts[i]) for i in range(len(texts))]
|
|
rows.append(_make_row(elements))
|
|
return {"rows": rows, "image_size": (595, 842), "total_rows": 10, "total_elements": 46}
|
|
|
|
|
|
# ============================================================
|
|
# TestExtractLayoutSchema
|
|
# ============================================================
|
|
|
|
class TestExtractLayoutSchema:
|
|
"""extract_layout_schema() 单元测试。"""
|
|
|
|
def test_basic_table(self):
|
|
"""5列x10行 布局 → 正确列数和区域分类。"""
|
|
layout = _make_5col_10row_layout()
|
|
schema = extract_layout_schema(layout)
|
|
assert schema["total_columns"] == 5
|
|
assert schema["total_rows"] == 10
|
|
assert len(schema["columns"]) == 5
|
|
assert len(schema["regions"]) >= 3 # title + header + data at minimum
|
|
|
|
def test_title_detection(self):
|
|
"""第 0 行只有 1 个元素 → 标题区域。"""
|
|
layout = _make_5col_10row_layout()
|
|
schema = extract_layout_schema(layout)
|
|
region_types = [r["type"] for r in schema["regions"]]
|
|
assert "title" in region_types
|
|
|
|
def test_footer_detection(self):
|
|
"""尾部行元素更少 → 表尾区域。"""
|
|
rows = []
|
|
for ri in range(8):
|
|
y = 100 + ri * 30
|
|
elements = [_make_elem(50 + i * 100, y, 80, 24, f"col{i}") for i in range(5)]
|
|
rows.append(_make_row(elements))
|
|
# 最后一行只有 1 个元素(如合计)
|
|
rows.append(_make_row([_make_elem(300, 340, 100, 24, "合计: 100")]))
|
|
layout = {"rows": rows, "image_size": (595, 842), "total_rows": 9, "total_elements": 41}
|
|
schema = extract_layout_schema(layout)
|
|
region_types = [r["type"] for r in schema["regions"]]
|
|
assert "footer" in region_types
|
|
|
|
def test_empty_layout(self):
|
|
"""空行 → 返回零值 schema。"""
|
|
schema = extract_layout_schema({"rows": [], "image_size": (0, 0)})
|
|
assert schema["total_rows"] == 0
|
|
assert schema["total_columns"] == 0
|
|
assert schema["columns"] == []
|
|
assert "无法解析" in schema["schema_text"]
|
|
|
|
def test_single_row(self):
|
|
"""单行 → 全部为 data。"""
|
|
layout = {
|
|
"rows": [_make_row([_make_elem(10, 10, 50, 20, "test")])],
|
|
"image_size": (200, 200),
|
|
}
|
|
schema = extract_layout_schema(layout)
|
|
assert schema["total_rows"] == 1
|
|
assert schema["total_columns"] == 1
|
|
|
|
def test_two_rows(self):
|
|
"""两行 → header + data。"""
|
|
rows = [
|
|
_make_row([_make_elem(10, 10, 50, 20, "表头")]),
|
|
_make_row([_make_elem(10, 40, 50, 20, "数据")]),
|
|
]
|
|
schema = extract_layout_schema({"rows": rows, "image_size": (200, 200)})
|
|
assert schema["total_rows"] == 2
|
|
|
|
def test_schema_text_format(self):
|
|
"""schema_text 包含中文标签(列定义、区域)。"""
|
|
layout = _make_5col_10row_layout()
|
|
schema = extract_layout_schema(layout)
|
|
assert "列定义" in schema["schema_text"]
|
|
assert "区域" in schema["schema_text"]
|
|
assert "A4纵向" in schema["schema_text"]
|
|
|
|
def test_column_width_categories(self):
|
|
"""宽度分类:窄/中/宽 标签存在。"""
|
|
layout = _make_5col_10row_layout()
|
|
schema = extract_layout_schema(layout)
|
|
text = schema["schema_text"]
|
|
assert any(w in text for w in ("窄", "中", "宽"))
|
|
|
|
def test_a4_dimensions(self):
|
|
"""A4 尺寸固定为 595x842。"""
|
|
layout = _make_5col_10row_layout()
|
|
schema = extract_layout_schema(layout)
|
|
assert schema["a4_dimensions"] == {"width": 595, "height": 842}
|
|
|
|
|
|
# ============================================================
|
|
# TestFormatRowCoordinates
|
|
# ============================================================
|
|
|
|
class TestFormatRowCoordinates:
|
|
"""_format_row_coordinates() 单元测试。"""
|
|
|
|
def test_formats_single_row(self):
|
|
"""单行 → columns 列表包含正确字段。"""
|
|
row = _make_row([
|
|
_make_elem(10, 100, 50, 20, "序号"),
|
|
_make_elem(60, 100, 80, 20, "姓名"),
|
|
])
|
|
result = _format_row_coordinates(row)
|
|
assert len(result["columns"]) == 2
|
|
assert result["y_center"] > 0
|
|
assert "col" in result["columns"][0]
|
|
assert "text" in result["columns"][0]
|
|
|
|
def test_sorts_by_x(self):
|
|
"""元素按 x 坐标从左到右排序。"""
|
|
row = _make_row([
|
|
_make_elem(200, 100, 80, 20, "right"),
|
|
_make_elem(10, 100, 50, 20, "left"),
|
|
])
|
|
result = _format_row_coordinates(row)
|
|
assert result["columns"][0]["text"] == "left"
|
|
assert result["columns"][1]["text"] == "right"
|
|
|
|
def test_empty_row(self):
|
|
"""空元素列表 → columns 为空。"""
|
|
result = _format_row_coordinates({"y_center": 100, "elements": []})
|
|
assert result["columns"] == []
|
|
|
|
def test_non_dict_input(self):
|
|
"""非 dict 输入 → 返回空 dict。"""
|
|
assert _format_row_coordinates(None) == {}
|
|
assert _format_row_coordinates("not a dict") == {}
|
|
|
|
|
|
# ============================================================
|
|
# TestRouting
|
|
# ============================================================
|
|
|
|
class TestRouting:
|
|
"""route_after_retrieve() 路由逻辑测试。"""
|
|
|
|
def _state(self, **kwargs) -> AgentState:
|
|
s = {"current_jrxml": "", "user_input": "test", "conversation_history": []}
|
|
s.update(kwargs)
|
|
return s
|
|
|
|
def test_with_schema(self):
|
|
"""layout_schema 有行 → generate_skeleton。"""
|
|
state = self._state(layout_schema={"total_rows": 5, "total_columns": 3})
|
|
assert route_after_retrieve(state) == "generate_skeleton"
|
|
|
|
def test_without_schema(self):
|
|
"""无 layout_schema → generate。"""
|
|
state = self._state()
|
|
assert route_after_retrieve(state) == "generate"
|
|
|
|
def test_empty_schema(self):
|
|
"""空 dict → generate。"""
|
|
state = self._state(layout_schema={})
|
|
assert route_after_retrieve(state) == "generate"
|
|
|
|
def test_zero_rows(self):
|
|
"""total_rows = 0 → generate。"""
|
|
state = self._state(layout_schema={"total_rows": 0, "total_columns": 0})
|
|
assert route_after_retrieve(state) == "generate"
|
|
|
|
|
|
# ============================================================
|
|
# TestIntegration
|
|
# ============================================================
|
|
|
|
class TestIntegration:
|
|
"""集成测试:验证图执行路径。"""
|
|
|
|
def test_three_phase_nodes_exist(self):
|
|
"""三个新节点可被导入。"""
|
|
from agent.nodes import generate_skeleton, refine_layout, map_fields
|
|
assert callable(generate_skeleton)
|
|
assert callable(refine_layout)
|
|
assert callable(map_fields)
|
|
|
|
def test_graph_builds_with_new_nodes(self):
|
|
"""图构建成功包含新节点。"""
|
|
from agent.graph import build_graph
|
|
graph = build_graph()
|
|
nodes = graph.get_graph().nodes
|
|
node_names = {n for n in nodes}
|
|
assert "generate_skeleton" in node_names
|
|
assert "refine_layout" in node_names
|
|
assert "map_fields" in node_names
|
|
|
|
def test_one_shot_path_unchanged(self):
|
|
"""无 schema 时 route_after_retrieve → generate。"""
|
|
from agent.graph import route_after_retrieve as rar
|
|
state = {"layout_schema": {}, "current_jrxml": "", "user_input": "test"}
|
|
assert rar(state) == "generate"
|
|
|
|
def test_modify_report_unaffected(self):
|
|
"""modify_report 意图不受影响 — route_by_intent 不变。"""
|
|
from agent.graph import route_by_intent
|
|
state = {
|
|
"intent": "modify_report",
|
|
"current_jrxml": "<jasperReport>...</jasperReport>",
|
|
"layout_schema": {"total_rows": 5},
|
|
}
|
|
assert route_by_intent(state) == "modify_jrxml"
|
|
|
|
|
|
# ============================================================
|
|
# TestLayoutSchemaJSONRoundTrip
|
|
# ============================================================
|
|
|
|
class TestLayoutSchemaJSONRoundTrip:
|
|
"""验证 schema 可被 JSON 序列化/反序列化(用于会话持久化)。"""
|
|
|
|
def test_serializable(self):
|
|
"""extract_layout_schema 输出可 JSON 序列化。"""
|
|
layout = _make_5col_10row_layout()
|
|
schema = extract_layout_schema(layout)
|
|
dumped = json.dumps(schema, ensure_ascii=False)
|
|
loaded = json.loads(dumped)
|
|
assert loaded["total_rows"] == schema["total_rows"]
|
|
assert loaded["total_columns"] == schema["total_columns"]
|