364 lines
14 KiB
Python
364 lines
14 KiB
Python
import os
|
||
import time
|
||
import threading
|
||
import logging
|
||
import gc
|
||
import torch
|
||
import numpy as np
|
||
from collections import OrderedDict
|
||
from datetime import datetime, timedelta
|
||
|
||
logger = logging.getLogger('model_manager')
|
||
logger.setLevel(logging.INFO)
|
||
|
||
class ModelManager:
|
||
"""
|
||
模型管理器 - 实现模型预加载和按需卸载技术
|
||
|
||
功能:
|
||
1. 预加载经常使用的模型,减少加载等待时间
|
||
2. 使用LRU (Least Recently Used) 策略管理内存中加载的模型
|
||
3. 支持模型的异步加载和监控
|
||
4. 自动检测并释放长时间未使用的模型内存
|
||
5. 提供模型使用统计
|
||
"""
|
||
|
||
_instance = None
|
||
_lock = threading.Lock()
|
||
|
||
def __new__(cls):
|
||
with cls._lock:
|
||
if cls._instance is None:
|
||
cls._instance = super(ModelManager, cls).__new__(cls)
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
if hasattr(self, 'initialized'):
|
||
return
|
||
|
||
# 已加载模型的缓存,使用OrderedDict实现LRU
|
||
self.loaded_models = OrderedDict()
|
||
# 模型使用统计
|
||
self.model_stats = {}
|
||
# 模型预热配置
|
||
self.preload_config = {}
|
||
# 最大内存占用(GB)
|
||
self.max_memory_usage = float(os.getenv('MAX_MODEL_MEMORY_USAGE', '4.0'))
|
||
# 模型加载中的锁
|
||
self.loading_locks = {}
|
||
# 模型卸载超时(分钟)
|
||
self.unload_timeout = int(os.getenv('MODEL_UNLOAD_TIMEOUT', '30'))
|
||
|
||
# 启动模型监控线程
|
||
self.monitor_thread = threading.Thread(target=self._monitor_models, daemon=True)
|
||
self.monitor_thread.start()
|
||
|
||
self.initialized = True
|
||
logger.info(f"模型管理器初始化完成,最大内存占用: {self.max_memory_usage}GB")
|
||
|
||
def register_model(self, model_id, model_path, preload=False, model_size_gb=0.5,
|
||
load_function=None, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
||
"""
|
||
注册模型,可选设置为预加载
|
||
|
||
参数:
|
||
model_id: 模型唯一标识符
|
||
model_path: 模型路径
|
||
preload: 是否预加载
|
||
model_size_gb: 模型估计大小(GB)
|
||
load_function: 自定义加载函数,签名为 load_function(model_path, device) -> model
|
||
device: 加载模型的设备
|
||
"""
|
||
self.preload_config[model_id] = {
|
||
'model_path': model_path,
|
||
'preload': preload,
|
||
'model_size_gb': model_size_gb,
|
||
'load_function': load_function,
|
||
'device': device
|
||
}
|
||
|
||
self.model_stats[model_id] = {
|
||
'load_count': 0,
|
||
'use_count': 0,
|
||
'total_load_time': 0,
|
||
'last_used': None,
|
||
'avg_load_time': 0
|
||
}
|
||
|
||
if preload:
|
||
logger.info(f"模型 {model_id} 已注册并标记为预加载")
|
||
# 启动预加载线程
|
||
threading.Thread(target=self._preload_model, args=(model_id,), daemon=True).start()
|
||
else:
|
||
logger.info(f"模型 {model_id} 已注册")
|
||
|
||
return True
|
||
|
||
def get_model(self, model_id):
|
||
"""
|
||
获取模型,如果未加载则加载
|
||
|
||
参数:
|
||
model_id: 模型唯一标识符
|
||
|
||
返回:
|
||
加载好的模型对象
|
||
"""
|
||
if model_id not in self.preload_config:
|
||
raise ValueError(f"模型 {model_id} 未注册")
|
||
|
||
# 更新最后使用时间
|
||
self.model_stats[model_id]['last_used'] = datetime.now()
|
||
self.model_stats[model_id]['use_count'] += 1
|
||
|
||
# 检查模型是否已加载
|
||
if model_id in self.loaded_models:
|
||
# 将模型移至OrderedDict末尾,表示最近使用
|
||
model = self.loaded_models.pop(model_id)
|
||
self.loaded_models[model_id] = model
|
||
logger.debug(f"使用已加载的模型: {model_id}")
|
||
return model
|
||
|
||
# 获取模型加载锁,防止并发加载同一模型
|
||
if model_id not in self.loading_locks:
|
||
self.loading_locks[model_id] = threading.Lock()
|
||
|
||
# 加锁加载模型
|
||
with self.loading_locks[model_id]:
|
||
# 再次检查模型是否已被其他线程加载
|
||
if model_id in self.loaded_models:
|
||
return self.loaded_models[model_id]
|
||
|
||
# 检查是否有足够内存
|
||
self._ensure_memory_available(self.preload_config[model_id]['model_size_gb'])
|
||
|
||
# 加载模型
|
||
start_time = time.time()
|
||
model = self._load_model(model_id)
|
||
load_time = time.time() - start_time
|
||
|
||
# 更新统计
|
||
self.model_stats[model_id]['load_count'] += 1
|
||
self.model_stats[model_id]['total_load_time'] += load_time
|
||
self.model_stats[model_id]['avg_load_time'] = (
|
||
self.model_stats[model_id]['total_load_time'] /
|
||
self.model_stats[model_id]['load_count']
|
||
)
|
||
|
||
logger.info(f"模型 {model_id} 加载完成,耗时: {load_time:.2f}秒")
|
||
|
||
# 存储模型
|
||
self.loaded_models[model_id] = model
|
||
return model
|
||
|
||
def unload_model(self, model_id):
|
||
"""
|
||
手动卸载模型
|
||
|
||
参数:
|
||
model_id: 模型唯一标识符
|
||
"""
|
||
if model_id in self.loaded_models:
|
||
logger.info(f"手动卸载模型: {model_id}")
|
||
del self.loaded_models[model_id]
|
||
# 强制垃圾回收
|
||
gc.collect()
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
return True
|
||
return False
|
||
|
||
def get_model_stats(self):
|
||
"""获取所有模型的使用统计"""
|
||
result = {}
|
||
for model_id, stats in self.model_stats.items():
|
||
is_loaded = model_id in self.loaded_models
|
||
result[model_id] = {
|
||
**stats,
|
||
'is_loaded': is_loaded,
|
||
'preload': self.preload_config[model_id]['preload'],
|
||
'model_size_gb': self.preload_config[model_id]['model_size_gb'],
|
||
'device': self.preload_config[model_id]['device'],
|
||
}
|
||
return result
|
||
|
||
def preload_all(self):
|
||
"""预加载所有标记为预加载的模型"""
|
||
for model_id, config in self.preload_config.items():
|
||
if config['preload'] and model_id not in self.loaded_models:
|
||
threading.Thread(target=self._preload_model, args=(model_id,), daemon=True).start()
|
||
|
||
def _preload_model(self, model_id):
|
||
"""预加载单个模型的内部方法"""
|
||
try:
|
||
logger.info(f"开始预加载模型: {model_id}")
|
||
# 确保有足够内存
|
||
self._ensure_memory_available(self.preload_config[model_id]['model_size_gb'])
|
||
|
||
# 加载模型
|
||
start_time = time.time()
|
||
model = self._load_model(model_id)
|
||
load_time = time.time() - start_time
|
||
|
||
# 更新统计
|
||
self.model_stats[model_id]['load_count'] += 1
|
||
self.model_stats[model_id]['total_load_time'] += load_time
|
||
self.model_stats[model_id]['avg_load_time'] = (
|
||
self.model_stats[model_id]['total_load_time'] /
|
||
self.model_stats[model_id]['load_count']
|
||
)
|
||
|
||
# 存储模型
|
||
self.loaded_models[model_id] = model
|
||
logger.info(f"模型 {model_id} 预加载完成,耗时: {load_time:.2f}秒")
|
||
|
||
except Exception as e:
|
||
logger.error(f"预加载模型 {model_id} 失败: {e}")
|
||
|
||
def _load_model(self, model_id):
|
||
"""加载模型的内部方法"""
|
||
config = self.preload_config[model_id]
|
||
|
||
if config['load_function'] is not None:
|
||
# 使用自定义加载函数
|
||
return config['load_function'](config['model_path'], config['device'])
|
||
|
||
# 默认加载逻辑 - 根据文件扩展名确定加载方式
|
||
model_path = config['model_path']
|
||
device = config['device']
|
||
|
||
if model_path.endswith('.pt') or model_path.endswith('.pth'):
|
||
# PyTorch模型
|
||
return torch.load(model_path, map_location=device)
|
||
elif model_path.endswith('.pkl'):
|
||
# Pickle模型
|
||
import pickle
|
||
with open(model_path, 'rb') as f:
|
||
return pickle.load(f)
|
||
else:
|
||
# 尝试作为目录加载
|
||
if os.path.isdir(model_path):
|
||
# 如果是目录,尝试加载预训练模型
|
||
try:
|
||
from transformers import AutoModel, AutoTokenizer
|
||
model = AutoModel.from_pretrained(model_path)
|
||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||
return {'model': model.to(device), 'tokenizer': tokenizer}
|
||
except ImportError:
|
||
logger.error("transformers库未安装,无法加载预训练模型")
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"加载预训练模型失败: {e}")
|
||
raise
|
||
|
||
raise ValueError(f"无法确定如何加载模型: {model_path}")
|
||
|
||
def _ensure_memory_available(self, required_gb):
|
||
"""确保有足够的内存来加载新模型"""
|
||
# 如果当前没有加载的模型,直接返回
|
||
if not self.loaded_models:
|
||
return
|
||
|
||
# 计算当前已加载模型的总内存
|
||
current_usage = sum(
|
||
self.preload_config[model_id]['model_size_gb']
|
||
for model_id in self.loaded_models
|
||
)
|
||
|
||
# 如果添加新模型后超过限制,需要卸载一些模型
|
||
while current_usage + required_gb > self.max_memory_usage and self.loaded_models:
|
||
# 卸载最久未使用的模型(OrderedDict的首项)
|
||
oldest_model_id, _ = next(iter(self.loaded_models.items()))
|
||
# 检查是否是预加载且最近使用过的模型
|
||
if (self.preload_config[oldest_model_id]['preload'] and
|
||
self.model_stats[oldest_model_id]['last_used'] and
|
||
(datetime.now() - self.model_stats[oldest_model_id]['last_used']) <
|
||
timedelta(minutes=self.unload_timeout)):
|
||
# 跳过预加载且最近使用过的模型
|
||
# 将该模型移至OrderedDict末尾
|
||
model = self.loaded_models.pop(oldest_model_id)
|
||
self.loaded_models[oldest_model_id] = model
|
||
# 如果所有模型都是预加载的且最近使用过,允许超过限制
|
||
if len(self.loaded_models) <= 1:
|
||
break
|
||
continue
|
||
|
||
# 卸载模型并更新内存使用
|
||
model_size = self.preload_config[oldest_model_id]['model_size_gb']
|
||
del self.loaded_models[oldest_model_id]
|
||
current_usage -= model_size
|
||
logger.info(f"自动卸载模型以释放内存: {oldest_model_id} ({model_size}GB)")
|
||
|
||
# 强制垃圾回收
|
||
gc.collect()
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
|
||
def _monitor_models(self):
|
||
"""监控并管理模型的内部线程方法"""
|
||
while True:
|
||
try:
|
||
# 检查长时间未使用的非预加载模型
|
||
current_time = datetime.now()
|
||
for model_id in list(self.loaded_models.keys()):
|
||
if (not self.preload_config[model_id]['preload'] and
|
||
self.model_stats[model_id]['last_used'] and
|
||
(current_time - self.model_stats[model_id]['last_used']) >
|
||
timedelta(minutes=self.unload_timeout)):
|
||
# 卸载长时间未使用的非预加载模型
|
||
logger.info(f"卸载长时间未使用的模型: {model_id}")
|
||
del self.loaded_models[model_id]
|
||
# 强制垃圾回收
|
||
gc.collect()
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
|
||
# 每5分钟检查一次
|
||
time.sleep(300)
|
||
except Exception as e:
|
||
logger.error(f"模型监控线程出错: {e}")
|
||
time.sleep(300)
|
||
|
||
# 创建全局模型管理器实例
|
||
model_manager = ModelManager()
|
||
|
||
# 注册示例函数
|
||
def register_sentiment_model():
|
||
"""注册情感分析模型示例"""
|
||
from utils.model_loader import load_sentiment_model # 假设您有一个加载情感模型的函数
|
||
|
||
try:
|
||
model_path = os.path.join('model', 'sentiment.marshal.3')
|
||
model_manager.register_model(
|
||
model_id='sentiment_basic',
|
||
model_path=model_path,
|
||
preload=True,
|
||
model_size_gb=0.2,
|
||
load_function=load_sentiment_model
|
||
)
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"注册情感分析模型失败: {e}")
|
||
return False
|
||
|
||
def register_bert_model():
|
||
"""注册BERT模型示例"""
|
||
try:
|
||
model_path = os.path.join('model_pro', 'bert_model')
|
||
model_manager.register_model(
|
||
model_id='bert_classifier',
|
||
model_path=model_path,
|
||
preload=True,
|
||
model_size_gb=0.8
|
||
)
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"注册BERT模型失败: {e}")
|
||
return False
|
||
|
||
# 自动注册常用模型(在导入时执行)
|
||
try:
|
||
register_sentiment_model()
|
||
register_bert_model()
|
||
except Exception as e:
|
||
logger.error(f"自动注册模型失败: {e}") |