Reconfiguration of the basic multi-agent architecture.
This commit is contained in:
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
搜索节点实现
|
||||
负责生成搜索查询和反思查询
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
from json.decoder import JSONDecodeError
|
||||
|
||||
from .base_node import BaseNode
|
||||
from ..prompts import SYSTEM_PROMPT_FIRST_SEARCH, SYSTEM_PROMPT_REFLECTION
|
||||
from ..utils.text_processing import (
|
||||
remove_reasoning_from_output,
|
||||
clean_json_tags,
|
||||
extract_clean_response,
|
||||
fix_incomplete_json
|
||||
)
|
||||
|
||||
|
||||
class FirstSearchNode(BaseNode):
|
||||
"""为段落生成首次搜索查询的节点"""
|
||||
|
||||
def __init__(self, llm_client):
|
||||
"""
|
||||
初始化首次搜索节点
|
||||
|
||||
Args:
|
||||
llm_client: LLM客户端
|
||||
"""
|
||||
super().__init__(llm_client, "FirstSearchNode")
|
||||
|
||||
def validate_input(self, input_data: Any) -> bool:
|
||||
"""验证输入数据"""
|
||||
if isinstance(input_data, str):
|
||||
try:
|
||||
data = json.loads(input_data)
|
||||
return "title" in data and "content" in data
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
elif isinstance(input_data, dict):
|
||||
return "title" in input_data and "content" in input_data
|
||||
return False
|
||||
|
||||
def run(self, input_data: Any, **kwargs) -> Dict[str, str]:
|
||||
"""
|
||||
调用LLM生成搜索查询和理由
|
||||
|
||||
Args:
|
||||
input_data: 包含title和content的字符串或字典
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
包含search_query和reasoning的字典
|
||||
"""
|
||||
try:
|
||||
if not self.validate_input(input_data):
|
||||
raise ValueError("输入数据格式错误,需要包含title和content字段")
|
||||
|
||||
# 准备输入数据
|
||||
if isinstance(input_data, str):
|
||||
message = input_data
|
||||
else:
|
||||
message = json.dumps(input_data, ensure_ascii=False)
|
||||
|
||||
self.log_info("正在生成首次搜索查询")
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message)
|
||||
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"生成首次搜索查询失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> Dict[str, str]:
|
||||
"""
|
||||
处理LLM输出,提取搜索查询和推理
|
||||
|
||||
Args:
|
||||
output: LLM原始输出
|
||||
|
||||
Returns:
|
||||
包含search_query和reasoning的字典
|
||||
"""
|
||||
try:
|
||||
# 清理响应文本
|
||||
cleaned_output = remove_reasoning_from_output(output)
|
||||
cleaned_output = clean_json_tags(cleaned_output)
|
||||
|
||||
# 记录清理后的输出用于调试
|
||||
self.log_info(f"清理后的输出: {cleaned_output[:200]}...")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
result = json.loads(cleaned_output)
|
||||
self.log_info("JSON解析成功")
|
||||
except JSONDecodeError as e:
|
||||
self.log_info(f"JSON解析失败: {str(e)}")
|
||||
# 使用更强大的提取方法
|
||||
result = extract_clean_response(cleaned_output)
|
||||
if "error" in result:
|
||||
self.log_error("JSON解析失败,尝试修复...")
|
||||
# 尝试修复JSON
|
||||
fixed_json = fix_incomplete_json(cleaned_output)
|
||||
if fixed_json:
|
||||
try:
|
||||
result = json.loads(fixed_json)
|
||||
self.log_info("JSON修复成功")
|
||||
except JSONDecodeError:
|
||||
self.log_error("JSON修复失败")
|
||||
# 返回默认查询
|
||||
return self._get_default_search_query()
|
||||
else:
|
||||
self.log_error("无法修复JSON,使用默认查询")
|
||||
return self._get_default_search_query()
|
||||
|
||||
# 验证和清理结果
|
||||
search_query = result.get("search_query", "")
|
||||
reasoning = result.get("reasoning", "")
|
||||
|
||||
if not search_query:
|
||||
self.log_warning("未找到搜索查询,使用默认查询")
|
||||
return self._get_default_search_query()
|
||||
|
||||
return {
|
||||
"search_query": search_query,
|
||||
"reasoning": reasoning
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"处理输出失败: {str(e)}")
|
||||
# 返回默认查询
|
||||
return self._get_default_search_query()
|
||||
|
||||
def _get_default_search_query(self) -> Dict[str, str]:
|
||||
"""
|
||||
获取默认搜索查询
|
||||
|
||||
Returns:
|
||||
默认的搜索查询字典
|
||||
"""
|
||||
return {
|
||||
"search_query": "相关主题研究",
|
||||
"reasoning": "由于解析失败,使用默认搜索查询"
|
||||
}
|
||||
|
||||
|
||||
class ReflectionNode(BaseNode):
|
||||
"""反思段落并生成新搜索查询的节点"""
|
||||
|
||||
def __init__(self, llm_client):
|
||||
"""
|
||||
初始化反思节点
|
||||
|
||||
Args:
|
||||
llm_client: LLM客户端
|
||||
"""
|
||||
super().__init__(llm_client, "ReflectionNode")
|
||||
|
||||
def validate_input(self, input_data: Any) -> bool:
|
||||
"""验证输入数据"""
|
||||
if isinstance(input_data, str):
|
||||
try:
|
||||
data = json.loads(input_data)
|
||||
required_fields = ["title", "content", "paragraph_latest_state"]
|
||||
return all(field in data for field in required_fields)
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
elif isinstance(input_data, dict):
|
||||
required_fields = ["title", "content", "paragraph_latest_state"]
|
||||
return all(field in input_data for field in required_fields)
|
||||
return False
|
||||
|
||||
def run(self, input_data: Any, **kwargs) -> Dict[str, str]:
|
||||
"""
|
||||
调用LLM反思并生成搜索查询
|
||||
|
||||
Args:
|
||||
input_data: 包含title、content和paragraph_latest_state的字符串或字典
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
包含search_query和reasoning的字典
|
||||
"""
|
||||
try:
|
||||
if not self.validate_input(input_data):
|
||||
raise ValueError("输入数据格式错误,需要包含title、content和paragraph_latest_state字段")
|
||||
|
||||
# 准备输入数据
|
||||
if isinstance(input_data, str):
|
||||
message = input_data
|
||||
else:
|
||||
message = json.dumps(input_data, ensure_ascii=False)
|
||||
|
||||
self.log_info("正在进行反思并生成新搜索查询")
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message)
|
||||
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"反思生成搜索查询失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> Dict[str, str]:
|
||||
"""
|
||||
处理LLM输出,提取搜索查询和推理
|
||||
|
||||
Args:
|
||||
output: LLM原始输出
|
||||
|
||||
Returns:
|
||||
包含search_query和reasoning的字典
|
||||
"""
|
||||
try:
|
||||
# 清理响应文本
|
||||
cleaned_output = remove_reasoning_from_output(output)
|
||||
cleaned_output = clean_json_tags(cleaned_output)
|
||||
|
||||
# 记录清理后的输出用于调试
|
||||
self.log_info(f"清理后的输出: {cleaned_output[:200]}...")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
result = json.loads(cleaned_output)
|
||||
self.log_info("JSON解析成功")
|
||||
except JSONDecodeError as e:
|
||||
self.log_info(f"JSON解析失败: {str(e)}")
|
||||
# 使用更强大的提取方法
|
||||
result = extract_clean_response(cleaned_output)
|
||||
if "error" in result:
|
||||
self.log_error("JSON解析失败,尝试修复...")
|
||||
# 尝试修复JSON
|
||||
fixed_json = fix_incomplete_json(cleaned_output)
|
||||
if fixed_json:
|
||||
try:
|
||||
result = json.loads(fixed_json)
|
||||
self.log_info("JSON修复成功")
|
||||
except JSONDecodeError:
|
||||
self.log_error("JSON修复失败")
|
||||
# 返回默认查询
|
||||
return self._get_default_reflection_query()
|
||||
else:
|
||||
self.log_error("无法修复JSON,使用默认查询")
|
||||
return self._get_default_reflection_query()
|
||||
|
||||
# 验证和清理结果
|
||||
search_query = result.get("search_query", "")
|
||||
reasoning = result.get("reasoning", "")
|
||||
|
||||
if not search_query:
|
||||
self.log_warning("未找到搜索查询,使用默认查询")
|
||||
return self._get_default_reflection_query()
|
||||
|
||||
return {
|
||||
"search_query": search_query,
|
||||
"reasoning": reasoning
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"处理输出失败: {str(e)}")
|
||||
# 返回默认查询
|
||||
return self._get_default_reflection_query()
|
||||
|
||||
def _get_default_reflection_query(self) -> Dict[str, str]:
|
||||
"""
|
||||
获取默认反思搜索查询
|
||||
|
||||
Returns:
|
||||
默认的反思搜索查询字典
|
||||
"""
|
||||
return {
|
||||
"search_query": "深度研究补充信息",
|
||||
"reasoning": "由于解析失败,使用默认反思搜索查询"
|
||||
}
|
||||
Reference in New Issue
Block a user