fix: 修复 NameError/状态污染/类型标注/统计; 补全练习与 main; 新增 config/.gitignore/requirements; 文档统一
This commit is contained in:
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
Step 05-07 练习题答案
|
||||
|
||||
⚠️ 先自己思考,再看答案!
|
||||
⚠️ 答案不是唯一的,这里只是其中一种实现
|
||||
"""
|
||||
|
||||
import re
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
from step_05_07_advanced.concept import (
|
||||
Agent,
|
||||
MultiAgentSystem,
|
||||
SelfCorrectingAgent,
|
||||
SimpleRAG,
|
||||
ValidationResult,
|
||||
)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# 练习 1 答案:升级 SimpleRAG 分词
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
_TOKEN_PATTERN = re.compile(r"[\u4e00-\u9fff]|[A-Za-z]+")
|
||||
|
||||
|
||||
def _tokenize(text: str) -> set[str]:
|
||||
return set(_TOKEN_PATTERN.findall(text.lower()))
|
||||
|
||||
|
||||
def upgrade_retrieve(rag: SimpleRAG) -> None:
|
||||
def retrieve(self, query: str, top_k: int = 3):
|
||||
q_words = _tokenize(query)
|
||||
scored = []
|
||||
for doc in self.documents:
|
||||
d_words = _tokenize(doc["text"])
|
||||
union = q_words | d_words
|
||||
if not union:
|
||||
continue
|
||||
score = len(q_words & d_words) / len(union)
|
||||
scored.append((score, doc))
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
return [doc for _, doc in scored[:top_k]]
|
||||
|
||||
SimpleRAG.retrieve = retrieve
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# 练习 2 答案:Self-Correction 主循环
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def install_self_correction_run() -> None:
|
||||
def run(
|
||||
self: SelfCorrectingAgent,
|
||||
requirement: str,
|
||||
generate_fn: Callable,
|
||||
validate_fn: Callable,
|
||||
max_retries: int = 3,
|
||||
):
|
||||
feedback = None
|
||||
output = None
|
||||
for attempt in range(max_retries):
|
||||
output = generate_fn(requirement, attempt, feedback)
|
||||
validation: ValidationResult = validate_fn(output)
|
||||
if validation.passed:
|
||||
return output
|
||||
feedback = self.build_feedback(validation, output, attempt)
|
||||
return output
|
||||
|
||||
SelfCorrectingAgent.run = run
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# 练习 3 答案:MultiAgentSystem 超时与回退
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def install_safe_process() -> None:
|
||||
def process(self: MultiAgentSystem, requirement: str, timeout_seconds: float = 2.0):
|
||||
try:
|
||||
return self._timed_process(requirement, timeout_seconds)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
def _timed_process(self, requirement: str, timeout_seconds: float):
|
||||
deadline = time.monotonic() + timeout_seconds
|
||||
searcher = self.agents.get("searcher")
|
||||
if searcher:
|
||||
self._check_timeout(deadline)
|
||||
context = searcher.process(requirement)
|
||||
else:
|
||||
context = ""
|
||||
|
||||
generator = self.agents.get("generator")
|
||||
if generator:
|
||||
self._check_timeout(deadline)
|
||||
result = generator.process({"requirement": requirement, "context": context})
|
||||
else:
|
||||
result = requirement
|
||||
|
||||
validator = self.agents.get("validator")
|
||||
if validator:
|
||||
self._check_timeout(deadline)
|
||||
validation = validator.process(result)
|
||||
if not validation.get("passed", True):
|
||||
return {"error": "验证失败", "validation": validation}
|
||||
|
||||
return result
|
||||
|
||||
def _check_timeout(self, deadline: float):
|
||||
if time.monotonic() > deadline:
|
||||
raise TimeoutError("Multi-Agent 处理超时")
|
||||
|
||||
MultiAgentSystem.process = process
|
||||
MultiAgentSystem._timed_process = _timed_process
|
||||
MultiAgentSystem._check_timeout = _check_timeout
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# 测试
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def test_answers():
|
||||
print("\n" + "=" * 60)
|
||||
print("Step 05-07 练习答案测试")
|
||||
print("=" * 60)
|
||||
|
||||
print("\n📝 练习 1: 升级 SimpleRAG")
|
||||
rag = SimpleRAG()
|
||||
rag.add_document("JasperReports 是一个 Java 报表库", {"source": "doc1"})
|
||||
rag.add_document("JRXML 是 JasperReports 模板格式", {"source": "doc2"})
|
||||
upgrade_retrieve(rag)
|
||||
hits = rag.retrieve("JasperReports")
|
||||
print(f" 检索命中 {len(hits)} 条")
|
||||
for d in hits:
|
||||
print(f" - {d['text']}")
|
||||
|
||||
print("\n📝 练习 2: Self-Correction run()")
|
||||
install_self_correction_run()
|
||||
sc = SelfCorrectingAgent()
|
||||
|
||||
def fake_generate(req, attempt, feedback):
|
||||
# 第一次失败,第二次成功
|
||||
return f"v{attempt}"
|
||||
|
||||
def fake_validate(output):
|
||||
passed = output == "v1"
|
||||
return ValidationResult(passed=passed, score=1.0 if passed else 0.2, issues=[] if passed else ["不达标"])
|
||||
|
||||
final = sc.run("测试", fake_generate, fake_validate, max_retries=3)
|
||||
print(f" 最终结果 = {final}")
|
||||
|
||||
print("\n📝 练习 3: Multi-Agent 安全 process()")
|
||||
install_safe_process()
|
||||
|
||||
class BoomValidator(Agent):
|
||||
name = "validator"
|
||||
|
||||
def process(self, input_data):
|
||||
raise RuntimeError("故意崩溃")
|
||||
|
||||
sys = MultiAgentSystem()
|
||||
sys.agents["validator"] = BoomValidator()
|
||||
res = sys.process("任何需求")
|
||||
print(f" 异常被吞掉: {res}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_answers()
|
||||
Reference in New Issue
Block a user