927 lines
32 KiB
Python
927 lines
32 KiB
Python
"""Streamlit 多轮对话 UI,用于 JRXML 生成代理。
|
||
|
||
支持:
|
||
- 流式输出(LLM 逐字展示)
|
||
- 节点平铺展开(每个处理阶段独立展示)
|
||
- 完成后自动折叠节点区
|
||
- 过程总结卡片
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
|
||
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
|
||
|
||
try:
|
||
import torchvision
|
||
except Exception:
|
||
pass
|
||
|
||
import base64
|
||
import tempfile
|
||
import time
|
||
from pathlib import Path
|
||
|
||
import streamlit as st
|
||
import streamlit.components.v1 as components
|
||
|
||
from dotenv import load_dotenv
|
||
load_dotenv(override=True)
|
||
|
||
from agent.graph import build_graph, create_initial_state
|
||
from backend.session import (
|
||
create_session,
|
||
load_session,
|
||
delete_session,
|
||
list_all_sessions,
|
||
)
|
||
from backend.logger import get_logger, set_trace_id, generate_trace_id
|
||
|
||
_app_log = get_logger("app")
|
||
|
||
st.set_page_config(
|
||
page_title="JRXML 代理",
|
||
page_icon="📊",
|
||
layout="wide",
|
||
initial_sidebar_state="expanded",
|
||
)
|
||
|
||
# 阻止 Streamlit 裸 'c' 键清除缓存,保留 Ctrl+C 复制行为
|
||
st.html("""
|
||
<script>
|
||
(function() {
|
||
const parent = window.parent.document;
|
||
parent.addEventListener('keydown', function(e) {
|
||
// 仅拦截裸 'c' 键(非 Ctrl/Cmd 组合)
|
||
if (e.key === 'c' && !e.ctrlKey && !e.metaKey && !e.altKey) {
|
||
const tag = parent.activeElement ? parent.activeElement.tagName : '';
|
||
if (tag !== 'INPUT' && tag !== 'TEXTAREA' && !parent.activeElement.isContentEditable) {
|
||
e.stopImmediatePropagation();
|
||
e.preventDefault();
|
||
}
|
||
}
|
||
}, true);
|
||
})();
|
||
</script>
|
||
""")
|
||
|
||
# ---- 节点名称 → 中文标签 ----
|
||
NODE_LABELS = {
|
||
"load_session": "📂 加载会话",
|
||
"process_input": "📝 记录输入",
|
||
"manage_context": "🧠 管理上下文",
|
||
"save_state_snapshot": "💾 保存快照",
|
||
"classify_intent": "🔍 识别意图",
|
||
"retrieve": "📚 检索模板",
|
||
"generate": "⚙️ 生成 JRXML",
|
||
"modify_jrxml": "🔧 修改 JRXML",
|
||
"validate": "✅ 验证",
|
||
"explain_error": "🔎 分析错误",
|
||
"correct_jrxml": "🛠 自动修正",
|
||
"finalize": "📋 完成",
|
||
"handle_consult": "💬 咨询回答",
|
||
"handle_undo": "↩ 撤销操作",
|
||
"handle_reset": "🔄 重置会话",
|
||
"save_session": "💾 保存会话",
|
||
"generate_skeleton": "🏗 生成骨架",
|
||
"refine_layout": "📐 精调布局",
|
||
"map_fields": "🏷 映射字段",
|
||
}
|
||
|
||
INTENT_LABELS = {
|
||
"initial_generation": "新建报表",
|
||
"modify_report": "修改报表",
|
||
"preview_report": "预览报表",
|
||
"export_pdf": "导出 PDF",
|
||
"export_jrxml": "下载 JRXML",
|
||
"undo_modification": "撤销修改",
|
||
"consult_question": "咨询问题",
|
||
"reset_session": "重置会话",
|
||
}
|
||
|
||
SKIP_NODES = {"load_session", "process_input", "manage_context",
|
||
"save_state_snapshot", "save_session"}
|
||
|
||
|
||
def _render_jrxml(jrxml: str, max_lines: int = 30):
|
||
"""展示 JRXML 代码(折叠、限行)。"""
|
||
lines = jrxml.strip().split("\n")
|
||
preview = "\n".join(lines[:max_lines])
|
||
if len(lines) > max_lines:
|
||
preview += f"\n... (共 {len(lines)} 行)"
|
||
st.code(preview, language="xml")
|
||
|
||
|
||
# ---- URL 参数 ----
|
||
query_params = st.query_params
|
||
url_session_id = query_params.get("session_id", "")
|
||
|
||
# ---- 会话状态初始化 ----
|
||
if "messages" not in st.session_state:
|
||
st.session_state.messages = []
|
||
if "graph" not in st.session_state:
|
||
st.session_state.graph = build_graph()
|
||
if "pending_action" not in st.session_state:
|
||
st.session_state.pending_action = None
|
||
if "agent_state" not in st.session_state:
|
||
if url_session_id:
|
||
data = load_session(url_session_id)
|
||
if data and data.get("agent_state"):
|
||
st.session_state.agent_state = data["agent_state"]
|
||
st.session_state.agent_state["session_id"] = url_session_id
|
||
else:
|
||
st.session_state.agent_state = create_initial_state()
|
||
new_data = create_session(name="", agent_state=st.session_state.agent_state)
|
||
st.session_state.agent_state["session_id"] = new_data["session_id"]
|
||
st.session_state.agent_state["session_name"] = new_data["session_name"]
|
||
st.session_state.agent_state["created_at"] = new_data["created_at"]
|
||
else:
|
||
st.session_state.agent_state = create_initial_state()
|
||
new_data = create_session(name="", agent_state=st.session_state.agent_state)
|
||
st.session_state.agent_state["session_id"] = new_data["session_id"]
|
||
st.session_state.agent_state["session_name"] = new_data["session_name"]
|
||
st.session_state.agent_state["created_at"] = new_data["created_at"]
|
||
|
||
current_session_id = st.session_state.agent_state.get("session_id", "")
|
||
|
||
|
||
def run_agent(user_input: str):
|
||
"""运行代理图:流式渲染节点进度 + LLM 文本。"""
|
||
trace_id = generate_trace_id()
|
||
set_trace_id(trace_id)
|
||
agent_state = st.session_state.agent_state
|
||
session_id = agent_state.get("session_id", "")
|
||
|
||
_app_log.info(
|
||
"代理执行开始",
|
||
extra={
|
||
"session_id": session_id,
|
||
"trace_id": trace_id,
|
||
"user_input_preview": user_input[:200],
|
||
"user_input_length": len(user_input),
|
||
"has_jrxml": bool(agent_state.get("current_jrxml", "").strip()),
|
||
"intent": agent_state.get("intent", ""),
|
||
},
|
||
)
|
||
|
||
if agent_state.get("current_jrxml") and agent_state.get("status") == "pass":
|
||
agent_state["user_modification_request"] = user_input
|
||
|
||
agent_state["user_input"] = user_input
|
||
agent_state["retry_count"] = 0
|
||
|
||
# ---- UI 占位 ----
|
||
progress_placeholder = st.empty() # 实时节点进度
|
||
streaming_placeholder = st.empty() # 流式文本
|
||
summary_placeholder = st.empty() # 总结卡片
|
||
|
||
# 初始状态提示
|
||
progress_placeholder.info("⏳ 正在分析您的需求...")
|
||
|
||
executed_nodes: list[dict] = []
|
||
stream_text = ""
|
||
stream_active = False
|
||
final_state = None
|
||
|
||
def _render_progress(nodes: list[dict]):
|
||
"""渲染实时节点进度到占位符。"""
|
||
if not nodes:
|
||
return
|
||
lines = []
|
||
for i, node in enumerate(nodes):
|
||
icon = "●" if i == len(nodes) - 1 else "✓"
|
||
detail = f" — {node['detail']}" if node.get("detail") else ""
|
||
lines.append(f"{icon} {node['label']}{detail}")
|
||
progress_placeholder.markdown("\n\n".join(lines))
|
||
|
||
try:
|
||
for event in st.session_state.graph.stream(
|
||
agent_state, stream_mode=["updates", "custom"]
|
||
):
|
||
mode, data = event
|
||
|
||
if mode == "updates":
|
||
for node_name, node_state in data.items():
|
||
label = NODE_LABELS.get(node_name, node_name)
|
||
if node_name not in SKIP_NODES:
|
||
executed_nodes.append({
|
||
"name": node_name,
|
||
"label": label,
|
||
})
|
||
|
||
if node_name == "classify_intent":
|
||
intent = node_state.get("intent", "")
|
||
il = INTENT_LABELS.get(intent, intent)
|
||
executed_nodes[-1]["detail"] = f"意图: {il}"
|
||
|
||
elif node_name == "retrieve":
|
||
ctx = node_state.get("retrieved_context", "")
|
||
executed_nodes[-1]["detail"] = (
|
||
f"找到 {len(ctx)} 字符参考模板" if ctx else "未匹配到模板"
|
||
)
|
||
|
||
elif node_name in ("generate", "modify_jrxml", "correct_jrxml",
|
||
"generate_skeleton", "refine_layout", "map_fields"):
|
||
jrxml = node_state.get("current_jrxml", "")
|
||
executed_nodes[-1]["detail"] = f"生成 {len(jrxml)} 字符 JRXML"
|
||
|
||
elif node_name == "validate":
|
||
status = node_state.get("status", "")
|
||
if status == "pass":
|
||
executed_nodes[-1]["detail"] = "验证通过 ✓"
|
||
else:
|
||
err = node_state.get("error_msg", "")
|
||
executed_nodes[-1]["detail"] = f"验证失败: {err[:80]}"
|
||
|
||
elif node_name == "explain_error":
|
||
expl = node_state.get("natural_explanation", "")
|
||
executed_nodes[-1]["detail"] = expl[:120]
|
||
|
||
elif node_name == "handle_consult":
|
||
ans = node_state.get("consult_answer", "")
|
||
executed_nodes[-1]["detail"] = ans[:150]
|
||
|
||
final_state = node_state
|
||
|
||
# 每个节点完成后立即更新进度
|
||
_render_progress(executed_nodes)
|
||
|
||
elif mode == "custom":
|
||
cd = data
|
||
if cd.get("type") == "stream":
|
||
stream_text += cd.get("text", "")
|
||
stream_active = True
|
||
streaming_placeholder.code(stream_text, language="xml")
|
||
|
||
except Exception as e:
|
||
progress_placeholder.empty()
|
||
_app_log.error(
|
||
f"代理执行异常: {e}",
|
||
extra={"session_id": session_id, "error": str(e)},
|
||
)
|
||
st.error(f"工作流异常: {e}")
|
||
return
|
||
|
||
# ---- 清理临时占位 ----
|
||
progress_placeholder.empty()
|
||
if stream_active:
|
||
streaming_placeholder.empty()
|
||
|
||
# ---- 总结卡片 ----
|
||
# 注:node_state 只含变更字段,用 agent_state(被所有节点就地修改)获取完整状态
|
||
final_state = agent_state
|
||
if final_state:
|
||
st.session_state.agent_state = final_state
|
||
intent = final_state.get("intent", "")
|
||
status = final_state.get("status", "")
|
||
|
||
with summary_placeholder.container(border=True):
|
||
if intent == "consult_question":
|
||
answer = final_state.get("consult_answer", "")
|
||
st.info(answer)
|
||
st.session_state.messages.append({
|
||
"role": "assistant", "content": answer, "type": "consult",
|
||
})
|
||
|
||
elif intent in ("undo_modification", "reset_session"):
|
||
st.success("操作已完成")
|
||
|
||
elif intent in ("preview_report", "export_pdf", "export_jrxml"):
|
||
jrxml = final_state.get("current_jrxml", "")
|
||
if jrxml:
|
||
st.success("✅ 当前报表")
|
||
_render_jrxml(jrxml)
|
||
st.session_state.messages.append({
|
||
"role": "assistant", "content": jrxml, "type": "jrxml",
|
||
})
|
||
else:
|
||
st.warning("⚠ 当前没有报表可以展示。")
|
||
|
||
elif status == "pass":
|
||
jrxml = final_state.get("current_jrxml", "")
|
||
st.success("✅ JRXML 生成成功")
|
||
st.markdown("**生成结果:**")
|
||
_render_jrxml(jrxml)
|
||
st.caption("您可以从侧边栏下载文件,或继续对话进行修改。")
|
||
st.session_state.messages.append({
|
||
"role": "assistant", "content": jrxml, "type": "jrxml",
|
||
})
|
||
st.session_state.messages.append({
|
||
"role": "assistant",
|
||
"content": "✅ JRXML 生成成功!您可以从侧边栏下载文件,或继续修改。",
|
||
"type": "success",
|
||
})
|
||
|
||
else:
|
||
jrxml = final_state.get("current_jrxml", "")
|
||
error_msg = final_state.get("error_msg", "未知错误")
|
||
explanation = final_state.get("natural_explanation", "")
|
||
retries = final_state.get("retry_count", 0)
|
||
st.error(f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML")
|
||
st.markdown(f"**错误:** {error_msg}")
|
||
if explanation:
|
||
st.markdown(f"**原因:** {explanation}")
|
||
if jrxml:
|
||
with st.expander("查看当前 JRXML"):
|
||
_render_jrxml(jrxml, max_lines=80)
|
||
st.caption("💡 下次输入修改需求时,系统会自动加载失败上下文继续修复。")
|
||
st.session_state.messages.append({
|
||
"role": "assistant",
|
||
"content": f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML。\n\n**错误:** {error_msg}\n\n💡 请直接描述修改需求,系统会自动加载失败上下文。",
|
||
"type": "error_explanation",
|
||
})
|
||
|
||
# OCR 字段提取结果展示
|
||
ocr_result = agent_state.get("ocr_extraction_result", {})
|
||
if ocr_result and ocr_result.get("ocr_available") and ocr_result.get("fields"):
|
||
with st.expander("🔍 OCR 单据字段提取结果", expanded=False):
|
||
fields = ocr_result.get("fields", [])
|
||
non_empty = [f for f in fields if f.get("field_value")]
|
||
empty = [f for f in fields if not f.get("field_value")]
|
||
if non_empty:
|
||
st.markdown("**已提取字段:**")
|
||
for f in non_empty:
|
||
method = f.get("extraction_method", "")
|
||
conf = f.get("confidence", 0)
|
||
st.markdown(
|
||
f"- **{f['field_name']}**: `{f['field_value']}` "
|
||
f"(置信度: {conf:.0%}, 方法: {method})"
|
||
)
|
||
if empty:
|
||
st.caption(
|
||
f"未提取到值的字段: {', '.join(f['field_name'] for f in empty)}"
|
||
)
|
||
st.caption(
|
||
f"共检测到 {ocr_result.get('total_elements', 0)} 个文本元素"
|
||
)
|
||
else:
|
||
st.error("未产生结果,请重试。")
|
||
|
||
_app_log.info(
|
||
"代理执行完成",
|
||
extra={
|
||
"session_id": session_id,
|
||
"intent": final_state.get("intent", ""),
|
||
"status": final_state.get("status", ""),
|
||
"jrxml_length": len(final_state.get("current_jrxml", "")),
|
||
"retry_count": final_state.get("retry_count", 0),
|
||
},
|
||
)
|
||
|
||
|
||
# ---- 侧边栏 ----
|
||
with st.sidebar:
|
||
st.title("📊 JRXML 代理")
|
||
st.markdown("通过自然语言生成 JasperReports 模板。")
|
||
st.divider()
|
||
|
||
# 会话管理
|
||
st.markdown("### 会话管理")
|
||
sessions = list_all_sessions()
|
||
session_options = {}
|
||
for s in sessions:
|
||
sid = s["session_id"]
|
||
name = s.get("session_name", sid)
|
||
updated = s.get("updated_at", "")[:16]
|
||
session_options[f"{name} ({updated})"] = sid
|
||
|
||
selected_label = None
|
||
for label, sid in session_options.items():
|
||
if sid == current_session_id:
|
||
selected_label = label
|
||
break
|
||
|
||
selected = st.selectbox(
|
||
"切换会话",
|
||
options=list(session_options.keys()),
|
||
index=list(session_options.keys()).index(selected_label) if selected_label else 0,
|
||
key="session_selector",
|
||
)
|
||
|
||
if selected and session_options.get(selected) != current_session_id:
|
||
new_sid = session_options[selected]
|
||
if st.session_state.get("_last_switched_to") == new_sid:
|
||
# 防止同一会话重复切换导致的无限 rerun 循环
|
||
st.session_state._last_switched_to = ""
|
||
else:
|
||
data = load_session(new_sid)
|
||
if data and data.get("agent_state"):
|
||
_app_log.info(
|
||
"切换会话",
|
||
extra={"from_session": current_session_id, "to_session": new_sid},
|
||
)
|
||
data["agent_state"]["session_id"] = new_sid
|
||
st.session_state.agent_state = data["agent_state"]
|
||
st.session_state.messages = []
|
||
st.session_state._last_switched_to = new_sid
|
||
st.rerun()
|
||
|
||
col1, col2 = st.columns(2)
|
||
with col1:
|
||
if st.button("➕ 新建", use_container_width=True):
|
||
new_data = create_session(name="", agent_state=create_initial_state())
|
||
_app_log.info(
|
||
"新建会话",
|
||
extra={"session_id": new_data["session_id"]},
|
||
)
|
||
st.session_state.agent_state = create_initial_state()
|
||
st.session_state.agent_state["session_id"] = new_data["session_id"]
|
||
st.session_state.agent_state["session_name"] = new_data["session_name"]
|
||
st.session_state.agent_state["created_at"] = new_data["created_at"]
|
||
st.session_state.messages = []
|
||
st.rerun()
|
||
with col2:
|
||
if st.button("🗑 删除", use_container_width=True):
|
||
if current_session_id:
|
||
_app_log.info(
|
||
"删除会话",
|
||
extra={"session_id": current_session_id},
|
||
)
|
||
delete_session(current_session_id)
|
||
st.session_state.agent_state = create_initial_state()
|
||
new_data = create_session(name="", agent_state=st.session_state.agent_state)
|
||
st.session_state.agent_state["session_id"] = new_data["session_id"]
|
||
st.session_state.agent_state["session_name"] = new_data["session_name"]
|
||
st.session_state.agent_state["created_at"] = new_data["created_at"]
|
||
st.session_state.messages = []
|
||
st.rerun()
|
||
|
||
current_name = st.session_state.agent_state.get("session_name", "")
|
||
st.caption(f"当前: {current_name} (`{current_session_id}`)")
|
||
|
||
st.divider()
|
||
st.markdown("### 快捷操作")
|
||
|
||
has_jrxml = bool(st.session_state.agent_state.get("current_jrxml", "").strip())
|
||
has_history = bool(st.session_state.agent_state.get("history_states", []))
|
||
|
||
qcol1, qcol2 = st.columns(2)
|
||
with qcol1:
|
||
if st.button("👁 预览", use_container_width=True, disabled=not has_jrxml):
|
||
with st.spinner("正在准备预览..."):
|
||
run_agent("预览报表")
|
||
st.rerun()
|
||
with qcol2:
|
||
if st.button("↩ 撤销", use_container_width=True, disabled=not has_history):
|
||
with st.spinner("正在撤销..."):
|
||
run_agent("撤销上一步修改")
|
||
st.rerun()
|
||
|
||
if st.button("🔄 重置会话", use_container_width=True):
|
||
with st.spinner("正在重置..."):
|
||
run_agent("重新来,清空当前报表")
|
||
st.rerun()
|
||
|
||
st.divider()
|
||
st.markdown("### 配置")
|
||
llm_backend = os.getenv("LLM_BACKEND", "cloud")
|
||
llm_model = os.getenv("LLM_MODEL", os.getenv("LOCAL_LLM_MODEL", "gpt-4o"))
|
||
st.caption(f"大语言模型: {llm_backend} / {llm_model}")
|
||
st.caption(f"最大重试次数: {os.getenv('MAX_RETRY', '5')}")
|
||
st.caption(f"验证服务: {os.getenv('VALIDATION_SERVICE_URL', 'http://localhost:8001/validate')}")
|
||
|
||
st.divider()
|
||
st.markdown("### 下载")
|
||
|
||
final = st.session_state.agent_state.get("final_jrxml", "")
|
||
versions = st.session_state.agent_state.get("jrxml_versions", [])
|
||
|
||
if final:
|
||
st.download_button(
|
||
label="📥 下载最新 JRXML",
|
||
data=final,
|
||
file_name="report.jrxml",
|
||
mime="application/xml",
|
||
use_container_width=True,
|
||
)
|
||
|
||
if versions:
|
||
with st.expander("📋 历史版本", expanded=False):
|
||
for i, v in enumerate(reversed(versions)):
|
||
ts = v.get("ts", "")[:16]
|
||
label = v.get("label", "版本")
|
||
status = v.get("status", "")
|
||
icon = "✅" if status == "pass" else "❌"
|
||
dl_label = f"{icon} v{len(versions)-i} — {label} ({ts})"
|
||
st.download_button(
|
||
label=dl_label,
|
||
data=v.get("jrxml", ""),
|
||
file_name=f"report_v{len(versions)-i}.jrxml",
|
||
mime="application/xml",
|
||
use_container_width=True,
|
||
key=f"dl_v{i}",
|
||
)
|
||
|
||
# ---- 标题 ----
|
||
st.title("📝 JRXML 报表生成器")
|
||
st.caption("用自然语言描述您的报表需求,我将逐步生成可用的 JRXML 模板。")
|
||
|
||
# ---- 聊天历史 ----
|
||
for msg in st.session_state.messages:
|
||
with st.chat_message(msg["role"]):
|
||
if msg.get("type") == "jrxml":
|
||
with st.expander("查看生成的 JRXML", expanded=False):
|
||
st.code(msg["content"], language="xml")
|
||
elif msg.get("type") == "error_explanation":
|
||
st.warning(msg["content"])
|
||
elif msg.get("type") == "success":
|
||
st.success(msg["content"])
|
||
elif msg.get("type") == "consult":
|
||
st.info(msg["content"])
|
||
else:
|
||
st.markdown(msg["content"])
|
||
|
||
# ---- 统一聊天输入组件 ----
|
||
UNIFIED_CHAT_HTML = r"""
|
||
<!DOCTYPE html>
|
||
<html lang="zh-CN">
|
||
<head>
|
||
<meta charset="utf-8">
|
||
<style>
|
||
* { box-sizing: border-box; margin: 0; padding: 0; }
|
||
body {
|
||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||
background: transparent;
|
||
padding: 4px 0;
|
||
}
|
||
.chat-container {
|
||
position: relative;
|
||
border: 1px solid #d1d5db;
|
||
border-radius: 12px;
|
||
padding: 8px 12px;
|
||
background: #ffffff;
|
||
transition: border-color 0.2s, box-shadow 0.2s;
|
||
}
|
||
.chat-container:focus-within {
|
||
border-color: #3b82f6;
|
||
box-shadow: 0 0 0 2px rgba(59,130,246,0.15);
|
||
}
|
||
.chat-container.drag-active {
|
||
border-color: #3b82f6;
|
||
background: rgba(59,130,246,0.04);
|
||
}
|
||
.file-chips {
|
||
display: flex;
|
||
flex-wrap: wrap;
|
||
gap: 6px;
|
||
margin-bottom: 6px;
|
||
}
|
||
.file-chips:empty { display: none; }
|
||
.file-chip {
|
||
display: inline-flex;
|
||
align-items: center;
|
||
gap: 4px;
|
||
padding: 2px 8px;
|
||
background: #f3f4f6;
|
||
border-radius: 14px;
|
||
font-size: 12px;
|
||
color: #374151;
|
||
max-width: 200px;
|
||
}
|
||
.file-chip .chip-icon { font-size: 13px; }
|
||
.file-chip .chip-name {
|
||
overflow: hidden;
|
||
text-overflow: ellipsis;
|
||
white-space: nowrap;
|
||
}
|
||
.file-chip .chip-remove {
|
||
border: none;
|
||
background: none;
|
||
cursor: pointer;
|
||
color: #9ca3af;
|
||
font-size: 14px;
|
||
line-height: 1;
|
||
padding: 0 2px;
|
||
flex-shrink: 0;
|
||
}
|
||
.file-chip .chip-remove:hover { color: #ef4444; }
|
||
.input-row {
|
||
display: flex;
|
||
align-items: flex-end;
|
||
gap: 8px;
|
||
}
|
||
.attach-btn {
|
||
border: none;
|
||
background: none;
|
||
cursor: pointer;
|
||
padding: 4px 6px;
|
||
font-size: 20px;
|
||
line-height: 1;
|
||
color: #6b7280;
|
||
border-radius: 6px;
|
||
transition: background 0.15s, color 0.15s;
|
||
flex-shrink: 0;
|
||
}
|
||
.attach-btn:hover { background: #f3f4f6; color: #374151; }
|
||
textarea {
|
||
flex: 1;
|
||
border: none;
|
||
outline: none;
|
||
resize: none;
|
||
font-size: 15px;
|
||
line-height: 1.5;
|
||
font-family: inherit;
|
||
color: #111827;
|
||
background: transparent;
|
||
padding: 4px 0;
|
||
min-height: 24px;
|
||
max-height: 120px;
|
||
overflow-y: auto;
|
||
}
|
||
textarea::placeholder { color: #9ca3af; }
|
||
.send-btn {
|
||
border: none;
|
||
cursor: pointer;
|
||
padding: 4px 10px;
|
||
font-size: 16px;
|
||
background: #e5e7eb;
|
||
color: #9ca3af;
|
||
border-radius: 8px;
|
||
transition: all 0.15s;
|
||
flex-shrink: 0;
|
||
}
|
||
.send-btn.active { background: #3b82f6; color: #fff; }
|
||
.send-btn.active:hover { background: #2563eb; }
|
||
.send-btn:disabled { opacity: 0.5; cursor: default; }
|
||
.error-toast {
|
||
position: fixed;
|
||
bottom: 12px;
|
||
left: 50%;
|
||
transform: translateX(-50%);
|
||
background: #ef4444;
|
||
color: #fff;
|
||
padding: 6px 16px;
|
||
border-radius: 8px;
|
||
font-size: 13px;
|
||
z-index: 9999;
|
||
animation: toastOut 2.5s forwards;
|
||
pointer-events: none;
|
||
}
|
||
@keyframes toastOut {
|
||
0%, 70% { opacity: 1; }
|
||
100% { opacity: 0; }
|
||
}
|
||
|
||
@media (prefers-color-scheme: dark) {
|
||
.chat-container { background: #1f2937; border-color: #374151; }
|
||
.chat-container:focus-within { border-color: #3b82f6; }
|
||
.file-chip { background: #374151; color: #e5e7eb; }
|
||
.file-chip .chip-remove { color: #6b7280; }
|
||
.attach-btn { color: #9ca3af; }
|
||
.attach-btn:hover { background: #374151; color: #e5e7eb; }
|
||
textarea { color: #f9fafb; }
|
||
textarea::placeholder { color: #6b7280; }
|
||
.send-btn { background: #374151; }
|
||
}
|
||
</style>
|
||
</head>
|
||
<body>
|
||
<div class="chat-container" id="container">
|
||
<div class="file-chips" id="chips"></div>
|
||
<div class="input-row">
|
||
<button class="attach-btn" id="attachBtn" title="附加文件">📎</button>
|
||
<textarea id="textInput" placeholder="描述您的报表需求..." rows="1"></textarea>
|
||
<button class="send-btn" id="sendBtn" title="发送">➤</button>
|
||
</div>
|
||
<input type="file" id="fileInput" multiple hidden
|
||
accept=".png,.jpg,.jpeg,.bmp,.webp,.pdf,.docx,.xlsx,.xls,.doc,.txt">
|
||
</div>
|
||
<script>
|
||
const container = document.getElementById('container');
|
||
const chipsEl = document.getElementById('chips');
|
||
const textInput = document.getElementById('textInput');
|
||
const sendBtn = document.getElementById('sendBtn');
|
||
const attachBtn = document.getElementById('attachBtn');
|
||
const fileInput = document.getElementById('fileInput');
|
||
|
||
let attachedFiles = [];
|
||
const MAX_FILES = 10;
|
||
const MAX_SIZE = 20 * 1024 * 1024;
|
||
|
||
function getIcon(type) {
|
||
if (type.startsWith('image/')) return '🖼';
|
||
if (type.includes('pdf')) return '📄';
|
||
if (type.includes('document')) return '📝';
|
||
if (type.includes('spreadsheet') || type.includes('excel')) return '📊';
|
||
return '📎';
|
||
}
|
||
|
||
function updateSendBtn() {
|
||
var canSend = textInput.value.trim() || attachedFiles.length > 0;
|
||
sendBtn.classList.toggle('active', canSend);
|
||
}
|
||
|
||
function renderChips() {
|
||
chipsEl.innerHTML = '';
|
||
attachedFiles.forEach(function(f, i) {
|
||
var chip = document.createElement('span');
|
||
chip.className = 'file-chip';
|
||
var name = f.name.length > 16 ? f.name.slice(0,14)+'..' : f.name;
|
||
chip.innerHTML = '<span class="chip-icon">'+getIcon(f.type)+'</span>' +
|
||
'<span class="chip-name">'+name+'</span>' +
|
||
'<button class="chip-remove">×</button>';
|
||
chip.querySelector('.chip-remove').onclick = (function(idx) {
|
||
return function() {
|
||
attachedFiles.splice(idx, 1);
|
||
renderChips();
|
||
updateSendBtn();
|
||
};
|
||
})(i);
|
||
chipsEl.appendChild(chip);
|
||
});
|
||
updateSendBtn();
|
||
}
|
||
|
||
function addFiles(fileList) {
|
||
for (var i = 0; i < fileList.length; i++) {
|
||
var file = fileList[i];
|
||
if (attachedFiles.length >= MAX_FILES) { showToast('最多附加 '+MAX_FILES+' 个文件'); break; }
|
||
if (file.size > MAX_SIZE) { showToast(file.name+' 超过 20MB 限制'); continue; }
|
||
if (attachedFiles.some(function(f) { return f.name === file.name && f.size === file.size; })) continue;
|
||
attachedFiles.push({name: file.name, type: file.type, file: file});
|
||
}
|
||
renderChips();
|
||
}
|
||
|
||
function showToast(msg) {
|
||
var t = document.createElement('div');
|
||
t.className = 'error-toast';
|
||
t.textContent = msg;
|
||
document.body.appendChild(t);
|
||
setTimeout(function() { t.remove(); }, 2600);
|
||
}
|
||
|
||
function readFile(file) {
|
||
return new Promise(function(resolve, reject) {
|
||
var reader = new FileReader();
|
||
reader.onload = function() { resolve(reader.result); };
|
||
reader.onerror = reject;
|
||
reader.readAsDataURL(file);
|
||
});
|
||
}
|
||
|
||
async function handleSend() {
|
||
var text = textInput.value.trim();
|
||
if (!text && attachedFiles.length === 0) return;
|
||
|
||
sendBtn.disabled = true;
|
||
var files = [];
|
||
for (var i = 0; i < attachedFiles.length; i++) {
|
||
var f = attachedFiles[i];
|
||
try {
|
||
var dataUrl = await readFile(f.file);
|
||
files.push({name: f.name, type: f.type, data: dataUrl, size: f.file.size});
|
||
} catch(e) {
|
||
showToast(f.name+' 读取失败');
|
||
}
|
||
}
|
||
|
||
Streamlit.setComponentValue({text: text, files: files});
|
||
|
||
textInput.value = '';
|
||
attachedFiles = [];
|
||
renderChips();
|
||
sendBtn.disabled = false;
|
||
textInput.style.height = 'auto';
|
||
}
|
||
|
||
attachBtn.onclick = function() { fileInput.click(); };
|
||
fileInput.onchange = function() { addFiles(fileInput.files); fileInput.value = ''; };
|
||
|
||
textInput.oninput = function() {
|
||
updateSendBtn();
|
||
textInput.style.height = 'auto';
|
||
textInput.style.height = Math.min(textInput.scrollHeight, 120) + 'px';
|
||
};
|
||
|
||
textInput.onkeydown = function(e) {
|
||
if (e.key === 'Enter' && !e.shiftKey) {
|
||
e.preventDefault();
|
||
handleSend();
|
||
}
|
||
};
|
||
|
||
sendBtn.onclick = handleSend;
|
||
|
||
document.addEventListener('paste', function(e) {
|
||
var items = e.clipboardData && e.clipboardData.items;
|
||
if (!items) return;
|
||
var files = [];
|
||
for (var i = 0; i < items.length; i++) {
|
||
if (items[i].kind === 'file') files.push(items[i].getAsFile());
|
||
}
|
||
if (files.length) { e.preventDefault(); addFiles(files); }
|
||
});
|
||
|
||
var containerDiv = document.getElementById('container');
|
||
containerDiv.addEventListener('dragover', function(e) {
|
||
e.preventDefault();
|
||
containerDiv.classList.add('drag-active');
|
||
});
|
||
containerDiv.addEventListener('dragleave', function() {
|
||
containerDiv.classList.remove('drag-active');
|
||
});
|
||
containerDiv.addEventListener('drop', function(e) {
|
||
e.preventDefault();
|
||
containerDiv.classList.remove('drag-active');
|
||
addFiles(e.dataTransfer.files);
|
||
});
|
||
|
||
updateSendBtn();
|
||
</script>
|
||
</body>
|
||
</html>
|
||
"""
|
||
|
||
chat_result = components.html(UNIFIED_CHAT_HTML, height=180)
|
||
|
||
if chat_result and isinstance(chat_result, dict):
|
||
prompt = chat_result.get("text", "")
|
||
files = chat_result.get("files", [])
|
||
|
||
from backend.file_parser import parse_file
|
||
from backend.layout_analyzer import analyze_layout, extract_layout_schema
|
||
|
||
file_texts = []
|
||
attached_info = []
|
||
first_image_path = None
|
||
temp_paths = []
|
||
|
||
for f in files:
|
||
header, b64data = f.get("data", ",").split(",", 1)
|
||
raw = base64.b64decode(b64data)
|
||
|
||
mime = f.get("type", "")
|
||
mime_to_suffix = {
|
||
"image/png": ".png", "image/jpeg": ".jpg", "image/bmp": ".bmp",
|
||
"image/webp": ".webp", "application/pdf": ".pdf",
|
||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
||
"application/vnd.ms-excel": ".xls", "application/msword": ".doc",
|
||
"text/plain": ".txt",
|
||
}
|
||
suffix = mime_to_suffix.get(mime, Path(f["name"]).suffix.lower())
|
||
|
||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
||
tmp.write(raw)
|
||
tmp_path = tmp.name
|
||
temp_paths.append(tmp_path)
|
||
|
||
result = parse_file(tmp_path, suffix)
|
||
text = result["text"]
|
||
file_type = result["file_type"]
|
||
|
||
img_suffixes = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
|
||
if suffix in img_suffixes and result.get("method") not in ("metadata_only", None):
|
||
try:
|
||
layout = analyze_layout(tmp_path)
|
||
tt = layout.get("template_type", "unknown")
|
||
if tt == "full_a4":
|
||
text = layout["description"]
|
||
file_type = "a4_template"
|
||
schema = extract_layout_schema(layout)
|
||
st.session_state.agent_state["layout_schema"] = schema
|
||
st.session_state.agent_state["ocr_elements"] = layout.get("rows", [])
|
||
elif tt == "partial_rows":
|
||
file_type = "a4_partial"
|
||
except Exception:
|
||
pass
|
||
|
||
file_texts.append(f"[附加文件: {f['name']} ({file_type})]\n{text}")
|
||
attached_info.append({"name": f["name"], "type": file_type, "length": len(text)})
|
||
|
||
if not first_image_path and file_type in ("image", "a4_template", "a4_partial"):
|
||
first_image_path = tmp_path
|
||
|
||
if file_texts:
|
||
full_prompt = "\n\n".join(file_texts) + "\n\n---\n用户需求:\n" + prompt
|
||
else:
|
||
full_prompt = prompt
|
||
|
||
if first_image_path:
|
||
st.session_state.agent_state["uploaded_file_path"] = first_image_path
|
||
|
||
_app_log.info(
|
||
"收到用户输入",
|
||
extra={
|
||
"session_id": current_session_id,
|
||
"prompt_preview": prompt[:200],
|
||
"prompt_length": len(prompt),
|
||
"has_uploaded_files": bool(attached_info),
|
||
"uploaded_files": attached_info,
|
||
},
|
||
)
|
||
|
||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||
with st.chat_message("user"):
|
||
st.markdown(prompt)
|
||
run_agent(full_prompt)
|
||
|
||
for p in temp_paths:
|
||
try:
|
||
Path(p).unlink(missing_ok=True)
|
||
except Exception:
|
||
pass
|
||
|
||
st.rerun()
|