import os import time import threading import logging import gc import torch import numpy as np from collections import OrderedDict from datetime import datetime, timedelta logger = logging.getLogger('model_manager') logger.setLevel(logging.INFO) class ModelManager: """ 模型管理器 - 实现模型预加载和按需卸载技术 功能: 1. 预加载经常使用的模型,减少加载等待时间 2. 使用LRU (Least Recently Used) 策略管理内存中加载的模型 3. 支持模型的异步加载和监控 4. 自动检测并释放长时间未使用的模型内存 5. 提供模型使用统计 """ _instance = None _lock = threading.Lock() def __new__(cls): with cls._lock: if cls._instance is None: cls._instance = super(ModelManager, cls).__new__(cls) return cls._instance def __init__(self): if hasattr(self, 'initialized'): return # 已加载模型的缓存,使用OrderedDict实现LRU self.loaded_models = OrderedDict() # 模型使用统计 self.model_stats = {} # 模型预热配置 self.preload_config = {} # 最大内存占用(GB) self.max_memory_usage = float(os.getenv('MAX_MODEL_MEMORY_USAGE', '4.0')) # 模型加载中的锁 self.loading_locks = {} # 模型卸载超时(分钟) self.unload_timeout = int(os.getenv('MODEL_UNLOAD_TIMEOUT', '30')) # 启动模型监控线程 self.monitor_thread = threading.Thread(target=self._monitor_models, daemon=True) self.monitor_thread.start() self.initialized = True logger.info(f"模型管理器初始化完成,最大内存占用: {self.max_memory_usage}GB") def register_model(self, model_id, model_path, preload=False, model_size_gb=0.5, load_function=None, device='cuda' if torch.cuda.is_available() else 'cpu'): """ 注册模型,可选设置为预加载 参数: model_id: 模型唯一标识符 model_path: 模型路径 preload: 是否预加载 model_size_gb: 模型估计大小(GB) load_function: 自定义加载函数,签名为 load_function(model_path, device) -> model device: 加载模型的设备 """ self.preload_config[model_id] = { 'model_path': model_path, 'preload': preload, 'model_size_gb': model_size_gb, 'load_function': load_function, 'device': device } self.model_stats[model_id] = { 'load_count': 0, 'use_count': 0, 'total_load_time': 0, 'last_used': None, 'avg_load_time': 0 } if preload: logger.info(f"模型 {model_id} 已注册并标记为预加载") # 启动预加载线程 threading.Thread(target=self._preload_model, args=(model_id,), daemon=True).start() else: logger.info(f"模型 {model_id} 已注册") return True def get_model(self, model_id): """ 获取模型,如果未加载则加载 参数: model_id: 模型唯一标识符 返回: 加载好的模型对象 """ if model_id not in self.preload_config: raise ValueError(f"模型 {model_id} 未注册") # 更新最后使用时间 self.model_stats[model_id]['last_used'] = datetime.now() self.model_stats[model_id]['use_count'] += 1 # 检查模型是否已加载 if model_id in self.loaded_models: # 将模型移至OrderedDict末尾,表示最近使用 model = self.loaded_models.pop(model_id) self.loaded_models[model_id] = model logger.debug(f"使用已加载的模型: {model_id}") return model # 获取模型加载锁,防止并发加载同一模型 if model_id not in self.loading_locks: self.loading_locks[model_id] = threading.Lock() # 加锁加载模型 with self.loading_locks[model_id]: # 再次检查模型是否已被其他线程加载 if model_id in self.loaded_models: return self.loaded_models[model_id] # 检查是否有足够内存 self._ensure_memory_available(self.preload_config[model_id]['model_size_gb']) # 加载模型 start_time = time.time() model = self._load_model(model_id) load_time = time.time() - start_time # 更新统计 self.model_stats[model_id]['load_count'] += 1 self.model_stats[model_id]['total_load_time'] += load_time self.model_stats[model_id]['avg_load_time'] = ( self.model_stats[model_id]['total_load_time'] / self.model_stats[model_id]['load_count'] ) logger.info(f"模型 {model_id} 加载完成,耗时: {load_time:.2f}秒") # 存储模型 self.loaded_models[model_id] = model return model def unload_model(self, model_id): """ 手动卸载模型 参数: model_id: 模型唯一标识符 """ if model_id in self.loaded_models: logger.info(f"手动卸载模型: {model_id}") del self.loaded_models[model_id] # 强制垃圾回收 gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return True return False def get_model_stats(self): """获取所有模型的使用统计""" result = {} for model_id, stats in self.model_stats.items(): is_loaded = model_id in self.loaded_models result[model_id] = { **stats, 'is_loaded': is_loaded, 'preload': self.preload_config[model_id]['preload'], 'model_size_gb': self.preload_config[model_id]['model_size_gb'], 'device': self.preload_config[model_id]['device'], } return result def preload_all(self): """预加载所有标记为预加载的模型""" for model_id, config in self.preload_config.items(): if config['preload'] and model_id not in self.loaded_models: threading.Thread(target=self._preload_model, args=(model_id,), daemon=True).start() def _preload_model(self, model_id): """预加载单个模型的内部方法""" try: logger.info(f"开始预加载模型: {model_id}") # 确保有足够内存 self._ensure_memory_available(self.preload_config[model_id]['model_size_gb']) # 加载模型 start_time = time.time() model = self._load_model(model_id) load_time = time.time() - start_time # 更新统计 self.model_stats[model_id]['load_count'] += 1 self.model_stats[model_id]['total_load_time'] += load_time self.model_stats[model_id]['avg_load_time'] = ( self.model_stats[model_id]['total_load_time'] / self.model_stats[model_id]['load_count'] ) # 存储模型 self.loaded_models[model_id] = model logger.info(f"模型 {model_id} 预加载完成,耗时: {load_time:.2f}秒") except Exception as e: logger.error(f"预加载模型 {model_id} 失败: {e}") def _load_model(self, model_id): """加载模型的内部方法""" config = self.preload_config[model_id] if config['load_function'] is not None: # 使用自定义加载函数 return config['load_function'](config['model_path'], config['device']) # 默认加载逻辑 - 根据文件扩展名确定加载方式 model_path = config['model_path'] device = config['device'] if model_path.endswith('.pt') or model_path.endswith('.pth'): # PyTorch模型 return torch.load(model_path, map_location=device) elif model_path.endswith('.pkl'): # Pickle模型 import pickle with open(model_path, 'rb') as f: return pickle.load(f) else: # 尝试作为目录加载 if os.path.isdir(model_path): # 如果是目录,尝试加载预训练模型 try: from transformers import AutoModel, AutoTokenizer model = AutoModel.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path) return {'model': model.to(device), 'tokenizer': tokenizer} except ImportError: logger.error("transformers库未安装,无法加载预训练模型") raise except Exception as e: logger.error(f"加载预训练模型失败: {e}") raise raise ValueError(f"无法确定如何加载模型: {model_path}") def _ensure_memory_available(self, required_gb): """确保有足够的内存来加载新模型""" # 如果当前没有加载的模型,直接返回 if not self.loaded_models: return # 计算当前已加载模型的总内存 current_usage = sum( self.preload_config[model_id]['model_size_gb'] for model_id in self.loaded_models ) # 如果添加新模型后超过限制,需要卸载一些模型 while current_usage + required_gb > self.max_memory_usage and self.loaded_models: # 卸载最久未使用的模型(OrderedDict的首项) oldest_model_id, _ = next(iter(self.loaded_models.items())) # 检查是否是预加载且最近使用过的模型 if (self.preload_config[oldest_model_id]['preload'] and self.model_stats[oldest_model_id]['last_used'] and (datetime.now() - self.model_stats[oldest_model_id]['last_used']) < timedelta(minutes=self.unload_timeout)): # 跳过预加载且最近使用过的模型 # 将该模型移至OrderedDict末尾 model = self.loaded_models.pop(oldest_model_id) self.loaded_models[oldest_model_id] = model # 如果所有模型都是预加载的且最近使用过,允许超过限制 if len(self.loaded_models) <= 1: break continue # 卸载模型并更新内存使用 model_size = self.preload_config[oldest_model_id]['model_size_gb'] del self.loaded_models[oldest_model_id] current_usage -= model_size logger.info(f"自动卸载模型以释放内存: {oldest_model_id} ({model_size}GB)") # 强制垃圾回收 gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def _monitor_models(self): """监控并管理模型的内部线程方法""" while True: try: # 检查长时间未使用的非预加载模型 current_time = datetime.now() for model_id in list(self.loaded_models.keys()): if (not self.preload_config[model_id]['preload'] and self.model_stats[model_id]['last_used'] and (current_time - self.model_stats[model_id]['last_used']) > timedelta(minutes=self.unload_timeout)): # 卸载长时间未使用的非预加载模型 logger.info(f"卸载长时间未使用的模型: {model_id}") del self.loaded_models[model_id] # 强制垃圾回收 gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # 每5分钟检查一次 time.sleep(300) except Exception as e: logger.error(f"模型监控线程出错: {e}") time.sleep(300) # 创建全局模型管理器实例 model_manager = ModelManager() # 注册示例函数 def register_sentiment_model(): """注册情感分析模型示例""" from utils.model_loader import load_sentiment_model # 假设您有一个加载情感模型的函数 try: model_path = os.path.join('model', 'sentiment.marshal.3') model_manager.register_model( model_id='sentiment_basic', model_path=model_path, preload=True, model_size_gb=0.2, load_function=load_sentiment_model ) return True except Exception as e: logger.error(f"注册情感分析模型失败: {e}") return False def register_bert_model(): """注册BERT模型示例""" try: model_path = os.path.join('model_pro', 'bert_model') model_manager.register_model( model_id='bert_classifier', model_path=model_path, preload=True, model_size_gb=0.8 ) return True except Exception as e: logger.error(f"注册BERT模型失败: {e}") return False # 自动注册常用模型(在导入时执行) try: register_sentiment_model() register_bert_model() except Exception as e: logger.error(f"自动注册模型失败: {e}")