c9f003e1b7
- 新增 backend/ocr_extractor.py: 两阶段提取流水线 (文档分析 + 字段提取) - 四种提取策略: 精确KV匹配/模糊KV匹配/正则模式/表格结构匹配 - agent/state.py: 新增 ocr_extraction_result 和 uploaded_file_path 字段 - agent/nodes.py: process_input() 中自动触发 OCR 提取钩子 - app.py: 文件上传时保留图片路径, 总结卡片中展示提取结果 - .env.example: 新增 OCR_USE_GPU / OCR_CONFIDENCE_THRESHOLD 配置项 - tests/test_ocr_extraction.py: 48 个单元测试全部通过
666 lines
26 KiB
Python
666 lines
26 KiB
Python
"""Streamlit 多轮对话 UI,用于 JRXML 生成代理。
|
||
|
||
支持:
|
||
- 流式输出(LLM 逐字展示)
|
||
- 节点平铺展开(每个处理阶段独立展示)
|
||
- 完成后自动折叠节点区
|
||
- 过程总结卡片
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
|
||
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
|
||
|
||
try:
|
||
import torchvision
|
||
except Exception:
|
||
pass
|
||
|
||
import time
|
||
from pathlib import Path
|
||
|
||
import streamlit as st
|
||
|
||
from dotenv import load_dotenv
|
||
load_dotenv()
|
||
|
||
from agent.graph import build_graph, create_initial_state
|
||
from backend.session import (
|
||
create_session,
|
||
load_session,
|
||
delete_session,
|
||
list_all_sessions,
|
||
)
|
||
from backend.logger import get_logger, set_trace_id, generate_trace_id
|
||
|
||
_app_log = get_logger("app")
|
||
|
||
st.set_page_config(
|
||
page_title="JRXML 代理",
|
||
page_icon="📊",
|
||
layout="wide",
|
||
initial_sidebar_state="expanded",
|
||
)
|
||
|
||
# 阻止 Streamlit 裸 'c' 键清除缓存,保留 Ctrl+C 复制行为
|
||
st.html("""
|
||
<script>
|
||
(function() {
|
||
const parent = window.parent.document;
|
||
parent.addEventListener('keydown', function(e) {
|
||
// 仅拦截裸 'c' 键(非 Ctrl/Cmd 组合)
|
||
if (e.key === 'c' && !e.ctrlKey && !e.metaKey && !e.altKey) {
|
||
const tag = parent.activeElement ? parent.activeElement.tagName : '';
|
||
if (tag !== 'INPUT' && tag !== 'TEXTAREA' && !parent.activeElement.isContentEditable) {
|
||
e.stopImmediatePropagation();
|
||
e.preventDefault();
|
||
}
|
||
}
|
||
}, true);
|
||
})();
|
||
</script>
|
||
""")
|
||
|
||
# ---- 节点名称 → 中文标签 ----
|
||
NODE_LABELS = {
|
||
"load_session": "📂 加载会话",
|
||
"process_input": "📝 记录输入",
|
||
"manage_context": "🧠 管理上下文",
|
||
"save_state_snapshot": "💾 保存快照",
|
||
"classify_intent": "🔍 识别意图",
|
||
"retrieve": "📚 检索模板",
|
||
"generate": "⚙️ 生成 JRXML",
|
||
"modify_jrxml": "🔧 修改 JRXML",
|
||
"validate": "✅ 验证",
|
||
"explain_error": "🔎 分析错误",
|
||
"correct_jrxml": "🛠 自动修正",
|
||
"finalize": "📋 完成",
|
||
"handle_consult": "💬 咨询回答",
|
||
"handle_undo": "↩ 撤销操作",
|
||
"handle_reset": "🔄 重置会话",
|
||
"save_session": "💾 保存会话",
|
||
}
|
||
|
||
INTENT_LABELS = {
|
||
"initial_generation": "新建报表",
|
||
"modify_report": "修改报表",
|
||
"preview_report": "预览报表",
|
||
"export_pdf": "导出 PDF",
|
||
"export_jrxml": "下载 JRXML",
|
||
"undo_modification": "撤销修改",
|
||
"consult_question": "咨询问题",
|
||
"reset_session": "重置会话",
|
||
}
|
||
|
||
SKIP_NODES = {"load_session", "process_input", "manage_context",
|
||
"save_state_snapshot", "save_session"}
|
||
|
||
|
||
def _render_jrxml(jrxml: str, max_lines: int = 30):
|
||
"""展示 JRXML 代码(折叠、限行)。"""
|
||
lines = jrxml.strip().split("\n")
|
||
preview = "\n".join(lines[:max_lines])
|
||
if len(lines) > max_lines:
|
||
preview += f"\n... (共 {len(lines)} 行)"
|
||
st.code(preview, language="xml")
|
||
|
||
|
||
# ---- URL 参数 ----
|
||
query_params = st.query_params
|
||
url_session_id = query_params.get("session_id", "")
|
||
|
||
# ---- 会话状态初始化 ----
|
||
if "messages" not in st.session_state:
|
||
st.session_state.messages = []
|
||
if "graph" not in st.session_state:
|
||
st.session_state.graph = build_graph()
|
||
if "pending_action" not in st.session_state:
|
||
st.session_state.pending_action = None
|
||
|
||
if "agent_state" not in st.session_state:
|
||
if url_session_id:
|
||
data = load_session(url_session_id)
|
||
if data and data.get("agent_state"):
|
||
st.session_state.agent_state = data["agent_state"]
|
||
st.session_state.agent_state["session_id"] = url_session_id
|
||
else:
|
||
st.session_state.agent_state = create_initial_state()
|
||
new_data = create_session(name="", agent_state=st.session_state.agent_state)
|
||
st.session_state.agent_state["session_id"] = new_data["session_id"]
|
||
st.session_state.agent_state["session_name"] = new_data["session_name"]
|
||
st.session_state.agent_state["created_at"] = new_data["created_at"]
|
||
else:
|
||
st.session_state.agent_state = create_initial_state()
|
||
new_data = create_session(name="", agent_state=st.session_state.agent_state)
|
||
st.session_state.agent_state["session_id"] = new_data["session_id"]
|
||
st.session_state.agent_state["session_name"] = new_data["session_name"]
|
||
st.session_state.agent_state["created_at"] = new_data["created_at"]
|
||
|
||
current_session_id = st.session_state.agent_state.get("session_id", "")
|
||
|
||
|
||
def run_agent(user_input: str):
|
||
"""运行代理图:流式渲染节点进度 + LLM 文本。"""
|
||
trace_id = generate_trace_id()
|
||
set_trace_id(trace_id)
|
||
agent_state = st.session_state.agent_state
|
||
session_id = agent_state.get("session_id", "")
|
||
|
||
_app_log.info(
|
||
"代理执行开始",
|
||
extra={
|
||
"session_id": session_id,
|
||
"trace_id": trace_id,
|
||
"user_input_preview": user_input[:200],
|
||
"user_input_length": len(user_input),
|
||
"has_jrxml": bool(agent_state.get("current_jrxml", "").strip()),
|
||
"intent": agent_state.get("intent", ""),
|
||
},
|
||
)
|
||
|
||
if agent_state.get("current_jrxml") and agent_state.get("status") == "pass":
|
||
agent_state["user_modification_request"] = user_input
|
||
|
||
agent_state["user_input"] = user_input
|
||
agent_state["retry_count"] = 0
|
||
|
||
# ---- UI 占位 ----
|
||
progress_placeholder = st.empty() # 实时节点进度
|
||
streaming_placeholder = st.empty() # 流式文本
|
||
summary_placeholder = st.empty() # 总结卡片
|
||
|
||
# 初始状态提示
|
||
progress_placeholder.info("⏳ 正在分析您的需求...")
|
||
|
||
executed_nodes: list[dict] = []
|
||
stream_text = ""
|
||
stream_active = False
|
||
final_state = None
|
||
|
||
def _render_progress(nodes: list[dict]):
|
||
"""渲染实时节点进度到占位符。"""
|
||
if not nodes:
|
||
return
|
||
lines = []
|
||
for i, node in enumerate(nodes):
|
||
icon = "●" if i == len(nodes) - 1 else "✓"
|
||
detail = f" — {node['detail']}" if node.get("detail") else ""
|
||
lines.append(f"{icon} {node['label']}{detail}")
|
||
progress_placeholder.markdown("\n\n".join(lines))
|
||
|
||
try:
|
||
for event in st.session_state.graph.stream(
|
||
agent_state, stream_mode=["updates", "custom"]
|
||
):
|
||
mode, data = event
|
||
|
||
if mode == "updates":
|
||
for node_name, node_state in data.items():
|
||
label = NODE_LABELS.get(node_name, node_name)
|
||
if node_name not in SKIP_NODES:
|
||
executed_nodes.append({
|
||
"name": node_name,
|
||
"label": label,
|
||
})
|
||
|
||
if node_name == "classify_intent":
|
||
intent = node_state.get("intent", "")
|
||
il = INTENT_LABELS.get(intent, intent)
|
||
executed_nodes[-1]["detail"] = f"意图: {il}"
|
||
|
||
elif node_name == "retrieve":
|
||
ctx = node_state.get("retrieved_context", "")
|
||
executed_nodes[-1]["detail"] = (
|
||
f"找到 {len(ctx)} 字符参考模板" if ctx else "未匹配到模板"
|
||
)
|
||
|
||
elif node_name in ("generate", "modify_jrxml", "correct_jrxml"):
|
||
jrxml = node_state.get("current_jrxml", "")
|
||
executed_nodes[-1]["detail"] = f"生成 {len(jrxml)} 字符 JRXML"
|
||
|
||
elif node_name == "validate":
|
||
status = node_state.get("status", "")
|
||
if status == "pass":
|
||
executed_nodes[-1]["detail"] = "验证通过 ✓"
|
||
else:
|
||
err = node_state.get("error_msg", "")
|
||
executed_nodes[-1]["detail"] = f"验证失败: {err[:80]}"
|
||
|
||
elif node_name == "explain_error":
|
||
expl = node_state.get("natural_explanation", "")
|
||
executed_nodes[-1]["detail"] = expl[:120]
|
||
|
||
elif node_name == "handle_consult":
|
||
ans = node_state.get("consult_answer", "")
|
||
executed_nodes[-1]["detail"] = ans[:150]
|
||
|
||
final_state = node_state
|
||
|
||
# 每个节点完成后立即更新进度
|
||
_render_progress(executed_nodes)
|
||
|
||
elif mode == "custom":
|
||
cd = data
|
||
if cd.get("type") == "stream":
|
||
stream_text += cd.get("text", "")
|
||
stream_active = True
|
||
streaming_placeholder.code(stream_text, language="xml")
|
||
|
||
except Exception as e:
|
||
progress_placeholder.empty()
|
||
_app_log.error(
|
||
f"代理执行异常: {e}",
|
||
extra={"session_id": session_id, "error": str(e)},
|
||
)
|
||
st.error(f"工作流异常: {e}")
|
||
return
|
||
|
||
# ---- 清理临时占位 ----
|
||
progress_placeholder.empty()
|
||
if stream_active:
|
||
streaming_placeholder.empty()
|
||
|
||
# 清理已处理的临时文件
|
||
for p in st.session_state.get("uploaded_temp_paths", []):
|
||
try:
|
||
Path(p).unlink(missing_ok=True)
|
||
except Exception:
|
||
pass
|
||
st.session_state.uploaded_temp_paths = []
|
||
|
||
# ---- 总结卡片 ----
|
||
# 注:node_state 只含变更字段,用 agent_state(被所有节点就地修改)获取完整状态
|
||
final_state = agent_state
|
||
if final_state:
|
||
st.session_state.agent_state = final_state
|
||
intent = final_state.get("intent", "")
|
||
status = final_state.get("status", "")
|
||
|
||
with summary_placeholder.container(border=True):
|
||
if intent == "consult_question":
|
||
answer = final_state.get("consult_answer", "")
|
||
st.info(answer)
|
||
st.session_state.messages.append({
|
||
"role": "assistant", "content": answer, "type": "consult",
|
||
})
|
||
|
||
elif intent in ("undo_modification", "reset_session"):
|
||
st.success("操作已完成")
|
||
|
||
elif intent in ("preview_report", "export_pdf", "export_jrxml"):
|
||
jrxml = final_state.get("current_jrxml", "")
|
||
if jrxml:
|
||
st.success("✅ 当前报表")
|
||
_render_jrxml(jrxml)
|
||
st.session_state.messages.append({
|
||
"role": "assistant", "content": jrxml, "type": "jrxml",
|
||
})
|
||
else:
|
||
st.warning("⚠ 当前没有报表可以展示。")
|
||
|
||
elif status == "pass":
|
||
jrxml = final_state.get("current_jrxml", "")
|
||
st.success("✅ JRXML 生成成功")
|
||
st.markdown("**生成结果:**")
|
||
_render_jrxml(jrxml)
|
||
st.caption("您可以从侧边栏下载文件,或继续对话进行修改。")
|
||
st.session_state.messages.append({
|
||
"role": "assistant", "content": jrxml, "type": "jrxml",
|
||
})
|
||
st.session_state.messages.append({
|
||
"role": "assistant",
|
||
"content": "✅ JRXML 生成成功!您可以从侧边栏下载文件,或继续修改。",
|
||
"type": "success",
|
||
})
|
||
|
||
else:
|
||
jrxml = final_state.get("current_jrxml", "")
|
||
error_msg = final_state.get("error_msg", "未知错误")
|
||
explanation = final_state.get("natural_explanation", "")
|
||
retries = final_state.get("retry_count", 0)
|
||
st.error(f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML")
|
||
st.markdown(f"**错误:** {error_msg}")
|
||
if explanation:
|
||
st.markdown(f"**原因:** {explanation}")
|
||
if jrxml:
|
||
with st.expander("查看当前 JRXML"):
|
||
_render_jrxml(jrxml, max_lines=80)
|
||
st.caption("💡 下次输入修改需求时,系统会自动加载失败上下文继续修复。")
|
||
st.session_state.messages.append({
|
||
"role": "assistant",
|
||
"content": f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML。\n\n**错误:** {error_msg}\n\n💡 请直接描述修改需求,系统会自动加载失败上下文。",
|
||
"type": "error_explanation",
|
||
})
|
||
|
||
# OCR 字段提取结果展示
|
||
ocr_result = agent_state.get("ocr_extraction_result", {})
|
||
if ocr_result and ocr_result.get("ocr_available") and ocr_result.get("fields"):
|
||
with st.expander("🔍 OCR 单据字段提取结果", expanded=False):
|
||
fields = ocr_result.get("fields", [])
|
||
non_empty = [f for f in fields if f.get("field_value")]
|
||
empty = [f for f in fields if not f.get("field_value")]
|
||
if non_empty:
|
||
st.markdown("**已提取字段:**")
|
||
for f in non_empty:
|
||
method = f.get("extraction_method", "")
|
||
conf = f.get("confidence", 0)
|
||
st.markdown(
|
||
f"- **{f['field_name']}**: `{f['field_value']}` "
|
||
f"(置信度: {conf:.0%}, 方法: {method})"
|
||
)
|
||
if empty:
|
||
st.caption(
|
||
f"未提取到值的字段: {', '.join(f['field_name'] for f in empty)}"
|
||
)
|
||
st.caption(
|
||
f"共检测到 {ocr_result.get('total_elements', 0)} 个文本元素"
|
||
)
|
||
else:
|
||
st.error("未产生结果,请重试。")
|
||
|
||
_app_log.info(
|
||
"代理执行完成",
|
||
extra={
|
||
"session_id": session_id,
|
||
"intent": final_state.get("intent", ""),
|
||
"status": final_state.get("status", ""),
|
||
"jrxml_length": len(final_state.get("current_jrxml", "")),
|
||
"retry_count": final_state.get("retry_count", 0),
|
||
},
|
||
)
|
||
|
||
|
||
# ---- 侧边栏 ----
|
||
with st.sidebar:
|
||
st.title("📊 JRXML 代理")
|
||
st.markdown("通过自然语言生成 JasperReports 模板。")
|
||
st.divider()
|
||
|
||
# 会话管理
|
||
st.markdown("### 会话管理")
|
||
sessions = list_all_sessions()
|
||
session_options = {}
|
||
for s in sessions:
|
||
sid = s["session_id"]
|
||
name = s.get("session_name", sid)
|
||
updated = s.get("updated_at", "")[:16]
|
||
session_options[f"{name} ({updated})"] = sid
|
||
|
||
selected_label = None
|
||
for label, sid in session_options.items():
|
||
if sid == current_session_id:
|
||
selected_label = label
|
||
break
|
||
|
||
selected = st.selectbox(
|
||
"切换会话",
|
||
options=list(session_options.keys()),
|
||
index=list(session_options.keys()).index(selected_label) if selected_label else 0,
|
||
key="session_selector",
|
||
)
|
||
|
||
if selected and session_options.get(selected) != current_session_id:
|
||
new_sid = session_options[selected]
|
||
data = load_session(new_sid)
|
||
if data and data.get("agent_state"):
|
||
_app_log.info(
|
||
"切换会话",
|
||
extra={"from_session": current_session_id, "to_session": new_sid},
|
||
)
|
||
st.session_state.agent_state = data["agent_state"]
|
||
st.session_state.messages = []
|
||
st.rerun()
|
||
|
||
col1, col2 = st.columns(2)
|
||
with col1:
|
||
if st.button("➕ 新建", use_container_width=True):
|
||
new_data = create_session(name="", agent_state=create_initial_state())
|
||
_app_log.info(
|
||
"新建会话",
|
||
extra={"session_id": new_data["session_id"]},
|
||
)
|
||
st.session_state.agent_state = create_initial_state()
|
||
st.session_state.agent_state["session_id"] = new_data["session_id"]
|
||
st.session_state.agent_state["session_name"] = new_data["session_name"]
|
||
st.session_state.agent_state["created_at"] = new_data["created_at"]
|
||
st.session_state.messages = []
|
||
st.rerun()
|
||
with col2:
|
||
if st.button("🗑 删除", use_container_width=True):
|
||
if current_session_id:
|
||
_app_log.info(
|
||
"删除会话",
|
||
extra={"session_id": current_session_id},
|
||
)
|
||
delete_session(current_session_id)
|
||
st.session_state.agent_state = create_initial_state()
|
||
new_data = create_session(name="", agent_state=st.session_state.agent_state)
|
||
st.session_state.agent_state["session_id"] = new_data["session_id"]
|
||
st.session_state.agent_state["session_name"] = new_data["session_name"]
|
||
st.session_state.agent_state["created_at"] = new_data["created_at"]
|
||
st.session_state.messages = []
|
||
st.rerun()
|
||
|
||
current_name = st.session_state.agent_state.get("session_name", "")
|
||
st.caption(f"当前: {current_name} (`{current_session_id}`)")
|
||
|
||
st.divider()
|
||
st.markdown("### 快捷操作")
|
||
|
||
has_jrxml = bool(st.session_state.agent_state.get("current_jrxml", "").strip())
|
||
has_history = bool(st.session_state.agent_state.get("history_states", []))
|
||
|
||
qcol1, qcol2 = st.columns(2)
|
||
with qcol1:
|
||
if st.button("👁 预览", use_container_width=True, disabled=not has_jrxml):
|
||
with st.spinner("正在准备预览..."):
|
||
run_agent("预览报表")
|
||
st.rerun()
|
||
with qcol2:
|
||
if st.button("↩ 撤销", use_container_width=True, disabled=not has_history):
|
||
with st.spinner("正在撤销..."):
|
||
run_agent("撤销上一步修改")
|
||
st.rerun()
|
||
|
||
if st.button("🔄 重置会话", use_container_width=True):
|
||
with st.spinner("正在重置..."):
|
||
run_agent("重新来,清空当前报表")
|
||
st.rerun()
|
||
|
||
st.divider()
|
||
st.markdown("### 上传文件")
|
||
st.caption("支持图片 (OCR)、PDF、Word、文本文件。内容将附加到您的下一条消息中。")
|
||
|
||
if "uploaded_files" not in st.session_state:
|
||
st.session_state.uploaded_files = [] # [{name, text, type}]
|
||
|
||
if "uploaded_temp_paths" not in st.session_state:
|
||
st.session_state.uploaded_temp_paths = [] # 待清理的临时文件路径
|
||
|
||
uploaded = st.file_uploader(
|
||
"选择文件",
|
||
type=["png", "jpg", "jpeg", "bmp", "webp", "pdf", "docx", "txt", "csv", "json", "xml"],
|
||
accept_multiple_files=True,
|
||
key="file_uploader",
|
||
label_visibility="collapsed",
|
||
)
|
||
|
||
if uploaded:
|
||
for uf in uploaded:
|
||
# 去重
|
||
if any(f["name"] == uf.name for f in st.session_state.uploaded_files):
|
||
continue
|
||
import tempfile
|
||
from backend.file_parser import parse_file
|
||
from backend.layout_analyzer import analyze_layout
|
||
|
||
suffix = Path(uf.name).suffix.lower()
|
||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
||
tmp.write(uf.getvalue())
|
||
tmp_path = tmp.name
|
||
|
||
result = parse_file(tmp_path, suffix)
|
||
|
||
# 对图片/PDF 进行 A4 模板布局分析
|
||
parsed_text = result["text"]
|
||
parsed_type = result["file_type"]
|
||
if suffix in (".png", ".jpg", ".jpeg", ".bmp", ".webp", ".pdf"):
|
||
layout = analyze_layout(tmp_path)
|
||
tt = layout.get("template_type", "unknown")
|
||
current_jrxml = st.session_state.agent_state.get("current_jrxml", "")
|
||
|
||
if tt == "full_a4":
|
||
parsed_text = layout["description"]
|
||
parsed_type = "a4_template"
|
||
elif tt == "partial_rows":
|
||
parsed_type = "a4_partial"
|
||
if current_jrxml.strip():
|
||
# 修改模式:尝试行匹配
|
||
from backend.layout_analyzer import match_rows_to_jrxml
|
||
match = match_rows_to_jrxml(layout, current_jrxml)
|
||
parsed_text = (
|
||
f"[行片段修改] 上传图片包含 {layout['total_rows']} 行,"
|
||
f"视为 A4 报表的一部分。\n\n"
|
||
f"{match['description']}\n\n"
|
||
f"--- 行结构 ---\n{layout['description']}"
|
||
)
|
||
else:
|
||
# 新建模式:按 A4 模板处理
|
||
parsed_text = layout["description"]
|
||
else:
|
||
# tt == "unknown": OCR 不可用或未检测到文字元素
|
||
has_ocr = result.get("method") not in ("metadata_only", None)
|
||
img_w, img_h = layout["image_size"]
|
||
ratio = layout["aspect_ratio"]
|
||
if has_ocr:
|
||
parsed_text = (
|
||
f"[图片上传] 尺寸 {img_w}x{img_h}px, 比例 {ratio}。"
|
||
f"未检测到 A4 报表结构,图片将被视为参考样式。\n"
|
||
f"请根据用户的文字描述生成报表。"
|
||
)
|
||
else:
|
||
parsed_text = (
|
||
f"[图片上传] 尺寸 {img_w}x{img_h}px, 比例 {ratio}。\n"
|
||
f"⚠ OCR 引擎未安装,无法识别图片中的文字内容。\n"
|
||
f"请严格根据用户的文字描述来推断图片中的报表需求。\n"
|
||
f"(提示:如需图片文字识别,请运行 pip install paddleocr)"
|
||
)
|
||
parsed_type = "image_reference"
|
||
|
||
if parsed_text:
|
||
st.session_state.uploaded_files.append({
|
||
"name": uf.name,
|
||
"text": parsed_text,
|
||
"type": parsed_type,
|
||
})
|
||
|
||
# 对图片类型,保存路径以便 OCR 字段提取(延迟到 process_input 阶段)
|
||
img_suffixes = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
|
||
if suffix in img_suffixes and result.get("method") not in ("metadata_only", None):
|
||
st.session_state.agent_state["uploaded_file_path"] = tmp_path
|
||
st.session_state.uploaded_temp_paths.append(tmp_path)
|
||
else:
|
||
Path(tmp_path).unlink(missing_ok=True)
|
||
|
||
if st.session_state.uploaded_files:
|
||
for i, f in enumerate(st.session_state.uploaded_files):
|
||
cols = st.columns([5, 1])
|
||
with cols[0]:
|
||
st.caption(f"📎 {f['name']} ({f['type']}, {len(f['text'])} 字符)")
|
||
with cols[1]:
|
||
if st.button("✕", key=f"rm_uf_{i}", help="移除"):
|
||
st.session_state.uploaded_files.pop(i)
|
||
st.rerun()
|
||
|
||
st.divider()
|
||
st.markdown("### 配置")
|
||
llm_backend = os.getenv("LLM_BACKEND", "cloud")
|
||
llm_model = os.getenv("LLM_MODEL", os.getenv("LOCAL_LLM_MODEL", "gpt-4o"))
|
||
st.caption(f"大语言模型: {llm_backend} / {llm_model}")
|
||
st.caption(f"最大重试次数: {os.getenv('MAX_RETRY', '3')}")
|
||
st.caption(f"验证服务: {os.getenv('VALIDATION_SERVICE_URL', 'http://localhost:8001/validate')}")
|
||
|
||
st.divider()
|
||
st.markdown("### 下载")
|
||
|
||
final = st.session_state.agent_state.get("final_jrxml", "")
|
||
versions = st.session_state.agent_state.get("jrxml_versions", [])
|
||
|
||
if final:
|
||
st.download_button(
|
||
label="📥 下载最新 JRXML",
|
||
data=final,
|
||
file_name="report.jrxml",
|
||
mime="application/xml",
|
||
use_container_width=True,
|
||
)
|
||
|
||
if versions:
|
||
with st.expander("📋 历史版本", expanded=False):
|
||
for i, v in enumerate(reversed(versions)):
|
||
ts = v.get("ts", "")[:16]
|
||
label = v.get("label", "版本")
|
||
status = v.get("status", "")
|
||
icon = "✅" if status == "pass" else "❌"
|
||
dl_label = f"{icon} v{len(versions)-i} — {label} ({ts})"
|
||
st.download_button(
|
||
label=dl_label,
|
||
data=v.get("jrxml", ""),
|
||
file_name=f"report_v{len(versions)-i}.jrxml",
|
||
mime="application/xml",
|
||
use_container_width=True,
|
||
key=f"dl_v{i}",
|
||
)
|
||
|
||
# ---- 标题 ----
|
||
st.title("📝 JRXML 报表生成器")
|
||
st.caption("用自然语言描述您的报表需求,我将逐步生成可用的 JRXML 模板。")
|
||
|
||
# ---- 聊天历史 ----
|
||
for msg in st.session_state.messages:
|
||
with st.chat_message(msg["role"]):
|
||
if msg.get("type") == "jrxml":
|
||
with st.expander("查看生成的 JRXML", expanded=False):
|
||
st.code(msg["content"], language="xml")
|
||
elif msg.get("type") == "error_explanation":
|
||
st.warning(msg["content"])
|
||
elif msg.get("type") == "success":
|
||
st.success(msg["content"])
|
||
elif msg.get("type") == "consult":
|
||
st.info(msg["content"])
|
||
else:
|
||
st.markdown(msg["content"])
|
||
|
||
# ---- 聊天输入 ----
|
||
if prompt := st.chat_input("描述您的报表需求..."):
|
||
# 拼接上传文件的文本
|
||
uploaded_texts = []
|
||
uploaded_files_info = []
|
||
if st.session_state.get("uploaded_files"):
|
||
for f in st.session_state.uploaded_files:
|
||
uploaded_texts.append(f"[上传文件: {f['name']}]\n{f['text']}")
|
||
uploaded_files_info.append({"name": f["name"], "type": f["type"], "length": len(f["text"])})
|
||
if uploaded_texts:
|
||
full_prompt = "\n\n".join(uploaded_texts) + "\n\n---\n用户需求:\n" + prompt
|
||
st.session_state.uploaded_files = [] # 用后即清
|
||
else:
|
||
full_prompt = prompt
|
||
|
||
_app_log.info(
|
||
"收到用户输入",
|
||
extra={
|
||
"session_id": current_session_id,
|
||
"prompt_preview": prompt[:200],
|
||
"prompt_length": len(prompt),
|
||
"has_uploaded_files": bool(uploaded_files_info),
|
||
"uploaded_files": uploaded_files_info,
|
||
},
|
||
)
|
||
|
||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||
with st.chat_message("user"):
|
||
st.markdown(prompt)
|
||
run_agent(full_prompt)
|
||
st.rerun()
|