From 8249ec004828630456ca34c5d0ccc38198ea3d14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=88=92=E9=85=92=E7=9A=84=E6=9D=8E=E7=99=BD?= <670939375@qq.com> Date: Sun, 9 Feb 2025 22:06:46 +0800 Subject: [PATCH] Implement a two-level caching system (memory + disk) to optimize topic switch response speed, support asynchronous writing, and automatically clean up expired data. --- utils/cache_manager.py | 116 +++++++++++++++++++++++++++++++++++++++++ views/page/page.py | 43 +++++++++------ 2 files changed, 143 insertions(+), 16 deletions(-) create mode 100644 utils/cache_manager.py diff --git a/utils/cache_manager.py b/utils/cache_manager.py new file mode 100644 index 0000000..4955a5d --- /dev/null +++ b/utils/cache_manager.py @@ -0,0 +1,116 @@ +import json +import os +import time +from datetime import datetime, timedelta +import threading +import queue + +class PredictionCache: + _instance = None + _lock = threading.Lock() + + def __new__(cls): + with cls._lock: + if cls._instance is None: + cls._instance = super(PredictionCache, cls).__new__(cls) + return cls._instance + + def __init__(self): + if not hasattr(self, 'initialized'): + self.cache_dir = 'cache/predictions' + self.cache_duration = timedelta(hours=24) # 缓存24小时 + self.cache = {} + self.cache_queue = queue.Queue() + self.initialized = True + + # 确保缓存目录存在 + os.makedirs(self.cache_dir, exist_ok=True) + + # 启动缓存清理线程 + self.cleanup_thread = threading.Thread(target=self._cleanup_old_cache, daemon=True) + self.cleanup_thread.start() + + # 加载现有缓存 + self._load_cache() + + def _load_cache(self): + """加载磁盘上的缓存文件""" + try: + for filename in os.listdir(self.cache_dir): + if filename.endswith('.json'): + filepath = os.path.join(self.cache_dir, filename) + with open(filepath, 'r', encoding='utf-8') as f: + cache_data = json.load(f) + # 检查缓存是否过期 + if self._is_cache_valid(cache_data['timestamp']): + topic = filename[:-5] # 移除.json后缀 + self.cache[topic] = cache_data + else: + # 删除过期缓存文件 + os.remove(filepath) + except Exception as e: + print(f"加载缓存失败: {e}") + + def _cleanup_old_cache(self): + """定期清理过期缓存的后台线程""" + while True: + try: + # 检查并清理内存缓存 + current_time = datetime.now() + expired_topics = [] + + for topic, cache_data in self.cache.items(): + if not self._is_cache_valid(cache_data['timestamp']): + expired_topics.append(topic) + + # 删除过期缓存 + for topic in expired_topics: + del self.cache[topic] + cache_file = os.path.join(self.cache_dir, f"{topic}.json") + if os.path.exists(cache_file): + os.remove(cache_file) + + # 休眠1小时后再次检查 + time.sleep(3600) + except Exception as e: + print(f"清理缓存时出错: {e}") + time.sleep(3600) # 发生错误时也等待1小时 + + def _is_cache_valid(self, timestamp): + """检查缓存是否有效""" + cache_time = datetime.fromtimestamp(timestamp) + return datetime.now() - cache_time < self.cache_duration + + def get(self, topic): + """获取话题的预测缓存""" + if topic in self.cache and self._is_cache_valid(self.cache[topic]['timestamp']): + return self.cache[topic]['prediction'] + return None + + def set(self, topic, prediction): + """设置话题的预测缓存""" + cache_data = { + 'prediction': prediction, + 'timestamp': datetime.now().timestamp() + } + + # 更新内存缓存 + self.cache[topic] = cache_data + + # 异步保存到磁盘 + self.cache_queue.put((topic, cache_data)) + threading.Thread(target=self._save_cache_to_disk, daemon=True).start() + + def _save_cache_to_disk(self): + """异步保存缓存到磁盘""" + try: + while not self.cache_queue.empty(): + topic, cache_data = self.cache_queue.get() + cache_file = os.path.join(self.cache_dir, f"{topic}.json") + with open(cache_file, 'w', encoding='utf-8') as f: + json.dump(cache_data, f, ensure_ascii=False, indent=2) + except Exception as e: + print(f"保存缓存到磁盘失败: {e}") + +# 创建全局缓存实例 +prediction_cache = PredictionCache() \ No newline at end of file diff --git a/views/page/page.py b/views/page/page.py index c788e87..7cd95a0 100644 --- a/views/page/page.py +++ b/views/page/page.py @@ -8,6 +8,7 @@ from utils.getEchartsData import * from utils.getTopicPageData import * from utils.yuqingpredict import * from utils.logger import app_logger as logging +from utils.cache_manager import prediction_cache import torch from BCAT_front.predict import model_manager @@ -207,24 +208,34 @@ def yuqingpredict(): # 获取模型选择参数 model_type = request.args.get('model', 'pro') # 默认使用改进模型 - if model_type == 'basic': - # 使用基础模型(SnowNLP) - value = SnowNLP(defaultTopic).sentiments - if value == 0.5: - sentences = '中性' - elif value > 0.5: - sentences = '正面' - elif value < 0.5: - sentences = '负面' + # 尝试从缓存获取预测结果 + cache_key = f"{defaultTopic}_{model_type}" + cached_result = prediction_cache.get(cache_key) + + if cached_result is not None: + sentences = cached_result else: - # 使用改进模型 - predicted_label, confidence = predict_sentiment(defaultTopic) - if predicted_label is not None: - sentences = '良好' if predicted_label == 0 else '不良' - sentences = f"{sentences} (置信度: {confidence:.2%})" + if model_type == 'basic': + # 使用基础模型(SnowNLP) + value = SnowNLP(defaultTopic).sentiments + if value == 0.5: + sentences = '中性' + elif value > 0.5: + sentences = '正面' + elif value < 0.5: + sentences = '负面' else: - sentences = '预测失败,请稍后重试' - logging.error(f"预测失败,话题: {defaultTopic}") + # 使用改进模型 + predicted_label, confidence = predict_sentiment(defaultTopic) + if predicted_label is not None: + sentences = '良好' if predicted_label == 0 else '不良' + sentences = f"{sentences} (置信度: {confidence:.2%})" + else: + sentences = '预测失败,请稍后重试' + logging.error(f"预测失败,话题: {defaultTopic}") + + # 将结果存入缓存 + prediction_cache.set(cache_key, sentences) comments = getCommentFilterDataTopic(defaultTopic) return render_template('yuqingpredict.html',