feat: 新增 OCR 单据字段精确提取模块
- 新增 backend/ocr_extractor.py: 两阶段提取流水线 (文档分析 + 字段提取) - 四种提取策略: 精确KV匹配/模糊KV匹配/正则模式/表格结构匹配 - agent/state.py: 新增 ocr_extraction_result 和 uploaded_file_path 字段 - agent/nodes.py: process_input() 中自动触发 OCR 提取钩子 - app.py: 文件上传时保留图片路径, 总结卡片中展示提取结果 - .env.example: 新增 OCR_USE_GPU / OCR_CONFIDENCE_THRESHOLD 配置项 - tests/test_ocr_extraction.py: 48 个单元测试全部通过
This commit is contained in:
@@ -261,6 +261,14 @@ def run_agent(user_input: str):
|
||||
if stream_active:
|
||||
streaming_placeholder.empty()
|
||||
|
||||
# 清理已处理的临时文件
|
||||
for p in st.session_state.get("uploaded_temp_paths", []):
|
||||
try:
|
||||
Path(p).unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
st.session_state.uploaded_temp_paths = []
|
||||
|
||||
# ---- 总结卡片 ----
|
||||
# 注:node_state 只含变更字段,用 agent_state(被所有节点就地修改)获取完整状态
|
||||
final_state = agent_state
|
||||
@@ -324,6 +332,30 @@ def run_agent(user_input: str):
|
||||
"content": f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML。\n\n**错误:** {error_msg}\n\n💡 请直接描述修改需求,系统会自动加载失败上下文。",
|
||||
"type": "error_explanation",
|
||||
})
|
||||
|
||||
# OCR 字段提取结果展示
|
||||
ocr_result = agent_state.get("ocr_extraction_result", {})
|
||||
if ocr_result and ocr_result.get("ocr_available") and ocr_result.get("fields"):
|
||||
with st.expander("🔍 OCR 单据字段提取结果", expanded=False):
|
||||
fields = ocr_result.get("fields", [])
|
||||
non_empty = [f for f in fields if f.get("field_value")]
|
||||
empty = [f for f in fields if not f.get("field_value")]
|
||||
if non_empty:
|
||||
st.markdown("**已提取字段:**")
|
||||
for f in non_empty:
|
||||
method = f.get("extraction_method", "")
|
||||
conf = f.get("confidence", 0)
|
||||
st.markdown(
|
||||
f"- **{f['field_name']}**: `{f['field_value']}` "
|
||||
f"(置信度: {conf:.0%}, 方法: {method})"
|
||||
)
|
||||
if empty:
|
||||
st.caption(
|
||||
f"未提取到值的字段: {', '.join(f['field_name'] for f in empty)}"
|
||||
)
|
||||
st.caption(
|
||||
f"共检测到 {ocr_result.get('total_elements', 0)} 个文本元素"
|
||||
)
|
||||
else:
|
||||
st.error("未产生结果,请重试。")
|
||||
|
||||
@@ -443,6 +475,9 @@ with st.sidebar:
|
||||
if "uploaded_files" not in st.session_state:
|
||||
st.session_state.uploaded_files = [] # [{name, text, type}]
|
||||
|
||||
if "uploaded_temp_paths" not in st.session_state:
|
||||
st.session_state.uploaded_temp_paths = [] # 待清理的临时文件路径
|
||||
|
||||
uploaded = st.file_uploader(
|
||||
"选择文件",
|
||||
type=["png", "jpg", "jpeg", "bmp", "webp", "pdf", "docx", "txt", "csv", "json", "xml"],
|
||||
@@ -513,8 +548,6 @@ with st.sidebar:
|
||||
)
|
||||
parsed_type = "image_reference"
|
||||
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
if parsed_text:
|
||||
st.session_state.uploaded_files.append({
|
||||
"name": uf.name,
|
||||
@@ -522,6 +555,14 @@ with st.sidebar:
|
||||
"type": parsed_type,
|
||||
})
|
||||
|
||||
# 对图片类型,保存路径以便 OCR 字段提取(延迟到 process_input 阶段)
|
||||
img_suffixes = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
|
||||
if suffix in img_suffixes and result.get("method") not in ("metadata_only", None):
|
||||
st.session_state.agent_state["uploaded_file_path"] = tmp_path
|
||||
st.session_state.uploaded_temp_paths.append(tmp_path)
|
||||
else:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
if st.session_state.uploaded_files:
|
||||
for i, f in enumerate(st.session_state.uploaded_files):
|
||||
cols = st.columns([5, 1])
|
||||
|
||||
Reference in New Issue
Block a user