"""JRXML 代理集成测试 - 4 个验收场景。 这些测试模拟多轮对话并验证代理管道。 需要验证服务在 8001 端口上运行。 """ import os import sys import pytest sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from dotenv import load_dotenv load_dotenv() from agent.graph import build_graph, create_initial_state @pytest.fixture def graph(): return build_graph() def run_graph(graph, initial_state): """使用给定的初始状态运行图并返回最终状态。""" final = None for event in graph.stream(initial_state): for node_name, node_state in event.items(): final = node_state return final class TestAcceptanceScenarios: def test_scenario1_simple_report_generation(self, graph): """场景 1:生成简单的员工名册 - 应该通过验证。""" state = create_initial_state() state["user_input"] = ( "Generate an employee roster report with columns: employee_id (Integer), " "full_name (String), department (String), and salary (BigDecimal). " "Query from employees table. Include a title 'Employee Roster'." ) state["stage"] = "initial_generation" final = run_graph(graph, state) assert final.get("current_jrxml"), "应该已生成 JRXML" # 注意:通过/失败取决于 LLM 输出质量;我们检查是否得到了结果 print(f"场景 1 状态: {final.get('status')}, 错误: {final.get('error_msg', '')[:100]}") def test_scenario2_auto_correction(self, graph): """场景 2:故意提出一个可能初次失败的需求。""" state = create_initial_state() state["user_input"] = ( "Create a sales summary report. Show customer_name and order_total. " "Add a subtotal by customer group. Query from orders table." ) state["stage"] = "initial_generation" final = run_graph(graph, state) assert final.get("retry_count", 0) <= 3, "不应超过最大重试次数" print(f"场景 2 状态: {final.get('status')}, 重试次数: {final.get('retry_count', 0)}") def test_scenario3_multi_turn_modification(self, graph): """场景 3:多轮对话 - 先生成,再修改两次。""" # 第 1 轮:生成销售订单报表 state = create_initial_state() state["user_input"] = ( "Create a sales order report with order_id (String), customer (String), " "amount (BigDecimal), order_date (Date). Query from sales_orders." ) state["stage"] = "initial_generation" final = run_graph(graph, state) print(f"第 1 轮状态: {final.get('status')}, 错误: {final.get('error_msg', '')[:100]}") assert final.get("current_jrxml"), "第 1 轮应该已生成 JRXML" # 第 2 轮:添加月度销售汇总 state2 = final.copy() state2["user_input"] = "Add a monthly sales total summary in the summary band." state2["user_modification_request"] = "Add a monthly sales total summary in the summary band." state2["stage"] = "modification" state2["retry_count"] = 0 final2 = run_graph(graph, state2) print(f"第 2 轮状态: {final2.get('status')}") assert final2.get("current_jrxml"), "第 2 轮应该已修改 JRXML" # 第 3 轮:修改标题 state3 = final2.copy() state3["user_input"] = "Change the title to '2024 Annual Sales Report' and make it bold." state3["user_modification_request"] = "Change the title to '2024 Annual Sales Report' and make it bold." state3["stage"] = "modification" state3["retry_count"] = 0 final3 = run_graph(graph, state3) print(f"第 3 轮状态: {final3.get('status')}") jrxml = final3.get("current_jrxml", "") assert "2024" in jrxml or "Annual" in jrxml, "标题修改应该体现在 JRXML 中" def test_scenario4_context_aware_modification(self, graph): """场景 4:基于对话上下文的修改。""" # 第 1 轮:生成按客户分组的报表 state = create_initial_state() state["user_input"] = ( "Create a customer orders report grouped by customer_name with order_id, " "order_total fields. Include a subtotal for each customer group. " "Query from customer_orders." ) state["stage"] = "initial_generation" final = run_graph(graph, state) print(f"第 1 轮状态: {final.get('status')}") # 第 2 轮:上下文感知修改 state2 = final.copy() state2["user_input"] = "Make the subtotal row font larger and bold." state2["user_modification_request"] = "Make the subtotal row font larger and bold." state2["stage"] = "modification" state2["retry_count"] = 0 final2 = run_graph(graph, state2) print(f"第 2 轮状态: {final2.get('status')}") jrxml = final2.get("current_jrxml", "") assert "isBold" in jrxml or "size=" in jrxml, "字体修改应该体现在结果中" def test_max_retry_handling(self, graph): """测试在 MAX_RETRY 次失败后,图能否正常终止。""" state = create_initial_state() state["current_jrxml"] = "xml<<<" state["user_input"] = "Fix this" state["retry_count"] = 3 # 已达到最大重试次数 state["status"] = "fail" final = run_graph(graph, state) assert final.get("retry_count", 0) >= 3 or final.get("status") == "pass"