Merge remote v4/v5 features (multimodal chat input, layered generation, annotation detection) with local v3 features (dialog file upload, XLSX support, session fix)
Key resolutions: - agent/nodes.py: Merged session_id exclusion fix with new persistable fields (ocr_extraction_result, annotation_result, layout_schema, ocr_elements) - app.py: Adopted st-multimodal-chatinput for unified paste/drop/upload, removed custom JS paste bridge - backend/file_parser.py: Kept local XLSX parser, added remote XLS/DOC parsers - CLAUDE.md + CODE_GUIDE.md: Merged documentation from both branches Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
+204
-3
@@ -154,6 +154,23 @@ def process_input(state: AgentState) -> Dict:
|
||||
user_input = f"{ocr_context}\n\n{user_input}"
|
||||
# 同时更新工作对话历史中的最后一条
|
||||
conv_history[-1]["content"] = user_input
|
||||
# 批注检测(圈选/箭头标记)
|
||||
elements = ocr_result.get("elements", [])
|
||||
if elements:
|
||||
try:
|
||||
from backend.annotation_detector import detect_annotations
|
||||
ann_result = detect_annotations(uploaded_path, elements)
|
||||
if ann_result.get("total", 0) > 0:
|
||||
state["annotation_result"] = ann_result
|
||||
_node_log.info(
|
||||
"批注检测完成",
|
||||
extra={
|
||||
"circles": len(ann_result.get("circles", [])),
|
||||
"arrows": len(ann_result.get("arrows", [])),
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
_node_log.warning(f"批注检测失败: {e}")
|
||||
except Exception as e:
|
||||
_node_log.warning(f"OCR 字段提取失败: {e}")
|
||||
state["ocr_extraction_result"] = {"error": str(e)}
|
||||
@@ -379,7 +396,9 @@ def load_session_node(state: AgentState) -> Dict:
|
||||
# 恢复核心字段(不覆盖当前请求的 user_input / stage / session_id)
|
||||
for key in ("conversation_history", "full_conversation_history",
|
||||
"current_jrxml", "final_jrxml", "compressed_history",
|
||||
"session_name", "created_at", "history_states"):
|
||||
"session_name", "created_at", "history_states",
|
||||
"ocr_extraction_result", "uploaded_file_path",
|
||||
"annotation_result", "layout_schema", "ocr_elements"):
|
||||
if key in saved and key not in ("user_input", "stage", "session_id"):
|
||||
state[key] = saved[key]
|
||||
state["session_name"] = data.get("session_name", "")
|
||||
@@ -401,7 +420,9 @@ def save_session_node(state: AgentState) -> Dict:
|
||||
persistable = {}
|
||||
for key in ("session_id", "conversation_history", "full_conversation_history",
|
||||
"current_jrxml", "final_jrxml", "compressed_history",
|
||||
"status", "error_msg", "history_states"):
|
||||
"status", "error_msg", "history_states",
|
||||
"ocr_extraction_result", "uploaded_file_path",
|
||||
"annotation_result", "layout_schema", "ocr_elements"):
|
||||
if key in state:
|
||||
persistable[key] = state[key]
|
||||
persistable["updated_at"] = _now_iso()
|
||||
@@ -436,6 +457,81 @@ def _now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _format_row_coordinates(row: dict) -> dict:
|
||||
"""将单行 OCR 元素格式化为紧凑的坐标描述,供阶段二 refine_layout 使用。"""
|
||||
if not isinstance(row, dict):
|
||||
return {}
|
||||
elements = row.get("elements", [])
|
||||
if not elements:
|
||||
return {"y_center": row.get("y_center", 0), "columns": []}
|
||||
sorted_elems = sorted(elements, key=lambda e: e.get("x", 0))
|
||||
cols = []
|
||||
for ci, e in enumerate(sorted_elems):
|
||||
cols.append({
|
||||
"col": ci,
|
||||
"x": e.get("x", 0),
|
||||
"y": e.get("y", 0),
|
||||
"w": e.get("w", 0),
|
||||
"h": e.get("h", 0),
|
||||
"font_size": e.get("font_size", 12),
|
||||
"text": e.get("text", ""),
|
||||
})
|
||||
return {"y_center": row.get("y_center", 0), "columns": cols}
|
||||
|
||||
|
||||
def _format_ocr_context(state: AgentState) -> str:
|
||||
"""将 OCR 提取结果格式化为 LLM 可用的上下文文本。"""
|
||||
ocr_result = state.get("ocr_extraction_result")
|
||||
if not ocr_result or not isinstance(ocr_result, dict):
|
||||
return ""
|
||||
if ocr_result.get("error"):
|
||||
return ""
|
||||
|
||||
parts = []
|
||||
parts.append("[图片OCR识别结果]")
|
||||
|
||||
total = ocr_result.get("total_elements", 0)
|
||||
if total:
|
||||
parts.append(f"检测到 {total} 个文字元素")
|
||||
|
||||
# 提取到的字段
|
||||
fields = ocr_result.get("fields", [])
|
||||
if fields:
|
||||
parts.append("\n提取的结构化字段:")
|
||||
for f in fields:
|
||||
if f.get("field_value"):
|
||||
parts.append(
|
||||
f" - {f['field_name']}: {f['field_value']} "
|
||||
f"(方法={f.get('extraction_method','?')}, "
|
||||
f"置信度={f.get('confidence',0):.2f})"
|
||||
)
|
||||
|
||||
# 所有原始文本(用于表格匹配等需要全文的场景)
|
||||
elements = ocr_result.get("elements", [])
|
||||
if elements:
|
||||
parts.append("\n全部文本元素(含坐标):")
|
||||
for e in elements:
|
||||
bbox = e.get("bbox", {})
|
||||
x, y, w, h = bbox.get("x", 0), bbox.get("y", 0), bbox.get("w", 0), bbox.get("h", 0)
|
||||
parts.append(
|
||||
f" [{x},{y} {w}×{h}] {e['text']} "
|
||||
f"(置信度={e.get('confidence',0):.2f})"
|
||||
)
|
||||
|
||||
# 批注检测结果
|
||||
ann_result = state.get("annotation_result")
|
||||
if ann_result and isinstance(ann_result, dict):
|
||||
try:
|
||||
from backend.annotation_detector import format_annotation_context
|
||||
ann_text = format_annotation_context(ann_result)
|
||||
if ann_text:
|
||||
parts.append("\n" + ann_text)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
@log_node("retrieve")
|
||||
def retrieve(state: AgentState) -> Dict:
|
||||
"""在 ChromaDB + 错误知识库中搜索相关的 JRXML 模板和组件。"""
|
||||
@@ -466,9 +562,15 @@ def generate(state: AgentState) -> Dict:
|
||||
|
||||
writer = get_stream_writer()
|
||||
llm = get_llm(caller="generate")
|
||||
|
||||
user_request = state.get("user_input", "")
|
||||
ocr_text = _format_ocr_context(state)
|
||||
if ocr_text:
|
||||
user_request = f"{ocr_text}\n\n---\n用户需求:\n{user_request}"
|
||||
|
||||
prompt = load_prompt("initial_generation").format(
|
||||
context=state.get("retrieved_context", ""),
|
||||
user_request=state.get("user_input", ""),
|
||||
user_request=user_request,
|
||||
)
|
||||
full = []
|
||||
for chunk in llm.stream(prompt):
|
||||
@@ -480,6 +582,104 @@ def generate(state: AgentState) -> Dict:
|
||||
return state
|
||||
|
||||
|
||||
@log_node("generate_skeleton")
|
||||
def generate_skeleton(state: AgentState) -> Dict:
|
||||
"""阶段一:根据压缩的布局 schema 生成骨架 JRXML($F{field_N} 占位)。"""
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
writer = get_stream_writer()
|
||||
llm = get_llm(caller="generate_skeleton")
|
||||
|
||||
schema = state.get("layout_schema", {})
|
||||
schema_text = schema.get("schema_text", "") if isinstance(schema, dict) else ""
|
||||
user_request = state.get("user_input", "")
|
||||
|
||||
prompt = load_prompt("skeleton_generation").format(
|
||||
layout_schema=schema_text,
|
||||
context=state.get("retrieved_context", ""),
|
||||
user_request=user_request,
|
||||
)
|
||||
full = []
|
||||
for chunk in llm.stream(prompt):
|
||||
full.append(chunk)
|
||||
writer({"type": "stream", "node": "generate_skeleton", "text": chunk})
|
||||
jrxml = _extract_jrxml("".join(full))
|
||||
state["current_jrxml"] = jrxml
|
||||
state["conversation_history"].append({"role": "assistant", "content": jrxml})
|
||||
return state
|
||||
|
||||
|
||||
@log_node("refine_layout")
|
||||
def refine_layout(state: AgentState) -> Dict:
|
||||
"""阶段二:使用采样坐标(表头 + 首行数据 + 最后一行)精确调整元素位置。"""
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
writer = get_stream_writer()
|
||||
llm = get_llm(caller="refine_layout")
|
||||
|
||||
ocr_rows = state.get("ocr_elements", [])
|
||||
sampled = {}
|
||||
if isinstance(ocr_rows, list) and len(ocr_rows) >= 1:
|
||||
sampled["header_row"] = _format_row_coordinates(ocr_rows[0])
|
||||
if len(ocr_rows) > 1:
|
||||
sampled["first_data_row"] = _format_row_coordinates(ocr_rows[1])
|
||||
if len(ocr_rows) > 2:
|
||||
sampled["last_row"] = _format_row_coordinates(ocr_rows[-1])
|
||||
sampled_text = json.dumps(sampled, ensure_ascii=False, indent=2)
|
||||
|
||||
prompt = load_prompt("refine_layout").format(
|
||||
current_jrxml=state.get("current_jrxml", ""),
|
||||
sampled_coordinates=sampled_text,
|
||||
)
|
||||
full = []
|
||||
for chunk in llm.stream(prompt):
|
||||
full.append(chunk)
|
||||
writer({"type": "stream", "node": "refine_layout", "text": chunk})
|
||||
jrxml = _extract_jrxml("".join(full))
|
||||
state["current_jrxml"] = jrxml
|
||||
state["conversation_history"].append({"role": "assistant", "content": jrxml})
|
||||
return state
|
||||
|
||||
|
||||
@log_node("map_fields")
|
||||
def map_fields(state: AgentState) -> Dict:
|
||||
"""阶段三:将占位字段名替换为 OCR 提取的真实字段名。"""
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
writer = get_stream_writer()
|
||||
llm = get_llm(caller="map_fields")
|
||||
|
||||
ocr_result = state.get("ocr_extraction_result", {})
|
||||
fields_text = ""
|
||||
if isinstance(ocr_result, dict) and ocr_result.get("fields"):
|
||||
field_descs = []
|
||||
for f in ocr_result["fields"]:
|
||||
fname = f.get("field_name", "")
|
||||
fval = f.get("field_value", "")
|
||||
if fname:
|
||||
field_descs.append(f" - {fname}: {fval}")
|
||||
if field_descs:
|
||||
fields_text = "提取的字段:\n" + "\n".join(field_descs)
|
||||
if not fields_text:
|
||||
elements = ocr_result.get("elements", []) if isinstance(ocr_result, dict) else []
|
||||
if elements:
|
||||
texts = [e.get("text", "") for e in elements if e.get("text")]
|
||||
fields_text = "OCR 文本内容:\n" + "\n".join(f" - {t}" for t in texts[:50])
|
||||
|
||||
prompt = load_prompt("field_mapping").format(
|
||||
current_jrxml=state.get("current_jrxml", ""),
|
||||
ocr_fields=fields_text,
|
||||
)
|
||||
full = []
|
||||
for chunk in llm.stream(prompt):
|
||||
full.append(chunk)
|
||||
writer({"type": "stream", "node": "map_fields", "text": chunk})
|
||||
jrxml = _extract_jrxml("".join(full))
|
||||
state["current_jrxml"] = jrxml
|
||||
state["conversation_history"].append({"role": "assistant", "content": jrxml})
|
||||
return state
|
||||
|
||||
|
||||
@log_node("modify_jrxml")
|
||||
def modify_jrxml(state: AgentState) -> Dict:
|
||||
"""根据用户的修改请求修改现有 JRXML。"""
|
||||
@@ -500,6 +700,7 @@ def modify_jrxml(state: AgentState) -> Dict:
|
||||
current_jrxml=state.get("current_jrxml", ""),
|
||||
conversation_history=conv_text,
|
||||
modification_request=state.get("user_modification_request", ""),
|
||||
ocr_context=_format_ocr_context(state),
|
||||
)
|
||||
full = []
|
||||
for chunk in llm.stream(prompt):
|
||||
|
||||
Reference in New Issue
Block a user