bd5bfbac2d
Root cause: LLM receiving full 34k-char JRXML would regenerate from scratch
instead of modifying coordinates in-place, shrinking output to ~3k chars.
Solution (programmatic node control, not prompt engineering):
- New agent/jrxml_windower.py: decompose JRXML into header (never sent to
LLM) + individual bands. Split bands >4000 chars at element boundaries.
Reassemble with element count validation (>10% change = rollback).
- Rewrite refine_layout: per-band windowed LLM processing (~2-4k chars
each). LLM cannot "reimagine" the entire report.
- Rewrite map_fields: 100% programmatic regex $F{field_N} -> real name
replacement. Zero LLM calls, zero content loss.
- _sanitize_field_name: non-ASCII chars escaped to _uXXXX_ format for
valid JRXML identifiers.
- Tests: 48 new unit tests (windower 28 + map_fields 20). All passing.
Full suite 385 tests, zero regressions.
514 lines
22 KiB
Python
514 lines
22 KiB
Python
"""api_server.py 集成测试 — REST 端点 + SSE 流 + 文件上传/下载。
|
|
|
|
使用 FastAPI TestClient,不需要启动真实服务器。
|
|
"""
|
|
|
|
import io
|
|
import json
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
from pathlib import Path
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
from api_server import app
|
|
|
|
|
|
@pytest.fixture
|
|
def client():
|
|
return TestClient(app)
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_sessions(monkeypatch):
|
|
"""重定向上传目录到临时目录,隔离测试数据。"""
|
|
with tempfile.TemporaryDirectory(prefix="test_api_") as tmpdir:
|
|
monkeypatch.setattr("api_server.UPLOADS_DIR", Path(tmpdir) / "uploads")
|
|
monkeypatch.setattr("backend.session.SESSIONS_DIR", Path(tmpdir) / "sessions")
|
|
yield Path(tmpdir)
|
|
|
|
|
|
# ── 健康检查 & 配置 ────────────────────────────────────────────
|
|
|
|
class TestHealthAndConfig:
|
|
def test_health_returns_ok(self, client):
|
|
resp = client.get("/api/health")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["status"] == "ok"
|
|
assert data["version"] == "5.0"
|
|
assert "timestamp" in data
|
|
|
|
def test_config_returns_env_keys(self, client):
|
|
resp = client.get("/api/config")
|
|
assert resp.status_code == 200
|
|
cfg = resp.json()["config"]
|
|
for key in ("LLM_PROVIDER", "OCR_ENGINE", "MAX_RETRY"):
|
|
assert key in cfg
|
|
|
|
|
|
# ── 会话 CRUD ──────────────────────────────────────────────────
|
|
|
|
class TestSessionCRUD:
|
|
def test_create_session(self, client, temp_sessions):
|
|
resp = client.post("/api/sessions")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert len(data["session_id"]) == 32
|
|
assert "session_name" in data
|
|
assert "created_at" in data
|
|
|
|
def test_list_sessions_empty(self, client, temp_sessions):
|
|
assert client.get("/api/sessions").json()["sessions"] == []
|
|
|
|
def test_list_sessions_populated(self, client, temp_sessions):
|
|
client.post("/api/sessions")
|
|
client.post("/api/sessions")
|
|
assert len(client.get("/api/sessions").json()["sessions"]) == 2
|
|
|
|
def test_get_session_found(self, client, temp_sessions):
|
|
created = client.post("/api/sessions").json()
|
|
resp = client.get(f"/api/sessions/{created['session_id']}")
|
|
assert resp.status_code == 200
|
|
assert resp.json()["session_id"] == created["session_id"]
|
|
assert "agent_state" in resp.json()
|
|
|
|
def test_get_session_invalid_id(self, client, temp_sessions):
|
|
assert client.get("/api/sessions/nonexistent").status_code == 400
|
|
|
|
def test_get_session_not_found(self, client, temp_sessions):
|
|
assert client.get("/api/sessions/aabbccddeeff0011223344").status_code == 404
|
|
|
|
def test_delete_session(self, client, temp_sessions):
|
|
sid = client.post("/api/sessions").json()["session_id"]
|
|
resp = client.delete(f"/api/sessions/{sid}")
|
|
assert resp.status_code == 200
|
|
assert resp.json()["status"] == "deleted"
|
|
assert client.get(f"/api/sessions/{sid}").status_code == 404
|
|
|
|
def test_delete_nonexistent(self, client, temp_sessions):
|
|
assert client.delete("/api/sessions/aabbccddeeff0011223344").status_code == 404
|
|
|
|
def test_full_crud_lifecycle(self, client, temp_sessions):
|
|
sid = client.post("/api/sessions").json()["session_id"]
|
|
assert client.get(f"/api/sessions/{sid}").status_code == 200
|
|
assert len(client.get("/api/sessions").json()["sessions"]) == 1
|
|
client.delete(f"/api/sessions/{sid}")
|
|
assert client.get("/api/sessions").json()["sessions"] == []
|
|
|
|
|
|
# ── 文件上传 ───────────────────────────────────────────────────
|
|
|
|
class TestFileUpload:
|
|
def test_upload_text_file(self, client, temp_sessions):
|
|
content = b"Hello, JRXML!"
|
|
resp = client.post(
|
|
"/api/upload",
|
|
files={"file": ("test.txt", io.BytesIO(content), "text/plain")},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["filename"] == "test.txt"
|
|
assert data["size"] == len(content)
|
|
assert len(data["file_id"]) == 12
|
|
|
|
def test_upload_with_session_id_in_query(self, client, temp_sessions):
|
|
resp = client.post(
|
|
"/api/upload?session_id=aabbccddeeff0011223344",
|
|
files={"file": ("data.csv", io.BytesIO(b"a,b,c"), "text/csv")},
|
|
)
|
|
assert resp.status_code == 200
|
|
|
|
def test_upload_png_gets_correct_content_type(self, client, temp_sessions):
|
|
png_minimal = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100
|
|
resp = client.post(
|
|
"/api/upload",
|
|
files={"file": ("chart.png", io.BytesIO(png_minimal), "image/png")},
|
|
)
|
|
assert resp.status_code == 200
|
|
assert resp.json()["content_type"] == "image/png"
|
|
|
|
def test_upload_writes_file_to_disk(self, client, temp_sessions):
|
|
data = b"persisted content"
|
|
file_id = client.post(
|
|
"/api/upload",
|
|
files={"file": ("note.txt", io.BytesIO(data), "text/plain")},
|
|
).json()["file_id"]
|
|
|
|
matches = list(temp_sessions.rglob(f"{file_id}_*"))
|
|
assert len(matches) == 1
|
|
assert matches[0].read_bytes() == data
|
|
|
|
|
|
# ── 下载 ───────────────────────────────────────────────────────
|
|
|
|
class TestDownload:
|
|
def test_download_missing_session_returns_404(self, client, temp_sessions):
|
|
assert client.get("/api/sessions/aabbccddeeff0011223344/download/latest").status_code == 404
|
|
|
|
def test_download_no_jrxml_returns_404(self, client, temp_sessions):
|
|
sid = client.post("/api/sessions").json()["session_id"]
|
|
resp = client.get(f"/api/sessions/{sid}/download/latest")
|
|
assert resp.status_code == 404
|
|
|
|
def test_download_with_jrxml_returns_file(self, client, temp_sessions):
|
|
import backend.session as sess
|
|
|
|
sess.create_session(name="测试下载")
|
|
# 需要手动写入 JRXML 到会话
|
|
sessions = sess.list_all_sessions()
|
|
sid = sessions[0]["session_id"]
|
|
sess.save_session(sid, {"current_jrxml": "<jasperReport name='rpt'/>"})
|
|
|
|
resp = client.get(f"/api/sessions/{sid}/download/latest")
|
|
assert resp.status_code == 200
|
|
assert "<jasperReport" in resp.text
|
|
assert "attachment" in resp.headers.get("content-disposition", "")
|
|
|
|
|
|
# ── 聊天 SSE ───────────────────────────────────────────────────
|
|
|
|
class TestChatSSE:
|
|
@pytest.fixture(autouse=True)
|
|
def mock_graph(self, monkeypatch):
|
|
"""Mock LangGraph 的 build_graph 和 stream,避免真实 LLM 调用。"""
|
|
mock_graph = MagicMock()
|
|
mock_graph.stream.return_value = [
|
|
("updates", {"classify_intent": {"intent": "initial_generation"}}),
|
|
("updates", {"generate": {"current_jrxml": "<jasperReport name='test'/>", "status": "pass"}}),
|
|
("updates", {"validate": {"status": "pass"}}),
|
|
("updates", {"finalize": {}}),
|
|
("done", {"reason": "graph_completed"}),
|
|
]
|
|
# 注意:_graph 是模块级变量,在导入时就编译了。需要直接替换。
|
|
monkeypatch.setattr("api_server._graph", mock_graph)
|
|
# 同时替换 agent.graph.build_graph 以防后续重新编译
|
|
monkeypatch.setattr("agent.graph.build_graph", lambda on_node_start=None: mock_graph)
|
|
return mock_graph
|
|
|
|
def test_empty_payload_rejected(self, client, temp_sessions):
|
|
sid = client.post("/api/sessions").json()["session_id"]
|
|
resp = client.post(
|
|
f"/api/sessions/{sid}/chat",
|
|
json={"text": "", "file_ids": []},
|
|
)
|
|
assert resp.status_code == 400
|
|
|
|
def test_sse_stream_returns_valid_events(self, client, temp_sessions):
|
|
sid = client.post("/api/sessions").json()["session_id"]
|
|
with client.stream(
|
|
"POST",
|
|
f"/api/sessions/{sid}/chat",
|
|
json={"text": "生成一个简单的员工名册报表", "file_ids": []},
|
|
) as resp:
|
|
assert resp.status_code == 200
|
|
assert resp.headers["content-type"].startswith("text/event-stream")
|
|
body = resp.read().decode("utf-8", errors="replace")
|
|
assert "event: node_complete" in body
|
|
assert "event: agent_complete" in body
|
|
|
|
def test_auto_creates_session_on_chat(self, client, temp_sessions):
|
|
with client.stream(
|
|
"POST",
|
|
"/api/sessions/aabbccddeeff0011223344/chat",
|
|
json={"text": "生成报表", "file_ids": []},
|
|
) as resp:
|
|
assert resp.status_code == 200
|
|
assert b"event:" in resp.read()
|
|
|
|
def test_unknown_file_ids_not_crash(self, client, temp_sessions):
|
|
sid = client.post("/api/sessions").json()["session_id"]
|
|
with client.stream(
|
|
"POST",
|
|
f"/api/sessions/{sid}/chat",
|
|
json={"text": "测试", "file_ids": ["fake_id_xyz"]},
|
|
) as resp:
|
|
assert resp.status_code == 200
|
|
|
|
|
|
# ── 边界 & 安全测试 ────────────────────────────────────────────
|
|
|
|
class TestBoundaries:
|
|
def test_session_id_invalid_format_returns_400(self, client, temp_sessions):
|
|
"""非 hex 字符的 session_id 应被拒绝。"""
|
|
assert client.get("/api/sessions/not_valid_hex_id").status_code == 400
|
|
|
|
def test_upload_with_path_traversal_session_id(self, client, temp_sessions):
|
|
"""路径穿越 session_id 被拒绝。"""
|
|
resp = client.post(
|
|
"/api/upload?session_id=../malicious",
|
|
files={"file": ("t.txt", io.BytesIO(b"x"), "text/plain")},
|
|
)
|
|
assert resp.status_code == 400
|
|
|
|
def test_invalid_json_body_rejected(self, client, temp_sessions):
|
|
sid = client.post("/api/sessions").json()["session_id"]
|
|
resp = client.post(
|
|
f"/api/sessions/{sid}/chat",
|
|
content=b"{not valid json",
|
|
headers={"Content-Type": "application/json"},
|
|
)
|
|
assert resp.status_code == 422
|
|
|
|
def test_large_payload_survives(self, client, temp_sessions):
|
|
"""大文本(100KB)不应崩溃。"""
|
|
sid = client.post("/api/sessions").json()["session_id"]
|
|
large_text = "生成报表包含字段: " + ", ".join(f"field_{i}" for i in range(5000))
|
|
with client.stream(
|
|
"POST",
|
|
f"/api/sessions/{sid}/chat",
|
|
json={"text": large_text, "file_ids": []},
|
|
) as resp:
|
|
assert resp.status_code == 200
|
|
|
|
|
|
# ── 用户 CRUD API ───────────────────────────────────────────────
|
|
|
|
class TestUserAPI:
|
|
@pytest.fixture(autouse=True)
|
|
def temp_kb_data(self, monkeypatch, tmp_path):
|
|
kb_data = tmp_path / "kb_data"
|
|
monkeypatch.setattr("backend.kb_manager.KB_DATA_DIR", kb_data)
|
|
monkeypatch.setattr("backend.kb_manager._USERS_FILE", kb_data / "users.json")
|
|
yield kb_data
|
|
|
|
def test_create_user(self, client):
|
|
resp = client.post("/api/users", json={"name": "测试用户"})
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["name"] == "测试用户"
|
|
assert len(data["user_id"]) >= 12
|
|
|
|
def test_create_user_empty_name_rejected(self, client):
|
|
resp = client.post("/api/users", json={"name": ""})
|
|
assert resp.status_code == 400
|
|
|
|
def test_list_users(self, client):
|
|
client.post("/api/users", json={"name": "A"})
|
|
client.post("/api/users", json={"name": "B"})
|
|
resp = client.get("/api/users")
|
|
assert resp.status_code == 200
|
|
assert len(resp.json()["users"]) == 2
|
|
|
|
def test_get_user(self, client):
|
|
uid = client.post("/api/users", json={"name": "张三"}).json()["user_id"]
|
|
resp = client.get(f"/api/users/{uid}")
|
|
assert resp.status_code == 200
|
|
assert resp.json()["name"] == "张三"
|
|
|
|
def test_get_user_not_found(self, client):
|
|
resp = client.get("/api/users/deadbeef1234567890abcd")
|
|
assert resp.status_code == 404
|
|
|
|
def test_delete_user(self, client):
|
|
uid = client.post("/api/users", json={"name": "待删除"}).json()["user_id"]
|
|
resp = client.delete(f"/api/users/{uid}")
|
|
assert resp.status_code == 200
|
|
assert resp.json()["status"] == "deleted"
|
|
assert client.get(f"/api/users/{uid}").status_code == 404
|
|
|
|
def test_delete_nonexistent_user(self, client):
|
|
resp = client.delete("/api/users/deadbeef1234567890abcd")
|
|
assert resp.status_code == 404
|
|
|
|
|
|
# ── 知识库 CRUD API ─────────────────────────────────────────────
|
|
|
|
class TestKbAPI:
|
|
@pytest.fixture(autouse=True)
|
|
def setup_kb(self, monkeypatch, tmp_path):
|
|
kb_data = tmp_path / "kb_data"
|
|
monkeypatch.setattr("backend.kb_manager.KB_DATA_DIR", kb_data)
|
|
monkeypatch.setattr("backend.kb_manager._USERS_FILE", kb_data / "users.json")
|
|
# 使用 raw TestClient 来创建前置用户
|
|
from fastapi.testclient import TestClient as TC
|
|
tc = TC(app)
|
|
resp = tc.post("/api/users", json={"name": "KB测试用户"})
|
|
self.uid = resp.json()["user_id"]
|
|
|
|
def test_create_kb(self, client):
|
|
resp = client.post(f"/api/users/{self.uid}/kbs", json={"name": "测试库", "description": "描述"})
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["name"] == "测试库"
|
|
assert data["parse_status"] == "empty"
|
|
|
|
def test_create_kb_empty_name_rejected(self, client):
|
|
resp = client.post(f"/api/users/{self.uid}/kbs", json={"name": ""})
|
|
assert resp.status_code == 400
|
|
|
|
def test_list_kbs(self, client):
|
|
client.post(f"/api/users/{self.uid}/kbs", json={"name": "KB1"})
|
|
client.post(f"/api/users/{self.uid}/kbs", json={"name": "KB2"})
|
|
resp = client.get(f"/api/users/{self.uid}/kbs")
|
|
assert resp.status_code == 200
|
|
assert len(resp.json()["kbs"]) == 2
|
|
|
|
def test_get_kb(self, client):
|
|
kid = client.post(f"/api/users/{self.uid}/kbs", json={"name": "查询库"}).json()["kb_id"]
|
|
resp = client.get(f"/api/kbs/{kid}")
|
|
assert resp.status_code == 200
|
|
assert resp.json()["name"] == "查询库"
|
|
|
|
def test_get_kb_not_found(self, client):
|
|
resp = client.get("/api/kbs/deadbeef1234567890abcd")
|
|
assert resp.status_code == 404
|
|
|
|
def test_delete_kb(self, client):
|
|
kid = client.post(f"/api/users/{self.uid}/kbs", json={"name": "待删库"}).json()["kb_id"]
|
|
resp = client.delete(f"/api/kbs/{kid}")
|
|
assert resp.status_code == 200
|
|
assert resp.json()["status"] == "deleted"
|
|
|
|
def test_kb_status(self, client):
|
|
kid = client.post(f"/api/users/{self.uid}/kbs", json={"name": "状态库"}).json()["kb_id"]
|
|
resp = client.get(f"/api/kbs/{kid}/status")
|
|
assert resp.status_code == 200
|
|
assert resp.json()["parse_status"] == "empty"
|
|
assert resp.json()["file_count"] == 0
|
|
|
|
def test_kb_fields(self, client):
|
|
kid = client.post(f"/api/users/{self.uid}/kbs", json={"name": "字段库"}).json()["kb_id"]
|
|
resp = client.get(f"/api/kbs/{kid}/fields")
|
|
assert resp.status_code == 200
|
|
assert resp.json()["fields"] == []
|
|
assert resp.json()["templates"] == []
|
|
|
|
|
|
# ── KB 文件上传 & 构建 API ──────────────────────────────────────
|
|
|
|
class TestKbUploadBuild:
|
|
@pytest.fixture(autouse=True)
|
|
def setup_up(self, monkeypatch, tmp_path):
|
|
kb_data = tmp_path / "kb_data"
|
|
kb_data.mkdir(parents=True, exist_ok=True)
|
|
monkeypatch.setattr("backend.kb_manager.KB_DATA_DIR", kb_data)
|
|
monkeypatch.setattr("backend.kb_manager._USERS_FILE", kb_data / "users.json")
|
|
# Mock process_file_for_kb to avoid SameFileError (API already writes to raw_dir)
|
|
monkeypatch.setattr(
|
|
"backend.kb_parser.process_file_for_kb",
|
|
lambda kb_id, file_path, source_name="": {
|
|
"filename": source_name, "type": "txt", "error": None})
|
|
from fastapi.testclient import TestClient as TC
|
|
tc = TC(app)
|
|
resp = tc.post("/api/users", json={"name": "上传测试用户"})
|
|
self.uid = resp.json()["user_id"]
|
|
|
|
def test_upload_to_kb(self, client):
|
|
kid = client.post(f"/api/users/{self.uid}/kbs", json={"name": "上传库"}).json()["kb_id"]
|
|
resp = client.post(
|
|
f"/api/kbs/{kid}/upload",
|
|
files={"file": ("readme.md", io.BytesIO(b"# test"), "text/markdown")},
|
|
)
|
|
assert resp.status_code == 200
|
|
assert resp.json()["filename"] == "readme.md"
|
|
|
|
def test_upload_to_nonexistent_kb(self, client):
|
|
resp = client.post(
|
|
"/api/kbs/deadbeef1234567890abcd/upload",
|
|
files={"file": ("x.txt", io.BytesIO(b"x"), "text/plain")},
|
|
)
|
|
assert resp.status_code == 404
|
|
|
|
def test_build_empty_kb_fails(self, client):
|
|
kid = client.post(f"/api/users/{self.uid}/kbs", json={"name": "空库"}).json()["kb_id"]
|
|
resp = client.post(f"/api/kbs/{kid}/build")
|
|
assert resp.status_code == 400
|
|
|
|
def test_search_kb_empty_query_rejected(self, client):
|
|
kid = client.post(f"/api/users/{self.uid}/kbs", json={"name": "搜索库"}).json()["kb_id"]
|
|
resp = client.get(f"/api/kbs/{kid}/search")
|
|
assert resp.status_code == 400
|
|
|
|
|
|
# ── 会话-KB 绑定 API ────────────────────────────────────────────
|
|
|
|
class TestSessionKbBinding:
|
|
@pytest.fixture(autouse=True)
|
|
def setup_bind(self, monkeypatch, tmp_path):
|
|
kb_data = tmp_path / "kb_data"
|
|
kb_data.mkdir(parents=True, exist_ok=True)
|
|
sessions_dir = tmp_path / "sessions"
|
|
monkeypatch.setattr("backend.kb_manager.KB_DATA_DIR", kb_data)
|
|
monkeypatch.setattr("backend.kb_manager._USERS_FILE", kb_data / "users.json")
|
|
monkeypatch.setattr("backend.session.SESSIONS_DIR", sessions_dir)
|
|
monkeypatch.setattr("api_server.UPLOADS_DIR", tmp_path / "uploads")
|
|
|
|
def test_bind_kb_to_session(self, client):
|
|
uid = client.post("/api/users", json={"name": "绑定用户"}).json()["user_id"]
|
|
kid = client.post(f"/api/users/{uid}/kbs", json={"name": "绑定库"}).json()["kb_id"]
|
|
sid = client.post("/api/sessions").json()["session_id"]
|
|
resp = client.put(f"/api/sessions/{sid}/kb", json={"kb_id": kid})
|
|
assert resp.status_code == 200
|
|
assert resp.json()["kb_id"] == kid
|
|
|
|
def test_get_session_kb(self, client):
|
|
uid = client.post("/api/users", json={"name": "查询用户"}).json()["user_id"]
|
|
kid = client.post(f"/api/users/{uid}/kbs", json={"name": "查询KB"}).json()["kb_id"]
|
|
sid = client.post("/api/sessions").json()["session_id"]
|
|
client.put(f"/api/sessions/{sid}/kb", json={"kb_id": kid})
|
|
resp = client.get(f"/api/sessions/{sid}/kb")
|
|
assert resp.status_code == 200
|
|
assert resp.json()["kb_id"] == kid
|
|
assert resp.json()["kb_name"] == "查询KB"
|
|
|
|
def test_unbind_kb(self, client):
|
|
sid = client.post("/api/sessions").json()["session_id"]
|
|
resp = client.put(f"/api/sessions/{sid}/kb", json={"kb_id": ""})
|
|
assert resp.status_code == 200
|
|
assert resp.json()["kb_id"] is None
|
|
|
|
def test_bind_nonexistent_kb(self, client):
|
|
sid = client.post("/api/sessions").json()["session_id"]
|
|
resp = client.put(f"/api/sessions/{sid}/kb", json={"kb_id": "deadbeef1234567890abcd"})
|
|
assert resp.status_code == 404
|
|
|
|
def test_bind_to_nonexistent_session(self, client):
|
|
resp = client.put("/api/sessions/deadbeef1234567890abcd/kb", json={"kb_id": ""})
|
|
assert resp.status_code == 404
|
|
|
|
|
|
# ── 用户-KB 端到端流程 ──────────────────────────────────────────
|
|
|
|
class TestUserKbE2E:
|
|
@pytest.fixture(autouse=True)
|
|
def setup_e2e(self, monkeypatch, tmp_path):
|
|
kb_data = tmp_path / "kb_data"
|
|
kb_data.mkdir(parents=True, exist_ok=True)
|
|
sessions_dir = tmp_path / "sessions"
|
|
monkeypatch.setattr("backend.kb_manager.KB_DATA_DIR", kb_data)
|
|
monkeypatch.setattr("backend.kb_manager._USERS_FILE", kb_data / "users.json")
|
|
monkeypatch.setattr("backend.session.SESSIONS_DIR", sessions_dir)
|
|
monkeypatch.setattr("api_server.UPLOADS_DIR", tmp_path / "uploads")
|
|
# Mock process_file_for_kb to avoid SameFileError
|
|
monkeypatch.setattr(
|
|
"backend.kb_parser.process_file_for_kb",
|
|
lambda kb_id, file_path, source_name="": {
|
|
"filename": source_name, "type": "txt", "error": None})
|
|
|
|
def test_full_flow(self, client):
|
|
# 1. 创建用户
|
|
uid = client.post("/api/users", json={"name": "E2E用户"}).json()["user_id"]
|
|
# 2. 创建 KB
|
|
kid = client.post(f"/api/users/{uid}/kbs", json={"name": "E2E库"}).json()["kb_id"]
|
|
# 3. 上传文件
|
|
resp = client.post(
|
|
f"/api/kbs/{kid}/upload",
|
|
files={"file": ("readme.md", io.BytesIO(b"# E2E test"), "text/markdown")},
|
|
)
|
|
assert resp.status_code == 200
|
|
# 4. 创建会话
|
|
sid = client.post("/api/sessions").json()["session_id"]
|
|
# 5. 绑定 KB 到会话
|
|
bind = client.put(f"/api/sessions/{sid}/kb", json={"kb_id": kid})
|
|
assert bind.status_code == 200
|
|
assert bind.json()["kb_id"] == kid
|
|
# 6. 查询会话 KB
|
|
info = client.get(f"/api/sessions/{sid}/kb")
|
|
assert info.json()["kb_name"] == "E2E库"
|