Reconfiguration of the basic multi-agent architecture.
This commit is contained in:
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
节点处理模块
|
||||
实现Deep Search Agent的各个处理步骤
|
||||
"""
|
||||
|
||||
from .base_node import BaseNode
|
||||
from .report_structure_node import ReportStructureNode
|
||||
from .search_node import FirstSearchNode, ReflectionNode
|
||||
from .summary_node import FirstSummaryNode, ReflectionSummaryNode
|
||||
from .formatting_node import ReportFormattingNode
|
||||
|
||||
__all__ = [
|
||||
"BaseNode",
|
||||
"ReportStructureNode",
|
||||
"FirstSearchNode",
|
||||
"ReflectionNode",
|
||||
"FirstSummaryNode",
|
||||
"ReflectionSummaryNode",
|
||||
"ReportFormattingNode"
|
||||
]
|
||||
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
节点基类
|
||||
定义所有处理节点的基础接口
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
from ..llms.base import BaseLLM
|
||||
from ..state.state import State
|
||||
|
||||
|
||||
class BaseNode(ABC):
|
||||
"""节点基类"""
|
||||
|
||||
def __init__(self, llm_client: BaseLLM, node_name: str = ""):
|
||||
"""
|
||||
初始化节点
|
||||
|
||||
Args:
|
||||
llm_client: LLM客户端
|
||||
node_name: 节点名称
|
||||
"""
|
||||
self.llm_client = llm_client
|
||||
self.node_name = node_name or self.__class__.__name__
|
||||
|
||||
@abstractmethod
|
||||
def run(self, input_data: Any, **kwargs) -> Any:
|
||||
"""
|
||||
执行节点处理逻辑
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
处理结果
|
||||
"""
|
||||
pass
|
||||
|
||||
def validate_input(self, input_data: Any) -> bool:
|
||||
"""
|
||||
验证输入数据
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
|
||||
Returns:
|
||||
验证是否通过
|
||||
"""
|
||||
return True
|
||||
|
||||
def process_output(self, output: Any) -> Any:
|
||||
"""
|
||||
处理输出数据
|
||||
|
||||
Args:
|
||||
output: 原始输出
|
||||
|
||||
Returns:
|
||||
处理后的输出
|
||||
"""
|
||||
return output
|
||||
|
||||
def log_info(self, message: str):
|
||||
"""记录信息日志"""
|
||||
print(f"[{self.node_name}] {message}")
|
||||
|
||||
def log_error(self, message: str):
|
||||
"""记录错误日志"""
|
||||
print(f"[{self.node_name}] 错误: {message}")
|
||||
|
||||
|
||||
class StateMutationNode(BaseNode):
|
||||
"""带状态修改功能的节点基类"""
|
||||
|
||||
@abstractmethod
|
||||
def mutate_state(self, input_data: Any, state: State, **kwargs) -> State:
|
||||
"""
|
||||
修改状态
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
state: 当前状态
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
修改后的状态
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
报告格式化节点
|
||||
负责将最终研究结果格式化为美观的Markdown报告
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from .base_node import BaseNode
|
||||
from ..prompts import SYSTEM_PROMPT_REPORT_FORMATTING
|
||||
from ..utils.text_processing import (
|
||||
remove_reasoning_from_output,
|
||||
clean_markdown_tags
|
||||
)
|
||||
|
||||
|
||||
class ReportFormattingNode(BaseNode):
|
||||
"""格式化最终报告的节点"""
|
||||
|
||||
def __init__(self, llm_client):
|
||||
"""
|
||||
初始化报告格式化节点
|
||||
|
||||
Args:
|
||||
llm_client: LLM客户端
|
||||
"""
|
||||
super().__init__(llm_client, "ReportFormattingNode")
|
||||
|
||||
def validate_input(self, input_data: Any) -> bool:
|
||||
"""验证输入数据"""
|
||||
if isinstance(input_data, str):
|
||||
try:
|
||||
data = json.loads(input_data)
|
||||
return isinstance(data, list) and all(
|
||||
isinstance(item, dict) and "title" in item and "paragraph_latest_state" in item
|
||||
for item in data
|
||||
)
|
||||
except:
|
||||
return False
|
||||
elif isinstance(input_data, list):
|
||||
return all(
|
||||
isinstance(item, dict) and "title" in item and "paragraph_latest_state" in item
|
||||
for item in input_data
|
||||
)
|
||||
return False
|
||||
|
||||
def run(self, input_data: Any, **kwargs) -> str:
|
||||
"""
|
||||
调用LLM生成Markdown格式报告
|
||||
|
||||
Args:
|
||||
input_data: 包含所有段落信息的列表
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
格式化的Markdown报告
|
||||
"""
|
||||
try:
|
||||
if not self.validate_input(input_data):
|
||||
raise ValueError("输入数据格式错误,需要包含title和paragraph_latest_state的列表")
|
||||
|
||||
# 准备输入数据
|
||||
if isinstance(input_data, str):
|
||||
message = input_data
|
||||
else:
|
||||
message = json.dumps(input_data, ensure_ascii=False)
|
||||
|
||||
self.log_info("正在格式化最终报告")
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_FORMATTING, message)
|
||||
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info("成功生成格式化报告")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"报告格式化失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> str:
|
||||
"""
|
||||
处理LLM输出,清理Markdown格式
|
||||
|
||||
Args:
|
||||
output: LLM原始输出
|
||||
|
||||
Returns:
|
||||
清理后的Markdown报告
|
||||
"""
|
||||
try:
|
||||
# 清理响应文本
|
||||
cleaned_output = remove_reasoning_from_output(output)
|
||||
cleaned_output = clean_markdown_tags(cleaned_output)
|
||||
|
||||
# 确保报告有基本结构
|
||||
if not cleaned_output.strip():
|
||||
return "# 报告生成失败\n\n无法生成有效的报告内容。"
|
||||
|
||||
# 如果没有标题,添加一个默认标题
|
||||
if not cleaned_output.strip().startswith('#'):
|
||||
cleaned_output = "# 深度研究报告\n\n" + cleaned_output
|
||||
|
||||
return cleaned_output.strip()
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"处理输出失败: {str(e)}")
|
||||
return "# 报告处理失败\n\n报告格式化过程中发生错误。"
|
||||
|
||||
def format_report_manually(self, paragraphs_data: List[Dict[str, str]],
|
||||
report_title: str = "深度研究报告") -> str:
|
||||
"""
|
||||
手动格式化报告(备用方法)
|
||||
|
||||
Args:
|
||||
paragraphs_data: 段落数据列表
|
||||
report_title: 报告标题
|
||||
|
||||
Returns:
|
||||
格式化的Markdown报告
|
||||
"""
|
||||
try:
|
||||
self.log_info("使用手动格式化方法")
|
||||
|
||||
# 构建报告
|
||||
report_lines = [
|
||||
f"# {report_title}",
|
||||
"",
|
||||
"---",
|
||||
""
|
||||
]
|
||||
|
||||
# 添加各个段落
|
||||
for i, paragraph in enumerate(paragraphs_data, 1):
|
||||
title = paragraph.get("title", f"段落 {i}")
|
||||
content = paragraph.get("paragraph_latest_state", "")
|
||||
|
||||
if content:
|
||||
report_lines.extend([
|
||||
f"## {title}",
|
||||
"",
|
||||
content,
|
||||
"",
|
||||
"---",
|
||||
""
|
||||
])
|
||||
|
||||
# 添加结论
|
||||
if len(paragraphs_data) > 1:
|
||||
report_lines.extend([
|
||||
"## 结论",
|
||||
"",
|
||||
"本报告通过深度搜索和研究,对相关主题进行了全面分析。"
|
||||
"以上各个方面的内容为理解该主题提供了重要参考。",
|
||||
""
|
||||
])
|
||||
|
||||
return "\n".join(report_lines)
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"手动格式化失败: {str(e)}")
|
||||
return "# 报告生成失败\n\n无法完成报告格式化。"
|
||||
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
报告结构生成节点
|
||||
负责根据查询生成报告的整体结构
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any, List
|
||||
from json.decoder import JSONDecodeError
|
||||
|
||||
from .base_node import StateMutationNode
|
||||
from ..state.state import State
|
||||
from ..prompts import SYSTEM_PROMPT_REPORT_STRUCTURE
|
||||
from ..utils.text_processing import (
|
||||
remove_reasoning_from_output,
|
||||
clean_json_tags,
|
||||
extract_clean_response,
|
||||
fix_incomplete_json
|
||||
)
|
||||
|
||||
|
||||
class ReportStructureNode(StateMutationNode):
|
||||
"""生成报告结构的节点"""
|
||||
|
||||
def __init__(self, llm_client, query: str):
|
||||
"""
|
||||
初始化报告结构节点
|
||||
|
||||
Args:
|
||||
llm_client: LLM客户端
|
||||
query: 用户查询
|
||||
"""
|
||||
super().__init__(llm_client, "ReportStructureNode")
|
||||
self.query = query
|
||||
|
||||
def validate_input(self, input_data: Any) -> bool:
|
||||
"""验证输入数据"""
|
||||
return isinstance(self.query, str) and len(self.query.strip()) > 0
|
||||
|
||||
def run(self, input_data: Any = None, **kwargs) -> List[Dict[str, str]]:
|
||||
"""
|
||||
调用LLM生成报告结构
|
||||
|
||||
Args:
|
||||
input_data: 输入数据(这里不使用,使用初始化时的query)
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
报告结构列表
|
||||
"""
|
||||
try:
|
||||
self.log_info(f"正在为查询生成报告结构: {self.query}")
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query)
|
||||
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info(f"成功生成 {len(processed_response)} 个段落结构")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"生成报告结构失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
处理LLM输出,提取报告结构
|
||||
|
||||
Args:
|
||||
output: LLM原始输出
|
||||
|
||||
Returns:
|
||||
处理后的报告结构列表
|
||||
"""
|
||||
try:
|
||||
# 清理响应文本
|
||||
cleaned_output = remove_reasoning_from_output(output)
|
||||
cleaned_output = clean_json_tags(cleaned_output)
|
||||
|
||||
# 记录清理后的输出用于调试
|
||||
self.log_info(f"清理后的输出: {cleaned_output[:200]}...")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
report_structure = json.loads(cleaned_output)
|
||||
self.log_info("JSON解析成功")
|
||||
except JSONDecodeError as e:
|
||||
self.log_info(f"JSON解析失败: {str(e)}")
|
||||
# 使用更强大的提取方法
|
||||
report_structure = extract_clean_response(cleaned_output)
|
||||
if "error" in report_structure:
|
||||
self.log_error("JSON解析失败,尝试修复...")
|
||||
# 尝试修复JSON
|
||||
fixed_json = fix_incomplete_json(cleaned_output)
|
||||
if fixed_json:
|
||||
try:
|
||||
report_structure = json.loads(fixed_json)
|
||||
self.log_info("JSON修复成功")
|
||||
except JSONDecodeError:
|
||||
self.log_error("JSON修复失败")
|
||||
# 返回默认结构
|
||||
return self._generate_default_structure()
|
||||
else:
|
||||
self.log_error("无法修复JSON,使用默认结构")
|
||||
return self._generate_default_structure()
|
||||
|
||||
# 验证结构
|
||||
if not isinstance(report_structure, list):
|
||||
self.log_info("报告结构不是列表,尝试转换...")
|
||||
if isinstance(report_structure, dict):
|
||||
# 如果是单个对象,包装成列表
|
||||
report_structure = [report_structure]
|
||||
else:
|
||||
self.log_error("报告结构格式无效,使用默认结构")
|
||||
return self._generate_default_structure()
|
||||
|
||||
# 验证每个段落
|
||||
validated_structure = []
|
||||
for i, paragraph in enumerate(report_structure):
|
||||
if not isinstance(paragraph, dict):
|
||||
self.log_warning(f"段落 {i+1} 不是字典格式,跳过")
|
||||
continue
|
||||
|
||||
title = paragraph.get("title", f"段落 {i+1}")
|
||||
content = paragraph.get("content", "")
|
||||
|
||||
if not title or not content:
|
||||
self.log_warning(f"段落 {i+1} 缺少标题或内容,跳过")
|
||||
continue
|
||||
|
||||
validated_structure.append({
|
||||
"title": title,
|
||||
"content": content
|
||||
})
|
||||
|
||||
if not validated_structure:
|
||||
self.log_warning("没有有效的段落结构,使用默认结构")
|
||||
return self._generate_default_structure()
|
||||
|
||||
self.log_info(f"成功验证 {len(validated_structure)} 个段落结构")
|
||||
return validated_structure
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"处理输出失败: {str(e)}")
|
||||
return self._generate_default_structure()
|
||||
|
||||
def _generate_default_structure(self) -> List[Dict[str, str]]:
|
||||
"""
|
||||
生成默认的报告结构
|
||||
|
||||
Returns:
|
||||
默认的报告结构列表
|
||||
"""
|
||||
self.log_info("生成默认报告结构")
|
||||
return [
|
||||
{
|
||||
"title": "研究概述",
|
||||
"content": "对查询主题进行总体概述和分析"
|
||||
},
|
||||
{
|
||||
"title": "深度分析",
|
||||
"content": "深入分析查询主题的各个方面"
|
||||
}
|
||||
]
|
||||
|
||||
def mutate_state(self, input_data: Any = None, state: State = None, **kwargs) -> State:
|
||||
"""
|
||||
将报告结构写入状态
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
state: 当前状态,如果为None则创建新状态
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
更新后的状态
|
||||
"""
|
||||
if state is None:
|
||||
state = State()
|
||||
|
||||
try:
|
||||
# 生成报告结构
|
||||
report_structure = self.run(input_data, **kwargs)
|
||||
|
||||
# 设置查询和报告标题
|
||||
state.query = self.query
|
||||
if not state.report_title:
|
||||
state.report_title = f"关于'{self.query}'的深度研究报告"
|
||||
|
||||
# 添加段落到状态
|
||||
for paragraph_data in report_structure:
|
||||
state.add_paragraph(
|
||||
title=paragraph_data["title"],
|
||||
content=paragraph_data["content"]
|
||||
)
|
||||
|
||||
self.log_info(f"已将 {len(report_structure)} 个段落添加到状态中")
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"状态更新失败: {str(e)}")
|
||||
raise e
|
||||
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
搜索节点实现
|
||||
负责生成搜索查询和反思查询
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
from json.decoder import JSONDecodeError
|
||||
|
||||
from .base_node import BaseNode
|
||||
from ..prompts import SYSTEM_PROMPT_FIRST_SEARCH, SYSTEM_PROMPT_REFLECTION
|
||||
from ..utils.text_processing import (
|
||||
remove_reasoning_from_output,
|
||||
clean_json_tags,
|
||||
extract_clean_response,
|
||||
fix_incomplete_json
|
||||
)
|
||||
|
||||
|
||||
class FirstSearchNode(BaseNode):
|
||||
"""为段落生成首次搜索查询的节点"""
|
||||
|
||||
def __init__(self, llm_client):
|
||||
"""
|
||||
初始化首次搜索节点
|
||||
|
||||
Args:
|
||||
llm_client: LLM客户端
|
||||
"""
|
||||
super().__init__(llm_client, "FirstSearchNode")
|
||||
|
||||
def validate_input(self, input_data: Any) -> bool:
|
||||
"""验证输入数据"""
|
||||
if isinstance(input_data, str):
|
||||
try:
|
||||
data = json.loads(input_data)
|
||||
return "title" in data and "content" in data
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
elif isinstance(input_data, dict):
|
||||
return "title" in input_data and "content" in input_data
|
||||
return False
|
||||
|
||||
def run(self, input_data: Any, **kwargs) -> Dict[str, str]:
|
||||
"""
|
||||
调用LLM生成搜索查询和理由
|
||||
|
||||
Args:
|
||||
input_data: 包含title和content的字符串或字典
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
包含search_query和reasoning的字典
|
||||
"""
|
||||
try:
|
||||
if not self.validate_input(input_data):
|
||||
raise ValueError("输入数据格式错误,需要包含title和content字段")
|
||||
|
||||
# 准备输入数据
|
||||
if isinstance(input_data, str):
|
||||
message = input_data
|
||||
else:
|
||||
message = json.dumps(input_data, ensure_ascii=False)
|
||||
|
||||
self.log_info("正在生成首次搜索查询")
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message)
|
||||
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"生成首次搜索查询失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> Dict[str, str]:
|
||||
"""
|
||||
处理LLM输出,提取搜索查询和推理
|
||||
|
||||
Args:
|
||||
output: LLM原始输出
|
||||
|
||||
Returns:
|
||||
包含search_query和reasoning的字典
|
||||
"""
|
||||
try:
|
||||
# 清理响应文本
|
||||
cleaned_output = remove_reasoning_from_output(output)
|
||||
cleaned_output = clean_json_tags(cleaned_output)
|
||||
|
||||
# 记录清理后的输出用于调试
|
||||
self.log_info(f"清理后的输出: {cleaned_output[:200]}...")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
result = json.loads(cleaned_output)
|
||||
self.log_info("JSON解析成功")
|
||||
except JSONDecodeError as e:
|
||||
self.log_info(f"JSON解析失败: {str(e)}")
|
||||
# 使用更强大的提取方法
|
||||
result = extract_clean_response(cleaned_output)
|
||||
if "error" in result:
|
||||
self.log_error("JSON解析失败,尝试修复...")
|
||||
# 尝试修复JSON
|
||||
fixed_json = fix_incomplete_json(cleaned_output)
|
||||
if fixed_json:
|
||||
try:
|
||||
result = json.loads(fixed_json)
|
||||
self.log_info("JSON修复成功")
|
||||
except JSONDecodeError:
|
||||
self.log_error("JSON修复失败")
|
||||
# 返回默认查询
|
||||
return self._get_default_search_query()
|
||||
else:
|
||||
self.log_error("无法修复JSON,使用默认查询")
|
||||
return self._get_default_search_query()
|
||||
|
||||
# 验证和清理结果
|
||||
search_query = result.get("search_query", "")
|
||||
reasoning = result.get("reasoning", "")
|
||||
|
||||
if not search_query:
|
||||
self.log_warning("未找到搜索查询,使用默认查询")
|
||||
return self._get_default_search_query()
|
||||
|
||||
return {
|
||||
"search_query": search_query,
|
||||
"reasoning": reasoning
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"处理输出失败: {str(e)}")
|
||||
# 返回默认查询
|
||||
return self._get_default_search_query()
|
||||
|
||||
def _get_default_search_query(self) -> Dict[str, str]:
|
||||
"""
|
||||
获取默认搜索查询
|
||||
|
||||
Returns:
|
||||
默认的搜索查询字典
|
||||
"""
|
||||
return {
|
||||
"search_query": "相关主题研究",
|
||||
"reasoning": "由于解析失败,使用默认搜索查询"
|
||||
}
|
||||
|
||||
|
||||
class ReflectionNode(BaseNode):
|
||||
"""反思段落并生成新搜索查询的节点"""
|
||||
|
||||
def __init__(self, llm_client):
|
||||
"""
|
||||
初始化反思节点
|
||||
|
||||
Args:
|
||||
llm_client: LLM客户端
|
||||
"""
|
||||
super().__init__(llm_client, "ReflectionNode")
|
||||
|
||||
def validate_input(self, input_data: Any) -> bool:
|
||||
"""验证输入数据"""
|
||||
if isinstance(input_data, str):
|
||||
try:
|
||||
data = json.loads(input_data)
|
||||
required_fields = ["title", "content", "paragraph_latest_state"]
|
||||
return all(field in data for field in required_fields)
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
elif isinstance(input_data, dict):
|
||||
required_fields = ["title", "content", "paragraph_latest_state"]
|
||||
return all(field in input_data for field in required_fields)
|
||||
return False
|
||||
|
||||
def run(self, input_data: Any, **kwargs) -> Dict[str, str]:
|
||||
"""
|
||||
调用LLM反思并生成搜索查询
|
||||
|
||||
Args:
|
||||
input_data: 包含title、content和paragraph_latest_state的字符串或字典
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
包含search_query和reasoning的字典
|
||||
"""
|
||||
try:
|
||||
if not self.validate_input(input_data):
|
||||
raise ValueError("输入数据格式错误,需要包含title、content和paragraph_latest_state字段")
|
||||
|
||||
# 准备输入数据
|
||||
if isinstance(input_data, str):
|
||||
message = input_data
|
||||
else:
|
||||
message = json.dumps(input_data, ensure_ascii=False)
|
||||
|
||||
self.log_info("正在进行反思并生成新搜索查询")
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message)
|
||||
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"反思生成搜索查询失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> Dict[str, str]:
|
||||
"""
|
||||
处理LLM输出,提取搜索查询和推理
|
||||
|
||||
Args:
|
||||
output: LLM原始输出
|
||||
|
||||
Returns:
|
||||
包含search_query和reasoning的字典
|
||||
"""
|
||||
try:
|
||||
# 清理响应文本
|
||||
cleaned_output = remove_reasoning_from_output(output)
|
||||
cleaned_output = clean_json_tags(cleaned_output)
|
||||
|
||||
# 记录清理后的输出用于调试
|
||||
self.log_info(f"清理后的输出: {cleaned_output[:200]}...")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
result = json.loads(cleaned_output)
|
||||
self.log_info("JSON解析成功")
|
||||
except JSONDecodeError as e:
|
||||
self.log_info(f"JSON解析失败: {str(e)}")
|
||||
# 使用更强大的提取方法
|
||||
result = extract_clean_response(cleaned_output)
|
||||
if "error" in result:
|
||||
self.log_error("JSON解析失败,尝试修复...")
|
||||
# 尝试修复JSON
|
||||
fixed_json = fix_incomplete_json(cleaned_output)
|
||||
if fixed_json:
|
||||
try:
|
||||
result = json.loads(fixed_json)
|
||||
self.log_info("JSON修复成功")
|
||||
except JSONDecodeError:
|
||||
self.log_error("JSON修复失败")
|
||||
# 返回默认查询
|
||||
return self._get_default_reflection_query()
|
||||
else:
|
||||
self.log_error("无法修复JSON,使用默认查询")
|
||||
return self._get_default_reflection_query()
|
||||
|
||||
# 验证和清理结果
|
||||
search_query = result.get("search_query", "")
|
||||
reasoning = result.get("reasoning", "")
|
||||
|
||||
if not search_query:
|
||||
self.log_warning("未找到搜索查询,使用默认查询")
|
||||
return self._get_default_reflection_query()
|
||||
|
||||
return {
|
||||
"search_query": search_query,
|
||||
"reasoning": reasoning
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"处理输出失败: {str(e)}")
|
||||
# 返回默认查询
|
||||
return self._get_default_reflection_query()
|
||||
|
||||
def _get_default_reflection_query(self) -> Dict[str, str]:
|
||||
"""
|
||||
获取默认反思搜索查询
|
||||
|
||||
Returns:
|
||||
默认的反思搜索查询字典
|
||||
"""
|
||||
return {
|
||||
"search_query": "深度研究补充信息",
|
||||
"reasoning": "由于解析失败,使用默认反思搜索查询"
|
||||
}
|
||||
@@ -0,0 +1,312 @@
|
||||
"""
|
||||
总结节点实现
|
||||
负责根据搜索结果生成和更新段落内容
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any, List
|
||||
from json.decoder import JSONDecodeError
|
||||
|
||||
from .base_node import StateMutationNode
|
||||
from ..state.state import State
|
||||
from ..prompts import SYSTEM_PROMPT_FIRST_SUMMARY, SYSTEM_PROMPT_REFLECTION_SUMMARY
|
||||
from ..utils.text_processing import (
|
||||
remove_reasoning_from_output,
|
||||
clean_json_tags,
|
||||
extract_clean_response,
|
||||
fix_incomplete_json,
|
||||
format_search_results_for_prompt
|
||||
)
|
||||
|
||||
|
||||
class FirstSummaryNode(StateMutationNode):
|
||||
"""根据搜索结果生成段落首次总结的节点"""
|
||||
|
||||
def __init__(self, llm_client):
|
||||
"""
|
||||
初始化首次总结节点
|
||||
|
||||
Args:
|
||||
llm_client: LLM客户端
|
||||
"""
|
||||
super().__init__(llm_client, "FirstSummaryNode")
|
||||
|
||||
def validate_input(self, input_data: Any) -> bool:
|
||||
"""验证输入数据"""
|
||||
if isinstance(input_data, str):
|
||||
try:
|
||||
data = json.loads(input_data)
|
||||
required_fields = ["title", "content", "search_query", "search_results"]
|
||||
return all(field in data for field in required_fields)
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
elif isinstance(input_data, dict):
|
||||
required_fields = ["title", "content", "search_query", "search_results"]
|
||||
return all(field in input_data for field in required_fields)
|
||||
return False
|
||||
|
||||
def run(self, input_data: Any, **kwargs) -> str:
|
||||
"""
|
||||
调用LLM生成段落总结
|
||||
|
||||
Args:
|
||||
input_data: 包含title、content、search_query和search_results的数据
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
段落总结内容
|
||||
"""
|
||||
try:
|
||||
if not self.validate_input(input_data):
|
||||
raise ValueError("输入数据格式错误")
|
||||
|
||||
# 准备输入数据
|
||||
if isinstance(input_data, str):
|
||||
message = input_data
|
||||
else:
|
||||
message = json.dumps(input_data, ensure_ascii=False)
|
||||
|
||||
self.log_info("正在生成首次段落总结")
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SUMMARY, message)
|
||||
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info("成功生成首次段落总结")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"生成首次总结失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> str:
|
||||
"""
|
||||
处理LLM输出,提取段落内容
|
||||
|
||||
Args:
|
||||
output: LLM原始输出
|
||||
|
||||
Returns:
|
||||
段落内容
|
||||
"""
|
||||
try:
|
||||
# 清理响应文本
|
||||
cleaned_output = remove_reasoning_from_output(output)
|
||||
cleaned_output = clean_json_tags(cleaned_output)
|
||||
|
||||
# 记录清理后的输出用于调试
|
||||
self.log_info(f"清理后的输出: {cleaned_output[:200]}...")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
result = json.loads(cleaned_output)
|
||||
self.log_info("JSON解析成功")
|
||||
except JSONDecodeError as e:
|
||||
self.log_info(f"JSON解析失败: {str(e)}")
|
||||
# 尝试修复JSON
|
||||
fixed_json = fix_incomplete_json(cleaned_output)
|
||||
if fixed_json:
|
||||
try:
|
||||
result = json.loads(fixed_json)
|
||||
self.log_info("JSON修复成功")
|
||||
except JSONDecodeError:
|
||||
self.log_info("JSON修复失败,直接使用清理后的文本")
|
||||
# 如果不是JSON格式,直接返回清理后的文本
|
||||
return cleaned_output
|
||||
else:
|
||||
self.log_info("无法修复JSON,直接使用清理后的文本")
|
||||
# 如果不是JSON格式,直接返回清理后的文本
|
||||
return cleaned_output
|
||||
|
||||
# 提取段落内容
|
||||
if isinstance(result, dict):
|
||||
paragraph_content = result.get("paragraph_latest_state", "")
|
||||
if paragraph_content:
|
||||
return paragraph_content
|
||||
|
||||
# 如果提取失败,返回原始清理后的文本
|
||||
return cleaned_output
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"处理输出失败: {str(e)}")
|
||||
return "段落总结生成失败"
|
||||
|
||||
def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State:
|
||||
"""
|
||||
更新段落的最新总结到状态
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
state: 当前状态
|
||||
paragraph_index: 段落索引
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
更新后的状态
|
||||
"""
|
||||
try:
|
||||
# 生成总结
|
||||
summary = self.run(input_data, **kwargs)
|
||||
|
||||
# 更新状态
|
||||
if 0 <= paragraph_index < len(state.paragraphs):
|
||||
state.paragraphs[paragraph_index].research.latest_summary = summary
|
||||
self.log_info(f"已更新段落 {paragraph_index} 的首次总结")
|
||||
else:
|
||||
raise ValueError(f"段落索引 {paragraph_index} 超出范围")
|
||||
|
||||
state.update_timestamp()
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"状态更新失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
|
||||
class ReflectionSummaryNode(StateMutationNode):
|
||||
"""根据反思搜索结果更新段落总结的节点"""
|
||||
|
||||
def __init__(self, llm_client):
|
||||
"""
|
||||
初始化反思总结节点
|
||||
|
||||
Args:
|
||||
llm_client: LLM客户端
|
||||
"""
|
||||
super().__init__(llm_client, "ReflectionSummaryNode")
|
||||
|
||||
def validate_input(self, input_data: Any) -> bool:
|
||||
"""验证输入数据"""
|
||||
if isinstance(input_data, str):
|
||||
try:
|
||||
data = json.loads(input_data)
|
||||
required_fields = ["title", "content", "search_query", "search_results", "paragraph_latest_state"]
|
||||
return all(field in data for field in required_fields)
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
elif isinstance(input_data, dict):
|
||||
required_fields = ["title", "content", "search_query", "search_results", "paragraph_latest_state"]
|
||||
return all(field in input_data for field in required_fields)
|
||||
return False
|
||||
|
||||
def run(self, input_data: Any, **kwargs) -> str:
|
||||
"""
|
||||
调用LLM更新段落内容
|
||||
|
||||
Args:
|
||||
input_data: 包含完整反思信息的数据
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
更新后的段落内容
|
||||
"""
|
||||
try:
|
||||
if not self.validate_input(input_data):
|
||||
raise ValueError("输入数据格式错误")
|
||||
|
||||
# 准备输入数据
|
||||
if isinstance(input_data, str):
|
||||
message = input_data
|
||||
else:
|
||||
message = json.dumps(input_data, ensure_ascii=False)
|
||||
|
||||
self.log_info("正在生成反思总结")
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION_SUMMARY, message)
|
||||
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info("成功生成反思总结")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"生成反思总结失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> str:
|
||||
"""
|
||||
处理LLM输出,提取更新后的段落内容
|
||||
|
||||
Args:
|
||||
output: LLM原始输出
|
||||
|
||||
Returns:
|
||||
更新后的段落内容
|
||||
"""
|
||||
try:
|
||||
# 清理响应文本
|
||||
cleaned_output = remove_reasoning_from_output(output)
|
||||
cleaned_output = clean_json_tags(cleaned_output)
|
||||
|
||||
# 记录清理后的输出用于调试
|
||||
self.log_info(f"清理后的输出: {cleaned_output[:200]}...")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
result = json.loads(cleaned_output)
|
||||
self.log_info("JSON解析成功")
|
||||
except JSONDecodeError as e:
|
||||
self.log_info(f"JSON解析失败: {str(e)}")
|
||||
# 尝试修复JSON
|
||||
fixed_json = fix_incomplete_json(cleaned_output)
|
||||
if fixed_json:
|
||||
try:
|
||||
result = json.loads(fixed_json)
|
||||
self.log_info("JSON修复成功")
|
||||
except JSONDecodeError:
|
||||
self.log_info("JSON修复失败,直接使用清理后的文本")
|
||||
# 如果不是JSON格式,直接返回清理后的文本
|
||||
return cleaned_output
|
||||
else:
|
||||
self.log_info("无法修复JSON,直接使用清理后的文本")
|
||||
# 如果不是JSON格式,直接返回清理后的文本
|
||||
return cleaned_output
|
||||
|
||||
# 提取更新后的段落内容
|
||||
if isinstance(result, dict):
|
||||
updated_content = result.get("updated_paragraph_latest_state", "")
|
||||
if updated_content:
|
||||
return updated_content
|
||||
|
||||
# 如果提取失败,返回原始清理后的文本
|
||||
return cleaned_output
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"处理输出失败: {str(e)}")
|
||||
return "反思总结生成失败"
|
||||
|
||||
def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State:
|
||||
"""
|
||||
将更新后的总结写入状态
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
state: 当前状态
|
||||
paragraph_index: 段落索引
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
更新后的状态
|
||||
"""
|
||||
try:
|
||||
# 生成更新后的总结
|
||||
updated_summary = self.run(input_data, **kwargs)
|
||||
|
||||
# 更新状态
|
||||
if 0 <= paragraph_index < len(state.paragraphs):
|
||||
state.paragraphs[paragraph_index].research.latest_summary = updated_summary
|
||||
state.paragraphs[paragraph_index].research.increment_reflection()
|
||||
self.log_info(f"已更新段落 {paragraph_index} 的反思总结")
|
||||
else:
|
||||
raise ValueError(f"段落索引 {paragraph_index} 超出范围")
|
||||
|
||||
state.update_timestamp()
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"状态更新失败: {str(e)}")
|
||||
raise e
|
||||
Reference in New Issue
Block a user