Fixed the Issue of Charts being Repeatedly Repaired
This commit is contained in:
@@ -21,6 +21,7 @@ from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
import hashlib
|
||||
from typing import Any, Dict, List, Optional, Tuple, Callable
|
||||
from dataclasses import dataclass
|
||||
from loguru import logger
|
||||
@@ -383,6 +384,30 @@ class ChartRepairer:
|
||||
"""
|
||||
self.validator = validator
|
||||
self.llm_repair_fns = llm_repair_fns or []
|
||||
# 缓存修复结果,避免同一个图表在多处被重复调用LLM
|
||||
self._result_cache: Dict[str, RepairResult] = {}
|
||||
|
||||
def build_cache_key(self, widget_block: Dict[str, Any]) -> str:
|
||||
"""
|
||||
为图表生成稳定的缓存key,保证同样的数据不会重复触发修复。
|
||||
|
||||
- 优先使用widgetId;
|
||||
- 结合数据内容的哈希,避免同ID但内容变化时误用旧结果。
|
||||
"""
|
||||
widget_id = ""
|
||||
if isinstance(widget_block, dict):
|
||||
widget_id = widget_block.get('widgetId') or widget_block.get('id') or ""
|
||||
try:
|
||||
serialized = json.dumps(
|
||||
widget_block,
|
||||
ensure_ascii=False,
|
||||
sort_keys=True,
|
||||
default=str
|
||||
)
|
||||
except Exception:
|
||||
serialized = repr(widget_block)
|
||||
digest = hashlib.md5(serialized.encode('utf-8', errors='ignore')).hexdigest()
|
||||
return f"{widget_id}:{digest}"
|
||||
|
||||
def repair(
|
||||
self,
|
||||
@@ -399,6 +424,20 @@ class ChartRepairer:
|
||||
Returns:
|
||||
RepairResult: 修复结果
|
||||
"""
|
||||
cache_key = self.build_cache_key(widget_block)
|
||||
|
||||
cached = self._result_cache.get(cache_key)
|
||||
if cached:
|
||||
# 返回缓存的深拷贝,避免外部修改影响缓存
|
||||
return copy.deepcopy(cached)
|
||||
|
||||
def _cache_and_return(res: RepairResult) -> RepairResult:
|
||||
try:
|
||||
self._result_cache[cache_key] = copy.deepcopy(res)
|
||||
except Exception:
|
||||
self._result_cache[cache_key] = res
|
||||
return res
|
||||
|
||||
# 1. 如果没有验证结果,先验证
|
||||
if validation_result is None:
|
||||
validation_result = self.validator.validate(widget_block)
|
||||
@@ -412,7 +451,9 @@ class ChartRepairer:
|
||||
repaired_validation = self.validator.validate(local_result.repaired_block)
|
||||
if repaired_validation.is_valid:
|
||||
logger.info(f"本地修复成功: {local_result.changes}")
|
||||
return RepairResult(True, local_result.repaired_block, 'local', local_result.changes)
|
||||
return _cache_and_return(
|
||||
RepairResult(True, local_result.repaired_block, 'local', local_result.changes)
|
||||
)
|
||||
else:
|
||||
logger.warning(f"本地修复后仍然无效: {repaired_validation.errors}")
|
||||
|
||||
@@ -426,20 +467,22 @@ class ChartRepairer:
|
||||
repaired_validation = self.validator.validate(api_result.repaired_block)
|
||||
if repaired_validation.is_valid:
|
||||
logger.info(f"API修复成功: {api_result.changes}")
|
||||
return api_result
|
||||
return _cache_and_return(api_result)
|
||||
else:
|
||||
logger.warning(f"API修复后仍然无效: {repaired_validation.errors}")
|
||||
|
||||
# 5. 如果验证通过,返回原始或修复后的数据
|
||||
if validation_result.is_valid:
|
||||
if local_result.has_changes():
|
||||
return RepairResult(True, local_result.repaired_block, 'local', local_result.changes)
|
||||
return _cache_and_return(
|
||||
RepairResult(True, local_result.repaired_block, 'local', local_result.changes)
|
||||
)
|
||||
else:
|
||||
return RepairResult(True, widget_block, 'none', [])
|
||||
return _cache_and_return(RepairResult(True, widget_block, 'none', []))
|
||||
|
||||
# 6. 所有修复都失败,返回原始数据
|
||||
logger.warning("所有修复尝试失败,保持原始数据")
|
||||
return RepairResult(False, widget_block, 'none', [])
|
||||
return _cache_and_return(RepairResult(False, widget_block, 'none', []))
|
||||
|
||||
def repair_locally(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user