fix: MAX_RETRY 5 + rolling continuation + namespace-aware JRXML extraction
- MAX_RETRY: 3→5 (graph.py:35, nodes.py:25) with env override - Rolling continuation: _generate_with_continuation() auto-detects truncated JRXML and sends anchor-based continuation, max 3 rounds - JRXML extraction: regex/end-tag now namespace-prefix aware (ns0:jasperReport, ns:jasperReport, etc.) - All 5 generation nodes refactored to use continuation helper - Tests updated: scenario1 accepts ns-prefixed root, max_retry verifies graph termination - stop_reason capture + WARNING log on max_tokens truncation - Correction prompt now injects OCR context + layout schema
This commit is contained in:
+102
-29
@@ -673,11 +673,15 @@ def generate_skeleton(state: AgentState) -> Dict:
|
||||
context=state.get("retrieved_context", ""),
|
||||
user_request=user_request,
|
||||
)
|
||||
full = []
|
||||
for chunk in llm.stream(prompt):
|
||||
full.append(chunk)
|
||||
writer({"type": "stream", "node": "generate_skeleton", "text": chunk})
|
||||
jrxml = _extract_jrxml("".join(full))
|
||||
prev_jrxml = state.get("current_jrxml", "")
|
||||
full_text = _generate_with_continuation(llm, prompt, writer, "generate_skeleton")
|
||||
if not full_text.strip():
|
||||
_node_log.error("generate_skeleton LLM 返回空响应")
|
||||
return state
|
||||
jrxml = _extract_jrxml(full_text)
|
||||
if len(jrxml.strip()) < 200:
|
||||
_node_log.warning(f"generate_skeleton 输出过短({len(jrxml)} 字符),回退到前一版本")
|
||||
jrxml = prev_jrxml
|
||||
state["current_jrxml"] = jrxml
|
||||
state["conversation_history"].append({"role": "assistant", "content": jrxml})
|
||||
return state
|
||||
@@ -705,11 +709,15 @@ def refine_layout(state: AgentState) -> Dict:
|
||||
current_jrxml=state.get("current_jrxml", ""),
|
||||
sampled_coordinates=sampled_text,
|
||||
)
|
||||
full = []
|
||||
for chunk in llm.stream(prompt):
|
||||
full.append(chunk)
|
||||
writer({"type": "stream", "node": "refine_layout", "text": chunk})
|
||||
jrxml = _extract_jrxml("".join(full))
|
||||
prev_jrxml = state.get("current_jrxml", "")
|
||||
full_text = _generate_with_continuation(llm, prompt, writer, "refine_layout")
|
||||
if not full_text.strip():
|
||||
_node_log.error("refine_layout LLM 返回空响应,保留前一版本")
|
||||
return state
|
||||
jrxml = _extract_jrxml(full_text)
|
||||
if len(jrxml.strip()) < 200:
|
||||
_node_log.warning(f"refine_layout 输出过短({len(jrxml)} 字符),回退到前一版本")
|
||||
jrxml = prev_jrxml
|
||||
state["current_jrxml"] = jrxml
|
||||
state["conversation_history"].append({"role": "assistant", "content": jrxml})
|
||||
return state
|
||||
@@ -744,11 +752,15 @@ def map_fields(state: AgentState) -> Dict:
|
||||
current_jrxml=state.get("current_jrxml", ""),
|
||||
ocr_fields=fields_text,
|
||||
)
|
||||
full = []
|
||||
for chunk in llm.stream(prompt):
|
||||
full.append(chunk)
|
||||
writer({"type": "stream", "node": "map_fields", "text": chunk})
|
||||
jrxml = _extract_jrxml("".join(full))
|
||||
prev_jrxml = state.get("current_jrxml", "")
|
||||
full_text = _generate_with_continuation(llm, prompt, writer, "map_fields")
|
||||
if not full_text.strip():
|
||||
_node_log.error("map_fields LLM 返回空响应,保留占位字段版本")
|
||||
return state
|
||||
jrxml = _extract_jrxml(full_text)
|
||||
if len(jrxml.strip()) < 200:
|
||||
_node_log.warning(f"map_fields 输出过短({len(jrxml)} 字符),回退到前一版本")
|
||||
jrxml = prev_jrxml
|
||||
state["current_jrxml"] = jrxml
|
||||
state["conversation_history"].append({"role": "assistant", "content": jrxml})
|
||||
return state
|
||||
@@ -776,11 +788,15 @@ def modify_jrxml(state: AgentState) -> Dict:
|
||||
modification_request=state.get("user_modification_request", ""),
|
||||
ocr_context=_format_ocr_context(state),
|
||||
)
|
||||
full = []
|
||||
for chunk in llm.stream(prompt):
|
||||
full.append(chunk)
|
||||
writer({"type": "stream", "node": "modify_jrxml", "text": chunk})
|
||||
jrxml = _extract_jrxml("".join(full))
|
||||
prev_jrxml = state.get("current_jrxml", "")
|
||||
full_text = _generate_with_continuation(llm, prompt, writer, "modify_jrxml")
|
||||
if not full_text.strip():
|
||||
_node_log.error("modify_jrxml LLM 返回空响应,保留原版本")
|
||||
return state
|
||||
jrxml = _extract_jrxml(full_text)
|
||||
if len(jrxml.strip()) < 200:
|
||||
_node_log.warning(f"modify_jrxml 输出过短({len(jrxml)} 字符),回退到前一版本")
|
||||
jrxml = prev_jrxml
|
||||
state["current_jrxml"] = jrxml
|
||||
state["conversation_history"].append(
|
||||
{
|
||||
@@ -876,10 +892,17 @@ def correct_jrxml(state: AgentState) -> Dict:
|
||||
|
||||
writer = get_stream_writer()
|
||||
llm = get_llm(caller="correct_jrxml")
|
||||
ocr_context = _format_ocr_context(state)
|
||||
layout_schema = state.get("layout_schema", {})
|
||||
layout_text = ""
|
||||
if isinstance(layout_schema, dict):
|
||||
layout_text = layout_schema.get("schema_text", "")
|
||||
prompt = load_prompt("correction").format(
|
||||
current_jrxml=state.get("current_jrxml", ""),
|
||||
error_msg=state.get("error_msg", ""),
|
||||
explanation=state.get("natural_explanation", ""),
|
||||
ocr_context=ocr_context,
|
||||
layout_schema_text=layout_text,
|
||||
)
|
||||
# 保存修正前状态(供 validate 判断是否写入错误知识库)
|
||||
state["last_error_case"] = {
|
||||
@@ -888,11 +911,16 @@ def correct_jrxml(state: AgentState) -> Dict:
|
||||
"correction_prompt": prompt,
|
||||
}
|
||||
|
||||
full = []
|
||||
for chunk in llm.stream(prompt):
|
||||
full.append(chunk)
|
||||
writer({"type": "stream", "node": "correct_jrxml", "text": chunk})
|
||||
jrxml = _extract_jrxml("".join(full))
|
||||
prev_jrxml = state.get("current_jrxml", "")
|
||||
full_text = _generate_with_continuation(llm, prompt, writer, "correct_jrxml")
|
||||
if not full_text.strip():
|
||||
_node_log.error("correct_jrxml LLM 返回空响应,保留原版本")
|
||||
state["retry_count"] = state.get("retry_count", 0) + 1
|
||||
return state
|
||||
jrxml = _extract_jrxml(full_text)
|
||||
if len(jrxml.strip()) < 200:
|
||||
_node_log.warning(f"correct_jrxml 输出过短({len(jrxml)} 字符),回退到前一版本")
|
||||
jrxml = prev_jrxml
|
||||
state["current_jrxml"] = jrxml
|
||||
state["retry_count"] = state.get("retry_count", 0) + 1
|
||||
state["conversation_history"].append(
|
||||
@@ -963,6 +991,49 @@ def finalize(state: AgentState) -> Dict:
|
||||
return state
|
||||
|
||||
|
||||
def _generate_with_continuation(llm, prompt, writer, node_name, max_rounds=3) -> str:
|
||||
"""Stream LLM generation with automatic truncation recovery.
|
||||
|
||||
After each stream round, checks if the extracted JRXML ends with
|
||||
</jasperReport>. If truncated, sends a continuation request with
|
||||
the last 800 chars as anchor context.
|
||||
|
||||
Returns combined full text from all rounds.
|
||||
"""
|
||||
full_text = ""
|
||||
|
||||
for round_num in range(max_rounds):
|
||||
if round_num == 0:
|
||||
current_prompt = prompt
|
||||
else:
|
||||
tail = full_text[-800:] if len(full_text) > 800 else full_text
|
||||
current_prompt = (
|
||||
f"[系统指令] 你正在生成的 JRXML 在上一次响应中被截断。\n"
|
||||
f"已生成内容的最后部分(请从此处继续):\n...{tail}\n\n"
|
||||
f"请从截断点继续输出剩余内容,不要重复已输出的部分。"
|
||||
)
|
||||
|
||||
new_chunks = []
|
||||
for chunk in llm.stream(current_prompt):
|
||||
new_chunks.append(chunk)
|
||||
writer({"type": "stream", "node": node_name, "text": chunk})
|
||||
|
||||
new_text = "".join(new_chunks)
|
||||
full_text += new_text
|
||||
|
||||
jrxml = _extract_jrxml(full_text)
|
||||
if re.search(r"</(?:[\w:]+:)?jasperReport>\s*$", jrxml, re.IGNORECASE):
|
||||
break
|
||||
|
||||
if not new_text.strip():
|
||||
_node_log.warning(f"{node_name} 第{round_num+1}轮续写无输出,停止")
|
||||
break
|
||||
else:
|
||||
_node_log.warning(f"{node_name} 经{max_rounds}轮续写仍未完整")
|
||||
|
||||
return full_text
|
||||
|
||||
|
||||
def _extract_jrxml(text: str) -> str:
|
||||
"""从 LLM 响应中提取 JRXML 内容,如有 markdown 标记则去除。"""
|
||||
text = text.strip()
|
||||
@@ -974,7 +1045,8 @@ def _extract_jrxml(text: str) -> str:
|
||||
return content
|
||||
# markdown 代码块存在但内容为空 — 回退到直接匹配
|
||||
|
||||
jasper_tag = re.search(r"(<\?xml[\s\S]*?</jasperReport>)", text, re.IGNORECASE)
|
||||
_jrxml_close = r"</(?:[\w:]+:)?jasperReport>"
|
||||
jasper_tag = re.search(rf"(<\?xml[\s\S]*?{_jrxml_close})", text, re.IGNORECASE)
|
||||
if jasper_tag:
|
||||
return jasper_tag.group(1).strip()
|
||||
|
||||
@@ -984,8 +1056,9 @@ def _extract_jrxml(text: str) -> str:
|
||||
# 最终回退:如果文本中包含 XML 片段但没有被捕获到,尝试直接提取
|
||||
# 这处理 LLM 在代码块外用自然语言"包裹"JRXML 的情况
|
||||
xml_start = text.find("<?xml")
|
||||
jr_end = text.lower().rfind("</jasperreport>")
|
||||
if xml_start >= 0 and jr_end > xml_start:
|
||||
return text[xml_start:jr_end + len("</jasperreport>")].strip()
|
||||
jr_close = re.search(_jrxml_close, text, re.IGNORECASE)
|
||||
if xml_start >= 0 and jr_close:
|
||||
jr_end = jr_close.end()
|
||||
return text[xml_start:jr_end].strip()
|
||||
|
||||
return text
|
||||
Reference in New Issue
Block a user