417 lines
15 KiB
Python
417 lines
15 KiB
Python
"""
|
||
Step 04: Memory - 记忆系统
|
||
|
||
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||
|
||
🎓 本节内容:
|
||
1. 为什么 Agent 需要 Memory?
|
||
2. 多层记忆架构
|
||
3. 上下文窗口管理
|
||
4. 对话压缩技术
|
||
|
||
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||
"""
|
||
|
||
from typing import TypedDict, List, Dict, Any, Optional
|
||
from dataclasses import dataclass
|
||
from datetime import datetime
|
||
import json
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
# 第一部分:理解为什么需要 Memory
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
|
||
"""
|
||
回顾 Step 03 的 SimpleAgent:
|
||
|
||
问题:对话历史都存在 state['messages'] 里
|
||
问题:如果对话很长,发送给 LLM 的 token 会越来越多
|
||
问题:LLM 有上下文长度限制,不能无限增长
|
||
|
||
解决方案:多层 Memory 系统
|
||
|
||
1. Working Memory(工作记忆):当前正在处理的任务
|
||
2. Short-Term Memory(短期记忆):最近的对话轮次
|
||
3. Long-Term Memory(长期记忆):持久化的知识
|
||
|
||
核心思想:
|
||
- 不是所有信息都需要保留
|
||
- 重要的信息保留,不重要的压缩或丢弃
|
||
- 平衡"记住"和"效率"
|
||
"""
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
# 第二部分:Memory 的数据类型
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
|
||
@dataclass
|
||
class Message:
|
||
"""对话消息"""
|
||
role: str # "user" / "assistant" / "system"
|
||
content: str
|
||
timestamp: str = ""
|
||
metadata: Dict[str, Any] = None
|
||
|
||
def __post_init__(self):
|
||
if not self.timestamp:
|
||
self.timestamp = datetime.now().isoformat()
|
||
if self.metadata is None:
|
||
self.metadata = {}
|
||
|
||
def to_dict(self) -> dict:
|
||
return {
|
||
"role": self.role,
|
||
"content": self.content,
|
||
"timestamp": self.timestamp,
|
||
"metadata": self.metadata,
|
||
}
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
# 第三部分:Working Memory(工作记忆)
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
|
||
class WorkingMemory:
|
||
"""
|
||
工作记忆 - 当前正在处理的任务
|
||
|
||
特点:
|
||
- 容量小,只有几项
|
||
- 频繁读写
|
||
- 断电即失
|
||
- 存放当前任务的关键信息
|
||
"""
|
||
|
||
def __init__(self, capacity: int = 7):
|
||
"""
|
||
Args:
|
||
capacity: 最大容量,超过后自动清理
|
||
"""
|
||
self.capacity = capacity
|
||
self._data: Dict[str, Any] = {}
|
||
|
||
def remember(self, key: str, value: Any) -> None:
|
||
"""记住新信息"""
|
||
self._data[key] = {
|
||
"value": value,
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
self._cleanup()
|
||
|
||
def recall(self, key: str, default: Any = None) -> Any:
|
||
"""回忆信息"""
|
||
item = self._data.get(key)
|
||
if item:
|
||
return item["value"]
|
||
return default
|
||
|
||
def forget(self, key: str) -> None:
|
||
"""忘记信息"""
|
||
if key in self._data:
|
||
del self._data[key]
|
||
|
||
def clear(self) -> None:
|
||
"""清空所有记忆"""
|
||
self._data = {}
|
||
|
||
def get_all(self) -> Dict[str, Any]:
|
||
"""获取所有记忆"""
|
||
return {k: v["value"] for k, v in self._data.items()}
|
||
|
||
def _cleanup(self) -> None:
|
||
"""当容量超限时,清理最旧的信息"""
|
||
if len(self._data) > self.capacity:
|
||
# 按时间排序,删除最旧的
|
||
sorted_items = sorted(
|
||
self._data.items(),
|
||
key=lambda x: x[1]["timestamp"]
|
||
)
|
||
# 删除最早的 1/3
|
||
delete_count = self.capacity // 3
|
||
for key, _ in sorted_items[:delete_count]:
|
||
del self._data[key]
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
# 第四部分:Short-Term Memory(短期记忆)
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
|
||
class ShortTermMemory:
|
||
"""
|
||
短期记忆 - 最近的对对话
|
||
|
||
特点:
|
||
- 保存最近 N 轮对话
|
||
- 超出部分要么压缩,要么丢弃
|
||
- 容易被遗忘
|
||
- 模拟人类的短期记忆
|
||
"""
|
||
|
||
def __init__(self, max_messages: int = 20):
|
||
"""
|
||
Args:
|
||
max_messages: 最大保存消息数
|
||
"""
|
||
self.max_messages = max_messages
|
||
self.messages: List[Message] = []
|
||
|
||
def add(self, role: str, content: str, **metadata) -> None:
|
||
"""添加消息"""
|
||
msg = Message(role=role, content=content, metadata=metadata)
|
||
self.messages.append(msg)
|
||
self._trim()
|
||
|
||
def get_recent(self, n: int = 10) -> List[Message]:
|
||
"""获取最近 N 条消息"""
|
||
return self.messages[-n:]
|
||
|
||
def get_all(self) -> List[Message]:
|
||
"""获取所有消息"""
|
||
return self.messages.copy()
|
||
|
||
def summarize_older(self, keep_recent: int = 5) -> str:
|
||
"""
|
||
将较早的消息压缩为摘要
|
||
|
||
Args:
|
||
keep_recent: 最近多少条保持不变
|
||
|
||
Returns:
|
||
压缩后的摘要文本
|
||
"""
|
||
if len(self.messages) <= keep_recent:
|
||
return ""
|
||
|
||
older = self.messages[:-keep_recent]
|
||
# 简化为摘要
|
||
summary_parts = []
|
||
for msg in older:
|
||
role_label = "用户" if msg.role == "user" else "助手"
|
||
content_preview = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content
|
||
summary_parts.append(f"{role_label}: {content_preview}")
|
||
|
||
return "\n".join(summary_parts)
|
||
|
||
def clear(self) -> None:
|
||
"""清空记忆"""
|
||
self.messages = []
|
||
|
||
def _trim(self) -> None:
|
||
"""超出容量时,删除最旧的消息"""
|
||
while len(self.messages) > self.max_messages:
|
||
self.messages.pop(0)
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
# 第五部分:Long-Term Memory(长期记忆)
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
|
||
class LongTermMemory:
|
||
"""
|
||
长期记忆 - 持久化的知识
|
||
|
||
特点:
|
||
- 存储在磁盘或数据库
|
||
- 持久保存
|
||
- 需要检索才能获取
|
||
- 容量大
|
||
- 对应 RAG(后续 Step 05)
|
||
"""
|
||
|
||
def __init__(self):
|
||
"""
|
||
初始化长期记忆
|
||
|
||
实际应用中,这里会初始化:
|
||
- 向量数据库(ChromaDB/Pinecone)
|
||
- 键值存储(Redis)
|
||
- 图数据库(Neo4j)
|
||
"""
|
||
self._storage: Dict[str, Any] = {}
|
||
|
||
def memorize(self, key: str, value: Any, tags: List[str] = None) -> None:
|
||
"""
|
||
存储信息到长期记忆
|
||
|
||
Args:
|
||
key: 唯一标识
|
||
value: 要存储的内容
|
||
tags: 标签,用于分类和检索
|
||
"""
|
||
self._storage[key] = {
|
||
"value": value,
|
||
"tags": tags or [],
|
||
"created_at": datetime.now().isoformat(),
|
||
"accessed_at": datetime.now().isoformat(),
|
||
}
|
||
|
||
def recall(self, key: str) -> Optional[Any]:
|
||
"""
|
||
从长期记忆中检索
|
||
|
||
Args:
|
||
key: 唯一标识
|
||
|
||
Returns:
|
||
存储的内容,如果不存在返回 None
|
||
"""
|
||
if key in self._storage:
|
||
# 更新访问时间
|
||
self._storage[key]["accessed_at"] = datetime.now().isoformat()
|
||
return self._storage[key]["value"]
|
||
return None
|
||
|
||
def search(self, query: str, tags: List[str] = None) -> List[Any]:
|
||
"""
|
||
搜索长期记忆
|
||
|
||
实际应用中,这里会使用向量相似度搜索
|
||
这里用简单的标签匹配演示
|
||
"""
|
||
results = []
|
||
for item in self._storage.values():
|
||
if tags:
|
||
# 标签匹配
|
||
if any(tag in item["tags"] for tag in tags):
|
||
results.append(item["value"])
|
||
elif query in str(item["value"]):
|
||
results.append(item["value"])
|
||
return results
|
||
|
||
def forget(self, key: str) -> None:
|
||
"""删除长期记忆"""
|
||
if key in self._storage:
|
||
del self._storage[key]
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
# 第六部分:完整的 Memory System
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
|
||
class MemorySystem:
|
||
"""
|
||
完整的记忆系统
|
||
|
||
整合三种记忆:
|
||
- Working Memory: 当前任务
|
||
- Short-Term Memory: 最近对话
|
||
- Long-Term Memory: 持久知识
|
||
"""
|
||
|
||
def __init__(self):
|
||
self.working = WorkingMemory(capacity=7)
|
||
self.short_term = ShortTermMemory(max_messages=20)
|
||
self.long_term = LongTermMemory()
|
||
|
||
def remember_task(self, task_id: str, task_info: Dict) -> None:
|
||
"""记住当前任务"""
|
||
self.working.remember("current_task", task_info)
|
||
|
||
def get_current_task(self) -> Optional[Dict]:
|
||
"""获取当前任务"""
|
||
return self.working.recall("current_task")
|
||
|
||
def add_message(self, role: str, content: str) -> None:
|
||
"""添加对话消息"""
|
||
self.short_term.add(role, content)
|
||
|
||
def get_context(self, include_older_summary: bool = True) -> str:
|
||
"""
|
||
获取发送给 LLM 的上下文
|
||
|
||
整合所有记忆,形成完整的上下文
|
||
"""
|
||
context_parts = []
|
||
|
||
# 当前任务
|
||
current_task = self.get_current_task()
|
||
if current_task:
|
||
context_parts.append(f"[当前任务]\n{json.dumps(current_task, ensure_ascii=False)}")
|
||
|
||
# 最近对话
|
||
recent = self.short_term.get_recent(10)
|
||
if recent:
|
||
context_parts.append("[最近对话]")
|
||
for msg in recent:
|
||
context_parts.append(f"- {msg.role}: {msg.content[:100]}")
|
||
|
||
# 较早对话摘要
|
||
if include_older_summary:
|
||
older_summary = self.short_term.summarize_older(keep_recent=5)
|
||
if older_summary:
|
||
context_parts.append(f"[较早对话摘要]\n{older_summary}")
|
||
|
||
return "\n\n".join(context_parts)
|
||
|
||
def clear_session(self) -> None:
|
||
"""清除会话相关的记忆(保留长期记忆)"""
|
||
self.working.clear()
|
||
self.short_term.clear()
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
# 演示代码
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
|
||
def demo():
|
||
"""演示记忆系统"""
|
||
print("=" * 60)
|
||
print("Step 04: Memory - 记忆系统演示")
|
||
print("=" * 60)
|
||
|
||
# 创建记忆系统
|
||
memory = MemorySystem()
|
||
|
||
# 1. 记住当前任务
|
||
print("\n📝 记住当前任务")
|
||
memory.remember_task("task_001", {
|
||
"type": "生成报表",
|
||
"requirement": "销售月报",
|
||
"status": "进行中"
|
||
})
|
||
print(f" 当前任务: {memory.get_current_task()}")
|
||
|
||
# 2. 添加对话
|
||
print("\n💬 添加对话消息")
|
||
messages = [
|
||
("user", "帮我生成一个销售报表"),
|
||
("assistant", "好的,请问你需要显示哪些数据?"),
|
||
("user", "显示月度汇总,包括销售额和数量"),
|
||
("assistant", "明白了,正在生成..."),
|
||
("user", "再添加一个增长率"),
|
||
("assistant", "好的,正在添加..."),
|
||
]
|
||
|
||
for role, content in messages:
|
||
memory.add_message(role, content)
|
||
print(f" 添加: [{role}] {content[:30]}...")
|
||
|
||
# 3. 获取上下文
|
||
print("\n📋 获取完整上下文")
|
||
context = memory.get_context()
|
||
print(context)
|
||
|
||
# 4. 获取较早对话摘要
|
||
print("\n📝 较早对话摘要")
|
||
summary = memory.short_term.summarize_older(keep_recent=3)
|
||
print(summary)
|
||
|
||
# 5. 长期记忆
|
||
print("\n💾 添加长期记忆")
|
||
memory.long_term.memorize(
|
||
key="user_preferences",
|
||
value={"name": "张三", "default_report": "销售报表"},
|
||
tags=["用户信息", "偏好"]
|
||
)
|
||
result = memory.long_term.recall("user_preferences")
|
||
print(f" 检索结果: {result}")
|
||
|
||
print("\n" + "=" * 60)
|
||
print("✅ 演示完成")
|
||
print("=" * 60)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
demo()
|