Streaming

This commit is contained in:
马一丁
2025-11-13 22:30:36 +08:00
parent 1c2f82e285
commit fa787af135
3 changed files with 402 additions and 39 deletions
+130 -9
View File
@@ -8,7 +8,7 @@ import os
from pathlib import Path
from uuid import uuid4
from datetime import datetime
from typing import Optional, Dict, Any, List
from typing import Optional, Dict, Any, List, Callable
from loguru import logger
@@ -23,6 +23,7 @@ from .llms import LLMClient
from .nodes import (
TemplateSelectionNode,
ChapterGenerationNode,
ChapterJsonParseError,
DocumentLayoutNode,
WordBudgetNode,
)
@@ -205,10 +206,11 @@ class ReportAgent:
)
def generate_report(self, query: str, reports: List[Any], forum_logs: str = "",
custom_template: str = "", save_report: bool = True) -> str:
custom_template: str = "", save_report: bool = True,
stream_handler: Optional[Callable[[str, Dict[str, Any]], None]] = None) -> str:
"""
生成综合报告(章节JSON → IR → HTML
Returns:
dict: HTML内容以及保存的文件路径信息
"""
@@ -220,15 +222,32 @@ class ReportAgent:
self.state.mark_processing()
normalized_reports = self._normalize_reports(reports)
def emit(event_type: str, payload: Dict[str, Any]):
"""面向Report Engine流通道的事件分发器,保证错误不外泄。"""
if not stream_handler:
return
try:
stream_handler(event_type, payload)
except Exception as callback_error: # pragma: no cover - 仅记录
logger.warning(f"流式事件回调失败: {callback_error}")
logger.info(f"开始生成报告 {report_id}: {query}")
logger.info(f"输入数据 - 报告数量: {len(reports)}, 论坛日志长度: {len(str(forum_logs))}")
emit('stage', {'stage': 'agent_start', 'report_id': report_id, 'query': query})
try:
template_result = self._select_template(query, reports, forum_logs, custom_template)
self.state.metadata.template_used = template_result.get('template_name', '')
emit('stage', {
'stage': 'template_selected',
'template': template_result.get('template_name'),
'reason': template_result.get('selection_reason')
})
sections = self._slice_template(template_result.get('template_content', ''))
if not sections:
raise ValueError("模板无法解析出章节,请检查模板内容。")
emit('stage', {'stage': 'template_sliced', 'section_count': len(sections)})
template_text = template_result.get('template_content', '')
template_overview = self._build_template_overview(template_text, sections)
@@ -241,6 +260,11 @@ class ReportAgent:
query,
template_overview,
)
emit('stage', {
'stage': 'layout_designed',
'title': layout_design.get('title'),
'toc': layout_design.get('tocTitle')
})
# 使用刚生成的设计稿对全书进行篇幅规划,约束各章字数与重点
word_plan = self.word_budget_node.run(
sections,
@@ -250,6 +274,10 @@ class ReportAgent:
query,
template_overview,
)
emit('stage', {
'stage': 'word_plan_ready',
'chapter_targets': len(word_plan.get('chapters', []))
})
# 记录每个章节的目标字数/强调点,后续传给章节LLM
chapter_targets = {
entry.get("chapterId"): entry
@@ -296,23 +324,97 @@ class ReportAgent:
# 初始化章节输出目录并写入manifest,方便流式存盘
run_dir = self.chapter_storage.start_session(report_id, manifest_meta)
self._persist_planning_artifacts(run_dir, layout_design, word_plan, template_overview)
emit('stage', {'stage': 'storage_ready', 'run_dir': str(run_dir)})
chapters = []
chapter_max_attempts = max(1, self.config.CHAPTER_JSON_MAX_ATTEMPTS)
for section in sections:
logger.info(f"生成章节: {section.title}")
chapter = self.chapter_generation_node.run(
section,
generation_context,
run_dir
)
chapters.append(chapter)
emit('chapter_status', {
'chapterId': section.chapter_id,
'title': section.title,
'status': 'running'
})
# 章节流式回调:把LLM返回的delta透传给SSE,便于前端实时渲染
def chunk_callback(delta: str, meta: Dict[str, Any], section_ref: TemplateSection = section):
emit('chapter_chunk', {
'chapterId': meta.get('chapterId') or section_ref.chapter_id,
'title': meta.get('title') or section_ref.title,
'delta': delta
})
chapter_payload: Dict[str, Any] | None = None
attempt = 1
while attempt <= chapter_max_attempts:
try:
chapter_payload = self.chapter_generation_node.run(
section,
generation_context,
run_dir,
stream_callback=chunk_callback
)
break
except ChapterJsonParseError as parse_error:
logger.warning(
"章节 %s JSON解析失败(第 %s/%s 次尝试): %s",
section.title,
attempt,
chapter_max_attempts,
parse_error,
)
emit('chapter_status', {
'chapterId': section.chapter_id,
'title': section.title,
'status': 'retrying' if attempt < chapter_max_attempts else 'error',
'attempt': attempt,
'error': str(parse_error),
})
if attempt >= chapter_max_attempts:
raise
attempt += 1
continue
except Exception as chapter_error:
if not self._should_retry_inappropriate_content_error(chapter_error):
raise
logger.warning(
"章节 %s 触发内容安全限制(第 %s/%s 次尝试),准备重新生成: %s",
section.title,
attempt,
chapter_max_attempts,
chapter_error,
)
emit('chapter_status', {
'chapterId': section.chapter_id,
'title': section.title,
'status': 'retrying' if attempt < chapter_max_attempts else 'error',
'attempt': attempt,
'error': str(chapter_error),
'reason': 'content_filter'
})
if attempt >= chapter_max_attempts:
raise
attempt += 1
continue
if chapter_payload is None:
raise ChapterJsonParseError(
f"{section.title} 章节JSON在 {chapter_max_attempts} 次尝试后仍无法解析"
)
chapters.append(chapter_payload)
emit('chapter_status', {
'chapterId': section.chapter_id,
'title': section.title,
'status': 'completed',
'attempt': attempt,
})
document_ir = self.document_composer.build_document(
report_id,
manifest_meta,
chapters
)
emit('stage', {'stage': 'chapters_compiled', 'chapter_count': len(chapters)})
html_report = self.renderer.render(document_ir)
emit('stage', {'stage': 'html_rendered', 'html_length': len(html_report)})
self.state.html_content = html_report
self.state.mark_completed()
@@ -320,10 +422,12 @@ class ReportAgent:
saved_files = {}
if save_report:
saved_files = self._save_report(html_report, document_ir, report_id)
emit('stage', {'stage': 'report_saved', 'files': saved_files})
generation_time = (datetime.now() - start_time).total_seconds()
self.state.metadata.generation_time = generation_time
logger.info(f"报告生成完成,耗时: {generation_time:.2f}")
emit('metrics', {'generation_seconds': generation_time})
return {
'html_content': html_report,
'report_id': report_id,
@@ -333,6 +437,7 @@ class ReportAgent:
except Exception as e:
self.state.mark_failed(str(e))
logger.exception(f"报告生成过程中发生错误: {str(e)}")
emit('error', {'stage': 'agent_failed', 'message': str(e)})
raise
def _select_template(self, query: str, reports: List[Any], forum_logs: str, custom_template: str):
@@ -444,6 +549,22 @@ class ReportAgent:
normalized[key] = self._stringify(value)
return normalized
def _should_retry_inappropriate_content_error(self, error: Exception) -> bool:
"""
判断LLM异常是否由内容安全/不当内容导致,满足时允许重新生成整章。
"""
message = str(error) if error else ""
if not message:
return False
normalized = message.lower()
keywords = [
"inappropriate content",
"content violation",
"content moderation",
"model-studio/error-code",
]
return any(keyword in normalized for keyword in keywords)
def _stringify(self, value: Any) -> str:
"""安全地将对象转成字符串"""
if value is None: