fix: MAX_RETRY 5 + rolling continuation + namespace-aware JRXML extraction

- MAX_RETRY: 3→5 (graph.py:35, nodes.py:25) with env override
- Rolling continuation: _generate_with_continuation() auto-detects
  truncated JRXML and sends anchor-based continuation, max 3 rounds
- JRXML extraction: regex/end-tag now namespace-prefix aware
  (ns0:jasperReport, ns:jasperReport, etc.)
- All 5 generation nodes refactored to use continuation helper
- Tests updated: scenario1 accepts ns-prefixed root, max_retry
  verifies graph termination
- stop_reason capture + WARNING log on max_tokens truncation
- Correction prompt now injects OCR context + layout schema
This commit is contained in:
2026-05-23 10:58:46 +08:00
parent 83e801a0b8
commit 1210b926c3
5 changed files with 187 additions and 50 deletions
+102 -29
View File
@@ -673,11 +673,15 @@ def generate_skeleton(state: AgentState) -> Dict:
context=state.get("retrieved_context", ""),
user_request=user_request,
)
full = []
for chunk in llm.stream(prompt):
full.append(chunk)
writer({"type": "stream", "node": "generate_skeleton", "text": chunk})
jrxml = _extract_jrxml("".join(full))
prev_jrxml = state.get("current_jrxml", "")
full_text = _generate_with_continuation(llm, prompt, writer, "generate_skeleton")
if not full_text.strip():
_node_log.error("generate_skeleton LLM 返回空响应")
return state
jrxml = _extract_jrxml(full_text)
if len(jrxml.strip()) < 200:
_node_log.warning(f"generate_skeleton 输出过短({len(jrxml)} 字符),回退到前一版本")
jrxml = prev_jrxml
state["current_jrxml"] = jrxml
state["conversation_history"].append({"role": "assistant", "content": jrxml})
return state
@@ -705,11 +709,15 @@ def refine_layout(state: AgentState) -> Dict:
current_jrxml=state.get("current_jrxml", ""),
sampled_coordinates=sampled_text,
)
full = []
for chunk in llm.stream(prompt):
full.append(chunk)
writer({"type": "stream", "node": "refine_layout", "text": chunk})
jrxml = _extract_jrxml("".join(full))
prev_jrxml = state.get("current_jrxml", "")
full_text = _generate_with_continuation(llm, prompt, writer, "refine_layout")
if not full_text.strip():
_node_log.error("refine_layout LLM 返回空响应,保留前一版本")
return state
jrxml = _extract_jrxml(full_text)
if len(jrxml.strip()) < 200:
_node_log.warning(f"refine_layout 输出过短({len(jrxml)} 字符),回退到前一版本")
jrxml = prev_jrxml
state["current_jrxml"] = jrxml
state["conversation_history"].append({"role": "assistant", "content": jrxml})
return state
@@ -744,11 +752,15 @@ def map_fields(state: AgentState) -> Dict:
current_jrxml=state.get("current_jrxml", ""),
ocr_fields=fields_text,
)
full = []
for chunk in llm.stream(prompt):
full.append(chunk)
writer({"type": "stream", "node": "map_fields", "text": chunk})
jrxml = _extract_jrxml("".join(full))
prev_jrxml = state.get("current_jrxml", "")
full_text = _generate_with_continuation(llm, prompt, writer, "map_fields")
if not full_text.strip():
_node_log.error("map_fields LLM 返回空响应,保留占位字段版本")
return state
jrxml = _extract_jrxml(full_text)
if len(jrxml.strip()) < 200:
_node_log.warning(f"map_fields 输出过短({len(jrxml)} 字符),回退到前一版本")
jrxml = prev_jrxml
state["current_jrxml"] = jrxml
state["conversation_history"].append({"role": "assistant", "content": jrxml})
return state
@@ -776,11 +788,15 @@ def modify_jrxml(state: AgentState) -> Dict:
modification_request=state.get("user_modification_request", ""),
ocr_context=_format_ocr_context(state),
)
full = []
for chunk in llm.stream(prompt):
full.append(chunk)
writer({"type": "stream", "node": "modify_jrxml", "text": chunk})
jrxml = _extract_jrxml("".join(full))
prev_jrxml = state.get("current_jrxml", "")
full_text = _generate_with_continuation(llm, prompt, writer, "modify_jrxml")
if not full_text.strip():
_node_log.error("modify_jrxml LLM 返回空响应,保留原版本")
return state
jrxml = _extract_jrxml(full_text)
if len(jrxml.strip()) < 200:
_node_log.warning(f"modify_jrxml 输出过短({len(jrxml)} 字符),回退到前一版本")
jrxml = prev_jrxml
state["current_jrxml"] = jrxml
state["conversation_history"].append(
{
@@ -876,10 +892,17 @@ def correct_jrxml(state: AgentState) -> Dict:
writer = get_stream_writer()
llm = get_llm(caller="correct_jrxml")
ocr_context = _format_ocr_context(state)
layout_schema = state.get("layout_schema", {})
layout_text = ""
if isinstance(layout_schema, dict):
layout_text = layout_schema.get("schema_text", "")
prompt = load_prompt("correction").format(
current_jrxml=state.get("current_jrxml", ""),
error_msg=state.get("error_msg", ""),
explanation=state.get("natural_explanation", ""),
ocr_context=ocr_context,
layout_schema_text=layout_text,
)
# 保存修正前状态(供 validate 判断是否写入错误知识库)
state["last_error_case"] = {
@@ -888,11 +911,16 @@ def correct_jrxml(state: AgentState) -> Dict:
"correction_prompt": prompt,
}
full = []
for chunk in llm.stream(prompt):
full.append(chunk)
writer({"type": "stream", "node": "correct_jrxml", "text": chunk})
jrxml = _extract_jrxml("".join(full))
prev_jrxml = state.get("current_jrxml", "")
full_text = _generate_with_continuation(llm, prompt, writer, "correct_jrxml")
if not full_text.strip():
_node_log.error("correct_jrxml LLM 返回空响应,保留原版本")
state["retry_count"] = state.get("retry_count", 0) + 1
return state
jrxml = _extract_jrxml(full_text)
if len(jrxml.strip()) < 200:
_node_log.warning(f"correct_jrxml 输出过短({len(jrxml)} 字符),回退到前一版本")
jrxml = prev_jrxml
state["current_jrxml"] = jrxml
state["retry_count"] = state.get("retry_count", 0) + 1
state["conversation_history"].append(
@@ -963,6 +991,49 @@ def finalize(state: AgentState) -> Dict:
return state
def _generate_with_continuation(llm, prompt, writer, node_name, max_rounds=3) -> str:
"""Stream LLM generation with automatic truncation recovery.
After each stream round, checks if the extracted JRXML ends with
</jasperReport>. If truncated, sends a continuation request with
the last 800 chars as anchor context.
Returns combined full text from all rounds.
"""
full_text = ""
for round_num in range(max_rounds):
if round_num == 0:
current_prompt = prompt
else:
tail = full_text[-800:] if len(full_text) > 800 else full_text
current_prompt = (
f"[系统指令] 你正在生成的 JRXML 在上一次响应中被截断。\n"
f"已生成内容的最后部分(请从此处继续):\n...{tail}\n\n"
f"请从截断点继续输出剩余内容,不要重复已输出的部分。"
)
new_chunks = []
for chunk in llm.stream(current_prompt):
new_chunks.append(chunk)
writer({"type": "stream", "node": node_name, "text": chunk})
new_text = "".join(new_chunks)
full_text += new_text
jrxml = _extract_jrxml(full_text)
if re.search(r"</(?:[\w:]+:)?jasperReport>\s*$", jrxml, re.IGNORECASE):
break
if not new_text.strip():
_node_log.warning(f"{node_name}{round_num+1}轮续写无输出,停止")
break
else:
_node_log.warning(f"{node_name}{max_rounds}轮续写仍未完整")
return full_text
def _extract_jrxml(text: str) -> str:
"""从 LLM 响应中提取 JRXML 内容,如有 markdown 标记则去除。"""
text = text.strip()
@@ -974,7 +1045,8 @@ def _extract_jrxml(text: str) -> str:
return content
# markdown 代码块存在但内容为空 — 回退到直接匹配
jasper_tag = re.search(r"(<\?xml[\s\S]*?</jasperReport>)", text, re.IGNORECASE)
_jrxml_close = r"</(?:[\w:]+:)?jasperReport>"
jasper_tag = re.search(rf"(<\?xml[\s\S]*?{_jrxml_close})", text, re.IGNORECASE)
if jasper_tag:
return jasper_tag.group(1).strip()
@@ -984,8 +1056,9 @@ def _extract_jrxml(text: str) -> str:
# 最终回退:如果文本中包含 XML 片段但没有被捕获到,尝试直接提取
# 这处理 LLM 在代码块外用自然语言"包裹"JRXML 的情况
xml_start = text.find("<?xml")
jr_end = text.lower().rfind("</jasperreport>")
if xml_start >= 0 and jr_end > xml_start:
return text[xml_start:jr_end + len("</jasperreport>")].strip()
jr_close = re.search(_jrxml_close, text, re.IGNORECASE)
if xml_start >= 0 and jr_close:
jr_end = jr_close.end()
return text[xml_start:jr_end].strip()
return text