"""datasource.py 测试 — 数据源模式解析, JDBC 检测, 上下文构建。""" import sys from pathlib import Path from unittest.mock import MagicMock import pytest sys.path.insert(0, str(Path(__file__).parent.parent)) from agent.datasource import ( resolve_datasource_mode, _detect_jdbc_intent, build_datasource_context, configure_jdbc, ask_db_config, ) def _make_state(**overrides): s = { "user_input": "", "conversation_history": [], "current_jrxml": "", "status": "", "error_msg": "", "natural_explanation": "", "retry_count": 0, "user_modification_request": "", "final_jrxml": "", "stage": "", "retrieved_context": "", **overrides, } return s # ── JDBC 意图检测 ─────────────────────────────────────────────── class TestDetectJdbcIntent: def test_direct_connect_keywords(self): assert _detect_jdbc_intent("我想从数据库直连查询") is True assert _detect_jdbc_intent("直连数据库获取数据") is True def test_db_name_mentions(self): assert _detect_jdbc_intent("从MySQL数据库查询用户表") is True assert _detect_jdbc_intent("在PostgreSQL中执行查询") is True assert _detect_jdbc_intent("从Oracle读取数据") is True def test_jdbc_explicit_mention(self): assert _detect_jdbc_intent("使用JDBC连接") is True def test_sql_keywords(self): assert _detect_jdbc_intent("SELECT * FROM users") is True assert _detect_jdbc_intent("从数据库查询用户表") is True assert _detect_jdbc_intent("先查询 数据库") is True def test_normal_request_is_not_jdbc(self): assert _detect_jdbc_intent("生成一个员工报表") is False assert _detect_jdbc_intent("修改标题为XX公司") is False def test_empty_input(self): assert _detect_jdbc_intent("") is False # ── 模式解析 ──────────────────────────────────────────────────── class TestResolveDatasourceMode: def test_defaults_to_parameter_mode(self): state = _make_state(user_input="生成报表") assert resolve_datasource_mode(state) == "parameter" def test_detects_jdbc_from_input(self): state = _make_state(user_input="从数据库直连查询") assert resolve_datasource_mode(state) == "jdbc" def test_respects_existing_mode_in_state(self): state = _make_state(datasource_mode="jdbc", user_input="生成报表") assert resolve_datasource_mode(state) == "jdbc" def test_existing_parameter_overrides_jdbc_input(self): state = _make_state(datasource_mode="parameter", user_input="从数据库直连") assert resolve_datasource_mode(state) == "parameter" def test_ignores_invalid_mode_in_state(self): state = _make_state(datasource_mode="unknown", user_input="从数据库直连") assert resolve_datasource_mode(state) == "jdbc" # ── 上下文构建 ────────────────────────────────────────────────── class TestBuildDatasourceContext: def test_parameter_mode_with_fields(self): fields = [ {"name": "billNo", "description": "工单号", "type": "java.lang.String"}, {"name": "amount", "description": "金额", "type": "java.math.BigDecimal"}, ] ctx = build_datasource_context("parameter", fields) assert "[数据源模式: 参数]" in ctx assert "$P{xxx}" in ctx assert "billNo" in ctx assert "amount" in ctx def test_parameter_mode_without_fields(self): ctx = build_datasource_context("parameter", []) assert "[数据源模式: 参数]" in ctx assert "$P{xxx}" in ctx def test_jdbc_mode_with_config(self): db_config = {"url": "jdbc:mysql://localhost:3306/mydb", "driver": "com.mysql.cj.jdbc.Driver"} ctx = build_datasource_context("jdbc", [], db_config) assert "[数据源模式: JDBC]" in ctx assert "jdbc:mysql://" in ctx assert "CDATA" in ctx def test_jdbc_mode_without_config_shows_warning(self): ctx = build_datasource_context("jdbc", []) assert "尚未配置数据库连接" in ctx assert "P{xxx}" in ctx # ── JDBC 配置 ─────────────────────────────────────────────────── class TestConfigureJdbc: def test_configure_returns_update_dict(self): state = _make_state() update = configure_jdbc( state, url="jdbc:mysql://localhost/db", driver="com.mysql.cj.jdbc.Driver", username="root", password="pass") assert update["datasource_mode"] == "jdbc" assert update["db_config"]["url"] == "jdbc:mysql://localhost/db" assert update["db_config"]["username"] == "root" def test_default_driver_is_mysql(self): update = configure_jdbc(_make_state(), url="jdbc:postgresql://localhost/db") assert "mysql" in update["db_config"]["driver"] # ── ask_db_config ─────────────────────────────────────────────── class TestAskDbConfig: def test_returns_none_for_parameter_mode(self): state = _make_state(datasource_mode="parameter") assert ask_db_config(state) is None def test_returns_none_when_jdbc_configured(self): state = _make_state(datasource_mode="jdbc", db_config={"url": "jdbc:mysql://localhost/db"}) assert ask_db_config(state) is None def test_returns_prompt_when_jdbc_missing_config(self): state = _make_state(datasource_mode="jdbc") msg = ask_db_config(state) assert msg is not None assert "JDBC URL" in msg assert "用户名" in msg assert "密码" in msg def test_returns_none_when_db_config_empty(self): state = _make_state(datasource_mode="jdbc", db_config={}) msg = ask_db_config(state) assert msg is not None