Files

170 lines
6.5 KiB
Python

"""
Step 05-07 练习题答案
⚠️ 先自己思考,再看答案!
⚠️ 答案不是唯一的,这里只是其中一种实现
"""
import re
import time
from typing import Callable
from step_05_07_advanced.concept import (
Agent,
MultiAgentSystem,
SelfCorrectingAgent,
SimpleRAG,
ValidationResult,
)
# ═══════════════════════════════════════════════════════════════════════════════
# 练习 1 答案:升级 SimpleRAG 分词
# ═══════════════════════════════════════════════════════════════════════════════
_TOKEN_PATTERN = re.compile(r"[\u4e00-\u9fff]|[A-Za-z]+")
def _tokenize(text: str) -> set[str]:
return set(_TOKEN_PATTERN.findall(text.lower()))
def upgrade_retrieve(rag: SimpleRAG) -> None:
def retrieve(self, query: str, top_k: int = 3):
q_words = _tokenize(query)
scored = []
for doc in self.documents:
d_words = _tokenize(doc["text"])
union = q_words | d_words
if not union:
continue
score = len(q_words & d_words) / len(union)
scored.append((score, doc))
scored.sort(key=lambda x: x[0], reverse=True)
return [doc for _, doc in scored[:top_k]]
SimpleRAG.retrieve = retrieve
# ═══════════════════════════════════════════════════════════════════════════════
# 练习 2 答案:Self-Correction 主循环
# ═══════════════════════════════════════════════════════════════════════════════
def install_self_correction_run() -> None:
def run(
self: SelfCorrectingAgent,
requirement: str,
generate_fn: Callable,
validate_fn: Callable,
max_retries: int = 3,
):
feedback = None
output = None
for attempt in range(max_retries):
output = generate_fn(requirement, attempt, feedback)
validation: ValidationResult = validate_fn(output)
if validation.passed:
return output
feedback = self.build_feedback(validation, output, attempt)
return output
SelfCorrectingAgent.run = run
# ═══════════════════════════════════════════════════════════════════════════════
# 练习 3 答案:MultiAgentSystem 超时与回退
# ═══════════════════════════════════════════════════════════════════════════════
def install_safe_process() -> None:
def process(self: MultiAgentSystem, requirement: str, timeout_seconds: float = 2.0):
try:
return self._timed_process(requirement, timeout_seconds)
except Exception as e:
return {"error": str(e)}
def _timed_process(self, requirement: str, timeout_seconds: float):
deadline = time.monotonic() + timeout_seconds
searcher = self.agents.get("searcher")
if searcher:
self._check_timeout(deadline)
context = searcher.process(requirement)
else:
context = ""
generator = self.agents.get("generator")
if generator:
self._check_timeout(deadline)
result = generator.process({"requirement": requirement, "context": context})
else:
result = requirement
validator = self.agents.get("validator")
if validator:
self._check_timeout(deadline)
validation = validator.process(result)
if not validation.get("passed", True):
return {"error": "验证失败", "validation": validation}
return result
def _check_timeout(self, deadline: float):
if time.monotonic() > deadline:
raise TimeoutError("Multi-Agent 处理超时")
MultiAgentSystem.process = process
MultiAgentSystem._timed_process = _timed_process
MultiAgentSystem._check_timeout = _check_timeout
# ═══════════════════════════════════════════════════════════════════════════════
# 测试
# ═══════════════════════════════════════════════════════════════════════════════
def test_answers():
print("\n" + "=" * 60)
print("Step 05-07 练习答案测试")
print("=" * 60)
print("\n📝 练习 1: 升级 SimpleRAG")
rag = SimpleRAG()
rag.add_document("JasperReports 是一个 Java 报表库", {"source": "doc1"})
rag.add_document("JRXML 是 JasperReports 模板格式", {"source": "doc2"})
upgrade_retrieve(rag)
hits = rag.retrieve("JasperReports")
print(f" 检索命中 {len(hits)}")
for d in hits:
print(f" - {d['text']}")
print("\n📝 练习 2: Self-Correction run()")
install_self_correction_run()
sc = SelfCorrectingAgent()
def fake_generate(req, attempt, feedback):
# 第一次失败,第二次成功
return f"v{attempt}"
def fake_validate(output):
passed = output == "v1"
return ValidationResult(passed=passed, score=1.0 if passed else 0.2, issues=[] if passed else ["不达标"])
final = sc.run("测试", fake_generate, fake_validate, max_retries=3)
print(f" 最终结果 = {final}")
print("\n📝 练习 3: Multi-Agent 安全 process()")
install_safe_process()
class BoomValidator(Agent):
name = "validator"
def process(self, input_data):
raise RuntimeError("故意崩溃")
sys = MultiAgentSystem()
sys.agents["validator"] = BoomValidator()
res = sys.process("任何需求")
print(f" 异常被吞掉: {res}")
if __name__ == "__main__":
test_answers()