fix: address audit findings — session_id validation, streaming reset, state isolation
- Replace truncated 12-char UUID with full 32-char UUID (128-bit entropy) - Add validate_session_id() regex check to prevent path traversal - Add _check_session_id() guard on all 6 API endpoints - Change _step_counter from module global to contextvars.ContextVar - Filter None values from node_state before merging into agent_state - Log save_session failures instead of silently swallowing them - Add finishStreaming() in catch/finally blocks to prevent UI lockup - Fix broken multiline docstring in chat() endpoint
This commit is contained in:
+27
-10
@@ -16,6 +16,7 @@ Usage:
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
import contextvars
|
||||||
import json
|
import json
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
@@ -97,25 +98,30 @@ SKIP_NODES = {"load_session", "process_input", "manage_context",
|
|||||||
_api_log = get_logger("api")
|
_api_log = get_logger("api")
|
||||||
UPLOADS_DIR = Path(os.getenv("UPLOADS_DIR", "./uploads"))
|
UPLOADS_DIR = Path(os.getenv("UPLOADS_DIR", "./uploads"))
|
||||||
|
|
||||||
|
def _check_session_id(session_id: str) -> None:
|
||||||
|
"""校验 session_id 合法性(防路径穿越),非法时抛出 HTTPException(400)。"""
|
||||||
|
from backend.session import validate_session_id
|
||||||
|
if not validate_session_id(session_id):
|
||||||
|
raise HTTPException(status_code=400, detail=f"Invalid session_id: {session_id!r}")
|
||||||
|
|
||||||
# ─────────────────────────────────────────────
|
# ─────────────────────────────────────────────
|
||||||
# 图编译(全局单例,带 node_start 回调)
|
# 图编译(全局单例,带 node_start 回调)
|
||||||
# ─────────────────────────────────────────────
|
# ─────────────────────────────────────────────
|
||||||
|
|
||||||
# 当前请求的事件队列(单个用户桌面应用,无并发问题)
|
# 当前请求的事件队列(单个用户桌面应用)
|
||||||
_current_event_queue: Optional[queue.Queue] = None
|
_current_event_queue: Optional[queue.Queue] = None
|
||||||
_step_counter: int = 0
|
_step_counter: contextvars.ContextVar[int] = contextvars.ContextVar('_step_counter', default=0)
|
||||||
|
|
||||||
|
|
||||||
def _on_node_start(node_name: str):
|
def _on_node_start(node_name: str):
|
||||||
"""全局 node_start 回调 — 将事件推入当前请求的事件队列。"""
|
"""全局 node_start 回调 — 将事件推入当前请求的事件队列。"""
|
||||||
global _step_counter
|
|
||||||
q = _current_event_queue
|
q = _current_event_queue
|
||||||
if q is not None:
|
if q is not None:
|
||||||
_step_counter += 1
|
_step_counter.set(_step_counter.get() + 1)
|
||||||
q.put(("node_start", {
|
q.put(("node_start", {
|
||||||
"node": node_name,
|
"node": node_name,
|
||||||
"label": NODE_LABELS.get(node_name, node_name),
|
"label": NODE_LABELS.get(node_name, node_name),
|
||||||
"step_index": _step_counter,
|
"step_index": _step_counter.get(),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
@@ -180,14 +186,18 @@ def _run_graph_sync(agent_state: AgentState, event_q: queue.Queue):
|
|||||||
if mode == "updates" and isinstance(data, dict):
|
if mode == "updates" and isinstance(data, dict):
|
||||||
for node_state in data.values():
|
for node_state in data.values():
|
||||||
if isinstance(node_state, dict):
|
if isinstance(node_state, dict):
|
||||||
agent_state.update(node_state)
|
agent_state.update({k: v for k, v in node_state.items() if v is not None})
|
||||||
# 在 graph 完成后立即保存 session,防止 SSE 流中断导致数据丢失
|
# 在 graph 完成后立即保存 session,防止 SSE 流中断导致数据丢失
|
||||||
sid = agent_state.get("session_id", "")
|
sid = agent_state.get("session_id", "")
|
||||||
if sid:
|
if sid:
|
||||||
try:
|
try:
|
||||||
save_session(sid, agent_state)
|
save_session(sid, agent_state)
|
||||||
except Exception:
|
except Exception as exc:
|
||||||
pass # 静默失败,SSE 流中还有一次保存机会
|
_api_log.error("图运行中保存会话失败", extra={
|
||||||
|
"session_id": sid,
|
||||||
|
"error": str(exc),
|
||||||
|
"traceback": traceback.format_exc(),
|
||||||
|
})
|
||||||
event_q.put(("done", {"reason": "graph_completed"}))
|
event_q.put(("done", {"reason": "graph_completed"}))
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
event_q.put(("error", {
|
event_q.put(("error", {
|
||||||
@@ -198,9 +208,9 @@ def _run_graph_sync(agent_state: AgentState, event_q: queue.Queue):
|
|||||||
|
|
||||||
async def _sse_generator(agent_state: AgentState, session_id: str = "") -> str:
|
async def _sse_generator(agent_state: AgentState, session_id: str = "") -> str:
|
||||||
"""SSE 事件生成器 —— 在后台线程运行图,异步产出 SSE 字符串。"""
|
"""SSE 事件生成器 —— 在后台线程运行图,异步产出 SSE 字符串。"""
|
||||||
global _current_event_queue, _step_counter
|
global _current_event_queue
|
||||||
|
|
||||||
_step_counter = 0
|
_step_counter.set(0)
|
||||||
t_start = time.time()
|
t_start = time.time()
|
||||||
event_q: queue.Queue = queue.Queue()
|
event_q: queue.Queue = queue.Queue()
|
||||||
_current_event_queue = event_q
|
_current_event_queue = event_q
|
||||||
@@ -347,6 +357,7 @@ async def list_sessions():
|
|||||||
|
|
||||||
@app.get("/api/sessions/{session_id}")
|
@app.get("/api/sessions/{session_id}")
|
||||||
async def get_session(session_id: str):
|
async def get_session(session_id: str):
|
||||||
|
_check_session_id(session_id)
|
||||||
data = get_session_state(session_id)
|
data = get_session_state(session_id)
|
||||||
if data is None:
|
if data is None:
|
||||||
raise HTTPException(status_code=404, detail="会话不存在")
|
raise HTTPException(status_code=404, detail="会话不存在")
|
||||||
@@ -361,6 +372,7 @@ async def get_session(session_id: str):
|
|||||||
|
|
||||||
@app.delete("/api/sessions/{session_id}")
|
@app.delete("/api/sessions/{session_id}")
|
||||||
async def remove_session(session_id: str):
|
async def remove_session(session_id: str):
|
||||||
|
_check_session_id(session_id)
|
||||||
ok = delete_session(session_id)
|
ok = delete_session(session_id)
|
||||||
if not ok:
|
if not ok:
|
||||||
raise HTTPException(status_code=404, detail="会话不存在或已删除")
|
raise HTTPException(status_code=404, detail="会话不存在或已删除")
|
||||||
@@ -373,6 +385,8 @@ async def remove_session(session_id: str):
|
|||||||
|
|
||||||
@app.post("/api/upload")
|
@app.post("/api/upload")
|
||||||
async def upload_file(file: UploadFile = File(...), session_id: str = ""):
|
async def upload_file(file: UploadFile = File(...), session_id: str = ""):
|
||||||
|
if session_id:
|
||||||
|
_check_session_id(session_id)
|
||||||
file_id = uuid.uuid4().hex[:12]
|
file_id = uuid.uuid4().hex[:12]
|
||||||
_ensure_upload_dir(session_id)
|
_ensure_upload_dir(session_id)
|
||||||
|
|
||||||
@@ -492,6 +506,7 @@ async def chat(session_id: str, payload: dict):
|
|||||||
Returns:
|
Returns:
|
||||||
text/event-stream (SSE)
|
text/event-stream (SSE)
|
||||||
"""
|
"""
|
||||||
|
_check_session_id(session_id)
|
||||||
text = payload.get("text", "").strip()
|
text = payload.get("text", "").strip()
|
||||||
file_ids = payload.get("file_ids", [])
|
file_ids = payload.get("file_ids", [])
|
||||||
|
|
||||||
@@ -577,6 +592,7 @@ async def chat(session_id: str, payload: dict):
|
|||||||
@app.get("/api/sessions/{session_id}/download/latest")
|
@app.get("/api/sessions/{session_id}/download/latest")
|
||||||
async def download_latest(session_id: str):
|
async def download_latest(session_id: str):
|
||||||
"""下载最新 JRXML 文件。"""
|
"""下载最新 JRXML 文件。"""
|
||||||
|
_check_session_id(session_id)
|
||||||
data = load_session(session_id)
|
data = load_session(session_id)
|
||||||
if data is None:
|
if data is None:
|
||||||
raise HTTPException(status_code=404, detail="会话不存在")
|
raise HTTPException(status_code=404, detail="会话不存在")
|
||||||
@@ -601,6 +617,7 @@ async def download_latest(session_id: str):
|
|||||||
@app.get("/api/sessions/{session_id}/download/{version}")
|
@app.get("/api/sessions/{session_id}/download/{version}")
|
||||||
async def download_version(session_id: str, version: int):
|
async def download_version(session_id: str, version: int):
|
||||||
"""下载指定版本的 JRXML 文件。"""
|
"""下载指定版本的 JRXML 文件。"""
|
||||||
|
_check_session_id(session_id)
|
||||||
data = load_session(session_id)
|
data = load_session(session_id)
|
||||||
if data is None:
|
if data is None:
|
||||||
raise HTTPException(status_code=404, detail="会话不存在")
|
raise HTTPException(status_code=404, detail="会话不存在")
|
||||||
|
|||||||
+18
-3
@@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
import tempfile
|
import tempfile
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
@@ -26,12 +27,20 @@ def _ensure_dir():
|
|||||||
SESSIONS_DIR.mkdir(parents=True, exist_ok=True)
|
SESSIONS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
_VALID_SESSION_ID_RE = re.compile(r'^[a-fA-F0-9]{12,}$')
|
||||||
|
|
||||||
|
def validate_session_id(session_id: str) -> bool:
|
||||||
|
"""校验 session_id 仅含合法 hex 字符(防路径穿越)。"""
|
||||||
|
return bool(_VALID_SESSION_ID_RE.match(session_id))
|
||||||
|
|
||||||
def _session_path(session_id: str) -> Path:
|
def _session_path(session_id: str) -> Path:
|
||||||
|
if not validate_session_id(session_id):
|
||||||
|
raise ValueError(f"Invalid session_id: {session_id!r}")
|
||||||
return SESSIONS_DIR / f"{session_id}.json"
|
return SESSIONS_DIR / f"{session_id}.json"
|
||||||
|
|
||||||
|
|
||||||
def generate_session_id() -> str:
|
def generate_session_id() -> str:
|
||||||
return uuid.uuid4().hex[:12]
|
return uuid.uuid4().hex
|
||||||
|
|
||||||
|
|
||||||
def create_session(name: str = "", agent_state: Optional[dict] = None,
|
def create_session(name: str = "", agent_state: Optional[dict] = None,
|
||||||
@@ -58,7 +67,10 @@ def create_session(name: str = "", agent_state: Optional[dict] = None,
|
|||||||
def load_session(session_id: str) -> Optional[dict]:
|
def load_session(session_id: str) -> Optional[dict]:
|
||||||
"""按 ID 加载会话数据。未找到则返回 None。"""
|
"""按 ID 加载会话数据。未找到则返回 None。"""
|
||||||
_ensure_dir()
|
_ensure_dir()
|
||||||
fp = _session_path(session_id)
|
try:
|
||||||
|
fp = _session_path(session_id)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
if not fp.exists():
|
if not fp.exists():
|
||||||
return None
|
return None
|
||||||
with open(fp, "r", encoding="utf-8") as f:
|
with open(fp, "r", encoding="utf-8") as f:
|
||||||
@@ -132,7 +144,10 @@ def list_all_sessions() -> list[dict]:
|
|||||||
def delete_session(session_id: str) -> bool:
|
def delete_session(session_id: str) -> bool:
|
||||||
"""按 ID 删除会话文件。"""
|
"""按 ID 删除会话文件。"""
|
||||||
_ensure_dir()
|
_ensure_dir()
|
||||||
fp = _session_path(session_id)
|
try:
|
||||||
|
fp = _session_path(session_id)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
if fp.exists():
|
if fp.exists():
|
||||||
fp.unlink()
|
fp.unlink()
|
||||||
_session_log.info("删除会话", extra={"session_id": session_id})
|
_session_log.info("删除会话", extra={"session_id": session_id})
|
||||||
|
|||||||
@@ -109,6 +109,11 @@ async function handleSend(text: string, files: File[]) {
|
|||||||
} catch (e: any) {
|
} catch (e: any) {
|
||||||
chat.setError(e.message || '网络请求失败')
|
chat.setError(e.message || '网络请求失败')
|
||||||
chat.addMessage({ role: 'assistant', content: `请求失败: ${e.message}`, type: 'error' })
|
chat.addMessage({ role: 'assistant', content: `请求失败: ${e.message}`, type: 'error' })
|
||||||
|
chat.finishStreaming({ status: '' })
|
||||||
|
} finally {
|
||||||
|
if (chat.streaming) {
|
||||||
|
chat.finishStreaming({ status: '' })
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ class TestSessionCRUD:
|
|||||||
resp = client.post("/api/sessions")
|
resp = client.post("/api/sessions")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
assert len(data["session_id"]) == 12
|
assert len(data["session_id"]) == 32
|
||||||
assert "session_name" in data
|
assert "session_name" in data
|
||||||
assert "created_at" in data
|
assert "created_at" in data
|
||||||
|
|
||||||
@@ -78,8 +78,11 @@ class TestSessionCRUD:
|
|||||||
assert resp.json()["session_id"] == created["session_id"]
|
assert resp.json()["session_id"] == created["session_id"]
|
||||||
assert "agent_state" in resp.json()
|
assert "agent_state" in resp.json()
|
||||||
|
|
||||||
|
def test_get_session_invalid_id(self, client, temp_sessions):
|
||||||
|
assert client.get("/api/sessions/nonexistent").status_code == 400
|
||||||
|
|
||||||
def test_get_session_not_found(self, client, temp_sessions):
|
def test_get_session_not_found(self, client, temp_sessions):
|
||||||
assert client.get("/api/sessions/nonexistent").status_code == 404
|
assert client.get("/api/sessions/aabbccddeeff0011223344").status_code == 404
|
||||||
|
|
||||||
def test_delete_session(self, client, temp_sessions):
|
def test_delete_session(self, client, temp_sessions):
|
||||||
sid = client.post("/api/sessions").json()["session_id"]
|
sid = client.post("/api/sessions").json()["session_id"]
|
||||||
@@ -89,7 +92,7 @@ class TestSessionCRUD:
|
|||||||
assert client.get(f"/api/sessions/{sid}").status_code == 404
|
assert client.get(f"/api/sessions/{sid}").status_code == 404
|
||||||
|
|
||||||
def test_delete_nonexistent(self, client, temp_sessions):
|
def test_delete_nonexistent(self, client, temp_sessions):
|
||||||
assert client.delete("/api/sessions/ghost_id").status_code == 404
|
assert client.delete("/api/sessions/aabbccddeeff0011223344").status_code == 404
|
||||||
|
|
||||||
def test_full_crud_lifecycle(self, client, temp_sessions):
|
def test_full_crud_lifecycle(self, client, temp_sessions):
|
||||||
sid = client.post("/api/sessions").json()["session_id"]
|
sid = client.post("/api/sessions").json()["session_id"]
|
||||||
@@ -116,7 +119,7 @@ class TestFileUpload:
|
|||||||
|
|
||||||
def test_upload_with_session_id_in_query(self, client, temp_sessions):
|
def test_upload_with_session_id_in_query(self, client, temp_sessions):
|
||||||
resp = client.post(
|
resp = client.post(
|
||||||
"/api/upload?session_id=abc123",
|
"/api/upload?session_id=aabbccddeeff0011223344",
|
||||||
files={"file": ("data.csv", io.BytesIO(b"a,b,c"), "text/csv")},
|
files={"file": ("data.csv", io.BytesIO(b"a,b,c"), "text/csv")},
|
||||||
)
|
)
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
@@ -146,7 +149,7 @@ class TestFileUpload:
|
|||||||
|
|
||||||
class TestDownload:
|
class TestDownload:
|
||||||
def test_download_missing_session_returns_404(self, client, temp_sessions):
|
def test_download_missing_session_returns_404(self, client, temp_sessions):
|
||||||
assert client.get("/api/sessions/missing/download/latest").status_code == 404
|
assert client.get("/api/sessions/aabbccddeeff0011223344/download/latest").status_code == 404
|
||||||
|
|
||||||
def test_download_no_jrxml_returns_404(self, client, temp_sessions):
|
def test_download_no_jrxml_returns_404(self, client, temp_sessions):
|
||||||
sid = client.post("/api/sessions").json()["session_id"]
|
sid = client.post("/api/sessions").json()["session_id"]
|
||||||
@@ -212,7 +215,7 @@ class TestChatSSE:
|
|||||||
def test_auto_creates_session_on_chat(self, client, temp_sessions):
|
def test_auto_creates_session_on_chat(self, client, temp_sessions):
|
||||||
with client.stream(
|
with client.stream(
|
||||||
"POST",
|
"POST",
|
||||||
"/api/sessions/auto_new_session/chat",
|
"/api/sessions/aabbccddeeff0011223344/chat",
|
||||||
json={"text": "生成报表", "file_ids": []},
|
json={"text": "生成报表", "file_ids": []},
|
||||||
) as resp:
|
) as resp:
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
@@ -231,16 +234,17 @@ class TestChatSSE:
|
|||||||
# ── 边界 & 安全测试 ────────────────────────────────────────────
|
# ── 边界 & 安全测试 ────────────────────────────────────────────
|
||||||
|
|
||||||
class TestBoundaries:
|
class TestBoundaries:
|
||||||
def test_session_id_path_traversal_returns_404(self, client, temp_sessions):
|
def test_session_id_invalid_format_returns_400(self, client, temp_sessions):
|
||||||
assert client.get("/api/sessions/../etc/passwd").status_code == 404
|
"""非 hex 字符的 session_id 应被拒绝。"""
|
||||||
|
assert client.get("/api/sessions/not_valid_hex_id").status_code == 400
|
||||||
|
|
||||||
def test_upload_with_path_traversal_session_id(self, client, temp_sessions):
|
def test_upload_with_path_traversal_session_id(self, client, temp_sessions):
|
||||||
"""路径穿越 session_id 仍正常处理(目录隔离在 UPLOADS_DIR 内)。"""
|
"""路径穿越 session_id 被拒绝。"""
|
||||||
resp = client.post(
|
resp = client.post(
|
||||||
"/api/upload?session_id=../malicious",
|
"/api/upload?session_id=../malicious",
|
||||||
files={"file": ("t.txt", io.BytesIO(b"x"), "text/plain")},
|
files={"file": ("t.txt", io.BytesIO(b"x"), "text/plain")},
|
||||||
)
|
)
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 400
|
||||||
|
|
||||||
def test_invalid_json_body_rejected(self, client, temp_sessions):
|
def test_invalid_json_body_rejected(self, client, temp_sessions):
|
||||||
sid = client.post("/api/sessions").json()["session_id"]
|
sid = client.post("/api/sessions").json()["session_id"]
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ def temp_sessions_dir(monkeypatch):
|
|||||||
class TestCreateSession:
|
class TestCreateSession:
|
||||||
def test_creates_with_defaults(self, temp_sessions_dir):
|
def test_creates_with_defaults(self, temp_sessions_dir):
|
||||||
s = create_session()
|
s = create_session()
|
||||||
assert len(s["session_id"]) == 12
|
assert len(s["session_id"]) == 32
|
||||||
assert "新建报表" in s["session_name"]
|
assert "新建报表" in s["session_name"]
|
||||||
assert s["created_at"]
|
assert s["created_at"]
|
||||||
assert s["updated_at"]
|
assert s["updated_at"]
|
||||||
@@ -139,7 +139,7 @@ class TestSaveSession:
|
|||||||
assert load_session(created["session_id"])["session_name"] == "原名"
|
assert load_session(created["session_id"])["session_name"] == "原名"
|
||||||
|
|
||||||
def test_fills_missing_created_at(self, temp_sessions_dir):
|
def test_fills_missing_created_at(self, temp_sessions_dir):
|
||||||
sid = "test_no_created"
|
sid = "aaaabbbbccccddddeeeeffff"
|
||||||
fp = temp_sessions_dir / f"{sid}.json"
|
fp = temp_sessions_dir / f"{sid}.json"
|
||||||
fp.write_text(
|
fp.write_text(
|
||||||
json.dumps({"session_id": sid, "session_name": "旧数据"}), "utf-8"
|
json.dumps({"session_id": sid, "session_name": "旧数据"}), "utf-8"
|
||||||
|
|||||||
Reference in New Issue
Block a user