feat: layered precise generation for A4 report images

3-phase pipeline to solve LLM prompt overflow from too many OCR elements:
Phase 1 (generate_skeleton): compressed layout schema → skeleton JRXML
Phase 2 (refine_layout): sampled coordinates → pixel-level position tuning
Phase 3 (map_fields): OCR field names → replace $F{field_N} placeholders

Only triggered when layout_schema.total_rows > 0 on initial_generation intent.
Text requests and all other intents are unaffected (zero behavior change).
This commit is contained in:
2026-05-21 08:34:32 +08:00
parent 9bb011e429
commit 43a0542a11
14 changed files with 882 additions and 81 deletions
+36 -1
View File
@@ -16,6 +16,9 @@ from agent.nodes import (
classify_intent,
retrieve,
generate,
generate_skeleton,
refine_layout,
map_fields,
modify_jrxml,
handle_consult,
handle_undo,
@@ -87,6 +90,15 @@ def route_by_intent(state: AgentState) -> Literal[
return "retrieve"
@_log_route("route_after_retrieve")
def route_after_retrieve(state: AgentState) -> Literal["generate", "generate_skeleton"]:
"""当 layout_schema 存在时走三层精确生成,否则走原有 1-shot。"""
layout_schema = state.get("layout_schema")
if layout_schema and isinstance(layout_schema, dict) and layout_schema.get("total_rows", 0) > 0:
return "generate_skeleton"
return "generate"
@_log_route("route_after_generate")
def route_after_generate(state: AgentState) -> Literal["save_session"]:
return "save_session"
@@ -158,6 +170,11 @@ def build_graph() -> StateGraph:
workflow.add_node("handle_undo", handle_undo)
workflow.add_node("handle_reset", handle_reset)
# 新增节点:分层精确生成(阶段一~三)
workflow.add_node("generate_skeleton", generate_skeleton)
workflow.add_node("refine_layout", refine_layout)
workflow.add_node("map_fields", map_fields)
# ---- 入口和前置流程 ----
workflow.set_entry_point("load_session")
workflow.add_edge("load_session", "process_input")
@@ -180,12 +197,28 @@ def build_graph() -> StateGraph:
)
# ---- 初始生成分支 ----
workflow.add_edge("retrieve", "generate")
workflow.add_conditional_edges(
"retrieve",
route_after_retrieve,
{
"generate": "generate",
"generate_skeleton": "generate_skeleton",
},
)
# 原有 1-shot 路径
workflow.add_conditional_edges(
"generate",
route_after_generate,
{"save_session": "save_session"},
)
# 分层精确生成 3 阶段路径
workflow.add_edge("generate_skeleton", "refine_layout")
workflow.add_edge("refine_layout", "map_fields")
workflow.add_conditional_edges(
"map_fields",
route_after_generate,
{"save_session": "save_session"},
)
# ---- 修改分支 ----
workflow.add_conditional_edges(
@@ -264,4 +297,6 @@ def create_initial_state() -> AgentState:
jrxml_versions=[],
last_error_case={},
pending_failure_context={},
layout_schema={},
ocr_elements=[],
)
+122 -2
View File
@@ -378,7 +378,7 @@ def load_session_node(state: AgentState) -> Dict:
"current_jrxml", "final_jrxml", "compressed_history",
"session_name", "created_at", "history_states",
"ocr_extraction_result", "uploaded_file_path",
"annotation_result"):
"annotation_result", "layout_schema", "ocr_elements"):
if key in saved and key not in ("user_input", "stage"):
state[key] = saved[key]
state["session_name"] = data.get("session_name", "")
@@ -402,7 +402,7 @@ def save_session_node(state: AgentState) -> Dict:
"current_jrxml", "final_jrxml", "compressed_history",
"status", "error_msg", "history_states",
"ocr_extraction_result", "uploaded_file_path",
"annotation_result"):
"annotation_result", "layout_schema", "ocr_elements"):
if key in state:
persistable[key] = state[key]
persistable["updated_at"] = _now_iso()
@@ -437,6 +437,28 @@ def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def _format_row_coordinates(row: dict) -> dict:
"""将单行 OCR 元素格式化为紧凑的坐标描述,供阶段二 refine_layout 使用。"""
if not isinstance(row, dict):
return {}
elements = row.get("elements", [])
if not elements:
return {"y_center": row.get("y_center", 0), "columns": []}
sorted_elems = sorted(elements, key=lambda e: e.get("x", 0))
cols = []
for ci, e in enumerate(sorted_elems):
cols.append({
"col": ci,
"x": e.get("x", 0),
"y": e.get("y", 0),
"w": e.get("w", 0),
"h": e.get("h", 0),
"font_size": e.get("font_size", 12),
"text": e.get("text", ""),
})
return {"y_center": row.get("y_center", 0), "columns": cols}
def _format_ocr_context(state: AgentState) -> str:
"""将 OCR 提取结果格式化为 LLM 可用的上下文文本。"""
ocr_result = state.get("ocr_extraction_result")
@@ -540,6 +562,104 @@ def generate(state: AgentState) -> Dict:
return state
@log_node("generate_skeleton")
def generate_skeleton(state: AgentState) -> Dict:
"""阶段一:根据压缩的布局 schema 生成骨架 JRXML$F{field_N} 占位)。"""
from langgraph.config import get_stream_writer
writer = get_stream_writer()
llm = get_llm(caller="generate_skeleton")
schema = state.get("layout_schema", {})
schema_text = schema.get("schema_text", "") if isinstance(schema, dict) else ""
user_request = state.get("user_input", "")
prompt = load_prompt("skeleton_generation").format(
layout_schema=schema_text,
context=state.get("retrieved_context", ""),
user_request=user_request,
)
full = []
for chunk in llm.stream(prompt):
full.append(chunk)
writer({"type": "stream", "node": "generate_skeleton", "text": chunk})
jrxml = _extract_jrxml("".join(full))
state["current_jrxml"] = jrxml
state["conversation_history"].append({"role": "assistant", "content": jrxml})
return state
@log_node("refine_layout")
def refine_layout(state: AgentState) -> Dict:
"""阶段二:使用采样坐标(表头 + 首行数据 + 最后一行)精确调整元素位置。"""
from langgraph.config import get_stream_writer
writer = get_stream_writer()
llm = get_llm(caller="refine_layout")
ocr_rows = state.get("ocr_elements", [])
sampled = {}
if isinstance(ocr_rows, list) and len(ocr_rows) >= 1:
sampled["header_row"] = _format_row_coordinates(ocr_rows[0])
if len(ocr_rows) > 1:
sampled["first_data_row"] = _format_row_coordinates(ocr_rows[1])
if len(ocr_rows) > 2:
sampled["last_row"] = _format_row_coordinates(ocr_rows[-1])
sampled_text = json.dumps(sampled, ensure_ascii=False, indent=2)
prompt = load_prompt("refine_layout").format(
current_jrxml=state.get("current_jrxml", ""),
sampled_coordinates=sampled_text,
)
full = []
for chunk in llm.stream(prompt):
full.append(chunk)
writer({"type": "stream", "node": "refine_layout", "text": chunk})
jrxml = _extract_jrxml("".join(full))
state["current_jrxml"] = jrxml
state["conversation_history"].append({"role": "assistant", "content": jrxml})
return state
@log_node("map_fields")
def map_fields(state: AgentState) -> Dict:
"""阶段三:将占位字段名替换为 OCR 提取的真实字段名。"""
from langgraph.config import get_stream_writer
writer = get_stream_writer()
llm = get_llm(caller="map_fields")
ocr_result = state.get("ocr_extraction_result", {})
fields_text = ""
if isinstance(ocr_result, dict) and ocr_result.get("fields"):
field_descs = []
for f in ocr_result["fields"]:
fname = f.get("field_name", "")
fval = f.get("field_value", "")
if fname:
field_descs.append(f" - {fname}: {fval}")
if field_descs:
fields_text = "提取的字段:\n" + "\n".join(field_descs)
if not fields_text:
elements = ocr_result.get("elements", []) if isinstance(ocr_result, dict) else []
if elements:
texts = [e.get("text", "") for e in elements if e.get("text")]
fields_text = "OCR 文本内容:\n" + "\n".join(f" - {t}" for t in texts[:50])
prompt = load_prompt("field_mapping").format(
current_jrxml=state.get("current_jrxml", ""),
ocr_fields=fields_text,
)
full = []
for chunk in llm.stream(prompt):
full.append(chunk)
writer({"type": "stream", "node": "map_fields", "text": chunk})
jrxml = _extract_jrxml("".join(full))
state["current_jrxml"] = jrxml
state["conversation_history"].append({"role": "assistant", "content": jrxml})
return state
@log_node("modify_jrxml")
def modify_jrxml(state: AgentState) -> Dict:
"""根据用户的修改请求修改现有 JRXML。"""
+4
View File
@@ -47,3 +47,7 @@ class AgentState(TypedDict, total=False):
# 需求8:图片批注检测(圈选/箭头标记)
annotation_result: dict
# 需求9:分层精确生成
layout_schema: dict # extract_layout_schema() 输出,列+区域结构
ocr_elements: list # OCR 原始行数据(用于阶段二坐标采样)