""" Step 05-07 练习题答案 ⚠️ 先自己思考,再看答案! ⚠️ 答案不是唯一的,这里只是其中一种实现 """ import re import time from typing import Callable from step_05_07_advanced.concept import ( Agent, MultiAgentSystem, SelfCorrectingAgent, SimpleRAG, ValidationResult, ) # ═══════════════════════════════════════════════════════════════════════════════ # 练习 1 答案:升级 SimpleRAG 分词 # ═══════════════════════════════════════════════════════════════════════════════ _TOKEN_PATTERN = re.compile(r"[\u4e00-\u9fff]|[A-Za-z]+") def _tokenize(text: str) -> set[str]: return set(_TOKEN_PATTERN.findall(text.lower())) def upgrade_retrieve(rag: SimpleRAG) -> None: def retrieve(self, query: str, top_k: int = 3): q_words = _tokenize(query) scored = [] for doc in self.documents: d_words = _tokenize(doc["text"]) union = q_words | d_words if not union: continue score = len(q_words & d_words) / len(union) scored.append((score, doc)) scored.sort(key=lambda x: x[0], reverse=True) return [doc for _, doc in scored[:top_k]] SimpleRAG.retrieve = retrieve # ═══════════════════════════════════════════════════════════════════════════════ # 练习 2 答案:Self-Correction 主循环 # ═══════════════════════════════════════════════════════════════════════════════ def install_self_correction_run() -> None: def run( self: SelfCorrectingAgent, requirement: str, generate_fn: Callable, validate_fn: Callable, max_retries: int = 3, ): feedback = None output = None for attempt in range(max_retries): output = generate_fn(requirement, attempt, feedback) validation: ValidationResult = validate_fn(output) if validation.passed: return output feedback = self.build_feedback(validation, output, attempt) return output SelfCorrectingAgent.run = run # ═══════════════════════════════════════════════════════════════════════════════ # 练习 3 答案:MultiAgentSystem 超时与回退 # ═══════════════════════════════════════════════════════════════════════════════ def install_safe_process() -> None: def process(self: MultiAgentSystem, requirement: str, timeout_seconds: float = 2.0): try: return self._timed_process(requirement, timeout_seconds) except Exception as e: return {"error": str(e)} def _timed_process(self, requirement: str, timeout_seconds: float): deadline = time.monotonic() + timeout_seconds searcher = self.agents.get("searcher") if searcher: self._check_timeout(deadline) context = searcher.process(requirement) else: context = "" generator = self.agents.get("generator") if generator: self._check_timeout(deadline) result = generator.process({"requirement": requirement, "context": context}) else: result = requirement validator = self.agents.get("validator") if validator: self._check_timeout(deadline) validation = validator.process(result) if not validation.get("passed", True): return {"error": "验证失败", "validation": validation} return result def _check_timeout(self, deadline: float): if time.monotonic() > deadline: raise TimeoutError("Multi-Agent 处理超时") MultiAgentSystem.process = process MultiAgentSystem._timed_process = _timed_process MultiAgentSystem._check_timeout = _check_timeout # ═══════════════════════════════════════════════════════════════════════════════ # 测试 # ═══════════════════════════════════════════════════════════════════════════════ def test_answers(): print("\n" + "=" * 60) print("Step 05-07 练习答案测试") print("=" * 60) print("\n📝 练习 1: 升级 SimpleRAG") rag = SimpleRAG() rag.add_document("JasperReports 是一个 Java 报表库", {"source": "doc1"}) rag.add_document("JRXML 是 JasperReports 模板格式", {"source": "doc2"}) upgrade_retrieve(rag) hits = rag.retrieve("JasperReports") print(f" 检索命中 {len(hits)} 条") for d in hits: print(f" - {d['text']}") print("\n📝 练习 2: Self-Correction run()") install_self_correction_run() sc = SelfCorrectingAgent() def fake_generate(req, attempt, feedback): # 第一次失败,第二次成功 return f"v{attempt}" def fake_validate(output): passed = output == "v1" return ValidationResult(passed=passed, score=1.0 if passed else 0.2, issues=[] if passed else ["不达标"]) final = sc.run("测试", fake_generate, fake_validate, max_retries=3) print(f" 最终结果 = {final}") print("\n📝 练习 3: Multi-Agent 安全 process()") install_safe_process() class BoomValidator(Agent): name = "validator" def process(self, input_data): raise RuntimeError("故意崩溃") sys = MultiAgentSystem() sys.agents["validator"] = BoomValidator() res = sys.process("任何需求") print(f" 异常被吞掉: {res}") if __name__ == "__main__": test_answers()