Files
agent_jrxml/tests/test_api_integration.py
T
panda 93ad5e8876 fix: address audit findings — session_id validation, streaming reset, state isolation
- Replace truncated 12-char UUID with full 32-char UUID (128-bit entropy)
- Add validate_session_id() regex check to prevent path traversal
- Add _check_session_id() guard on all 6 API endpoints
- Change _step_counter from module global to contextvars.ContextVar
- Filter None values from node_state before merging into agent_state
- Log save_session failures instead of silently swallowing them
- Add finishStreaming() in catch/finally blocks to prevent UI lockup
- Fix broken multiline docstring in chat() endpoint
2026-05-23 09:08:53 +08:00

268 lines
11 KiB
Python

"""api_server.py 集成测试 — REST 端点 + SSE 流 + 文件上传/下载。
使用 FastAPI TestClient,不需要启动真实服务器。
"""
import io
import json
import os
import sys
import tempfile
from pathlib import Path
from unittest.mock import patch, MagicMock
import pytest
from fastapi.testclient import TestClient
sys.path.insert(0, str(Path(__file__).parent.parent))
from api_server import app
@pytest.fixture
def client():
return TestClient(app)
@pytest.fixture
def temp_sessions(monkeypatch):
"""重定向上传目录到临时目录,隔离测试数据。"""
with tempfile.TemporaryDirectory(prefix="test_api_") as tmpdir:
monkeypatch.setattr("api_server.UPLOADS_DIR", Path(tmpdir) / "uploads")
monkeypatch.setattr("backend.session.SESSIONS_DIR", Path(tmpdir) / "sessions")
yield Path(tmpdir)
# ── 健康检查 & 配置 ────────────────────────────────────────────
class TestHealthAndConfig:
def test_health_returns_ok(self, client):
resp = client.get("/api/health")
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "ok"
assert data["version"] == "5.0"
assert "timestamp" in data
def test_config_returns_env_keys(self, client):
resp = client.get("/api/config")
assert resp.status_code == 200
cfg = resp.json()["config"]
for key in ("LLM_PROVIDER", "OCR_ENGINE", "MAX_RETRY"):
assert key in cfg
# ── 会话 CRUD ──────────────────────────────────────────────────
class TestSessionCRUD:
def test_create_session(self, client, temp_sessions):
resp = client.post("/api/sessions")
assert resp.status_code == 200
data = resp.json()
assert len(data["session_id"]) == 32
assert "session_name" in data
assert "created_at" in data
def test_list_sessions_empty(self, client, temp_sessions):
assert client.get("/api/sessions").json()["sessions"] == []
def test_list_sessions_populated(self, client, temp_sessions):
client.post("/api/sessions")
client.post("/api/sessions")
assert len(client.get("/api/sessions").json()["sessions"]) == 2
def test_get_session_found(self, client, temp_sessions):
created = client.post("/api/sessions").json()
resp = client.get(f"/api/sessions/{created['session_id']}")
assert resp.status_code == 200
assert resp.json()["session_id"] == created["session_id"]
assert "agent_state" in resp.json()
def test_get_session_invalid_id(self, client, temp_sessions):
assert client.get("/api/sessions/nonexistent").status_code == 400
def test_get_session_not_found(self, client, temp_sessions):
assert client.get("/api/sessions/aabbccddeeff0011223344").status_code == 404
def test_delete_session(self, client, temp_sessions):
sid = client.post("/api/sessions").json()["session_id"]
resp = client.delete(f"/api/sessions/{sid}")
assert resp.status_code == 200
assert resp.json()["status"] == "deleted"
assert client.get(f"/api/sessions/{sid}").status_code == 404
def test_delete_nonexistent(self, client, temp_sessions):
assert client.delete("/api/sessions/aabbccddeeff0011223344").status_code == 404
def test_full_crud_lifecycle(self, client, temp_sessions):
sid = client.post("/api/sessions").json()["session_id"]
assert client.get(f"/api/sessions/{sid}").status_code == 200
assert len(client.get("/api/sessions").json()["sessions"]) == 1
client.delete(f"/api/sessions/{sid}")
assert client.get("/api/sessions").json()["sessions"] == []
# ── 文件上传 ───────────────────────────────────────────────────
class TestFileUpload:
def test_upload_text_file(self, client, temp_sessions):
content = b"Hello, JRXML!"
resp = client.post(
"/api/upload",
files={"file": ("test.txt", io.BytesIO(content), "text/plain")},
)
assert resp.status_code == 200
data = resp.json()
assert data["filename"] == "test.txt"
assert data["size"] == len(content)
assert len(data["file_id"]) == 12
def test_upload_with_session_id_in_query(self, client, temp_sessions):
resp = client.post(
"/api/upload?session_id=aabbccddeeff0011223344",
files={"file": ("data.csv", io.BytesIO(b"a,b,c"), "text/csv")},
)
assert resp.status_code == 200
def test_upload_png_gets_correct_content_type(self, client, temp_sessions):
png_minimal = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100
resp = client.post(
"/api/upload",
files={"file": ("chart.png", io.BytesIO(png_minimal), "image/png")},
)
assert resp.status_code == 200
assert resp.json()["content_type"] == "image/png"
def test_upload_writes_file_to_disk(self, client, temp_sessions):
data = b"persisted content"
file_id = client.post(
"/api/upload",
files={"file": ("note.txt", io.BytesIO(data), "text/plain")},
).json()["file_id"]
matches = list(temp_sessions.rglob(f"{file_id}_*"))
assert len(matches) == 1
assert matches[0].read_bytes() == data
# ── 下载 ───────────────────────────────────────────────────────
class TestDownload:
def test_download_missing_session_returns_404(self, client, temp_sessions):
assert client.get("/api/sessions/aabbccddeeff0011223344/download/latest").status_code == 404
def test_download_no_jrxml_returns_404(self, client, temp_sessions):
sid = client.post("/api/sessions").json()["session_id"]
resp = client.get(f"/api/sessions/{sid}/download/latest")
assert resp.status_code == 404
def test_download_with_jrxml_returns_file(self, client, temp_sessions):
import backend.session as sess
sess.create_session(name="测试下载")
# 需要手动写入 JRXML 到会话
sessions = sess.list_all_sessions()
sid = sessions[0]["session_id"]
sess.save_session(sid, {"current_jrxml": "<jasperReport name='rpt'/>"})
resp = client.get(f"/api/sessions/{sid}/download/latest")
assert resp.status_code == 200
assert "<jasperReport" in resp.text
assert "attachment" in resp.headers.get("content-disposition", "")
# ── 聊天 SSE ───────────────────────────────────────────────────
class TestChatSSE:
@pytest.fixture(autouse=True)
def mock_graph(self, monkeypatch):
"""Mock LangGraph 的 build_graph 和 stream,避免真实 LLM 调用。"""
mock_graph = MagicMock()
mock_graph.stream.return_value = [
("updates", {"classify_intent": {"intent": "initial_generation"}}),
("updates", {"generate": {"current_jrxml": "<jasperReport name='test'/>", "status": "pass"}}),
("updates", {"validate": {"status": "pass"}}),
("updates", {"finalize": {}}),
("done", {"reason": "graph_completed"}),
]
# 注意:_graph 是模块级变量,在导入时就编译了。需要直接替换。
monkeypatch.setattr("api_server._graph", mock_graph)
# 同时替换 agent.graph.build_graph 以防后续重新编译
monkeypatch.setattr("agent.graph.build_graph", lambda on_node_start=None: mock_graph)
return mock_graph
def test_empty_payload_rejected(self, client, temp_sessions):
sid = client.post("/api/sessions").json()["session_id"]
resp = client.post(
f"/api/sessions/{sid}/chat",
json={"text": "", "file_ids": []},
)
assert resp.status_code == 400
def test_sse_stream_returns_valid_events(self, client, temp_sessions):
sid = client.post("/api/sessions").json()["session_id"]
with client.stream(
"POST",
f"/api/sessions/{sid}/chat",
json={"text": "生成一个简单的员工名册报表", "file_ids": []},
) as resp:
assert resp.status_code == 200
assert resp.headers["content-type"].startswith("text/event-stream")
body = resp.read().decode("utf-8", errors="replace")
assert "event: node_complete" in body
assert "event: agent_complete" in body
def test_auto_creates_session_on_chat(self, client, temp_sessions):
with client.stream(
"POST",
"/api/sessions/aabbccddeeff0011223344/chat",
json={"text": "生成报表", "file_ids": []},
) as resp:
assert resp.status_code == 200
assert b"event:" in resp.read()
def test_unknown_file_ids_not_crash(self, client, temp_sessions):
sid = client.post("/api/sessions").json()["session_id"]
with client.stream(
"POST",
f"/api/sessions/{sid}/chat",
json={"text": "测试", "file_ids": ["fake_id_xyz"]},
) as resp:
assert resp.status_code == 200
# ── 边界 & 安全测试 ────────────────────────────────────────────
class TestBoundaries:
def test_session_id_invalid_format_returns_400(self, client, temp_sessions):
"""非 hex 字符的 session_id 应被拒绝。"""
assert client.get("/api/sessions/not_valid_hex_id").status_code == 400
def test_upload_with_path_traversal_session_id(self, client, temp_sessions):
"""路径穿越 session_id 被拒绝。"""
resp = client.post(
"/api/upload?session_id=../malicious",
files={"file": ("t.txt", io.BytesIO(b"x"), "text/plain")},
)
assert resp.status_code == 400
def test_invalid_json_body_rejected(self, client, temp_sessions):
sid = client.post("/api/sessions").json()["session_id"]
resp = client.post(
f"/api/sessions/{sid}/chat",
content=b"{not valid json",
headers={"Content-Type": "application/json"},
)
assert resp.status_code == 422
def test_large_payload_survives(self, client, temp_sessions):
"""大文本(100KB)不应崩溃。"""
sid = client.post("/api/sessions").json()["session_id"]
large_text = "生成报表包含字段: " + ", ".join(f"field_{i}" for i in range(5000))
with client.stream(
"POST",
f"/api/sessions/{sid}/chat",
json={"text": large_text, "file_ids": []},
) as resp:
assert resp.status_code == 200