feat: v4 multimodal chat input, multi-format support, and annotation detection

- Replace st.chat_input with st-multimodal-chatinput (Ctrl+V paste, drag-drop, file button)
- Extract _process_uploaded_file() shared handler (eliminates ~70 duplicated lines)
- Add XLSX (openpyxl), XLS (xlrd), DOC (olefile) parsers to file_parser.py
- Add backend/annotation_detector.py: circle detection (HoughCircles) + arrow detection (HoughLinesP clustering) + OCR correlation + LLM context formatting
- Add annotation_result field to AgentState with session persistence
- Wire annotation detection into process_input and _format_ocr_context
- Add 11 new tests: 7 annotation detector + 4 multi-format parser
- Update all docs: CLAUDE.md, README.md, CODE_GUIDE.md, ROADMAP.md
This commit is contained in:
2026-05-20 23:43:16 +08:00
parent c9f003e1b7
commit 9bb011e429
16 changed files with 1257 additions and 164 deletions
+43 -5
View File
@@ -20,7 +20,7 @@ STREAMLIT_SERVER_HEADLESS=true streamlit run app.py --server.port 8501
## 当前配置(.env ## 当前配置(.env
- **OCR**: EasyOCR优先ch_sim+en→ PaddleOCR(回退),两者均未安装时仅返回图片元信息 - **OCR**: PaddleOCR(精确识别首选,ppocr-v4)→ EasyOCR回退,ch_sim+en),两者均未安装时仅返回图片元信息
- **LLM**: `cloud` / `anthropic` → MiniMax Anthropic 兼容 API (`MiniMax-M2.7`) - **LLM**: `cloud` / `anthropic` → MiniMax Anthropic 兼容 API (`MiniMax-M2.7`)
- Base URL: `https://api.minimaxi.com/anthropic` - Base URL: `https://api.minimaxi.com/anthropic`
- 认证: Anthropic SDK 自动读取 `ANTHROPIC_API_KEY`fallback `OPENAI_API_KEY` - 认证: Anthropic SDK 自动读取 `ANTHROPIC_API_KEY`fallback `OPENAI_API_KEY`
@@ -55,8 +55,10 @@ agent/graph.py (LangGraph 状态机)
├──► backend/logger.py 集中日志: JSON + trace_id + llm.log/app.log 分离 ├──► backend/logger.py 集中日志: JSON + trace_id + llm.log/app.log 分离
├──► backend/rag_adapter.py 语义搜索: ChromaDB + SentenceTransformer ├──► backend/rag_adapter.py 语义搜索: ChromaDB + SentenceTransformer
├──► backend/error_kb.py 错误知识库: 指纹去重 + ChromaDB 持久化 ├──► backend/error_kb.py 错误知识库: 指纹去重 + ChromaDB 持久化
├──► backend/file_parser.py 文件解析: PDF/DOCX/图片/文本 ├──► backend/file_parser.py 文件解析: PDF/DOCX/XLSX/XLS/DOC/图片/文本
├──► backend/layout_analyzer.py A4布局分析: OCR + 行分组 + JRXML行匹配 ├──► backend/layout_analyzer.py A4布局分析: OCR + 行分组 + JRXML行匹配
├──► backend/ocr_extractor.py OCR字段精确提取: 4策略优先级 + 置信度
├──► backend/annotation_detector.py 批注检测: 圈选(HoughCircles) + 箭头(HoughLinesP) + OCR关联
├──► backend/validation.py HTTP 客户端: POST /validate ├──► backend/validation.py HTTP 客户端: POST /validate
├──► backend/session.py 会话持久化: JSON 文件 CRUD ├──► backend/session.py 会话持久化: JSON 文件 CRUD
└──► validation_service/ 独立 FastAPI: 结构检查 + XSD 校验 └──► validation_service/ 独立 FastAPI: 结构检查 + XSD 校验
@@ -67,7 +69,7 @@ agent/graph.py (LangGraph 状态机)
| 文件 | 职责 | 修改频率 | | 文件 | 职责 | 修改频率 |
|------|------|---------| |------|------|---------|
| `app.py` | Streamlit UI 入口,聊天界面 + 侧边栏 + 下载 + 文件上传 | **高** | | `app.py` | Streamlit UI 入口,聊天界面 + 侧边栏 + 下载 + 文件上传 | **高** |
| `agent/state.py` | AgentState 类型定义(~24 字段,含 pending_failure_context | 低 | | `agent/state.py` | AgentState 类型定义(~26 字段,含 pending_failure_context / annotation_result | 低 |
| `agent/nodes.py` | 14 个工作流节点 + 流式生成 + 错误记录 | **高** | | `agent/nodes.py` | 14 个工作流节点 + 流式生成 + 错误记录 | **高** |
| `agent/graph.py` | 状态图编译 + 路由函数(预览跳过验证) | 中 | | `agent/graph.py` | 状态图编译 + 路由函数(预览跳过验证) | 中 |
| `prompts/loader.py` | Prompt 加载器(从 .md 文件热重载) | 低 | | `prompts/loader.py` | Prompt 加载器(从 .md 文件热重载) | 低 |
@@ -76,8 +78,10 @@ agent/graph.py (LangGraph 状态机)
| `backend/logger.py` | 集中日志模块:JSON 格式化 + trace_id + 独立 llm.log | 低 | | `backend/logger.py` | 集中日志模块:JSON 格式化 + trace_id + 独立 llm.log | 低 |
| `backend/rag_adapter.py` | RAGSearcher 单例,语义搜索接口 | 中 | | `backend/rag_adapter.py` | RAGSearcher 单例,语义搜索接口 | 中 |
| `backend/error_kb.py` | ErrorKB — 错误指纹去重 + ChromaDB 持久化 + 语义检索 | 中 | | `backend/error_kb.py` | ErrorKB — 错误指纹去重 + ChromaDB 持久化 + 语义检索 | 中 |
| `backend/file_parser.py` | 文件解析: PDF/DOCX/图片(EasyOCR→PaddleOCR回退)/文本 | 中 | | `backend/file_parser.py` | 文件解析: PDF/DOCX/XLSX/XLS/DOC/图片(EasyOCR→PaddleOCR回退)/文本 | 中 |
| `backend/layout_analyzer.py` | A4模板分析: 比例检测/EasyOCR→PaddleOCR元素提取/行分组/JRXML行匹配 | 中 | | `backend/layout_analyzer.py` | A4模板分析: 比例检测/EasyOCR→PaddleOCR元素提取/行分组/JRXML行匹配 | 中 |
| `backend/ocr_extractor.py` | OCR字段精确提取: 4策略(exact→kv_pair→regex→table_match) + 置信度 | 中 |
| `backend/annotation_detector.py` | 批注检测: 圈选(cv2 HoughCircles) + 箭头(HoughLinesP聚类) + OCR关联 + LLM格式化 | 中 |
| `backend/embeddings.py` | 嵌入模型工厂 (HuggingFace/OpenAI) | 低 | | `backend/embeddings.py` | 嵌入模型工厂 (HuggingFace/OpenAI) | 低 |
| `backend/validation.py` | 验证服务 HTTP 客户端 | 低 | | `backend/validation.py` | 验证服务 HTTP 客户端 | 低 |
| `backend/session.py` | 会话 JSON 文件 CRUD | 低 | | `backend/session.py` | 会话 JSON 文件 CRUD | 低 |
@@ -156,6 +160,37 @@ agent/graph.py (LangGraph 状态机)
- `@log_node` / `@_log_route` — 装饰器自动记录节点和路由 - `@log_node` / `@_log_route` — 装饰器自动记录节点和路由
- 日志分离: `logs/app.log` (业务) + `logs/llm.log` (AI 调用) - 日志分离: `logs/app.log` (业务) + `logs/llm.log` (AI 调用)
## 新增功能 (v3/v4)
### OCR 单据字段精确提取 (v3)
- `backend/ocr_extractor.py` — 4 策略优先级提取: exact_match → kv_pair → regex → table_match
- PaddleOCR 首次识别后将原始结果(含所有文本元素 + bbox坐标)持久化
- `_format_ocr_context()` — 将 OCR 结果(字段 + 原始元素坐标)格式化为 LLM prompt 注入
- OCR 结果在 `modify_jrxml``generate` 节点中自动注入 prompt
- `process_input` 节点在上传图片时自动触发 OCR 字段提取
- 结果持久化到会话文件(`save_session_node` / `load_session_node`
### 多模态聊天输入 + 多格式文件 (v4)
- `app.py``st.chat_input` 替换为 `st_multimodal_chatinput`(支持 Ctrl+V 粘贴 + 拖拽 + 文件按钮)
- `_process_uploaded_file()` — 提取共享文件处理逻辑(侧边栏 + 聊天共用,消除 ~70 行重复代码)
- 新增文件格式支持: XLSX (openpyxl)、XLS (xlrd)、DOC (olefile)
- 剪贴板粘贴文件通过 base64 解码 + MIME type → 扩展名推断
- 侧边栏上传器类型列表中新增 xlsx/xls/doc
### 批注检测 (v4)
- `backend/annotation_detector.py` — 识别用户在手写单据上的圈选和箭头标记
- **圆圈检测**: 红色通道增强 → HoughCircles → 圆形度验证
- **箭头检测**: Canny边缘 → HoughLinesP → 线段方向聚类 → 端点边缘密度判定方向
- **OCR 关联**: 批注与附近 OCR 文本元素关联(15% 图片尺寸内)
- **LLM 注入**: `format_annotation_context()` 将批注结果格式化为中文提示
- `process_input` 节点在 OCR 提取后自动运行批注检测
- `annotation_result` 字段持久化到 AgentState + 会话文件
### OCR 上下文提示增强 (v3/v4)
- `prompts/modification.md` — 新增 `{ocr_context}` 占位符
- `modify_jrxml` 节点 — 将 OCR 上下文注入 modification prompt
- OCR 上下文包含: 结构化字段、全部文本元素(含坐标)、批注检测结果
## 已知注意点 ## 已知注意点
- **Anthropic SDK**: 使用原始 `anthropic` 包(非 `langchain-anthropic`),因为需要直连 MiniMax 兼容端点。API Key 优先读 `ANTHROPIC_API_KEY`fallback `OPENAI_API_KEY`。Anthropic SDK 会自动将 key 放入 `x-api-key` header。 - **Anthropic SDK**: 使用原始 `anthropic` 包(非 `langchain-anthropic`),因为需要直连 MiniMax 兼容端点。API Key 优先读 `ANTHROPIC_API_KEY`fallback `OPENAI_API_KEY`。Anthropic SDK 会自动将 key 放入 `x-api-key` header。
@@ -165,7 +200,10 @@ agent/graph.py (LangGraph 状态机)
- **验证服务结构检查**: 字段引用一致性 (`$F{field}` vs `<field>` 声明)、SQL SELECT 存在性、pageWidth/pageHeight/name 属性。 - **验证服务结构检查**: 字段引用一致性 (`$F{field}` vs `<field>` 声明)、SQL SELECT 存在性、pageWidth/pageHeight/name 属性。
- **XSD 校验可选**: 需要 `validation_service/schemas/jasperreport_7_0_6.xsd` 存在。 - **XSD 校验可选**: 需要 `validation_service/schemas/jasperreport_7_0_6.xsd` 存在。
- **rag 子模块**: 内部有独立的管线脚本(`batch_chunker.py``embed_chunks.py``import_to_chroma.py`),通常不需要在主项目中运行。 - **rag 子模块**: 内部有独立的管线脚本(`batch_chunker.py``embed_chunks.py``import_to_chroma.py`),通常不需要在主项目中运行。
- **OCR 引擎**: 优先使用 EasyOCRWindows 兼容性更好`pip install easyocr`),回退 PaddleOCR。两者均未安装时仅返回图片元信息,建议至少安装 EasyOCR - **OCR 引擎**: 优先 PaddleOCR 2.9.x(精确识别`pip install paddleocr`),回退 EasyOCR 1.7+。两者均未安装时仅返回图片元信息。PaddlePaddle 3.x 在 Windows 上有 ONEDNN bug,固定在 2.6.x
- **MAX_RETRY**: 默认 3 次。重试耗尽后 `pending_failure_context` 记录失败信息,下次用户输入时自动注入。 - **MAX_RETRY**: 默认 3 次。重试耗尽后 `pending_failure_context` 记录失败信息,下次用户输入时自动注入。
- **验证最小内容检查**: 验证服务额外检查至少 1 个 `<band>` + 1 个 `<textField>``<staticText>`,拦截空壳 JRXML。 - **验证最小内容检查**: 验证服务额外检查至少 1 个 `<band>` + 1 个 `<textField>``<staticText>`,拦截空壳 JRXML。
- **torchvision**: `transformers` 库的懒加载需要 `torchvision`,已作为依赖安装。 - **torchvision**: `transformers` 库的懒加载需要 `torchvision`,已作为依赖安装。
- **opencv-python-headless**: 批注检测(圈选/箭头)依赖,通过 `pip install -r requirements.txt` 安装。
- **st-multimodal-chatinput**: Streamlit 聊天输入增强组件,替代 `st.chat_input`,支持粘贴/拖拽文件。返回 base64 编码文件内容。
- **xlwt**: 仅在测试中使用(生成 .xls 测试文件)。
+14 -6
View File
@@ -751,14 +751,20 @@ def parse_file(file_path, file_type="") -> dict:
# .png/.jpg/.jpeg/.bmp/.webp → _parse_image() # .png/.jpg/.jpeg/.bmp/.webp → _parse_image()
# .pdf → _parse_pdf() # .pdf → _parse_pdf()
# .docx → _parse_docx() # .docx → _parse_docx()
# .xlsx → _parse_xlsx()
# .xls → _parse_xls()
# .doc → _parse_doc()
# 其他 → _parse_text() (UTF-8 / GBK) # 其他 → _parse_text() (UTF-8 / GBK)
``` ```
### 各解析器的回退链 ### 各解析器的回退链
- **图片**EasyOCRch_sim+en)→ PaddleOCR → 仅返回元信息 + 安装提示 - **图片**PaddleOCR(精确识别首选)→ EasyOCRch_sim+en)→ 仅返回元信息 + 安装提示
- **PDF**pdfplumber → PyMuPDF → 失败 - **PDF**pdfplumber → PyMuPDF → 失败
- **DOCX**python-docx(含表格内容提取)→ 失败 - **DOCX**python-docx(含表格内容提取)→ 失败
- **XLSX**openpyxl(含多 sheet 支持)→ 失败
- **XLS**xlrd(旧版 Excel 格式)→ 失败
- **DOC**olefile(二进制格式,尽力而为提取)→ 失败
- **文本**UTF-8 → GBK → 失败 - **文本**UTF-8 → GBK → 失败
--- ---
@@ -1158,20 +1164,22 @@ st.json(state) # 打印完整状态(调试用,记得删除)
| 文件 | 行数 | 角色 | | 文件 | 行数 | 角色 |
|------|------|------| |------|------|------|
| `app.py` | ~530 | Streamlit UI 入口 | | `app.py` | ~670 | Streamlit UI 入口(多模态聊天输入) |
| `agent/state.py` | ~40 | 状态类型定义 | | `agent/state.py` | ~48 | 状态类型定义26 字段) |
| `agent/nodes.py` | ~523 | 14 个工作流节点 | | `agent/nodes.py` | ~740 | 15 个工作流节点 |
| `agent/graph.py` | ~232 | 状态图编译 + 路由 | | `agent/graph.py` | ~232 | 状态图编译 + 路由 |
| `backend/llm.py` | ~105 | LLM 工厂 (3 个后端) | | `backend/llm.py` | ~105 | LLM 工厂 (3 个后端) |
| `backend/rag_adapter.py` | ~156 | ChromaDB 语义搜索 | | `backend/rag_adapter.py` | ~156 | ChromaDB 语义搜索 |
| `backend/error_kb.py` | ~226 | 错误知识库 | | `backend/error_kb.py` | ~226 | 错误知识库 |
| `backend/embeddings.py` | ~49 | 嵌入模型工厂 | | `backend/embeddings.py` | ~49 | 嵌入模型工厂 |
| `backend/file_parser.py` | ~194 | 多格式文件解析 | | `backend/file_parser.py` | ~320 | 多格式文件解析7 种格式) |
| `backend/layout_analyzer.py` | ~495 | A4 模板布局分析 | | `backend/layout_analyzer.py` | ~495 | A4 模板布局分析 |
| `backend/ocr_extractor.py` | ~380 | OCR 字段精确提取 |
| `backend/annotation_detector.py` | ~250 | 批注检测(圈选 + 箭头) |
| `backend/validation.py` | ~27 | 验证服务 HTTP 客户端 | | `backend/validation.py` | ~27 | 验证服务 HTTP 客户端 |
| `backend/session.py` | ~113 | 会话 JSON CRUD | | `backend/session.py` | ~113 | 会话 JSON CRUD |
| `prompts/loader.py` | ~54 | Prompt 热重载 | | `prompts/loader.py` | ~54 | Prompt 热重载 |
| `prompts/*.md` (7 个) | — | Prompt 模板 | | `prompts/*.md` (7 个) | — | Prompt 模板 |
| `validation_service/main.py` | ~130 | FastAPI 验证服务 | | `validation_service/main.py` | ~130 | FastAPI 验证服务 |
| `.env.example` | ~62 | 配置模板 | | `.env.example` | ~62 | 配置模板 |
| `requirements.txt` | ~32 | Python 依赖 | | `requirements.txt` | ~42 | Python 依赖 |
+14 -4
View File
@@ -8,6 +8,10 @@
- **自动验证**:每次生成或修改后都会验证 JRXML - **自动验证**:每次生成或修改后都会验证 JRXML
- **自动修正**:如果验证失败,代理会分析错误并自动修正(最多 3 次) - **自动修正**:如果验证失败,代理会分析错误并自动修正(最多 3 次)
- **模板检索**:使用 Chroma 向量数据库检索相关的 JRXML 示例以获得更好的生成效果 - **模板检索**:使用 Chroma 向量数据库检索相关的 JRXML 示例以获得更好的生成效果
- **文件上传**:支持图片(OCR识别)、PDF、Word、Excel、文本文件等
- **聊天粘贴/拖拽**:支持直接在对话框中 Ctrl+V 粘贴或拖拽文件(图片/PDF/Excel/Word
- **单据OCR识别**:上传报表单据图片后自动提取所有字段(4策略优先级 + 置信度评分)
- **批注检测**:识别手写单据上的圈选和箭头标记,自动定位用户要修改的字段
- **下载**:导出已验证的、可供 JasperReports 使用的 JRXML 文件 - **下载**:导出已验证的、可供 JasperReports 使用的 JRXML 文件
## 架构 ## 架构
@@ -105,10 +109,10 @@ pytest tests/ -v
``` ```
jrxml-agent/ jrxml-agent/
app.py Streamlit 聊天界面 app.py Streamlit 聊天界面(多模态输入)
agent/ agent/
state.py AgentState 定义 state.py AgentState 定义26 字段)
nodes.py 图节点(generate, validate, modify 等) nodes.py 图节点(generate, validate, modify 等15 节点
graph.py LangGraph 状态机 graph.py LangGraph 状态机
backend/ backend/
llm.py LLM 工厂(Anthropic SDK / OpenAI / Ollama llm.py LLM 工厂(Anthropic SDK / OpenAI / Ollama
@@ -117,8 +121,10 @@ jrxml-agent/
validation.py 验证服务客户端 validation.py 验证服务客户端
rag_adapter.py RAG 语义搜索适配器 rag_adapter.py RAG 语义搜索适配器
error_kb.py 错误自增长知识库 error_kb.py 错误自增长知识库
file_parser.py 文件解析器(PDF/DOCX/图片 file_parser.py 文件解析器(PDF/DOCX/XLSX/XLS/DOC/图片/文本
layout_analyzer.py A4 模板布局分析 layout_analyzer.py A4 模板布局分析
ocr_extractor.py OCR 字段精确提取(4 策略 + 置信度)
annotation_detector.py 批注检测(圈选 + 箭头 + OCR 关联)
session.py 会话持久化 CRUD session.py 会话持久化 CRUD
prompts/ prompts/
loader.py Prompt 加载器(热重载) loader.py Prompt 加载器(热重载)
@@ -137,6 +143,10 @@ jrxml-agent/
tests/ tests/
test_validation.py 验证服务测试 test_validation.py 验证服务测试
test_agent.py 代理集成测试 test_agent.py 代理集成测试
test_e2e_ocr.py OCR 端到端测试
test_ocr_extraction.py OCR 字段提取单元测试
test_annotation_detector.py 批注检测测试
test_file_parser_formats.py 多格式解析测试
requirements.txt requirements.txt
.env.example .env.example
README.md README.md
+39 -1
View File
@@ -122,4 +122,42 @@
10. 结构化日志系统 10. 结构化日志系统
``` ```
阶段一立即可做,无外部依赖。阶段二是主要工作量。阶段三是收尾。阶段四是可观测性基础。 ---
## 阶段五:OCR 与智能上传 (v3/v4) ✓
### 11. OCR 单据字段精确提取 ✓
- [x] `backend/ocr_extractor.py` — 4 策略优先级提取 (exact_match → kv_pair → regex → table_match)
- [x] PaddleOCR 首次识别后将原始结果(含所有文本元素 + bbox坐标)持久化
- [x] `_format_ocr_context()` — OCR 结果格式化为 LLM prompt 注入
- [x] `process_input` 节点在上传图片时自动触发 OCR 字段提取
- [x] OCR 结果持久化到会话文件
### 12. 多模态聊天输入 ✓
- [x] `app.py``st.chat_input` 替换为 `st_multimodal_chatinput`
- [x] 支持 Ctrl+V 粘贴文件 + 拖拽 + 文件按钮
- [x] `_process_uploaded_file()` — 提取共享文件处理逻辑(消除 ~70 行重复代码)
- [x] 剪贴板文件 base64 解码 + MIME type → 扩展名推断
### 13. 多格式文件支持 ✓
- [x] `backend/file_parser.py` — 新增 XLSX (openpyxl)、XLS (xlrd)、DOC (olefile)
- [x] 侧边栏上传器类型列表中新增 xlsx/xls/doc
- [x] 单元测试: `tests/test_file_parser_formats.py` (4 tests)
### 14. 批注检测 ✓
- [x] `backend/annotation_detector.py` — 圈选 + 箭头 + OCR 关联
- [x] 圆圈检测: 红色通道增强 → HoughCircles
- [x] 箭头检测: Canny → HoughLinesP → 线段聚类 → 端点方向判定
- [x] `format_annotation_context()` — 批注结果格式化为中文提示
- [x] `process_input` 节点在 OCR 提取后自动运行批注检测
- [x] `annotation_result` 字段持久化到 AgentState + 会话文件
- [x] 单元测试: `tests/test_annotation_detector.py` (7 tests)
### 15. OCR 上下文 LLM 注入 ✓
- [x] `prompts/modification.md` — 新增 `{ocr_context}` 占位符
- [x] `modify_jrxml` + `generate` 节点注入 OCR 上下文
- [x] OCR 上下文包含: 结构化字段、全部文本元素(含坐标)、批注检测结果
---
阶段一立即可做,无外部依赖。阶段二是主要工作量。阶段三是收尾。阶段四是可观测性基础。阶段五是 OCR 智能增强和用户体验改进。
+84 -3
View File
@@ -134,6 +134,23 @@ def process_input(state: AgentState) -> Dict:
"fields": len(ocr_result.get("fields", [])), "fields": len(ocr_result.get("fields", [])),
}, },
) )
# 批注检测(圈选/箭头标记)
elements = ocr_result.get("elements", [])
if elements:
try:
from backend.annotation_detector import detect_annotations
ann_result = detect_annotations(uploaded_path, elements)
if ann_result.get("total", 0) > 0:
state["annotation_result"] = ann_result
_node_log.info(
"批注检测完成",
extra={
"circles": len(ann_result.get("circles", [])),
"arrows": len(ann_result.get("arrows", [])),
},
)
except Exception as e:
_node_log.warning(f"批注检测失败: {e}")
except Exception as e: except Exception as e:
_node_log.warning(f"OCR 字段提取失败: {e}") _node_log.warning(f"OCR 字段提取失败: {e}")
state["ocr_extraction_result"] = {"error": str(e)} state["ocr_extraction_result"] = {"error": str(e)}
@@ -359,7 +376,9 @@ def load_session_node(state: AgentState) -> Dict:
# 恢复核心字段(不覆盖当前请求的 user_input / stage # 恢复核心字段(不覆盖当前请求的 user_input / stage
for key in ("conversation_history", "full_conversation_history", for key in ("conversation_history", "full_conversation_history",
"current_jrxml", "final_jrxml", "compressed_history", "current_jrxml", "final_jrxml", "compressed_history",
"session_name", "created_at", "history_states"): "session_name", "created_at", "history_states",
"ocr_extraction_result", "uploaded_file_path",
"annotation_result"):
if key in saved and key not in ("user_input", "stage"): if key in saved and key not in ("user_input", "stage"):
state[key] = saved[key] state[key] = saved[key]
state["session_name"] = data.get("session_name", "") state["session_name"] = data.get("session_name", "")
@@ -381,7 +400,9 @@ def save_session_node(state: AgentState) -> Dict:
persistable = {} persistable = {}
for key in ("conversation_history", "full_conversation_history", for key in ("conversation_history", "full_conversation_history",
"current_jrxml", "final_jrxml", "compressed_history", "current_jrxml", "final_jrxml", "compressed_history",
"status", "error_msg", "history_states"): "status", "error_msg", "history_states",
"ocr_extraction_result", "uploaded_file_path",
"annotation_result"):
if key in state: if key in state:
persistable[key] = state[key] persistable[key] = state[key]
persistable["updated_at"] = _now_iso() persistable["updated_at"] = _now_iso()
@@ -416,6 +437,59 @@ def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat() return datetime.now(timezone.utc).isoformat()
def _format_ocr_context(state: AgentState) -> str:
"""将 OCR 提取结果格式化为 LLM 可用的上下文文本。"""
ocr_result = state.get("ocr_extraction_result")
if not ocr_result or not isinstance(ocr_result, dict):
return ""
if ocr_result.get("error"):
return ""
parts = []
parts.append("[图片OCR识别结果]")
total = ocr_result.get("total_elements", 0)
if total:
parts.append(f"检测到 {total} 个文字元素")
# 提取到的字段
fields = ocr_result.get("fields", [])
if fields:
parts.append("\n提取的结构化字段:")
for f in fields:
if f.get("field_value"):
parts.append(
f" - {f['field_name']}: {f['field_value']} "
f"(方法={f.get('extraction_method','?')}, "
f"置信度={f.get('confidence',0):.2f})"
)
# 所有原始文本(用于表格匹配等需要全文的场景)
elements = ocr_result.get("elements", [])
if elements:
parts.append("\n全部文本元素(含坐标):")
for e in elements:
bbox = e.get("bbox", {})
x, y, w, h = bbox.get("x", 0), bbox.get("y", 0), bbox.get("w", 0), bbox.get("h", 0)
parts.append(
f" [{x},{y} {w}×{h}] {e['text']} "
f"(置信度={e.get('confidence',0):.2f})"
)
# 批注检测结果
ann_result = state.get("annotation_result")
if ann_result and isinstance(ann_result, dict):
try:
from backend.annotation_detector import format_annotation_context
ann_text = format_annotation_context(ann_result)
if ann_text:
parts.append("\n" + ann_text)
except Exception:
pass
return "\n".join(parts)
@log_node("retrieve") @log_node("retrieve")
def retrieve(state: AgentState) -> Dict: def retrieve(state: AgentState) -> Dict:
"""在 ChromaDB + 错误知识库中搜索相关的 JRXML 模板和组件。""" """在 ChromaDB + 错误知识库中搜索相关的 JRXML 模板和组件。"""
@@ -446,9 +520,15 @@ def generate(state: AgentState) -> Dict:
writer = get_stream_writer() writer = get_stream_writer()
llm = get_llm(caller="generate") llm = get_llm(caller="generate")
user_request = state.get("user_input", "")
ocr_text = _format_ocr_context(state)
if ocr_text:
user_request = f"{ocr_text}\n\n---\n用户需求:\n{user_request}"
prompt = load_prompt("initial_generation").format( prompt = load_prompt("initial_generation").format(
context=state.get("retrieved_context", ""), context=state.get("retrieved_context", ""),
user_request=state.get("user_input", ""), user_request=user_request,
) )
full = [] full = []
for chunk in llm.stream(prompt): for chunk in llm.stream(prompt):
@@ -480,6 +560,7 @@ def modify_jrxml(state: AgentState) -> Dict:
current_jrxml=state.get("current_jrxml", ""), current_jrxml=state.get("current_jrxml", ""),
conversation_history=conv_text, conversation_history=conv_text,
modification_request=state.get("user_modification_request", ""), modification_request=state.get("user_modification_request", ""),
ocr_context=_format_ocr_context(state),
) )
full = [] full = []
for chunk in llm.stream(prompt): for chunk in llm.stream(prompt):
+3
View File
@@ -44,3 +44,6 @@ class AgentState(TypedDict, total=False):
# 需求7:OCR 单据字段精确提取结果 # 需求7:OCR 单据字段精确提取结果
ocr_extraction_result: dict ocr_extraction_result: dict
uploaded_file_path: str uploaded_file_path: str
# 需求8:图片批注检测(圈选/箭头标记)
annotation_result: dict
+179 -87
View File
@@ -106,6 +106,81 @@ def _render_jrxml(jrxml: str, max_lines: int = 30):
st.code(preview, language="xml") st.code(preview, language="xml")
# ---- 共享文件上传处理 ----
def _process_uploaded_file(uploaded_file, suffix: str) -> dict:
"""处理单个上传文件:保存临时文件、解析、布局分析。
返回: {"name": str, "text": str, "type": str, "tmp_path": str|None}
"""
import tempfile
from backend.file_parser import parse_file
from backend.layout_analyzer import analyze_layout
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(uploaded_file.getvalue())
tmp_path = tmp.name
result = parse_file(tmp_path, suffix)
parsed_text = result["text"]
parsed_type = result["file_type"]
# 对图片/PDF 进行 A4 模板布局分析
if suffix in (".png", ".jpg", ".jpeg", ".bmp", ".webp", ".pdf"):
layout = analyze_layout(tmp_path)
tt = layout.get("template_type", "unknown")
current_jrxml = st.session_state.agent_state.get("current_jrxml", "")
if tt == "full_a4":
parsed_text = layout["description"]
parsed_type = "a4_template"
elif tt == "partial_rows":
parsed_type = "a4_partial"
if current_jrxml.strip():
from backend.layout_analyzer import match_rows_to_jrxml
match = match_rows_to_jrxml(layout, current_jrxml)
parsed_text = (
f"[行片段修改] 上传图片包含 {layout['total_rows']} 行,"
f"视为 A4 报表的一部分。\n\n"
f"{match['description']}\n\n"
f"--- 行结构 ---\n{layout['description']}"
)
else:
parsed_text = layout["description"]
else:
has_ocr = result.get("method") not in ("metadata_only", None)
img_w, img_h = layout["image_size"]
ratio = layout["aspect_ratio"]
if has_ocr:
parsed_text = (
f"[图片上传] 尺寸 {img_w}x{img_h}px, 比例 {ratio}"
f"未检测到 A4 报表结构,图片将被视为参考样式。\n"
f"请根据用户的文字描述生成报表。"
)
else:
parsed_text = (
f"[图片上传] 尺寸 {img_w}x{img_h}px, 比例 {ratio}\n"
f"⚠ OCR 引擎未安装,无法识别图片中的文字内容。\n"
f"请严格根据用户的文字描述来推断图片中的报表需求。\n"
f"(提示:如需图片文字识别,请运行 pip install paddleocr"
)
parsed_type = "image_reference"
elif suffix in (".pdf", ".docx", ".xlsx", ".xls", ".doc"):
parsed_type = suffix.lstrip(".")
keep_temp = (
suffix in (".png", ".jpg", ".jpeg", ".bmp", ".webp")
and result.get("method") not in ("metadata_only", None)
)
return {
"name": uploaded_file.name,
"text": parsed_text,
"type": parsed_type,
"tmp_path": tmp_path if keep_temp else None,
}
# ---- URL 参数 ---- # ---- URL 参数 ----
query_params = st.query_params query_params = st.query_params
url_session_id = query_params.get("session_id", "") url_session_id = query_params.get("session_id", "")
@@ -480,7 +555,8 @@ with st.sidebar:
uploaded = st.file_uploader( uploaded = st.file_uploader(
"选择文件", "选择文件",
type=["png", "jpg", "jpeg", "bmp", "webp", "pdf", "docx", "txt", "csv", "json", "xml"], type=["png", "jpg", "jpeg", "bmp", "webp", "pdf", "docx", "xlsx", "xls", "doc",
"txt", "csv", "json", "xml"],
accept_multiple_files=True, accept_multiple_files=True,
key="file_uploader", key="file_uploader",
label_visibility="collapsed", label_visibility="collapsed",
@@ -491,77 +567,21 @@ with st.sidebar:
# 去重 # 去重
if any(f["name"] == uf.name for f in st.session_state.uploaded_files): if any(f["name"] == uf.name for f in st.session_state.uploaded_files):
continue continue
import tempfile
from backend.file_parser import parse_file
from backend.layout_analyzer import analyze_layout
suffix = Path(uf.name).suffix.lower() suffix = Path(uf.name).suffix.lower()
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: result = _process_uploaded_file(uf, suffix)
tmp.write(uf.getvalue())
tmp_path = tmp.name
result = parse_file(tmp_path, suffix) if result["text"]:
# 对图片/PDF 进行 A4 模板布局分析
parsed_text = result["text"]
parsed_type = result["file_type"]
if suffix in (".png", ".jpg", ".jpeg", ".bmp", ".webp", ".pdf"):
layout = analyze_layout(tmp_path)
tt = layout.get("template_type", "unknown")
current_jrxml = st.session_state.agent_state.get("current_jrxml", "")
if tt == "full_a4":
parsed_text = layout["description"]
parsed_type = "a4_template"
elif tt == "partial_rows":
parsed_type = "a4_partial"
if current_jrxml.strip():
# 修改模式:尝试行匹配
from backend.layout_analyzer import match_rows_to_jrxml
match = match_rows_to_jrxml(layout, current_jrxml)
parsed_text = (
f"[行片段修改] 上传图片包含 {layout['total_rows']} 行,"
f"视为 A4 报表的一部分。\n\n"
f"{match['description']}\n\n"
f"--- 行结构 ---\n{layout['description']}"
)
else:
# 新建模式:按 A4 模板处理
parsed_text = layout["description"]
else:
# tt == "unknown": OCR 不可用或未检测到文字元素
has_ocr = result.get("method") not in ("metadata_only", None)
img_w, img_h = layout["image_size"]
ratio = layout["aspect_ratio"]
if has_ocr:
parsed_text = (
f"[图片上传] 尺寸 {img_w}x{img_h}px, 比例 {ratio}"
f"未检测到 A4 报表结构,图片将被视为参考样式。\n"
f"请根据用户的文字描述生成报表。"
)
else:
parsed_text = (
f"[图片上传] 尺寸 {img_w}x{img_h}px, 比例 {ratio}\n"
f"⚠ OCR 引擎未安装,无法识别图片中的文字内容。\n"
f"请严格根据用户的文字描述来推断图片中的报表需求。\n"
f"(提示:如需图片文字识别,请运行 pip install paddleocr"
)
parsed_type = "image_reference"
if parsed_text:
st.session_state.uploaded_files.append({ st.session_state.uploaded_files.append({
"name": uf.name, "name": result["name"],
"text": parsed_text, "text": result["text"],
"type": parsed_type, "type": result["type"],
}) })
# 对图片类型,保存路径以便 OCR 字段提取(延迟到 process_input 阶段) tmp_path = result["tmp_path"]
img_suffixes = (".png", ".jpg", ".jpeg", ".bmp", ".webp") if tmp_path:
if suffix in img_suffixes and result.get("method") not in ("metadata_only", None):
st.session_state.agent_state["uploaded_file_path"] = tmp_path st.session_state.agent_state["uploaded_file_path"] = tmp_path
st.session_state.uploaded_temp_paths.append(tmp_path) st.session_state.uploaded_temp_paths.append(tmp_path)
else:
Path(tmp_path).unlink(missing_ok=True)
if st.session_state.uploaded_files: if st.session_state.uploaded_files:
for i, f in enumerate(st.session_state.uploaded_files): for i, f in enumerate(st.session_state.uploaded_files):
@@ -632,34 +652,106 @@ for msg in st.session_state.messages:
else: else:
st.markdown(msg["content"]) st.markdown(msg["content"])
# ---- 聊天输入 ---- # ---- 聊天输入(支持粘贴/拖拽文件) ----
if prompt := st.chat_input("描述您的报表需求..."): from st_multimodal_chatinput import multimodal_chatinput
# 拼接上传文件的文本 import base64
import io
from pathlib import Path as _Path
# MIME type → 文件扩展名映射(用于剪贴板粘贴无扩展名的文件)
MIME_TO_EXT = {
"image/png": ".png",
"image/jpeg": ".jpg",
"image/bmp": ".bmp",
"image/webp": ".webp",
"application/pdf": ".pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
"application/vnd.ms-excel": ".xls",
"application/msword": ".doc",
"text/plain": ".txt",
"text/csv": ".csv",
"application/json": ".json",
"text/xml": ".xml",
}
chat_result = multimodal_chatinput()
if chat_result:
prompt = (chat_result.get("textInput") or "").strip()
chat_files = chat_result.get("uploadedFiles") or []
# 处理聊天中上传/粘贴的文件
uploaded_texts = [] uploaded_texts = []
uploaded_files_info = [] uploaded_files_info = []
# 先收集侧边栏已上传的文件
if st.session_state.get("uploaded_files"): if st.session_state.get("uploaded_files"):
for f in st.session_state.uploaded_files: for f in st.session_state.uploaded_files:
uploaded_texts.append(f"[上传文件: {f['name']}]\n{f['text']}") uploaded_texts.append(f"[上传文件: {f['name']}]\n{f['text']}")
uploaded_files_info.append({"name": f["name"], "type": f["type"], "length": len(f["text"])}) uploaded_files_info.append({"name": f["name"], "type": f["type"], "length": len(f["text"])})
if uploaded_texts: st.session_state.uploaded_files = []
full_prompt = "\n\n".join(uploaded_texts) + "\n\n---\n用户需求:\n" + prompt
st.session_state.uploaded_files = [] # 用后即清
else:
full_prompt = prompt
_app_log.info( # 处理聊天中的文件
"收到用户输入", class _Base64File:
extra={ """包装 base64 文件为类 UploadedFile 接口。"""
"session_id": current_session_id, def __init__(self, name, data_bytes):
"prompt_preview": prompt[:200], self.name = name
"prompt_length": len(prompt), self._data = data_bytes
"has_uploaded_files": bool(uploaded_files_info),
"uploaded_files": uploaded_files_info,
},
)
st.session_state.messages.append({"role": "user", "content": prompt}) def getvalue(self):
with st.chat_message("user"): return self._data
st.markdown(prompt)
run_agent(full_prompt) for cf in chat_files:
st.rerun() name = cf.get("name", "clipboard_file")
mime = cf.get("type", "")
content_b64 = cf.get("content", "")
if not content_b64:
continue
try:
data = base64.b64decode(content_b64)
except Exception:
continue
suffix = _Path(name).suffix.lower()
if not suffix and mime in MIME_TO_EXT:
suffix = MIME_TO_EXT[mime]
name = f"{_Path(name).stem}{suffix}"
wrapper = _Base64File(name, data)
result = _process_uploaded_file(wrapper, suffix)
if result["text"]:
uploaded_texts.append(f"[上传文件: {result['name']}]\n{result['text']}")
uploaded_files_info.append({"name": result["name"], "type": result["type"], "length": len(result["text"])})
tmp_path = result["tmp_path"]
if tmp_path:
st.session_state.agent_state["uploaded_file_path"] = tmp_path
st.session_state.uploaded_temp_paths.append(tmp_path)
if prompt or uploaded_texts:
if uploaded_texts:
full_prompt = "\n\n".join(uploaded_texts)
if prompt:
full_prompt += "\n\n---\n用户需求:\n" + prompt
else:
full_prompt = prompt
displayed_prompt = prompt or "(已上传文件,未输入文字)"
_app_log.info(
"收到用户输入",
extra={
"session_id": current_session_id,
"prompt_preview": displayed_prompt[:200],
"prompt_length": len(full_prompt),
"has_uploaded_files": bool(uploaded_files_info),
"uploaded_files": uploaded_files_info,
},
)
st.session_state.messages.append({"role": "user", "content": displayed_prompt})
with st.chat_message("user"):
st.markdown(displayed_prompt)
run_agent(full_prompt)
st.rerun()
+331
View File
@@ -0,0 +1,331 @@
"""批注检测器:识别图片上的圈选(圆)和箭头,定位用户要修改的字段。
依赖 OpenCV (cv2),从 PaddleOCR 传递依赖已安装。
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field
from typing import Optional
import cv2
import numpy as np
@dataclass
class Annotation:
"""单个批注标记。"""
type: str # "circle" | "arrow"
bbox: dict # {"x": int, "y": int, "w": int, "h": int}
center: tuple[int, int] # (cx, cy)
nearby_texts: list[str] = field(default_factory=list)
from_text: str = "" # 箭头出发点的文本
to_text: str = "" # 箭头指向的文本
from_pt: Optional[tuple[int, int]] = None
to_pt: Optional[tuple[int, int]] = None
def detect_annotations(image_path: str, ocr_elements: list[dict]) -> dict:
"""检测图片上的手写批注(圈选 + 箭头),并与 OCR 文本关联。
Args:
image_path: 图片文件路径
ocr_elements: OCR 元素列表 [{"text": str, "bbox": {x,y,w,h}, "confidence": float}]
Returns:
{"circles": [...], "arrows": [...], "total": int}
"""
img = cv2.imread(image_path)
if img is None:
return {"circles": [], "arrows": [], "total": 0, "error": "无法读取图片"}
h, w = img.shape[:2]
circles = _detect_circles(img)
arrows = _detect_arrows(img)
all_annotations = circles + arrows
_correlate_with_ocr(all_annotations, ocr_elements, w, h)
result: dict = {
"circles": [_annotation_to_dict(a) for a in circles],
"arrows": [_annotation_to_dict(a) for a in arrows],
"total": len(all_annotations),
}
return result
def _annotation_to_dict(a: Annotation) -> dict:
d = {
"type": a.type,
"bbox": a.bbox,
"center": list(a.center),
"nearby_texts": a.nearby_texts,
}
if a.type == "arrow":
d["from_text"] = a.from_text
d["to_text"] = a.to_text
if a.from_pt:
d["from_pt"] = list(a.from_pt)
if a.to_pt:
d["to_pt"] = list(a.to_pt)
return d
# ---------------------------------------------------------------------------
# 圆圈检测
# ---------------------------------------------------------------------------
def _detect_circles(img: np.ndarray) -> list[Annotation]:
"""检测图片中可能是手绘批注的圆圈。"""
h, w = img.shape[:2]
b, g, r = cv2.split(img)
red_enhanced = cv2.addWeighted(r.astype(np.float32), 1.5,
g.astype(np.float32), -0.3, 0)
red_enhanced = cv2.addWeighted(red_enhanced, 1.2,
b.astype(np.float32), -0.3, 0)
red_enhanced = np.clip(red_enhanced, 0, 255).astype(np.uint8)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
combined = cv2.addWeighted(gray, 0.5, red_enhanced, 0.5, 0)
blurred = cv2.GaussianBlur(combined, (9, 9), 2)
min_radius = max(15, min(w, h) // 40)
max_radius = min(200, max(w, h) // 8)
circles_raw = cv2.HoughCircles(
blurred, cv2.HOUGH_GRADIENT, dp=1.2, minDist=min_radius * 2,
param1=50, param2=30, minRadius=min_radius, maxRadius=max_radius,
)
annotations: list[Annotation] = []
if circles_raw is not None:
for cx, cy, r in circles_raw[0]:
bbox = {
"x": max(0, int(cx - r)),
"y": max(0, int(cy - r)),
"w": int(r * 2),
"h": int(r * 2),
}
annotations.append(Annotation(
type="circle",
bbox=bbox,
center=(int(cx), int(cy)),
))
return annotations
# ---------------------------------------------------------------------------
# 箭头检测
# ---------------------------------------------------------------------------
def _detect_arrows(img: np.ndarray) -> list[Annotation]:
"""检测图片中的手绘箭头(直线段 + 端点三角形)。"""
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
edges = cv2.Canny(gray, 50, 150, apertureSize=3)
lines = cv2.HoughLinesP(
edges, rho=1, theta=np.pi / 180, threshold=40,
minLineLength=30, maxLineGap=15,
)
if lines is None:
return []
segments = [(x1, y1, x2, y2) for x1, y1, x2, y2 in lines[:, 0]]
clusters = _cluster_segments(segments)
annotations: list[Annotation] = []
for segs in clusters:
if len(segs) < 2:
continue
all_pts = []
for x1, y1, x2, y2 in segs:
all_pts.append((x1, y1))
all_pts.append((x2, y2))
all_pts_arr = np.array(all_pts)
max_dist = 0
p1 = p2 = all_pts[0]
for i in range(len(all_pts)):
for j in range(i + 1, len(all_pts)):
d = (all_pts[i][0] - all_pts[j][0]) ** 2 + (all_pts[i][1] - all_pts[j][1]) ** 2
if d > max_dist:
max_dist = d
p1, p2 = all_pts[i], all_pts[j]
from_pt, to_pt = _find_arrow_direction(edges, p1, p2)
x1, y1 = from_pt
x2, y2 = to_pt
bbox = {
"x": min(x1, x2),
"y": min(y1, y2),
"w": abs(x2 - x1),
"h": abs(y2 - y1),
}
cx = (x1 + x2) // 2
cy = (y1 + y2) // 2
annotations.append(Annotation(
type="arrow",
bbox=bbox,
center=(cx, cy),
from_pt=from_pt,
to_pt=to_pt,
))
return annotations
def _cluster_segments(segments: list[tuple]) -> list[list[tuple]]:
"""将线段按方向和空间距离聚类。"""
clusters: list[list[tuple]] = []
used = [False] * len(segments)
for i, (x1, y1, x2, y2) in enumerate(segments):
if used[i]:
continue
cluster = [(x1, y1, x2, y2)]
used[i] = True
angle_i = math.atan2(y2 - y1, x2 - x1)
for j in range(i + 1, len(segments)):
if used[j]:
continue
x3, y3, x4, y4 = segments[j]
angle_j = math.atan2(y4 - y3, x4 - x3)
angle_diff = abs(angle_i - angle_j)
if angle_diff > math.pi:
angle_diff = 2 * math.pi - angle_diff
if angle_diff < 0.35:
d1 = math.hypot(x3 - x2, y3 - y2)
d2 = math.hypot(x1 - x4, y1 - y4)
d3 = math.hypot(x3 - x1, y3 - y1)
d4 = math.hypot(x4 - x2, y4 - y2)
if min(d1, d2, d3, d4) < 80:
cluster.append((x3, y3, x4, y4))
used[j] = True
clusters.append(cluster)
return clusters
def _find_arrow_direction(edges: np.ndarray, p1: tuple, p2: tuple) -> tuple[tuple, tuple]:
"""判断箭头的方向(哪端是箭头/三角形汇聚点)。"""
r = 20
h, w = edges.shape[:2]
def edge_density(cx, cy):
x1 = max(0, int(cx - r))
y1 = max(0, int(cy - r))
x2 = min(w, int(cx + r))
y2 = min(h, int(cy + r))
roi = edges[y1:y2, x1:x2]
if roi.size == 0:
return 0
return float(np.count_nonzero(roi)) / roi.size
d1 = edge_density(p1[0], p1[1])
d2 = edge_density(p2[0], p2[1])
if d1 > d2 * 1.3:
return p2, p1
if d2 > d1 * 1.3:
return p1, p2
return p1, p2
# ---------------------------------------------------------------------------
# OCR 关联
# ---------------------------------------------------------------------------
def _correlate_with_ocr(
annotations: list[Annotation],
ocr_elements: list[dict],
img_w: int,
img_h: int,
) -> None:
"""将批注与附近的 OCR 文本关联。"""
if not ocr_elements:
return
for ann in annotations:
ax = ann.center[0]
ay = ann.center[1]
near_texts: list[tuple[str, float]] = []
for elem in ocr_elements:
bbox = elem.get("bbox", {})
ex = bbox.get("x", 0) + bbox.get("w", 0) / 2
ey = bbox.get("y", 0) + bbox.get("h", 0) / 2
dist = math.hypot(ax - ex, ay - ey)
max_dist = max(img_w, img_h) * 0.15
if dist < max_dist:
near_texts.append((elem.get("text", ""), dist))
near_texts.sort(key=lambda x: x[1])
ann.nearby_texts = [t for t, _ in near_texts[:5]]
if ann.type == "arrow" and ann.from_pt and ann.to_pt:
ann.from_text = _closest_text(ann.from_pt, ocr_elements, img_w, img_h)
ann.to_text = _closest_text(ann.to_pt, ocr_elements, img_w, img_h)
def _closest_text(pt: tuple[int, int], ocr_elements: list[dict], img_w: int, img_h: int) -> str:
"""找到离 pt 最近的 OCR 文本。"""
best_text = ""
best_dist = max(img_w, img_h) * 0.12
for elem in ocr_elements:
bbox = elem.get("bbox", {})
ex = bbox.get("x", 0) + bbox.get("w", 0) / 2
ey = bbox.get("y", 0) + bbox.get("h", 0) / 2
dist = math.hypot(pt[0] - ex, pt[1] - ey)
if dist < best_dist:
best_dist = dist
best_text = elem.get("text", "")
return best_text
# ---------------------------------------------------------------------------
# LLM 上下文格式化
# ---------------------------------------------------------------------------
def format_annotation_context(annotation_result: dict) -> str:
"""将批注检测结果格式化为中文 LLM 提示文本。"""
if not annotation_result or not isinstance(annotation_result, dict):
return ""
circles = annotation_result.get("circles", [])
arrows = annotation_result.get("arrows", [])
total = annotation_result.get("total", len(circles) + len(arrows))
if total == 0:
return ""
parts = ["[图片批注检测结果]"]
if circles:
parts.append(f"\n检测到 {len(circles)} 个圈选标记:")
for i, c in enumerate(circles):
center = c.get("center", [0, 0])
near = c.get("nearby_texts", [])
parts.append(
f"{i+1}. 位置 ({center[0]},{center[1]})"
f" — 圈选内容: {', '.join(near) if near else '(附近无文字)'}"
)
if arrows:
parts.append(f"\n检测到 {len(arrows)} 个箭头标记:")
for i, a in enumerate(arrows):
ft = a.get("from_text", "")
tt = a.get("to_text", "")
parts.append(f" 箭头{i+1}. 从「{ft}」→ 指向「{tt}")
parts.append("\n请根据上述圈选/箭头定位用户要修改的报表字段。")
return "\n".join(parts)
+108 -20
View File
@@ -51,6 +51,9 @@ def parse_file(file_path: str, file_type: str = "") -> dict:
".webp": _parse_image, ".webp": _parse_image,
".pdf": _parse_pdf, ".pdf": _parse_pdf,
".docx": _parse_docx, ".docx": _parse_docx,
".xlsx": _parse_xlsx,
".xls": _parse_xls,
".doc": _parse_doc,
} }
parser = parsers.get(suffix) parser = parsers.get(suffix)
@@ -72,26 +75,7 @@ def _parse_image(path: Path) -> dict:
except Exception: except Exception:
info = "[图片: 无法读取元数据]" info = "[图片: 无法读取元数据]"
# 优先 EasyOCRWindows 兼容性更好 # 优先 PaddleOCR(精确识别
try:
import easyocr
import numpy as np
reader = easyocr.Reader(["ch_sim", "en"], gpu=False, verbose=False)
result = reader.readtext(np.array(img))
lines = [text.strip() for (_, text, _) in result if text.strip()]
if lines:
return {
"text": f"{info}\n识别文本:\n" + "\n".join(lines),
"file_type": "image",
"method": "easyocr",
"error": None,
}
except ImportError:
pass
except Exception:
pass
# 回退 PaddleOCR
try: try:
from paddleocr import PaddleOCR from paddleocr import PaddleOCR
ocr = PaddleOCR(lang="ch") ocr = PaddleOCR(lang="ch")
@@ -114,6 +98,25 @@ def _parse_image(path: Path) -> dict:
except Exception: except Exception:
pass pass
# 回退 EasyOCR
try:
import easyocr
import numpy as np
reader = easyocr.Reader(["ch_sim", "en"], gpu=False, verbose=False)
result = reader.readtext(np.array(img))
lines = [text.strip() for (_, text, _) in result if text.strip()]
if lines:
return {
"text": f"{info}\n识别文本:\n" + "\n".join(lines),
"file_type": "image",
"method": "easyocr",
"error": None,
}
except ImportError:
pass
except Exception:
pass
# OCR 不可用 → 返回图片元信息 + 安装提示 # OCR 不可用 → 返回图片元信息 + 安装提示
return { return {
"text": f"{info}\n(如需 OCR 文字识别,请安装: pip install easyocr)", "text": f"{info}\n(如需 OCR 文字识别,请安装: pip install easyocr)",
@@ -195,6 +198,91 @@ def _parse_docx(path: Path) -> dict:
"error": "DOCX 解析需要安装 python-docx"} "error": "DOCX 解析需要安装 python-docx"}
def _parse_xlsx(path: Path) -> dict:
"""提取 Excel .xlsx 文件中的文本。"""
try:
from openpyxl import load_workbook
wb = load_workbook(path, read_only=True, data_only=True)
parts = []
for name in wb.sheetnames:
ws = wb[name]
rows = []
for row in ws.iter_rows(values_only=True):
cells = [str(c) if c is not None else "" for c in row]
if any(c for c in cells):
rows.append("\t".join(cells))
if rows:
parts.append(f"[Sheet: {name}]\n" + "\n".join(rows))
wb.close()
text = "\n\n".join(parts)
return {"text": text, "file_type": "xlsx", "method": "openpyxl", "error": None}
except ImportError:
pass
except Exception as e:
return {"text": "", "file_type": "xlsx", "method": "none",
"error": f"XLSX 解析失败: {e}"}
return {"text": "", "file_type": "xlsx", "method": "none",
"error": "XLSX 解析需要安装 openpyxl"}
def _parse_xls(path: Path) -> dict:
"""提取旧版 Excel .xls 文件中的文本。"""
try:
import xlrd
wb = xlrd.open_workbook(path)
parts = []
for name in wb.sheet_names():
ws = wb.sheet_by_name(name)
rows = []
for rx in range(ws.nrows):
cells = [str(ws.cell_value(rx, cx)) if ws.cell_value(rx, cx) != "" else ""
for cx in range(ws.ncols)]
if any(c for c in cells):
rows.append("\t".join(cells))
if rows:
parts.append(f"[Sheet: {name}]\n" + "\n".join(rows))
text = "\n\n".join(parts)
return {"text": text, "file_type": "xls", "method": "xlrd", "error": None}
except ImportError:
pass
except Exception as e:
return {"text": "", "file_type": "xls", "method": "none",
"error": f"XLS 解析失败: {e}"}
return {"text": "", "file_type": "xls", "method": "none",
"error": "XLS 解析需要安装 xlrd"}
def _parse_doc(path: Path) -> dict:
"""提取旧版 Word .doc 文件中的文本(尽力而为,二进制格式)。"""
try:
import olefile
ole = olefile.OleFileIO(path)
if not ole.exists("WordDocument"):
ole.close()
return {"text": "", "file_type": "doc", "method": "none",
"error": "不是有效的 .doc 文件"}
raw = ole.openstream("WordDocument").read()
ole.close()
# 提取可打印 UTF-16LE 字符段
text = ""
try:
decoded = raw.decode("utf-16-le", errors="ignore")
text = "".join(c for c in decoded if c.isprintable() or c in "\n\r\t")
except Exception:
pass
if not text.strip():
return {"text": "", "file_type": "doc", "method": "olefile",
"error": "无法提取文本(.doc 为二进制格式,建议转换为 .docx)"}
return {"text": text.strip(), "file_type": "doc", "method": "olefile", "error": None}
except ImportError:
pass
except Exception as e:
return {"text": "", "file_type": "doc", "method": "none",
"error": f"DOC 解析失败: {e}"}
return {"text": "", "file_type": "doc", "method": "none",
"error": "DOC 解析需要安装 olefile"}
def _parse_text(path: Path) -> dict: def _parse_text(path: Path) -> dict:
"""读取纯文本文件。""" """读取纯文本文件。"""
try: try:
+34 -34
View File
@@ -373,40 +373,7 @@ def _load_image(path: Path) -> Optional[PIL.Image.Image]:
def _ocr_elements(img: PIL.Image.Image, file_path: str) -> list[dict]: def _ocr_elements(img: PIL.Image.Image, file_path: str) -> list[dict]:
"""OCR 提取图片中的文字元素(位置+内容)。优先 EasyOCR,回退 PaddleOCR。""" """OCR 提取图片中的文字元素(位置+内容)。优先 EasyOCR,回退 PaddleOCR。"""
# 优先 EasyOCR # 优先 PaddleOCR(精确识别)
try:
import easyocr
import numpy as np
reader = easyocr.Reader(["ch_sim", "en"], gpu=False, verbose=False)
result = reader.readtext(np.array(img))
elements = []
for (bbox, text, confidence) in result:
if not text.strip():
continue
xs = [p[0] for p in bbox]
ys = [p[1] for p in bbox]
x_min, x_max = min(xs), max(xs)
y_min, y_max = min(ys), max(ys)
elements.append({
"x": round(x_min, 1),
"y": round(y_min, 1),
"w": round(x_max - x_min, 1),
"h": round(y_max - y_min, 1),
"font_size": round(y_max - y_min, 1),
"text": text.strip(),
})
elements.sort(key=lambda e: (e["y"], e["x"]))
return elements
except ImportError:
pass
except Exception:
pass
# 回退 PaddleOCR
try: try:
from paddleocr import PaddleOCR from paddleocr import PaddleOCR
import numpy as np import numpy as np
@@ -446,6 +413,39 @@ def _ocr_elements(img: PIL.Image.Image, file_path: str) -> list[dict]:
except Exception: except Exception:
pass pass
# 回退 EasyOCR
try:
import easyocr
import numpy as np
reader = easyocr.Reader(["ch_sim", "en"], gpu=False, verbose=False)
result = reader.readtext(np.array(img))
elements = []
for (bbox, text, confidence) in result:
if not text.strip():
continue
xs = [p[0] for p in bbox]
ys = [p[1] for p in bbox]
x_min, x_max = min(xs), max(xs)
y_min, y_max = min(ys), max(ys)
elements.append({
"x": round(x_min, 1),
"y": round(y_min, 1),
"w": round(x_max - x_min, 1),
"h": round(y_max - y_min, 1),
"font_size": round(y_max - y_min, 1),
"text": text.strip(),
})
elements.sort(key=lambda e: (e["y"], e["x"]))
return elements
except ImportError:
pass
except Exception:
pass
return [] return []
+4 -4
View File
@@ -284,13 +284,13 @@ class OcrExtractor:
try: try:
import numpy as np import numpy as np
easyocr_result = self._try_easyocr(np.array(img))
if easyocr_result:
return easyocr_result
paddleocr_result = self._try_paddleocr(img, file_path) paddleocr_result = self._try_paddleocr(img, file_path)
if paddleocr_result: if paddleocr_result:
return paddleocr_result return paddleocr_result
easyocr_result = self._try_easyocr(np.array(img))
if easyocr_result:
return easyocr_result
except Exception: except Exception:
pass pass
+2
View File
@@ -8,6 +8,8 @@
- 如果添加新字段,正确声明它们。 - 如果添加新字段,正确声明它们。
- 确保 <queryString> 是 <![CDATA[...]]> 中有效的 SQL。 - 确保 <queryString> 是 <![CDATA[...]]> 中有效的 SQL。
{ocr_context}
当前 JRXML 当前 JRXML
{current_jrxml} {current_jrxml}
+18
View File
@@ -26,6 +26,24 @@ python-dotenv>=1.0.0
httpx>=0.27.0 httpx>=0.27.0
tiktoken>=0.7.0 tiktoken>=0.7.0
# OCR 依赖(PaddleOCR 精确识别优先,EasyOCR 回退)
# Pinned: paddleocr 2.9.x + paddlepaddle 2.6.x known-stable on Windows CPU
# 3.x has ONEDNN compatibility issues on Windows
paddleocr>=2.9.0,<3.0.0
paddlepaddle>=2.6.0,<3.0.0
easyocr>=1.7.0
# 聊天输入增强(粘贴/拖拽上传)
st-multimodal-chatinput>=0.2.1
# 多格式文件解析
openpyxl>=3.1.0
xlrd>=2.0.0
olefile>=0.47
# 批注检测(圈选/箭头识别)
opencv-python-headless>=4.8.0
# 测试 # 测试
pytest>=8.0.0 pytest>=8.0.0
pytest-asyncio>=0.24.0 pytest-asyncio>=0.24.0
xlwt>=1.3.0
+151
View File
@@ -0,0 +1,151 @@
"""测试批注检测器:圆圈检测、箭头检测、OCR 关联、格式化。"""
import tempfile
from pathlib import Path
import cv2
import numpy as np
import pytest
def _draw_circle_image(path: str, size: tuple = (400, 300)) -> None:
"""生成包含红色圆圈的合成测试图片。"""
img = np.ones((size[1], size[0], 3), dtype=np.uint8) * 255
cv2.circle(img, (200, 150), 50, (0, 0, 255), 2)
cv2.imwrite(path, img)
def _draw_arrow_image(path: str, size: tuple = (400, 300)) -> None:
"""生成包含手绘风格箭头的合成测试图片(多段线模拟手绘)。"""
img = np.ones((size[1], size[0], 3), dtype=np.uint8) * 255
# 多段略微偏移的线段模拟手绘箭杆(产生多个 HoughLinesP 段)
for offset_y in (-1, 0, 1):
cv2.line(img, (50, 150 + offset_y), (200, 150 + offset_y), (0, 0, 255), 2)
for offset_y in (-1, 0, 1):
cv2.line(img, (200, 150 + offset_y), (340, 150 + offset_y), (0, 0, 255), 2)
# 箭头三角形
pts = np.array([[350, 150], [330, 135], [330, 165]], np.int32)
cv2.fillPoly(img, [pts], (0, 0, 255))
# 额外三角形边缘线
cv2.line(img, (350, 150), (330, 135), (0, 0, 255), 2)
cv2.line(img, (350, 150), (330, 165), (0, 0, 255), 2)
cv2.imwrite(path, img)
def _draw_circle_and_text_image(path: str, size: tuple = (500, 400)) -> None:
"""生成包含红色圆圈和"文本"的合成图片(模拟圈选批注)。"""
img = np.ones((size[1], size[0], 3), dtype=np.uint8) * 255
cv2.circle(img, (250, 150), 60, (0, 0, 255), 3)
cv2.putText(img, "项目A", (20, 160), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
cv2.imwrite(path, img)
class TestAnnotationDetector:
"""测试 annotation_detector.py 各功能。"""
def test_detect_circles_finds_circle(self):
from backend.annotation_detector import detect_annotations
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
path = tmp.name
try:
_draw_circle_image(path)
ocr_elements = [
{"text": "测试字段", "bbox": {"x": 170, "y": 120, "w": 60, "h": 20}, "confidence": 0.95},
]
result = detect_annotations(path, ocr_elements)
assert result["total"] >= 1
circles = result["circles"]
assert len(circles) >= 1
c = circles[0]
assert c["type"] == "circle"
assert "center" in c
assert "bbox" in c
assert "nearby_texts" in c
finally:
Path(path).unlink(missing_ok=True)
def test_detect_arrows_finds_arrow(self):
from backend.annotation_detector import detect_annotations
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
path = tmp.name
try:
_draw_arrow_image(path)
ocr_elements = [
{"text": "起点", "bbox": {"x": 30, "y": 130, "w": 40, "h": 20}, "confidence": 0.9},
{"text": "终点", "bbox": {"x": 310, "y": 130, "w": 40, "h": 20}, "confidence": 0.9},
]
result = detect_annotations(path, ocr_elements)
assert result["total"] >= 1
arrows = result["arrows"]
assert len(arrows) >= 1
a = arrows[0]
assert a["type"] == "arrow"
assert "from_pt" in a
assert "to_pt" in a
finally:
Path(path).unlink(missing_ok=True)
def test_correlate_with_ocr_links_nearby_texts(self):
from backend.annotation_detector import detect_annotations
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
path = tmp.name
try:
_draw_circle_and_text_image(path)
ocr_elements = [
{"text": "项目A", "bbox": {"x": 20, "y": 140, "w": 80, "h": 30}, "confidence": 0.98},
{"text": "金额", "bbox": {"x": 350, "y": 200, "w": 50, "h": 20}, "confidence": 0.9},
]
result = detect_annotations(path, ocr_elements)
circles = result["circles"]
if circles:
near = circles[0].get("nearby_texts", [])
if near:
assert "项目A" in near
finally:
Path(path).unlink(missing_ok=True)
def test_invalid_image_path(self):
from backend.annotation_detector import detect_annotations
result = detect_annotations("/nonexistent/file.png", [])
assert result["total"] == 0
assert "error" in result
def test_format_annotation_context_empty(self):
from backend.annotation_detector import format_annotation_context
assert format_annotation_context({}) == ""
assert format_annotation_context(None) == ""
assert format_annotation_context({"circles": [], "arrows": [], "total": 0}) == ""
def test_format_annotation_context_with_circles(self):
from backend.annotation_detector import format_annotation_context
ann = {
"circles": [
{"center": [100, 200], "nearby_texts": ["项目A", "金额"]},
],
"arrows": [],
"total": 1,
}
text = format_annotation_context(ann)
assert "圈选标记" in text
assert "项目A" in text
def test_format_annotation_context_with_arrows(self):
from backend.annotation_detector import format_annotation_context
ann = {
"circles": [],
"arrows": [
{"from_text": "修理号", "to_text": "车架号"},
],
"total": 1,
}
text = format_annotation_context(ann)
assert "箭头标记" in text
assert "修理号" in text
assert "车架号" in text
+143
View File
@@ -0,0 +1,143 @@
"""端到端测试:OCR 字段精确提取完整流水线。
覆盖:
1. PaddleOCR 精确识别优先
2. EasyOCR 降级回退
3. 4种提取策略
4. 验证服务连通性
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from PIL import Image, ImageDraw
from backend.ocr_extractor import OcrExtractor, extract_ocr_fields
from backend.file_parser import parse_file
from backend.validation import validate_jrxml
def create_test_invoice(path: str):
"""创建一张模拟中文发票图片,包含已知字段。"""
img = Image.new("RGB", (800, 600), color="white")
draw = ImageDraw.Draw(img)
draw.text((300, 20), "增值税普通发票", fill="black")
draw.text((300, 60), "发票代码: 1234567890", fill="black")
draw.text((300, 100), "发票号码: 87654321", fill="black")
draw.text((50, 160), "开票日期: 2024年1月15日", fill="black")
draw.text((50, 200), "购买方名称: 测试公司", fill="black")
draw.text((50, 240), "合计金额: 1,234.56", fill="black")
draw.text((50, 280), "校验码: ABC12345678", fill="black")
draw.text((50, 350), "名称 数量 单价", fill="black")
draw.text((50, 390), "商品A 2 10.00", fill="black")
draw.text((50, 430), "商品B 5 20.00", fill="black")
img.save(path)
print(f"[OK] 测试图片已创建: {path}")
return path
def test_ocr_extraction_pipeline():
"""端到端测试:图片 -> OCR -> 字段提取。"""
print("\n=== 端到端OCR字段提取测试 ===\n")
img_path = create_test_invoice("test_invoice_e2e.png")
# 阶段1: 文件解析(含OCR
print("\n--- 阶段1: 文件解析(OCR) ---")
result = parse_file(img_path)
method = result.get("method", "N/A")
print(f" OCR方法: {method}")
print(f" 文件类型: {result.get('file_type', 'N/A')}")
text_preview = result.get("text", "")[:200]
print(f" 文本预览: {text_preview}")
# 阶段2: OCR精确提取
print("\n--- 阶段2: 字段精确提取 ---")
target_fields = ["发票代码", "发票号码", "开票日期", "合计金额", "校验码"]
extraction = extract_ocr_fields(img_path, target_fields)
print(f" OCR可用: {extraction.get('ocr_available')}")
print(f" 图片尺寸: {extraction.get('image_size')}")
print(f" 元素总数: {extraction.get('total_elements')}")
print(f" 错误: {extraction.get('errors')}")
print("\n 提取结果:")
all_ok = True
for field in extraction.get("fields", []):
status = "PASS" if field["field_value"] else "FAIL"
if not field["field_value"]:
all_ok = False
print(f" [{status}] {field['field_name']:10s} = {field['field_value']:20s} "
f"方法={field['extraction_method']:12s} 置信度={field['confidence']:.2f} "
f"bbox={field['bbox']}")
# 策略独立验证
print("\n--- 4种策略独立验证 ---")
extractor = OcrExtractor()
result_obj = extractor.extract(img_path, ["合计金额"])
fields = result_obj.get("fields", [])
if fields:
print(f" 策略: {fields[0].get('extraction_method', 'N/A')}")
print(f" 值: {fields[0].get('field_value', 'N/A')}")
print(f" 坐标: {fields[0].get('bbox', 'N/A')}")
return all_ok
def test_validation_service():
"""测试验证服务连通性。"""
print("\n=== 验证服务连通性测试 ===\n")
result = validate_jrxml("<jasperReport/>")
print(f" 状态: {'OK' if result else 'FAIL'}")
print(f" 响应: {result}")
return True
def test_ocr_fallback():
"""测试OCR回退:无图片时优雅降级。"""
print("\n=== OCR降级测试 ===\n")
result = extract_ocr_fields("/nonexistent/file.png", ["发票代码"])
print(f" OCR可用: {result.get('ocr_available')}")
print(f" 错误: {result.get('errors')}")
assert not result["ocr_available"]
assert len(result["errors"]) > 0
print(" [PASS] 降级行为正常(不阻断流程)")
if __name__ == "__main__":
errors = []
try:
test_ocr_fallback()
except Exception as e:
print(f"[FAIL] 降级测试: {e}")
errors.append(str(e))
try:
ok = test_ocr_extraction_pipeline()
if not ok:
print("\n 部分字段未提取到(可能因字体渲染差异)")
except Exception as e:
print(f"[FAIL] 流水线测试: {e}")
errors.append(str(e))
try:
test_validation_service()
except Exception as e:
print(f"[FAIL] 验证服务: {e}")
errors.append(str(e))
Path("test_invoice_e2e.png").unlink(missing_ok=True)
print("\n" + "=" * 50)
if errors:
print(f"测试完成,{len(errors)} 个错误:")
for e in errors:
print(f" - {e}")
sys.exit(1)
else:
print("所有端到端测试通过!")
+90
View File
@@ -0,0 +1,90 @@
"""测试多格式文件解析器:XLSX, XLS, DOC。"""
import tempfile
from pathlib import Path
import pytest
def _make_xlsx(path: str) -> None:
"""生成最小 .xlsx 测试文件。"""
from openpyxl import Workbook
wb = Workbook()
ws = wb.active
ws.title = "Sheet1"
ws["A1"] = "名称"
ws["B1"] = "金额"
ws["A2"] = "项目A"
ws["B2"] = 100
ws["A3"] = "项目B"
ws["B3"] = 200
wb.save(path)
def _make_xls(path: str) -> None:
"""生成最小 .xls 测试文件。"""
from xlwt import Workbook
wb = Workbook()
ws = wb.add_sheet("Sheet1")
ws.write(0, 0, "名称")
ws.write(0, 1, "金额")
ws.write(1, 0, "项目A")
ws.write(1, 1, 100)
ws.write(2, 0, "项目B")
ws.write(2, 1, 200)
wb.save(path)
class TestMultiFormatParsers:
"""测试 file_parser.py 的多格式解析器。"""
def test_parse_xlsx(self):
from backend.file_parser import parse_file
with tempfile.NamedTemporaryFile(suffix=".xlsx", delete=False) as tmp:
path = tmp.name
try:
_make_xlsx(path)
result = parse_file(path, ".xlsx")
assert result["file_type"] == "xlsx"
assert result["method"] == "openpyxl"
assert result["error"] is None
assert "Sheet1" in result["text"]
assert "项目A" in result["text"]
assert "100" in result["text"]
finally:
Path(path).unlink(missing_ok=True)
def test_parse_xls(self):
from backend.file_parser import parse_file
with tempfile.NamedTemporaryFile(suffix=".xls", delete=False) as tmp:
path = tmp.name
try:
_make_xls(path)
result = parse_file(path, ".xls")
assert result["file_type"] == "xls"
assert result["method"] == "xlrd"
assert result["error"] is None
assert "Sheet1" in result["text"]
assert "项目A" in result["text"]
assert "100.0" in result["text"]
finally:
Path(path).unlink(missing_ok=True)
def test_parse_doc_nonexistent(self):
"""测试 .doc 文件不存在时的错误处理。"""
from backend.file_parser import parse_file
result = parse_file("/nonexistent/file.doc", ".doc")
assert result["file_type"] == ".doc"
assert result["method"] == "none"
assert result.get("error") is not None
def test_dispatch_adds_new_formats(self):
"""验证新格式已在 parse_file 调度表中注册。"""
from backend.file_parser import parse_file
for ext in [".xlsx", ".xls", ".doc"]:
result = parse_file("/tmp/test" + ext, ext)
assert result["file_type"] in (ext, "xlsx", "xls", "doc")