170 lines
6.5 KiB
Python
170 lines
6.5 KiB
Python
"""
|
|
Step 05-07 练习题答案
|
|
|
|
⚠️ 先自己思考,再看答案!
|
|
⚠️ 答案不是唯一的,这里只是其中一种实现
|
|
"""
|
|
|
|
import re
|
|
import time
|
|
from typing import Callable
|
|
|
|
from step_05_07_advanced.concept import (
|
|
Agent,
|
|
MultiAgentSystem,
|
|
SelfCorrectingAgent,
|
|
SimpleRAG,
|
|
ValidationResult,
|
|
)
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
# 练习 1 答案:升级 SimpleRAG 分词
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
|
|
_TOKEN_PATTERN = re.compile(r"[\u4e00-\u9fff]|[A-Za-z]+")
|
|
|
|
|
|
def _tokenize(text: str) -> set[str]:
|
|
return set(_TOKEN_PATTERN.findall(text.lower()))
|
|
|
|
|
|
def upgrade_retrieve(rag: SimpleRAG) -> None:
|
|
def retrieve(self, query: str, top_k: int = 3):
|
|
q_words = _tokenize(query)
|
|
scored = []
|
|
for doc in self.documents:
|
|
d_words = _tokenize(doc["text"])
|
|
union = q_words | d_words
|
|
if not union:
|
|
continue
|
|
score = len(q_words & d_words) / len(union)
|
|
scored.append((score, doc))
|
|
scored.sort(key=lambda x: x[0], reverse=True)
|
|
return [doc for _, doc in scored[:top_k]]
|
|
|
|
SimpleRAG.retrieve = retrieve
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
# 练习 2 答案:Self-Correction 主循环
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
|
|
def install_self_correction_run() -> None:
|
|
def run(
|
|
self: SelfCorrectingAgent,
|
|
requirement: str,
|
|
generate_fn: Callable,
|
|
validate_fn: Callable,
|
|
max_retries: int = 3,
|
|
):
|
|
feedback = None
|
|
output = None
|
|
for attempt in range(max_retries):
|
|
output = generate_fn(requirement, attempt, feedback)
|
|
validation: ValidationResult = validate_fn(output)
|
|
if validation.passed:
|
|
return output
|
|
feedback = self.build_feedback(validation, output, attempt)
|
|
return output
|
|
|
|
SelfCorrectingAgent.run = run
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
# 练习 3 答案:MultiAgentSystem 超时与回退
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
|
|
def install_safe_process() -> None:
|
|
def process(self: MultiAgentSystem, requirement: str, timeout_seconds: float = 2.0):
|
|
try:
|
|
return self._timed_process(requirement, timeout_seconds)
|
|
except Exception as e:
|
|
return {"error": str(e)}
|
|
|
|
def _timed_process(self, requirement: str, timeout_seconds: float):
|
|
deadline = time.monotonic() + timeout_seconds
|
|
searcher = self.agents.get("searcher")
|
|
if searcher:
|
|
self._check_timeout(deadline)
|
|
context = searcher.process(requirement)
|
|
else:
|
|
context = ""
|
|
|
|
generator = self.agents.get("generator")
|
|
if generator:
|
|
self._check_timeout(deadline)
|
|
result = generator.process({"requirement": requirement, "context": context})
|
|
else:
|
|
result = requirement
|
|
|
|
validator = self.agents.get("validator")
|
|
if validator:
|
|
self._check_timeout(deadline)
|
|
validation = validator.process(result)
|
|
if not validation.get("passed", True):
|
|
return {"error": "验证失败", "validation": validation}
|
|
|
|
return result
|
|
|
|
def _check_timeout(self, deadline: float):
|
|
if time.monotonic() > deadline:
|
|
raise TimeoutError("Multi-Agent 处理超时")
|
|
|
|
MultiAgentSystem.process = process
|
|
MultiAgentSystem._timed_process = _timed_process
|
|
MultiAgentSystem._check_timeout = _check_timeout
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
# 测试
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
|
|
def test_answers():
|
|
print("\n" + "=" * 60)
|
|
print("Step 05-07 练习答案测试")
|
|
print("=" * 60)
|
|
|
|
print("\n📝 练习 1: 升级 SimpleRAG")
|
|
rag = SimpleRAG()
|
|
rag.add_document("JasperReports 是一个 Java 报表库", {"source": "doc1"})
|
|
rag.add_document("JRXML 是 JasperReports 模板格式", {"source": "doc2"})
|
|
upgrade_retrieve(rag)
|
|
hits = rag.retrieve("JasperReports")
|
|
print(f" 检索命中 {len(hits)} 条")
|
|
for d in hits:
|
|
print(f" - {d['text']}")
|
|
|
|
print("\n📝 练习 2: Self-Correction run()")
|
|
install_self_correction_run()
|
|
sc = SelfCorrectingAgent()
|
|
|
|
def fake_generate(req, attempt, feedback):
|
|
# 第一次失败,第二次成功
|
|
return f"v{attempt}"
|
|
|
|
def fake_validate(output):
|
|
passed = output == "v1"
|
|
return ValidationResult(passed=passed, score=1.0 if passed else 0.2, issues=[] if passed else ["不达标"])
|
|
|
|
final = sc.run("测试", fake_generate, fake_validate, max_retries=3)
|
|
print(f" 最终结果 = {final}")
|
|
|
|
print("\n📝 练习 3: Multi-Agent 安全 process()")
|
|
install_safe_process()
|
|
|
|
class BoomValidator(Agent):
|
|
name = "validator"
|
|
|
|
def process(self, input_data):
|
|
raise RuntimeError("故意崩溃")
|
|
|
|
sys = MultiAgentSystem()
|
|
sys.agents["validator"] = BoomValidator()
|
|
res = sys.process("任何需求")
|
|
print(f" 异常被吞掉: {res}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_answers()
|