"""Streamlit 多轮对话 UI,用于 JRXML 生成代理。
支持:
- 流式输出(LLM 逐字展示)
- 节点平铺展开(每个处理阶段独立展示)
- 完成后自动折叠节点区
- 过程总结卡片
"""
import os
import sys
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
try:
import torchvision
except Exception:
pass
import time
from pathlib import Path
import streamlit as st
import streamlit.components.v1 as components
from dotenv import load_dotenv
load_dotenv()
from agent.graph import build_graph, create_initial_state
from backend.session import (
create_session,
load_session,
delete_session,
list_all_sessions,
)
from backend.logger import get_logger, set_trace_id, generate_trace_id
_app_log = get_logger("app")
st.set_page_config(
page_title="JRXML 代理",
page_icon="📊",
layout="wide",
initial_sidebar_state="expanded",
)
# 阻止 Streamlit 裸 'c' 键清除缓存,保留 Ctrl+C 复制行为
st.html("""
""")
# ---- 节点名称 → 中文标签 ----
NODE_LABELS = {
"load_session": "📂 加载会话",
"process_input": "📝 记录输入",
"manage_context": "🧠 管理上下文",
"save_state_snapshot": "💾 保存快照",
"classify_intent": "🔍 识别意图",
"retrieve": "📚 检索模板",
"generate": "⚙️ 生成 JRXML",
"modify_jrxml": "🔧 修改 JRXML",
"validate": "✅ 验证",
"explain_error": "🔎 分析错误",
"correct_jrxml": "🛠 自动修正",
"finalize": "📋 完成",
"handle_consult": "💬 咨询回答",
"handle_undo": "↩ 撤销操作",
"handle_reset": "🔄 重置会话",
"save_session": "💾 保存会话",
}
INTENT_LABELS = {
"initial_generation": "新建报表",
"modify_report": "修改报表",
"preview_report": "预览报表",
"export_pdf": "导出 PDF",
"export_jrxml": "下载 JRXML",
"undo_modification": "撤销修改",
"consult_question": "咨询问题",
"reset_session": "重置会话",
}
SKIP_NODES = {"load_session", "process_input", "manage_context",
"save_state_snapshot", "save_session"}
def _render_jrxml(jrxml: str, max_lines: int = 30):
"""展示 JRXML 代码(折叠、限行)。"""
lines = jrxml.strip().split("\n")
preview = "\n".join(lines[:max_lines])
if len(lines) > max_lines:
preview += f"\n... (共 {len(lines)} 行)"
st.code(preview, language="xml")
# ---- URL 参数 ----
query_params = st.query_params
url_session_id = query_params.get("session_id", "")
# ---- 会话状态初始化 ----
if "messages" not in st.session_state:
st.session_state.messages = []
if "graph" not in st.session_state:
st.session_state.graph = build_graph()
if "pending_action" not in st.session_state:
st.session_state.pending_action = None
if "chat_attached_files" not in st.session_state:
st.session_state.chat_attached_files = [] # [{name, text, type, path}]
if "_paste_processed_ts" not in st.session_state:
st.session_state._paste_processed_ts = 0
if "agent_state" not in st.session_state:
if url_session_id:
data = load_session(url_session_id)
if data and data.get("agent_state"):
st.session_state.agent_state = data["agent_state"]
st.session_state.agent_state["session_id"] = url_session_id
else:
st.session_state.agent_state = create_initial_state()
new_data = create_session(name="", agent_state=st.session_state.agent_state)
st.session_state.agent_state["session_id"] = new_data["session_id"]
st.session_state.agent_state["session_name"] = new_data["session_name"]
st.session_state.agent_state["created_at"] = new_data["created_at"]
else:
st.session_state.agent_state = create_initial_state()
new_data = create_session(name="", agent_state=st.session_state.agent_state)
st.session_state.agent_state["session_id"] = new_data["session_id"]
st.session_state.agent_state["session_name"] = new_data["session_name"]
st.session_state.agent_state["created_at"] = new_data["created_at"]
current_session_id = st.session_state.agent_state.get("session_id", "")
def run_agent(user_input: str):
"""运行代理图:流式渲染节点进度 + LLM 文本。"""
trace_id = generate_trace_id()
set_trace_id(trace_id)
agent_state = st.session_state.agent_state
session_id = agent_state.get("session_id", "")
_app_log.info(
"代理执行开始",
extra={
"session_id": session_id,
"trace_id": trace_id,
"user_input_preview": user_input[:200],
"user_input_length": len(user_input),
"has_jrxml": bool(agent_state.get("current_jrxml", "").strip()),
"intent": agent_state.get("intent", ""),
},
)
if agent_state.get("current_jrxml") and agent_state.get("status") == "pass":
agent_state["user_modification_request"] = user_input
agent_state["user_input"] = user_input
agent_state["retry_count"] = 0
# ---- UI 占位 ----
progress_placeholder = st.empty() # 实时节点进度
streaming_placeholder = st.empty() # 流式文本
summary_placeholder = st.empty() # 总结卡片
# 初始状态提示
progress_placeholder.info("⏳ 正在分析您的需求...")
executed_nodes: list[dict] = []
stream_text = ""
stream_active = False
final_state = None
def _render_progress(nodes: list[dict]):
"""渲染实时节点进度到占位符。"""
if not nodes:
return
lines = []
for i, node in enumerate(nodes):
icon = "●" if i == len(nodes) - 1 else "✓"
detail = f" — {node['detail']}" if node.get("detail") else ""
lines.append(f"{icon} {node['label']}{detail}")
progress_placeholder.markdown("\n\n".join(lines))
try:
for event in st.session_state.graph.stream(
agent_state, stream_mode=["updates", "custom"]
):
mode, data = event
if mode == "updates":
for node_name, node_state in data.items():
label = NODE_LABELS.get(node_name, node_name)
if node_name not in SKIP_NODES:
executed_nodes.append({
"name": node_name,
"label": label,
})
if node_name == "classify_intent":
intent = node_state.get("intent", "")
il = INTENT_LABELS.get(intent, intent)
executed_nodes[-1]["detail"] = f"意图: {il}"
elif node_name == "retrieve":
ctx = node_state.get("retrieved_context", "")
executed_nodes[-1]["detail"] = (
f"找到 {len(ctx)} 字符参考模板" if ctx else "未匹配到模板"
)
elif node_name in ("generate", "modify_jrxml", "correct_jrxml"):
jrxml = node_state.get("current_jrxml", "")
executed_nodes[-1]["detail"] = f"生成 {len(jrxml)} 字符 JRXML"
elif node_name == "validate":
status = node_state.get("status", "")
if status == "pass":
executed_nodes[-1]["detail"] = "验证通过 ✓"
else:
err = node_state.get("error_msg", "")
executed_nodes[-1]["detail"] = f"验证失败: {err[:80]}"
elif node_name == "explain_error":
expl = node_state.get("natural_explanation", "")
executed_nodes[-1]["detail"] = expl[:120]
elif node_name == "handle_consult":
ans = node_state.get("consult_answer", "")
executed_nodes[-1]["detail"] = ans[:150]
final_state = node_state
# 每个节点完成后立即更新进度
_render_progress(executed_nodes)
elif mode == "custom":
cd = data
if cd.get("type") == "stream":
stream_text += cd.get("text", "")
stream_active = True
streaming_placeholder.code(stream_text, language="xml")
except Exception as e:
progress_placeholder.empty()
_app_log.error(
f"代理执行异常: {e}",
extra={"session_id": session_id, "error": str(e)},
)
st.error(f"工作流异常: {e}")
return
# ---- 清理临时占位 ----
progress_placeholder.empty()
if stream_active:
streaming_placeholder.empty()
# 清理已处理的临时文件
for p in st.session_state.get("uploaded_temp_paths", []):
try:
Path(p).unlink(missing_ok=True)
except Exception:
pass
st.session_state.uploaded_temp_paths = []
# ---- 总结卡片 ----
# 注:node_state 只含变更字段,用 agent_state(被所有节点就地修改)获取完整状态
final_state = agent_state
if final_state:
st.session_state.agent_state = final_state
intent = final_state.get("intent", "")
status = final_state.get("status", "")
with summary_placeholder.container(border=True):
if intent == "consult_question":
answer = final_state.get("consult_answer", "")
st.info(answer)
st.session_state.messages.append({
"role": "assistant", "content": answer, "type": "consult",
})
elif intent in ("undo_modification", "reset_session"):
st.success("操作已完成")
elif intent in ("preview_report", "export_pdf", "export_jrxml"):
jrxml = final_state.get("current_jrxml", "")
if jrxml:
st.success("✅ 当前报表")
_render_jrxml(jrxml)
st.session_state.messages.append({
"role": "assistant", "content": jrxml, "type": "jrxml",
})
else:
st.warning("⚠ 当前没有报表可以展示。")
elif status == "pass":
jrxml = final_state.get("current_jrxml", "")
st.success("✅ JRXML 生成成功")
st.markdown("**生成结果:**")
_render_jrxml(jrxml)
st.caption("您可以从侧边栏下载文件,或继续对话进行修改。")
st.session_state.messages.append({
"role": "assistant", "content": jrxml, "type": "jrxml",
})
st.session_state.messages.append({
"role": "assistant",
"content": "✅ JRXML 生成成功!您可以从侧边栏下载文件,或继续修改。",
"type": "success",
})
else:
jrxml = final_state.get("current_jrxml", "")
error_msg = final_state.get("error_msg", "未知错误")
explanation = final_state.get("natural_explanation", "")
retries = final_state.get("retry_count", 0)
st.error(f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML")
st.markdown(f"**错误:** {error_msg}")
if explanation:
st.markdown(f"**原因:** {explanation}")
if jrxml:
with st.expander("查看当前 JRXML"):
_render_jrxml(jrxml, max_lines=80)
st.caption("💡 下次输入修改需求时,系统会自动加载失败上下文继续修复。")
st.session_state.messages.append({
"role": "assistant",
"content": f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML。\n\n**错误:** {error_msg}\n\n💡 请直接描述修改需求,系统会自动加载失败上下文。",
"type": "error_explanation",
})
# OCR 字段提取结果展示
ocr_result = agent_state.get("ocr_extraction_result", {})
if ocr_result and ocr_result.get("ocr_available") and ocr_result.get("fields"):
with st.expander("🔍 OCR 单据字段提取结果", expanded=False):
fields = ocr_result.get("fields", [])
non_empty = [f for f in fields if f.get("field_value")]
empty = [f for f in fields if not f.get("field_value")]
if non_empty:
st.markdown("**已提取字段:**")
for f in non_empty:
method = f.get("extraction_method", "")
conf = f.get("confidence", 0)
st.markdown(
f"- **{f['field_name']}**: `{f['field_value']}` "
f"(置信度: {conf:.0%}, 方法: {method})"
)
if empty:
st.caption(
f"未提取到值的字段: {', '.join(f['field_name'] for f in empty)}"
)
st.caption(
f"共检测到 {ocr_result.get('total_elements', 0)} 个文本元素"
)
else:
st.error("未产生结果,请重试。")
_app_log.info(
"代理执行完成",
extra={
"session_id": session_id,
"intent": final_state.get("intent", ""),
"status": final_state.get("status", ""),
"jrxml_length": len(final_state.get("current_jrxml", "")),
"retry_count": final_state.get("retry_count", 0),
},
)
# ---- 侧边栏 ----
with st.sidebar:
st.title("📊 JRXML 代理")
st.markdown("通过自然语言生成 JasperReports 模板。")
st.divider()
# 会话管理
st.markdown("### 会话管理")
sessions = list_all_sessions()
session_options = {}
for s in sessions:
sid = s["session_id"]
name = s.get("session_name", sid)
updated = s.get("updated_at", "")[:16]
session_options[f"{name} ({updated})"] = sid
selected_label = None
for label, sid in session_options.items():
if sid == current_session_id:
selected_label = label
break
selected = st.selectbox(
"切换会话",
options=list(session_options.keys()),
index=list(session_options.keys()).index(selected_label) if selected_label else 0,
key="session_selector",
)
if selected and session_options.get(selected) != current_session_id:
new_sid = session_options[selected]
if st.session_state.get("_last_switched_to") == new_sid:
# 防止同一会话重复切换导致的无限 rerun 循环
st.session_state._last_switched_to = ""
else:
data = load_session(new_sid)
if data and data.get("agent_state"):
_app_log.info(
"切换会话",
extra={"from_session": current_session_id, "to_session": new_sid},
)
data["agent_state"]["session_id"] = new_sid
st.session_state.agent_state = data["agent_state"]
st.session_state.messages = []
st.session_state._last_switched_to = new_sid
st.rerun()
col1, col2 = st.columns(2)
with col1:
if st.button("➕ 新建", use_container_width=True):
new_data = create_session(name="", agent_state=create_initial_state())
_app_log.info(
"新建会话",
extra={"session_id": new_data["session_id"]},
)
st.session_state.agent_state = create_initial_state()
st.session_state.agent_state["session_id"] = new_data["session_id"]
st.session_state.agent_state["session_name"] = new_data["session_name"]
st.session_state.agent_state["created_at"] = new_data["created_at"]
st.session_state.messages = []
st.rerun()
with col2:
if st.button("🗑 删除", use_container_width=True):
if current_session_id:
_app_log.info(
"删除会话",
extra={"session_id": current_session_id},
)
delete_session(current_session_id)
st.session_state.agent_state = create_initial_state()
new_data = create_session(name="", agent_state=st.session_state.agent_state)
st.session_state.agent_state["session_id"] = new_data["session_id"]
st.session_state.agent_state["session_name"] = new_data["session_name"]
st.session_state.agent_state["created_at"] = new_data["created_at"]
st.session_state.messages = []
st.rerun()
current_name = st.session_state.agent_state.get("session_name", "")
st.caption(f"当前: {current_name} (`{current_session_id}`)")
st.divider()
st.markdown("### 快捷操作")
has_jrxml = bool(st.session_state.agent_state.get("current_jrxml", "").strip())
has_history = bool(st.session_state.agent_state.get("history_states", []))
qcol1, qcol2 = st.columns(2)
with qcol1:
if st.button("👁 预览", use_container_width=True, disabled=not has_jrxml):
with st.spinner("正在准备预览..."):
run_agent("预览报表")
st.rerun()
with qcol2:
if st.button("↩ 撤销", use_container_width=True, disabled=not has_history):
with st.spinner("正在撤销..."):
run_agent("撤销上一步修改")
st.rerun()
if st.button("🔄 重置会话", use_container_width=True):
with st.spinner("正在重置..."):
run_agent("重新来,清空当前报表")
st.rerun()
st.divider()
st.markdown("### 上传文件")
st.caption("支持图片 (OCR)、PDF、Word、文本文件。内容将附加到您的下一条消息中。")
if "uploaded_files" not in st.session_state:
st.session_state.uploaded_files = [] # [{name, text, type}]
if "uploaded_temp_paths" not in st.session_state:
st.session_state.uploaded_temp_paths = [] # 待清理的临时文件路径
uploaded = st.file_uploader(
"选择文件",
type=["png", "jpg", "jpeg", "bmp", "webp", "pdf", "docx", "xlsx", "txt", "csv", "json", "xml"],
accept_multiple_files=True,
key="file_uploader",
label_visibility="collapsed",
)
if uploaded:
for uf in uploaded:
# 去重
if any(f["name"] == uf.name for f in st.session_state.uploaded_files):
continue
import tempfile
from backend.file_parser import parse_file
from backend.layout_analyzer import analyze_layout
suffix = Path(uf.name).suffix.lower()
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(uf.getvalue())
tmp_path = tmp.name
result = parse_file(tmp_path, suffix)
# 对图片/PDF 进行 A4 模板布局分析
parsed_text = result["text"]
parsed_type = result["file_type"]
if suffix in (".png", ".jpg", ".jpeg", ".bmp", ".webp", ".pdf"):
layout = analyze_layout(tmp_path)
tt = layout.get("template_type", "unknown")
current_jrxml = st.session_state.agent_state.get("current_jrxml", "")
if tt == "full_a4":
parsed_text = layout["description"]
parsed_type = "a4_template"
elif tt == "partial_rows":
parsed_type = "a4_partial"
if current_jrxml.strip():
# 修改模式:尝试行匹配
from backend.layout_analyzer import match_rows_to_jrxml
match = match_rows_to_jrxml(layout, current_jrxml)
parsed_text = (
f"[行片段修改] 上传图片包含 {layout['total_rows']} 行,"
f"视为 A4 报表的一部分。\n\n"
f"{match['description']}\n\n"
f"--- 行结构 ---\n{layout['description']}"
)
else:
# 新建模式:按 A4 模板处理
parsed_text = layout["description"]
else:
# tt == "unknown": OCR 不可用或未检测到文字元素
has_ocr = result.get("method") not in ("metadata_only", None)
img_w, img_h = layout["image_size"]
ratio = layout["aspect_ratio"]
if has_ocr:
parsed_text = (
f"[图片上传] 尺寸 {img_w}x{img_h}px, 比例 {ratio}。"
f"未检测到 A4 报表结构,图片将被视为参考样式。\n"
f"请根据用户的文字描述生成报表。"
)
else:
parsed_text = (
f"[图片上传] 尺寸 {img_w}x{img_h}px, 比例 {ratio}。\n"
f"⚠ OCR 引擎未安装,无法识别图片中的文字内容。\n"
f"请严格根据用户的文字描述来推断图片中的报表需求。\n"
f"(提示:如需图片文字识别,请运行 pip install paddleocr)"
)
parsed_type = "image_reference"
if parsed_text:
st.session_state.uploaded_files.append({
"name": uf.name,
"text": parsed_text,
"type": parsed_type,
})
# 对图片类型,保存路径以便 OCR 字段提取(延迟到 process_input 阶段)
img_suffixes = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
if suffix in img_suffixes and result.get("method") not in ("metadata_only", None):
st.session_state.agent_state["uploaded_file_path"] = tmp_path
st.session_state.uploaded_temp_paths.append(tmp_path)
else:
Path(tmp_path).unlink(missing_ok=True)
if st.session_state.uploaded_files:
for i, f in enumerate(st.session_state.uploaded_files):
cols = st.columns([5, 1])
with cols[0]:
st.caption(f"📎 {f['name']} ({f['type']}, {len(f['text'])} 字符)")
with cols[1]:
if st.button("✕", key=f"rm_uf_{i}", help="移除"):
st.session_state.uploaded_files.pop(i)
st.rerun()
st.divider()
st.markdown("### 配置")
llm_backend = os.getenv("LLM_BACKEND", "cloud")
llm_model = os.getenv("LLM_MODEL", os.getenv("LOCAL_LLM_MODEL", "gpt-4o"))
st.caption(f"大语言模型: {llm_backend} / {llm_model}")
st.caption(f"最大重试次数: {os.getenv('MAX_RETRY', '3')}")
st.caption(f"验证服务: {os.getenv('VALIDATION_SERVICE_URL', 'http://localhost:8001/validate')}")
st.divider()
st.markdown("### 下载")
final = st.session_state.agent_state.get("final_jrxml", "")
versions = st.session_state.agent_state.get("jrxml_versions", [])
if final:
st.download_button(
label="📥 下载最新 JRXML",
data=final,
file_name="report.jrxml",
mime="application/xml",
use_container_width=True,
)
if versions:
with st.expander("📋 历史版本", expanded=False):
for i, v in enumerate(reversed(versions)):
ts = v.get("ts", "")[:16]
label = v.get("label", "版本")
status = v.get("status", "")
icon = "✅" if status == "pass" else "❌"
dl_label = f"{icon} v{len(versions)-i} — {label} ({ts})"
st.download_button(
label=dl_label,
data=v.get("jrxml", ""),
file_name=f"report_v{len(versions)-i}.jrxml",
mime="application/xml",
use_container_width=True,
key=f"dl_v{i}",
)
# ---- 文件粘贴/拖拽全局处理器 ----
st.html("""
""")
# ---- 粘贴桥接组件 ----
paste_data = components.html("""
""", height=0, default=0)
if paste_data and paste_data != 0:
pts = paste_data.get("ts", 0)
if pts > st.session_state._paste_processed_ts:
st.session_state._paste_processed_ts = pts
import base64, tempfile
from backend.file_parser import parse_file
from backend.layout_analyzer import analyze_layout
for fi in paste_data.get("files", []):
if not any(f["name"] == fi["name"] for f in st.session_state.chat_attached_files):
header, b64 = fi["data"].split(",", 1)
raw = base64.b64decode(b64)
suffix = Path(fi["name"]).suffix.lower()
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(raw)
tmp_path = tmp.name
result = parse_file(tmp_path, suffix)
text = result["text"]
file_type = result["file_type"]
img_suffixes = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
if suffix in img_suffixes and result.get("method") not in ("metadata_only", None):
try:
layout = analyze_layout(tmp_path)
tt = layout.get("template_type", "unknown")
if tt == "full_a4":
text = layout["description"]
file_type = "a4_template"
elif tt == "partial_rows":
file_type = "a4_partial"
except Exception:
pass
st.session_state.chat_attached_files.append({
"name": fi["name"], "text": text, "type": file_type, "path": tmp_path
})
st.rerun()
# ---- 标题 ----
st.title("📝 JRXML 报表生成器")
st.caption("用自然语言描述您的报表需求,我将逐步生成可用的 JRXML 模板。")
# ---- 聊天历史 ----
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
if msg.get("type") == "jrxml":
with st.expander("查看生成的 JRXML", expanded=False):
st.code(msg["content"], language="xml")
elif msg.get("type") == "error_explanation":
st.warning(msg["content"])
elif msg.get("type") == "success":
st.success(msg["content"])
elif msg.get("type") == "consult":
st.info(msg["content"])
else:
st.markdown(msg["content"])
# ---- 已附加文件预览 ----
if st.session_state.chat_attached_files:
n_files = len(st.session_state.chat_attached_files)
chip_cols = st.columns(min(n_files, 4))
files_to_remove = []
for i, f in enumerate(st.session_state.chat_attached_files):
with chip_cols[i % len(chip_cols)]:
c1, c2 = st.columns([5, 1])
with c1:
name = f["name"]
short_name = name[:16] + ("…" if len(name) > 16 else "")
emoji_map = {"a4_template": "📷", "image": "🖼", "pdf": "📄", "docx": "📝", "xlsx": "📊"}
emoji = emoji_map.get(f["type"], "📎")
st.caption(f"{emoji} {short_name}")
with c2:
if st.button("✕", key=f"rm_chip_{i}"):
files_to_remove.append(i)
if files_to_remove:
for i in sorted(files_to_remove, reverse=True):
try:
Path(st.session_state.chat_attached_files[i]["path"]).unlink(missing_ok=True)
except Exception:
pass
st.session_state.chat_attached_files.pop(i)
st.rerun()
# ---- 对话区域文件上传 ----
col_fu, col_hint = st.columns([5, 1])
with col_fu:
chat_uploads = st.file_uploader(
"附加文件",
type=["png", "jpg", "jpeg", "bmp", "webp", "pdf", "docx", "xlsx", "txt", "csv", "json", "xml"],
accept_multiple_files=True,
key="chat_file_uploader",
label_visibility="visible",
)
with col_hint:
st.caption("Ctrl+V 粘贴\n或拖拽到页面")
if chat_uploads:
newly_added = False
import tempfile
from backend.file_parser import parse_file
from backend.layout_analyzer import analyze_layout
for uf in chat_uploads:
if not any(f["name"] == uf.name for f in st.session_state.chat_attached_files):
suffix = Path(uf.name).suffix.lower()
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(uf.getvalue())
tmp_path = tmp.name
result = parse_file(tmp_path, suffix)
text = result["text"]
file_type = result["file_type"]
img_suffixes = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
if suffix in img_suffixes and result.get("method") not in ("metadata_only", None):
try:
layout = analyze_layout(tmp_path)
tt = layout.get("template_type", "unknown")
if tt == "full_a4":
text = layout["description"]
file_type = "a4_template"
elif tt == "partial_rows":
file_type = "a4_partial"
except Exception:
pass
st.session_state.chat_attached_files.append({
"name": uf.name, "text": text, "type": file_type, "path": tmp_path
})
newly_added = True
if newly_added:
st.session_state.chat_file_uploader = []
st.rerun()
# ---- 聊天输入 ----
if prompt := st.chat_input("描述您的报表需求..."):
# 拼接对话区域附加文件的文本
file_texts = []
attached_info = []
for f in st.session_state.chat_attached_files:
file_texts.append(f"[附加文件: {f['name']} ({f['type']})]\n{f['text']}")
attached_info.append({"name": f["name"], "type": f["type"], "length": len(f["text"])})
# 同时拼接侧边栏上传的文件(向后兼容)
if st.session_state.get("uploaded_files"):
for f in st.session_state.uploaded_files:
file_texts.append(f"[上传文件: {f['name']}]\n{f['text']}")
attached_info.append({"name": f["name"], "type": f["type"], "length": len(f["text"])})
if file_texts:
full_prompt = "\n\n".join(file_texts) + "\n\n---\n用户需求:\n" + prompt
else:
full_prompt = prompt
# 将第一个图片文件的路径传给 agent,供 OCR 字段精确提取
for f in st.session_state.chat_attached_files:
if f["type"] in ("image", "a4_template", "a4_partial"):
st.session_state.agent_state["uploaded_file_path"] = f["path"]
break
# 清理临时文件和状态
st.session_state.uploaded_files = []
for f in st.session_state.chat_attached_files:
try:
Path(f["path"]).unlink(missing_ok=True)
except Exception:
pass
st.session_state.chat_attached_files = []
_app_log.info(
"收到用户输入",
extra={
"session_id": current_session_id,
"prompt_preview": prompt[:200],
"prompt_length": len(prompt),
"has_uploaded_files": bool(attached_info),
"uploaded_files": attached_info,
},
)
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
run_agent(full_prompt)
st.rerun()