feat: v4 multimodal chat input, multi-format support, and annotation detection

- Replace st.chat_input with st-multimodal-chatinput (Ctrl+V paste, drag-drop, file button)
- Extract _process_uploaded_file() shared handler (eliminates ~70 duplicated lines)
- Add XLSX (openpyxl), XLS (xlrd), DOC (olefile) parsers to file_parser.py
- Add backend/annotation_detector.py: circle detection (HoughCircles) + arrow detection (HoughLinesP clustering) + OCR correlation + LLM context formatting
- Add annotation_result field to AgentState with session persistence
- Wire annotation detection into process_input and _format_ocr_context
- Add 11 new tests: 7 annotation detector + 4 multi-format parser
- Update all docs: CLAUDE.md, README.md, CODE_GUIDE.md, ROADMAP.md
This commit is contained in:
2026-05-20 23:43:16 +08:00
parent c9f003e1b7
commit 9bb011e429
16 changed files with 1257 additions and 164 deletions
+84 -3
View File
@@ -134,6 +134,23 @@ def process_input(state: AgentState) -> Dict:
"fields": len(ocr_result.get("fields", [])),
},
)
# 批注检测(圈选/箭头标记)
elements = ocr_result.get("elements", [])
if elements:
try:
from backend.annotation_detector import detect_annotations
ann_result = detect_annotations(uploaded_path, elements)
if ann_result.get("total", 0) > 0:
state["annotation_result"] = ann_result
_node_log.info(
"批注检测完成",
extra={
"circles": len(ann_result.get("circles", [])),
"arrows": len(ann_result.get("arrows", [])),
},
)
except Exception as e:
_node_log.warning(f"批注检测失败: {e}")
except Exception as e:
_node_log.warning(f"OCR 字段提取失败: {e}")
state["ocr_extraction_result"] = {"error": str(e)}
@@ -359,7 +376,9 @@ def load_session_node(state: AgentState) -> Dict:
# 恢复核心字段(不覆盖当前请求的 user_input / stage
for key in ("conversation_history", "full_conversation_history",
"current_jrxml", "final_jrxml", "compressed_history",
"session_name", "created_at", "history_states"):
"session_name", "created_at", "history_states",
"ocr_extraction_result", "uploaded_file_path",
"annotation_result"):
if key in saved and key not in ("user_input", "stage"):
state[key] = saved[key]
state["session_name"] = data.get("session_name", "")
@@ -381,7 +400,9 @@ def save_session_node(state: AgentState) -> Dict:
persistable = {}
for key in ("conversation_history", "full_conversation_history",
"current_jrxml", "final_jrxml", "compressed_history",
"status", "error_msg", "history_states"):
"status", "error_msg", "history_states",
"ocr_extraction_result", "uploaded_file_path",
"annotation_result"):
if key in state:
persistable[key] = state[key]
persistable["updated_at"] = _now_iso()
@@ -416,6 +437,59 @@ def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def _format_ocr_context(state: AgentState) -> str:
"""将 OCR 提取结果格式化为 LLM 可用的上下文文本。"""
ocr_result = state.get("ocr_extraction_result")
if not ocr_result or not isinstance(ocr_result, dict):
return ""
if ocr_result.get("error"):
return ""
parts = []
parts.append("[图片OCR识别结果]")
total = ocr_result.get("total_elements", 0)
if total:
parts.append(f"检测到 {total} 个文字元素")
# 提取到的字段
fields = ocr_result.get("fields", [])
if fields:
parts.append("\n提取的结构化字段:")
for f in fields:
if f.get("field_value"):
parts.append(
f" - {f['field_name']}: {f['field_value']} "
f"(方法={f.get('extraction_method','?')}, "
f"置信度={f.get('confidence',0):.2f})"
)
# 所有原始文本(用于表格匹配等需要全文的场景)
elements = ocr_result.get("elements", [])
if elements:
parts.append("\n全部文本元素(含坐标):")
for e in elements:
bbox = e.get("bbox", {})
x, y, w, h = bbox.get("x", 0), bbox.get("y", 0), bbox.get("w", 0), bbox.get("h", 0)
parts.append(
f" [{x},{y} {w}×{h}] {e['text']} "
f"(置信度={e.get('confidence',0):.2f})"
)
# 批注检测结果
ann_result = state.get("annotation_result")
if ann_result and isinstance(ann_result, dict):
try:
from backend.annotation_detector import format_annotation_context
ann_text = format_annotation_context(ann_result)
if ann_text:
parts.append("\n" + ann_text)
except Exception:
pass
return "\n".join(parts)
@log_node("retrieve")
def retrieve(state: AgentState) -> Dict:
"""在 ChromaDB + 错误知识库中搜索相关的 JRXML 模板和组件。"""
@@ -446,9 +520,15 @@ def generate(state: AgentState) -> Dict:
writer = get_stream_writer()
llm = get_llm(caller="generate")
user_request = state.get("user_input", "")
ocr_text = _format_ocr_context(state)
if ocr_text:
user_request = f"{ocr_text}\n\n---\n用户需求:\n{user_request}"
prompt = load_prompt("initial_generation").format(
context=state.get("retrieved_context", ""),
user_request=state.get("user_input", ""),
user_request=user_request,
)
full = []
for chunk in llm.stream(prompt):
@@ -480,6 +560,7 @@ def modify_jrxml(state: AgentState) -> Dict:
current_jrxml=state.get("current_jrxml", ""),
conversation_history=conv_text,
modification_request=state.get("user_modification_request", ""),
ocr_context=_format_ocr_context(state),
)
full = []
for chunk in llm.stream(prompt):
+3
View File
@@ -44,3 +44,6 @@ class AgentState(TypedDict, total=False):
# 需求7:OCR 单据字段精确提取结果
ocr_extraction_result: dict
uploaded_file_path: str
# 需求8:图片批注检测(圈选/箭头标记)
annotation_result: dict