"""api_server.py 集成测试 — REST 端点 + SSE 流 + 文件上传/下载。 使用 FastAPI TestClient,不需要启动真实服务器。 """ import io import json import os import sys import tempfile from pathlib import Path from unittest.mock import patch, MagicMock import pytest from fastapi.testclient import TestClient sys.path.insert(0, str(Path(__file__).parent.parent)) from api_server import app @pytest.fixture def client(): return TestClient(app) @pytest.fixture def temp_sessions(monkeypatch): """重定向上传目录到临时目录,隔离测试数据。""" with tempfile.TemporaryDirectory(prefix="test_api_") as tmpdir: monkeypatch.setattr("api_server.UPLOADS_DIR", Path(tmpdir) / "uploads") monkeypatch.setattr("backend.session.SESSIONS_DIR", Path(tmpdir) / "sessions") yield Path(tmpdir) # ── 健康检查 & 配置 ──────────────────────────────────────────── class TestHealthAndConfig: def test_health_returns_ok(self, client): resp = client.get("/api/health") assert resp.status_code == 200 data = resp.json() assert data["status"] == "ok" assert data["version"] == "5.0" assert "timestamp" in data def test_config_returns_env_keys(self, client): resp = client.get("/api/config") assert resp.status_code == 200 cfg = resp.json()["config"] for key in ("LLM_PROVIDER", "OCR_ENGINE", "MAX_RETRY"): assert key in cfg # ── 会话 CRUD ────────────────────────────────────────────────── class TestSessionCRUD: def test_create_session(self, client, temp_sessions): resp = client.post("/api/sessions") assert resp.status_code == 200 data = resp.json() assert len(data["session_id"]) == 32 assert "session_name" in data assert "created_at" in data def test_list_sessions_empty(self, client, temp_sessions): assert client.get("/api/sessions").json()["sessions"] == [] def test_list_sessions_populated(self, client, temp_sessions): client.post("/api/sessions") client.post("/api/sessions") assert len(client.get("/api/sessions").json()["sessions"]) == 2 def test_get_session_found(self, client, temp_sessions): created = client.post("/api/sessions").json() resp = client.get(f"/api/sessions/{created['session_id']}") assert resp.status_code == 200 assert resp.json()["session_id"] == created["session_id"] assert "agent_state" in resp.json() def test_get_session_invalid_id(self, client, temp_sessions): assert client.get("/api/sessions/nonexistent").status_code == 400 def test_get_session_not_found(self, client, temp_sessions): assert client.get("/api/sessions/aabbccddeeff0011223344").status_code == 404 def test_delete_session(self, client, temp_sessions): sid = client.post("/api/sessions").json()["session_id"] resp = client.delete(f"/api/sessions/{sid}") assert resp.status_code == 200 assert resp.json()["status"] == "deleted" assert client.get(f"/api/sessions/{sid}").status_code == 404 def test_delete_nonexistent(self, client, temp_sessions): assert client.delete("/api/sessions/aabbccddeeff0011223344").status_code == 404 def test_full_crud_lifecycle(self, client, temp_sessions): sid = client.post("/api/sessions").json()["session_id"] assert client.get(f"/api/sessions/{sid}").status_code == 200 assert len(client.get("/api/sessions").json()["sessions"]) == 1 client.delete(f"/api/sessions/{sid}") assert client.get("/api/sessions").json()["sessions"] == [] # ── 文件上传 ─────────────────────────────────────────────────── class TestFileUpload: def test_upload_text_file(self, client, temp_sessions): content = b"Hello, JRXML!" resp = client.post( "/api/upload", files={"file": ("test.txt", io.BytesIO(content), "text/plain")}, ) assert resp.status_code == 200 data = resp.json() assert data["filename"] == "test.txt" assert data["size"] == len(content) assert len(data["file_id"]) == 12 def test_upload_with_session_id_in_query(self, client, temp_sessions): resp = client.post( "/api/upload?session_id=aabbccddeeff0011223344", files={"file": ("data.csv", io.BytesIO(b"a,b,c"), "text/csv")}, ) assert resp.status_code == 200 def test_upload_png_gets_correct_content_type(self, client, temp_sessions): png_minimal = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100 resp = client.post( "/api/upload", files={"file": ("chart.png", io.BytesIO(png_minimal), "image/png")}, ) assert resp.status_code == 200 assert resp.json()["content_type"] == "image/png" def test_upload_writes_file_to_disk(self, client, temp_sessions): data = b"persisted content" file_id = client.post( "/api/upload", files={"file": ("note.txt", io.BytesIO(data), "text/plain")}, ).json()["file_id"] matches = list(temp_sessions.rglob(f"{file_id}_*")) assert len(matches) == 1 assert matches[0].read_bytes() == data # ── 下载 ─────────────────────────────────────────────────────── class TestDownload: def test_download_missing_session_returns_404(self, client, temp_sessions): assert client.get("/api/sessions/aabbccddeeff0011223344/download/latest").status_code == 404 def test_download_no_jrxml_returns_404(self, client, temp_sessions): sid = client.post("/api/sessions").json()["session_id"] resp = client.get(f"/api/sessions/{sid}/download/latest") assert resp.status_code == 404 def test_download_with_jrxml_returns_file(self, client, temp_sessions): import backend.session as sess sess.create_session(name="测试下载") # 需要手动写入 JRXML 到会话 sessions = sess.list_all_sessions() sid = sessions[0]["session_id"] sess.save_session(sid, {"current_jrxml": ""}) resp = client.get(f"/api/sessions/{sid}/download/latest") assert resp.status_code == 200 assert "", "status": "pass"}}), ("updates", {"validate": {"status": "pass"}}), ("updates", {"finalize": {}}), ("done", {"reason": "graph_completed"}), ] # 注意:_graph 是模块级变量,在导入时就编译了。需要直接替换。 monkeypatch.setattr("api_server._graph", mock_graph) # 同时替换 agent.graph.build_graph 以防后续重新编译 monkeypatch.setattr("agent.graph.build_graph", lambda on_node_start=None: mock_graph) return mock_graph def test_empty_payload_rejected(self, client, temp_sessions): sid = client.post("/api/sessions").json()["session_id"] resp = client.post( f"/api/sessions/{sid}/chat", json={"text": "", "file_ids": []}, ) assert resp.status_code == 400 def test_sse_stream_returns_valid_events(self, client, temp_sessions): sid = client.post("/api/sessions").json()["session_id"] with client.stream( "POST", f"/api/sessions/{sid}/chat", json={"text": "生成一个简单的员工名册报表", "file_ids": []}, ) as resp: assert resp.status_code == 200 assert resp.headers["content-type"].startswith("text/event-stream") body = resp.read().decode("utf-8", errors="replace") assert "event: node_complete" in body assert "event: agent_complete" in body def test_auto_creates_session_on_chat(self, client, temp_sessions): with client.stream( "POST", "/api/sessions/aabbccddeeff0011223344/chat", json={"text": "生成报表", "file_ids": []}, ) as resp: assert resp.status_code == 200 assert b"event:" in resp.read() def test_unknown_file_ids_not_crash(self, client, temp_sessions): sid = client.post("/api/sessions").json()["session_id"] with client.stream( "POST", f"/api/sessions/{sid}/chat", json={"text": "测试", "file_ids": ["fake_id_xyz"]}, ) as resp: assert resp.status_code == 200 # ── 边界 & 安全测试 ──────────────────────────────────────────── class TestBoundaries: def test_session_id_invalid_format_returns_400(self, client, temp_sessions): """非 hex 字符的 session_id 应被拒绝。""" assert client.get("/api/sessions/not_valid_hex_id").status_code == 400 def test_upload_with_path_traversal_session_id(self, client, temp_sessions): """路径穿越 session_id 被拒绝。""" resp = client.post( "/api/upload?session_id=../malicious", files={"file": ("t.txt", io.BytesIO(b"x"), "text/plain")}, ) assert resp.status_code == 400 def test_invalid_json_body_rejected(self, client, temp_sessions): sid = client.post("/api/sessions").json()["session_id"] resp = client.post( f"/api/sessions/{sid}/chat", content=b"{not valid json", headers={"Content-Type": "application/json"}, ) assert resp.status_code == 422 def test_large_payload_survives(self, client, temp_sessions): """大文本(100KB)不应崩溃。""" sid = client.post("/api/sessions").json()["session_id"] large_text = "生成报表包含字段: " + ", ".join(f"field_{i}" for i in range(5000)) with client.stream( "POST", f"/api/sessions/{sid}/chat", json={"text": large_text, "file_ids": []}, ) as resp: assert resp.status_code == 200