"""Streamlit 多轮对话 UI,用于 JRXML 生成代理。 支持: - 流式输出(LLM 逐字展示) - 节点平铺展开(每个处理阶段独立展示) - 完成后自动折叠节点区 - 过程总结卡片 """ import os import sys os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error") try: import torchvision except Exception: pass import time from pathlib import Path import streamlit as st from dotenv import load_dotenv load_dotenv() from agent.graph import build_graph, create_initial_state from backend.session import ( create_session, load_session, delete_session, list_all_sessions, ) from backend.logger import get_logger, set_trace_id, generate_trace_id _app_log = get_logger("app") st.set_page_config( page_title="JRXML 代理", page_icon="📊", layout="wide", initial_sidebar_state="expanded", ) # 阻止 Streamlit 裸 'c' 键清除缓存,保留 Ctrl+C 复制行为 st.html(""" """) # ---- 节点名称 → 中文标签 ---- NODE_LABELS = { "load_session": "📂 加载会话", "process_input": "📝 记录输入", "manage_context": "🧠 管理上下文", "save_state_snapshot": "💾 保存快照", "classify_intent": "🔍 识别意图", "retrieve": "📚 检索模板", "generate": "⚙️ 生成 JRXML", "modify_jrxml": "🔧 修改 JRXML", "validate": "✅ 验证", "explain_error": "🔎 分析错误", "correct_jrxml": "🛠 自动修正", "finalize": "📋 完成", "handle_consult": "💬 咨询回答", "handle_undo": "↩ 撤销操作", "handle_reset": "🔄 重置会话", "save_session": "💾 保存会话", } INTENT_LABELS = { "initial_generation": "新建报表", "modify_report": "修改报表", "preview_report": "预览报表", "export_pdf": "导出 PDF", "export_jrxml": "下载 JRXML", "undo_modification": "撤销修改", "consult_question": "咨询问题", "reset_session": "重置会话", } SKIP_NODES = {"load_session", "process_input", "manage_context", "save_state_snapshot", "save_session"} def _render_jrxml(jrxml: str, max_lines: int = 30): """展示 JRXML 代码(折叠、限行)。""" lines = jrxml.strip().split("\n") preview = "\n".join(lines[:max_lines]) if len(lines) > max_lines: preview += f"\n... (共 {len(lines)} 行)" st.code(preview, language="xml") # ---- URL 参数 ---- query_params = st.query_params url_session_id = query_params.get("session_id", "") # ---- 会话状态初始化 ---- if "messages" not in st.session_state: st.session_state.messages = [] if "graph" not in st.session_state: st.session_state.graph = build_graph() if "pending_action" not in st.session_state: st.session_state.pending_action = None if "agent_state" not in st.session_state: if url_session_id: data = load_session(url_session_id) if data and data.get("agent_state"): st.session_state.agent_state = data["agent_state"] st.session_state.agent_state["session_id"] = url_session_id else: st.session_state.agent_state = create_initial_state() new_data = create_session(name="", agent_state=st.session_state.agent_state) st.session_state.agent_state["session_id"] = new_data["session_id"] st.session_state.agent_state["session_name"] = new_data["session_name"] st.session_state.agent_state["created_at"] = new_data["created_at"] else: st.session_state.agent_state = create_initial_state() new_data = create_session(name="", agent_state=st.session_state.agent_state) st.session_state.agent_state["session_id"] = new_data["session_id"] st.session_state.agent_state["session_name"] = new_data["session_name"] st.session_state.agent_state["created_at"] = new_data["created_at"] current_session_id = st.session_state.agent_state.get("session_id", "") def run_agent(user_input: str): """运行代理图:流式渲染节点进度 + LLM 文本。""" trace_id = generate_trace_id() set_trace_id(trace_id) agent_state = st.session_state.agent_state session_id = agent_state.get("session_id", "") _app_log.info( "代理执行开始", extra={ "session_id": session_id, "trace_id": trace_id, "user_input_preview": user_input[:200], "user_input_length": len(user_input), "has_jrxml": bool(agent_state.get("current_jrxml", "").strip()), "intent": agent_state.get("intent", ""), }, ) if agent_state.get("current_jrxml") and agent_state.get("status") == "pass": agent_state["user_modification_request"] = user_input agent_state["user_input"] = user_input agent_state["retry_count"] = 0 # ---- UI 占位 ---- progress_placeholder = st.empty() # 实时节点进度 streaming_placeholder = st.empty() # 流式文本 summary_placeholder = st.empty() # 总结卡片 # 初始状态提示 progress_placeholder.info("⏳ 正在分析您的需求...") executed_nodes: list[dict] = [] stream_text = "" stream_active = False final_state = None def _render_progress(nodes: list[dict]): """渲染实时节点进度到占位符。""" if not nodes: return lines = [] for i, node in enumerate(nodes): icon = "●" if i == len(nodes) - 1 else "✓" detail = f" — {node['detail']}" if node.get("detail") else "" lines.append(f"{icon} {node['label']}{detail}") progress_placeholder.markdown("\n\n".join(lines)) try: for event in st.session_state.graph.stream( agent_state, stream_mode=["updates", "custom"] ): mode, data = event if mode == "updates": for node_name, node_state in data.items(): label = NODE_LABELS.get(node_name, node_name) if node_name not in SKIP_NODES: executed_nodes.append({ "name": node_name, "label": label, }) if node_name == "classify_intent": intent = node_state.get("intent", "") il = INTENT_LABELS.get(intent, intent) executed_nodes[-1]["detail"] = f"意图: {il}" elif node_name == "retrieve": ctx = node_state.get("retrieved_context", "") executed_nodes[-1]["detail"] = ( f"找到 {len(ctx)} 字符参考模板" if ctx else "未匹配到模板" ) elif node_name in ("generate", "modify_jrxml", "correct_jrxml"): jrxml = node_state.get("current_jrxml", "") executed_nodes[-1]["detail"] = f"生成 {len(jrxml)} 字符 JRXML" elif node_name == "validate": status = node_state.get("status", "") if status == "pass": executed_nodes[-1]["detail"] = "验证通过 ✓" else: err = node_state.get("error_msg", "") executed_nodes[-1]["detail"] = f"验证失败: {err[:80]}" elif node_name == "explain_error": expl = node_state.get("natural_explanation", "") executed_nodes[-1]["detail"] = expl[:120] elif node_name == "handle_consult": ans = node_state.get("consult_answer", "") executed_nodes[-1]["detail"] = ans[:150] final_state = node_state # 每个节点完成后立即更新进度 _render_progress(executed_nodes) elif mode == "custom": cd = data if cd.get("type") == "stream": stream_text += cd.get("text", "") stream_active = True streaming_placeholder.code(stream_text, language="xml") except Exception as e: progress_placeholder.empty() _app_log.error( f"代理执行异常: {e}", extra={"session_id": session_id, "error": str(e)}, ) st.error(f"工作流异常: {e}") return # ---- 清理临时占位 ---- progress_placeholder.empty() if stream_active: streaming_placeholder.empty() # 清理已处理的临时文件 for p in st.session_state.get("uploaded_temp_paths", []): try: Path(p).unlink(missing_ok=True) except Exception: pass st.session_state.uploaded_temp_paths = [] # ---- 总结卡片 ---- # 注:node_state 只含变更字段,用 agent_state(被所有节点就地修改)获取完整状态 final_state = agent_state if final_state: st.session_state.agent_state = final_state intent = final_state.get("intent", "") status = final_state.get("status", "") with summary_placeholder.container(border=True): if intent == "consult_question": answer = final_state.get("consult_answer", "") st.info(answer) st.session_state.messages.append({ "role": "assistant", "content": answer, "type": "consult", }) elif intent in ("undo_modification", "reset_session"): st.success("操作已完成") elif intent in ("preview_report", "export_pdf", "export_jrxml"): jrxml = final_state.get("current_jrxml", "") if jrxml: st.success("✅ 当前报表") _render_jrxml(jrxml) st.session_state.messages.append({ "role": "assistant", "content": jrxml, "type": "jrxml", }) else: st.warning("⚠ 当前没有报表可以展示。") elif status == "pass": jrxml = final_state.get("current_jrxml", "") st.success("✅ JRXML 生成成功") st.markdown("**生成结果:**") _render_jrxml(jrxml) st.caption("您可以从侧边栏下载文件,或继续对话进行修改。") st.session_state.messages.append({ "role": "assistant", "content": jrxml, "type": "jrxml", }) st.session_state.messages.append({ "role": "assistant", "content": "✅ JRXML 生成成功!您可以从侧边栏下载文件,或继续修改。", "type": "success", }) else: jrxml = final_state.get("current_jrxml", "") error_msg = final_state.get("error_msg", "未知错误") explanation = final_state.get("natural_explanation", "") retries = final_state.get("retry_count", 0) st.error(f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML") st.markdown(f"**错误:** {error_msg}") if explanation: st.markdown(f"**原因:** {explanation}") if jrxml: with st.expander("查看当前 JRXML"): _render_jrxml(jrxml, max_lines=80) st.caption("💡 下次输入修改需求时,系统会自动加载失败上下文继续修复。") st.session_state.messages.append({ "role": "assistant", "content": f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML。\n\n**错误:** {error_msg}\n\n💡 请直接描述修改需求,系统会自动加载失败上下文。", "type": "error_explanation", }) # OCR 字段提取结果展示 ocr_result = agent_state.get("ocr_extraction_result", {}) if ocr_result and ocr_result.get("ocr_available") and ocr_result.get("fields"): with st.expander("🔍 OCR 单据字段提取结果", expanded=False): fields = ocr_result.get("fields", []) non_empty = [f for f in fields if f.get("field_value")] empty = [f for f in fields if not f.get("field_value")] if non_empty: st.markdown("**已提取字段:**") for f in non_empty: method = f.get("extraction_method", "") conf = f.get("confidence", 0) st.markdown( f"- **{f['field_name']}**: `{f['field_value']}` " f"(置信度: {conf:.0%}, 方法: {method})" ) if empty: st.caption( f"未提取到值的字段: {', '.join(f['field_name'] for f in empty)}" ) st.caption( f"共检测到 {ocr_result.get('total_elements', 0)} 个文本元素" ) else: st.error("未产生结果,请重试。") _app_log.info( "代理执行完成", extra={ "session_id": session_id, "intent": final_state.get("intent", ""), "status": final_state.get("status", ""), "jrxml_length": len(final_state.get("current_jrxml", "")), "retry_count": final_state.get("retry_count", 0), }, ) # ---- 侧边栏 ---- with st.sidebar: st.title("📊 JRXML 代理") st.markdown("通过自然语言生成 JasperReports 模板。") st.divider() # 会话管理 st.markdown("### 会话管理") sessions = list_all_sessions() session_options = {} for s in sessions: sid = s["session_id"] name = s.get("session_name", sid) updated = s.get("updated_at", "")[:16] session_options[f"{name} ({updated})"] = sid selected_label = None for label, sid in session_options.items(): if sid == current_session_id: selected_label = label break selected = st.selectbox( "切换会话", options=list(session_options.keys()), index=list(session_options.keys()).index(selected_label) if selected_label else 0, key="session_selector", ) if selected and session_options.get(selected) != current_session_id: new_sid = session_options[selected] data = load_session(new_sid) if data and data.get("agent_state"): _app_log.info( "切换会话", extra={"from_session": current_session_id, "to_session": new_sid}, ) st.session_state.agent_state = data["agent_state"] st.session_state.messages = [] st.rerun() col1, col2 = st.columns(2) with col1: if st.button("➕ 新建", use_container_width=True): new_data = create_session(name="", agent_state=create_initial_state()) _app_log.info( "新建会话", extra={"session_id": new_data["session_id"]}, ) st.session_state.agent_state = create_initial_state() st.session_state.agent_state["session_id"] = new_data["session_id"] st.session_state.agent_state["session_name"] = new_data["session_name"] st.session_state.agent_state["created_at"] = new_data["created_at"] st.session_state.messages = [] st.rerun() with col2: if st.button("🗑 删除", use_container_width=True): if current_session_id: _app_log.info( "删除会话", extra={"session_id": current_session_id}, ) delete_session(current_session_id) st.session_state.agent_state = create_initial_state() new_data = create_session(name="", agent_state=st.session_state.agent_state) st.session_state.agent_state["session_id"] = new_data["session_id"] st.session_state.agent_state["session_name"] = new_data["session_name"] st.session_state.agent_state["created_at"] = new_data["created_at"] st.session_state.messages = [] st.rerun() current_name = st.session_state.agent_state.get("session_name", "") st.caption(f"当前: {current_name} (`{current_session_id}`)") st.divider() st.markdown("### 快捷操作") has_jrxml = bool(st.session_state.agent_state.get("current_jrxml", "").strip()) has_history = bool(st.session_state.agent_state.get("history_states", [])) qcol1, qcol2 = st.columns(2) with qcol1: if st.button("👁 预览", use_container_width=True, disabled=not has_jrxml): with st.spinner("正在准备预览..."): run_agent("预览报表") st.rerun() with qcol2: if st.button("↩ 撤销", use_container_width=True, disabled=not has_history): with st.spinner("正在撤销..."): run_agent("撤销上一步修改") st.rerun() if st.button("🔄 重置会话", use_container_width=True): with st.spinner("正在重置..."): run_agent("重新来,清空当前报表") st.rerun() st.divider() st.markdown("### 上传文件") st.caption("支持图片 (OCR)、PDF、Word、文本文件。内容将附加到您的下一条消息中。") if "uploaded_files" not in st.session_state: st.session_state.uploaded_files = [] # [{name, text, type}] if "uploaded_temp_paths" not in st.session_state: st.session_state.uploaded_temp_paths = [] # 待清理的临时文件路径 uploaded = st.file_uploader( "选择文件", type=["png", "jpg", "jpeg", "bmp", "webp", "pdf", "docx", "txt", "csv", "json", "xml"], accept_multiple_files=True, key="file_uploader", label_visibility="collapsed", ) if uploaded: for uf in uploaded: # 去重 if any(f["name"] == uf.name for f in st.session_state.uploaded_files): continue import tempfile from backend.file_parser import parse_file from backend.layout_analyzer import analyze_layout suffix = Path(uf.name).suffix.lower() with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: tmp.write(uf.getvalue()) tmp_path = tmp.name result = parse_file(tmp_path, suffix) # 对图片/PDF 进行 A4 模板布局分析 parsed_text = result["text"] parsed_type = result["file_type"] if suffix in (".png", ".jpg", ".jpeg", ".bmp", ".webp", ".pdf"): layout = analyze_layout(tmp_path) tt = layout.get("template_type", "unknown") current_jrxml = st.session_state.agent_state.get("current_jrxml", "") if tt == "full_a4": parsed_text = layout["description"] parsed_type = "a4_template" elif tt == "partial_rows": parsed_type = "a4_partial" if current_jrxml.strip(): # 修改模式:尝试行匹配 from backend.layout_analyzer import match_rows_to_jrxml match = match_rows_to_jrxml(layout, current_jrxml) parsed_text = ( f"[行片段修改] 上传图片包含 {layout['total_rows']} 行," f"视为 A4 报表的一部分。\n\n" f"{match['description']}\n\n" f"--- 行结构 ---\n{layout['description']}" ) else: # 新建模式:按 A4 模板处理 parsed_text = layout["description"] else: # tt == "unknown": OCR 不可用或未检测到文字元素 has_ocr = result.get("method") not in ("metadata_only", None) img_w, img_h = layout["image_size"] ratio = layout["aspect_ratio"] if has_ocr: parsed_text = ( f"[图片上传] 尺寸 {img_w}x{img_h}px, 比例 {ratio}。" f"未检测到 A4 报表结构,图片将被视为参考样式。\n" f"请根据用户的文字描述生成报表。" ) else: parsed_text = ( f"[图片上传] 尺寸 {img_w}x{img_h}px, 比例 {ratio}。\n" f"⚠ OCR 引擎未安装,无法识别图片中的文字内容。\n" f"请严格根据用户的文字描述来推断图片中的报表需求。\n" f"(提示:如需图片文字识别,请运行 pip install paddleocr)" ) parsed_type = "image_reference" if parsed_text: st.session_state.uploaded_files.append({ "name": uf.name, "text": parsed_text, "type": parsed_type, }) # 对图片类型,保存路径以便 OCR 字段提取(延迟到 process_input 阶段) img_suffixes = (".png", ".jpg", ".jpeg", ".bmp", ".webp") if suffix in img_suffixes and result.get("method") not in ("metadata_only", None): st.session_state.agent_state["uploaded_file_path"] = tmp_path st.session_state.uploaded_temp_paths.append(tmp_path) else: Path(tmp_path).unlink(missing_ok=True) if st.session_state.uploaded_files: for i, f in enumerate(st.session_state.uploaded_files): cols = st.columns([5, 1]) with cols[0]: st.caption(f"📎 {f['name']} ({f['type']}, {len(f['text'])} 字符)") with cols[1]: if st.button("✕", key=f"rm_uf_{i}", help="移除"): st.session_state.uploaded_files.pop(i) st.rerun() st.divider() st.markdown("### 配置") llm_backend = os.getenv("LLM_BACKEND", "cloud") llm_model = os.getenv("LLM_MODEL", os.getenv("LOCAL_LLM_MODEL", "gpt-4o")) st.caption(f"大语言模型: {llm_backend} / {llm_model}") st.caption(f"最大重试次数: {os.getenv('MAX_RETRY', '3')}") st.caption(f"验证服务: {os.getenv('VALIDATION_SERVICE_URL', 'http://localhost:8001/validate')}") st.divider() st.markdown("### 下载") final = st.session_state.agent_state.get("final_jrxml", "") versions = st.session_state.agent_state.get("jrxml_versions", []) if final: st.download_button( label="📥 下载最新 JRXML", data=final, file_name="report.jrxml", mime="application/xml", use_container_width=True, ) if versions: with st.expander("📋 历史版本", expanded=False): for i, v in enumerate(reversed(versions)): ts = v.get("ts", "")[:16] label = v.get("label", "版本") status = v.get("status", "") icon = "✅" if status == "pass" else "❌" dl_label = f"{icon} v{len(versions)-i} — {label} ({ts})" st.download_button( label=dl_label, data=v.get("jrxml", ""), file_name=f"report_v{len(versions)-i}.jrxml", mime="application/xml", use_container_width=True, key=f"dl_v{i}", ) # ---- 标题 ---- st.title("📝 JRXML 报表生成器") st.caption("用自然语言描述您的报表需求,我将逐步生成可用的 JRXML 模板。") # ---- 聊天历史 ---- for msg in st.session_state.messages: with st.chat_message(msg["role"]): if msg.get("type") == "jrxml": with st.expander("查看生成的 JRXML", expanded=False): st.code(msg["content"], language="xml") elif msg.get("type") == "error_explanation": st.warning(msg["content"]) elif msg.get("type") == "success": st.success(msg["content"]) elif msg.get("type") == "consult": st.info(msg["content"]) else: st.markdown(msg["content"]) # ---- 聊天输入 ---- if prompt := st.chat_input("描述您的报表需求..."): # 拼接上传文件的文本 uploaded_texts = [] uploaded_files_info = [] if st.session_state.get("uploaded_files"): for f in st.session_state.uploaded_files: uploaded_texts.append(f"[上传文件: {f['name']}]\n{f['text']}") uploaded_files_info.append({"name": f["name"], "type": f["type"], "length": len(f["text"])}) if uploaded_texts: full_prompt = "\n\n".join(uploaded_texts) + "\n\n---\n用户需求:\n" + prompt st.session_state.uploaded_files = [] # 用后即清 else: full_prompt = prompt _app_log.info( "收到用户输入", extra={ "session_id": current_session_id, "prompt_preview": prompt[:200], "prompt_length": len(prompt), "has_uploaded_files": bool(uploaded_files_info), "uploaded_files": uploaded_files_info, }, ) st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) run_agent(full_prompt) st.rerun()