119 lines
5.2 KiB
Python
119 lines
5.2 KiB
Python
"""
|
|
Step 04 练习题答案
|
|
|
|
⚠️ 先自己思考,再看答案!
|
|
⚠️ 答案不是唯一的,这里只是其中一种实现
|
|
"""
|
|
|
|
import math
|
|
import re
|
|
|
|
from step_04_memory.concept import MemorySystem, Message, ShortTermMemory
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
# 练习 1 答案:Token 估算
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
|
|
_CN_PATTERN = re.compile(r"[\u4e00-\u9fff]")
|
|
_EN_WORD = re.compile(r"[A-Za-z]+")
|
|
|
|
|
|
def estimate_tokens_for_text(text: str) -> int:
|
|
cn = len(_CN_PATTERN.findall(text))
|
|
en = len(_EN_WORD.findall(text))
|
|
return math.ceil(cn * 1.5 + en * 1.3)
|
|
|
|
|
|
def install_estimate_tokens() -> None:
|
|
def estimate_tokens(self: ShortTermMemory) -> int:
|
|
return sum(estimate_tokens_for_text(m.content) for m in self.messages)
|
|
|
|
ShortTermMemory.estimate_tokens = estimate_tokens
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
# 练习 2 答案:基于 Token 阈值的自动压缩
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
|
|
def install_maybe_compress() -> None:
|
|
def maybe_compress(self: ShortTermMemory, max_tokens: int = 800) -> bool:
|
|
if not hasattr(self, "estimate_tokens"):
|
|
raise RuntimeError("请先调用 install_estimate_tokens()")
|
|
if self.estimate_tokens() <= max_tokens:
|
|
return False
|
|
summary = self.summarize_older(keep_recent=5)
|
|
if not summary:
|
|
return False
|
|
# 保留最近 5 条,把摘要作为 system message 放最前
|
|
recent = self.messages[-5:]
|
|
self.messages = [Message(role="system", content=f"[历史摘要]\n{summary}")] + recent
|
|
return True
|
|
|
|
ShortTermMemory.maybe_compress = maybe_compress
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
# 练习 3 答案:把 MemorySystem 接到 SimpleAgent
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
|
|
def install_memory_to_agent() -> None:
|
|
from step_03_simple_agent.concept import SimpleAgent
|
|
|
|
orig_init = SimpleAgent.__init__
|
|
|
|
def patched_init(self, *args, **kwargs):
|
|
orig_init(self, *args, **kwargs)
|
|
self.memory = MemorySystem()
|
|
|
|
SimpleAgent.__init__ = patched_init
|
|
|
|
orig_process = SimpleAgent.process
|
|
|
|
def patched_process(self, user_input: str) -> str:
|
|
# 把记忆上下文注入 state
|
|
self.state["context"] = self.memory.get_context()
|
|
response = orig_process(self, user_input)
|
|
# 记录本轮对话
|
|
self.memory.add_message("user", user_input)
|
|
self.memory.add_message("assistant", response)
|
|
return response
|
|
|
|
SimpleAgent.process = patched_process
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
# 测试
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
|
|
def test_answers():
|
|
print("\n" + "=" * 60)
|
|
print("Step 04 练习答案测试")
|
|
print("=" * 60)
|
|
|
|
install_estimate_tokens()
|
|
install_maybe_compress()
|
|
|
|
mem = MemorySystem()
|
|
for i in range(20):
|
|
mem.short_term.add("user", f"第 {i} 轮对话内容,包含中文与 english words " * 5)
|
|
print(f"\n📝 练习 1: 注入 20 条后估算 token = {mem.short_term.estimate_tokens()}")
|
|
compressed = mem.short_term.maybe_compress(max_tokens=200)
|
|
print(f" maybe_compress() = {compressed}, 压缩后消息数 = {len(mem.short_term.messages)}")
|
|
print(f" 压缩后估算 token = {mem.short_term.estimate_tokens()}")
|
|
|
|
print("\n📝 练习 3: SimpleAgent 接入 Memory")
|
|
try:
|
|
install_memory_to_agent()
|
|
from step_03_simple_agent.concept import SimpleAgent
|
|
|
|
agent = SimpleAgent()
|
|
agent.process("1 + 2")
|
|
print(f" agent.memory 工作正常,消息数 = {len(agent.memory.short_term.messages)}")
|
|
except Exception as e:
|
|
print(f" 接入失败(可忽略,需在 step_03 父目录运行): {e}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_answers()
|