Files

293 lines
11 KiB
Python

"""
Deep Search Agent状态管理
定义所有状态数据结构和操作方法
"""
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional
import json
from datetime import datetime
@dataclass
class Search:
"""单个搜索结果的状态"""
query: str = "" # 搜索查询
url: str = "" # 搜索结果的链接
title: str = "" # 搜索结果标题
content: str = "" # 搜索返回的内容
score: Optional[float] = None # 相关度评分
paragraph_title: str = "" # 段落标题,便于展示归属
search_tool: str = "" # 使用的搜索工具
has_result: bool = True # 是否有返回结果
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"query": self.query,
"url": self.url,
"title": self.title,
"content": self.content,
"score": self.score,
"paragraph_title": self.paragraph_title,
"search_tool": self.search_tool,
"has_result": self.has_result,
"timestamp": self.timestamp
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Search":
"""从字典创建Search对象"""
return cls(
query=data.get("query", ""),
url=data.get("url", ""),
title=data.get("title", ""),
content=data.get("content", ""),
score=data.get("score"),
paragraph_title=data.get("paragraph_title", ""),
search_tool=data.get("search_tool", ""),
has_result=data.get("has_result", True),
timestamp=data.get("timestamp", datetime.now().isoformat())
)
@dataclass
class Research:
"""段落研究过程的状态"""
search_history: List[Search] = field(default_factory=list) # 搜索记录列表
latest_summary: str = "" # 当前段落的最新总结
reflection_iteration: int = 0 # 反思迭代次数
is_completed: bool = False # 是否完成研究
def add_search(self, search: Search):
"""添加搜索记录"""
self.search_history.append(search)
def add_search_results(self, query: str, results: List[Dict[str, Any]], search_tool: str = "", paragraph_title: str = ""):
"""批量添加搜索结果"""
if not results:
# 记录一次“无结果”搜索,方便前端显示搜索轨迹
self.add_search(
Search(
query=query or "",
title="未找到结果",
content="本次搜索未返回结果或调用失败",
url="",
score=None,
paragraph_title=paragraph_title,
search_tool=search_tool,
has_result=False,
)
)
return
for result in results:
url = result.get("url") or ""
title = result.get("title") or ""
content = result.get("content") or result.get("raw_content") or ""
if not isinstance(content, str):
content = str(content)
self.add_search(
Search(
query=query or "",
url=url,
title=title,
content=content,
score=result.get("score"),
paragraph_title=paragraph_title or result.get("paragraph_title", ""),
search_tool=search_tool or result.get("search_tool", ""),
has_result=True,
)
)
def get_search_count(self) -> int:
"""获取搜索次数"""
return len(self.search_history)
def increment_reflection(self):
"""增加反思次数"""
self.reflection_iteration += 1
def mark_completed(self):
"""标记为完成"""
self.is_completed = True
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"search_history": [search.to_dict() for search in self.search_history],
"latest_summary": self.latest_summary,
"reflection_iteration": self.reflection_iteration,
"is_completed": self.is_completed
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Research":
"""从字典创建Research对象"""
search_history = [Search.from_dict(search_data) for search_data in data.get("search_history", [])]
return cls(
search_history=search_history,
latest_summary=data.get("latest_summary", ""),
reflection_iteration=data.get("reflection_iteration", 0),
is_completed=data.get("is_completed", False)
)
@dataclass
class Paragraph:
"""报告中单个段落的状态"""
title: str = "" # 段落标题
content: str = "" # 段落的预期内容(初始规划)
research: Research = field(default_factory=Research) # 研究进度
order: int = 0 # 段落顺序
def is_completed(self) -> bool:
"""检查段落是否完成"""
return self.research.is_completed and bool(self.research.latest_summary)
def get_final_content(self) -> str:
"""获取最终内容"""
return self.research.latest_summary or self.content
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"title": self.title,
"content": self.content,
"research": self.research.to_dict(),
"order": self.order
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Paragraph":
"""从字典创建Paragraph对象"""
research_data = data.get("research", {})
research = Research.from_dict(research_data) if research_data else Research()
return cls(
title=data.get("title", ""),
content=data.get("content", ""),
research=research,
order=data.get("order", 0)
)
@dataclass
class State:
"""整个报告的状态"""
query: str = "" # 原始查询
report_title: str = "" # 报告标题
paragraphs: List[Paragraph] = field(default_factory=list) # 段落列表
final_report: str = "" # 最终报告内容
is_completed: bool = False # 是否完成
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
def add_paragraph(self, title: str, content: str) -> int:
"""
添加段落
Args:
title: 段落标题
content: 段落内容
Returns:
段落索引
"""
order = len(self.paragraphs)
paragraph = Paragraph(title=title, content=content, order=order)
self.paragraphs.append(paragraph)
self.update_timestamp()
return order
def get_paragraph(self, index: int) -> Optional[Paragraph]:
"""获取指定索引的段落"""
if 0 <= index < len(self.paragraphs):
return self.paragraphs[index]
return None
def get_completed_paragraphs_count(self) -> int:
"""获取已完成段落数量"""
return sum(1 for p in self.paragraphs if p.is_completed())
def get_total_paragraphs_count(self) -> int:
"""获取总段落数量"""
return len(self.paragraphs)
def is_all_paragraphs_completed(self) -> bool:
"""检查是否所有段落都完成"""
return all(p.is_completed() for p in self.paragraphs) if self.paragraphs else False
def mark_completed(self):
"""标记整个报告为完成"""
self.is_completed = True
self.update_timestamp()
def update_timestamp(self):
"""更新时间戳"""
self.updated_at = datetime.now().isoformat()
def get_progress_summary(self) -> Dict[str, Any]:
"""获取进度摘要"""
completed = self.get_completed_paragraphs_count()
total = self.get_total_paragraphs_count()
return {
"total_paragraphs": total,
"completed_paragraphs": completed,
"progress_percentage": (completed / total * 100) if total > 0 else 0,
"is_completed": self.is_completed,
"created_at": self.created_at,
"updated_at": self.updated_at
}
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"query": self.query,
"report_title": self.report_title,
"paragraphs": [p.to_dict() for p in self.paragraphs],
"final_report": self.final_report,
"is_completed": self.is_completed,
"created_at": self.created_at,
"updated_at": self.updated_at
}
def to_json(self, indent: int = 2) -> str:
"""转换为JSON字符串"""
return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "State":
"""从字典创建State对象"""
paragraphs = [Paragraph.from_dict(p_data) for p_data in data.get("paragraphs", [])]
return cls(
query=data.get("query", ""),
report_title=data.get("report_title", ""),
paragraphs=paragraphs,
final_report=data.get("final_report", ""),
is_completed=data.get("is_completed", False),
created_at=data.get("created_at", datetime.now().isoformat()),
updated_at=data.get("updated_at", datetime.now().isoformat())
)
@classmethod
def from_json(cls, json_str: str) -> "State":
"""从JSON字符串创建State对象"""
data = json.loads(json_str)
return cls.from_dict(data)
def save_to_file(self, filepath: str):
"""保存状态到文件"""
with open(filepath, 'w', encoding='utf-8') as f:
f.write(self.to_json())
@classmethod
def load_from_file(cls, filepath: str) -> "State":
"""从文件加载状态"""
with open(filepath, 'r', encoding='utf-8') as f:
json_str = f.read()
return cls.from_json(json_str)