Enhance Repair Capabilities

This commit is contained in:
马一丁
2025-11-15 15:22:31 +08:00
parent fa1ebc07ec
commit 90d12a092d
5 changed files with 365 additions and 7 deletions
+45 -2
View File
@@ -13,7 +13,7 @@ import os
from pathlib import Path
from uuid import uuid4
from datetime import datetime
from typing import Optional, Dict, Any, List, Callable
from typing import Optional, Dict, Any, List, Callable, Tuple
from loguru import logger
@@ -199,6 +199,7 @@ class ReportAgent:
# 初始化LLM客户端
self.llm_client = self._initialize_llm()
self.json_rescue_clients = self._initialize_rescue_llms()
# 初始化章级存储/校验/渲染组件
self.chapter_storage = ChapterStorage(self.config.CHAPTER_OUTPUT_DIR)
@@ -263,6 +264,46 @@ class ReportAgent:
model_name=self.config.REPORT_ENGINE_MODEL_NAME,
base_url=self.config.REPORT_ENGINE_BASE_URL,
)
def _initialize_rescue_llms(self) -> List[Tuple[str, LLMClient]]:
"""
初始化跨引擎章节修复所需的LLM客户端列表。
顺序遵循“Report → Forum → Insight → Media”,缺失配置会被自动跳过。
"""
clients: List[Tuple[str, LLMClient]] = []
if self.llm_client:
clients.append(("report_engine", self.llm_client))
fallback_specs = [
(
"forum_engine",
self.config.FORUM_HOST_API_KEY,
self.config.FORUM_HOST_MODEL_NAME,
self.config.FORUM_HOST_BASE_URL,
),
(
"insight_engine",
self.config.INSIGHT_ENGINE_API_KEY,
self.config.INSIGHT_ENGINE_MODEL_NAME,
self.config.INSIGHT_ENGINE_BASE_URL,
),
(
"media_engine",
self.config.MEDIA_ENGINE_API_KEY,
self.config.MEDIA_ENGINE_MODEL_NAME,
self.config.MEDIA_ENGINE_BASE_URL,
),
]
for label, api_key, model_name, base_url in fallback_specs:
if not api_key or not model_name:
continue
try:
client = LLMClient(api_key=api_key, model_name=model_name, base_url=base_url)
except Exception as exc:
logger.warning(f"{label} LLM初始化失败,跳过该修复通道: {exc}")
continue
clients.append((label, client))
return clients
def _initialize_nodes(self):
"""
@@ -280,7 +321,9 @@ class ReportAgent:
self.chapter_generation_node = ChapterGenerationNode(
self.llm_client,
self.validator,
self.chapter_storage
self.chapter_storage,
fallback_llm_clients=self.json_rescue_clients,
error_log_dir=self.config.JSON_ERROR_LOG_DIR,
)
def generate_report(self, query: str, reports: List[Any], forum_logs: str = "",