Streaming
This commit is contained in:
+130
-9
@@ -8,7 +8,7 @@ import os
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from typing import Optional, Dict, Any, List, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -23,6 +23,7 @@ from .llms import LLMClient
|
||||
from .nodes import (
|
||||
TemplateSelectionNode,
|
||||
ChapterGenerationNode,
|
||||
ChapterJsonParseError,
|
||||
DocumentLayoutNode,
|
||||
WordBudgetNode,
|
||||
)
|
||||
@@ -205,10 +206,11 @@ class ReportAgent:
|
||||
)
|
||||
|
||||
def generate_report(self, query: str, reports: List[Any], forum_logs: str = "",
|
||||
custom_template: str = "", save_report: bool = True) -> str:
|
||||
custom_template: str = "", save_report: bool = True,
|
||||
stream_handler: Optional[Callable[[str, Dict[str, Any]], None]] = None) -> str:
|
||||
"""
|
||||
生成综合报告(章节JSON → IR → HTML)
|
||||
|
||||
|
||||
Returns:
|
||||
dict: HTML内容以及保存的文件路径信息
|
||||
"""
|
||||
@@ -220,15 +222,32 @@ class ReportAgent:
|
||||
self.state.mark_processing()
|
||||
|
||||
normalized_reports = self._normalize_reports(reports)
|
||||
|
||||
def emit(event_type: str, payload: Dict[str, Any]):
|
||||
"""面向Report Engine流通道的事件分发器,保证错误不外泄。"""
|
||||
if not stream_handler:
|
||||
return
|
||||
try:
|
||||
stream_handler(event_type, payload)
|
||||
except Exception as callback_error: # pragma: no cover - 仅记录
|
||||
logger.warning(f"流式事件回调失败: {callback_error}")
|
||||
|
||||
logger.info(f"开始生成报告 {report_id}: {query}")
|
||||
logger.info(f"输入数据 - 报告数量: {len(reports)}, 论坛日志长度: {len(str(forum_logs))}")
|
||||
emit('stage', {'stage': 'agent_start', 'report_id': report_id, 'query': query})
|
||||
|
||||
try:
|
||||
template_result = self._select_template(query, reports, forum_logs, custom_template)
|
||||
self.state.metadata.template_used = template_result.get('template_name', '')
|
||||
emit('stage', {
|
||||
'stage': 'template_selected',
|
||||
'template': template_result.get('template_name'),
|
||||
'reason': template_result.get('selection_reason')
|
||||
})
|
||||
sections = self._slice_template(template_result.get('template_content', ''))
|
||||
if not sections:
|
||||
raise ValueError("模板无法解析出章节,请检查模板内容。")
|
||||
emit('stage', {'stage': 'template_sliced', 'section_count': len(sections)})
|
||||
|
||||
template_text = template_result.get('template_content', '')
|
||||
template_overview = self._build_template_overview(template_text, sections)
|
||||
@@ -241,6 +260,11 @@ class ReportAgent:
|
||||
query,
|
||||
template_overview,
|
||||
)
|
||||
emit('stage', {
|
||||
'stage': 'layout_designed',
|
||||
'title': layout_design.get('title'),
|
||||
'toc': layout_design.get('tocTitle')
|
||||
})
|
||||
# 使用刚生成的设计稿对全书进行篇幅规划,约束各章字数与重点
|
||||
word_plan = self.word_budget_node.run(
|
||||
sections,
|
||||
@@ -250,6 +274,10 @@ class ReportAgent:
|
||||
query,
|
||||
template_overview,
|
||||
)
|
||||
emit('stage', {
|
||||
'stage': 'word_plan_ready',
|
||||
'chapter_targets': len(word_plan.get('chapters', []))
|
||||
})
|
||||
# 记录每个章节的目标字数/强调点,后续传给章节LLM
|
||||
chapter_targets = {
|
||||
entry.get("chapterId"): entry
|
||||
@@ -296,23 +324,97 @@ class ReportAgent:
|
||||
# 初始化章节输出目录并写入manifest,方便流式存盘
|
||||
run_dir = self.chapter_storage.start_session(report_id, manifest_meta)
|
||||
self._persist_planning_artifacts(run_dir, layout_design, word_plan, template_overview)
|
||||
emit('stage', {'stage': 'storage_ready', 'run_dir': str(run_dir)})
|
||||
|
||||
chapters = []
|
||||
chapter_max_attempts = max(1, self.config.CHAPTER_JSON_MAX_ATTEMPTS)
|
||||
for section in sections:
|
||||
logger.info(f"生成章节: {section.title}")
|
||||
chapter = self.chapter_generation_node.run(
|
||||
section,
|
||||
generation_context,
|
||||
run_dir
|
||||
)
|
||||
chapters.append(chapter)
|
||||
emit('chapter_status', {
|
||||
'chapterId': section.chapter_id,
|
||||
'title': section.title,
|
||||
'status': 'running'
|
||||
})
|
||||
# 章节流式回调:把LLM返回的delta透传给SSE,便于前端实时渲染
|
||||
def chunk_callback(delta: str, meta: Dict[str, Any], section_ref: TemplateSection = section):
|
||||
emit('chapter_chunk', {
|
||||
'chapterId': meta.get('chapterId') or section_ref.chapter_id,
|
||||
'title': meta.get('title') or section_ref.title,
|
||||
'delta': delta
|
||||
})
|
||||
|
||||
chapter_payload: Dict[str, Any] | None = None
|
||||
attempt = 1
|
||||
while attempt <= chapter_max_attempts:
|
||||
try:
|
||||
chapter_payload = self.chapter_generation_node.run(
|
||||
section,
|
||||
generation_context,
|
||||
run_dir,
|
||||
stream_callback=chunk_callback
|
||||
)
|
||||
break
|
||||
except ChapterJsonParseError as parse_error:
|
||||
logger.warning(
|
||||
"章节 %s JSON解析失败(第 %s/%s 次尝试): %s",
|
||||
section.title,
|
||||
attempt,
|
||||
chapter_max_attempts,
|
||||
parse_error,
|
||||
)
|
||||
emit('chapter_status', {
|
||||
'chapterId': section.chapter_id,
|
||||
'title': section.title,
|
||||
'status': 'retrying' if attempt < chapter_max_attempts else 'error',
|
||||
'attempt': attempt,
|
||||
'error': str(parse_error),
|
||||
})
|
||||
if attempt >= chapter_max_attempts:
|
||||
raise
|
||||
attempt += 1
|
||||
continue
|
||||
except Exception as chapter_error:
|
||||
if not self._should_retry_inappropriate_content_error(chapter_error):
|
||||
raise
|
||||
logger.warning(
|
||||
"章节 %s 触发内容安全限制(第 %s/%s 次尝试),准备重新生成: %s",
|
||||
section.title,
|
||||
attempt,
|
||||
chapter_max_attempts,
|
||||
chapter_error,
|
||||
)
|
||||
emit('chapter_status', {
|
||||
'chapterId': section.chapter_id,
|
||||
'title': section.title,
|
||||
'status': 'retrying' if attempt < chapter_max_attempts else 'error',
|
||||
'attempt': attempt,
|
||||
'error': str(chapter_error),
|
||||
'reason': 'content_filter'
|
||||
})
|
||||
if attempt >= chapter_max_attempts:
|
||||
raise
|
||||
attempt += 1
|
||||
continue
|
||||
if chapter_payload is None:
|
||||
raise ChapterJsonParseError(
|
||||
f"{section.title} 章节JSON在 {chapter_max_attempts} 次尝试后仍无法解析"
|
||||
)
|
||||
chapters.append(chapter_payload)
|
||||
emit('chapter_status', {
|
||||
'chapterId': section.chapter_id,
|
||||
'title': section.title,
|
||||
'status': 'completed',
|
||||
'attempt': attempt,
|
||||
})
|
||||
|
||||
document_ir = self.document_composer.build_document(
|
||||
report_id,
|
||||
manifest_meta,
|
||||
chapters
|
||||
)
|
||||
emit('stage', {'stage': 'chapters_compiled', 'chapter_count': len(chapters)})
|
||||
html_report = self.renderer.render(document_ir)
|
||||
emit('stage', {'stage': 'html_rendered', 'html_length': len(html_report)})
|
||||
|
||||
self.state.html_content = html_report
|
||||
self.state.mark_completed()
|
||||
@@ -320,10 +422,12 @@ class ReportAgent:
|
||||
saved_files = {}
|
||||
if save_report:
|
||||
saved_files = self._save_report(html_report, document_ir, report_id)
|
||||
emit('stage', {'stage': 'report_saved', 'files': saved_files})
|
||||
|
||||
generation_time = (datetime.now() - start_time).total_seconds()
|
||||
self.state.metadata.generation_time = generation_time
|
||||
logger.info(f"报告生成完成,耗时: {generation_time:.2f} 秒")
|
||||
emit('metrics', {'generation_seconds': generation_time})
|
||||
return {
|
||||
'html_content': html_report,
|
||||
'report_id': report_id,
|
||||
@@ -333,6 +437,7 @@ class ReportAgent:
|
||||
except Exception as e:
|
||||
self.state.mark_failed(str(e))
|
||||
logger.exception(f"报告生成过程中发生错误: {str(e)}")
|
||||
emit('error', {'stage': 'agent_failed', 'message': str(e)})
|
||||
raise
|
||||
|
||||
def _select_template(self, query: str, reports: List[Any], forum_logs: str, custom_template: str):
|
||||
@@ -444,6 +549,22 @@ class ReportAgent:
|
||||
normalized[key] = self._stringify(value)
|
||||
return normalized
|
||||
|
||||
def _should_retry_inappropriate_content_error(self, error: Exception) -> bool:
|
||||
"""
|
||||
判断LLM异常是否由内容安全/不当内容导致,满足时允许重新生成整章。
|
||||
"""
|
||||
message = str(error) if error else ""
|
||||
if not message:
|
||||
return False
|
||||
normalized = message.lower()
|
||||
keywords = [
|
||||
"inappropriate content",
|
||||
"content violation",
|
||||
"content moderation",
|
||||
"model-studio/error-code",
|
||||
]
|
||||
return any(keyword in normalized for keyword in keywords)
|
||||
|
||||
def _stringify(self, value: Any) -> str:
|
||||
"""安全地将对象转成字符串"""
|
||||
if value is None:
|
||||
|
||||
Reference in New Issue
Block a user