fix: OCR字段提取集成修复 + 会话切换无限循环修复 + 一键启动脚本

- process_input 传入17个默认中文字段(修复空列表导致零字段提取)
- OCR提取结果自动注入 LLM 上下文
- save_session_node/load_session_node 持久化 session_id(修复切换会话无限 rerun)
- app.py 会话切换后显式设置 session_id(纵深防御)
- 新增 start.bat / stop.bat 一键启动/停止脚本
- 更新 CLAUDE.md + CODE_GUIDE.md 文档

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-20 10:17:05 +08:00
parent c9f003e1b7
commit da79640259
6 changed files with 166 additions and 6 deletions
+23 -3
View File
@@ -123,7 +123,13 @@ def process_input(state: AgentState) -> Dict:
try:
from backend.ocr_extractor import OcrExtractor
extractor = OcrExtractor()
ocr_result = extractor.extract(uploaded_path, [])
default_fields = [
"发票代码", "发票号码", "开票日期", "合计金额", "校验码",
"价税合计", "总金额", "日期", "金额", "数量", "单价", "税率",
"购买方名称", "销售方名称", "货物名称", "规格型号",
"不含税金额", "税额",
]
ocr_result = extractor.extract(uploaded_path, default_fields)
if ocr_result.get("ocr_available"):
state["ocr_extraction_result"] = ocr_result
_node_log.info(
@@ -134,6 +140,20 @@ def process_input(state: AgentState) -> Dict:
"fields": len(ocr_result.get("fields", [])),
},
)
# 将提取到的字段注入到对话上下文,供 LLM 使用
extracted_fields = ocr_result.get("fields", [])
non_empty = [f for f in extracted_fields if f.get("field_value")]
if non_empty:
lines = ["[OCR 单据字段提取结果]"]
for f in non_empty:
lines.append(
f"- {f['field_name']}: {f['field_value']}"
f"(置信度: {f['confidence']:.0%}, 方法: {f['extraction_method']}"
)
ocr_context = "\n".join(lines)
user_input = f"{ocr_context}\n\n{user_input}"
# 同时更新工作对话历史中的最后一条
conv_history[-1]["content"] = user_input
except Exception as e:
_node_log.warning(f"OCR 字段提取失败: {e}")
state["ocr_extraction_result"] = {"error": str(e)}
@@ -357,7 +377,7 @@ def load_session_node(state: AgentState) -> Dict:
if data and data.get("agent_state"):
saved = data["agent_state"]
# 恢复核心字段(不覆盖当前请求的 user_input / stage
for key in ("conversation_history", "full_conversation_history",
for key in ("session_id", "conversation_history", "full_conversation_history",
"current_jrxml", "final_jrxml", "compressed_history",
"session_name", "created_at", "history_states"):
if key in saved and key not in ("user_input", "stage"):
@@ -379,7 +399,7 @@ def save_session_node(state: AgentState) -> Dict:
try:
from backend.session import save_session
persistable = {}
for key in ("conversation_history", "full_conversation_history",
for key in ("session_id", "conversation_history", "full_conversation_history",
"current_jrxml", "final_jrxml", "compressed_history",
"status", "error_msg", "history_states"):
if key in state: