diff --git a/app.py b/app.py new file mode 100644 index 0000000..8fda4a6 --- /dev/null +++ b/app.py @@ -0,0 +1,318 @@ +"""Streamlit 多轮对话 UI,用于 JRXML 生成代理。""" + +import os +import sys + +import streamlit as st + +from dotenv import load_dotenv +load_dotenv() + +from agent.graph import build_graph, create_initial_state +from backend.session import ( + create_session, + load_session, + delete_session, + list_all_sessions, +) + +st.set_page_config( + page_title="JRXML 代理", + page_icon="📊", + layout="wide", + initial_sidebar_state="expanded", +) + +# ---- URL 参数:session_id ---- +query_params = st.query_params +url_session_id = query_params.get("session_id", "") + +# ---- 会话状态初始化 ---- +if "messages" not in st.session_state: + st.session_state.messages = [] +if "graph" not in st.session_state: + st.session_state.graph = build_graph() +if "pending_action" not in st.session_state: + st.session_state.pending_action = None + +# 确定活跃的 session_id +if "agent_state" not in st.session_state: + if url_session_id: + data = load_session(url_session_id) + if data and data.get("agent_state"): + st.session_state.agent_state = data["agent_state"] + st.session_state.agent_state["session_id"] = url_session_id + else: + st.session_state.agent_state = create_initial_state() + new_data = create_session(name="", agent_state=st.session_state.agent_state) + st.session_state.agent_state["session_id"] = new_data["session_id"] + st.session_state.agent_state["session_name"] = new_data["session_name"] + st.session_state.agent_state["created_at"] = new_data["created_at"] + else: + st.session_state.agent_state = create_initial_state() + new_data = create_session(name="", agent_state=st.session_state.agent_state) + st.session_state.agent_state["session_id"] = new_data["session_id"] + st.session_state.agent_state["session_name"] = new_data["session_name"] + st.session_state.agent_state["created_at"] = new_data["created_at"] + +current_session_id = st.session_state.agent_state.get("session_id", "") + + +def run_agent(user_input: str) -> dict: + """运行代理图,返回最终状态。""" + agent_state = st.session_state.agent_state + + if agent_state.get("current_jrxml") and agent_state.get("status") == "pass": + agent_state["user_modification_request"] = user_input + + agent_state["user_input"] = user_input + agent_state["retry_count"] = 0 + + final_state = None + with st.chat_message("assistant"): + status_placeholder = st.empty() + jrxml_placeholder = st.empty() + + for event in st.session_state.graph.stream(agent_state): + for node_name, node_state in event.items(): + final_state = node_state + if node_name == "classify_intent": + intent = node_state.get("intent", "") + intent_labels = { + "initial_generation": "🆕 识别为新建报表请求", + "modify_report": "✏️ 识别为修改报表请求", + "preview_report": "👁 识别为预览请求", + "export_pdf": "📄 识别为导出PDF请求", + "export_jrxml": "📥 识别为导出JRXML请求", + "undo_modification": "↩ 识别为撤销请求", + "consult_question": "💬 识别为咨询问题", + "reset_session": "🔄 识别为重置会话请求", + } + label = intent_labels.get(intent, f"🔍 意图: {intent}") + status_placeholder.info(label) + elif node_name == "generate": + status_placeholder.info("🔧 正在生成 JRXML...") + elif node_name == "modify_jrxml": + status_placeholder.info("🔧 正在根据您的请求修改 JRXML...") + elif node_name == "validate": + if node_state.get("status") == "pass": + status_placeholder.success("✅ 验证通过!") + else: + status_placeholder.warning("⚠ 验证失败,正在分析错误...") + elif node_name == "explain_error": + explanation = node_state.get("natural_explanation", "") + status_placeholder.warning(f"🔍 {explanation}") + elif node_name == "correct_jrxml": + status_placeholder.info(f"🛠 正在自动修正(尝试 {node_state.get('retry_count', 1)})...") + elif node_name == "handle_consult": + pass + elif node_name == "handle_undo": + status_placeholder.info("↩ 已撤销上一步修改") + elif node_name == "handle_reset": + status_placeholder.info("🔄 会话已重置") + elif node_name == "manage_context": + pass + elif node_name == "save_state_snapshot": + pass + elif node_name == "save_session": + pass + elif node_name == "finalize": + pass + + if final_state: + st.session_state.agent_state = final_state + intent = final_state.get("intent", "") + + if intent == "consult_question": + answer = final_state.get("consult_answer", "") + st.session_state.messages.append({ + "role": "assistant", + "content": answer, + "type": "consult", + }) + status_placeholder.empty() + st.markdown(answer) + elif intent in ("undo_modification", "reset_session"): + # 消息已在节点中添加,不需要额外输出 + status_placeholder.empty() + jrxml_placeholder.empty() + elif intent in ("preview_report", "export_pdf", "export_jrxml"): + jrxml = final_state.get("current_jrxml", "") + if jrxml: + st.session_state.messages.append({ + "role": "assistant", + "content": jrxml, + "type": "jrxml", + }) + status_placeholder.success("✅ 当前报表") + jrxml_placeholder.code(jrxml, language="xml") + else: + status_placeholder.warning("⚠ 当前没有报表可以预览或导出。") + jrxml_placeholder.empty() + elif final_state.get("status") == "pass": + st.session_state.messages.append({ + "role": "assistant", + "content": final_state.get("current_jrxml", ""), + "type": "jrxml", + }) + st.session_state.messages.append({ + "role": "assistant", + "content": "✅ JRXML 生成成功!您可以从侧边栏下载文件,或继续修改。", + "type": "success", + }) + status_placeholder.success("✅ JRXML 验证通过!") + jrxml_placeholder.code(final_state.get("current_jrxml", ""), language="xml") + else: + error_msg = final_state.get("error_msg", "未知错误") + explanation = final_state.get("natural_explanation", "") + retries = final_state.get("retry_count", 0) + st.session_state.messages.append({ + "role": "assistant", + "content": final_state.get("current_jrxml", ""), + "type": "jrxml", + }) + st.session_state.messages.append({ + "role": "assistant", + "content": f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML。\n\n**错误:** {error_msg}\n\n**解释:** {explanation}\n\n请重新描述您的需求或简化报表结构。", + "type": "error_explanation", + }) + status_placeholder.error(f"❌ 经过 {retries} 次重试后验证失败") + jrxml_placeholder.text("") + else: + st.error("未产生结果,请重试。") + + return final_state + + +# ---- 侧边栏 ---- +with st.sidebar: + st.title("📊 JRXML 代理") + st.markdown("通过自然语言生成 JasperReports 模板。") + st.divider() + + # 会话管理 + st.markdown("### 会话管理") + + sessions = list_all_sessions() + session_options = {} + for s in sessions: + sid = s["session_id"] + name = s.get("session_name", sid) + updated = s.get("updated_at", "")[:16] + session_options[f"{name} ({updated})"] = sid + + selected_label = None + for label, sid in session_options.items(): + if sid == current_session_id: + selected_label = label + break + + selected = st.selectbox( + "切换会话", + options=list(session_options.keys()), + index=list(session_options.keys()).index(selected_label) if selected_label else 0, + key="session_selector", + ) + + if selected and session_options.get(selected) != current_session_id: + new_sid = session_options[selected] + data = load_session(new_sid) + if data and data.get("agent_state"): + st.session_state.agent_state = data["agent_state"] + st.session_state.messages = [] + st.rerun() + + col1, col2 = st.columns(2) + with col1: + if st.button("➕ 新建", use_container_width=True): + new_data = create_session(name="", agent_state=create_initial_state()) + st.session_state.agent_state = create_initial_state() + st.session_state.agent_state["session_id"] = new_data["session_id"] + st.session_state.agent_state["session_name"] = new_data["session_name"] + st.session_state.agent_state["created_at"] = new_data["created_at"] + st.session_state.messages = [] + st.rerun() + with col2: + if st.button("🗑 删除", use_container_width=True): + if current_session_id: + delete_session(current_session_id) + st.session_state.agent_state = create_initial_state() + new_data = create_session(name="", agent_state=st.session_state.agent_state) + st.session_state.agent_state["session_id"] = new_data["session_id"] + st.session_state.agent_state["session_name"] = new_data["session_name"] + st.session_state.agent_state["created_at"] = new_data["created_at"] + st.session_state.messages = [] + st.rerun() + + current_name = st.session_state.agent_state.get("session_name", "") + st.caption(f"当前: {current_name} (`{current_session_id}`)") + + st.divider() + st.markdown("### 快捷操作") + + has_jrxml = bool(st.session_state.agent_state.get("current_jrxml", "").strip()) + has_history = bool(st.session_state.agent_state.get("history_states", [])) + + qcol1, qcol2 = st.columns(2) + with qcol1: + if st.button("👁 预览", use_container_width=True, disabled=not has_jrxml): + with st.spinner("正在准备预览..."): + run_agent("预览报表") + st.rerun() + with qcol2: + if st.button("↩ 撤销", use_container_width=True, disabled=not has_history): + with st.spinner("正在撤销..."): + run_agent("撤销上一步修改") + st.rerun() + + if st.button("🔄 重置会话", use_container_width=True): + with st.spinner("正在重置..."): + run_agent("重新来,清空当前报表") + st.rerun() + + st.divider() + st.markdown("### 配置") + llm_backend = os.getenv("LLM_BACKEND", "cloud") + llm_model = os.getenv("LLM_MODEL", os.getenv("LOCAL_LLM_MODEL", "gpt-4o")) + st.caption(f"大语言模型: {llm_backend} / {llm_model}") + st.caption(f"最大重试次数: {os.getenv('MAX_RETRY', '3')}") + st.caption(f"验证服务: {os.getenv('VALIDATION_SERVICE_URL', 'http://localhost:8001/validate')}") + + st.divider() + st.markdown("### 下载") + final = st.session_state.agent_state.get("final_jrxml", "") + if final: + st.download_button( + label="📥 下载 JRXML", + data=final, + file_name="report.jrxml", + mime="application/xml", + use_container_width=True, + ) + +# ---- 标题 ---- +st.title("📝 JRXML 报表生成器") +st.caption("用自然语言描述您的报表需求,我将逐步生成可用的 JRXML 模板。") + +# ---- 聊天历史 ---- +for msg in st.session_state.messages: + with st.chat_message(msg["role"]): + if msg["role"] == "assistant" and msg.get("type") == "jrxml": + with st.expander("查看生成的 JRXML", expanded=False): + st.code(msg["content"], language="xml") + elif msg["role"] == "assistant" and msg.get("type") == "error_explanation": + st.warning(msg["content"]) + elif msg["role"] == "assistant" and msg.get("type") == "success": + st.success(msg["content"]) + elif msg["role"] == "assistant" and msg.get("type") == "consult": + st.info(msg["content"]) + else: + st.markdown(msg["content"]) + +# ---- 聊天输入 ---- +if prompt := st.chat_input("描述您的报表需求..."): + st.session_state.messages.append({"role": "user", "content": prompt}) + with st.chat_message("user"): + st.markdown(prompt) + run_agent(prompt) + st.rerun() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_agent.py b/tests/test_agent.py new file mode 100644 index 0000000..61a85c8 --- /dev/null +++ b/tests/test_agent.py @@ -0,0 +1,135 @@ +"""JRXML 代理集成测试 - 4 个验收场景。 + +这些测试模拟多轮对话并验证代理管道。 +需要验证服务在 8001 端口上运行。 +""" + +import os +import sys +import pytest + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from dotenv import load_dotenv +load_dotenv() + +from agent.graph import build_graph, create_initial_state + + +@pytest.fixture +def graph(): + return build_graph() + + +def run_graph(graph, initial_state): + """使用给定的初始状态运行图并返回最终状态。""" + final = None + for event in graph.stream(initial_state): + for node_name, node_state in event.items(): + final = node_state + return final + + +class TestAcceptanceScenarios: + + def test_scenario1_simple_report_generation(self, graph): + """场景 1:生成简单的员工名册 - 应该通过验证。""" + state = create_initial_state() + state["user_input"] = ( + "Generate an employee roster report with columns: employee_id (Integer), " + "full_name (String), department (String), and salary (BigDecimal). " + "Query from employees table. Include a title 'Employee Roster'." + ) + state["stage"] = "initial_generation" + + final = run_graph(graph, state) + assert final.get("current_jrxml"), "应该已生成 JRXML" + # 注意:通过/失败取决于 LLM 输出质量;我们检查是否得到了结果 + print(f"场景 1 状态: {final.get('status')}, 错误: {final.get('error_msg', '')[:100]}") + + def test_scenario2_auto_correction(self, graph): + """场景 2:故意提出一个可能初次失败的需求。""" + state = create_initial_state() + state["user_input"] = ( + "Create a sales summary report. Show customer_name and order_total. " + "Add a subtotal by customer group. Query from orders table." + ) + state["stage"] = "initial_generation" + + final = run_graph(graph, state) + assert final.get("retry_count", 0) <= 3, "不应超过最大重试次数" + print(f"场景 2 状态: {final.get('status')}, 重试次数: {final.get('retry_count', 0)}") + + def test_scenario3_multi_turn_modification(self, graph): + """场景 3:多轮对话 - 先生成,再修改两次。""" + # 第 1 轮:生成销售订单报表 + state = create_initial_state() + state["user_input"] = ( + "Create a sales order report with order_id (String), customer (String), " + "amount (BigDecimal), order_date (Date). Query from sales_orders." + ) + state["stage"] = "initial_generation" + + final = run_graph(graph, state) + print(f"第 1 轮状态: {final.get('status')}, 错误: {final.get('error_msg', '')[:100]}") + assert final.get("current_jrxml"), "第 1 轮应该已生成 JRXML" + + # 第 2 轮:添加月度销售汇总 + state2 = final.copy() + state2["user_input"] = "Add a monthly sales total summary in the summary band." + state2["user_modification_request"] = "Add a monthly sales total summary in the summary band." + state2["stage"] = "modification" + state2["retry_count"] = 0 + + final2 = run_graph(graph, state2) + print(f"第 2 轮状态: {final2.get('status')}") + assert final2.get("current_jrxml"), "第 2 轮应该已修改 JRXML" + + # 第 3 轮:修改标题 + state3 = final2.copy() + state3["user_input"] = "Change the title to '2024 Annual Sales Report' and make it bold." + state3["user_modification_request"] = "Change the title to '2024 Annual Sales Report' and make it bold." + state3["stage"] = "modification" + state3["retry_count"] = 0 + + final3 = run_graph(graph, state3) + print(f"第 3 轮状态: {final3.get('status')}") + jrxml = final3.get("current_jrxml", "") + assert "2024" in jrxml or "Annual" in jrxml, "标题修改应该体现在 JRXML 中" + + def test_scenario4_context_aware_modification(self, graph): + """场景 4:基于对话上下文的修改。""" + # 第 1 轮:生成按客户分组的报表 + state = create_initial_state() + state["user_input"] = ( + "Create a customer orders report grouped by customer_name with order_id, " + "order_total fields. Include a subtotal for each customer group. " + "Query from customer_orders." + ) + state["stage"] = "initial_generation" + + final = run_graph(graph, state) + print(f"第 1 轮状态: {final.get('status')}") + + # 第 2 轮:上下文感知修改 + state2 = final.copy() + state2["user_input"] = "Make the subtotal row font larger and bold." + state2["user_modification_request"] = "Make the subtotal row font larger and bold." + state2["stage"] = "modification" + state2["retry_count"] = 0 + + final2 = run_graph(graph, state2) + print(f"第 2 轮状态: {final2.get('status')}") + jrxml = final2.get("current_jrxml", "") + assert "isBold" in jrxml or "size=" in jrxml, "字体修改应该体现在结果中" + + def test_max_retry_handling(self, graph): + """测试在 MAX_RETRY 次失败后,图能否正常终止。""" + state = create_initial_state() + state["current_jrxml"] = "xml<<<" + state["user_input"] = "Fix this" + state["retry_count"] = 3 # 已达到最大重试次数 + state["status"] = "fail" + + final = run_graph(graph, state) + assert final.get("retry_count", 0) >= 3 or final.get("status") == "pass" diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 0000000..24cb8ab --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,105 @@ +"""JRXML 验证服务的单元测试。""" + +import pytest +from fastapi.testclient import TestClient + +from validation_service.main import app + +client = TestClient(app) + + +class TestValidationService: + def test_health_endpoint(self): + resp = client.get("/health") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ok" + + def test_empty_jrxml(self): + resp = client.post("/validate", json={"jrxml": ""}) + assert resp.status_code == 200 + assert resp.json()["valid"] is False + assert "空" in resp.json()["error"] + + def test_invalid_xml(self): + resp = client.post("/validate", json={"jrxml": "xml<<<"}) + assert resp.status_code == 200 + data = resp.json() + assert data["valid"] is False + + def test_missing_page_dimensions(self): + jrxml = """ + + + + <band height="30"/> + + + + + + + + +""" + resp = client.post("/validate", json={"jrxml": jrxml}) + assert resp.status_code == 200 + data = resp.json() + assert data["valid"] is False + assert "pageWidth" in data["error"] + + def test_valid_jrxml(self): + jrxml = """ + + + + + <band height="30"> + <staticText> + <reportElement x="0" y="0" width="555" height="30"/> + <text><![CDATA[Report Title]]></text> + </staticText> + </band> + + + + + + + + + + + + +""" + resp = client.post("/validate", json={"jrxml": jrxml}) + assert resp.status_code == 200 + data = resp.json() + assert data["valid"] is True, f"验证应该通过,实际错误: {data.get('error')}" + + def test_missing_field_declaration(self): + jrxml = """ + + + + + + + + + + + +""" + resp = client.post("/validate", json={"jrxml": jrxml}) + assert resp.status_code == 200 + data = resp.json() + assert data["valid"] is False + assert "missing_field" in data["error"]