Files
agent_jrxml/app.py
T

927 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Streamlit 多轮对话 UI,用于 JRXML 生成代理。
支持:
- 流式输出(LLM 逐字展示)
- 节点平铺展开(每个处理阶段独立展示)
- 完成后自动折叠节点区
- 过程总结卡片
"""
import os
import sys
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
try:
import torchvision
except Exception:
pass
import base64
import tempfile
import time
from pathlib import Path
import streamlit as st
import streamlit.components.v1 as components
from dotenv import load_dotenv
load_dotenv(override=True)
from agent.graph import build_graph, create_initial_state
from backend.session import (
create_session,
load_session,
delete_session,
list_all_sessions,
)
from backend.logger import get_logger, set_trace_id, generate_trace_id
_app_log = get_logger("app")
st.set_page_config(
page_title="JRXML 代理",
page_icon="📊",
layout="wide",
initial_sidebar_state="expanded",
)
# 阻止 Streamlit 裸 'c' 键清除缓存,保留 Ctrl+C 复制行为
st.html("""
<script>
(function() {
const parent = window.parent.document;
parent.addEventListener('keydown', function(e) {
// 仅拦截裸 'c' 键(非 Ctrl/Cmd 组合)
if (e.key === 'c' && !e.ctrlKey && !e.metaKey && !e.altKey) {
const tag = parent.activeElement ? parent.activeElement.tagName : '';
if (tag !== 'INPUT' && tag !== 'TEXTAREA' && !parent.activeElement.isContentEditable) {
e.stopImmediatePropagation();
e.preventDefault();
}
}
}, true);
})();
</script>
""")
# ---- 节点名称 → 中文标签 ----
NODE_LABELS = {
"load_session": "📂 加载会话",
"process_input": "📝 记录输入",
"manage_context": "🧠 管理上下文",
"save_state_snapshot": "💾 保存快照",
"classify_intent": "🔍 识别意图",
"retrieve": "📚 检索模板",
"generate": "⚙️ 生成 JRXML",
"modify_jrxml": "🔧 修改 JRXML",
"validate": "✅ 验证",
"explain_error": "🔎 分析错误",
"correct_jrxml": "🛠 自动修正",
"finalize": "📋 完成",
"handle_consult": "💬 咨询回答",
"handle_undo": "↩ 撤销操作",
"handle_reset": "🔄 重置会话",
"save_session": "💾 保存会话",
"generate_skeleton": "🏗 生成骨架",
"refine_layout": "📐 精调布局",
"map_fields": "🏷 映射字段",
}
INTENT_LABELS = {
"initial_generation": "新建报表",
"modify_report": "修改报表",
"preview_report": "预览报表",
"export_pdf": "导出 PDF",
"export_jrxml": "下载 JRXML",
"undo_modification": "撤销修改",
"consult_question": "咨询问题",
"reset_session": "重置会话",
}
SKIP_NODES = {"load_session", "process_input", "manage_context",
"save_state_snapshot", "save_session"}
def _render_jrxml(jrxml: str, max_lines: int = 30):
"""展示 JRXML 代码(折叠、限行)。"""
lines = jrxml.strip().split("\n")
preview = "\n".join(lines[:max_lines])
if len(lines) > max_lines:
preview += f"\n... (共 {len(lines)} 行)"
st.code(preview, language="xml")
# ---- URL 参数 ----
query_params = st.query_params
url_session_id = query_params.get("session_id", "")
# ---- 会话状态初始化 ----
if "messages" not in st.session_state:
st.session_state.messages = []
if "graph" not in st.session_state:
st.session_state.graph = build_graph()
if "pending_action" not in st.session_state:
st.session_state.pending_action = None
if "agent_state" not in st.session_state:
if url_session_id:
data = load_session(url_session_id)
if data and data.get("agent_state"):
st.session_state.agent_state = data["agent_state"]
st.session_state.agent_state["session_id"] = url_session_id
else:
st.session_state.agent_state = create_initial_state()
new_data = create_session(name="", agent_state=st.session_state.agent_state)
st.session_state.agent_state["session_id"] = new_data["session_id"]
st.session_state.agent_state["session_name"] = new_data["session_name"]
st.session_state.agent_state["created_at"] = new_data["created_at"]
else:
st.session_state.agent_state = create_initial_state()
new_data = create_session(name="", agent_state=st.session_state.agent_state)
st.session_state.agent_state["session_id"] = new_data["session_id"]
st.session_state.agent_state["session_name"] = new_data["session_name"]
st.session_state.agent_state["created_at"] = new_data["created_at"]
current_session_id = st.session_state.agent_state.get("session_id", "")
def run_agent(user_input: str):
"""运行代理图:流式渲染节点进度 + LLM 文本。"""
trace_id = generate_trace_id()
set_trace_id(trace_id)
agent_state = st.session_state.agent_state
session_id = agent_state.get("session_id", "")
_app_log.info(
"代理执行开始",
extra={
"session_id": session_id,
"trace_id": trace_id,
"user_input_preview": user_input[:200],
"user_input_length": len(user_input),
"has_jrxml": bool(agent_state.get("current_jrxml", "").strip()),
"intent": agent_state.get("intent", ""),
},
)
if agent_state.get("current_jrxml") and agent_state.get("status") == "pass":
agent_state["user_modification_request"] = user_input
agent_state["user_input"] = user_input
agent_state["retry_count"] = 0
# ---- UI 占位 ----
progress_placeholder = st.empty() # 实时节点进度
streaming_placeholder = st.empty() # 流式文本
summary_placeholder = st.empty() # 总结卡片
# 初始状态提示
progress_placeholder.info("⏳ 正在分析您的需求...")
executed_nodes: list[dict] = []
stream_text = ""
stream_active = False
final_state = None
def _render_progress(nodes: list[dict]):
"""渲染实时节点进度到占位符。"""
if not nodes:
return
lines = []
for i, node in enumerate(nodes):
icon = "" if i == len(nodes) - 1 else ""
detail = f"{node['detail']}" if node.get("detail") else ""
lines.append(f"{icon} {node['label']}{detail}")
progress_placeholder.markdown("\n\n".join(lines))
try:
for event in st.session_state.graph.stream(
agent_state, stream_mode=["updates", "custom"]
):
mode, data = event
if mode == "updates":
for node_name, node_state in data.items():
label = NODE_LABELS.get(node_name, node_name)
if node_name not in SKIP_NODES:
executed_nodes.append({
"name": node_name,
"label": label,
})
if node_name == "classify_intent":
intent = node_state.get("intent", "")
il = INTENT_LABELS.get(intent, intent)
executed_nodes[-1]["detail"] = f"意图: {il}"
elif node_name == "retrieve":
ctx = node_state.get("retrieved_context", "")
executed_nodes[-1]["detail"] = (
f"找到 {len(ctx)} 字符参考模板" if ctx else "未匹配到模板"
)
elif node_name in ("generate", "modify_jrxml", "correct_jrxml",
"generate_skeleton", "refine_layout", "map_fields"):
jrxml = node_state.get("current_jrxml", "")
executed_nodes[-1]["detail"] = f"生成 {len(jrxml)} 字符 JRXML"
elif node_name == "validate":
status = node_state.get("status", "")
if status == "pass":
executed_nodes[-1]["detail"] = "验证通过 ✓"
else:
err = node_state.get("error_msg", "")
executed_nodes[-1]["detail"] = f"验证失败: {err[:80]}"
elif node_name == "explain_error":
expl = node_state.get("natural_explanation", "")
executed_nodes[-1]["detail"] = expl[:120]
elif node_name == "handle_consult":
ans = node_state.get("consult_answer", "")
executed_nodes[-1]["detail"] = ans[:150]
final_state = node_state
# 每个节点完成后立即更新进度
_render_progress(executed_nodes)
elif mode == "custom":
cd = data
if cd.get("type") == "stream":
stream_text += cd.get("text", "")
stream_active = True
streaming_placeholder.code(stream_text, language="xml")
except Exception as e:
progress_placeholder.empty()
_app_log.error(
f"代理执行异常: {e}",
extra={"session_id": session_id, "error": str(e)},
)
st.error(f"工作流异常: {e}")
return
# ---- 清理临时占位 ----
progress_placeholder.empty()
if stream_active:
streaming_placeholder.empty()
# ---- 总结卡片 ----
# 注:node_state 只含变更字段,用 agent_state(被所有节点就地修改)获取完整状态
final_state = agent_state
if final_state:
st.session_state.agent_state = final_state
intent = final_state.get("intent", "")
status = final_state.get("status", "")
with summary_placeholder.container(border=True):
if intent == "consult_question":
answer = final_state.get("consult_answer", "")
st.info(answer)
st.session_state.messages.append({
"role": "assistant", "content": answer, "type": "consult",
})
elif intent in ("undo_modification", "reset_session"):
st.success("操作已完成")
elif intent in ("preview_report", "export_pdf", "export_jrxml"):
jrxml = final_state.get("current_jrxml", "")
if jrxml:
st.success("✅ 当前报表")
_render_jrxml(jrxml)
st.session_state.messages.append({
"role": "assistant", "content": jrxml, "type": "jrxml",
})
else:
st.warning("⚠ 当前没有报表可以展示。")
elif status == "pass":
jrxml = final_state.get("current_jrxml", "")
st.success("✅ JRXML 生成成功")
st.markdown("**生成结果:**")
_render_jrxml(jrxml)
st.caption("您可以从侧边栏下载文件,或继续对话进行修改。")
st.session_state.messages.append({
"role": "assistant", "content": jrxml, "type": "jrxml",
})
st.session_state.messages.append({
"role": "assistant",
"content": "✅ JRXML 生成成功!您可以从侧边栏下载文件,或继续修改。",
"type": "success",
})
else:
jrxml = final_state.get("current_jrxml", "")
error_msg = final_state.get("error_msg", "未知错误")
explanation = final_state.get("natural_explanation", "")
retries = final_state.get("retry_count", 0)
st.error(f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML")
st.markdown(f"**错误:** {error_msg}")
if explanation:
st.markdown(f"**原因:** {explanation}")
if jrxml:
with st.expander("查看当前 JRXML"):
_render_jrxml(jrxml, max_lines=80)
st.caption("💡 下次输入修改需求时,系统会自动加载失败上下文继续修复。")
st.session_state.messages.append({
"role": "assistant",
"content": f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML。\n\n**错误:** {error_msg}\n\n💡 请直接描述修改需求,系统会自动加载失败上下文。",
"type": "error_explanation",
})
# OCR 字段提取结果展示
ocr_result = agent_state.get("ocr_extraction_result", {})
if ocr_result and ocr_result.get("ocr_available") and ocr_result.get("fields"):
with st.expander("🔍 OCR 单据字段提取结果", expanded=False):
fields = ocr_result.get("fields", [])
non_empty = [f for f in fields if f.get("field_value")]
empty = [f for f in fields if not f.get("field_value")]
if non_empty:
st.markdown("**已提取字段:**")
for f in non_empty:
method = f.get("extraction_method", "")
conf = f.get("confidence", 0)
st.markdown(
f"- **{f['field_name']}**: `{f['field_value']}` "
f"(置信度: {conf:.0%}, 方法: {method}"
)
if empty:
st.caption(
f"未提取到值的字段: {', '.join(f['field_name'] for f in empty)}"
)
st.caption(
f"共检测到 {ocr_result.get('total_elements', 0)} 个文本元素"
)
else:
st.error("未产生结果,请重试。")
_app_log.info(
"代理执行完成",
extra={
"session_id": session_id,
"intent": final_state.get("intent", ""),
"status": final_state.get("status", ""),
"jrxml_length": len(final_state.get("current_jrxml", "")),
"retry_count": final_state.get("retry_count", 0),
},
)
# ---- 侧边栏 ----
with st.sidebar:
st.title("📊 JRXML 代理")
st.markdown("通过自然语言生成 JasperReports 模板。")
st.divider()
# 会话管理
st.markdown("### 会话管理")
sessions = list_all_sessions()
session_options = {}
for s in sessions:
sid = s["session_id"]
name = s.get("session_name", sid)
updated = s.get("updated_at", "")[:16]
session_options[f"{name} ({updated})"] = sid
selected_label = None
for label, sid in session_options.items():
if sid == current_session_id:
selected_label = label
break
selected = st.selectbox(
"切换会话",
options=list(session_options.keys()),
index=list(session_options.keys()).index(selected_label) if selected_label else 0,
key="session_selector",
)
if selected and session_options.get(selected) != current_session_id:
new_sid = session_options[selected]
if st.session_state.get("_last_switched_to") == new_sid:
# 防止同一会话重复切换导致的无限 rerun 循环
st.session_state._last_switched_to = ""
else:
data = load_session(new_sid)
if data and data.get("agent_state"):
_app_log.info(
"切换会话",
extra={"from_session": current_session_id, "to_session": new_sid},
)
data["agent_state"]["session_id"] = new_sid
st.session_state.agent_state = data["agent_state"]
st.session_state.messages = []
st.session_state._last_switched_to = new_sid
st.rerun()
col1, col2 = st.columns(2)
with col1:
if st.button(" 新建", use_container_width=True):
new_data = create_session(name="", agent_state=create_initial_state())
_app_log.info(
"新建会话",
extra={"session_id": new_data["session_id"]},
)
st.session_state.agent_state = create_initial_state()
st.session_state.agent_state["session_id"] = new_data["session_id"]
st.session_state.agent_state["session_name"] = new_data["session_name"]
st.session_state.agent_state["created_at"] = new_data["created_at"]
st.session_state.messages = []
st.rerun()
with col2:
if st.button("🗑 删除", use_container_width=True):
if current_session_id:
_app_log.info(
"删除会话",
extra={"session_id": current_session_id},
)
delete_session(current_session_id)
st.session_state.agent_state = create_initial_state()
new_data = create_session(name="", agent_state=st.session_state.agent_state)
st.session_state.agent_state["session_id"] = new_data["session_id"]
st.session_state.agent_state["session_name"] = new_data["session_name"]
st.session_state.agent_state["created_at"] = new_data["created_at"]
st.session_state.messages = []
st.rerun()
current_name = st.session_state.agent_state.get("session_name", "")
st.caption(f"当前: {current_name} (`{current_session_id}`)")
st.divider()
st.markdown("### 快捷操作")
has_jrxml = bool(st.session_state.agent_state.get("current_jrxml", "").strip())
has_history = bool(st.session_state.agent_state.get("history_states", []))
qcol1, qcol2 = st.columns(2)
with qcol1:
if st.button("👁 预览", use_container_width=True, disabled=not has_jrxml):
with st.spinner("正在准备预览..."):
run_agent("预览报表")
st.rerun()
with qcol2:
if st.button("↩ 撤销", use_container_width=True, disabled=not has_history):
with st.spinner("正在撤销..."):
run_agent("撤销上一步修改")
st.rerun()
if st.button("🔄 重置会话", use_container_width=True):
with st.spinner("正在重置..."):
run_agent("重新来,清空当前报表")
st.rerun()
st.divider()
st.markdown("### 配置")
llm_backend = os.getenv("LLM_BACKEND", "cloud")
llm_model = os.getenv("LLM_MODEL", os.getenv("LOCAL_LLM_MODEL", "gpt-4o"))
st.caption(f"大语言模型: {llm_backend} / {llm_model}")
st.caption(f"最大重试次数: {os.getenv('MAX_RETRY', '5')}")
st.caption(f"验证服务: {os.getenv('VALIDATION_SERVICE_URL', 'http://localhost:8001/validate')}")
st.divider()
st.markdown("### 下载")
final = st.session_state.agent_state.get("final_jrxml", "")
versions = st.session_state.agent_state.get("jrxml_versions", [])
if final:
st.download_button(
label="📥 下载最新 JRXML",
data=final,
file_name="report.jrxml",
mime="application/xml",
use_container_width=True,
)
if versions:
with st.expander("📋 历史版本", expanded=False):
for i, v in enumerate(reversed(versions)):
ts = v.get("ts", "")[:16]
label = v.get("label", "版本")
status = v.get("status", "")
icon = "" if status == "pass" else ""
dl_label = f"{icon} v{len(versions)-i}{label} ({ts})"
st.download_button(
label=dl_label,
data=v.get("jrxml", ""),
file_name=f"report_v{len(versions)-i}.jrxml",
mime="application/xml",
use_container_width=True,
key=f"dl_v{i}",
)
# ---- 标题 ----
st.title("📝 JRXML 报表生成器")
st.caption("用自然语言描述您的报表需求,我将逐步生成可用的 JRXML 模板。")
# ---- 聊天历史 ----
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
if msg.get("type") == "jrxml":
with st.expander("查看生成的 JRXML", expanded=False):
st.code(msg["content"], language="xml")
elif msg.get("type") == "error_explanation":
st.warning(msg["content"])
elif msg.get("type") == "success":
st.success(msg["content"])
elif msg.get("type") == "consult":
st.info(msg["content"])
else:
st.markdown(msg["content"])
# ---- 统一聊天输入组件 ----
UNIFIED_CHAT_HTML = r"""
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="utf-8">
<style>
* { box-sizing: border-box; margin: 0; padding: 0; }
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
background: transparent;
padding: 4px 0;
}
.chat-container {
position: relative;
border: 1px solid #d1d5db;
border-radius: 12px;
padding: 8px 12px;
background: #ffffff;
transition: border-color 0.2s, box-shadow 0.2s;
}
.chat-container:focus-within {
border-color: #3b82f6;
box-shadow: 0 0 0 2px rgba(59,130,246,0.15);
}
.chat-container.drag-active {
border-color: #3b82f6;
background: rgba(59,130,246,0.04);
}
.file-chips {
display: flex;
flex-wrap: wrap;
gap: 6px;
margin-bottom: 6px;
}
.file-chips:empty { display: none; }
.file-chip {
display: inline-flex;
align-items: center;
gap: 4px;
padding: 2px 8px;
background: #f3f4f6;
border-radius: 14px;
font-size: 12px;
color: #374151;
max-width: 200px;
}
.file-chip .chip-icon { font-size: 13px; }
.file-chip .chip-name {
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.file-chip .chip-remove {
border: none;
background: none;
cursor: pointer;
color: #9ca3af;
font-size: 14px;
line-height: 1;
padding: 0 2px;
flex-shrink: 0;
}
.file-chip .chip-remove:hover { color: #ef4444; }
.input-row {
display: flex;
align-items: flex-end;
gap: 8px;
}
.attach-btn {
border: none;
background: none;
cursor: pointer;
padding: 4px 6px;
font-size: 20px;
line-height: 1;
color: #6b7280;
border-radius: 6px;
transition: background 0.15s, color 0.15s;
flex-shrink: 0;
}
.attach-btn:hover { background: #f3f4f6; color: #374151; }
textarea {
flex: 1;
border: none;
outline: none;
resize: none;
font-size: 15px;
line-height: 1.5;
font-family: inherit;
color: #111827;
background: transparent;
padding: 4px 0;
min-height: 24px;
max-height: 120px;
overflow-y: auto;
}
textarea::placeholder { color: #9ca3af; }
.send-btn {
border: none;
cursor: pointer;
padding: 4px 10px;
font-size: 16px;
background: #e5e7eb;
color: #9ca3af;
border-radius: 8px;
transition: all 0.15s;
flex-shrink: 0;
}
.send-btn.active { background: #3b82f6; color: #fff; }
.send-btn.active:hover { background: #2563eb; }
.send-btn:disabled { opacity: 0.5; cursor: default; }
.error-toast {
position: fixed;
bottom: 12px;
left: 50%;
transform: translateX(-50%);
background: #ef4444;
color: #fff;
padding: 6px 16px;
border-radius: 8px;
font-size: 13px;
z-index: 9999;
animation: toastOut 2.5s forwards;
pointer-events: none;
}
@keyframes toastOut {
0%, 70% { opacity: 1; }
100% { opacity: 0; }
}
@media (prefers-color-scheme: dark) {
.chat-container { background: #1f2937; border-color: #374151; }
.chat-container:focus-within { border-color: #3b82f6; }
.file-chip { background: #374151; color: #e5e7eb; }
.file-chip .chip-remove { color: #6b7280; }
.attach-btn { color: #9ca3af; }
.attach-btn:hover { background: #374151; color: #e5e7eb; }
textarea { color: #f9fafb; }
textarea::placeholder { color: #6b7280; }
.send-btn { background: #374151; }
}
</style>
</head>
<body>
<div class="chat-container" id="container">
<div class="file-chips" id="chips"></div>
<div class="input-row">
<button class="attach-btn" id="attachBtn" title="附加文件">&#x1F4CE;</button>
<textarea id="textInput" placeholder="描述您的报表需求..." rows="1"></textarea>
<button class="send-btn" id="sendBtn" title="发送">&#x27A4;</button>
</div>
<input type="file" id="fileInput" multiple hidden
accept=".png,.jpg,.jpeg,.bmp,.webp,.pdf,.docx,.xlsx,.xls,.doc,.txt">
</div>
<script>
const container = document.getElementById('container');
const chipsEl = document.getElementById('chips');
const textInput = document.getElementById('textInput');
const sendBtn = document.getElementById('sendBtn');
const attachBtn = document.getElementById('attachBtn');
const fileInput = document.getElementById('fileInput');
let attachedFiles = [];
const MAX_FILES = 10;
const MAX_SIZE = 20 * 1024 * 1024;
function getIcon(type) {
if (type.startsWith('image/')) return '🖼';
if (type.includes('pdf')) return '📄';
if (type.includes('document')) return '📝';
if (type.includes('spreadsheet') || type.includes('excel')) return '📊';
return '📎';
}
function updateSendBtn() {
var canSend = textInput.value.trim() || attachedFiles.length > 0;
sendBtn.classList.toggle('active', canSend);
}
function renderChips() {
chipsEl.innerHTML = '';
attachedFiles.forEach(function(f, i) {
var chip = document.createElement('span');
chip.className = 'file-chip';
var name = f.name.length > 16 ? f.name.slice(0,14)+'..' : f.name;
chip.innerHTML = '<span class="chip-icon">'+getIcon(f.type)+'</span>' +
'<span class="chip-name">'+name+'</span>' +
'<button class="chip-remove">&times;</button>';
chip.querySelector('.chip-remove').onclick = (function(idx) {
return function() {
attachedFiles.splice(idx, 1);
renderChips();
updateSendBtn();
};
})(i);
chipsEl.appendChild(chip);
});
updateSendBtn();
}
function addFiles(fileList) {
for (var i = 0; i < fileList.length; i++) {
var file = fileList[i];
if (attachedFiles.length >= MAX_FILES) { showToast('最多附加 '+MAX_FILES+' 个文件'); break; }
if (file.size > MAX_SIZE) { showToast(file.name+' 超过 20MB 限制'); continue; }
if (attachedFiles.some(function(f) { return f.name === file.name && f.size === file.size; })) continue;
attachedFiles.push({name: file.name, type: file.type, file: file});
}
renderChips();
}
function showToast(msg) {
var t = document.createElement('div');
t.className = 'error-toast';
t.textContent = msg;
document.body.appendChild(t);
setTimeout(function() { t.remove(); }, 2600);
}
function readFile(file) {
return new Promise(function(resolve, reject) {
var reader = new FileReader();
reader.onload = function() { resolve(reader.result); };
reader.onerror = reject;
reader.readAsDataURL(file);
});
}
async function handleSend() {
var text = textInput.value.trim();
if (!text && attachedFiles.length === 0) return;
sendBtn.disabled = true;
var files = [];
for (var i = 0; i < attachedFiles.length; i++) {
var f = attachedFiles[i];
try {
var dataUrl = await readFile(f.file);
files.push({name: f.name, type: f.type, data: dataUrl, size: f.file.size});
} catch(e) {
showToast(f.name+' 读取失败');
}
}
Streamlit.setComponentValue({text: text, files: files});
textInput.value = '';
attachedFiles = [];
renderChips();
sendBtn.disabled = false;
textInput.style.height = 'auto';
}
attachBtn.onclick = function() { fileInput.click(); };
fileInput.onchange = function() { addFiles(fileInput.files); fileInput.value = ''; };
textInput.oninput = function() {
updateSendBtn();
textInput.style.height = 'auto';
textInput.style.height = Math.min(textInput.scrollHeight, 120) + 'px';
};
textInput.onkeydown = function(e) {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
handleSend();
}
};
sendBtn.onclick = handleSend;
document.addEventListener('paste', function(e) {
var items = e.clipboardData && e.clipboardData.items;
if (!items) return;
var files = [];
for (var i = 0; i < items.length; i++) {
if (items[i].kind === 'file') files.push(items[i].getAsFile());
}
if (files.length) { e.preventDefault(); addFiles(files); }
});
var containerDiv = document.getElementById('container');
containerDiv.addEventListener('dragover', function(e) {
e.preventDefault();
containerDiv.classList.add('drag-active');
});
containerDiv.addEventListener('dragleave', function() {
containerDiv.classList.remove('drag-active');
});
containerDiv.addEventListener('drop', function(e) {
e.preventDefault();
containerDiv.classList.remove('drag-active');
addFiles(e.dataTransfer.files);
});
updateSendBtn();
</script>
</body>
</html>
"""
chat_result = components.html(UNIFIED_CHAT_HTML, height=180)
if chat_result and isinstance(chat_result, dict):
prompt = chat_result.get("text", "")
files = chat_result.get("files", [])
from backend.file_parser import parse_file
from backend.layout_analyzer import analyze_layout, extract_layout_schema
file_texts = []
attached_info = []
first_image_path = None
temp_paths = []
for f in files:
header, b64data = f.get("data", ",").split(",", 1)
raw = base64.b64decode(b64data)
mime = f.get("type", "")
mime_to_suffix = {
"image/png": ".png", "image/jpeg": ".jpg", "image/bmp": ".bmp",
"image/webp": ".webp", "application/pdf": ".pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
"application/vnd.ms-excel": ".xls", "application/msword": ".doc",
"text/plain": ".txt",
}
suffix = mime_to_suffix.get(mime, Path(f["name"]).suffix.lower())
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(raw)
tmp_path = tmp.name
temp_paths.append(tmp_path)
result = parse_file(tmp_path, suffix)
text = result["text"]
file_type = result["file_type"]
img_suffixes = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
if suffix in img_suffixes and result.get("method") not in ("metadata_only", None):
try:
layout = analyze_layout(tmp_path)
tt = layout.get("template_type", "unknown")
if tt == "full_a4":
text = layout["description"]
file_type = "a4_template"
schema = extract_layout_schema(layout)
st.session_state.agent_state["layout_schema"] = schema
st.session_state.agent_state["ocr_elements"] = layout.get("rows", [])
elif tt == "partial_rows":
file_type = "a4_partial"
except Exception:
pass
file_texts.append(f"[附加文件: {f['name']} ({file_type})]\n{text}")
attached_info.append({"name": f["name"], "type": file_type, "length": len(text)})
if not first_image_path and file_type in ("image", "a4_template", "a4_partial"):
first_image_path = tmp_path
if file_texts:
full_prompt = "\n\n".join(file_texts) + "\n\n---\n用户需求:\n" + prompt
else:
full_prompt = prompt
if first_image_path:
st.session_state.agent_state["uploaded_file_path"] = first_image_path
_app_log.info(
"收到用户输入",
extra={
"session_id": current_session_id,
"prompt_preview": prompt[:200],
"prompt_length": len(prompt),
"has_uploaded_files": bool(attached_info),
"uploaded_files": attached_info,
},
)
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
run_agent(full_prompt)
for p in temp_paths:
try:
Path(p).unlink(missing_ok=True)
except Exception:
pass
st.rerun()