Merge remote v4/v5 features (multimodal chat input, layered generation, annotation detection) with local v3 features (dialog file upload, XLSX support, session fix)
Key resolutions: - agent/nodes.py: Merged session_id exclusion fix with new persistable fields (ocr_extraction_result, annotation_result, layout_schema, ocr_elements) - app.py: Adopted st-multimodal-chatinput for unified paste/drop/upload, removed custom JS paste bridge - backend/file_parser.py: Kept local XLSX parser, added remote XLS/DOC parsers - CLAUDE.md + CODE_GUIDE.md: Merged documentation from both branches Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -23,7 +23,7 @@ STREAMLIT_SERVER_HEADLESS=true streamlit run app.py --server.port 8501
|
||||
|
||||
## 当前配置(.env)
|
||||
|
||||
- **OCR**: EasyOCR(优先,ch_sim+en)→ PaddleOCR(回退),两者均未安装时仅返回图片元信息
|
||||
- **OCR**: PaddleOCR(精确识别首选,ppocr-v4)→ EasyOCR(回退,ch_sim+en),两者均未安装时仅返回图片元信息
|
||||
- **LLM**: `cloud` / `anthropic` → MiniMax Anthropic 兼容 API (`MiniMax-M2.7`)
|
||||
- Base URL: `https://api.minimaxi.com/anthropic`
|
||||
- 认证: Anthropic SDK 自动读取 `ANTHROPIC_API_KEY`(fallback `OPENAI_API_KEY`)
|
||||
@@ -44,7 +44,10 @@ agent/graph.py (LangGraph 状态机)
|
||||
│ 节点流程:
|
||||
│ load_session → process_input → manage_context → save_state_snapshot
|
||||
│ → classify_intent (8种意图路由)
|
||||
│ ├─ retrieve → generate → save_session → validate → ... → finalize
|
||||
│ ├─ retrieve → route_after_retrieve
|
||||
│ ├─ [有布局schema] generate_skeleton → refine_layout → map_fields
|
||||
│ └─ [无布局schema] generate
|
||||
├─ generate/map_fields → save_session → validate → ... → finalize
|
||||
│ ├─ modify_jrxml → save_session → validate → ... → finalize
|
||||
│ ├─ handle_consult / handle_undo / handle_reset → finalize
|
||||
│ └─ preview/export → save_session → finalize (跳过验证)
|
||||
@@ -53,14 +56,15 @@ agent/graph.py (LangGraph 状态机)
|
||||
│ ▲ │
|
||||
│ └──────── (retry < MAX_RETRY=3) ───────────────────┘
|
||||
│
|
||||
├──► prompts/loader.py Prompt 外部化:7 个 .md 文件热重载
|
||||
├──► prompts/loader.py Prompt 外部化:10 个 .md 文件热重载
|
||||
├──► backend/llm.py LLM 工厂: Anthropic SDK / OpenAI / Ollama (统一 stream/invoke)
|
||||
├──► backend/logger.py 集中日志: JSON + trace_id + llm.log/app.log 分离
|
||||
├──► backend/rag_adapter.py 语义搜索: ChromaDB + SentenceTransformer
|
||||
├──► backend/error_kb.py 错误知识库: 指纹去重 + ChromaDB 持久化
|
||||
├──► backend/file_parser.py 文件解析: PDF/DOCX/图片/文本
|
||||
├──► backend/file_parser.py 文件解析: PDF/DOCX/XLSX/XLS/DOC/图片/文本
|
||||
├──► backend/layout_analyzer.py A4布局分析: OCR + 行分组 + JRXML行匹配
|
||||
├──► backend/ocr_extractor.py OCR字段提取: 两阶段4策略精确提取单据字段值
|
||||
├──► backend/ocr_extractor.py OCR字段精确提取: 4策略优先级 + 置信度
|
||||
├──► backend/annotation_detector.py 批注检测: 圈选(HoughCircles) + 箭头(HoughLinesP) + OCR关联
|
||||
├──► backend/validation.py HTTP 客户端: POST /validate
|
||||
├──► backend/session.py 会话持久化: JSON 文件 CRUD
|
||||
└──► validation_service/ 独立 FastAPI: 结构检查 + XSD 校验
|
||||
@@ -71,18 +75,19 @@ agent/graph.py (LangGraph 状态机)
|
||||
| 文件 | 职责 | 修改频率 |
|
||||
|------|------|---------|
|
||||
| `app.py` | Streamlit UI 入口,聊天界面 + 对话文件上传(粘贴/拖拽) + 侧边栏 + 下载 | **高** |
|
||||
| `agent/state.py` | AgentState 类型定义(~24 字段,含 pending_failure_context) | 低 |
|
||||
| `agent/nodes.py` | 14 个工作流节点 + 流式生成 + 错误记录 | **高** |
|
||||
| `agent/state.py` | AgentState 类型定义(~28 字段,含 layout_schema / annotation_result) | 低 |
|
||||
| `agent/nodes.py` | 18 个工作流节点 + 流式生成 + 错误记录 | **高** |
|
||||
| `agent/graph.py` | 状态图编译 + 路由函数(预览跳过验证) | 中 |
|
||||
| `prompts/loader.py` | Prompt 加载器(从 .md 文件热重载) | 低 |
|
||||
| `prompts/*.md` | 7 个独立 Prompt 模板 | **高** |
|
||||
| `prompts/*.md` | 10 个独立 Prompt 模板 | **高** |
|
||||
| `backend/llm.py` | LLM 工厂,统一 `_BaseLLM` 接口(invoke + stream)+ `_LLMLoggingWrapper` | 中 |
|
||||
| `backend/logger.py` | 集中日志模块:JSON 格式化 + trace_id + 独立 llm.log | 低 |
|
||||
| `backend/rag_adapter.py` | RAGSearcher 单例,语义搜索接口 | 中 |
|
||||
| `backend/error_kb.py` | ErrorKB — 错误指纹去重 + ChromaDB 持久化 + 语义检索 | 中 |
|
||||
| `backend/file_parser.py` | 文件解析: PDF/DOCX/XLSX/图片(EasyOCR→PaddleOCR回退)/文本 | 中 |
|
||||
| `backend/layout_analyzer.py` | A4模板分析: 比例检测/EasyOCR→PaddleOCR元素提取/行分组/JRXML行匹配 | 中 |
|
||||
| `backend/ocr_extractor.py` | OCR单据字段提取: 两阶段流水线 + 4种策略(精确KV/模糊KV/正则/表格) + 17个默认中文字段 | 中 |
|
||||
| `backend/file_parser.py` | 文件解析: PDF/DOCX/XLSX/XLS/DOC/图片(EasyOCR→PaddleOCR回退)/文本 | 中 |
|
||||
| `backend/layout_analyzer.py` | A4模板分析: 比例检测/EasyOCR→PaddleOCR元素提取/行分组/JRXML行匹配/布局schema提取 | 中 |
|
||||
| `backend/ocr_extractor.py` | OCR字段精确提取: 4策略(exact→kv_pair→regex→table_match) + 置信度 | 中 |
|
||||
| `backend/annotation_detector.py` | 批注检测: 圈选(cv2 HoughCircles) + 箭头(HoughLinesP聚类) + OCR关联 + LLM格式化 | 中 |
|
||||
| `backend/embeddings.py` | 嵌入模型工厂 (HuggingFace/OpenAI) | 低 |
|
||||
| `backend/validation.py` | 验证服务 HTTP 客户端 | 低 |
|
||||
| `backend/session.py` | 会话 JSON 文件 CRUD | 低 |
|
||||
@@ -116,6 +121,9 @@ agent/graph.py (LangGraph 状态机)
|
||||
| `prompts/explain_error.md` | 错误转人话 |
|
||||
| `prompts/compression.md` | 对话压缩摘要 |
|
||||
| `prompts/consult.md` | 咨询解答 |
|
||||
| `prompts/skeleton_generation.md` | 分层生成-骨架 |
|
||||
| `prompts/refine_layout.md` | 分层生成-精调 |
|
||||
| `prompts/field_mapping.md` | 分层生成-字段映射 |
|
||||
|
||||
## 新增功能 (v2)
|
||||
|
||||
@@ -170,6 +178,50 @@ agent/graph.py (LangGraph 状态机)
|
||||
- `@log_node` / `@_log_route` — 装饰器自动记录节点和路由
|
||||
- 日志分离: `logs/app.log` (业务) + `logs/llm.log` (AI 调用)
|
||||
|
||||
## 新增功能 (v3/v4)
|
||||
|
||||
### OCR 单据字段精确提取 (v3)
|
||||
- `backend/ocr_extractor.py` — 4 策略优先级提取: exact_match → kv_pair → regex → table_match
|
||||
- PaddleOCR 首次识别后将原始结果(含所有文本元素 + bbox坐标)持久化
|
||||
- `_format_ocr_context()` — 将 OCR 结果(字段 + 原始元素坐标)格式化为 LLM prompt 注入
|
||||
- OCR 结果在 `modify_jrxml` 和 `generate` 节点中自动注入 prompt
|
||||
- `process_input` 节点在上传图片时自动触发 OCR 字段提取
|
||||
- 结果持久化到会话文件(`save_session_node` / `load_session_node`)
|
||||
|
||||
### 多模态聊天输入 + 多格式文件 (v4)
|
||||
- `app.py` — `st.chat_input` 替换为 `st_multimodal_chatinput`(支持 Ctrl+V 粘贴 + 拖拽 + 文件按钮)
|
||||
- `_process_uploaded_file()` — 提取共享文件处理逻辑(侧边栏 + 聊天共用,消除 ~70 行重复代码)
|
||||
- 新增文件格式支持: XLSX (openpyxl)、XLS (xlrd)、DOC (olefile)
|
||||
- 剪贴板粘贴文件通过 base64 解码 + MIME type → 扩展名推断
|
||||
- 侧边栏上传器类型列表中新增 xlsx/xls/doc
|
||||
|
||||
### 批注检测 (v4)
|
||||
- `backend/annotation_detector.py` — 识别用户在手写单据上的圈选和箭头标记
|
||||
- **圆圈检测**: 红色通道增强 → HoughCircles → 圆形度验证
|
||||
- **箭头检测**: Canny边缘 → HoughLinesP → 线段方向聚类 → 端点边缘密度判定方向
|
||||
- **OCR 关联**: 批注与附近 OCR 文本元素关联(15% 图片尺寸内)
|
||||
- **LLM 注入**: `format_annotation_context()` 将批注结果格式化为中文提示
|
||||
- `process_input` 节点在 OCR 提取后自动运行批注检测
|
||||
- `annotation_result` 字段持久化到 AgentState + 会话文件
|
||||
|
||||
### OCR 上下文提示增强 (v3/v4)
|
||||
- `prompts/modification.md` — 新增 `{ocr_context}` 占位符
|
||||
- `modify_jrxml` 节点 — 将 OCR 上下文注入 modification prompt
|
||||
- OCR 上下文包含: 结构化字段、全部文本元素(含坐标)、批注检测结果
|
||||
|
||||
## 新增功能 (v5)
|
||||
|
||||
### 分层精确生成
|
||||
- 解决 A4 报表图片 OCR 元素过多(数百个)导致 LLM prompt 超长的问题
|
||||
- **3 阶段管线**(仅对 `initial_generation` + 有布局 schema 时触发):
|
||||
1. `generate_skeleton` — 压缩的布局 schema → 骨架 JRXML (`$F{field_N}` 占位)
|
||||
2. `refine_layout` — 采样坐标(表头+首行数据+末行)→ 像素级位置精调
|
||||
3. `map_fields` — OCR 字段名 → 替换占位符
|
||||
- `backend/layout_analyzer.py` — 新增 `extract_layout_schema()`: 列聚类 + 区域分类 + schema_text
|
||||
- `agent/graph.py` — 新增 `route_after_retrieve()`: 有 schema 走 3 阶段,无 schema 走原有 1-shot
|
||||
- `prompts/` — 新增 `skeleton_generation.md`, `refine_layout.md`, `field_mapping.md`
|
||||
- 文本请求和所有其他意图零行为变更
|
||||
|
||||
## 已知注意点
|
||||
|
||||
- **Anthropic SDK**: 使用原始 `anthropic` 包(非 `langchain-anthropic`),因为需要直连 MiniMax 兼容端点。API Key 优先读 `ANTHROPIC_API_KEY`,fallback `OPENAI_API_KEY`。Anthropic SDK 会自动将 key 放入 `x-api-key` header。
|
||||
@@ -179,12 +231,15 @@ agent/graph.py (LangGraph 状态机)
|
||||
- **验证服务结构检查**: 字段引用一致性 (`$F{field}` vs `<field>` 声明)、SQL SELECT 存在性、pageWidth/pageHeight/name 属性。
|
||||
- **XSD 校验可选**: 需要 `validation_service/schemas/jasperreport_7_0_6.xsd` 存在。
|
||||
- **rag 子模块**: 内部有独立的管线脚本(`batch_chunker.py` → `embed_chunks.py` → `import_to_chroma.py`),通常不需要在主项目中运行。
|
||||
- **OCR 引擎**: 优先使用 EasyOCR(Windows 兼容性更好,`pip install easyocr`),回退 PaddleOCR。两者均未安装时仅返回图片元信息,建议至少安装 EasyOCR。
|
||||
- **OCR 字段提取**: `process_input` 自动检测上传图片,调用 `OcrExtractor` 提取 17 个常见中文字段(发票代码/号码/金额/日期等),提取结果自动注入 LLM 上下文。
|
||||
- **会话持久化**: `session_id` 现已包含在 `save_session_node` 的持久化字段中,避免切换会话时因 `session_id` 丢失导致的无限 rerun bug。
|
||||
- **v3 修复**: `create_session` 现在存盘前强制写入 `agent_state["session_id"] = sid`。`load_session_node` 不再从磁盘覆盖 `session_id`。切换会话增加 `_last_switched_to` 哨兵防止重复触发。
|
||||
- **OCR 引擎**: 优先 PaddleOCR 2.9.x(精确识别,`pip install paddleocr`),回退 EasyOCR 1.7+。两者均未安装时仅返回图片元信息。PaddlePaddle 3.x 在 Windows 上有 ONEDNN bug,固定在 2.6.x。
|
||||
- **OCR 字段提取**: `process_input` 自动检测上传图片,调用 `OcrExtractor` 提取常见中文字段(发票代码/号码/金额/日期等),提取结果自动注入 LLM 上下文。
|
||||
- **会话持久化**: `session_id` 现已包含在 `save_session_node` 的持久化字段中,避免切换会话时因 `session_id` 丢失导致的无限 rerun bug。`create_session` 存盘前强制写入 `agent_state["session_id"] = sid`。`load_session_node` 不从磁盘覆盖 `session_id`。切换会话增加 `_last_switched_to` 哨兵防止重复触发。
|
||||
- **MAX_RETRY**: 默认 3 次。重试耗尽后 `pending_failure_context` 记录失败信息,下次用户输入时自动注入。
|
||||
- **验证最小内容检查**: 验证服务额外检查至少 1 个 `<band>` + 1 个 `<textField>` 或 `<staticText>`,拦截空壳 JRXML。
|
||||
- **XLSX 支持 (v3)**: 需要 `openpyxl>=3.1.0`(已加入 requirements.txt)。表格按工作表逐行读取,单元格用 `|` 分隔。
|
||||
- **粘贴功能限制**: 文件以 base64 编码在 sessionStorage 中传递,单文件上限 20MB。大文件建议使用 file_uploader 按钮。
|
||||
- **torchvision**: `transformers` 库的懒加载需要 `torchvision`,已作为依赖安装。
|
||||
- **opencv-python-headless**: 批注检测(圈选/箭头)依赖,通过 `pip install -r requirements.txt` 安装。
|
||||
- **st-multimodal-chatinput**: Streamlit 聊天输入增强组件,替代 `st.chat_input`,支持粘贴/拖拽文件。返回 base64 编码文件内容。
|
||||
- **xlwt**: 仅在测试中使用(生成 .xls 测试文件)。
|
||||
- **分层精确生成**: 3 阶段管线仅在 `layout_schema.total_rows > 0` 时触发。文本请求和 `modify_report` 等意图不受影响,走原有 `generate` 节点。中间阶段(骨架/精调)跳过验证,只有最终 mapped 结果进入 `validate`。
|
||||
|
||||
+179
-126
@@ -11,20 +11,21 @@
|
||||
3. [架构全景图](#3-架构全景图)
|
||||
4. [数据总线:AgentState](#4-数据总线agentstate)
|
||||
5. [状态机:graphpy](#5-状态机graphpy)
|
||||
6. [14 个节点详解:nodespy](#6-14-个节点详解nodespy)
|
||||
6. [18 个节点详解:nodespy](#6-18-个节点详解nodespy)
|
||||
7. [LLM 调用层:llmpy](#7-llm-调用层llmpy)
|
||||
8. [Prompt 系统:prompts](#8-prompt-系统prompts)
|
||||
9. [RAG 与向量搜索](#9-rag-与向量搜索)
|
||||
10. [错误自增长知识库](#10-错误自增长知识库)
|
||||
11. [布局分析器](#11-布局分析器)
|
||||
12. [文件解析器](#12-文件解析器)
|
||||
12b. [OCR 单据字段提取器](#12b-ocr-单据字段提取器)
|
||||
13. [验证服务](#13-验证服务)
|
||||
14. [会话持久化](#14-会话持久化)
|
||||
15. [Streamlit UI:apppy](#15-streamlit-uiapppy)
|
||||
16. [配置参考](#16-配置参考)
|
||||
17. [如何添加新功能](#17-如何添加新功能)
|
||||
18. [调试指南](#18-调试指南)
|
||||
10. [分层精确生成](#10-分层精确生成)
|
||||
11. [错误自增长知识库](#11-错误自增长知识库)
|
||||
12. [布局分析器](#12-布局分析器)
|
||||
13. [文件解析器](#13-文件解析器)
|
||||
14. [验证服务](#14-验证服务)
|
||||
15. [会话持久化](#15-会话持久化)
|
||||
16. [日志系统:loggerpy](#16-日志系统loggerpy)
|
||||
17. [Streamlit UI:apppy](#17-streamlit-uiapppy)
|
||||
18. [配置参考](#18-配置参考)
|
||||
19. [如何添加新功能](#19-如何添加新功能)
|
||||
20. [调试指南](#20-调试指南)
|
||||
|
||||
---
|
||||
|
||||
@@ -91,7 +92,10 @@ streamlit run app.py --server.port 8501
|
||||
│ │
|
||||
│ load_session → process_input → manage_context → save_snapshot│
|
||||
│ → classify_intent │
|
||||
│ ├─ initial_generation → retrieve → generate │
|
||||
│ ├─ initial_generation → retrieve │
|
||||
│ │ ├─ [有布局schema] → generate_skeleton → refine │
|
||||
│ │ │ → map_fields (3 阶段精确生成) │
|
||||
│ │ └─ [无布局schema] → generate (原 1-shot) │
|
||||
│ ├─ modify_report → modify_jrxml │
|
||||
│ ├─ consult_question → handle_consult │
|
||||
│ ├─ undo_modification → handle_undo │
|
||||
@@ -116,7 +120,7 @@ streamlit run app.py --server.port 8501
|
||||
┌──────────┐ ┌──────────────┐ ┌───────────────┐
|
||||
│backend/ │ │prompts/ │ │validation_ │
|
||||
│llm.py │ │loader.py │ │service/main.py│
|
||||
│logger.py │ │*.md (7个 │ │(FastAPI, │
|
||||
│logger.py │ │*.md (10个 │ │(FastAPI, │
|
||||
│rag_ │ │Prompt模板) │ │独立进程) │
|
||||
│adapter.py│ └──────────────┘ └───────────────┘
|
||||
│error_kb │
|
||||
@@ -131,6 +135,12 @@ streamlit run app.py --server.port 8501
|
||||
│.py │
|
||||
│file_ │
|
||||
│parser.py │
|
||||
│ocr_ │
|
||||
│extractor │
|
||||
│.py │
|
||||
│annotation│
|
||||
│_detector │
|
||||
│.py │
|
||||
│validation│
|
||||
│.py │
|
||||
│session.py│
|
||||
@@ -153,7 +163,7 @@ streamlit run app.py --server.port 8501
|
||||
|
||||
## 4. 数据总线:AgentState
|
||||
|
||||
`agent/state.py` — 只有 23 个字段的定义,不包含任何逻辑。
|
||||
`agent/state.py` — 只有 28 个字段的定义,不包含任何逻辑。
|
||||
|
||||
```python
|
||||
class AgentState(TypedDict, total=False):
|
||||
@@ -194,9 +204,14 @@ class AgentState(TypedDict, total=False):
|
||||
# ── 失败上下文传递 ──
|
||||
pending_failure_context: dict # 重试耗尽后暂存失败信息,下次用户输入时自动注入
|
||||
|
||||
# ── OCR 单据字段提取 ──
|
||||
ocr_extraction_result: dict # OCR字段提取结果(来自 OcrExtractor)
|
||||
# ── 分层精确生成 (v5) ──
|
||||
layout_schema: dict # extract_layout_schema() 输出,列+区域结构
|
||||
ocr_elements: list # OCR 原始行数据(用于阶段二坐标采样)
|
||||
|
||||
# ── OCR 与批注 (v3/v4) ──
|
||||
ocr_extraction_result: dict # OCR 字段精确提取结果
|
||||
uploaded_file_path: str # 上传图片的临时路径
|
||||
annotation_result: dict # 批注检测结果(圈选+箭头)
|
||||
```
|
||||
|
||||
**数据流向**:每个节点函数接收 `state`,修改后返回 `state`(实际上是 dict)。LangGraph 自动合并返回值到全局状态。
|
||||
@@ -225,6 +240,13 @@ def route_by_intent(state) -> Literal["retrieve", "modify_jrxml", ...]:
|
||||
def route_after_validate(state) -> Literal["finalize", "explain_error"]:
|
||||
return "finalize" if state.get("status") == "pass" else "explain_error"
|
||||
|
||||
def route_after_retrieve(state) -> Literal["generate", "generate_skeleton"]:
|
||||
"""layout_schema 有行时走 3 阶段精确生成,否则走原 1-shot"""
|
||||
schema = state.get("layout_schema")
|
||||
if schema and isinstance(schema, dict) and schema.get("total_rows", 0) > 0:
|
||||
return "generate_skeleton"
|
||||
return "generate"
|
||||
|
||||
def route_after_correct(state) -> Literal["validate", "finalize"]:
|
||||
return "validate" if state.get("retry_count", 0) < MAX_RETRY else "finalize"
|
||||
```
|
||||
@@ -234,6 +256,7 @@ def route_after_correct(state) -> Literal["validate", "finalize"]:
|
||||
|
||||
**关键路由逻辑**:
|
||||
- `route_by_intent`:8 种意图分叉,是整个系统的"交通枢纽"
|
||||
- `route_after_retrieve`:有 layout_schema → 3 阶段精确生成(generate_skeleton → refine_layout → map_fields),无 schema → 原 1-shot generate
|
||||
- `route_after_save`:预览/导出意图**跳过验证**直通 finalize(这是修复预览问题的关键)
|
||||
- `route_after_correct`:重试次数 < 3 则继续验证循环,否则认输
|
||||
|
||||
@@ -246,7 +269,7 @@ def build_graph():
|
||||
# 注册节点
|
||||
workflow.add_node("load_session", load_session_node)
|
||||
workflow.add_node("process_input", process_input)
|
||||
# ... 14 个节点
|
||||
# ... 18 个节点
|
||||
|
||||
# 连线
|
||||
workflow.set_entry_point("load_session")
|
||||
@@ -288,38 +311,53 @@ def build_graph():
|
||||
retrieve modify save_ handle_ handle_ handle_
|
||||
_jrxml session consult undo reset
|
||||
│ │ │ │ │
|
||||
▼ │ │ ▼ │
|
||||
generate │ │ save_session │
|
||||
┌────┤ │ │ ▼ │
|
||||
│ │ │ │ save_session │
|
||||
▼ │ │ │ │ │
|
||||
generate│ │ │ ▼ │
|
||||
(1-shot) │ │ │ finalize │
|
||||
│ │ │ │ │
|
||||
└───┬────┘ │ ▼ │
|
||||
│ │ finalize │
|
||||
▼ │ │
|
||||
save_session ◄───────────┘ │
|
||||
│ │
|
||||
│ ▼ │ │ │
|
||||
│ generate │ │ │
|
||||
│ _skeleton │ │ │
|
||||
│ │ │ │ │
|
||||
│ ▼ │ │ │
|
||||
│ refine │ │ │
|
||||
│ _layout │ │ │
|
||||
│ │ │ │ │
|
||||
│ ▼ │ │ │
|
||||
│ map_ │ │ │
|
||||
│ fields │ │ │
|
||||
│ │ │ │ │
|
||||
└──┬──┘ │ │ │
|
||||
│ │ │ │
|
||||
▼ │ │ │
|
||||
save_session ◄─┘ │ │
|
||||
│ │ │
|
||||
├── preview/export? ──► finalize │
|
||||
│ ▲ │
|
||||
▼ │ │
|
||||
validate ◄─────────────────────┘ │
|
||||
│ │ │
|
||||
pass fail │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ explain_error │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ correct_jrxml │
|
||||
│ │ │
|
||||
│ ├── retry < 3? ──► validate (循环) │
|
||||
│ │ │
|
||||
│ └── retry >= 3? ──► finalize (放弃) │
|
||||
│ │
|
||||
▼ │
|
||||
validate ◄────────────────────────────────┘
|
||||
│ │
|
||||
pass fail
|
||||
│ │
|
||||
│ ▼
|
||||
│ explain_error
|
||||
│ │
|
||||
│ ▼
|
||||
│ correct_jrxml
|
||||
│ │
|
||||
│ ├── retry < 3? ──► validate (循环)
|
||||
│ │
|
||||
│ └── retry >= 3? ──► finalize (放弃)
|
||||
│
|
||||
▼
|
||||
finalize ──► END
|
||||
finalize ──► END │
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. 14 个节点详解:nodes.py
|
||||
## 6. 18 个节点详解:nodes.py
|
||||
|
||||
`agent/nodes.py` 是系统的"血肉",每个节点实现一个处理步骤。
|
||||
|
||||
@@ -596,17 +634,20 @@ def load_prompt(name: str) -> str:
|
||||
|
||||
这意味着你可以直接编辑 `prompts/*.md`,下次请求立即生效,无需重启。
|
||||
|
||||
### 8.2 7 个 Prompt 文件
|
||||
### 8.2 10 个 Prompt 文件
|
||||
|
||||
| 文件 | 调用节点 | 占位符 | 用途 |
|
||||
|------|---------|--------|------|
|
||||
| `intent_classify.md` | classify_intent | `{has_report}`, `{user_input}` | 8 分类意图识别 |
|
||||
| `initial_generation.md` | generate | `{context}`, `{user_request}` | 首次生成 JRXML |
|
||||
| `modification.md` | modify_jrxml | `{current_jrxml}`, `{conversation_history}`, `{modification_request}` | 修改现有 JRXML |
|
||||
| `modification.md` | modify_jrxml | `{current_jrxml}`, `{conversation_history}`, `{modification_request}`, `{ocr_context}` | 修改现有 JRXML |
|
||||
| `correction.md` | correct_jrxml | `{current_jrxml}`, `{error_msg}`, `{explanation}` | 修正验证错误 |
|
||||
| `explain_error.md` | explain_error | `{error_msg}`, `{jrxml_snippet}` | 技术错误转人话 |
|
||||
| `compression.md` | manage_context | `{conversation_text}` | 对话摘要压缩 |
|
||||
| `consult.md` | handle_consult | `{question}` | 咨询问答 |
|
||||
| `skeleton_generation.md` | generate_skeleton | `{layout_schema}`, `{context}`, `{user_request}` | 骨架 JRXML ($F{field_N}) |
|
||||
| `refine_layout.md` | refine_layout | `{current_jrxml}`, `{sampled_coordinates}` | 像素级位置精调 |
|
||||
| `field_mapping.md` | map_fields | `{current_jrxml}`, `{ocr_fields}` | 占位符 → 真实字段名 |
|
||||
|
||||
### 8.3 Prompt 模板写法
|
||||
|
||||
@@ -663,7 +704,72 @@ class RAGSearcher:
|
||||
|
||||
---
|
||||
|
||||
## 10. 错误自增长知识库
|
||||
## 10. 分层精确生成
|
||||
|
||||
专为 A4 报表图片上传场景设计,解决 OCR 元素过多(数百个)导致 LLM prompt 超长的问题。
|
||||
|
||||
### 10.1 触发条件
|
||||
|
||||
仅当满足以下条件时走 3 阶段管线:
|
||||
- `intent == "initial_generation"`(新建报表)
|
||||
- `layout_schema` 存在且 `total_rows > 0`(成功提取布局 schema)
|
||||
|
||||
其他所有意图(modify_report、文本新建等)走原有 1-shot `generate` 节点,零行为变更。
|
||||
|
||||
### 10.2 3 阶段管线
|
||||
|
||||
```
|
||||
上传 A4 图片
|
||||
│ analyze_layout() → layout dict
|
||||
│ extract_layout_schema() → schema
|
||||
▼
|
||||
route_after_retrieve()
|
||||
├─ 有 schema → generate_skeleton → refine_layout → map_fields
|
||||
└─ 无 schema → generate (原 1-shot)
|
||||
```
|
||||
|
||||
**Phase 1: generate_skeleton**
|
||||
- 输入:压缩的布局 schema(`schema_text`:列定义 + 区域 + 宽度分类)
|
||||
- 输出:骨架 JRXML,所有字段用 `$F{field_N}` 占位
|
||||
- 目标:正确的 band 结构和大致位置
|
||||
|
||||
**Phase 2: refine_layout**
|
||||
- 输入:当前 JRXML + 采样坐标(表头行 + 首行数据 + 末行)
|
||||
- 输出:像素级位置精调后的 JRXML
|
||||
- 目标:精确的 x/y/w/h 数值,中间行通过插值处理
|
||||
|
||||
**Phase 3: map_fields**
|
||||
- 输入:当前 JRXML + OCR 字段名列表(来自 `ocr_extraction_result.fields`)
|
||||
- 输出:`$F{field_N}` → 真实字段名(如 `$F{name}`、`$F{department}`)
|
||||
- 目标:可读且可编译的完整 JRXML
|
||||
|
||||
**关键设计**:中间阶段(骨架/精调)跳过验证,只有最终 mapped 结果进入 validate 循环。
|
||||
|
||||
### 10.3 extract_layout_schema()
|
||||
|
||||
位于 `backend/layout_analyzer.py`,在 `analyze_layout()` 之后调用:
|
||||
|
||||
```python
|
||||
def extract_layout_schema(layout_result: dict) -> dict:
|
||||
# 列检测:X 坐标聚类,同列条件 → X 中心距离 < avg_width * 0.5
|
||||
# 区域分类:row[0] 元素少 → title; row[1] → header; 末尾1-2行 → footer
|
||||
# 宽度分类:< A4宽度 10% → 窄; > 25% → 宽; 其余 → 中
|
||||
# 返回: {columns, regions, total_rows, total_columns, a4_dimensions, schema_text}
|
||||
```
|
||||
|
||||
`schema_text` 示例:`"报表布局: 5列 x 10行, A4纵向\n列定义: 序号(窄), 姓名(中), 部门(中), 职位(中), 入职日期(宽)\n区域: 标题(1行) → 表头(1行) → 数据(8行)"`
|
||||
|
||||
### 10.4 _format_row_coordinates()
|
||||
|
||||
```python
|
||||
def _format_row_coordinates(row: dict) -> dict:
|
||||
# 将 OCR 单行元素转为 {y_center, columns: [{col, x, y, w, h, font_size, text}]}
|
||||
# 按 x 坐标从左到右排序
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 11. 错误自增长知识库
|
||||
|
||||
`backend/error_kb.py` — 自动积累修正成功的错误案例,下次遇到相似错误时提供参考。
|
||||
|
||||
@@ -709,9 +815,9 @@ ChromaDB 中每条记录:
|
||||
|
||||
---
|
||||
|
||||
## 11. 布局分析器
|
||||
## 12. 布局分析器
|
||||
|
||||
`backend/layout_analyzer.py` — 处理用户上传的图片/PDF,识别报表布局结构。
|
||||
`backend/layout_analyzer.py` — 处理用户上传的图片/PDF,识别报表布局结构。另有 `extract_layout_schema()` 从 OCR 行数据提取列+区域的紧凑描述(用于分层精确生成)。
|
||||
|
||||
### 11.1 三种处理路径
|
||||
|
||||
@@ -772,7 +878,7 @@ def _parse_jrxml_sections(jrxml):
|
||||
|
||||
---
|
||||
|
||||
## 12. 文件解析器
|
||||
## 13. 文件解析器
|
||||
|
||||
`backend/file_parser.py` — 统一的多格式文件解析入口。
|
||||
|
||||
@@ -784,79 +890,25 @@ def parse_file(file_path, file_type="") -> dict:
|
||||
# .png/.jpg/.jpeg/.bmp/.webp → _parse_image()
|
||||
# .pdf → _parse_pdf()
|
||||
# .docx → _parse_docx()
|
||||
# .xlsx → _parse_xlsx()
|
||||
# .xls → _parse_xls()
|
||||
# .doc → _parse_doc()
|
||||
# 其他 → _parse_text() (UTF-8 / GBK)
|
||||
```
|
||||
|
||||
### 各解析器的回退链
|
||||
|
||||
- **图片**:EasyOCR(ch_sim+en)→ PaddleOCR → 仅返回元信息 + 安装提示
|
||||
- **图片**:PaddleOCR(精确识别首选)→ EasyOCR(ch_sim+en)→ 仅返回元信息 + 安装提示
|
||||
- **PDF**:pdfplumber → PyMuPDF → 失败
|
||||
- **DOCX**:python-docx(含表格内容提取)→ 失败
|
||||
- **XLSX**:openpyxl(含多 sheet 支持)→ 失败
|
||||
- **XLS**:xlrd(旧版 Excel 格式)→ 失败
|
||||
- **DOC**:olefile(二进制格式,尽力而为提取)→ 失败
|
||||
- **文本**:UTF-8 → GBK → 失败
|
||||
|
||||
---
|
||||
|
||||
## 12b. OCR 单据字段提取器
|
||||
|
||||
`backend/ocr_extractor.py` — 两阶段精确提取单据图像中的字段值。
|
||||
|
||||
### 12b.1 数据模型
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class OcrTextElement: # OCR 文本元素,含精确坐标
|
||||
text: str
|
||||
x_min, y_min, x_max, y_max: float
|
||||
confidence: float = 1.0
|
||||
# 属性: center_x, center_y, width, height, bbox
|
||||
|
||||
@dataclass
|
||||
class ExtractedField: # 提取的字段结果
|
||||
field_name: str
|
||||
field_value: str
|
||||
bbox: list[float]
|
||||
confidence: float
|
||||
extraction_method: str # exact_match / kv_pair / regex / table_match / none
|
||||
|
||||
@dataclass
|
||||
class ExtractionResult: # 完整提取结果
|
||||
file_path: str
|
||||
image_size: tuple
|
||||
fields: list[ExtractedField]
|
||||
all_elements: list[OcrTextElement]
|
||||
errors: list[str]
|
||||
ocr_available: bool
|
||||
```
|
||||
|
||||
### 12b.2 两阶段流水线
|
||||
|
||||
**阶段1 — 文档分析** (`_analyze_document`):
|
||||
- 加载图片 → `_ocr_elements_enhanced()` → EasyOCR(ch_sim+en) → PaddleOCR 回退
|
||||
- 按 `OCR_CONFIDENCE_THRESHOLD` (默认 0.5) 过滤低置信度元素
|
||||
- 返回按 (y, x) 排序的 `OcrTextElement` 列表
|
||||
|
||||
**阶段2 — 字段提取** (`_extract_field`):
|
||||
按优先级尝试 4 种策略:
|
||||
1. **精确键值对** (`_exact_kv_match`, conf=0.95/0.85): 同一元素中 "字段名: 值" 模式
|
||||
2. **模糊键值对** (`_fuzzy_kv_match`, conf=0.75/0.60): 相邻元素匹配,同行/下一行搜索
|
||||
3. **正则模式** (`_regex_match`, conf=0.70/0.60): 12 种预定义模式 (发票代码/号码/金额/日期等)
|
||||
4. **表格结构** (`_table_match`, conf=0.55): 行列分组 + 表头匹配
|
||||
|
||||
### 12b.3 集成点
|
||||
|
||||
- **`process_input`**: 检测到上传图片后自动调用,传入 17 个默认中文字段
|
||||
- **结果注入**: 提取到的字段值自动拼入 `user_input` 前缀(`[OCR 单据字段提取结果]`)
|
||||
- **结果展示**: `app.py` 总结卡片中 "🔍 OCR 单据字段提取结果" 折叠区
|
||||
|
||||
### 12b.4 回退能力
|
||||
|
||||
- 任一 OCR 引擎不可用时静默回退,不影响主流程
|
||||
- 两种复用路径: `extract()` (全流程) 和 `extract_from_layout_result()` (复用已有布局分析)
|
||||
- 便捷函数: `extract_ocr_fields()`, `extract_from_layout()`
|
||||
|
||||
---
|
||||
|
||||
## 13. 验证服务
|
||||
## 14. 验证服务
|
||||
|
||||
`validation_service/main.py` — 独立的 FastAPI 进程,提供 JRXML 验证。
|
||||
|
||||
@@ -892,7 +944,7 @@ def validate_jrxml(jrxml_text):
|
||||
|
||||
---
|
||||
|
||||
## 14. 会话持久化
|
||||
## 15. 会话持久化
|
||||
|
||||
`backend/session.py` — 基于 JSON 文件的简单 CRUD,每个会话一个文件。
|
||||
|
||||
@@ -920,7 +972,7 @@ generate_session_id() → str # UUID hex[:12]
|
||||
|
||||
---
|
||||
|
||||
## 15. 日志系统:logger.py
|
||||
## 16. 日志系统:logger.py
|
||||
|
||||
`backend/logger.py` 提供结构化日志能力,是整个系统的"黑匣子"。
|
||||
|
||||
@@ -975,14 +1027,14 @@ backend/logger.py
|
||||
|
||||
### 15.5 `@log_node` 装饰器
|
||||
|
||||
[agent/nodes.py](file:///d:/Idea%20Project/jaspersoft/agent/nodes.py) 中 17 个节点均使用 `@log_node("节点名")` 装饰器,自动记录:
|
||||
[agent/nodes.py](file:///d:/Idea%20Project/jaspersoft/agent/nodes.py) 中 18 个节点均使用 `@log_node("节点名")` 装饰器,自动记录:
|
||||
- **入口日志** — 节点开始执行时的 state 摘要
|
||||
- **出口日志** — 节点完成时的 state 摘要 + 耗时 (duration_ms)
|
||||
- **异常日志** — 节点抛异常时的错误信息 + state 摘要
|
||||
|
||||
### 15.6 `@_log_route` 装饰器
|
||||
|
||||
[agent/graph.py](file:///d:/Idea%20Project/jaspersoft/agent/graph.py) 中 8 个路由函数均使用 `@_log_route("路由名")`,自动记录每次路由决策(from → to)。
|
||||
[agent/graph.py](file:///d:/Idea%20Project/jaspersoft/agent/graph.py) 中 9 个路由函数均使用 `@_log_route("路由名")`,自动记录每次路由决策(from → to)。
|
||||
|
||||
### 15.7 日志分析示例
|
||||
|
||||
@@ -999,7 +1051,7 @@ jq 'select(.extra.direction=="response") | {caller: .extra.caller, ms: .extra.du
|
||||
|
||||
---
|
||||
|
||||
## 16. Streamlit UI:app.py
|
||||
## 17. Streamlit UI:app.py
|
||||
|
||||
`app.py` 是整个系统的入口,约 560 行。分为几个区域:
|
||||
|
||||
@@ -1096,7 +1148,7 @@ parent.addEventListener('keydown', function(e) {
|
||||
|
||||
---
|
||||
|
||||
## 17. 配置参考
|
||||
## 18. 配置参考
|
||||
|
||||
所有配置通过 `.env` 文件管理。完整配置项:
|
||||
|
||||
@@ -1127,7 +1179,7 @@ parent.addEventListener('keydown', function(e) {
|
||||
|
||||
---
|
||||
|
||||
## 18. 如何添加新功能
|
||||
## 19. 如何添加新功能
|
||||
|
||||
### 18.1 添加新的意图类型
|
||||
|
||||
@@ -1171,7 +1223,7 @@ elif provider == "my_provider":
|
||||
|
||||
---
|
||||
|
||||
## 19. 调试指南
|
||||
## 20. 调试指南
|
||||
|
||||
### 19.1 常见问题
|
||||
|
||||
@@ -1251,24 +1303,25 @@ st.json(state) # 打印完整状态(调试用,记得删除)
|
||||
|
||||
| 文件 | 行数 | 角色 |
|
||||
|------|------|------|
|
||||
| `app.py` | ~530 | Streamlit UI 入口 |
|
||||
| `agent/state.py` | ~40 | 状态类型定义 |
|
||||
| `agent/nodes.py` | ~523 | 14 个工作流节点 |
|
||||
| `agent/graph.py` | ~232 | 状态图编译 + 路由 |
|
||||
| `app.py` | ~690 | Streamlit UI 入口(多模态聊天输入) |
|
||||
| `agent/state.py` | ~52 | 状态类型定义(28 字段) |
|
||||
| `agent/nodes.py` | ~900 | 18 个工作流节点 |
|
||||
| `agent/graph.py` | ~270 | 状态图编译 + 路由(9 个路由函数) |
|
||||
| `backend/llm.py` | ~105 | LLM 工厂 (3 个后端) |
|
||||
| `backend/rag_adapter.py` | ~156 | ChromaDB 语义搜索 |
|
||||
| `backend/error_kb.py` | ~226 | 错误知识库 |
|
||||
| `backend/embeddings.py` | ~49 | 嵌入模型工厂 |
|
||||
| `backend/file_parser.py` | ~194 | 多格式文件解析 |
|
||||
| `backend/layout_analyzer.py` | ~495 | A4 模板布局分析 |
|
||||
| `backend/ocr_extractor.py` | ~797 | OCR 单据字段精确提取 (两阶段+4策略) |
|
||||
| `backend/file_parser.py` | ~320 | 多格式文件解析(7 种格式) |
|
||||
| `backend/layout_analyzer.py` | ~600 | A4 模板布局分析 + 布局 schema 提取 |
|
||||
| `backend/ocr_extractor.py` | ~380 | OCR 字段精确提取 |
|
||||
| `backend/annotation_detector.py` | ~250 | 批注检测(圈选 + 箭头) |
|
||||
| `backend/validation.py` | ~27 | 验证服务 HTTP 客户端 |
|
||||
| `backend/session.py` | ~113 | 会话 JSON CRUD |
|
||||
| `prompts/loader.py` | ~54 | Prompt 热重载 |
|
||||
| `prompts/*.md` (7 个) | — | Prompt 模板 |
|
||||
| `prompts/*.md` (10 个) | — | Prompt 模板 |
|
||||
| `validation_service/main.py` | ~130 | FastAPI 验证服务 |
|
||||
| `tests/test_ocr_extraction.py` | ~543 | OCR 提取器单元测试 (48 项) |
|
||||
| `start.bat` | — | 一键启动脚本 (Windows) |
|
||||
| `stop.bat` | — | 一键停止脚本 (Windows) |
|
||||
| `.env.example` | ~62 | 配置模板 |
|
||||
| `requirements.txt` | ~32 | Python 依赖 |
|
||||
| `requirements.txt` | ~42 | Python 依赖 |
|
||||
|
||||
@@ -8,6 +8,11 @@
|
||||
- **自动验证**:每次生成或修改后都会验证 JRXML
|
||||
- **自动修正**:如果验证失败,代理会分析错误并自动修正(最多 3 次)
|
||||
- **模板检索**:使用 Chroma 向量数据库检索相关的 JRXML 示例以获得更好的生成效果
|
||||
- **文件上传**:支持图片(OCR识别)、PDF、Word、Excel、文本文件等
|
||||
- **聊天粘贴/拖拽**:支持直接在对话框中 Ctrl+V 粘贴或拖拽文件(图片/PDF/Excel/Word)
|
||||
- **单据OCR识别**:上传报表单据图片后自动提取所有字段(4策略优先级 + 置信度评分)
|
||||
- **批注检测**:识别手写单据上的圈选和箭头标记,自动定位用户要修改的字段
|
||||
- **分层精确生成**:A4 报表图片先提取布局 schema,再分 3 阶段(骨架→精调→字段映射)生成,避免 OCR 元素过多导致 prompt 溢出
|
||||
- **下载**:导出已验证的、可供 JasperReports 使用的 JRXML 文件
|
||||
|
||||
## 架构
|
||||
@@ -17,7 +22,7 @@ Streamlit 界面 (app.py)
|
||||
|
|
||||
LangGraph 代理 (agent/)
|
||||
|-- retrieve (Chroma/embeddings)
|
||||
|-- generate (LLM)
|
||||
|-- generate / generate_skeleton → refine_layout → map_fields (分层生成)
|
||||
|-- validate (FastAPI service)
|
||||
|-- explain + correct (auto-fix loop)
|
||||
|-- modify (multi-turn edits)
|
||||
@@ -105,11 +110,11 @@ pytest tests/ -v
|
||||
|
||||
```
|
||||
jrxml-agent/
|
||||
app.py Streamlit 聊天界面
|
||||
app.py Streamlit 聊天界面(多模态输入)
|
||||
agent/
|
||||
state.py AgentState 定义
|
||||
nodes.py 图节点(generate, validate, modify 等)
|
||||
graph.py LangGraph 状态机
|
||||
state.py AgentState 定义(28 字段)
|
||||
nodes.py 图节点(generate, generate_skeleton, refine_layout 等,18 节点)
|
||||
graph.py LangGraph 状态机(含分层生成路由)
|
||||
backend/
|
||||
llm.py LLM 工厂(Anthropic SDK / OpenAI / Ollama)
|
||||
logger.py 集中日志模块(JSON + trace_id)
|
||||
@@ -117,12 +122,14 @@ jrxml-agent/
|
||||
validation.py 验证服务客户端
|
||||
rag_adapter.py RAG 语义搜索适配器
|
||||
error_kb.py 错误自增长知识库
|
||||
file_parser.py 文件解析器(PDF/DOCX/图片)
|
||||
layout_analyzer.py A4 模板布局分析
|
||||
file_parser.py 文件解析器(PDF/DOCX/XLSX/XLS/DOC/图片/文本)
|
||||
layout_analyzer.py A4 模板布局分析(含布局 schema 提取)
|
||||
ocr_extractor.py OCR 字段精确提取(4 策略 + 置信度)
|
||||
annotation_detector.py 批注检测(圈选 + 箭头 + OCR 关联)
|
||||
session.py 会话持久化 CRUD
|
||||
prompts/
|
||||
loader.py Prompt 加载器(热重载)
|
||||
*.md 7 个 Prompt 模板文件
|
||||
*.md 10 个 Prompt 模板文件
|
||||
validation_service/
|
||||
main.py FastAPI 验证服务器
|
||||
validate.bat Windows 启动器
|
||||
@@ -137,6 +144,11 @@ jrxml-agent/
|
||||
tests/
|
||||
test_validation.py 验证服务测试
|
||||
test_agent.py 代理集成测试
|
||||
test_e2e_ocr.py OCR 端到端测试
|
||||
test_ocr_extraction.py OCR 字段提取单元测试
|
||||
test_annotation_detector.py 批注检测测试
|
||||
test_file_parser_formats.py 多格式解析测试
|
||||
test_layered_generation.py 分层生成测试
|
||||
requirements.txt
|
||||
.env.example
|
||||
README.md
|
||||
|
||||
+78
-1
@@ -122,4 +122,81 @@
|
||||
10. 结构化日志系统
|
||||
```
|
||||
|
||||
阶段一立即可做,无外部依赖。阶段二是主要工作量。阶段三是收尾。阶段四是可观测性基础。
|
||||
---
|
||||
|
||||
## 阶段五:OCR 与智能上传 (v3/v4) ✓
|
||||
|
||||
### 11. OCR 单据字段精确提取 ✓
|
||||
- [x] `backend/ocr_extractor.py` — 4 策略优先级提取 (exact_match → kv_pair → regex → table_match)
|
||||
- [x] PaddleOCR 首次识别后将原始结果(含所有文本元素 + bbox坐标)持久化
|
||||
- [x] `_format_ocr_context()` — OCR 结果格式化为 LLM prompt 注入
|
||||
- [x] `process_input` 节点在上传图片时自动触发 OCR 字段提取
|
||||
- [x] OCR 结果持久化到会话文件
|
||||
|
||||
### 12. 多模态聊天输入 ✓
|
||||
- [x] `app.py` — `st.chat_input` 替换为 `st_multimodal_chatinput`
|
||||
- [x] 支持 Ctrl+V 粘贴文件 + 拖拽 + 文件按钮
|
||||
- [x] `_process_uploaded_file()` — 提取共享文件处理逻辑(消除 ~70 行重复代码)
|
||||
- [x] 剪贴板文件 base64 解码 + MIME type → 扩展名推断
|
||||
|
||||
### 13. 多格式文件支持 ✓
|
||||
- [x] `backend/file_parser.py` — 新增 XLSX (openpyxl)、XLS (xlrd)、DOC (olefile)
|
||||
- [x] 侧边栏上传器类型列表中新增 xlsx/xls/doc
|
||||
- [x] 单元测试: `tests/test_file_parser_formats.py` (4 tests)
|
||||
|
||||
### 14. 批注检测 ✓
|
||||
- [x] `backend/annotation_detector.py` — 圈选 + 箭头 + OCR 关联
|
||||
- [x] 圆圈检测: 红色通道增强 → HoughCircles
|
||||
- [x] 箭头检测: Canny → HoughLinesP → 线段聚类 → 端点方向判定
|
||||
- [x] `format_annotation_context()` — 批注结果格式化为中文提示
|
||||
- [x] `process_input` 节点在 OCR 提取后自动运行批注检测
|
||||
- [x] `annotation_result` 字段持久化到 AgentState + 会话文件
|
||||
- [x] 单元测试: `tests/test_annotation_detector.py` (7 tests)
|
||||
|
||||
### 15. OCR 上下文 LLM 注入 ✓
|
||||
- [x] `prompts/modification.md` — 新增 `{ocr_context}` 占位符
|
||||
- [x] `modify_jrxml` + `generate` 节点注入 OCR 上下文
|
||||
- [x] OCR 上下文包含: 结构化字段、全部文本元素(含坐标)、批注检测结果
|
||||
|
||||
---
|
||||
|
||||
## 阶段六:分层精确生成 (v5) ✓
|
||||
|
||||
### 16. 布局 Schema 提取 ✓
|
||||
- [x] `backend/layout_analyzer.py` — 新增 `extract_layout_schema()` 函数(+107 行)
|
||||
- [x] X 坐标聚类列检测(avg_width * 0.5 阈值)
|
||||
- [x] 区域分类:标题/表头/数据/表尾(启发式算法)
|
||||
- [x] `schema_text` 紧凑中文描述(列定义 + 区域 + 宽度分类)
|
||||
- [x] 空行/单行/双行边界情况处理
|
||||
- [x] 单元测试: `tests/test_layered_generation.py::TestExtractLayoutSchema` (9 tests)
|
||||
|
||||
### 17. 3 阶段生成管线 ✓
|
||||
- [x] Phase 1: `generate_skeleton` — 压缩布局 schema → 骨架 JRXML (`$F{field_N}` 占位)
|
||||
- [x] Phase 2: `refine_layout` — 采样坐标(表头+首行数据+末行)→ 像素级位置精调
|
||||
- [x] Phase 3: `map_fields` — OCR 字段名 → 替换占位符为真实字段名
|
||||
- [x] 中间阶段跳过验证(仅最终 mapped 结果进入 validate 循环)
|
||||
- [x] 流式输出支持(每阶段逐字生成)
|
||||
- [x] 单元测试: `tests/test_layered_generation.py::TestIntegration` (4 tests)
|
||||
|
||||
### 18. 路由与状态 ✓
|
||||
- [x] `agent/graph.py` — 新增 `route_after_retrieve()` 条件路由
|
||||
- [x] `layout_schema.total_rows > 0` → 3 阶段,否则 → 原有 1-shot
|
||||
- [x] `agent/state.py` — 新增 `layout_schema: dict` 和 `ocr_elements: list`
|
||||
- [x] 会话持久化支持(`save_session_node` / `load_session_node`)
|
||||
- [x] 文本请求和其他意图零行为变更
|
||||
- [x] 单元测试: `tests/test_layered_generation.py::TestRouting` (4 tests)
|
||||
|
||||
### 19. Prompt 模板 ✓
|
||||
- [x] `prompts/skeleton_generation.md` — 骨架生成 prompt
|
||||
- [x] `prompts/refine_layout.md` — 布局精调 prompt
|
||||
- [x] `prompts/field_mapping.md` — 字段映射 prompt
|
||||
- [x] `prompts/loader.py` — 注册 3 个新模板(热重载)
|
||||
|
||||
### 20. UI 集成 ✓
|
||||
- [x] `app.py` — 上传 A4 图片时自动调用 `extract_layout_schema()`
|
||||
- [x] 新增节点标签:`🏗 生成骨架` / `📐 精调布局` / `🏷 映射字段`
|
||||
- [x] 3 个新节点的详情渲染
|
||||
|
||||
---
|
||||
|
||||
阶段一立即可做,无外部依赖。阶段二是主要工作量。阶段三是收尾。阶段四是可观测性基础。阶段五是 OCR 智能增强和用户体验改进。阶段六解决 A4 报表图片 OCR 元素过多(数百个)导致 LLM prompt 超长的问题。
|
||||
|
||||
+36
-1
@@ -16,6 +16,9 @@ from agent.nodes import (
|
||||
classify_intent,
|
||||
retrieve,
|
||||
generate,
|
||||
generate_skeleton,
|
||||
refine_layout,
|
||||
map_fields,
|
||||
modify_jrxml,
|
||||
handle_consult,
|
||||
handle_undo,
|
||||
@@ -87,6 +90,15 @@ def route_by_intent(state: AgentState) -> Literal[
|
||||
return "retrieve"
|
||||
|
||||
|
||||
@_log_route("route_after_retrieve")
|
||||
def route_after_retrieve(state: AgentState) -> Literal["generate", "generate_skeleton"]:
|
||||
"""当 layout_schema 存在时走三层精确生成,否则走原有 1-shot。"""
|
||||
layout_schema = state.get("layout_schema")
|
||||
if layout_schema and isinstance(layout_schema, dict) and layout_schema.get("total_rows", 0) > 0:
|
||||
return "generate_skeleton"
|
||||
return "generate"
|
||||
|
||||
|
||||
@_log_route("route_after_generate")
|
||||
def route_after_generate(state: AgentState) -> Literal["save_session"]:
|
||||
return "save_session"
|
||||
@@ -158,6 +170,11 @@ def build_graph() -> StateGraph:
|
||||
workflow.add_node("handle_undo", handle_undo)
|
||||
workflow.add_node("handle_reset", handle_reset)
|
||||
|
||||
# 新增节点:分层精确生成(阶段一~三)
|
||||
workflow.add_node("generate_skeleton", generate_skeleton)
|
||||
workflow.add_node("refine_layout", refine_layout)
|
||||
workflow.add_node("map_fields", map_fields)
|
||||
|
||||
# ---- 入口和前置流程 ----
|
||||
workflow.set_entry_point("load_session")
|
||||
workflow.add_edge("load_session", "process_input")
|
||||
@@ -180,12 +197,28 @@ def build_graph() -> StateGraph:
|
||||
)
|
||||
|
||||
# ---- 初始生成分支 ----
|
||||
workflow.add_edge("retrieve", "generate")
|
||||
workflow.add_conditional_edges(
|
||||
"retrieve",
|
||||
route_after_retrieve,
|
||||
{
|
||||
"generate": "generate",
|
||||
"generate_skeleton": "generate_skeleton",
|
||||
},
|
||||
)
|
||||
# 原有 1-shot 路径
|
||||
workflow.add_conditional_edges(
|
||||
"generate",
|
||||
route_after_generate,
|
||||
{"save_session": "save_session"},
|
||||
)
|
||||
# 分层精确生成 3 阶段路径
|
||||
workflow.add_edge("generate_skeleton", "refine_layout")
|
||||
workflow.add_edge("refine_layout", "map_fields")
|
||||
workflow.add_conditional_edges(
|
||||
"map_fields",
|
||||
route_after_generate,
|
||||
{"save_session": "save_session"},
|
||||
)
|
||||
|
||||
# ---- 修改分支 ----
|
||||
workflow.add_conditional_edges(
|
||||
@@ -264,4 +297,6 @@ def create_initial_state() -> AgentState:
|
||||
jrxml_versions=[],
|
||||
last_error_case={},
|
||||
pending_failure_context={},
|
||||
layout_schema={},
|
||||
ocr_elements=[],
|
||||
)
|
||||
|
||||
+204
-3
@@ -154,6 +154,23 @@ def process_input(state: AgentState) -> Dict:
|
||||
user_input = f"{ocr_context}\n\n{user_input}"
|
||||
# 同时更新工作对话历史中的最后一条
|
||||
conv_history[-1]["content"] = user_input
|
||||
# 批注检测(圈选/箭头标记)
|
||||
elements = ocr_result.get("elements", [])
|
||||
if elements:
|
||||
try:
|
||||
from backend.annotation_detector import detect_annotations
|
||||
ann_result = detect_annotations(uploaded_path, elements)
|
||||
if ann_result.get("total", 0) > 0:
|
||||
state["annotation_result"] = ann_result
|
||||
_node_log.info(
|
||||
"批注检测完成",
|
||||
extra={
|
||||
"circles": len(ann_result.get("circles", [])),
|
||||
"arrows": len(ann_result.get("arrows", [])),
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
_node_log.warning(f"批注检测失败: {e}")
|
||||
except Exception as e:
|
||||
_node_log.warning(f"OCR 字段提取失败: {e}")
|
||||
state["ocr_extraction_result"] = {"error": str(e)}
|
||||
@@ -379,7 +396,9 @@ def load_session_node(state: AgentState) -> Dict:
|
||||
# 恢复核心字段(不覆盖当前请求的 user_input / stage / session_id)
|
||||
for key in ("conversation_history", "full_conversation_history",
|
||||
"current_jrxml", "final_jrxml", "compressed_history",
|
||||
"session_name", "created_at", "history_states"):
|
||||
"session_name", "created_at", "history_states",
|
||||
"ocr_extraction_result", "uploaded_file_path",
|
||||
"annotation_result", "layout_schema", "ocr_elements"):
|
||||
if key in saved and key not in ("user_input", "stage", "session_id"):
|
||||
state[key] = saved[key]
|
||||
state["session_name"] = data.get("session_name", "")
|
||||
@@ -401,7 +420,9 @@ def save_session_node(state: AgentState) -> Dict:
|
||||
persistable = {}
|
||||
for key in ("session_id", "conversation_history", "full_conversation_history",
|
||||
"current_jrxml", "final_jrxml", "compressed_history",
|
||||
"status", "error_msg", "history_states"):
|
||||
"status", "error_msg", "history_states",
|
||||
"ocr_extraction_result", "uploaded_file_path",
|
||||
"annotation_result", "layout_schema", "ocr_elements"):
|
||||
if key in state:
|
||||
persistable[key] = state[key]
|
||||
persistable["updated_at"] = _now_iso()
|
||||
@@ -436,6 +457,81 @@ def _now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _format_row_coordinates(row: dict) -> dict:
|
||||
"""将单行 OCR 元素格式化为紧凑的坐标描述,供阶段二 refine_layout 使用。"""
|
||||
if not isinstance(row, dict):
|
||||
return {}
|
||||
elements = row.get("elements", [])
|
||||
if not elements:
|
||||
return {"y_center": row.get("y_center", 0), "columns": []}
|
||||
sorted_elems = sorted(elements, key=lambda e: e.get("x", 0))
|
||||
cols = []
|
||||
for ci, e in enumerate(sorted_elems):
|
||||
cols.append({
|
||||
"col": ci,
|
||||
"x": e.get("x", 0),
|
||||
"y": e.get("y", 0),
|
||||
"w": e.get("w", 0),
|
||||
"h": e.get("h", 0),
|
||||
"font_size": e.get("font_size", 12),
|
||||
"text": e.get("text", ""),
|
||||
})
|
||||
return {"y_center": row.get("y_center", 0), "columns": cols}
|
||||
|
||||
|
||||
def _format_ocr_context(state: AgentState) -> str:
|
||||
"""将 OCR 提取结果格式化为 LLM 可用的上下文文本。"""
|
||||
ocr_result = state.get("ocr_extraction_result")
|
||||
if not ocr_result or not isinstance(ocr_result, dict):
|
||||
return ""
|
||||
if ocr_result.get("error"):
|
||||
return ""
|
||||
|
||||
parts = []
|
||||
parts.append("[图片OCR识别结果]")
|
||||
|
||||
total = ocr_result.get("total_elements", 0)
|
||||
if total:
|
||||
parts.append(f"检测到 {total} 个文字元素")
|
||||
|
||||
# 提取到的字段
|
||||
fields = ocr_result.get("fields", [])
|
||||
if fields:
|
||||
parts.append("\n提取的结构化字段:")
|
||||
for f in fields:
|
||||
if f.get("field_value"):
|
||||
parts.append(
|
||||
f" - {f['field_name']}: {f['field_value']} "
|
||||
f"(方法={f.get('extraction_method','?')}, "
|
||||
f"置信度={f.get('confidence',0):.2f})"
|
||||
)
|
||||
|
||||
# 所有原始文本(用于表格匹配等需要全文的场景)
|
||||
elements = ocr_result.get("elements", [])
|
||||
if elements:
|
||||
parts.append("\n全部文本元素(含坐标):")
|
||||
for e in elements:
|
||||
bbox = e.get("bbox", {})
|
||||
x, y, w, h = bbox.get("x", 0), bbox.get("y", 0), bbox.get("w", 0), bbox.get("h", 0)
|
||||
parts.append(
|
||||
f" [{x},{y} {w}×{h}] {e['text']} "
|
||||
f"(置信度={e.get('confidence',0):.2f})"
|
||||
)
|
||||
|
||||
# 批注检测结果
|
||||
ann_result = state.get("annotation_result")
|
||||
if ann_result and isinstance(ann_result, dict):
|
||||
try:
|
||||
from backend.annotation_detector import format_annotation_context
|
||||
ann_text = format_annotation_context(ann_result)
|
||||
if ann_text:
|
||||
parts.append("\n" + ann_text)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
@log_node("retrieve")
|
||||
def retrieve(state: AgentState) -> Dict:
|
||||
"""在 ChromaDB + 错误知识库中搜索相关的 JRXML 模板和组件。"""
|
||||
@@ -466,9 +562,15 @@ def generate(state: AgentState) -> Dict:
|
||||
|
||||
writer = get_stream_writer()
|
||||
llm = get_llm(caller="generate")
|
||||
|
||||
user_request = state.get("user_input", "")
|
||||
ocr_text = _format_ocr_context(state)
|
||||
if ocr_text:
|
||||
user_request = f"{ocr_text}\n\n---\n用户需求:\n{user_request}"
|
||||
|
||||
prompt = load_prompt("initial_generation").format(
|
||||
context=state.get("retrieved_context", ""),
|
||||
user_request=state.get("user_input", ""),
|
||||
user_request=user_request,
|
||||
)
|
||||
full = []
|
||||
for chunk in llm.stream(prompt):
|
||||
@@ -480,6 +582,104 @@ def generate(state: AgentState) -> Dict:
|
||||
return state
|
||||
|
||||
|
||||
@log_node("generate_skeleton")
|
||||
def generate_skeleton(state: AgentState) -> Dict:
|
||||
"""阶段一:根据压缩的布局 schema 生成骨架 JRXML($F{field_N} 占位)。"""
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
writer = get_stream_writer()
|
||||
llm = get_llm(caller="generate_skeleton")
|
||||
|
||||
schema = state.get("layout_schema", {})
|
||||
schema_text = schema.get("schema_text", "") if isinstance(schema, dict) else ""
|
||||
user_request = state.get("user_input", "")
|
||||
|
||||
prompt = load_prompt("skeleton_generation").format(
|
||||
layout_schema=schema_text,
|
||||
context=state.get("retrieved_context", ""),
|
||||
user_request=user_request,
|
||||
)
|
||||
full = []
|
||||
for chunk in llm.stream(prompt):
|
||||
full.append(chunk)
|
||||
writer({"type": "stream", "node": "generate_skeleton", "text": chunk})
|
||||
jrxml = _extract_jrxml("".join(full))
|
||||
state["current_jrxml"] = jrxml
|
||||
state["conversation_history"].append({"role": "assistant", "content": jrxml})
|
||||
return state
|
||||
|
||||
|
||||
@log_node("refine_layout")
|
||||
def refine_layout(state: AgentState) -> Dict:
|
||||
"""阶段二:使用采样坐标(表头 + 首行数据 + 最后一行)精确调整元素位置。"""
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
writer = get_stream_writer()
|
||||
llm = get_llm(caller="refine_layout")
|
||||
|
||||
ocr_rows = state.get("ocr_elements", [])
|
||||
sampled = {}
|
||||
if isinstance(ocr_rows, list) and len(ocr_rows) >= 1:
|
||||
sampled["header_row"] = _format_row_coordinates(ocr_rows[0])
|
||||
if len(ocr_rows) > 1:
|
||||
sampled["first_data_row"] = _format_row_coordinates(ocr_rows[1])
|
||||
if len(ocr_rows) > 2:
|
||||
sampled["last_row"] = _format_row_coordinates(ocr_rows[-1])
|
||||
sampled_text = json.dumps(sampled, ensure_ascii=False, indent=2)
|
||||
|
||||
prompt = load_prompt("refine_layout").format(
|
||||
current_jrxml=state.get("current_jrxml", ""),
|
||||
sampled_coordinates=sampled_text,
|
||||
)
|
||||
full = []
|
||||
for chunk in llm.stream(prompt):
|
||||
full.append(chunk)
|
||||
writer({"type": "stream", "node": "refine_layout", "text": chunk})
|
||||
jrxml = _extract_jrxml("".join(full))
|
||||
state["current_jrxml"] = jrxml
|
||||
state["conversation_history"].append({"role": "assistant", "content": jrxml})
|
||||
return state
|
||||
|
||||
|
||||
@log_node("map_fields")
|
||||
def map_fields(state: AgentState) -> Dict:
|
||||
"""阶段三:将占位字段名替换为 OCR 提取的真实字段名。"""
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
writer = get_stream_writer()
|
||||
llm = get_llm(caller="map_fields")
|
||||
|
||||
ocr_result = state.get("ocr_extraction_result", {})
|
||||
fields_text = ""
|
||||
if isinstance(ocr_result, dict) and ocr_result.get("fields"):
|
||||
field_descs = []
|
||||
for f in ocr_result["fields"]:
|
||||
fname = f.get("field_name", "")
|
||||
fval = f.get("field_value", "")
|
||||
if fname:
|
||||
field_descs.append(f" - {fname}: {fval}")
|
||||
if field_descs:
|
||||
fields_text = "提取的字段:\n" + "\n".join(field_descs)
|
||||
if not fields_text:
|
||||
elements = ocr_result.get("elements", []) if isinstance(ocr_result, dict) else []
|
||||
if elements:
|
||||
texts = [e.get("text", "") for e in elements if e.get("text")]
|
||||
fields_text = "OCR 文本内容:\n" + "\n".join(f" - {t}" for t in texts[:50])
|
||||
|
||||
prompt = load_prompt("field_mapping").format(
|
||||
current_jrxml=state.get("current_jrxml", ""),
|
||||
ocr_fields=fields_text,
|
||||
)
|
||||
full = []
|
||||
for chunk in llm.stream(prompt):
|
||||
full.append(chunk)
|
||||
writer({"type": "stream", "node": "map_fields", "text": chunk})
|
||||
jrxml = _extract_jrxml("".join(full))
|
||||
state["current_jrxml"] = jrxml
|
||||
state["conversation_history"].append({"role": "assistant", "content": jrxml})
|
||||
return state
|
||||
|
||||
|
||||
@log_node("modify_jrxml")
|
||||
def modify_jrxml(state: AgentState) -> Dict:
|
||||
"""根据用户的修改请求修改现有 JRXML。"""
|
||||
@@ -500,6 +700,7 @@ def modify_jrxml(state: AgentState) -> Dict:
|
||||
current_jrxml=state.get("current_jrxml", ""),
|
||||
conversation_history=conv_text,
|
||||
modification_request=state.get("user_modification_request", ""),
|
||||
ocr_context=_format_ocr_context(state),
|
||||
)
|
||||
full = []
|
||||
for chunk in llm.stream(prompt):
|
||||
|
||||
@@ -44,3 +44,10 @@ class AgentState(TypedDict, total=False):
|
||||
# 需求7:OCR 单据字段精确提取结果
|
||||
ocr_extraction_result: dict
|
||||
uploaded_file_path: str
|
||||
|
||||
# 需求8:图片批注检测(圈选/箭头标记)
|
||||
annotation_result: dict
|
||||
|
||||
# 需求9:分层精确生成
|
||||
layout_schema: dict # extract_layout_schema() 输出,列+区域结构
|
||||
ocr_elements: list # OCR 原始行数据(用于阶段二坐标采样)
|
||||
|
||||
@@ -21,7 +21,6 @@ import time
|
||||
from pathlib import Path
|
||||
|
||||
import streamlit as st
|
||||
import streamlit.components.v1 as components
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
@@ -81,6 +80,9 @@ NODE_LABELS = {
|
||||
"handle_undo": "↩ 撤销操作",
|
||||
"handle_reset": "🔄 重置会话",
|
||||
"save_session": "💾 保存会话",
|
||||
"generate_skeleton": "🏗 生成骨架",
|
||||
"refine_layout": "📐 精调布局",
|
||||
"map_fields": "🏷 映射字段",
|
||||
}
|
||||
|
||||
INTENT_LABELS = {
|
||||
@@ -107,6 +109,86 @@ def _render_jrxml(jrxml: str, max_lines: int = 30):
|
||||
st.code(preview, language="xml")
|
||||
|
||||
|
||||
# ---- 共享文件上传处理 ----
|
||||
def _process_uploaded_file(uploaded_file, suffix: str) -> dict:
|
||||
"""处理单个上传文件:保存临时文件、解析、布局分析。
|
||||
|
||||
返回: {"name": str, "text": str, "type": str, "tmp_path": str|None}
|
||||
"""
|
||||
import tempfile
|
||||
from backend.file_parser import parse_file
|
||||
from backend.layout_analyzer import analyze_layout
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
||||
tmp.write(uploaded_file.getvalue())
|
||||
tmp_path = tmp.name
|
||||
|
||||
result = parse_file(tmp_path, suffix)
|
||||
parsed_text = result["text"]
|
||||
parsed_type = result["file_type"]
|
||||
|
||||
# 对图片/PDF 进行 A4 模板布局分析
|
||||
if suffix in (".png", ".jpg", ".jpeg", ".bmp", ".webp", ".pdf"):
|
||||
layout = analyze_layout(tmp_path)
|
||||
tt = layout.get("template_type", "unknown")
|
||||
current_jrxml = st.session_state.agent_state.get("current_jrxml", "")
|
||||
|
||||
if tt == "full_a4":
|
||||
parsed_text = layout["description"]
|
||||
parsed_type = "a4_template"
|
||||
# 存储布局 schema 供分层精确生成使用
|
||||
from backend.layout_analyzer import extract_layout_schema
|
||||
schema = extract_layout_schema(layout)
|
||||
st.session_state.agent_state["layout_schema"] = schema
|
||||
st.session_state.agent_state["ocr_elements"] = layout.get("rows", [])
|
||||
elif tt == "partial_rows":
|
||||
parsed_type = "a4_partial"
|
||||
if current_jrxml.strip():
|
||||
from backend.layout_analyzer import match_rows_to_jrxml
|
||||
match = match_rows_to_jrxml(layout, current_jrxml)
|
||||
parsed_text = (
|
||||
f"[行片段修改] 上传图片包含 {layout['total_rows']} 行,"
|
||||
f"视为 A4 报表的一部分。\n\n"
|
||||
f"{match['description']}\n\n"
|
||||
f"--- 行结构 ---\n{layout['description']}"
|
||||
)
|
||||
else:
|
||||
parsed_text = layout["description"]
|
||||
else:
|
||||
has_ocr = result.get("method") not in ("metadata_only", None)
|
||||
img_w, img_h = layout["image_size"]
|
||||
ratio = layout["aspect_ratio"]
|
||||
if has_ocr:
|
||||
parsed_text = (
|
||||
f"[图片上传] 尺寸 {img_w}x{img_h}px, 比例 {ratio}。"
|
||||
f"未检测到 A4 报表结构,图片将被视为参考样式。\n"
|
||||
f"请根据用户的文字描述生成报表。"
|
||||
)
|
||||
else:
|
||||
parsed_text = (
|
||||
f"[图片上传] 尺寸 {img_w}x{img_h}px, 比例 {ratio}。\n"
|
||||
f"⚠ OCR 引擎未安装,无法识别图片中的文字内容。\n"
|
||||
f"请严格根据用户的文字描述来推断图片中的报表需求。\n"
|
||||
f"(提示:如需图片文字识别,请运行 pip install paddleocr)"
|
||||
)
|
||||
parsed_type = "image_reference"
|
||||
|
||||
elif suffix in (".pdf", ".docx", ".xlsx", ".xls", ".doc"):
|
||||
parsed_type = suffix.lstrip(".")
|
||||
|
||||
keep_temp = (
|
||||
suffix in (".png", ".jpg", ".jpeg", ".bmp", ".webp")
|
||||
and result.get("method") not in ("metadata_only", None)
|
||||
)
|
||||
|
||||
return {
|
||||
"name": uploaded_file.name,
|
||||
"text": parsed_text,
|
||||
"type": parsed_type,
|
||||
"tmp_path": tmp_path if keep_temp else None,
|
||||
}
|
||||
|
||||
|
||||
# ---- URL 参数 ----
|
||||
query_params = st.query_params
|
||||
url_session_id = query_params.get("session_id", "")
|
||||
@@ -118,11 +200,6 @@ if "graph" not in st.session_state:
|
||||
st.session_state.graph = build_graph()
|
||||
if "pending_action" not in st.session_state:
|
||||
st.session_state.pending_action = None
|
||||
if "chat_attached_files" not in st.session_state:
|
||||
st.session_state.chat_attached_files = [] # [{name, text, type, path}]
|
||||
if "_paste_processed_ts" not in st.session_state:
|
||||
st.session_state._paste_processed_ts = 0
|
||||
|
||||
if "agent_state" not in st.session_state:
|
||||
if url_session_id:
|
||||
data = load_session(url_session_id)
|
||||
@@ -220,7 +297,8 @@ def run_agent(user_input: str):
|
||||
f"找到 {len(ctx)} 字符参考模板" if ctx else "未匹配到模板"
|
||||
)
|
||||
|
||||
elif node_name in ("generate", "modify_jrxml", "correct_jrxml"):
|
||||
elif node_name in ("generate", "modify_jrxml", "correct_jrxml",
|
||||
"generate_skeleton", "refine_layout", "map_fields"):
|
||||
jrxml = node_state.get("current_jrxml", "")
|
||||
executed_nodes[-1]["detail"] = f"生成 {len(jrxml)} 字符 JRXML"
|
||||
|
||||
@@ -491,7 +569,8 @@ with st.sidebar:
|
||||
|
||||
uploaded = st.file_uploader(
|
||||
"选择文件",
|
||||
type=["png", "jpg", "jpeg", "bmp", "webp", "pdf", "docx", "xlsx", "txt", "csv", "json", "xml"],
|
||||
type=["png", "jpg", "jpeg", "bmp", "webp", "pdf", "docx", "xlsx", "xls", "doc",
|
||||
"txt", "csv", "json", "xml"],
|
||||
accept_multiple_files=True,
|
||||
key="file_uploader",
|
||||
label_visibility="collapsed",
|
||||
@@ -502,77 +581,21 @@ with st.sidebar:
|
||||
# 去重
|
||||
if any(f["name"] == uf.name for f in st.session_state.uploaded_files):
|
||||
continue
|
||||
import tempfile
|
||||
from backend.file_parser import parse_file
|
||||
from backend.layout_analyzer import analyze_layout
|
||||
|
||||
suffix = Path(uf.name).suffix.lower()
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
||||
tmp.write(uf.getvalue())
|
||||
tmp_path = tmp.name
|
||||
result = _process_uploaded_file(uf, suffix)
|
||||
|
||||
result = parse_file(tmp_path, suffix)
|
||||
|
||||
# 对图片/PDF 进行 A4 模板布局分析
|
||||
parsed_text = result["text"]
|
||||
parsed_type = result["file_type"]
|
||||
if suffix in (".png", ".jpg", ".jpeg", ".bmp", ".webp", ".pdf"):
|
||||
layout = analyze_layout(tmp_path)
|
||||
tt = layout.get("template_type", "unknown")
|
||||
current_jrxml = st.session_state.agent_state.get("current_jrxml", "")
|
||||
|
||||
if tt == "full_a4":
|
||||
parsed_text = layout["description"]
|
||||
parsed_type = "a4_template"
|
||||
elif tt == "partial_rows":
|
||||
parsed_type = "a4_partial"
|
||||
if current_jrxml.strip():
|
||||
# 修改模式:尝试行匹配
|
||||
from backend.layout_analyzer import match_rows_to_jrxml
|
||||
match = match_rows_to_jrxml(layout, current_jrxml)
|
||||
parsed_text = (
|
||||
f"[行片段修改] 上传图片包含 {layout['total_rows']} 行,"
|
||||
f"视为 A4 报表的一部分。\n\n"
|
||||
f"{match['description']}\n\n"
|
||||
f"--- 行结构 ---\n{layout['description']}"
|
||||
)
|
||||
else:
|
||||
# 新建模式:按 A4 模板处理
|
||||
parsed_text = layout["description"]
|
||||
else:
|
||||
# tt == "unknown": OCR 不可用或未检测到文字元素
|
||||
has_ocr = result.get("method") not in ("metadata_only", None)
|
||||
img_w, img_h = layout["image_size"]
|
||||
ratio = layout["aspect_ratio"]
|
||||
if has_ocr:
|
||||
parsed_text = (
|
||||
f"[图片上传] 尺寸 {img_w}x{img_h}px, 比例 {ratio}。"
|
||||
f"未检测到 A4 报表结构,图片将被视为参考样式。\n"
|
||||
f"请根据用户的文字描述生成报表。"
|
||||
)
|
||||
else:
|
||||
parsed_text = (
|
||||
f"[图片上传] 尺寸 {img_w}x{img_h}px, 比例 {ratio}。\n"
|
||||
f"⚠ OCR 引擎未安装,无法识别图片中的文字内容。\n"
|
||||
f"请严格根据用户的文字描述来推断图片中的报表需求。\n"
|
||||
f"(提示:如需图片文字识别,请运行 pip install paddleocr)"
|
||||
)
|
||||
parsed_type = "image_reference"
|
||||
|
||||
if parsed_text:
|
||||
if result["text"]:
|
||||
st.session_state.uploaded_files.append({
|
||||
"name": uf.name,
|
||||
"text": parsed_text,
|
||||
"type": parsed_type,
|
||||
"name": result["name"],
|
||||
"text": result["text"],
|
||||
"type": result["type"],
|
||||
})
|
||||
|
||||
# 对图片类型,保存路径以便 OCR 字段提取(延迟到 process_input 阶段)
|
||||
img_suffixes = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
|
||||
if suffix in img_suffixes and result.get("method") not in ("metadata_only", None):
|
||||
tmp_path = result["tmp_path"]
|
||||
if tmp_path:
|
||||
st.session_state.agent_state["uploaded_file_path"] = tmp_path
|
||||
st.session_state.uploaded_temp_paths.append(tmp_path)
|
||||
else:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
if st.session_state.uploaded_files:
|
||||
for i, f in enumerate(st.session_state.uploaded_files):
|
||||
@@ -624,95 +647,6 @@ with st.sidebar:
|
||||
key=f"dl_v{i}",
|
||||
)
|
||||
|
||||
# ---- 文件粘贴/拖拽全局处理器 ----
|
||||
st.html("""
|
||||
<script>
|
||||
(function() {
|
||||
if (window.__jrxml_drop_paste) return;
|
||||
window.__jrxml_drop_paste = true;
|
||||
var MAX_SIZE = 20 * 1024 * 1024;
|
||||
function handleFiles(files) {
|
||||
var fd = []; var n = 0; var total = Math.min(files.length, 10);
|
||||
for (var i = 0; i < total; i++) {
|
||||
var f = files[i];
|
||||
if (f.size > MAX_SIZE) { n++; continue; }
|
||||
var reader = new FileReader();
|
||||
reader.onload = (function(file) {
|
||||
return function(e) {
|
||||
fd.push({name: file.name, size: file.size, data: e.target.result});
|
||||
n++;
|
||||
if (n === total && fd.length) {
|
||||
sessionStorage.setItem('_jrxml_paste', JSON.stringify({ts: Date.now(), files: fd}));
|
||||
}
|
||||
};
|
||||
})(f);
|
||||
reader.readAsDataURL(f);
|
||||
}
|
||||
}
|
||||
document.addEventListener('paste', function(e) {
|
||||
var fs = e.clipboardData && e.clipboardData.files;
|
||||
if (fs && fs.length) { e.preventDefault(); handleFiles(fs); }
|
||||
});
|
||||
document.addEventListener('dragover', function(e) {
|
||||
e.preventDefault(); e.dataTransfer.dropEffect = 'copy';
|
||||
});
|
||||
document.addEventListener('drop', function(e) {
|
||||
var fs = e.dataTransfer && e.dataTransfer.files;
|
||||
if (fs && fs.length) { e.preventDefault(); handleFiles(fs); }
|
||||
});
|
||||
})();
|
||||
</script>
|
||||
""")
|
||||
|
||||
# ---- 粘贴桥接组件 ----
|
||||
paste_data = components.html("""
|
||||
<script>
|
||||
(function poll() {
|
||||
var raw = sessionStorage.getItem('_jrxml_paste');
|
||||
if (raw) {
|
||||
try { sessionStorage.removeItem('_jrxml_paste'); Streamlit.setComponentValue(JSON.parse(raw)); return; }
|
||||
catch(e) {}
|
||||
}
|
||||
setTimeout(poll, 800);
|
||||
})();
|
||||
</script>
|
||||
""", height=0, default=0)
|
||||
|
||||
if paste_data and paste_data != 0:
|
||||
pts = paste_data.get("ts", 0)
|
||||
if pts > st.session_state._paste_processed_ts:
|
||||
st.session_state._paste_processed_ts = pts
|
||||
import base64, tempfile
|
||||
from backend.file_parser import parse_file
|
||||
from backend.layout_analyzer import analyze_layout
|
||||
for fi in paste_data.get("files", []):
|
||||
if not any(f["name"] == fi["name"] for f in st.session_state.chat_attached_files):
|
||||
header, b64 = fi["data"].split(",", 1)
|
||||
raw = base64.b64decode(b64)
|
||||
suffix = Path(fi["name"]).suffix.lower()
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
||||
tmp.write(raw)
|
||||
tmp_path = tmp.name
|
||||
result = parse_file(tmp_path, suffix)
|
||||
text = result["text"]
|
||||
file_type = result["file_type"]
|
||||
img_suffixes = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
|
||||
if suffix in img_suffixes and result.get("method") not in ("metadata_only", None):
|
||||
try:
|
||||
layout = analyze_layout(tmp_path)
|
||||
tt = layout.get("template_type", "unknown")
|
||||
if tt == "full_a4":
|
||||
text = layout["description"]
|
||||
file_type = "a4_template"
|
||||
elif tt == "partial_rows":
|
||||
file_type = "a4_partial"
|
||||
except Exception:
|
||||
pass
|
||||
st.session_state.chat_attached_files.append({
|
||||
"name": fi["name"], "text": text, "type": file_type, "path": tmp_path
|
||||
})
|
||||
st.rerun()
|
||||
|
||||
# ---- 标题 ----
|
||||
st.title("📝 JRXML 报表生成器")
|
||||
st.caption("用自然语言描述您的报表需求,我将逐步生成可用的 JRXML 模板。")
|
||||
@@ -732,127 +666,106 @@ for msg in st.session_state.messages:
|
||||
else:
|
||||
st.markdown(msg["content"])
|
||||
|
||||
# ---- 已附加文件预览 ----
|
||||
if st.session_state.chat_attached_files:
|
||||
n_files = len(st.session_state.chat_attached_files)
|
||||
chip_cols = st.columns(min(n_files, 4))
|
||||
files_to_remove = []
|
||||
for i, f in enumerate(st.session_state.chat_attached_files):
|
||||
with chip_cols[i % len(chip_cols)]:
|
||||
c1, c2 = st.columns([5, 1])
|
||||
with c1:
|
||||
name = f["name"]
|
||||
short_name = name[:16] + ("…" if len(name) > 16 else "")
|
||||
emoji_map = {"a4_template": "📷", "image": "🖼", "pdf": "📄", "docx": "📝", "xlsx": "📊"}
|
||||
emoji = emoji_map.get(f["type"], "📎")
|
||||
st.caption(f"{emoji} {short_name}")
|
||||
with c2:
|
||||
if st.button("✕", key=f"rm_chip_{i}"):
|
||||
files_to_remove.append(i)
|
||||
if files_to_remove:
|
||||
for i in sorted(files_to_remove, reverse=True):
|
||||
try:
|
||||
Path(st.session_state.chat_attached_files[i]["path"]).unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
st.session_state.chat_attached_files.pop(i)
|
||||
st.rerun()
|
||||
# ---- 聊天输入(支持粘贴/拖拽文件) ----
|
||||
from st_multimodal_chatinput import multimodal_chatinput
|
||||
import base64
|
||||
import io
|
||||
from pathlib import Path as _Path
|
||||
|
||||
# ---- 对话区域文件上传 ----
|
||||
col_fu, col_hint = st.columns([5, 1])
|
||||
with col_fu:
|
||||
chat_uploads = st.file_uploader(
|
||||
"附加文件",
|
||||
type=["png", "jpg", "jpeg", "bmp", "webp", "pdf", "docx", "xlsx", "txt", "csv", "json", "xml"],
|
||||
accept_multiple_files=True,
|
||||
key="chat_file_uploader",
|
||||
label_visibility="visible",
|
||||
)
|
||||
with col_hint:
|
||||
st.caption("Ctrl+V 粘贴\n或拖拽到页面")
|
||||
# MIME type → 文件扩展名映射(用于剪贴板粘贴无扩展名的文件)
|
||||
MIME_TO_EXT = {
|
||||
"image/png": ".png",
|
||||
"image/jpeg": ".jpg",
|
||||
"image/bmp": ".bmp",
|
||||
"image/webp": ".webp",
|
||||
"application/pdf": ".pdf",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
||||
"application/vnd.ms-excel": ".xls",
|
||||
"application/msword": ".doc",
|
||||
"text/plain": ".txt",
|
||||
"text/csv": ".csv",
|
||||
"application/json": ".json",
|
||||
"text/xml": ".xml",
|
||||
}
|
||||
|
||||
if chat_uploads:
|
||||
newly_added = False
|
||||
import tempfile
|
||||
from backend.file_parser import parse_file
|
||||
from backend.layout_analyzer import analyze_layout
|
||||
for uf in chat_uploads:
|
||||
if not any(f["name"] == uf.name for f in st.session_state.chat_attached_files):
|
||||
suffix = Path(uf.name).suffix.lower()
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
||||
tmp.write(uf.getvalue())
|
||||
tmp_path = tmp.name
|
||||
result = parse_file(tmp_path, suffix)
|
||||
text = result["text"]
|
||||
file_type = result["file_type"]
|
||||
img_suffixes = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
|
||||
if suffix in img_suffixes and result.get("method") not in ("metadata_only", None):
|
||||
try:
|
||||
layout = analyze_layout(tmp_path)
|
||||
tt = layout.get("template_type", "unknown")
|
||||
if tt == "full_a4":
|
||||
text = layout["description"]
|
||||
file_type = "a4_template"
|
||||
elif tt == "partial_rows":
|
||||
file_type = "a4_partial"
|
||||
except Exception:
|
||||
pass
|
||||
st.session_state.chat_attached_files.append({
|
||||
"name": uf.name, "text": text, "type": file_type, "path": tmp_path
|
||||
})
|
||||
newly_added = True
|
||||
if newly_added:
|
||||
st.session_state.chat_file_uploader = []
|
||||
st.rerun()
|
||||
chat_result = multimodal_chatinput()
|
||||
if chat_result:
|
||||
prompt = (chat_result.get("textInput") or "").strip()
|
||||
chat_files = chat_result.get("uploadedFiles") or []
|
||||
|
||||
# ---- 聊天输入 ----
|
||||
if prompt := st.chat_input("描述您的报表需求..."):
|
||||
# 拼接对话区域附加文件的文本
|
||||
file_texts = []
|
||||
attached_info = []
|
||||
for f in st.session_state.chat_attached_files:
|
||||
file_texts.append(f"[附加文件: {f['name']} ({f['type']})]\n{f['text']}")
|
||||
attached_info.append({"name": f["name"], "type": f["type"], "length": len(f["text"])})
|
||||
# 处理聊天中上传/粘贴的文件
|
||||
uploaded_texts = []
|
||||
uploaded_files_info = []
|
||||
|
||||
# 同时拼接侧边栏上传的文件(向后兼容)
|
||||
# 先收集侧边栏已上传的文件
|
||||
if st.session_state.get("uploaded_files"):
|
||||
for f in st.session_state.uploaded_files:
|
||||
file_texts.append(f"[上传文件: {f['name']}]\n{f['text']}")
|
||||
attached_info.append({"name": f["name"], "type": f["type"], "length": len(f["text"])})
|
||||
uploaded_texts.append(f"[上传文件: {f['name']}]\n{f['text']}")
|
||||
uploaded_files_info.append({"name": f["name"], "type": f["type"], "length": len(f["text"])})
|
||||
st.session_state.uploaded_files = []
|
||||
|
||||
if file_texts:
|
||||
full_prompt = "\n\n".join(file_texts) + "\n\n---\n用户需求:\n" + prompt
|
||||
# 处理聊天中的文件
|
||||
class _Base64File:
|
||||
"""包装 base64 文件为类 UploadedFile 接口。"""
|
||||
def __init__(self, name, data_bytes):
|
||||
self.name = name
|
||||
self._data = data_bytes
|
||||
|
||||
def getvalue(self):
|
||||
return self._data
|
||||
|
||||
for cf in chat_files:
|
||||
name = cf.get("name", "clipboard_file")
|
||||
mime = cf.get("type", "")
|
||||
content_b64 = cf.get("content", "")
|
||||
if not content_b64:
|
||||
continue
|
||||
try:
|
||||
data = base64.b64decode(content_b64)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
suffix = _Path(name).suffix.lower()
|
||||
if not suffix and mime in MIME_TO_EXT:
|
||||
suffix = MIME_TO_EXT[mime]
|
||||
name = f"{_Path(name).stem}{suffix}"
|
||||
|
||||
wrapper = _Base64File(name, data)
|
||||
result = _process_uploaded_file(wrapper, suffix)
|
||||
|
||||
if result["text"]:
|
||||
uploaded_texts.append(f"[上传文件: {result['name']}]\n{result['text']}")
|
||||
uploaded_files_info.append({"name": result["name"], "type": result["type"], "length": len(result["text"])})
|
||||
|
||||
tmp_path = result["tmp_path"]
|
||||
if tmp_path:
|
||||
st.session_state.agent_state["uploaded_file_path"] = tmp_path
|
||||
st.session_state.uploaded_temp_paths.append(tmp_path)
|
||||
|
||||
if prompt or uploaded_texts:
|
||||
if uploaded_texts:
|
||||
full_prompt = "\n\n".join(uploaded_texts)
|
||||
if prompt:
|
||||
full_prompt += "\n\n---\n用户需求:\n" + prompt
|
||||
else:
|
||||
full_prompt = prompt
|
||||
|
||||
# 将第一个图片文件的路径传给 agent,供 OCR 字段精确提取
|
||||
for f in st.session_state.chat_attached_files:
|
||||
if f["type"] in ("image", "a4_template", "a4_partial"):
|
||||
st.session_state.agent_state["uploaded_file_path"] = f["path"]
|
||||
break
|
||||
|
||||
# 清理临时文件和状态
|
||||
st.session_state.uploaded_files = []
|
||||
for f in st.session_state.chat_attached_files:
|
||||
try:
|
||||
Path(f["path"]).unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
st.session_state.chat_attached_files = []
|
||||
displayed_prompt = prompt or "(已上传文件,未输入文字)"
|
||||
|
||||
_app_log.info(
|
||||
"收到用户输入",
|
||||
extra={
|
||||
"session_id": current_session_id,
|
||||
"prompt_preview": prompt[:200],
|
||||
"prompt_length": len(prompt),
|
||||
"has_uploaded_files": bool(attached_info),
|
||||
"uploaded_files": attached_info,
|
||||
"prompt_preview": displayed_prompt[:200],
|
||||
"prompt_length": len(full_prompt),
|
||||
"has_uploaded_files": bool(uploaded_files_info),
|
||||
"uploaded_files": uploaded_files_info,
|
||||
},
|
||||
)
|
||||
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
st.session_state.messages.append({"role": "user", "content": displayed_prompt})
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
st.markdown(displayed_prompt)
|
||||
run_agent(full_prompt)
|
||||
st.rerun()
|
||||
|
||||
@@ -0,0 +1,331 @@
|
||||
"""批注检测器:识别图片上的圈选(圆)和箭头,定位用户要修改的字段。
|
||||
|
||||
依赖 OpenCV (cv2),从 PaddleOCR 传递依赖已安装。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class Annotation:
|
||||
"""单个批注标记。"""
|
||||
type: str # "circle" | "arrow"
|
||||
bbox: dict # {"x": int, "y": int, "w": int, "h": int}
|
||||
center: tuple[int, int] # (cx, cy)
|
||||
nearby_texts: list[str] = field(default_factory=list)
|
||||
from_text: str = "" # 箭头出发点的文本
|
||||
to_text: str = "" # 箭头指向的文本
|
||||
from_pt: Optional[tuple[int, int]] = None
|
||||
to_pt: Optional[tuple[int, int]] = None
|
||||
|
||||
|
||||
def detect_annotations(image_path: str, ocr_elements: list[dict]) -> dict:
|
||||
"""检测图片上的手写批注(圈选 + 箭头),并与 OCR 文本关联。
|
||||
|
||||
Args:
|
||||
image_path: 图片文件路径
|
||||
ocr_elements: OCR 元素列表 [{"text": str, "bbox": {x,y,w,h}, "confidence": float}]
|
||||
|
||||
Returns:
|
||||
{"circles": [...], "arrows": [...], "total": int}
|
||||
"""
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
return {"circles": [], "arrows": [], "total": 0, "error": "无法读取图片"}
|
||||
|
||||
h, w = img.shape[:2]
|
||||
|
||||
circles = _detect_circles(img)
|
||||
arrows = _detect_arrows(img)
|
||||
|
||||
all_annotations = circles + arrows
|
||||
_correlate_with_ocr(all_annotations, ocr_elements, w, h)
|
||||
|
||||
result: dict = {
|
||||
"circles": [_annotation_to_dict(a) for a in circles],
|
||||
"arrows": [_annotation_to_dict(a) for a in arrows],
|
||||
"total": len(all_annotations),
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def _annotation_to_dict(a: Annotation) -> dict:
|
||||
d = {
|
||||
"type": a.type,
|
||||
"bbox": a.bbox,
|
||||
"center": list(a.center),
|
||||
"nearby_texts": a.nearby_texts,
|
||||
}
|
||||
if a.type == "arrow":
|
||||
d["from_text"] = a.from_text
|
||||
d["to_text"] = a.to_text
|
||||
if a.from_pt:
|
||||
d["from_pt"] = list(a.from_pt)
|
||||
if a.to_pt:
|
||||
d["to_pt"] = list(a.to_pt)
|
||||
return d
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 圆圈检测
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _detect_circles(img: np.ndarray) -> list[Annotation]:
|
||||
"""检测图片中可能是手绘批注的圆圈。"""
|
||||
h, w = img.shape[:2]
|
||||
b, g, r = cv2.split(img)
|
||||
red_enhanced = cv2.addWeighted(r.astype(np.float32), 1.5,
|
||||
g.astype(np.float32), -0.3, 0)
|
||||
red_enhanced = cv2.addWeighted(red_enhanced, 1.2,
|
||||
b.astype(np.float32), -0.3, 0)
|
||||
red_enhanced = np.clip(red_enhanced, 0, 255).astype(np.uint8)
|
||||
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
combined = cv2.addWeighted(gray, 0.5, red_enhanced, 0.5, 0)
|
||||
blurred = cv2.GaussianBlur(combined, (9, 9), 2)
|
||||
|
||||
min_radius = max(15, min(w, h) // 40)
|
||||
max_radius = min(200, max(w, h) // 8)
|
||||
|
||||
circles_raw = cv2.HoughCircles(
|
||||
blurred, cv2.HOUGH_GRADIENT, dp=1.2, minDist=min_radius * 2,
|
||||
param1=50, param2=30, minRadius=min_radius, maxRadius=max_radius,
|
||||
)
|
||||
|
||||
annotations: list[Annotation] = []
|
||||
|
||||
if circles_raw is not None:
|
||||
for cx, cy, r in circles_raw[0]:
|
||||
bbox = {
|
||||
"x": max(0, int(cx - r)),
|
||||
"y": max(0, int(cy - r)),
|
||||
"w": int(r * 2),
|
||||
"h": int(r * 2),
|
||||
}
|
||||
annotations.append(Annotation(
|
||||
type="circle",
|
||||
bbox=bbox,
|
||||
center=(int(cx), int(cy)),
|
||||
))
|
||||
|
||||
return annotations
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 箭头检测
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _detect_arrows(img: np.ndarray) -> list[Annotation]:
|
||||
"""检测图片中的手绘箭头(直线段 + 端点三角形)。"""
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
edges = cv2.Canny(gray, 50, 150, apertureSize=3)
|
||||
|
||||
lines = cv2.HoughLinesP(
|
||||
edges, rho=1, theta=np.pi / 180, threshold=40,
|
||||
minLineLength=30, maxLineGap=15,
|
||||
)
|
||||
|
||||
if lines is None:
|
||||
return []
|
||||
|
||||
segments = [(x1, y1, x2, y2) for x1, y1, x2, y2 in lines[:, 0]]
|
||||
clusters = _cluster_segments(segments)
|
||||
|
||||
annotations: list[Annotation] = []
|
||||
for segs in clusters:
|
||||
if len(segs) < 2:
|
||||
continue
|
||||
all_pts = []
|
||||
for x1, y1, x2, y2 in segs:
|
||||
all_pts.append((x1, y1))
|
||||
all_pts.append((x2, y2))
|
||||
all_pts_arr = np.array(all_pts)
|
||||
max_dist = 0
|
||||
p1 = p2 = all_pts[0]
|
||||
for i in range(len(all_pts)):
|
||||
for j in range(i + 1, len(all_pts)):
|
||||
d = (all_pts[i][0] - all_pts[j][0]) ** 2 + (all_pts[i][1] - all_pts[j][1]) ** 2
|
||||
if d > max_dist:
|
||||
max_dist = d
|
||||
p1, p2 = all_pts[i], all_pts[j]
|
||||
|
||||
from_pt, to_pt = _find_arrow_direction(edges, p1, p2)
|
||||
|
||||
x1, y1 = from_pt
|
||||
x2, y2 = to_pt
|
||||
bbox = {
|
||||
"x": min(x1, x2),
|
||||
"y": min(y1, y2),
|
||||
"w": abs(x2 - x1),
|
||||
"h": abs(y2 - y1),
|
||||
}
|
||||
cx = (x1 + x2) // 2
|
||||
cy = (y1 + y2) // 2
|
||||
|
||||
annotations.append(Annotation(
|
||||
type="arrow",
|
||||
bbox=bbox,
|
||||
center=(cx, cy),
|
||||
from_pt=from_pt,
|
||||
to_pt=to_pt,
|
||||
))
|
||||
|
||||
return annotations
|
||||
|
||||
|
||||
def _cluster_segments(segments: list[tuple]) -> list[list[tuple]]:
|
||||
"""将线段按方向和空间距离聚类。"""
|
||||
clusters: list[list[tuple]] = []
|
||||
used = [False] * len(segments)
|
||||
|
||||
for i, (x1, y1, x2, y2) in enumerate(segments):
|
||||
if used[i]:
|
||||
continue
|
||||
cluster = [(x1, y1, x2, y2)]
|
||||
used[i] = True
|
||||
angle_i = math.atan2(y2 - y1, x2 - x1)
|
||||
|
||||
for j in range(i + 1, len(segments)):
|
||||
if used[j]:
|
||||
continue
|
||||
x3, y3, x4, y4 = segments[j]
|
||||
angle_j = math.atan2(y4 - y3, x4 - x3)
|
||||
angle_diff = abs(angle_i - angle_j)
|
||||
if angle_diff > math.pi:
|
||||
angle_diff = 2 * math.pi - angle_diff
|
||||
|
||||
if angle_diff < 0.35:
|
||||
d1 = math.hypot(x3 - x2, y3 - y2)
|
||||
d2 = math.hypot(x1 - x4, y1 - y4)
|
||||
d3 = math.hypot(x3 - x1, y3 - y1)
|
||||
d4 = math.hypot(x4 - x2, y4 - y2)
|
||||
if min(d1, d2, d3, d4) < 80:
|
||||
cluster.append((x3, y3, x4, y4))
|
||||
used[j] = True
|
||||
|
||||
clusters.append(cluster)
|
||||
|
||||
return clusters
|
||||
|
||||
|
||||
def _find_arrow_direction(edges: np.ndarray, p1: tuple, p2: tuple) -> tuple[tuple, tuple]:
|
||||
"""判断箭头的方向(哪端是箭头/三角形汇聚点)。"""
|
||||
r = 20
|
||||
h, w = edges.shape[:2]
|
||||
|
||||
def edge_density(cx, cy):
|
||||
x1 = max(0, int(cx - r))
|
||||
y1 = max(0, int(cy - r))
|
||||
x2 = min(w, int(cx + r))
|
||||
y2 = min(h, int(cy + r))
|
||||
roi = edges[y1:y2, x1:x2]
|
||||
if roi.size == 0:
|
||||
return 0
|
||||
return float(np.count_nonzero(roi)) / roi.size
|
||||
|
||||
d1 = edge_density(p1[0], p1[1])
|
||||
d2 = edge_density(p2[0], p2[1])
|
||||
|
||||
if d1 > d2 * 1.3:
|
||||
return p2, p1
|
||||
if d2 > d1 * 1.3:
|
||||
return p1, p2
|
||||
return p1, p2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OCR 关联
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _correlate_with_ocr(
|
||||
annotations: list[Annotation],
|
||||
ocr_elements: list[dict],
|
||||
img_w: int,
|
||||
img_h: int,
|
||||
) -> None:
|
||||
"""将批注与附近的 OCR 文本关联。"""
|
||||
if not ocr_elements:
|
||||
return
|
||||
|
||||
for ann in annotations:
|
||||
ax = ann.center[0]
|
||||
ay = ann.center[1]
|
||||
|
||||
near_texts: list[tuple[str, float]] = []
|
||||
|
||||
for elem in ocr_elements:
|
||||
bbox = elem.get("bbox", {})
|
||||
ex = bbox.get("x", 0) + bbox.get("w", 0) / 2
|
||||
ey = bbox.get("y", 0) + bbox.get("h", 0) / 2
|
||||
dist = math.hypot(ax - ex, ay - ey)
|
||||
max_dist = max(img_w, img_h) * 0.15
|
||||
if dist < max_dist:
|
||||
near_texts.append((elem.get("text", ""), dist))
|
||||
|
||||
near_texts.sort(key=lambda x: x[1])
|
||||
ann.nearby_texts = [t for t, _ in near_texts[:5]]
|
||||
|
||||
if ann.type == "arrow" and ann.from_pt and ann.to_pt:
|
||||
ann.from_text = _closest_text(ann.from_pt, ocr_elements, img_w, img_h)
|
||||
ann.to_text = _closest_text(ann.to_pt, ocr_elements, img_w, img_h)
|
||||
|
||||
|
||||
def _closest_text(pt: tuple[int, int], ocr_elements: list[dict], img_w: int, img_h: int) -> str:
|
||||
"""找到离 pt 最近的 OCR 文本。"""
|
||||
best_text = ""
|
||||
best_dist = max(img_w, img_h) * 0.12
|
||||
for elem in ocr_elements:
|
||||
bbox = elem.get("bbox", {})
|
||||
ex = bbox.get("x", 0) + bbox.get("w", 0) / 2
|
||||
ey = bbox.get("y", 0) + bbox.get("h", 0) / 2
|
||||
dist = math.hypot(pt[0] - ex, pt[1] - ey)
|
||||
if dist < best_dist:
|
||||
best_dist = dist
|
||||
best_text = elem.get("text", "")
|
||||
return best_text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM 上下文格式化
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def format_annotation_context(annotation_result: dict) -> str:
|
||||
"""将批注检测结果格式化为中文 LLM 提示文本。"""
|
||||
if not annotation_result or not isinstance(annotation_result, dict):
|
||||
return ""
|
||||
|
||||
circles = annotation_result.get("circles", [])
|
||||
arrows = annotation_result.get("arrows", [])
|
||||
total = annotation_result.get("total", len(circles) + len(arrows))
|
||||
|
||||
if total == 0:
|
||||
return ""
|
||||
|
||||
parts = ["[图片批注检测结果]"]
|
||||
|
||||
if circles:
|
||||
parts.append(f"\n检测到 {len(circles)} 个圈选标记:")
|
||||
for i, c in enumerate(circles):
|
||||
center = c.get("center", [0, 0])
|
||||
near = c.get("nearby_texts", [])
|
||||
parts.append(
|
||||
f" 圈{i+1}. 位置 ({center[0]},{center[1]})"
|
||||
f" — 圈选内容: {', '.join(near) if near else '(附近无文字)'}"
|
||||
)
|
||||
|
||||
if arrows:
|
||||
parts.append(f"\n检测到 {len(arrows)} 个箭头标记:")
|
||||
for i, a in enumerate(arrows):
|
||||
ft = a.get("from_text", "")
|
||||
tt = a.get("to_text", "")
|
||||
parts.append(f" 箭头{i+1}. 从「{ft}」→ 指向「{tt}」")
|
||||
|
||||
parts.append("\n请根据上述圈选/箭头定位用户要修改的报表字段。")
|
||||
return "\n".join(parts)
|
||||
+95
-38
@@ -52,6 +52,8 @@ def parse_file(file_path: str, file_type: str = "") -> dict:
|
||||
".pdf": _parse_pdf,
|
||||
".docx": _parse_docx,
|
||||
".xlsx": _parse_xlsx,
|
||||
".xls": _parse_xls,
|
||||
".doc": _parse_doc,
|
||||
}
|
||||
|
||||
parser = parsers.get(suffix)
|
||||
@@ -73,26 +75,7 @@ def _parse_image(path: Path) -> dict:
|
||||
except Exception:
|
||||
info = "[图片: 无法读取元数据]"
|
||||
|
||||
# 优先 EasyOCR(Windows 兼容性更好)
|
||||
try:
|
||||
import easyocr
|
||||
import numpy as np
|
||||
reader = easyocr.Reader(["ch_sim", "en"], gpu=False, verbose=False)
|
||||
result = reader.readtext(np.array(img))
|
||||
lines = [text.strip() for (_, text, _) in result if text.strip()]
|
||||
if lines:
|
||||
return {
|
||||
"text": f"{info}\n识别文本:\n" + "\n".join(lines),
|
||||
"file_type": "image",
|
||||
"method": "easyocr",
|
||||
"error": None,
|
||||
}
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 回退 PaddleOCR
|
||||
# 优先 PaddleOCR(精确识别)
|
||||
try:
|
||||
from paddleocr import PaddleOCR
|
||||
ocr = PaddleOCR(lang="ch")
|
||||
@@ -115,6 +98,25 @@ def _parse_image(path: Path) -> dict:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 回退 EasyOCR
|
||||
try:
|
||||
import easyocr
|
||||
import numpy as np
|
||||
reader = easyocr.Reader(["ch_sim", "en"], gpu=False, verbose=False)
|
||||
result = reader.readtext(np.array(img))
|
||||
lines = [text.strip() for (_, text, _) in result if text.strip()]
|
||||
if lines:
|
||||
return {
|
||||
"text": f"{info}\n识别文本:\n" + "\n".join(lines),
|
||||
"file_type": "image",
|
||||
"method": "easyocr",
|
||||
"error": None,
|
||||
}
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# OCR 不可用 → 返回图片元信息 + 安装提示
|
||||
return {
|
||||
"text": f"{info}\n(如需 OCR 文字识别,请安装: pip install easyocr)",
|
||||
@@ -197,36 +199,91 @@ def _parse_docx(path: Path) -> dict:
|
||||
|
||||
|
||||
def _parse_xlsx(path: Path) -> dict:
|
||||
"""提取 Excel (.xlsx) 表格内容为文本。"""
|
||||
"""提取 Excel .xlsx 文件中的文本。"""
|
||||
try:
|
||||
import openpyxl
|
||||
wb = openpyxl.load_workbook(path, read_only=True, data_only=True)
|
||||
sheets_text = []
|
||||
for sheet_name in wb.sheetnames:
|
||||
ws = wb[sheet_name]
|
||||
from openpyxl import load_workbook
|
||||
wb = load_workbook(path, read_only=True, data_only=True)
|
||||
parts = []
|
||||
for name in wb.sheetnames:
|
||||
ws = wb[name]
|
||||
rows = []
|
||||
for row in ws.iter_rows(values_only=True):
|
||||
cells = [str(c) if c is not None else "" for c in row]
|
||||
if any(c.strip() for c in cells):
|
||||
rows.append(" | ".join(cells))
|
||||
if any(c for c in cells):
|
||||
rows.append("\t".join(cells))
|
||||
if rows:
|
||||
sheets_text.append(f"--- 工作表: {sheet_name} ---\n" + "\n".join(rows))
|
||||
parts.append(f"[Sheet: {name}]\n" + "\n".join(rows))
|
||||
wb.close()
|
||||
if sheets_text:
|
||||
return {
|
||||
"text": "\n\n".join(sheets_text),
|
||||
"file_type": "xlsx",
|
||||
"method": "openpyxl",
|
||||
"error": None,
|
||||
}
|
||||
text = "\n\n".join(parts)
|
||||
return {"text": text, "file_type": "xlsx", "method": "openpyxl", "error": None}
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
return {"text": "", "file_type": "xlsx", "method": "none",
|
||||
"error": f"XLSX 解析失败: {e}"}
|
||||
return {"text": "", "file_type": "xlsx", "method": "none",
|
||||
"error": "XLSX 解析需要安装 openpyxl"}
|
||||
|
||||
|
||||
def _parse_xls(path: Path) -> dict:
|
||||
"""提取旧版 Excel .xls 文件中的文本。"""
|
||||
try:
|
||||
import xlrd
|
||||
wb = xlrd.open_workbook(path)
|
||||
parts = []
|
||||
for name in wb.sheet_names():
|
||||
ws = wb.sheet_by_name(name)
|
||||
rows = []
|
||||
for rx in range(ws.nrows):
|
||||
cells = [str(ws.cell_value(rx, cx)) if ws.cell_value(rx, cx) != "" else ""
|
||||
for cx in range(ws.ncols)]
|
||||
if any(c for c in cells):
|
||||
rows.append("\t".join(cells))
|
||||
if rows:
|
||||
parts.append(f"[Sheet: {name}]\n" + "\n".join(rows))
|
||||
text = "\n\n".join(parts)
|
||||
return {"text": text, "file_type": "xls", "method": "xlrd", "error": None}
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
return {"text": "", "file_type": "xls", "method": "none",
|
||||
"error": f"XLS 解析失败: {e}"}
|
||||
return {"text": "", "file_type": "xls", "method": "none",
|
||||
"error": "XLS 解析需要安装 xlrd"}
|
||||
|
||||
|
||||
def _parse_doc(path: Path) -> dict:
|
||||
"""提取旧版 Word .doc 文件中的文本(尽力而为,二进制格式)。"""
|
||||
try:
|
||||
import olefile
|
||||
ole = olefile.OleFileIO(path)
|
||||
if not ole.exists("WordDocument"):
|
||||
ole.close()
|
||||
return {"text": "", "file_type": "doc", "method": "none",
|
||||
"error": "不是有效的 .doc 文件"}
|
||||
raw = ole.openstream("WordDocument").read()
|
||||
ole.close()
|
||||
# 提取可打印 UTF-16LE 字符段
|
||||
text = ""
|
||||
try:
|
||||
decoded = raw.decode("utf-16-le", errors="ignore")
|
||||
text = "".join(c for c in decoded if c.isprintable() or c in "\n\r\t")
|
||||
except Exception:
|
||||
pass
|
||||
if not text.strip():
|
||||
return {"text": "", "file_type": "doc", "method": "olefile",
|
||||
"error": "无法提取文本(.doc 为二进制格式,建议转换为 .docx)"}
|
||||
return {"text": text.strip(), "file_type": "doc", "method": "olefile", "error": None}
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
return {"text": "", "file_type": "doc", "method": "none",
|
||||
"error": f"DOC 解析失败: {e}"}
|
||||
return {"text": "", "file_type": "doc", "method": "none",
|
||||
"error": "DOC 解析需要安装 olefile"}
|
||||
|
||||
|
||||
|
||||
def _parse_text(path: Path) -> dict:
|
||||
"""读取纯文本文件。"""
|
||||
try:
|
||||
|
||||
+174
-34
@@ -119,6 +119,146 @@ def analyze_layout(
|
||||
}
|
||||
|
||||
|
||||
def extract_layout_schema(layout_result: dict) -> dict:
|
||||
"""将 analyze_layout() 的完整 OCR 行数据压缩为高层布局 schema。
|
||||
|
||||
列检测:跨所有行对元素 X 坐标进行聚类。
|
||||
区域分类:启发式识别标题/表头/数据/表尾行。
|
||||
输出紧凑的 schema_text,供 LLM 阶段一骨架生成使用。
|
||||
"""
|
||||
rows = layout_result.get("rows", [])
|
||||
if not rows:
|
||||
return _empty_schema()
|
||||
|
||||
img_w, img_h = layout_result.get("image_size", (595, 842))
|
||||
if img_w <= 0:
|
||||
img_w = 595
|
||||
|
||||
all_elements = []
|
||||
for row in rows:
|
||||
all_elements.extend(row.get("elements", []))
|
||||
if not all_elements:
|
||||
return _empty_schema()
|
||||
|
||||
x_centers = sorted((e["x"] + e["w"] / 2) for e in all_elements)
|
||||
avg_width = sum(e["w"] for e in all_elements) / len(all_elements)
|
||||
cluster_threshold = avg_width * 0.5
|
||||
|
||||
clusters = []
|
||||
current_cluster = [x_centers[0]]
|
||||
for xc in x_centers[1:]:
|
||||
if xc - current_cluster[-1] < cluster_threshold:
|
||||
current_cluster.append(xc)
|
||||
else:
|
||||
clusters.append(current_cluster)
|
||||
current_cluster = [xc]
|
||||
if current_cluster:
|
||||
clusters.append(current_cluster)
|
||||
|
||||
columns = []
|
||||
for ci, cluster in enumerate(clusters):
|
||||
cx_min = min(cluster)
|
||||
cx_max = max(cluster)
|
||||
col_elements = [
|
||||
e for e in all_elements
|
||||
if cx_min - cluster_threshold <= (e["x"] + e["w"] / 2) <= cx_max + cluster_threshold
|
||||
]
|
||||
avg_w = sum(e["w"] for e in col_elements) / len(col_elements) if col_elements else 0
|
||||
x_start = min(e["x"] for e in col_elements)
|
||||
|
||||
col_elements_by_y = sorted(col_elements, key=lambda e: e["y"])
|
||||
header_text = col_elements_by_y[0]["text"] if col_elements_by_y else f"列{ci+1}"
|
||||
|
||||
columns.append({
|
||||
"index": ci,
|
||||
"header_text": header_text,
|
||||
"avg_width": round(avg_w, 1),
|
||||
"x_start": round(x_start, 1),
|
||||
})
|
||||
|
||||
columns.sort(key=lambda c: c["x_start"])
|
||||
|
||||
row_element_counts = [len(r.get("elements", [])) for r in rows]
|
||||
median_count = sorted(row_element_counts)[len(row_element_counts) // 2] if row_element_counts else 0
|
||||
total_rows = len(rows)
|
||||
|
||||
regions = []
|
||||
current_region = None
|
||||
|
||||
for ri in range(total_rows):
|
||||
count = row_element_counts[ri]
|
||||
if ri == 0 and count < median_count * 0.6 and total_rows > 2:
|
||||
rtype = "title"
|
||||
elif ri == 0 and total_rows <= 2:
|
||||
rtype = "header"
|
||||
elif ri == 1 and total_rows > 2:
|
||||
rtype = "header" if median_count > 0 else "data"
|
||||
elif ri >= total_rows - 2 and count < median_count * 0.7 and total_rows > 3:
|
||||
rtype = "footer"
|
||||
else:
|
||||
rtype = "data"
|
||||
|
||||
if current_region and current_region["type"] == rtype:
|
||||
current_region["row_indices"].append(ri)
|
||||
current_region["element_count"] += count
|
||||
else:
|
||||
if current_region:
|
||||
regions.append(current_region)
|
||||
current_region = {"type": rtype, "row_indices": [ri], "element_count": count}
|
||||
|
||||
if current_region:
|
||||
regions.append(current_region)
|
||||
|
||||
# schema_text
|
||||
width_ratios = [c["avg_width"] / img_w for c in columns]
|
||||
width_labels = []
|
||||
for r in width_ratios:
|
||||
if r < 0.08:
|
||||
width_labels.append("窄")
|
||||
elif r > 0.20:
|
||||
width_labels.append("宽")
|
||||
else:
|
||||
width_labels.append("中")
|
||||
|
||||
col_descs = []
|
||||
for ci, col in enumerate(columns):
|
||||
wl = width_labels[ci] if ci < len(width_labels) else "中"
|
||||
col_descs.append(f"{col['header_text']}({wl})")
|
||||
|
||||
_rn = {"title": "标题", "header": "表头", "data": "数据", "footer": "表尾"}
|
||||
region_parts = []
|
||||
for r in regions:
|
||||
label = _rn.get(r["type"], r["type"])
|
||||
region_parts.append(f"{label}({len(r['row_indices'])}行)")
|
||||
region_summary = " → ".join(region_parts)
|
||||
|
||||
schema_text = (
|
||||
f"报表布局: {len(columns)}列 x {total_rows}行, A4纵向\n"
|
||||
f"列定义: {', '.join(col_descs)}\n"
|
||||
f"区域: {region_summary}"
|
||||
)
|
||||
|
||||
return {
|
||||
"columns": columns,
|
||||
"regions": regions,
|
||||
"total_rows": total_rows,
|
||||
"total_columns": len(columns),
|
||||
"a4_dimensions": {"width": 595, "height": 842},
|
||||
"schema_text": schema_text,
|
||||
}
|
||||
|
||||
|
||||
def _empty_schema() -> dict:
|
||||
return {
|
||||
"columns": [],
|
||||
"regions": [],
|
||||
"total_rows": 0,
|
||||
"total_columns": 0,
|
||||
"a4_dimensions": {"width": 595, "height": 842},
|
||||
"schema_text": "无法解析报表布局",
|
||||
}
|
||||
|
||||
|
||||
def match_rows_to_jrxml(
|
||||
layout_result: dict,
|
||||
current_jrxml: str,
|
||||
@@ -373,40 +513,7 @@ def _load_image(path: Path) -> Optional[PIL.Image.Image]:
|
||||
def _ocr_elements(img: PIL.Image.Image, file_path: str) -> list[dict]:
|
||||
"""OCR 提取图片中的文字元素(位置+内容)。优先 EasyOCR,回退 PaddleOCR。"""
|
||||
|
||||
# 优先 EasyOCR
|
||||
try:
|
||||
import easyocr
|
||||
import numpy as np
|
||||
|
||||
reader = easyocr.Reader(["ch_sim", "en"], gpu=False, verbose=False)
|
||||
result = reader.readtext(np.array(img))
|
||||
|
||||
elements = []
|
||||
for (bbox, text, confidence) in result:
|
||||
if not text.strip():
|
||||
continue
|
||||
xs = [p[0] for p in bbox]
|
||||
ys = [p[1] for p in bbox]
|
||||
x_min, x_max = min(xs), max(xs)
|
||||
y_min, y_max = min(ys), max(ys)
|
||||
|
||||
elements.append({
|
||||
"x": round(x_min, 1),
|
||||
"y": round(y_min, 1),
|
||||
"w": round(x_max - x_min, 1),
|
||||
"h": round(y_max - y_min, 1),
|
||||
"font_size": round(y_max - y_min, 1),
|
||||
"text": text.strip(),
|
||||
})
|
||||
|
||||
elements.sort(key=lambda e: (e["y"], e["x"]))
|
||||
return elements
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 回退 PaddleOCR
|
||||
# 优先 PaddleOCR(精确识别)
|
||||
try:
|
||||
from paddleocr import PaddleOCR
|
||||
import numpy as np
|
||||
@@ -446,6 +553,39 @@ def _ocr_elements(img: PIL.Image.Image, file_path: str) -> list[dict]:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 回退 EasyOCR
|
||||
try:
|
||||
import easyocr
|
||||
import numpy as np
|
||||
|
||||
reader = easyocr.Reader(["ch_sim", "en"], gpu=False, verbose=False)
|
||||
result = reader.readtext(np.array(img))
|
||||
|
||||
elements = []
|
||||
for (bbox, text, confidence) in result:
|
||||
if not text.strip():
|
||||
continue
|
||||
xs = [p[0] for p in bbox]
|
||||
ys = [p[1] for p in bbox]
|
||||
x_min, x_max = min(xs), max(xs)
|
||||
y_min, y_max = min(ys), max(ys)
|
||||
|
||||
elements.append({
|
||||
"x": round(x_min, 1),
|
||||
"y": round(y_min, 1),
|
||||
"w": round(x_max - x_min, 1),
|
||||
"h": round(y_max - y_min, 1),
|
||||
"font_size": round(y_max - y_min, 1),
|
||||
"text": text.strip(),
|
||||
})
|
||||
|
||||
elements.sort(key=lambda e: (e["y"], e["x"]))
|
||||
return elements
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return []
|
||||
|
||||
|
||||
|
||||
@@ -284,13 +284,13 @@ class OcrExtractor:
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
easyocr_result = self._try_easyocr(np.array(img))
|
||||
if easyocr_result:
|
||||
return easyocr_result
|
||||
|
||||
paddleocr_result = self._try_paddleocr(img, file_path)
|
||||
if paddleocr_result:
|
||||
return paddleocr_result
|
||||
|
||||
easyocr_result = self._try_easyocr(np.array(img))
|
||||
if easyocr_result:
|
||||
return easyocr_result
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
你是一位资深 JasperReports 工程师。当前有一个 JRXML 使用占位字段名($F{field_1}, $F{field_2}, ...),需要替换为从 OCR 提取的真实字段名。
|
||||
|
||||
关键规则:
|
||||
- 只输出完整修改后的 JRXML 代码,不要解释,不要 markdown 标记。
|
||||
- 将每个 $F{field_N} 占位符替换为 OCR 提取结果中对应的真实字段名。
|
||||
- 替换规则:根据列的顺序映射——$F{field_1} 对应第 1 列的 OCR 字段名,$F{field_2} 对应第 2 列,以此类推。
|
||||
- 同时更新 <field name="..."> 声明和所有 $F{...} 表达式中的引用。
|
||||
- 如果 OCR 提取的字段数少于占位字段数,保留多余的占位字段。
|
||||
- 不要修改 band 结构、元素位置或大小。
|
||||
- 确保 JRXML 兼容 JasperReports 7.0.6。
|
||||
|
||||
当前 JRXML(含占位字段):
|
||||
{current_jrxml}
|
||||
|
||||
OCR 提取的结构化字段:
|
||||
{ocr_fields}
|
||||
@@ -21,6 +21,9 @@ _NAME_MAP = {
|
||||
"correction": "correction.md",
|
||||
"explain_error": "explain_error.md",
|
||||
"compression": "compression.md",
|
||||
"skeleton_generation": "skeleton_generation.md",
|
||||
"refine_layout": "refine_layout.md",
|
||||
"field_mapping": "field_mapping.md",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
- 如果添加新字段,正确声明它们。
|
||||
- 确保 <queryString> 是 <![CDATA[...]]> 中有效的 SQL。
|
||||
|
||||
{ocr_context}
|
||||
|
||||
当前 JRXML:
|
||||
{current_jrxml}
|
||||
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
你是一位资深 JasperReports 工程师。当前有一个骨架 JRXML,需要根据精确的像素坐标调整每个元素的位置。
|
||||
|
||||
关键规则:
|
||||
- 只输出完整修改后的 JRXML 代码,不要解释,不要 markdown 标记。
|
||||
- 根据提供的采样坐标,精确调整每个 textField/staticText 的 x, y, width, height。
|
||||
- 表头行的坐标直接使用采样坐标中 header_row 对应列的 x, y, width, height。
|
||||
- 数据行:根据 first_data_row 的坐标模式,向下插值生成剩余数据行(每行 y 递增行高)。
|
||||
- 标题行(如有)和表尾行:保持其在骨架中的 y 位置大致不变,但调整 x 和 width 与列的采样坐标对齐。
|
||||
- 不要修改字段名(保持 $F{field_N} 占位名不变)。
|
||||
- 不要修改 band 结构。
|
||||
- 确保 JRXML 兼容 JasperReports 7.0.6。
|
||||
|
||||
当前骨架 JRXML:
|
||||
{current_jrxml}
|
||||
|
||||
采样坐标(表头行 + 第一行数据行,像素位置):
|
||||
{sampled_coordinates}
|
||||
@@ -0,0 +1,19 @@
|
||||
你是一位资深 JasperReports 工程师。根据以下报表布局描述和用户需求,生成一个完整的骨架 JRXML 文件。
|
||||
|
||||
关键规则:
|
||||
- 只输出 JRXML 代码,不要解释,不要 markdown 标记。
|
||||
- 使用 $F{field_1}, $F{field_2}, ... 作为占位字段名,并在 <field> 部分声明它们。
|
||||
- 报表结构必须正确(title, pageHeader, columnHeader, detail, pageFooter 等 band)。
|
||||
- 元素位置使用近似值即可,后续会精确调整。
|
||||
- 根元素为 <jasperReport>,包含正确的 xmlns 属性。
|
||||
- 包含 <queryString>,在 <![CDATA[...]]> 中放置占位 SQL(SELECT * FROM table_name)。
|
||||
- 确保 JRXML 兼容 JasperReports 7.0.6。
|
||||
|
||||
报表布局描述:
|
||||
{layout_schema}
|
||||
|
||||
参考模板和组件:
|
||||
{context}
|
||||
|
||||
用户需求:
|
||||
{user_request}
|
||||
@@ -27,6 +27,24 @@ httpx>=0.27.0
|
||||
tiktoken>=0.7.0
|
||||
openpyxl>=3.1.0
|
||||
|
||||
# OCR 依赖(PaddleOCR 精确识别优先,EasyOCR 回退)
|
||||
# Pinned: paddleocr 2.9.x + paddlepaddle 2.6.x known-stable on Windows CPU
|
||||
# 3.x has ONEDNN compatibility issues on Windows
|
||||
paddleocr>=2.9.0,<3.0.0
|
||||
paddlepaddle>=2.6.0,<3.0.0
|
||||
easyocr>=1.7.0
|
||||
# 聊天输入增强(粘贴/拖拽上传)
|
||||
st-multimodal-chatinput>=0.2.1
|
||||
|
||||
# 多格式文件解析
|
||||
openpyxl>=3.1.0
|
||||
xlrd>=2.0.0
|
||||
olefile>=0.47
|
||||
|
||||
# 批注检测(圈选/箭头识别)
|
||||
opencv-python-headless>=4.8.0
|
||||
|
||||
# 测试
|
||||
pytest>=8.0.0
|
||||
pytest-asyncio>=0.24.0
|
||||
xlwt>=1.3.0
|
||||
|
||||
@@ -0,0 +1,151 @@
|
||||
"""测试批注检测器:圆圈检测、箭头检测、OCR 关联、格式化。"""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
def _draw_circle_image(path: str, size: tuple = (400, 300)) -> None:
|
||||
"""生成包含红色圆圈的合成测试图片。"""
|
||||
img = np.ones((size[1], size[0], 3), dtype=np.uint8) * 255
|
||||
cv2.circle(img, (200, 150), 50, (0, 0, 255), 2)
|
||||
cv2.imwrite(path, img)
|
||||
|
||||
|
||||
def _draw_arrow_image(path: str, size: tuple = (400, 300)) -> None:
|
||||
"""生成包含手绘风格箭头的合成测试图片(多段线模拟手绘)。"""
|
||||
img = np.ones((size[1], size[0], 3), dtype=np.uint8) * 255
|
||||
# 多段略微偏移的线段模拟手绘箭杆(产生多个 HoughLinesP 段)
|
||||
for offset_y in (-1, 0, 1):
|
||||
cv2.line(img, (50, 150 + offset_y), (200, 150 + offset_y), (0, 0, 255), 2)
|
||||
for offset_y in (-1, 0, 1):
|
||||
cv2.line(img, (200, 150 + offset_y), (340, 150 + offset_y), (0, 0, 255), 2)
|
||||
# 箭头三角形
|
||||
pts = np.array([[350, 150], [330, 135], [330, 165]], np.int32)
|
||||
cv2.fillPoly(img, [pts], (0, 0, 255))
|
||||
# 额外三角形边缘线
|
||||
cv2.line(img, (350, 150), (330, 135), (0, 0, 255), 2)
|
||||
cv2.line(img, (350, 150), (330, 165), (0, 0, 255), 2)
|
||||
cv2.imwrite(path, img)
|
||||
|
||||
|
||||
def _draw_circle_and_text_image(path: str, size: tuple = (500, 400)) -> None:
|
||||
"""生成包含红色圆圈和"文本"的合成图片(模拟圈选批注)。"""
|
||||
img = np.ones((size[1], size[0], 3), dtype=np.uint8) * 255
|
||||
cv2.circle(img, (250, 150), 60, (0, 0, 255), 3)
|
||||
cv2.putText(img, "项目A", (20, 160), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
|
||||
cv2.imwrite(path, img)
|
||||
|
||||
|
||||
class TestAnnotationDetector:
|
||||
"""测试 annotation_detector.py 各功能。"""
|
||||
|
||||
def test_detect_circles_finds_circle(self):
|
||||
from backend.annotation_detector import detect_annotations
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
||||
path = tmp.name
|
||||
try:
|
||||
_draw_circle_image(path)
|
||||
ocr_elements = [
|
||||
{"text": "测试字段", "bbox": {"x": 170, "y": 120, "w": 60, "h": 20}, "confidence": 0.95},
|
||||
]
|
||||
result = detect_annotations(path, ocr_elements)
|
||||
assert result["total"] >= 1
|
||||
circles = result["circles"]
|
||||
assert len(circles) >= 1
|
||||
c = circles[0]
|
||||
assert c["type"] == "circle"
|
||||
assert "center" in c
|
||||
assert "bbox" in c
|
||||
assert "nearby_texts" in c
|
||||
finally:
|
||||
Path(path).unlink(missing_ok=True)
|
||||
|
||||
def test_detect_arrows_finds_arrow(self):
|
||||
from backend.annotation_detector import detect_annotations
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
||||
path = tmp.name
|
||||
try:
|
||||
_draw_arrow_image(path)
|
||||
ocr_elements = [
|
||||
{"text": "起点", "bbox": {"x": 30, "y": 130, "w": 40, "h": 20}, "confidence": 0.9},
|
||||
{"text": "终点", "bbox": {"x": 310, "y": 130, "w": 40, "h": 20}, "confidence": 0.9},
|
||||
]
|
||||
result = detect_annotations(path, ocr_elements)
|
||||
assert result["total"] >= 1
|
||||
arrows = result["arrows"]
|
||||
assert len(arrows) >= 1
|
||||
a = arrows[0]
|
||||
assert a["type"] == "arrow"
|
||||
assert "from_pt" in a
|
||||
assert "to_pt" in a
|
||||
finally:
|
||||
Path(path).unlink(missing_ok=True)
|
||||
|
||||
def test_correlate_with_ocr_links_nearby_texts(self):
|
||||
from backend.annotation_detector import detect_annotations
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
||||
path = tmp.name
|
||||
try:
|
||||
_draw_circle_and_text_image(path)
|
||||
ocr_elements = [
|
||||
{"text": "项目A", "bbox": {"x": 20, "y": 140, "w": 80, "h": 30}, "confidence": 0.98},
|
||||
{"text": "金额", "bbox": {"x": 350, "y": 200, "w": 50, "h": 20}, "confidence": 0.9},
|
||||
]
|
||||
result = detect_annotations(path, ocr_elements)
|
||||
circles = result["circles"]
|
||||
if circles:
|
||||
near = circles[0].get("nearby_texts", [])
|
||||
if near:
|
||||
assert "项目A" in near
|
||||
finally:
|
||||
Path(path).unlink(missing_ok=True)
|
||||
|
||||
def test_invalid_image_path(self):
|
||||
from backend.annotation_detector import detect_annotations
|
||||
|
||||
result = detect_annotations("/nonexistent/file.png", [])
|
||||
assert result["total"] == 0
|
||||
assert "error" in result
|
||||
|
||||
def test_format_annotation_context_empty(self):
|
||||
from backend.annotation_detector import format_annotation_context
|
||||
|
||||
assert format_annotation_context({}) == ""
|
||||
assert format_annotation_context(None) == ""
|
||||
assert format_annotation_context({"circles": [], "arrows": [], "total": 0}) == ""
|
||||
|
||||
def test_format_annotation_context_with_circles(self):
|
||||
from backend.annotation_detector import format_annotation_context
|
||||
|
||||
ann = {
|
||||
"circles": [
|
||||
{"center": [100, 200], "nearby_texts": ["项目A", "金额"]},
|
||||
],
|
||||
"arrows": [],
|
||||
"total": 1,
|
||||
}
|
||||
text = format_annotation_context(ann)
|
||||
assert "圈选标记" in text
|
||||
assert "项目A" in text
|
||||
|
||||
def test_format_annotation_context_with_arrows(self):
|
||||
from backend.annotation_detector import format_annotation_context
|
||||
|
||||
ann = {
|
||||
"circles": [],
|
||||
"arrows": [
|
||||
{"from_text": "修理号", "to_text": "车架号"},
|
||||
],
|
||||
"total": 1,
|
||||
}
|
||||
text = format_annotation_context(ann)
|
||||
assert "箭头标记" in text
|
||||
assert "修理号" in text
|
||||
assert "车架号" in text
|
||||
@@ -0,0 +1,143 @@
|
||||
"""端到端测试:OCR 字段精确提取完整流水线。
|
||||
|
||||
覆盖:
|
||||
1. PaddleOCR 精确识别(优先)
|
||||
2. EasyOCR 降级回退
|
||||
3. 4种提取策略
|
||||
4. 验证服务连通性
|
||||
"""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from backend.ocr_extractor import OcrExtractor, extract_ocr_fields
|
||||
from backend.file_parser import parse_file
|
||||
from backend.validation import validate_jrxml
|
||||
|
||||
|
||||
def create_test_invoice(path: str):
|
||||
"""创建一张模拟中文发票图片,包含已知字段。"""
|
||||
img = Image.new("RGB", (800, 600), color="white")
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
draw.text((300, 20), "增值税普通发票", fill="black")
|
||||
draw.text((300, 60), "发票代码: 1234567890", fill="black")
|
||||
draw.text((300, 100), "发票号码: 87654321", fill="black")
|
||||
draw.text((50, 160), "开票日期: 2024年1月15日", fill="black")
|
||||
draw.text((50, 200), "购买方名称: 测试公司", fill="black")
|
||||
draw.text((50, 240), "合计金额: 1,234.56", fill="black")
|
||||
draw.text((50, 280), "校验码: ABC12345678", fill="black")
|
||||
|
||||
draw.text((50, 350), "名称 数量 单价", fill="black")
|
||||
draw.text((50, 390), "商品A 2 10.00", fill="black")
|
||||
draw.text((50, 430), "商品B 5 20.00", fill="black")
|
||||
|
||||
img.save(path)
|
||||
print(f"[OK] 测试图片已创建: {path}")
|
||||
return path
|
||||
|
||||
|
||||
def test_ocr_extraction_pipeline():
|
||||
"""端到端测试:图片 -> OCR -> 字段提取。"""
|
||||
print("\n=== 端到端OCR字段提取测试 ===\n")
|
||||
|
||||
img_path = create_test_invoice("test_invoice_e2e.png")
|
||||
|
||||
# 阶段1: 文件解析(含OCR)
|
||||
print("\n--- 阶段1: 文件解析(OCR) ---")
|
||||
result = parse_file(img_path)
|
||||
method = result.get("method", "N/A")
|
||||
print(f" OCR方法: {method}")
|
||||
print(f" 文件类型: {result.get('file_type', 'N/A')}")
|
||||
text_preview = result.get("text", "")[:200]
|
||||
print(f" 文本预览: {text_preview}")
|
||||
|
||||
# 阶段2: OCR精确提取
|
||||
print("\n--- 阶段2: 字段精确提取 ---")
|
||||
target_fields = ["发票代码", "发票号码", "开票日期", "合计金额", "校验码"]
|
||||
extraction = extract_ocr_fields(img_path, target_fields)
|
||||
|
||||
print(f" OCR可用: {extraction.get('ocr_available')}")
|
||||
print(f" 图片尺寸: {extraction.get('image_size')}")
|
||||
print(f" 元素总数: {extraction.get('total_elements')}")
|
||||
print(f" 错误: {extraction.get('errors')}")
|
||||
|
||||
print("\n 提取结果:")
|
||||
all_ok = True
|
||||
for field in extraction.get("fields", []):
|
||||
status = "PASS" if field["field_value"] else "FAIL"
|
||||
if not field["field_value"]:
|
||||
all_ok = False
|
||||
print(f" [{status}] {field['field_name']:10s} = {field['field_value']:20s} "
|
||||
f"方法={field['extraction_method']:12s} 置信度={field['confidence']:.2f} "
|
||||
f"bbox={field['bbox']}")
|
||||
|
||||
# 策略独立验证
|
||||
print("\n--- 4种策略独立验证 ---")
|
||||
extractor = OcrExtractor()
|
||||
result_obj = extractor.extract(img_path, ["合计金额"])
|
||||
fields = result_obj.get("fields", [])
|
||||
if fields:
|
||||
print(f" 策略: {fields[0].get('extraction_method', 'N/A')}")
|
||||
print(f" 值: {fields[0].get('field_value', 'N/A')}")
|
||||
print(f" 坐标: {fields[0].get('bbox', 'N/A')}")
|
||||
|
||||
return all_ok
|
||||
|
||||
|
||||
def test_validation_service():
|
||||
"""测试验证服务连通性。"""
|
||||
print("\n=== 验证服务连通性测试 ===\n")
|
||||
result = validate_jrxml("<jasperReport/>")
|
||||
print(f" 状态: {'OK' if result else 'FAIL'}")
|
||||
print(f" 响应: {result}")
|
||||
return True
|
||||
|
||||
|
||||
def test_ocr_fallback():
|
||||
"""测试OCR回退:无图片时优雅降级。"""
|
||||
print("\n=== OCR降级测试 ===\n")
|
||||
result = extract_ocr_fields("/nonexistent/file.png", ["发票代码"])
|
||||
print(f" OCR可用: {result.get('ocr_available')}")
|
||||
print(f" 错误: {result.get('errors')}")
|
||||
assert not result["ocr_available"]
|
||||
assert len(result["errors"]) > 0
|
||||
print(" [PASS] 降级行为正常(不阻断流程)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
errors = []
|
||||
|
||||
try:
|
||||
test_ocr_fallback()
|
||||
except Exception as e:
|
||||
print(f"[FAIL] 降级测试: {e}")
|
||||
errors.append(str(e))
|
||||
|
||||
try:
|
||||
ok = test_ocr_extraction_pipeline()
|
||||
if not ok:
|
||||
print("\n 部分字段未提取到(可能因字体渲染差异)")
|
||||
except Exception as e:
|
||||
print(f"[FAIL] 流水线测试: {e}")
|
||||
errors.append(str(e))
|
||||
|
||||
try:
|
||||
test_validation_service()
|
||||
except Exception as e:
|
||||
print(f"[FAIL] 验证服务: {e}")
|
||||
errors.append(str(e))
|
||||
|
||||
Path("test_invoice_e2e.png").unlink(missing_ok=True)
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
if errors:
|
||||
print(f"测试完成,{len(errors)} 个错误:")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("所有端到端测试通过!")
|
||||
@@ -0,0 +1,90 @@
|
||||
"""测试多格式文件解析器:XLSX, XLS, DOC。"""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_xlsx(path: str) -> None:
|
||||
"""生成最小 .xlsx 测试文件。"""
|
||||
from openpyxl import Workbook
|
||||
wb = Workbook()
|
||||
ws = wb.active
|
||||
ws.title = "Sheet1"
|
||||
ws["A1"] = "名称"
|
||||
ws["B1"] = "金额"
|
||||
ws["A2"] = "项目A"
|
||||
ws["B2"] = 100
|
||||
ws["A3"] = "项目B"
|
||||
ws["B3"] = 200
|
||||
wb.save(path)
|
||||
|
||||
|
||||
def _make_xls(path: str) -> None:
|
||||
"""生成最小 .xls 测试文件。"""
|
||||
from xlwt import Workbook
|
||||
wb = Workbook()
|
||||
ws = wb.add_sheet("Sheet1")
|
||||
ws.write(0, 0, "名称")
|
||||
ws.write(0, 1, "金额")
|
||||
ws.write(1, 0, "项目A")
|
||||
ws.write(1, 1, 100)
|
||||
ws.write(2, 0, "项目B")
|
||||
ws.write(2, 1, 200)
|
||||
wb.save(path)
|
||||
|
||||
|
||||
class TestMultiFormatParsers:
|
||||
"""测试 file_parser.py 的多格式解析器。"""
|
||||
|
||||
def test_parse_xlsx(self):
|
||||
from backend.file_parser import parse_file
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".xlsx", delete=False) as tmp:
|
||||
path = tmp.name
|
||||
try:
|
||||
_make_xlsx(path)
|
||||
result = parse_file(path, ".xlsx")
|
||||
assert result["file_type"] == "xlsx"
|
||||
assert result["method"] == "openpyxl"
|
||||
assert result["error"] is None
|
||||
assert "Sheet1" in result["text"]
|
||||
assert "项目A" in result["text"]
|
||||
assert "100" in result["text"]
|
||||
finally:
|
||||
Path(path).unlink(missing_ok=True)
|
||||
|
||||
def test_parse_xls(self):
|
||||
from backend.file_parser import parse_file
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".xls", delete=False) as tmp:
|
||||
path = tmp.name
|
||||
try:
|
||||
_make_xls(path)
|
||||
result = parse_file(path, ".xls")
|
||||
assert result["file_type"] == "xls"
|
||||
assert result["method"] == "xlrd"
|
||||
assert result["error"] is None
|
||||
assert "Sheet1" in result["text"]
|
||||
assert "项目A" in result["text"]
|
||||
assert "100.0" in result["text"]
|
||||
finally:
|
||||
Path(path).unlink(missing_ok=True)
|
||||
|
||||
def test_parse_doc_nonexistent(self):
|
||||
"""测试 .doc 文件不存在时的错误处理。"""
|
||||
from backend.file_parser import parse_file
|
||||
|
||||
result = parse_file("/nonexistent/file.doc", ".doc")
|
||||
assert result["file_type"] == ".doc"
|
||||
assert result["method"] == "none"
|
||||
assert result.get("error") is not None
|
||||
|
||||
def test_dispatch_adds_new_formats(self):
|
||||
"""验证新格式已在 parse_file 调度表中注册。"""
|
||||
from backend.file_parser import parse_file
|
||||
|
||||
for ext in [".xlsx", ".xls", ".doc"]:
|
||||
result = parse_file("/tmp/test" + ext, ext)
|
||||
assert result["file_type"] in (ext, "xlsx", "xls", "doc")
|
||||
@@ -0,0 +1,267 @@
|
||||
"""测试分层精确生成:extract_layout_schema, _format_row_coordinates, 路由逻辑。"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# 确保项目根在 path 中
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from backend.layout_analyzer import extract_layout_schema
|
||||
from agent.nodes import _format_row_coordinates
|
||||
from agent.graph import route_after_retrieve
|
||||
from agent.state import AgentState
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 测试夹具
|
||||
# ============================================================
|
||||
|
||||
def _make_row(elements: list[dict]) -> dict:
|
||||
y_center = sum(e["y"] + e["h"] / 2 for e in elements) / len(elements) if elements else 0
|
||||
return {"y_center": round(y_center, 1), "elements": elements}
|
||||
|
||||
|
||||
def _make_elem(x: float, y: float, w: float, h: float, text: str) -> dict:
|
||||
return {"x": x, "y": y, "w": w, "h": h, "font_size": h, "text": text}
|
||||
|
||||
|
||||
def _make_5col_10row_layout() -> dict:
|
||||
"""构造一个标准的 5列 x 10行 报表布局。"""
|
||||
rows = []
|
||||
for ri in range(10):
|
||||
y = 100 + ri * 30
|
||||
if ri == 0:
|
||||
texts = ["员工名册"]
|
||||
xs = [300]
|
||||
ws = [100]
|
||||
elif ri == 1:
|
||||
texts = ["序号", "姓名", "部门", "职位", "入职日期"]
|
||||
xs = [50, 150, 300, 450, 550]
|
||||
ws = [40, 80, 100, 80, 80]
|
||||
else:
|
||||
texts = [str(ri - 1), f"员工{ri-1}", "技术部", "工程师", "2024-01-01"]
|
||||
xs = [50, 150, 300, 450, 550]
|
||||
ws = [40, 80, 100, 80, 80]
|
||||
elements = [_make_elem(xs[i], y, ws[i], 24, texts[i]) for i in range(len(texts))]
|
||||
rows.append(_make_row(elements))
|
||||
return {"rows": rows, "image_size": (595, 842), "total_rows": 10, "total_elements": 46}
|
||||
|
||||
|
||||
# ============================================================
|
||||
# TestExtractLayoutSchema
|
||||
# ============================================================
|
||||
|
||||
class TestExtractLayoutSchema:
|
||||
"""extract_layout_schema() 单元测试。"""
|
||||
|
||||
def test_basic_table(self):
|
||||
"""5列x10行 布局 → 正确列数和区域分类。"""
|
||||
layout = _make_5col_10row_layout()
|
||||
schema = extract_layout_schema(layout)
|
||||
assert schema["total_columns"] == 5
|
||||
assert schema["total_rows"] == 10
|
||||
assert len(schema["columns"]) == 5
|
||||
assert len(schema["regions"]) >= 3 # title + header + data at minimum
|
||||
|
||||
def test_title_detection(self):
|
||||
"""第 0 行只有 1 个元素 → 标题区域。"""
|
||||
layout = _make_5col_10row_layout()
|
||||
schema = extract_layout_schema(layout)
|
||||
region_types = [r["type"] for r in schema["regions"]]
|
||||
assert "title" in region_types
|
||||
|
||||
def test_footer_detection(self):
|
||||
"""尾部行元素更少 → 表尾区域。"""
|
||||
rows = []
|
||||
for ri in range(8):
|
||||
y = 100 + ri * 30
|
||||
elements = [_make_elem(50 + i * 100, y, 80, 24, f"col{i}") for i in range(5)]
|
||||
rows.append(_make_row(elements))
|
||||
# 最后一行只有 1 个元素(如合计)
|
||||
rows.append(_make_row([_make_elem(300, 340, 100, 24, "合计: 100")]))
|
||||
layout = {"rows": rows, "image_size": (595, 842), "total_rows": 9, "total_elements": 41}
|
||||
schema = extract_layout_schema(layout)
|
||||
region_types = [r["type"] for r in schema["regions"]]
|
||||
assert "footer" in region_types
|
||||
|
||||
def test_empty_layout(self):
|
||||
"""空行 → 返回零值 schema。"""
|
||||
schema = extract_layout_schema({"rows": [], "image_size": (0, 0)})
|
||||
assert schema["total_rows"] == 0
|
||||
assert schema["total_columns"] == 0
|
||||
assert schema["columns"] == []
|
||||
assert "无法解析" in schema["schema_text"]
|
||||
|
||||
def test_single_row(self):
|
||||
"""单行 → 全部为 data。"""
|
||||
layout = {
|
||||
"rows": [_make_row([_make_elem(10, 10, 50, 20, "test")])],
|
||||
"image_size": (200, 200),
|
||||
}
|
||||
schema = extract_layout_schema(layout)
|
||||
assert schema["total_rows"] == 1
|
||||
assert schema["total_columns"] == 1
|
||||
|
||||
def test_two_rows(self):
|
||||
"""两行 → header + data。"""
|
||||
rows = [
|
||||
_make_row([_make_elem(10, 10, 50, 20, "表头")]),
|
||||
_make_row([_make_elem(10, 40, 50, 20, "数据")]),
|
||||
]
|
||||
schema = extract_layout_schema({"rows": rows, "image_size": (200, 200)})
|
||||
assert schema["total_rows"] == 2
|
||||
|
||||
def test_schema_text_format(self):
|
||||
"""schema_text 包含中文标签(列定义、区域)。"""
|
||||
layout = _make_5col_10row_layout()
|
||||
schema = extract_layout_schema(layout)
|
||||
assert "列定义" in schema["schema_text"]
|
||||
assert "区域" in schema["schema_text"]
|
||||
assert "A4纵向" in schema["schema_text"]
|
||||
|
||||
def test_column_width_categories(self):
|
||||
"""宽度分类:窄/中/宽 标签存在。"""
|
||||
layout = _make_5col_10row_layout()
|
||||
schema = extract_layout_schema(layout)
|
||||
text = schema["schema_text"]
|
||||
assert any(w in text for w in ("窄", "中", "宽"))
|
||||
|
||||
def test_a4_dimensions(self):
|
||||
"""A4 尺寸固定为 595x842。"""
|
||||
layout = _make_5col_10row_layout()
|
||||
schema = extract_layout_schema(layout)
|
||||
assert schema["a4_dimensions"] == {"width": 595, "height": 842}
|
||||
|
||||
|
||||
# ============================================================
|
||||
# TestFormatRowCoordinates
|
||||
# ============================================================
|
||||
|
||||
class TestFormatRowCoordinates:
|
||||
"""_format_row_coordinates() 单元测试。"""
|
||||
|
||||
def test_formats_single_row(self):
|
||||
"""单行 → columns 列表包含正确字段。"""
|
||||
row = _make_row([
|
||||
_make_elem(10, 100, 50, 20, "序号"),
|
||||
_make_elem(60, 100, 80, 20, "姓名"),
|
||||
])
|
||||
result = _format_row_coordinates(row)
|
||||
assert len(result["columns"]) == 2
|
||||
assert result["y_center"] > 0
|
||||
assert "col" in result["columns"][0]
|
||||
assert "text" in result["columns"][0]
|
||||
|
||||
def test_sorts_by_x(self):
|
||||
"""元素按 x 坐标从左到右排序。"""
|
||||
row = _make_row([
|
||||
_make_elem(200, 100, 80, 20, "right"),
|
||||
_make_elem(10, 100, 50, 20, "left"),
|
||||
])
|
||||
result = _format_row_coordinates(row)
|
||||
assert result["columns"][0]["text"] == "left"
|
||||
assert result["columns"][1]["text"] == "right"
|
||||
|
||||
def test_empty_row(self):
|
||||
"""空元素列表 → columns 为空。"""
|
||||
result = _format_row_coordinates({"y_center": 100, "elements": []})
|
||||
assert result["columns"] == []
|
||||
|
||||
def test_non_dict_input(self):
|
||||
"""非 dict 输入 → 返回空 dict。"""
|
||||
assert _format_row_coordinates(None) == {}
|
||||
assert _format_row_coordinates("not a dict") == {}
|
||||
|
||||
|
||||
# ============================================================
|
||||
# TestRouting
|
||||
# ============================================================
|
||||
|
||||
class TestRouting:
|
||||
"""route_after_retrieve() 路由逻辑测试。"""
|
||||
|
||||
def _state(self, **kwargs) -> AgentState:
|
||||
s = {"current_jrxml": "", "user_input": "test", "conversation_history": []}
|
||||
s.update(kwargs)
|
||||
return s
|
||||
|
||||
def test_with_schema(self):
|
||||
"""layout_schema 有行 → generate_skeleton。"""
|
||||
state = self._state(layout_schema={"total_rows": 5, "total_columns": 3})
|
||||
assert route_after_retrieve(state) == "generate_skeleton"
|
||||
|
||||
def test_without_schema(self):
|
||||
"""无 layout_schema → generate。"""
|
||||
state = self._state()
|
||||
assert route_after_retrieve(state) == "generate"
|
||||
|
||||
def test_empty_schema(self):
|
||||
"""空 dict → generate。"""
|
||||
state = self._state(layout_schema={})
|
||||
assert route_after_retrieve(state) == "generate"
|
||||
|
||||
def test_zero_rows(self):
|
||||
"""total_rows = 0 → generate。"""
|
||||
state = self._state(layout_schema={"total_rows": 0, "total_columns": 0})
|
||||
assert route_after_retrieve(state) == "generate"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# TestIntegration
|
||||
# ============================================================
|
||||
|
||||
class TestIntegration:
|
||||
"""集成测试:验证图执行路径。"""
|
||||
|
||||
def test_three_phase_nodes_exist(self):
|
||||
"""三个新节点可被导入。"""
|
||||
from agent.nodes import generate_skeleton, refine_layout, map_fields
|
||||
assert callable(generate_skeleton)
|
||||
assert callable(refine_layout)
|
||||
assert callable(map_fields)
|
||||
|
||||
def test_graph_builds_with_new_nodes(self):
|
||||
"""图构建成功包含新节点。"""
|
||||
from agent.graph import build_graph
|
||||
graph = build_graph()
|
||||
nodes = graph.get_graph().nodes
|
||||
node_names = {n for n in nodes}
|
||||
assert "generate_skeleton" in node_names
|
||||
assert "refine_layout" in node_names
|
||||
assert "map_fields" in node_names
|
||||
|
||||
def test_one_shot_path_unchanged(self):
|
||||
"""无 schema 时 route_after_retrieve → generate。"""
|
||||
from agent.graph import route_after_retrieve as rar
|
||||
state = {"layout_schema": {}, "current_jrxml": "", "user_input": "test"}
|
||||
assert rar(state) == "generate"
|
||||
|
||||
def test_modify_report_unaffected(self):
|
||||
"""modify_report 意图不受影响 — route_by_intent 不变。"""
|
||||
from agent.graph import route_by_intent
|
||||
state = {
|
||||
"intent": "modify_report",
|
||||
"current_jrxml": "<jasperReport>...</jasperReport>",
|
||||
"layout_schema": {"total_rows": 5},
|
||||
}
|
||||
assert route_by_intent(state) == "modify_jrxml"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# TestLayoutSchemaJSONRoundTrip
|
||||
# ============================================================
|
||||
|
||||
class TestLayoutSchemaJSONRoundTrip:
|
||||
"""验证 schema 可被 JSON 序列化/反序列化(用于会话持久化)。"""
|
||||
|
||||
def test_serializable(self):
|
||||
"""extract_layout_schema 输出可 JSON 序列化。"""
|
||||
layout = _make_5col_10row_layout()
|
||||
schema = extract_layout_schema(layout)
|
||||
dumped = json.dumps(schema, ensure_ascii=False)
|
||||
loaded = json.loads(dumped)
|
||||
assert loaded["total_rows"] == schema["total_rows"]
|
||||
assert loaded["total_columns"] == schema["total_columns"]
|
||||
Reference in New Issue
Block a user