fix: session persistence, multi-turn memory, OCR pipeline, download UX (v7)
- graph.stream() state fix: agent_state now properly accumulates node updates - atomic session save (tempfile + os.replace) - uploaded_file_path injection for OcrExtractor + annotation_detector - download section always visible; refreshFromApi auto-reloads after generation - node_start/complete unfiltered for full progress visibility - modification_request without status=='pass' check
This commit is contained in:
+25
-14
@@ -166,10 +166,21 @@ def _extract_detail(node_name: str, node_state: dict) -> str:
|
||||
|
||||
|
||||
def _run_graph_sync(agent_state: AgentState, event_q: queue.Queue):
|
||||
"""在后台线程中运行 graph.stream(),将所有事件推入队列。"""
|
||||
"""在后台线程中运行 graph.stream(),将所有事件推入队列。
|
||||
|
||||
graph.stream() 只产出事件,不修改传入的 agent_state。
|
||||
因此需要手动收集每个节点的返回并合并到 agent_state。
|
||||
"""
|
||||
try:
|
||||
for event in _graph.stream(agent_state, stream_mode=["updates", "custom"]):
|
||||
event_q.put(event)
|
||||
# 将节点更新合并到 agent_state
|
||||
if isinstance(event, tuple) and len(event) == 2:
|
||||
mode, data = event
|
||||
if mode == "updates" and isinstance(data, dict):
|
||||
for node_state in data.values():
|
||||
if isinstance(node_state, dict):
|
||||
agent_state.update(node_state)
|
||||
event_q.put(("done", {"reason": "graph_completed"}))
|
||||
except Exception as exc:
|
||||
event_q.put(("error", {
|
||||
@@ -178,7 +189,7 @@ def _run_graph_sync(agent_state: AgentState, event_q: queue.Queue):
|
||||
}))
|
||||
|
||||
|
||||
async def _sse_generator(agent_state: AgentState) -> str:
|
||||
async def _sse_generator(agent_state: AgentState, session_id: str = "") -> str:
|
||||
"""SSE 事件生成器 —— 在后台线程运行图,异步产出 SSE 字符串。"""
|
||||
global _current_event_queue, _step_counter
|
||||
|
||||
@@ -205,6 +216,8 @@ async def _sse_generator(agent_state: AgentState) -> str:
|
||||
if kind == "done":
|
||||
_current_event_queue = None
|
||||
total_ms = round((time.time() - t_start) * 1000)
|
||||
if session_id:
|
||||
save_session(session_id, agent_state)
|
||||
yield _sse_line("agent_complete", {
|
||||
"reason": "done",
|
||||
"intent": agent_state.get("intent", ""),
|
||||
@@ -233,13 +246,12 @@ async def _sse_generator(agent_state: AgentState) -> str:
|
||||
mode, data = item
|
||||
if mode == "updates":
|
||||
for node_name, node_state in data.items():
|
||||
if node_name not in SKIP_NODES:
|
||||
detail = _extract_detail(node_name, node_state)
|
||||
yield _sse_line("node_complete", {
|
||||
"node": node_name,
|
||||
"label": NODE_LABELS.get(node_name, node_name),
|
||||
"detail": detail,
|
||||
})
|
||||
detail = _extract_detail(node_name, node_state)
|
||||
yield _sse_line("node_complete", {
|
||||
"node": node_name,
|
||||
"label": NODE_LABELS.get(node_name, node_name),
|
||||
"detail": detail,
|
||||
})
|
||||
elif mode == "custom":
|
||||
cd = data
|
||||
if cd.get("type") == "stream":
|
||||
@@ -500,9 +512,11 @@ async def chat(session_id: str, payload: dict):
|
||||
for line in file_result["ocr_text"].split("\n") if line.strip()]
|
||||
if ocr_rows:
|
||||
agent_state["ocr_elements"] = ocr_rows
|
||||
if file_result.get("uploaded_paths"):
|
||||
agent_state["uploaded_file_path"] = file_result["uploaded_paths"][0]
|
||||
|
||||
# ── 设置本轮输入 ──
|
||||
if agent_state.get("current_jrxml") and agent_state.get("status") == "pass":
|
||||
if agent_state.get("current_jrxml"):
|
||||
agent_state["user_modification_request"] = full_prompt
|
||||
|
||||
agent_state["user_input"] = full_prompt
|
||||
@@ -519,12 +533,9 @@ async def chat(session_id: str, payload: dict):
|
||||
# ── 返回 SSE 流 ──
|
||||
async def stream_and_save():
|
||||
final_state = None
|
||||
async for sse_chunk in _sse_generator(agent_state):
|
||||
async for sse_chunk in _sse_generator(agent_state, session_id):
|
||||
yield sse_chunk
|
||||
|
||||
# 图执行完成后保存会话状态
|
||||
save_session(session_id, agent_state)
|
||||
|
||||
return StreamingResponse(
|
||||
stream_and_save(),
|
||||
media_type="text/event-stream",
|
||||
|
||||
Reference in New Issue
Block a user