From 93ad5e8876139ee89937f63a3b212341467b36bd Mon Sep 17 00:00:00 2001 From: panda <1415243231@qq.com> Date: Sat, 23 May 2026 09:08:53 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20address=20audit=20findings=20=E2=80=94?= =?UTF-8?q?=20session=5Fid=20validation,=20streaming=20reset,=20state=20is?= =?UTF-8?q?olation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace truncated 12-char UUID with full 32-char UUID (128-bit entropy) - Add validate_session_id() regex check to prevent path traversal - Add _check_session_id() guard on all 6 API endpoints - Change _step_counter from module global to contextvars.ContextVar - Filter None values from node_state before merging into agent_state - Log save_session failures instead of silently swallowing them - Add finishStreaming() in catch/finally blocks to prevent UI lockup - Fix broken multiline docstring in chat() endpoint --- api_server.py | 37 +++++++++++++++++++++++++---------- backend/session.py | 21 +++++++++++++++++--- frontend/src/App.vue | 5 +++++ tests/test_api_integration.py | 24 +++++++++++++---------- tests/test_session.py | 4 ++-- 5 files changed, 66 insertions(+), 25 deletions(-) diff --git a/api_server.py b/api_server.py index 630807f..f7ab8d5 100644 --- a/api_server.py +++ b/api_server.py @@ -16,6 +16,7 @@ Usage: import asyncio import base64 +import contextvars import json import mimetypes import os @@ -97,25 +98,30 @@ SKIP_NODES = {"load_session", "process_input", "manage_context", _api_log = get_logger("api") UPLOADS_DIR = Path(os.getenv("UPLOADS_DIR", "./uploads")) +def _check_session_id(session_id: str) -> None: + """校验 session_id 合法性(防路径穿越),非法时抛出 HTTPException(400)。""" + from backend.session import validate_session_id + if not validate_session_id(session_id): + raise HTTPException(status_code=400, detail=f"Invalid session_id: {session_id!r}") + # ───────────────────────────────────────────── # 图编译(全局单例,带 node_start 回调) # ───────────────────────────────────────────── -# 当前请求的事件队列(单个用户桌面应用,无并发问题) +# 当前请求的事件队列(单个用户桌面应用) _current_event_queue: Optional[queue.Queue] = None -_step_counter: int = 0 +_step_counter: contextvars.ContextVar[int] = contextvars.ContextVar('_step_counter', default=0) def _on_node_start(node_name: str): """全局 node_start 回调 — 将事件推入当前请求的事件队列。""" - global _step_counter q = _current_event_queue if q is not None: - _step_counter += 1 + _step_counter.set(_step_counter.get() + 1) q.put(("node_start", { "node": node_name, "label": NODE_LABELS.get(node_name, node_name), - "step_index": _step_counter, + "step_index": _step_counter.get(), })) @@ -180,14 +186,18 @@ def _run_graph_sync(agent_state: AgentState, event_q: queue.Queue): if mode == "updates" and isinstance(data, dict): for node_state in data.values(): if isinstance(node_state, dict): - agent_state.update(node_state) + agent_state.update({k: v for k, v in node_state.items() if v is not None}) # 在 graph 完成后立即保存 session,防止 SSE 流中断导致数据丢失 sid = agent_state.get("session_id", "") if sid: try: save_session(sid, agent_state) - except Exception: - pass # 静默失败,SSE 流中还有一次保存机会 + except Exception as exc: + _api_log.error("图运行中保存会话失败", extra={ + "session_id": sid, + "error": str(exc), + "traceback": traceback.format_exc(), + }) event_q.put(("done", {"reason": "graph_completed"})) except Exception as exc: event_q.put(("error", { @@ -198,9 +208,9 @@ def _run_graph_sync(agent_state: AgentState, event_q: queue.Queue): async def _sse_generator(agent_state: AgentState, session_id: str = "") -> str: """SSE 事件生成器 —— 在后台线程运行图,异步产出 SSE 字符串。""" - global _current_event_queue, _step_counter + global _current_event_queue - _step_counter = 0 + _step_counter.set(0) t_start = time.time() event_q: queue.Queue = queue.Queue() _current_event_queue = event_q @@ -347,6 +357,7 @@ async def list_sessions(): @app.get("/api/sessions/{session_id}") async def get_session(session_id: str): + _check_session_id(session_id) data = get_session_state(session_id) if data is None: raise HTTPException(status_code=404, detail="会话不存在") @@ -361,6 +372,7 @@ async def get_session(session_id: str): @app.delete("/api/sessions/{session_id}") async def remove_session(session_id: str): + _check_session_id(session_id) ok = delete_session(session_id) if not ok: raise HTTPException(status_code=404, detail="会话不存在或已删除") @@ -373,6 +385,8 @@ async def remove_session(session_id: str): @app.post("/api/upload") async def upload_file(file: UploadFile = File(...), session_id: str = ""): + if session_id: + _check_session_id(session_id) file_id = uuid.uuid4().hex[:12] _ensure_upload_dir(session_id) @@ -492,6 +506,7 @@ async def chat(session_id: str, payload: dict): Returns: text/event-stream (SSE) """ + _check_session_id(session_id) text = payload.get("text", "").strip() file_ids = payload.get("file_ids", []) @@ -577,6 +592,7 @@ async def chat(session_id: str, payload: dict): @app.get("/api/sessions/{session_id}/download/latest") async def download_latest(session_id: str): """下载最新 JRXML 文件。""" + _check_session_id(session_id) data = load_session(session_id) if data is None: raise HTTPException(status_code=404, detail="会话不存在") @@ -601,6 +617,7 @@ async def download_latest(session_id: str): @app.get("/api/sessions/{session_id}/download/{version}") async def download_version(session_id: str, version: int): """下载指定版本的 JRXML 文件。""" + _check_session_id(session_id) data = load_session(session_id) if data is None: raise HTTPException(status_code=404, detail="会话不存在") diff --git a/backend/session.py b/backend/session.py index 22472c0..fae4f53 100644 --- a/backend/session.py +++ b/backend/session.py @@ -5,6 +5,7 @@ import json import os +import re import uuid import tempfile from datetime import datetime, timezone @@ -26,12 +27,20 @@ def _ensure_dir(): SESSIONS_DIR.mkdir(parents=True, exist_ok=True) +_VALID_SESSION_ID_RE = re.compile(r'^[a-fA-F0-9]{12,}$') + +def validate_session_id(session_id: str) -> bool: + """校验 session_id 仅含合法 hex 字符(防路径穿越)。""" + return bool(_VALID_SESSION_ID_RE.match(session_id)) + def _session_path(session_id: str) -> Path: + if not validate_session_id(session_id): + raise ValueError(f"Invalid session_id: {session_id!r}") return SESSIONS_DIR / f"{session_id}.json" def generate_session_id() -> str: - return uuid.uuid4().hex[:12] + return uuid.uuid4().hex def create_session(name: str = "", agent_state: Optional[dict] = None, @@ -58,7 +67,10 @@ def create_session(name: str = "", agent_state: Optional[dict] = None, def load_session(session_id: str) -> Optional[dict]: """按 ID 加载会话数据。未找到则返回 None。""" _ensure_dir() - fp = _session_path(session_id) + try: + fp = _session_path(session_id) + except ValueError: + return None if not fp.exists(): return None with open(fp, "r", encoding="utf-8") as f: @@ -132,7 +144,10 @@ def list_all_sessions() -> list[dict]: def delete_session(session_id: str) -> bool: """按 ID 删除会话文件。""" _ensure_dir() - fp = _session_path(session_id) + try: + fp = _session_path(session_id) + except ValueError: + return False if fp.exists(): fp.unlink() _session_log.info("删除会话", extra={"session_id": session_id}) diff --git a/frontend/src/App.vue b/frontend/src/App.vue index 7193cb8..f43c3cb 100644 --- a/frontend/src/App.vue +++ b/frontend/src/App.vue @@ -109,6 +109,11 @@ async function handleSend(text: string, files: File[]) { } catch (e: any) { chat.setError(e.message || '网络请求失败') chat.addMessage({ role: 'assistant', content: `请求失败: ${e.message}`, type: 'error' }) + chat.finishStreaming({ status: '' }) + } finally { + if (chat.streaming) { + chat.finishStreaming({ status: '' }) + } } } diff --git a/tests/test_api_integration.py b/tests/test_api_integration.py index 67b2c55..b85f559 100644 --- a/tests/test_api_integration.py +++ b/tests/test_api_integration.py @@ -59,7 +59,7 @@ class TestSessionCRUD: resp = client.post("/api/sessions") assert resp.status_code == 200 data = resp.json() - assert len(data["session_id"]) == 12 + assert len(data["session_id"]) == 32 assert "session_name" in data assert "created_at" in data @@ -78,8 +78,11 @@ class TestSessionCRUD: assert resp.json()["session_id"] == created["session_id"] assert "agent_state" in resp.json() + def test_get_session_invalid_id(self, client, temp_sessions): + assert client.get("/api/sessions/nonexistent").status_code == 400 + def test_get_session_not_found(self, client, temp_sessions): - assert client.get("/api/sessions/nonexistent").status_code == 404 + assert client.get("/api/sessions/aabbccddeeff0011223344").status_code == 404 def test_delete_session(self, client, temp_sessions): sid = client.post("/api/sessions").json()["session_id"] @@ -89,7 +92,7 @@ class TestSessionCRUD: assert client.get(f"/api/sessions/{sid}").status_code == 404 def test_delete_nonexistent(self, client, temp_sessions): - assert client.delete("/api/sessions/ghost_id").status_code == 404 + assert client.delete("/api/sessions/aabbccddeeff0011223344").status_code == 404 def test_full_crud_lifecycle(self, client, temp_sessions): sid = client.post("/api/sessions").json()["session_id"] @@ -116,7 +119,7 @@ class TestFileUpload: def test_upload_with_session_id_in_query(self, client, temp_sessions): resp = client.post( - "/api/upload?session_id=abc123", + "/api/upload?session_id=aabbccddeeff0011223344", files={"file": ("data.csv", io.BytesIO(b"a,b,c"), "text/csv")}, ) assert resp.status_code == 200 @@ -146,7 +149,7 @@ class TestFileUpload: class TestDownload: def test_download_missing_session_returns_404(self, client, temp_sessions): - assert client.get("/api/sessions/missing/download/latest").status_code == 404 + assert client.get("/api/sessions/aabbccddeeff0011223344/download/latest").status_code == 404 def test_download_no_jrxml_returns_404(self, client, temp_sessions): sid = client.post("/api/sessions").json()["session_id"] @@ -212,7 +215,7 @@ class TestChatSSE: def test_auto_creates_session_on_chat(self, client, temp_sessions): with client.stream( "POST", - "/api/sessions/auto_new_session/chat", + "/api/sessions/aabbccddeeff0011223344/chat", json={"text": "生成报表", "file_ids": []}, ) as resp: assert resp.status_code == 200 @@ -231,16 +234,17 @@ class TestChatSSE: # ── 边界 & 安全测试 ──────────────────────────────────────────── class TestBoundaries: - def test_session_id_path_traversal_returns_404(self, client, temp_sessions): - assert client.get("/api/sessions/../etc/passwd").status_code == 404 + def test_session_id_invalid_format_returns_400(self, client, temp_sessions): + """非 hex 字符的 session_id 应被拒绝。""" + assert client.get("/api/sessions/not_valid_hex_id").status_code == 400 def test_upload_with_path_traversal_session_id(self, client, temp_sessions): - """路径穿越 session_id 仍正常处理(目录隔离在 UPLOADS_DIR 内)。""" + """路径穿越 session_id 被拒绝。""" resp = client.post( "/api/upload?session_id=../malicious", files={"file": ("t.txt", io.BytesIO(b"x"), "text/plain")}, ) - assert resp.status_code == 200 + assert resp.status_code == 400 def test_invalid_json_body_rejected(self, client, temp_sessions): sid = client.post("/api/sessions").json()["session_id"] diff --git a/tests/test_session.py b/tests/test_session.py index 6a0414b..13cf4f2 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -41,7 +41,7 @@ def temp_sessions_dir(monkeypatch): class TestCreateSession: def test_creates_with_defaults(self, temp_sessions_dir): s = create_session() - assert len(s["session_id"]) == 12 + assert len(s["session_id"]) == 32 assert "新建报表" in s["session_name"] assert s["created_at"] assert s["updated_at"] @@ -139,7 +139,7 @@ class TestSaveSession: assert load_session(created["session_id"])["session_name"] == "原名" def test_fills_missing_created_at(self, temp_sessions_dir): - sid = "test_no_created" + sid = "aaaabbbbccccddddeeeeffff" fp = temp_sessions_dir / f"{sid}.json" fp.write_text( json.dumps({"session_id": sid, "session_name": "旧数据"}), "utf-8"