Files

417 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Step 04: Memory - 记忆系统
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
🎓 本节内容:
1. 为什么 Agent 需要 Memory
2. 多层记忆架构
3. 上下文窗口管理
4. 对话压缩技术
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
"""
from typing import TypedDict, List, Dict, Any, Optional
from dataclasses import dataclass
from datetime import datetime
import json
# ═══════════════════════════════════════════════════════════════════════════════
# 第一部分:理解为什么需要 Memory
# ═══════════════════════════════════════════════════════════════════════════════
"""
回顾 Step 03 的 SimpleAgent
问题:对话历史都存在 state['messages'] 里
问题:如果对话很长,发送给 LLM 的 token 会越来越多
问题:LLM 有上下文长度限制,不能无限增长
解决方案:多层 Memory 系统
1. Working Memory(工作记忆):当前正在处理的任务
2. Short-Term Memory(短期记忆):最近的对话轮次
3. Long-Term Memory(长期记忆):持久化的知识
核心思想:
- 不是所有信息都需要保留
- 重要的信息保留,不重要的压缩或丢弃
- 平衡"记住""效率"
"""
# ═══════════════════════════════════════════════════════════════════════════════
# 第二部分:Memory 的数据类型
# ═══════════════════════════════════════════════════════════════════════════════
@dataclass
class Message:
"""对话消息"""
role: str # "user" / "assistant" / "system"
content: str
timestamp: str = ""
metadata: Dict[str, Any] = None
def __post_init__(self):
if not self.timestamp:
self.timestamp = datetime.now().isoformat()
if self.metadata is None:
self.metadata = {}
def to_dict(self) -> dict:
return {
"role": self.role,
"content": self.content,
"timestamp": self.timestamp,
"metadata": self.metadata,
}
# ═══════════════════════════════════════════════════════════════════════════════
# 第三部分:Working Memory(工作记忆)
# ═══════════════════════════════════════════════════════════════════════════════
class WorkingMemory:
"""
工作记忆 - 当前正在处理的任务
特点:
- 容量小,只有几项
- 频繁读写
- 断电即失
- 存放当前任务的关键信息
"""
def __init__(self, capacity: int = 7):
"""
Args:
capacity: 最大容量,超过后自动清理
"""
self.capacity = capacity
self._data: Dict[str, Any] = {}
def remember(self, key: str, value: Any) -> None:
"""记住新信息"""
self._data[key] = {
"value": value,
"timestamp": datetime.now().isoformat()
}
self._cleanup()
def recall(self, key: str, default: Any = None) -> Any:
"""回忆信息"""
item = self._data.get(key)
if item:
return item["value"]
return default
def forget(self, key: str) -> None:
"""忘记信息"""
if key in self._data:
del self._data[key]
def clear(self) -> None:
"""清空所有记忆"""
self._data = {}
def get_all(self) -> Dict[str, Any]:
"""获取所有记忆"""
return {k: v["value"] for k, v in self._data.items()}
def _cleanup(self) -> None:
"""当容量超限时,清理最旧的信息"""
if len(self._data) > self.capacity:
# 按时间排序,删除最旧的
sorted_items = sorted(
self._data.items(),
key=lambda x: x[1]["timestamp"]
)
# 删除最早的 1/3
delete_count = self.capacity // 3
for key, _ in sorted_items[:delete_count]:
del self._data[key]
# ═══════════════════════════════════════════════════════════════════════════════
# 第四部分:Short-Term Memory(短期记忆)
# ═══════════════════════════════════════════════════════════════════════════════
class ShortTermMemory:
"""
短期记忆 - 最近的对对话
特点:
- 保存最近 N 轮对话
- 超出部分要么压缩,要么丢弃
- 容易被遗忘
- 模拟人类的短期记忆
"""
def __init__(self, max_messages: int = 20):
"""
Args:
max_messages: 最大保存消息数
"""
self.max_messages = max_messages
self.messages: List[Message] = []
def add(self, role: str, content: str, **metadata) -> None:
"""添加消息"""
msg = Message(role=role, content=content, metadata=metadata)
self.messages.append(msg)
self._trim()
def get_recent(self, n: int = 10) -> List[Message]:
"""获取最近 N 条消息"""
return self.messages[-n:]
def get_all(self) -> List[Message]:
"""获取所有消息"""
return self.messages.copy()
def summarize_older(self, keep_recent: int = 5) -> str:
"""
将较早的消息压缩为摘要
Args:
keep_recent: 最近多少条保持不变
Returns:
压缩后的摘要文本
"""
if len(self.messages) <= keep_recent:
return ""
older = self.messages[:-keep_recent]
# 简化为摘要
summary_parts = []
for msg in older:
role_label = "用户" if msg.role == "user" else "助手"
content_preview = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content
summary_parts.append(f"{role_label}: {content_preview}")
return "\n".join(summary_parts)
def clear(self) -> None:
"""清空记忆"""
self.messages = []
def _trim(self) -> None:
"""超出容量时,删除最旧的消息"""
while len(self.messages) > self.max_messages:
self.messages.pop(0)
# ═══════════════════════════════════════════════════════════════════════════════
# 第五部分:Long-Term Memory(长期记忆)
# ═══════════════════════════════════════════════════════════════════════════════
class LongTermMemory:
"""
长期记忆 - 持久化的知识
特点:
- 存储在磁盘或数据库
- 持久保存
- 需要检索才能获取
- 容量大
- 对应 RAG(后续 Step 05
"""
def __init__(self):
"""
初始化长期记忆
实际应用中,这里会初始化:
- 向量数据库(ChromaDB/Pinecone
- 键值存储(Redis)
- 图数据库(Neo4j)
"""
self._storage: Dict[str, Any] = {}
def memorize(self, key: str, value: Any, tags: List[str] = None) -> None:
"""
存储信息到长期记忆
Args:
key: 唯一标识
value: 要存储的内容
tags: 标签,用于分类和检索
"""
self._storage[key] = {
"value": value,
"tags": tags or [],
"created_at": datetime.now().isoformat(),
"accessed_at": datetime.now().isoformat(),
}
def recall(self, key: str) -> Optional[Any]:
"""
从长期记忆中检索
Args:
key: 唯一标识
Returns:
存储的内容,如果不存在返回 None
"""
if key in self._storage:
# 更新访问时间
self._storage[key]["accessed_at"] = datetime.now().isoformat()
return self._storage[key]["value"]
return None
def search(self, query: str, tags: List[str] = None) -> List[Any]:
"""
搜索长期记忆
实际应用中,这里会使用向量相似度搜索
这里用简单的标签匹配演示
"""
results = []
for item in self._storage.values():
if tags:
# 标签匹配
if any(tag in item["tags"] for tag in tags):
results.append(item["value"])
elif query in str(item["value"]):
results.append(item["value"])
return results
def forget(self, key: str) -> None:
"""删除长期记忆"""
if key in self._storage:
del self._storage[key]
# ═══════════════════════════════════════════════════════════════════════════════
# 第六部分:完整的 Memory System
# ═══════════════════════════════════════════════════════════════════════════════
class MemorySystem:
"""
完整的记忆系统
整合三种记忆:
- Working Memory: 当前任务
- Short-Term Memory: 最近对话
- Long-Term Memory: 持久知识
"""
def __init__(self):
self.working = WorkingMemory(capacity=7)
self.short_term = ShortTermMemory(max_messages=20)
self.long_term = LongTermMemory()
def remember_task(self, task_id: str, task_info: Dict) -> None:
"""记住当前任务"""
self.working.remember("current_task", task_info)
def get_current_task(self) -> Optional[Dict]:
"""获取当前任务"""
return self.working.recall("current_task")
def add_message(self, role: str, content: str) -> None:
"""添加对话消息"""
self.short_term.add(role, content)
def get_context(self, include_older_summary: bool = True) -> str:
"""
获取发送给 LLM 的上下文
整合所有记忆,形成完整的上下文
"""
context_parts = []
# 当前任务
current_task = self.get_current_task()
if current_task:
context_parts.append(f"[当前任务]\n{json.dumps(current_task, ensure_ascii=False)}")
# 最近对话
recent = self.short_term.get_recent(10)
if recent:
context_parts.append("[最近对话]")
for msg in recent:
context_parts.append(f"- {msg.role}: {msg.content[:100]}")
# 较早对话摘要
if include_older_summary:
older_summary = self.short_term.summarize_older(keep_recent=5)
if older_summary:
context_parts.append(f"[较早对话摘要]\n{older_summary}")
return "\n\n".join(context_parts)
def clear_session(self) -> None:
"""清除会话相关的记忆(保留长期记忆)"""
self.working.clear()
self.short_term.clear()
# ═══════════════════════════════════════════════════════════════════════════════
# 演示代码
# ═══════════════════════════════════════════════════════════════════════════════
def demo():
"""演示记忆系统"""
print("=" * 60)
print("Step 04: Memory - 记忆系统演示")
print("=" * 60)
# 创建记忆系统
memory = MemorySystem()
# 1. 记住当前任务
print("\n📝 记住当前任务")
memory.remember_task("task_001", {
"type": "生成报表",
"requirement": "销售月报",
"status": "进行中"
})
print(f" 当前任务: {memory.get_current_task()}")
# 2. 添加对话
print("\n💬 添加对话消息")
messages = [
("user", "帮我生成一个销售报表"),
("assistant", "好的,请问你需要显示哪些数据?"),
("user", "显示月度汇总,包括销售额和数量"),
("assistant", "明白了,正在生成..."),
("user", "再添加一个增长率"),
("assistant", "好的,正在添加..."),
]
for role, content in messages:
memory.add_message(role, content)
print(f" 添加: [{role}] {content[:30]}...")
# 3. 获取上下文
print("\n📋 获取完整上下文")
context = memory.get_context()
print(context)
# 4. 获取较早对话摘要
print("\n📝 较早对话摘要")
summary = memory.short_term.summarize_older(keep_recent=3)
print(summary)
# 5. 长期记忆
print("\n💾 添加长期记忆")
memory.long_term.memorize(
key="user_preferences",
value={"name": "张三", "default_report": "销售报表"},
tags=["用户信息", "偏好"]
)
result = memory.long_term.recall("user_preferences")
print(f" 检索结果: {result}")
print("\n" + "=" * 60)
print("✅ 演示完成")
print("=" * 60)
if __name__ == "__main__":
demo()