feat: layered precise generation for A4 report images

3-phase pipeline to solve LLM prompt overflow from too many OCR elements:
Phase 1 (generate_skeleton): compressed layout schema → skeleton JRXML
Phase 2 (refine_layout): sampled coordinates → pixel-level position tuning
Phase 3 (map_fields): OCR field names → replace $F{field_N} placeholders

Only triggered when layout_schema.total_rows > 0 on initial_generation intent.
Text requests and all other intents are unaffected (zero behavior change).
This commit is contained in:
2026-05-21 08:34:32 +08:00
parent 9bb011e429
commit 43a0542a11
14 changed files with 882 additions and 81 deletions
+26 -6
View File
@@ -41,7 +41,10 @@ agent/graph.py (LangGraph 状态机)
│ 节点流程:
│ load_session → process_input → manage_context → save_state_snapshot
│ → classify_intent (8种意图路由)
│ ├─ retrieve → generate → save_session → validate → ... → finalize
│ ├─ retrieve → route_after_retrieve
│ ├─ [有布局schema] generate_skeleton → refine_layout → map_fields
│ └─ [无布局schema] generate
├─ generate/map_fields → save_session → validate → ... → finalize
│ ├─ modify_jrxml → save_session → validate → ... → finalize
│ ├─ handle_consult / handle_undo / handle_reset → finalize
│ └─ preview/export → save_session → finalize (跳过验证)
@@ -50,7 +53,7 @@ agent/graph.py (LangGraph 状态机)
│ ▲ │
│ └──────── (retry < MAX_RETRY=3) ───────────────────┘
├──► prompts/loader.py Prompt 外部化:7 个 .md 文件热重载
├──► prompts/loader.py Prompt 外部化:10 个 .md 文件热重载
├──► backend/llm.py LLM 工厂: Anthropic SDK / OpenAI / Ollama (统一 stream/invoke)
├──► backend/logger.py 集中日志: JSON + trace_id + llm.log/app.log 分离
├──► backend/rag_adapter.py 语义搜索: ChromaDB + SentenceTransformer
@@ -69,17 +72,17 @@ agent/graph.py (LangGraph 状态机)
| 文件 | 职责 | 修改频率 |
|------|------|---------|
| `app.py` | Streamlit UI 入口,聊天界面 + 侧边栏 + 下载 + 文件上传 | **高** |
| `agent/state.py` | AgentState 类型定义(~26 字段,含 pending_failure_context / annotation_result | 低 |
| `agent/nodes.py` | 14 个工作流节点 + 流式生成 + 错误记录 | **高** |
| `agent/state.py` | AgentState 类型定义(~28 字段,含 layout_schema / annotation_result | 低 |
| `agent/nodes.py` | 18 个工作流节点 + 流式生成 + 错误记录 | **高** |
| `agent/graph.py` | 状态图编译 + 路由函数(预览跳过验证) | 中 |
| `prompts/loader.py` | Prompt 加载器(从 .md 文件热重载) | 低 |
| `prompts/*.md` | 7 个独立 Prompt 模板 | **高** |
| `prompts/*.md` | 10 个独立 Prompt 模板 | **高** |
| `backend/llm.py` | LLM 工厂,统一 `_BaseLLM` 接口(invoke + stream+ `_LLMLoggingWrapper` | 中 |
| `backend/logger.py` | 集中日志模块:JSON 格式化 + trace_id + 独立 llm.log | 低 |
| `backend/rag_adapter.py` | RAGSearcher 单例,语义搜索接口 | 中 |
| `backend/error_kb.py` | ErrorKB — 错误指纹去重 + ChromaDB 持久化 + 语义检索 | 中 |
| `backend/file_parser.py` | 文件解析: PDF/DOCX/XLSX/XLS/DOC/图片(EasyOCR→PaddleOCR回退)/文本 | 中 |
| `backend/layout_analyzer.py` | A4模板分析: 比例检测/EasyOCR→PaddleOCR元素提取/行分组/JRXML行匹配 | 中 |
| `backend/layout_analyzer.py` | A4模板分析: 比例检测/EasyOCR→PaddleOCR元素提取/行分组/JRXML行匹配/布局schema提取 | 中 |
| `backend/ocr_extractor.py` | OCR字段精确提取: 4策略(exact→kv_pair→regex→table_match) + 置信度 | 中 |
| `backend/annotation_detector.py` | 批注检测: 圈选(cv2 HoughCircles) + 箭头(HoughLinesP聚类) + OCR关联 + LLM格式化 | 中 |
| `backend/embeddings.py` | 嵌入模型工厂 (HuggingFace/OpenAI) | 低 |
@@ -115,6 +118,9 @@ agent/graph.py (LangGraph 状态机)
| `prompts/explain_error.md` | 错误转人话 |
| `prompts/compression.md` | 对话压缩摘要 |
| `prompts/consult.md` | 咨询解答 |
| `prompts/skeleton_generation.md` | 分层生成-骨架 |
| `prompts/refine_layout.md` | 分层生成-精调 |
| `prompts/field_mapping.md` | 分层生成-字段映射 |
## 新增功能 (v2)
@@ -191,6 +197,19 @@ agent/graph.py (LangGraph 状态机)
- `modify_jrxml` 节点 — 将 OCR 上下文注入 modification prompt
- OCR 上下文包含: 结构化字段、全部文本元素(含坐标)、批注检测结果
## 新增功能 (v5)
### 分层精确生成
- 解决 A4 报表图片 OCR 元素过多(数百个)导致 LLM prompt 超长的问题
- **3 阶段管线**(仅对 `initial_generation` + 有布局 schema 时触发):
1. `generate_skeleton` — 压缩的布局 schema → 骨架 JRXML (`$F{field_N}` 占位)
2. `refine_layout` — 采样坐标(表头+首行数据+末行)→ 像素级位置精调
3. `map_fields` — OCR 字段名 → 替换占位符
- `backend/layout_analyzer.py` — 新增 `extract_layout_schema()`: 列聚类 + 区域分类 + schema_text
- `agent/graph.py` — 新增 `route_after_retrieve()`: 有 schema 走 3 阶段,无 schema 走原有 1-shot
- `prompts/` — 新增 `skeleton_generation.md`, `refine_layout.md`, `field_mapping.md`
- 文本请求和所有其他意图零行为变更
## 已知注意点
- **Anthropic SDK**: 使用原始 `anthropic` 包(非 `langchain-anthropic`),因为需要直连 MiniMax 兼容端点。API Key 优先读 `ANTHROPIC_API_KEY`fallback `OPENAI_API_KEY`。Anthropic SDK 会自动将 key 放入 `x-api-key` header。
@@ -207,3 +226,4 @@ agent/graph.py (LangGraph 状态机)
- **opencv-python-headless**: 批注检测(圈选/箭头)依赖,通过 `pip install -r requirements.txt` 安装。
- **st-multimodal-chatinput**: Streamlit 聊天输入增强组件,替代 `st.chat_input`,支持粘贴/拖拽文件。返回 base64 编码文件内容。
- **xlwt**: 仅在测试中使用(生成 .xls 测试文件)。
- **分层精确生成**: 3 阶段管线仅在 `layout_schema.total_rows > 0` 时触发。文本请求和 `modify_report` 等意图不受影响,走原有 `generate` 节点。中间阶段(骨架/精调)跳过验证,只有最终 mapped 结果进入 `validate`
+173 -63
View File
@@ -11,19 +11,21 @@
3. [架构全景图](#3-架构全景图)
4. [数据总线:AgentState](#4-数据总线agentstate)
5. [状态机:graphpy](#5-状态机graphpy)
6. [14 个节点详解:nodespy](#6-14-个节点详解nodespy)
6. [18 个节点详解:nodespy](#6-18-个节点详解nodespy)
7. [LLM 调用层:llmpy](#7-llm-调用层llmpy)
8. [Prompt 系统:prompts](#8-prompt-系统prompts)
9. [RAG 与向量搜索](#9-rag-与向量搜索)
10. [错误自增长知识库](#10-错误自增长知识库)
11. [布局分析器](#11-布局分析器)
12. [文件解析器](#12-文件解析器)
13. [验证服务](#13-验证服务)
14. [会话持久化](#14-会话持久化)
15. [Streamlit UIapppy](#15-streamlit-uiapppy)
16. [配置参考](#16-配置参考)
17. [如何添加新功能](#17-如何添加新功能)
18. [调试指南](#18-调试指南)
10. [分层精确生成](#10-分层精确生成)
11. [错误自增长知识库](#11-错误自增长知识库)
12. [布局分析器](#12-布局分析器)
13. [文件解析器](#13-文件解析器)
14. [验证服务](#14-验证服务)
15. [会话持久化](#15-会话持久化)
16. [日志系统:loggerpy](#16-日志系统loggerpy)
17. [Streamlit UIapppy](#17-streamlit-uiapppy)
18. [配置参考](#18-配置参考)
19. [如何添加新功能](#19-如何添加新功能)
20. [调试指南](#20-调试指南)
---
@@ -89,7 +91,10 @@ streamlit run app.py --server.port 8501
│ │
│ load_session → process_input → manage_context → save_snapshot│
│ → classify_intent │
│ ├─ initial_generation → retrieve → generate
│ ├─ initial_generation → retrieve
│ │ ├─ [有布局schema] → generate_skeleton → refine │
│ │ │ → map_fields (3 阶段精确生成) │
│ │ └─ [无布局schema] → generate (原 1-shot) │
│ ├─ modify_report → modify_jrxml │
│ ├─ consult_question → handle_consult │
│ ├─ undo_modification → handle_undo │
@@ -114,7 +119,7 @@ streamlit run app.py --server.port 8501
┌──────────┐ ┌──────────────┐ ┌───────────────┐
│backend/ │ │prompts/ │ │validation_ │
│llm.py │ │loader.py │ │service/main.py│
│logger.py │ │*.md (7个 │ │(FastAPI, │
│logger.py │ │*.md (10个 │ │(FastAPI, │
│rag_ │ │Prompt模板) │ │独立进程) │
│adapter.py│ └──────────────┘ └───────────────┘
│error_kb │
@@ -126,6 +131,12 @@ streamlit run app.py --server.port 8501
│.py │
│file_ │
│parser.py │
│ocr_ │
│extractor │
│.py │
│annotation│
│_detector │
│.py │
│validation│
│.py │
│session.py│
@@ -148,7 +159,7 @@ streamlit run app.py --server.port 8501
## 4. 数据总线:AgentState
`agent/state.py` — 只有 23 个字段的定义,不包含任何逻辑。
`agent/state.py` — 只有 28 个字段的定义,不包含任何逻辑。
```python
class AgentState(TypedDict, total=False):
@@ -188,6 +199,14 @@ class AgentState(TypedDict, total=False):
# ── 失败上下文传递 ──
pending_failure_context: dict # 重试耗尽后暂存失败信息,下次用户输入时自动注入
# ── 分层精确生成 (v5) ──
layout_schema: dict # extract_layout_schema() 输出,列+区域结构
ocr_elements: list # OCR 原始行数据(用于阶段二坐标采样)
# ── OCR 与批注 (v3/v4) ──
ocr_extraction_result: dict # OCR 字段精确提取结果
annotation_result: dict # 批注检测结果(圈选+箭头)
```
**数据流向**:每个节点函数接收 `state`,修改后返回 `state`(实际上是 dict)。LangGraph 自动合并返回值到全局状态。
@@ -216,6 +235,13 @@ def route_by_intent(state) -> Literal["retrieve", "modify_jrxml", ...]:
def route_after_validate(state) -> Literal["finalize", "explain_error"]:
return "finalize" if state.get("status") == "pass" else "explain_error"
def route_after_retrieve(state) -> Literal["generate", "generate_skeleton"]:
"""layout_schema 有行时走 3 阶段精确生成,否则走原 1-shot"""
schema = state.get("layout_schema")
if schema and isinstance(schema, dict) and schema.get("total_rows", 0) > 0:
return "generate_skeleton"
return "generate"
def route_after_correct(state) -> Literal["validate", "finalize"]:
return "validate" if state.get("retry_count", 0) < MAX_RETRY else "finalize"
```
@@ -225,6 +251,7 @@ def route_after_correct(state) -> Literal["validate", "finalize"]:
**关键路由逻辑**
- `route_by_intent`8 种意图分叉,是整个系统的"交通枢纽"
- `route_after_retrieve`:有 layout_schema → 3 阶段精确生成(generate_skeleton → refine_layout → map_fields),无 schema → 原 1-shot generate
- `route_after_save`:预览/导出意图**跳过验证**直通 finalize(这是修复预览问题的关键)
- `route_after_correct`:重试次数 < 3 则继续验证循环,否则认输
@@ -237,7 +264,7 @@ def build_graph():
# 注册节点
workflow.add_node("load_session", load_session_node)
workflow.add_node("process_input", process_input)
# ... 14 个节点
# ... 18 个节点
# 连线
workflow.set_entry_point("load_session")
@@ -279,38 +306,53 @@ def build_graph():
retrieve modify save_ handle_ handle_ handle_
_jrxml session consult undo reset
│ │ │ │ │
│ │ ▼ │
generate │ │ save_session │
│ │ │ │ │
└───┬────┘ │ ▼ │
│ finalize │
│ │
save_session ◄───────────┘
├── preview/export? ──► finalize
validate ◄────────────────────────────────┘
pass fail
│ explain_error
correct_jrxml
│ │
│ ├── retry < 3? ──► validate (循环)
│ │
│ └── retry >= 3? ──► finalize (放弃)
finalize ──► END
┌────┤ │ │ ▼ │
│ │ │ │ save_session │
│ │ │ │ │
generate│ │ │ ▼ │
(1-shot) │ │ │ finalize │
│ │
│ ▼ │ │
generate
│ _skeleton │ │
│ refine │ │ │
│ _layout │ │
│ │ │ │ │
│ ▼ │ │
│ map_ │ │
│ fields │ │ │
│ │ │ │
└──┬──┘ │ │
│ │
save_session ◄─┘ │ │
── preview/export? ──► finalize
│ ▲
│ │
validate ─────────────────────┘ │
│ │ │
pass fail │
│ │ │
│ ▼ │
│ explain_error │
│ │ │
│ ▼ │
│ correct_jrxml │
│ │ │
│ ├── retry < 3? ──► validate (循环) │
│ │ │
│ └── retry >= 3? ──► finalize (放弃) │
│ │
▼ │
finalize ──► END │
```
---
## 6. 14 个节点详解:nodes.py
## 6. 18 个节点详解:nodes.py
`agent/nodes.py` 是系统的"血肉",每个节点实现一个处理步骤。
@@ -563,17 +605,20 @@ def load_prompt(name: str) -> str:
这意味着你可以直接编辑 `prompts/*.md`,下次请求立即生效,无需重启。
### 8.2 7 个 Prompt 文件
### 8.2 10 个 Prompt 文件
| 文件 | 调用节点 | 占位符 | 用途 |
|------|---------|--------|------|
| `intent_classify.md` | classify_intent | `{has_report}`, `{user_input}` | 8 分类意图识别 |
| `initial_generation.md` | generate | `{context}`, `{user_request}` | 首次生成 JRXML |
| `modification.md` | modify_jrxml | `{current_jrxml}`, `{conversation_history}`, `{modification_request}` | 修改现有 JRXML |
| `modification.md` | modify_jrxml | `{current_jrxml}`, `{conversation_history}`, `{modification_request}`, `{ocr_context}` | 修改现有 JRXML |
| `correction.md` | correct_jrxml | `{current_jrxml}`, `{error_msg}`, `{explanation}` | 修正验证错误 |
| `explain_error.md` | explain_error | `{error_msg}`, `{jrxml_snippet}` | 技术错误转人话 |
| `compression.md` | manage_context | `{conversation_text}` | 对话摘要压缩 |
| `consult.md` | handle_consult | `{question}` | 咨询问答 |
| `skeleton_generation.md` | generate_skeleton | `{layout_schema}`, `{context}`, `{user_request}` | 骨架 JRXML ($F{field_N}) |
| `refine_layout.md` | refine_layout | `{current_jrxml}`, `{sampled_coordinates}` | 像素级位置精调 |
| `field_mapping.md` | map_fields | `{current_jrxml}`, `{ocr_fields}` | 占位符 → 真实字段名 |
### 8.3 Prompt 模板写法
@@ -630,7 +675,72 @@ class RAGSearcher:
---
## 10. 错误自增长知识库
## 10. 分层精确生成
专为 A4 报表图片上传场景设计,解决 OCR 元素过多(数百个)导致 LLM prompt 超长的问题。
### 10.1 触发条件
仅当满足以下条件时走 3 阶段管线:
- `intent == "initial_generation"`(新建报表)
- `layout_schema` 存在且 `total_rows > 0`(成功提取布局 schema
其他所有意图(modify_report、文本新建等)走原有 1-shot `generate` 节点,零行为变更。
### 10.2 3 阶段管线
```
上传 A4 图片
│ analyze_layout() → layout dict
│ extract_layout_schema() → schema
route_after_retrieve()
├─ 有 schema → generate_skeleton → refine_layout → map_fields
└─ 无 schema → generate (原 1-shot)
```
**Phase 1: generate_skeleton**
- 输入:压缩的布局 schema`schema_text`:列定义 + 区域 + 宽度分类)
- 输出:骨架 JRXML,所有字段用 `$F{field_N}` 占位
- 目标:正确的 band 结构和大致位置
**Phase 2: refine_layout**
- 输入:当前 JRXML + 采样坐标(表头行 + 首行数据 + 末行)
- 输出:像素级位置精调后的 JRXML
- 目标:精确的 x/y/w/h 数值,中间行通过插值处理
**Phase 3: map_fields**
- 输入:当前 JRXML + OCR 字段名列表(来自 `ocr_extraction_result.fields`
- 输出:`$F{field_N}` → 真实字段名(如 `$F{name}``$F{department}`
- 目标:可读且可编译的完整 JRXML
**关键设计**:中间阶段(骨架/精调)跳过验证,只有最终 mapped 结果进入 validate 循环。
### 10.3 extract_layout_schema()
位于 `backend/layout_analyzer.py`,在 `analyze_layout()` 之后调用:
```python
def extract_layout_schema(layout_result: dict) -> dict:
# 列检测:X 坐标聚类,同列条件 → X 中心距离 < avg_width * 0.5
# 区域分类:row[0] 元素少 → title; row[1] → header; 末尾1-2行 → footer
# 宽度分类:< A4宽度 10% → 窄; > 25% → 宽; 其余 → 中
# 返回: {columns, regions, total_rows, total_columns, a4_dimensions, schema_text}
```
`schema_text` 示例:`"报表布局: 5列 x 10行, A4纵向\n列定义: 序号(窄), 姓名(中), 部门(中), 职位(中), 入职日期(宽)\n区域: 标题(1行) → 表头(1行) → 数据(8行)"`
### 10.4 _format_row_coordinates()
```python
def _format_row_coordinates(row: dict) -> dict:
# 将 OCR 单行元素转为 {y_center, columns: [{col, x, y, w, h, font_size, text}]}
# 按 x 坐标从左到右排序
```
---
## 11. 错误自增长知识库
`backend/error_kb.py` — 自动积累修正成功的错误案例,下次遇到相似错误时提供参考。
@@ -676,9 +786,9 @@ ChromaDB 中每条记录:
---
## 11. 布局分析器
## 12. 布局分析器
`backend/layout_analyzer.py` — 处理用户上传的图片/PDF,识别报表布局结构。
`backend/layout_analyzer.py` — 处理用户上传的图片/PDF,识别报表布局结构。另有 `extract_layout_schema()` 从 OCR 行数据提取列+区域的紧凑描述(用于分层精确生成)。
### 11.1 三种处理路径
@@ -739,7 +849,7 @@ def _parse_jrxml_sections(jrxml):
---
## 12. 文件解析器
## 13. 文件解析器
`backend/file_parser.py` — 统一的多格式文件解析入口。
@@ -769,7 +879,7 @@ def parse_file(file_path, file_type="") -> dict:
---
## 13. 验证服务
## 14. 验证服务
`validation_service/main.py` — 独立的 FastAPI 进程,提供 JRXML 验证。
@@ -805,7 +915,7 @@ def validate_jrxml(jrxml_text):
---
## 14. 会话持久化
## 15. 会话持久化
`backend/session.py` — 基于 JSON 文件的简单 CRUD,每个会话一个文件。
@@ -833,7 +943,7 @@ generate_session_id() → str # UUID hex[:12]
---
## 15. 日志系统:logger.py
## 16. 日志系统:logger.py
`backend/logger.py` 提供结构化日志能力,是整个系统的"黑匣子"。
@@ -888,14 +998,14 @@ backend/logger.py
### 15.5 `@log_node` 装饰器
[agent/nodes.py](file:///d:/Idea%20Project/jaspersoft/agent/nodes.py) 中 17 个节点均使用 `@log_node("节点名")` 装饰器,自动记录:
[agent/nodes.py](file:///d:/Idea%20Project/jaspersoft/agent/nodes.py) 中 18 个节点均使用 `@log_node("节点名")` 装饰器,自动记录:
- **入口日志** — 节点开始执行时的 state 摘要
- **出口日志** — 节点完成时的 state 摘要 + 耗时 (duration_ms)
- **异常日志** — 节点抛异常时的错误信息 + state 摘要
### 15.6 `@_log_route` 装饰器
[agent/graph.py](file:///d:/Idea%20Project/jaspersoft/agent/graph.py) 中 8 个路由函数均使用 `@_log_route("路由名")`,自动记录每次路由决策(from → to)。
[agent/graph.py](file:///d:/Idea%20Project/jaspersoft/agent/graph.py) 中 9 个路由函数均使用 `@_log_route("路由名")`,自动记录每次路由决策(from → to)。
### 15.7 日志分析示例
@@ -912,7 +1022,7 @@ jq 'select(.extra.direction=="response") | {caller: .extra.caller, ms: .extra.du
---
## 16. Streamlit UIapp.py
## 17. Streamlit UIapp.py
`app.py` 是整个系统的入口,约 560 行。分为几个区域:
@@ -1009,7 +1119,7 @@ parent.addEventListener('keydown', function(e) {
---
## 17. 配置参考
## 18. 配置参考
所有配置通过 `.env` 文件管理。完整配置项:
@@ -1040,7 +1150,7 @@ parent.addEventListener('keydown', function(e) {
---
## 18. 如何添加新功能
## 19. 如何添加新功能
### 18.1 添加新的意图类型
@@ -1084,7 +1194,7 @@ elif provider == "my_provider":
---
## 19. 调试指南
## 20. 调试指南
### 19.1 常见问题
@@ -1164,22 +1274,22 @@ st.json(state) # 打印完整状态(调试用,记得删除)
| 文件 | 行数 | 角色 |
|------|------|------|
| `app.py` | ~670 | Streamlit UI 入口(多模态聊天输入) |
| `agent/state.py` | ~48 | 状态类型定义(26 字段) |
| `agent/nodes.py` | ~740 | 15 个工作流节点 |
| `agent/graph.py` | ~232 | 状态图编译 + 路由 |
| `app.py` | ~690 | Streamlit UI 入口(多模态聊天输入) |
| `agent/state.py` | ~52 | 状态类型定义(28 字段) |
| `agent/nodes.py` | ~900 | 18 个工作流节点 |
| `agent/graph.py` | ~270 | 状态图编译 + 路由9 个路由函数) |
| `backend/llm.py` | ~105 | LLM 工厂 (3 个后端) |
| `backend/rag_adapter.py` | ~156 | ChromaDB 语义搜索 |
| `backend/error_kb.py` | ~226 | 错误知识库 |
| `backend/embeddings.py` | ~49 | 嵌入模型工厂 |
| `backend/file_parser.py` | ~320 | 多格式文件解析(7 种格式) |
| `backend/layout_analyzer.py` | ~495 | A4 模板布局分析 |
| `backend/layout_analyzer.py` | ~600 | A4 模板布局分析 + 布局 schema 提取 |
| `backend/ocr_extractor.py` | ~380 | OCR 字段精确提取 |
| `backend/annotation_detector.py` | ~250 | 批注检测(圈选 + 箭头) |
| `backend/validation.py` | ~27 | 验证服务 HTTP 客户端 |
| `backend/session.py` | ~113 | 会话 JSON CRUD |
| `prompts/loader.py` | ~54 | Prompt 热重载 |
| `prompts/*.md` (7 个) | — | Prompt 模板 |
| `prompts/*.md` (10 个) | — | Prompt 模板 |
| `validation_service/main.py` | ~130 | FastAPI 验证服务 |
| `.env.example` | ~62 | 配置模板 |
| `requirements.txt` | ~42 | Python 依赖 |
+8 -6
View File
@@ -12,6 +12,7 @@
- **聊天粘贴/拖拽**:支持直接在对话框中 Ctrl+V 粘贴或拖拽文件(图片/PDF/Excel/Word
- **单据OCR识别**:上传报表单据图片后自动提取所有字段(4策略优先级 + 置信度评分)
- **批注检测**:识别手写单据上的圈选和箭头标记,自动定位用户要修改的字段
- **分层精确生成**:A4 报表图片先提取布局 schema,再分 3 阶段(骨架→精调→字段映射)生成,避免 OCR 元素过多导致 prompt 溢出
- **下载**:导出已验证的、可供 JasperReports 使用的 JRXML 文件
## 架构
@@ -21,7 +22,7 @@ Streamlit 界面 (app.py)
|
LangGraph 代理 (agent/)
|-- retrieve (Chroma/embeddings)
|-- generate (LLM)
|-- generate / generate_skeleton → refine_layout → map_fields (分层生成)
|-- validate (FastAPI service)
|-- explain + correct (auto-fix loop)
|-- modify (multi-turn edits)
@@ -111,9 +112,9 @@ pytest tests/ -v
jrxml-agent/
app.py Streamlit 聊天界面(多模态输入)
agent/
state.py AgentState 定义(26 字段)
nodes.py 图节点(generate, validate, modify 等,15 节点)
graph.py LangGraph 状态机
state.py AgentState 定义(28 字段)
nodes.py 图节点(generate, generate_skeleton, refine_layout 等,18 节点)
graph.py LangGraph 状态机(含分层生成路由)
backend/
llm.py LLM 工厂(Anthropic SDK / OpenAI / Ollama
logger.py 集中日志模块(JSON + trace_id
@@ -122,13 +123,13 @@ jrxml-agent/
rag_adapter.py RAG 语义搜索适配器
error_kb.py 错误自增长知识库
file_parser.py 文件解析器(PDF/DOCX/XLSX/XLS/DOC/图片/文本)
layout_analyzer.py A4 模板布局分析
layout_analyzer.py A4 模板布局分析(含布局 schema 提取)
ocr_extractor.py OCR 字段精确提取(4 策略 + 置信度)
annotation_detector.py 批注检测(圈选 + 箭头 + OCR 关联)
session.py 会话持久化 CRUD
prompts/
loader.py Prompt 加载器(热重载)
*.md 7 个 Prompt 模板文件
*.md 10 个 Prompt 模板文件
validation_service/
main.py FastAPI 验证服务器
validate.bat Windows 启动器
@@ -147,6 +148,7 @@ jrxml-agent/
test_ocr_extraction.py OCR 字段提取单元测试
test_annotation_detector.py 批注检测测试
test_file_parser_formats.py 多格式解析测试
test_layered_generation.py 分层生成测试
requirements.txt
.env.example
README.md
+40 -1
View File
@@ -160,4 +160,43 @@
---
阶段一立即可做,无外部依赖。阶段二是主要工作量。阶段三是收尾。阶段四是可观测性基础。阶段五是 OCR 智能增强和用户体验改进。
## 阶段六:分层精确生成 (v5) ✓
### 16. 布局 Schema 提取 ✓
- [x] `backend/layout_analyzer.py` — 新增 `extract_layout_schema()` 函数(+107 行)
- [x] X 坐标聚类列检测(avg_width * 0.5 阈值)
- [x] 区域分类:标题/表头/数据/表尾(启发式算法)
- [x] `schema_text` 紧凑中文描述(列定义 + 区域 + 宽度分类)
- [x] 空行/单行/双行边界情况处理
- [x] 单元测试: `tests/test_layered_generation.py::TestExtractLayoutSchema` (9 tests)
### 17. 3 阶段生成管线 ✓
- [x] Phase 1: `generate_skeleton` — 压缩布局 schema → 骨架 JRXML (`$F{field_N}` 占位)
- [x] Phase 2: `refine_layout` — 采样坐标(表头+首行数据+末行)→ 像素级位置精调
- [x] Phase 3: `map_fields` — OCR 字段名 → 替换占位符为真实字段名
- [x] 中间阶段跳过验证(仅最终 mapped 结果进入 validate 循环)
- [x] 流式输出支持(每阶段逐字生成)
- [x] 单元测试: `tests/test_layered_generation.py::TestIntegration` (4 tests)
### 18. 路由与状态 ✓
- [x] `agent/graph.py` — 新增 `route_after_retrieve()` 条件路由
- [x] `layout_schema.total_rows > 0` → 3 阶段,否则 → 原有 1-shot
- [x] `agent/state.py` — 新增 `layout_schema: dict``ocr_elements: list`
- [x] 会话持久化支持(`save_session_node` / `load_session_node`
- [x] 文本请求和其他意图零行为变更
- [x] 单元测试: `tests/test_layered_generation.py::TestRouting` (4 tests)
### 19. Prompt 模板 ✓
- [x] `prompts/skeleton_generation.md` — 骨架生成 prompt
- [x] `prompts/refine_layout.md` — 布局精调 prompt
- [x] `prompts/field_mapping.md` — 字段映射 prompt
- [x] `prompts/loader.py` — 注册 3 个新模板(热重载)
### 20. UI 集成 ✓
- [x] `app.py` — 上传 A4 图片时自动调用 `extract_layout_schema()`
- [x] 新增节点标签:`🏗 生成骨架` / `📐 精调布局` / `🏷 映射字段`
- [x] 3 个新节点的详情渲染
---
阶段一立即可做,无外部依赖。阶段二是主要工作量。阶段三是收尾。阶段四是可观测性基础。阶段五是 OCR 智能增强和用户体验改进。阶段六解决 A4 报表图片 OCR 元素过多(数百个)导致 LLM prompt 超长的问题。
+36 -1
View File
@@ -16,6 +16,9 @@ from agent.nodes import (
classify_intent,
retrieve,
generate,
generate_skeleton,
refine_layout,
map_fields,
modify_jrxml,
handle_consult,
handle_undo,
@@ -87,6 +90,15 @@ def route_by_intent(state: AgentState) -> Literal[
return "retrieve"
@_log_route("route_after_retrieve")
def route_after_retrieve(state: AgentState) -> Literal["generate", "generate_skeleton"]:
"""当 layout_schema 存在时走三层精确生成,否则走原有 1-shot。"""
layout_schema = state.get("layout_schema")
if layout_schema and isinstance(layout_schema, dict) and layout_schema.get("total_rows", 0) > 0:
return "generate_skeleton"
return "generate"
@_log_route("route_after_generate")
def route_after_generate(state: AgentState) -> Literal["save_session"]:
return "save_session"
@@ -158,6 +170,11 @@ def build_graph() -> StateGraph:
workflow.add_node("handle_undo", handle_undo)
workflow.add_node("handle_reset", handle_reset)
# 新增节点:分层精确生成(阶段一~三)
workflow.add_node("generate_skeleton", generate_skeleton)
workflow.add_node("refine_layout", refine_layout)
workflow.add_node("map_fields", map_fields)
# ---- 入口和前置流程 ----
workflow.set_entry_point("load_session")
workflow.add_edge("load_session", "process_input")
@@ -180,12 +197,28 @@ def build_graph() -> StateGraph:
)
# ---- 初始生成分支 ----
workflow.add_edge("retrieve", "generate")
workflow.add_conditional_edges(
"retrieve",
route_after_retrieve,
{
"generate": "generate",
"generate_skeleton": "generate_skeleton",
},
)
# 原有 1-shot 路径
workflow.add_conditional_edges(
"generate",
route_after_generate,
{"save_session": "save_session"},
)
# 分层精确生成 3 阶段路径
workflow.add_edge("generate_skeleton", "refine_layout")
workflow.add_edge("refine_layout", "map_fields")
workflow.add_conditional_edges(
"map_fields",
route_after_generate,
{"save_session": "save_session"},
)
# ---- 修改分支 ----
workflow.add_conditional_edges(
@@ -264,4 +297,6 @@ def create_initial_state() -> AgentState:
jrxml_versions=[],
last_error_case={},
pending_failure_context={},
layout_schema={},
ocr_elements=[],
)
+122 -2
View File
@@ -378,7 +378,7 @@ def load_session_node(state: AgentState) -> Dict:
"current_jrxml", "final_jrxml", "compressed_history",
"session_name", "created_at", "history_states",
"ocr_extraction_result", "uploaded_file_path",
"annotation_result"):
"annotation_result", "layout_schema", "ocr_elements"):
if key in saved and key not in ("user_input", "stage"):
state[key] = saved[key]
state["session_name"] = data.get("session_name", "")
@@ -402,7 +402,7 @@ def save_session_node(state: AgentState) -> Dict:
"current_jrxml", "final_jrxml", "compressed_history",
"status", "error_msg", "history_states",
"ocr_extraction_result", "uploaded_file_path",
"annotation_result"):
"annotation_result", "layout_schema", "ocr_elements"):
if key in state:
persistable[key] = state[key]
persistable["updated_at"] = _now_iso()
@@ -437,6 +437,28 @@ def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def _format_row_coordinates(row: dict) -> dict:
"""将单行 OCR 元素格式化为紧凑的坐标描述,供阶段二 refine_layout 使用。"""
if not isinstance(row, dict):
return {}
elements = row.get("elements", [])
if not elements:
return {"y_center": row.get("y_center", 0), "columns": []}
sorted_elems = sorted(elements, key=lambda e: e.get("x", 0))
cols = []
for ci, e in enumerate(sorted_elems):
cols.append({
"col": ci,
"x": e.get("x", 0),
"y": e.get("y", 0),
"w": e.get("w", 0),
"h": e.get("h", 0),
"font_size": e.get("font_size", 12),
"text": e.get("text", ""),
})
return {"y_center": row.get("y_center", 0), "columns": cols}
def _format_ocr_context(state: AgentState) -> str:
"""将 OCR 提取结果格式化为 LLM 可用的上下文文本。"""
ocr_result = state.get("ocr_extraction_result")
@@ -540,6 +562,104 @@ def generate(state: AgentState) -> Dict:
return state
@log_node("generate_skeleton")
def generate_skeleton(state: AgentState) -> Dict:
"""阶段一:根据压缩的布局 schema 生成骨架 JRXML$F{field_N} 占位)。"""
from langgraph.config import get_stream_writer
writer = get_stream_writer()
llm = get_llm(caller="generate_skeleton")
schema = state.get("layout_schema", {})
schema_text = schema.get("schema_text", "") if isinstance(schema, dict) else ""
user_request = state.get("user_input", "")
prompt = load_prompt("skeleton_generation").format(
layout_schema=schema_text,
context=state.get("retrieved_context", ""),
user_request=user_request,
)
full = []
for chunk in llm.stream(prompt):
full.append(chunk)
writer({"type": "stream", "node": "generate_skeleton", "text": chunk})
jrxml = _extract_jrxml("".join(full))
state["current_jrxml"] = jrxml
state["conversation_history"].append({"role": "assistant", "content": jrxml})
return state
@log_node("refine_layout")
def refine_layout(state: AgentState) -> Dict:
"""阶段二:使用采样坐标(表头 + 首行数据 + 最后一行)精确调整元素位置。"""
from langgraph.config import get_stream_writer
writer = get_stream_writer()
llm = get_llm(caller="refine_layout")
ocr_rows = state.get("ocr_elements", [])
sampled = {}
if isinstance(ocr_rows, list) and len(ocr_rows) >= 1:
sampled["header_row"] = _format_row_coordinates(ocr_rows[0])
if len(ocr_rows) > 1:
sampled["first_data_row"] = _format_row_coordinates(ocr_rows[1])
if len(ocr_rows) > 2:
sampled["last_row"] = _format_row_coordinates(ocr_rows[-1])
sampled_text = json.dumps(sampled, ensure_ascii=False, indent=2)
prompt = load_prompt("refine_layout").format(
current_jrxml=state.get("current_jrxml", ""),
sampled_coordinates=sampled_text,
)
full = []
for chunk in llm.stream(prompt):
full.append(chunk)
writer({"type": "stream", "node": "refine_layout", "text": chunk})
jrxml = _extract_jrxml("".join(full))
state["current_jrxml"] = jrxml
state["conversation_history"].append({"role": "assistant", "content": jrxml})
return state
@log_node("map_fields")
def map_fields(state: AgentState) -> Dict:
"""阶段三:将占位字段名替换为 OCR 提取的真实字段名。"""
from langgraph.config import get_stream_writer
writer = get_stream_writer()
llm = get_llm(caller="map_fields")
ocr_result = state.get("ocr_extraction_result", {})
fields_text = ""
if isinstance(ocr_result, dict) and ocr_result.get("fields"):
field_descs = []
for f in ocr_result["fields"]:
fname = f.get("field_name", "")
fval = f.get("field_value", "")
if fname:
field_descs.append(f" - {fname}: {fval}")
if field_descs:
fields_text = "提取的字段:\n" + "\n".join(field_descs)
if not fields_text:
elements = ocr_result.get("elements", []) if isinstance(ocr_result, dict) else []
if elements:
texts = [e.get("text", "") for e in elements if e.get("text")]
fields_text = "OCR 文本内容:\n" + "\n".join(f" - {t}" for t in texts[:50])
prompt = load_prompt("field_mapping").format(
current_jrxml=state.get("current_jrxml", ""),
ocr_fields=fields_text,
)
full = []
for chunk in llm.stream(prompt):
full.append(chunk)
writer({"type": "stream", "node": "map_fields", "text": chunk})
jrxml = _extract_jrxml("".join(full))
state["current_jrxml"] = jrxml
state["conversation_history"].append({"role": "assistant", "content": jrxml})
return state
@log_node("modify_jrxml")
def modify_jrxml(state: AgentState) -> Dict:
"""根据用户的修改请求修改现有 JRXML。"""
+4
View File
@@ -47,3 +47,7 @@ class AgentState(TypedDict, total=False):
# 需求8:图片批注检测(圈选/箭头标记)
annotation_result: dict
# 需求9:分层精确生成
layout_schema: dict # extract_layout_schema() 输出,列+区域结构
ocr_elements: list # OCR 原始行数据(用于阶段二坐标采样)
+10 -1
View File
@@ -80,6 +80,9 @@ NODE_LABELS = {
"handle_undo": "↩ 撤销操作",
"handle_reset": "🔄 重置会话",
"save_session": "💾 保存会话",
"generate_skeleton": "🏗 生成骨架",
"refine_layout": "📐 精调布局",
"map_fields": "🏷 映射字段",
}
INTENT_LABELS = {
@@ -133,6 +136,11 @@ def _process_uploaded_file(uploaded_file, suffix: str) -> dict:
if tt == "full_a4":
parsed_text = layout["description"]
parsed_type = "a4_template"
# 存储布局 schema 供分层精确生成使用
from backend.layout_analyzer import extract_layout_schema
schema = extract_layout_schema(layout)
st.session_state.agent_state["layout_schema"] = schema
st.session_state.agent_state["ocr_elements"] = layout.get("rows", [])
elif tt == "partial_rows":
parsed_type = "a4_partial"
if current_jrxml.strip():
@@ -290,7 +298,8 @@ def run_agent(user_input: str):
f"找到 {len(ctx)} 字符参考模板" if ctx else "未匹配到模板"
)
elif node_name in ("generate", "modify_jrxml", "correct_jrxml"):
elif node_name in ("generate", "modify_jrxml", "correct_jrxml",
"generate_skeleton", "refine_layout", "map_fields"):
jrxml = node_state.get("current_jrxml", "")
executed_nodes[-1]["detail"] = f"生成 {len(jrxml)} 字符 JRXML"
+140
View File
@@ -119,6 +119,146 @@ def analyze_layout(
}
def extract_layout_schema(layout_result: dict) -> dict:
"""将 analyze_layout() 的完整 OCR 行数据压缩为高层布局 schema。
列检测跨所有行对元素 X 坐标进行聚类
区域分类启发式识别标题/表头/数据/表尾行
输出紧凑的 schema_text LLM 阶段一骨架生成使用
"""
rows = layout_result.get("rows", [])
if not rows:
return _empty_schema()
img_w, img_h = layout_result.get("image_size", (595, 842))
if img_w <= 0:
img_w = 595
all_elements = []
for row in rows:
all_elements.extend(row.get("elements", []))
if not all_elements:
return _empty_schema()
x_centers = sorted((e["x"] + e["w"] / 2) for e in all_elements)
avg_width = sum(e["w"] for e in all_elements) / len(all_elements)
cluster_threshold = avg_width * 0.5
clusters = []
current_cluster = [x_centers[0]]
for xc in x_centers[1:]:
if xc - current_cluster[-1] < cluster_threshold:
current_cluster.append(xc)
else:
clusters.append(current_cluster)
current_cluster = [xc]
if current_cluster:
clusters.append(current_cluster)
columns = []
for ci, cluster in enumerate(clusters):
cx_min = min(cluster)
cx_max = max(cluster)
col_elements = [
e for e in all_elements
if cx_min - cluster_threshold <= (e["x"] + e["w"] / 2) <= cx_max + cluster_threshold
]
avg_w = sum(e["w"] for e in col_elements) / len(col_elements) if col_elements else 0
x_start = min(e["x"] for e in col_elements)
col_elements_by_y = sorted(col_elements, key=lambda e: e["y"])
header_text = col_elements_by_y[0]["text"] if col_elements_by_y else f"{ci+1}"
columns.append({
"index": ci,
"header_text": header_text,
"avg_width": round(avg_w, 1),
"x_start": round(x_start, 1),
})
columns.sort(key=lambda c: c["x_start"])
row_element_counts = [len(r.get("elements", [])) for r in rows]
median_count = sorted(row_element_counts)[len(row_element_counts) // 2] if row_element_counts else 0
total_rows = len(rows)
regions = []
current_region = None
for ri in range(total_rows):
count = row_element_counts[ri]
if ri == 0 and count < median_count * 0.6 and total_rows > 2:
rtype = "title"
elif ri == 0 and total_rows <= 2:
rtype = "header"
elif ri == 1 and total_rows > 2:
rtype = "header" if median_count > 0 else "data"
elif ri >= total_rows - 2 and count < median_count * 0.7 and total_rows > 3:
rtype = "footer"
else:
rtype = "data"
if current_region and current_region["type"] == rtype:
current_region["row_indices"].append(ri)
current_region["element_count"] += count
else:
if current_region:
regions.append(current_region)
current_region = {"type": rtype, "row_indices": [ri], "element_count": count}
if current_region:
regions.append(current_region)
# schema_text
width_ratios = [c["avg_width"] / img_w for c in columns]
width_labels = []
for r in width_ratios:
if r < 0.08:
width_labels.append("")
elif r > 0.20:
width_labels.append("")
else:
width_labels.append("")
col_descs = []
for ci, col in enumerate(columns):
wl = width_labels[ci] if ci < len(width_labels) else ""
col_descs.append(f"{col['header_text']}({wl})")
_rn = {"title": "标题", "header": "表头", "data": "数据", "footer": "表尾"}
region_parts = []
for r in regions:
label = _rn.get(r["type"], r["type"])
region_parts.append(f"{label}({len(r['row_indices'])}行)")
region_summary = "".join(region_parts)
schema_text = (
f"报表布局: {len(columns)}列 x {total_rows}行, A4纵向\n"
f"列定义: {', '.join(col_descs)}\n"
f"区域: {region_summary}"
)
return {
"columns": columns,
"regions": regions,
"total_rows": total_rows,
"total_columns": len(columns),
"a4_dimensions": {"width": 595, "height": 842},
"schema_text": schema_text,
}
def _empty_schema() -> dict:
return {
"columns": [],
"regions": [],
"total_rows": 0,
"total_columns": 0,
"a4_dimensions": {"width": 595, "height": 842},
"schema_text": "无法解析报表布局",
}
def match_rows_to_jrxml(
layout_result: dict,
current_jrxml: str,
+16
View File
@@ -0,0 +1,16 @@
你是一位资深 JasperReports 工程师。当前有一个 JRXML 使用占位字段名($F{field_1}, $F{field_2}, ...),需要替换为从 OCR 提取的真实字段名。
关键规则:
- 只输出完整修改后的 JRXML 代码,不要解释,不要 markdown 标记。
- 将每个 $F{field_N} 占位符替换为 OCR 提取结果中对应的真实字段名。
- 替换规则:根据列的顺序映射——$F{field_1} 对应第 1 列的 OCR 字段名,$F{field_2} 对应第 2 列,以此类推。
- 同时更新 <field name="..."> 声明和所有 $F{...} 表达式中的引用。
- 如果 OCR 提取的字段数少于占位字段数,保留多余的占位字段。
- 不要修改 band 结构、元素位置或大小。
- 确保 JRXML 兼容 JasperReports 7.0.6。
当前 JRXML(含占位字段):
{current_jrxml}
OCR 提取的结构化字段:
{ocr_fields}
+4 -1
View File
@@ -20,7 +20,10 @@ _NAME_MAP = {
"modification": "modification.md",
"correction": "correction.md",
"explain_error": "explain_error.md",
"compression": "compression.md",
"compression": "compression.md",
"skeleton_generation": "skeleton_generation.md",
"refine_layout": "refine_layout.md",
"field_mapping": "field_mapping.md",
}
+17
View File
@@ -0,0 +1,17 @@
你是一位资深 JasperReports 工程师。当前有一个骨架 JRXML,需要根据精确的像素坐标调整每个元素的位置。
关键规则:
- 只输出完整修改后的 JRXML 代码,不要解释,不要 markdown 标记。
- 根据提供的采样坐标,精确调整每个 textField/staticText 的 x, y, width, height。
- 表头行的坐标直接使用采样坐标中 header_row 对应列的 x, y, width, height。
- 数据行:根据 first_data_row 的坐标模式,向下插值生成剩余数据行(每行 y 递增行高)。
- 标题行(如有)和表尾行:保持其在骨架中的 y 位置大致不变,但调整 x 和 width 与列的采样坐标对齐。
- 不要修改字段名(保持 $F{field_N} 占位名不变)。
- 不要修改 band 结构。
- 确保 JRXML 兼容 JasperReports 7.0.6。
当前骨架 JRXML
{current_jrxml}
采样坐标(表头行 + 第一行数据行,像素位置):
{sampled_coordinates}
+19
View File
@@ -0,0 +1,19 @@
你是一位资深 JasperReports 工程师。根据以下报表布局描述和用户需求,生成一个完整的骨架 JRXML 文件。
关键规则:
- 只输出 JRXML 代码,不要解释,不要 markdown 标记。
- 使用 $F{field_1}, $F{field_2}, ... 作为占位字段名,并在 <field> 部分声明它们。
- 报表结构必须正确(title, pageHeader, columnHeader, detail, pageFooter 等 band)。
- 元素位置使用近似值即可,后续会精确调整。
- 根元素为 <jasperReport>,包含正确的 xmlns 属性。
- 包含 <queryString>,在 <![CDATA[...]]> 中放置占位 SQLSELECT * FROM table_name)。
- 确保 JRXML 兼容 JasperReports 7.0.6。
报表布局描述:
{layout_schema}
参考模板和组件:
{context}
用户需求:
{user_request}
+267
View File
@@ -0,0 +1,267 @@
"""测试分层精确生成:extract_layout_schema, _format_row_coordinates, 路由逻辑。"""
import json
import sys
from pathlib import Path
import pytest
# 确保项目根在 path 中
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from backend.layout_analyzer import extract_layout_schema
from agent.nodes import _format_row_coordinates
from agent.graph import route_after_retrieve
from agent.state import AgentState
# ============================================================
# 测试夹具
# ============================================================
def _make_row(elements: list[dict]) -> dict:
y_center = sum(e["y"] + e["h"] / 2 for e in elements) / len(elements) if elements else 0
return {"y_center": round(y_center, 1), "elements": elements}
def _make_elem(x: float, y: float, w: float, h: float, text: str) -> dict:
return {"x": x, "y": y, "w": w, "h": h, "font_size": h, "text": text}
def _make_5col_10row_layout() -> dict:
"""构造一个标准的 5列 x 10行 报表布局。"""
rows = []
for ri in range(10):
y = 100 + ri * 30
if ri == 0:
texts = ["员工名册"]
xs = [300]
ws = [100]
elif ri == 1:
texts = ["序号", "姓名", "部门", "职位", "入职日期"]
xs = [50, 150, 300, 450, 550]
ws = [40, 80, 100, 80, 80]
else:
texts = [str(ri - 1), f"员工{ri-1}", "技术部", "工程师", "2024-01-01"]
xs = [50, 150, 300, 450, 550]
ws = [40, 80, 100, 80, 80]
elements = [_make_elem(xs[i], y, ws[i], 24, texts[i]) for i in range(len(texts))]
rows.append(_make_row(elements))
return {"rows": rows, "image_size": (595, 842), "total_rows": 10, "total_elements": 46}
# ============================================================
# TestExtractLayoutSchema
# ============================================================
class TestExtractLayoutSchema:
"""extract_layout_schema() 单元测试。"""
def test_basic_table(self):
"""5列x10行 布局 → 正确列数和区域分类。"""
layout = _make_5col_10row_layout()
schema = extract_layout_schema(layout)
assert schema["total_columns"] == 5
assert schema["total_rows"] == 10
assert len(schema["columns"]) == 5
assert len(schema["regions"]) >= 3 # title + header + data at minimum
def test_title_detection(self):
"""第 0 行只有 1 个元素 → 标题区域。"""
layout = _make_5col_10row_layout()
schema = extract_layout_schema(layout)
region_types = [r["type"] for r in schema["regions"]]
assert "title" in region_types
def test_footer_detection(self):
"""尾部行元素更少 → 表尾区域。"""
rows = []
for ri in range(8):
y = 100 + ri * 30
elements = [_make_elem(50 + i * 100, y, 80, 24, f"col{i}") for i in range(5)]
rows.append(_make_row(elements))
# 最后一行只有 1 个元素(如合计)
rows.append(_make_row([_make_elem(300, 340, 100, 24, "合计: 100")]))
layout = {"rows": rows, "image_size": (595, 842), "total_rows": 9, "total_elements": 41}
schema = extract_layout_schema(layout)
region_types = [r["type"] for r in schema["regions"]]
assert "footer" in region_types
def test_empty_layout(self):
"""空行 → 返回零值 schema。"""
schema = extract_layout_schema({"rows": [], "image_size": (0, 0)})
assert schema["total_rows"] == 0
assert schema["total_columns"] == 0
assert schema["columns"] == []
assert "无法解析" in schema["schema_text"]
def test_single_row(self):
"""单行 → 全部为 data。"""
layout = {
"rows": [_make_row([_make_elem(10, 10, 50, 20, "test")])],
"image_size": (200, 200),
}
schema = extract_layout_schema(layout)
assert schema["total_rows"] == 1
assert schema["total_columns"] == 1
def test_two_rows(self):
"""两行 → header + data。"""
rows = [
_make_row([_make_elem(10, 10, 50, 20, "表头")]),
_make_row([_make_elem(10, 40, 50, 20, "数据")]),
]
schema = extract_layout_schema({"rows": rows, "image_size": (200, 200)})
assert schema["total_rows"] == 2
def test_schema_text_format(self):
"""schema_text 包含中文标签(列定义、区域)。"""
layout = _make_5col_10row_layout()
schema = extract_layout_schema(layout)
assert "列定义" in schema["schema_text"]
assert "区域" in schema["schema_text"]
assert "A4纵向" in schema["schema_text"]
def test_column_width_categories(self):
"""宽度分类:窄/中/宽 标签存在。"""
layout = _make_5col_10row_layout()
schema = extract_layout_schema(layout)
text = schema["schema_text"]
assert any(w in text for w in ("", "", ""))
def test_a4_dimensions(self):
"""A4 尺寸固定为 595x842。"""
layout = _make_5col_10row_layout()
schema = extract_layout_schema(layout)
assert schema["a4_dimensions"] == {"width": 595, "height": 842}
# ============================================================
# TestFormatRowCoordinates
# ============================================================
class TestFormatRowCoordinates:
"""_format_row_coordinates() 单元测试。"""
def test_formats_single_row(self):
"""单行 → columns 列表包含正确字段。"""
row = _make_row([
_make_elem(10, 100, 50, 20, "序号"),
_make_elem(60, 100, 80, 20, "姓名"),
])
result = _format_row_coordinates(row)
assert len(result["columns"]) == 2
assert result["y_center"] > 0
assert "col" in result["columns"][0]
assert "text" in result["columns"][0]
def test_sorts_by_x(self):
"""元素按 x 坐标从左到右排序。"""
row = _make_row([
_make_elem(200, 100, 80, 20, "right"),
_make_elem(10, 100, 50, 20, "left"),
])
result = _format_row_coordinates(row)
assert result["columns"][0]["text"] == "left"
assert result["columns"][1]["text"] == "right"
def test_empty_row(self):
"""空元素列表 → columns 为空。"""
result = _format_row_coordinates({"y_center": 100, "elements": []})
assert result["columns"] == []
def test_non_dict_input(self):
"""非 dict 输入 → 返回空 dict。"""
assert _format_row_coordinates(None) == {}
assert _format_row_coordinates("not a dict") == {}
# ============================================================
# TestRouting
# ============================================================
class TestRouting:
"""route_after_retrieve() 路由逻辑测试。"""
def _state(self, **kwargs) -> AgentState:
s = {"current_jrxml": "", "user_input": "test", "conversation_history": []}
s.update(kwargs)
return s
def test_with_schema(self):
"""layout_schema 有行 → generate_skeleton。"""
state = self._state(layout_schema={"total_rows": 5, "total_columns": 3})
assert route_after_retrieve(state) == "generate_skeleton"
def test_without_schema(self):
"""无 layout_schema → generate。"""
state = self._state()
assert route_after_retrieve(state) == "generate"
def test_empty_schema(self):
"""空 dict → generate。"""
state = self._state(layout_schema={})
assert route_after_retrieve(state) == "generate"
def test_zero_rows(self):
"""total_rows = 0 → generate。"""
state = self._state(layout_schema={"total_rows": 0, "total_columns": 0})
assert route_after_retrieve(state) == "generate"
# ============================================================
# TestIntegration
# ============================================================
class TestIntegration:
"""集成测试:验证图执行路径。"""
def test_three_phase_nodes_exist(self):
"""三个新节点可被导入。"""
from agent.nodes import generate_skeleton, refine_layout, map_fields
assert callable(generate_skeleton)
assert callable(refine_layout)
assert callable(map_fields)
def test_graph_builds_with_new_nodes(self):
"""图构建成功包含新节点。"""
from agent.graph import build_graph
graph = build_graph()
nodes = graph.get_graph().nodes
node_names = {n for n in nodes}
assert "generate_skeleton" in node_names
assert "refine_layout" in node_names
assert "map_fields" in node_names
def test_one_shot_path_unchanged(self):
"""无 schema 时 route_after_retrieve → generate。"""
from agent.graph import route_after_retrieve as rar
state = {"layout_schema": {}, "current_jrxml": "", "user_input": "test"}
assert rar(state) == "generate"
def test_modify_report_unaffected(self):
"""modify_report 意图不受影响 — route_by_intent 不变。"""
from agent.graph import route_by_intent
state = {
"intent": "modify_report",
"current_jrxml": "<jasperReport>...</jasperReport>",
"layout_schema": {"total_rows": 5},
}
assert route_by_intent(state) == "modify_jrxml"
# ============================================================
# TestLayoutSchemaJSONRoundTrip
# ============================================================
class TestLayoutSchemaJSONRoundTrip:
"""验证 schema 可被 JSON 序列化/反序列化(用于会话持久化)。"""
def test_serializable(self):
"""extract_layout_schema 输出可 JSON 序列化。"""
layout = _make_5col_10row_layout()
schema = extract_layout_schema(layout)
dumped = json.dumps(schema, ensure_ascii=False)
loaded = json.loads(dumped)
assert loaded["total_rows"] == schema["total_rows"]
assert loaded["total_columns"] == schema["total_columns"]