259 lines
9.2 KiB
Python
259 lines
9.2 KiB
Python
"""
|
|
Deep Search Agent状态管理
|
|
定义所有状态数据结构和操作方法
|
|
"""
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import List, Dict, Any, Optional
|
|
import json
|
|
from datetime import datetime
|
|
|
|
|
|
@dataclass
|
|
class Search:
|
|
"""单个搜索结果的状态"""
|
|
query: str = "" # 搜索查询
|
|
url: str = "" # 搜索结果的链接
|
|
title: str = "" # 搜索结果标题
|
|
content: str = "" # 搜索返回的内容
|
|
score: Optional[float] = None # 相关度评分
|
|
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""转换为字典格式"""
|
|
return {
|
|
"query": self.query,
|
|
"url": self.url,
|
|
"title": self.title,
|
|
"content": self.content,
|
|
"score": self.score,
|
|
"timestamp": self.timestamp
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "Search":
|
|
"""从字典创建Search对象"""
|
|
return cls(
|
|
query=data.get("query", ""),
|
|
url=data.get("url", ""),
|
|
title=data.get("title", ""),
|
|
content=data.get("content", ""),
|
|
score=data.get("score"),
|
|
timestamp=data.get("timestamp", datetime.now().isoformat())
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class Research:
|
|
"""段落研究过程的状态"""
|
|
search_history: List[Search] = field(default_factory=list) # 搜索记录列表
|
|
latest_summary: str = "" # 当前段落的最新总结
|
|
reflection_iteration: int = 0 # 反思迭代次数
|
|
is_completed: bool = False # 是否完成研究
|
|
|
|
def add_search(self, search: Search):
|
|
"""添加搜索记录"""
|
|
self.search_history.append(search)
|
|
|
|
def add_search_results(self, query: str, results: List[Dict[str, Any]]):
|
|
"""批量添加搜索结果"""
|
|
for result in results:
|
|
search = Search(
|
|
query=query,
|
|
url=result.get("url", ""),
|
|
title=result.get("title", ""),
|
|
content=result.get("content", ""),
|
|
score=result.get("score")
|
|
)
|
|
self.add_search(search)
|
|
|
|
def get_search_count(self) -> int:
|
|
"""获取搜索次数"""
|
|
return len(self.search_history)
|
|
|
|
def increment_reflection(self):
|
|
"""增加反思次数"""
|
|
self.reflection_iteration += 1
|
|
|
|
def mark_completed(self):
|
|
"""标记为完成"""
|
|
self.is_completed = True
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""转换为字典格式"""
|
|
return {
|
|
"search_history": [search.to_dict() for search in self.search_history],
|
|
"latest_summary": self.latest_summary,
|
|
"reflection_iteration": self.reflection_iteration,
|
|
"is_completed": self.is_completed
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "Research":
|
|
"""从字典创建Research对象"""
|
|
search_history = [Search.from_dict(search_data) for search_data in data.get("search_history", [])]
|
|
return cls(
|
|
search_history=search_history,
|
|
latest_summary=data.get("latest_summary", ""),
|
|
reflection_iteration=data.get("reflection_iteration", 0),
|
|
is_completed=data.get("is_completed", False)
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class Paragraph:
|
|
"""报告中单个段落的状态"""
|
|
title: str = "" # 段落标题
|
|
content: str = "" # 段落的预期内容(初始规划)
|
|
research: Research = field(default_factory=Research) # 研究进度
|
|
order: int = 0 # 段落顺序
|
|
|
|
def is_completed(self) -> bool:
|
|
"""检查段落是否完成"""
|
|
return self.research.is_completed and bool(self.research.latest_summary)
|
|
|
|
def get_final_content(self) -> str:
|
|
"""获取最终内容"""
|
|
return self.research.latest_summary or self.content
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""转换为字典格式"""
|
|
return {
|
|
"title": self.title,
|
|
"content": self.content,
|
|
"research": self.research.to_dict(),
|
|
"order": self.order
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "Paragraph":
|
|
"""从字典创建Paragraph对象"""
|
|
research_data = data.get("research", {})
|
|
research = Research.from_dict(research_data) if research_data else Research()
|
|
|
|
return cls(
|
|
title=data.get("title", ""),
|
|
content=data.get("content", ""),
|
|
research=research,
|
|
order=data.get("order", 0)
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class State:
|
|
"""整个报告的状态"""
|
|
query: str = "" # 原始查询
|
|
report_title: str = "" # 报告标题
|
|
paragraphs: List[Paragraph] = field(default_factory=list) # 段落列表
|
|
final_report: str = "" # 最终报告内容
|
|
is_completed: bool = False # 是否完成
|
|
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
|
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
|
|
|
def add_paragraph(self, title: str, content: str) -> int:
|
|
"""
|
|
添加段落
|
|
|
|
Args:
|
|
title: 段落标题
|
|
content: 段落内容
|
|
|
|
Returns:
|
|
段落索引
|
|
"""
|
|
order = len(self.paragraphs)
|
|
paragraph = Paragraph(title=title, content=content, order=order)
|
|
self.paragraphs.append(paragraph)
|
|
self.update_timestamp()
|
|
return order
|
|
|
|
def get_paragraph(self, index: int) -> Optional[Paragraph]:
|
|
"""获取指定索引的段落"""
|
|
if 0 <= index < len(self.paragraphs):
|
|
return self.paragraphs[index]
|
|
return None
|
|
|
|
def get_completed_paragraphs_count(self) -> int:
|
|
"""获取已完成段落数量"""
|
|
return sum(1 for p in self.paragraphs if p.is_completed())
|
|
|
|
def get_total_paragraphs_count(self) -> int:
|
|
"""获取总段落数量"""
|
|
return len(self.paragraphs)
|
|
|
|
def is_all_paragraphs_completed(self) -> bool:
|
|
"""检查是否所有段落都完成"""
|
|
return all(p.is_completed() for p in self.paragraphs) if self.paragraphs else False
|
|
|
|
def mark_completed(self):
|
|
"""标记整个报告为完成"""
|
|
self.is_completed = True
|
|
self.update_timestamp()
|
|
|
|
def update_timestamp(self):
|
|
"""更新时间戳"""
|
|
self.updated_at = datetime.now().isoformat()
|
|
|
|
def get_progress_summary(self) -> Dict[str, Any]:
|
|
"""获取进度摘要"""
|
|
completed = self.get_completed_paragraphs_count()
|
|
total = self.get_total_paragraphs_count()
|
|
|
|
return {
|
|
"total_paragraphs": total,
|
|
"completed_paragraphs": completed,
|
|
"progress_percentage": (completed / total * 100) if total > 0 else 0,
|
|
"is_completed": self.is_completed,
|
|
"created_at": self.created_at,
|
|
"updated_at": self.updated_at
|
|
}
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""转换为字典格式"""
|
|
return {
|
|
"query": self.query,
|
|
"report_title": self.report_title,
|
|
"paragraphs": [p.to_dict() for p in self.paragraphs],
|
|
"final_report": self.final_report,
|
|
"is_completed": self.is_completed,
|
|
"created_at": self.created_at,
|
|
"updated_at": self.updated_at
|
|
}
|
|
|
|
def to_json(self, indent: int = 2) -> str:
|
|
"""转换为JSON字符串"""
|
|
return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False)
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "State":
|
|
"""从字典创建State对象"""
|
|
paragraphs = [Paragraph.from_dict(p_data) for p_data in data.get("paragraphs", [])]
|
|
|
|
return cls(
|
|
query=data.get("query", ""),
|
|
report_title=data.get("report_title", ""),
|
|
paragraphs=paragraphs,
|
|
final_report=data.get("final_report", ""),
|
|
is_completed=data.get("is_completed", False),
|
|
created_at=data.get("created_at", datetime.now().isoformat()),
|
|
updated_at=data.get("updated_at", datetime.now().isoformat())
|
|
)
|
|
|
|
@classmethod
|
|
def from_json(cls, json_str: str) -> "State":
|
|
"""从JSON字符串创建State对象"""
|
|
data = json.loads(json_str)
|
|
return cls.from_dict(data)
|
|
|
|
def save_to_file(self, filepath: str):
|
|
"""保存状态到文件"""
|
|
with open(filepath, 'w', encoding='utf-8') as f:
|
|
f.write(self.to_json())
|
|
|
|
@classmethod
|
|
def load_from_file(cls, filepath: str) -> "State":
|
|
"""从文件加载状态"""
|
|
with open(filepath, 'r', encoding='utf-8') as f:
|
|
json_str = f.read()
|
|
return cls.from_json(json_str)
|