"""api_server.py 集成测试 — REST 端点 + SSE 流 + 文件上传/下载。 使用 FastAPI TestClient,不需要启动真实服务器。 """ import io import json import os import sys import tempfile from pathlib import Path from unittest.mock import patch, MagicMock import pytest from fastapi.testclient import TestClient sys.path.insert(0, str(Path(__file__).parent.parent)) from api_server import app @pytest.fixture def client(): return TestClient(app) @pytest.fixture def temp_sessions(monkeypatch): """重定向上传目录到临时目录,隔离测试数据。""" with tempfile.TemporaryDirectory(prefix="test_api_") as tmpdir: monkeypatch.setattr("api_server.UPLOADS_DIR", Path(tmpdir) / "uploads") monkeypatch.setattr("backend.session.SESSIONS_DIR", Path(tmpdir) / "sessions") yield Path(tmpdir) # ── 健康检查 & 配置 ──────────────────────────────────────────── class TestHealthAndConfig: def test_health_returns_ok(self, client): resp = client.get("/api/health") assert resp.status_code == 200 data = resp.json() assert data["status"] == "ok" assert data["version"] == "5.0" assert "timestamp" in data def test_config_returns_env_keys(self, client): resp = client.get("/api/config") assert resp.status_code == 200 cfg = resp.json()["config"] for key in ("LLM_PROVIDER", "OCR_ENGINE", "MAX_RETRY"): assert key in cfg # ── 会话 CRUD ────────────────────────────────────────────────── class TestSessionCRUD: def test_create_session(self, client, temp_sessions): resp = client.post("/api/sessions") assert resp.status_code == 200 data = resp.json() assert len(data["session_id"]) == 32 assert "session_name" in data assert "created_at" in data def test_list_sessions_empty(self, client, temp_sessions): assert client.get("/api/sessions").json()["sessions"] == [] def test_list_sessions_populated(self, client, temp_sessions): client.post("/api/sessions") client.post("/api/sessions") assert len(client.get("/api/sessions").json()["sessions"]) == 2 def test_get_session_found(self, client, temp_sessions): created = client.post("/api/sessions").json() resp = client.get(f"/api/sessions/{created['session_id']}") assert resp.status_code == 200 assert resp.json()["session_id"] == created["session_id"] assert "agent_state" in resp.json() def test_get_session_invalid_id(self, client, temp_sessions): assert client.get("/api/sessions/nonexistent").status_code == 400 def test_get_session_not_found(self, client, temp_sessions): assert client.get("/api/sessions/aabbccddeeff0011223344").status_code == 404 def test_delete_session(self, client, temp_sessions): sid = client.post("/api/sessions").json()["session_id"] resp = client.delete(f"/api/sessions/{sid}") assert resp.status_code == 200 assert resp.json()["status"] == "deleted" assert client.get(f"/api/sessions/{sid}").status_code == 404 def test_delete_nonexistent(self, client, temp_sessions): assert client.delete("/api/sessions/aabbccddeeff0011223344").status_code == 404 def test_full_crud_lifecycle(self, client, temp_sessions): sid = client.post("/api/sessions").json()["session_id"] assert client.get(f"/api/sessions/{sid}").status_code == 200 assert len(client.get("/api/sessions").json()["sessions"]) == 1 client.delete(f"/api/sessions/{sid}") assert client.get("/api/sessions").json()["sessions"] == [] # ── 文件上传 ─────────────────────────────────────────────────── class TestFileUpload: def test_upload_text_file(self, client, temp_sessions): content = b"Hello, JRXML!" resp = client.post( "/api/upload", files={"file": ("test.txt", io.BytesIO(content), "text/plain")}, ) assert resp.status_code == 200 data = resp.json() assert data["filename"] == "test.txt" assert data["size"] == len(content) assert len(data["file_id"]) == 12 def test_upload_with_session_id_in_query(self, client, temp_sessions): resp = client.post( "/api/upload?session_id=aabbccddeeff0011223344", files={"file": ("data.csv", io.BytesIO(b"a,b,c"), "text/csv")}, ) assert resp.status_code == 200 def test_upload_png_gets_correct_content_type(self, client, temp_sessions): png_minimal = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100 resp = client.post( "/api/upload", files={"file": ("chart.png", io.BytesIO(png_minimal), "image/png")}, ) assert resp.status_code == 200 assert resp.json()["content_type"] == "image/png" def test_upload_writes_file_to_disk(self, client, temp_sessions): data = b"persisted content" file_id = client.post( "/api/upload", files={"file": ("note.txt", io.BytesIO(data), "text/plain")}, ).json()["file_id"] matches = list(temp_sessions.rglob(f"{file_id}_*")) assert len(matches) == 1 assert matches[0].read_bytes() == data # ── 下载 ─────────────────────────────────────────────────────── class TestDownload: def test_download_missing_session_returns_404(self, client, temp_sessions): assert client.get("/api/sessions/aabbccddeeff0011223344/download/latest").status_code == 404 def test_download_no_jrxml_returns_404(self, client, temp_sessions): sid = client.post("/api/sessions").json()["session_id"] resp = client.get(f"/api/sessions/{sid}/download/latest") assert resp.status_code == 404 def test_download_with_jrxml_returns_file(self, client, temp_sessions): import backend.session as sess sess.create_session(name="测试下载") # 需要手动写入 JRXML 到会话 sessions = sess.list_all_sessions() sid = sessions[0]["session_id"] sess.save_session(sid, {"current_jrxml": ""}) resp = client.get(f"/api/sessions/{sid}/download/latest") assert resp.status_code == 200 assert "", "status": "pass"}}), ("updates", {"validate": {"status": "pass"}}), ("updates", {"finalize": {}}), ("done", {"reason": "graph_completed"}), ] # 注意:_graph 是模块级变量,在导入时就编译了。需要直接替换。 monkeypatch.setattr("api_server._graph", mock_graph) # 同时替换 agent.graph.build_graph 以防后续重新编译 monkeypatch.setattr("agent.graph.build_graph", lambda on_node_start=None: mock_graph) return mock_graph def test_empty_payload_rejected(self, client, temp_sessions): sid = client.post("/api/sessions").json()["session_id"] resp = client.post( f"/api/sessions/{sid}/chat", json={"text": "", "file_ids": []}, ) assert resp.status_code == 400 def test_sse_stream_returns_valid_events(self, client, temp_sessions): sid = client.post("/api/sessions").json()["session_id"] with client.stream( "POST", f"/api/sessions/{sid}/chat", json={"text": "生成一个简单的员工名册报表", "file_ids": []}, ) as resp: assert resp.status_code == 200 assert resp.headers["content-type"].startswith("text/event-stream") body = resp.read().decode("utf-8", errors="replace") assert "event: node_complete" in body assert "event: agent_complete" in body def test_auto_creates_session_on_chat(self, client, temp_sessions): with client.stream( "POST", "/api/sessions/aabbccddeeff0011223344/chat", json={"text": "生成报表", "file_ids": []}, ) as resp: assert resp.status_code == 200 assert b"event:" in resp.read() def test_unknown_file_ids_not_crash(self, client, temp_sessions): sid = client.post("/api/sessions").json()["session_id"] with client.stream( "POST", f"/api/sessions/{sid}/chat", json={"text": "测试", "file_ids": ["fake_id_xyz"]}, ) as resp: assert resp.status_code == 200 # ── 边界 & 安全测试 ──────────────────────────────────────────── class TestBoundaries: def test_session_id_invalid_format_returns_400(self, client, temp_sessions): """非 hex 字符的 session_id 应被拒绝。""" assert client.get("/api/sessions/not_valid_hex_id").status_code == 400 def test_upload_with_path_traversal_session_id(self, client, temp_sessions): """路径穿越 session_id 被拒绝。""" resp = client.post( "/api/upload?session_id=../malicious", files={"file": ("t.txt", io.BytesIO(b"x"), "text/plain")}, ) assert resp.status_code == 400 def test_invalid_json_body_rejected(self, client, temp_sessions): sid = client.post("/api/sessions").json()["session_id"] resp = client.post( f"/api/sessions/{sid}/chat", content=b"{not valid json", headers={"Content-Type": "application/json"}, ) assert resp.status_code == 422 def test_large_payload_survives(self, client, temp_sessions): """大文本(100KB)不应崩溃。""" sid = client.post("/api/sessions").json()["session_id"] large_text = "生成报表包含字段: " + ", ".join(f"field_{i}" for i in range(5000)) with client.stream( "POST", f"/api/sessions/{sid}/chat", json={"text": large_text, "file_ids": []}, ) as resp: assert resp.status_code == 200 # ── 用户 CRUD API ─────────────────────────────────────────────── class TestUserAPI: @pytest.fixture(autouse=True) def temp_kb_data(self, monkeypatch, tmp_path): kb_data = tmp_path / "kb_data" monkeypatch.setattr("backend.kb_manager.KB_DATA_DIR", kb_data) monkeypatch.setattr("backend.kb_manager._USERS_FILE", kb_data / "users.json") yield kb_data def test_create_user(self, client): resp = client.post("/api/users", json={"name": "测试用户"}) assert resp.status_code == 200 data = resp.json() assert data["name"] == "测试用户" assert len(data["user_id"]) >= 12 def test_create_user_empty_name_rejected(self, client): resp = client.post("/api/users", json={"name": ""}) assert resp.status_code == 400 def test_list_users(self, client): client.post("/api/users", json={"name": "A"}) client.post("/api/users", json={"name": "B"}) resp = client.get("/api/users") assert resp.status_code == 200 assert len(resp.json()["users"]) == 2 def test_get_user(self, client): uid = client.post("/api/users", json={"name": "张三"}).json()["user_id"] resp = client.get(f"/api/users/{uid}") assert resp.status_code == 200 assert resp.json()["name"] == "张三" def test_get_user_not_found(self, client): resp = client.get("/api/users/deadbeef1234567890abcd") assert resp.status_code == 404 def test_delete_user(self, client): uid = client.post("/api/users", json={"name": "待删除"}).json()["user_id"] resp = client.delete(f"/api/users/{uid}") assert resp.status_code == 200 assert resp.json()["status"] == "deleted" assert client.get(f"/api/users/{uid}").status_code == 404 def test_delete_nonexistent_user(self, client): resp = client.delete("/api/users/deadbeef1234567890abcd") assert resp.status_code == 404 # ── 知识库 CRUD API ───────────────────────────────────────────── class TestKbAPI: @pytest.fixture(autouse=True) def setup_kb(self, monkeypatch, tmp_path): kb_data = tmp_path / "kb_data" monkeypatch.setattr("backend.kb_manager.KB_DATA_DIR", kb_data) monkeypatch.setattr("backend.kb_manager._USERS_FILE", kb_data / "users.json") # 使用 raw TestClient 来创建前置用户 from fastapi.testclient import TestClient as TC tc = TC(app) resp = tc.post("/api/users", json={"name": "KB测试用户"}) self.uid = resp.json()["user_id"] def test_create_kb(self, client): resp = client.post(f"/api/users/{self.uid}/kbs", json={"name": "测试库", "description": "描述"}) assert resp.status_code == 200 data = resp.json() assert data["name"] == "测试库" assert data["parse_status"] == "empty" def test_create_kb_empty_name_rejected(self, client): resp = client.post(f"/api/users/{self.uid}/kbs", json={"name": ""}) assert resp.status_code == 400 def test_list_kbs(self, client): client.post(f"/api/users/{self.uid}/kbs", json={"name": "KB1"}) client.post(f"/api/users/{self.uid}/kbs", json={"name": "KB2"}) resp = client.get(f"/api/users/{self.uid}/kbs") assert resp.status_code == 200 assert len(resp.json()["kbs"]) == 2 def test_get_kb(self, client): kid = client.post(f"/api/users/{self.uid}/kbs", json={"name": "查询库"}).json()["kb_id"] resp = client.get(f"/api/kbs/{kid}") assert resp.status_code == 200 assert resp.json()["name"] == "查询库" def test_get_kb_not_found(self, client): resp = client.get("/api/kbs/deadbeef1234567890abcd") assert resp.status_code == 404 def test_delete_kb(self, client): kid = client.post(f"/api/users/{self.uid}/kbs", json={"name": "待删库"}).json()["kb_id"] resp = client.delete(f"/api/kbs/{kid}") assert resp.status_code == 200 assert resp.json()["status"] == "deleted" def test_kb_status(self, client): kid = client.post(f"/api/users/{self.uid}/kbs", json={"name": "状态库"}).json()["kb_id"] resp = client.get(f"/api/kbs/{kid}/status") assert resp.status_code == 200 assert resp.json()["parse_status"] == "empty" assert resp.json()["file_count"] == 0 def test_kb_fields(self, client): kid = client.post(f"/api/users/{self.uid}/kbs", json={"name": "字段库"}).json()["kb_id"] resp = client.get(f"/api/kbs/{kid}/fields") assert resp.status_code == 200 assert resp.json()["fields"] == [] assert resp.json()["templates"] == [] # ── KB 文件上传 & 构建 API ────────────────────────────────────── class TestKbUploadBuild: @pytest.fixture(autouse=True) def setup_up(self, monkeypatch, tmp_path): kb_data = tmp_path / "kb_data" kb_data.mkdir(parents=True, exist_ok=True) monkeypatch.setattr("backend.kb_manager.KB_DATA_DIR", kb_data) monkeypatch.setattr("backend.kb_manager._USERS_FILE", kb_data / "users.json") # Mock process_file_for_kb to avoid SameFileError (API already writes to raw_dir) monkeypatch.setattr( "backend.kb_parser.process_file_for_kb", lambda kb_id, file_path, source_name="": { "filename": source_name, "type": "txt", "error": None}) from fastapi.testclient import TestClient as TC tc = TC(app) resp = tc.post("/api/users", json={"name": "上传测试用户"}) self.uid = resp.json()["user_id"] def test_upload_to_kb(self, client): kid = client.post(f"/api/users/{self.uid}/kbs", json={"name": "上传库"}).json()["kb_id"] resp = client.post( f"/api/kbs/{kid}/upload", files={"file": ("readme.md", io.BytesIO(b"# test"), "text/markdown")}, ) assert resp.status_code == 200 assert resp.json()["filename"] == "readme.md" def test_upload_to_nonexistent_kb(self, client): resp = client.post( "/api/kbs/deadbeef1234567890abcd/upload", files={"file": ("x.txt", io.BytesIO(b"x"), "text/plain")}, ) assert resp.status_code == 404 def test_build_empty_kb_fails(self, client): kid = client.post(f"/api/users/{self.uid}/kbs", json={"name": "空库"}).json()["kb_id"] resp = client.post(f"/api/kbs/{kid}/build") assert resp.status_code == 400 def test_search_kb_empty_query_rejected(self, client): kid = client.post(f"/api/users/{self.uid}/kbs", json={"name": "搜索库"}).json()["kb_id"] resp = client.get(f"/api/kbs/{kid}/search") assert resp.status_code == 400 # ── 会话-KB 绑定 API ──────────────────────────────────────────── class TestSessionKbBinding: @pytest.fixture(autouse=True) def setup_bind(self, monkeypatch, tmp_path): kb_data = tmp_path / "kb_data" kb_data.mkdir(parents=True, exist_ok=True) sessions_dir = tmp_path / "sessions" monkeypatch.setattr("backend.kb_manager.KB_DATA_DIR", kb_data) monkeypatch.setattr("backend.kb_manager._USERS_FILE", kb_data / "users.json") monkeypatch.setattr("backend.session.SESSIONS_DIR", sessions_dir) monkeypatch.setattr("api_server.UPLOADS_DIR", tmp_path / "uploads") def test_bind_kb_to_session(self, client): uid = client.post("/api/users", json={"name": "绑定用户"}).json()["user_id"] kid = client.post(f"/api/users/{uid}/kbs", json={"name": "绑定库"}).json()["kb_id"] sid = client.post("/api/sessions").json()["session_id"] resp = client.put(f"/api/sessions/{sid}/kb", json={"kb_id": kid}) assert resp.status_code == 200 assert resp.json()["kb_id"] == kid def test_get_session_kb(self, client): uid = client.post("/api/users", json={"name": "查询用户"}).json()["user_id"] kid = client.post(f"/api/users/{uid}/kbs", json={"name": "查询KB"}).json()["kb_id"] sid = client.post("/api/sessions").json()["session_id"] client.put(f"/api/sessions/{sid}/kb", json={"kb_id": kid}) resp = client.get(f"/api/sessions/{sid}/kb") assert resp.status_code == 200 assert resp.json()["kb_id"] == kid assert resp.json()["kb_name"] == "查询KB" def test_unbind_kb(self, client): sid = client.post("/api/sessions").json()["session_id"] resp = client.put(f"/api/sessions/{sid}/kb", json={"kb_id": ""}) assert resp.status_code == 200 assert resp.json()["kb_id"] is None def test_bind_nonexistent_kb(self, client): sid = client.post("/api/sessions").json()["session_id"] resp = client.put(f"/api/sessions/{sid}/kb", json={"kb_id": "deadbeef1234567890abcd"}) assert resp.status_code == 404 def test_bind_to_nonexistent_session(self, client): resp = client.put("/api/sessions/deadbeef1234567890abcd/kb", json={"kb_id": ""}) assert resp.status_code == 404 # ── 用户-KB 端到端流程 ────────────────────────────────────────── class TestUserKbE2E: @pytest.fixture(autouse=True) def setup_e2e(self, monkeypatch, tmp_path): kb_data = tmp_path / "kb_data" kb_data.mkdir(parents=True, exist_ok=True) sessions_dir = tmp_path / "sessions" monkeypatch.setattr("backend.kb_manager.KB_DATA_DIR", kb_data) monkeypatch.setattr("backend.kb_manager._USERS_FILE", kb_data / "users.json") monkeypatch.setattr("backend.session.SESSIONS_DIR", sessions_dir) monkeypatch.setattr("api_server.UPLOADS_DIR", tmp_path / "uploads") # Mock process_file_for_kb to avoid SameFileError monkeypatch.setattr( "backend.kb_parser.process_file_for_kb", lambda kb_id, file_path, source_name="": { "filename": source_name, "type": "txt", "error": None}) def test_full_flow(self, client): # 1. 创建用户 uid = client.post("/api/users", json={"name": "E2E用户"}).json()["user_id"] # 2. 创建 KB kid = client.post(f"/api/users/{uid}/kbs", json={"name": "E2E库"}).json()["kb_id"] # 3. 上传文件 resp = client.post( f"/api/kbs/{kid}/upload", files={"file": ("readme.md", io.BytesIO(b"# E2E test"), "text/markdown")}, ) assert resp.status_code == 200 # 4. 创建会话 sid = client.post("/api/sessions").json()["session_id"] # 5. 绑定 KB 到会话 bind = client.put(f"/api/sessions/{sid}/kb", json={"kb_id": kid}) assert bind.status_code == 200 assert bind.json()["kb_id"] == kid # 6. 查询会话 KB info = client.get(f"/api/sessions/{sid}/kb") assert info.json()["kb_name"] == "E2E库"