Files
bettafish-company/utils/model_loader.py
T

285 lines
9.4 KiB
Python

import os
import sys
import pickle
import marshal
import types
import logging
import torch
import numpy as np
import json
from pathlib import Path
logger = logging.getLogger('model_loader')
logger.setLevel(logging.INFO)
def load_sentiment_model(model_path, device=None):
"""
加载情感分析模型
参数:
model_path: 模型文件路径
device: 设备(可忽略,marshal模型不依赖设备)
返回:
加载好的模型对象
"""
try:
logger.info(f"加载情感分析模型: {model_path}")
if model_path.endswith('.marshal') or model_path.endswith('.marshal.3'):
with open(model_path, 'rb') as f:
model_data = marshal.load(f)
# 将marshal数据转换为可调用的函数对象
sentiment_func = types.FunctionType(model_data, globals(), "sentiment_func")
logger.info("情感分析模型加载成功")
return sentiment_func
else:
raise ValueError(f"不支持的情感模型格式: {model_path}")
except Exception as e:
logger.error(f"加载情感分析模型失败: {e}")
raise
def load_bert_ctm_model(model_dir, device='cuda' if torch.cuda.is_available() else 'cpu'):
"""
加载BERT-CTM模型
参数:
model_dir: 模型目录
device: 计算设备
返回:
包含模型和分词器的字典
"""
try:
logger.info(f"加载BERT-CTM模型: {model_dir}")
sys.path.append('model_pro')
from BERT_CTM import BERT_CTM
from transformers import BertTokenizer
# 加载模型
model_path = os.path.join(model_dir, 'final_model.pt') if not model_dir.endswith('.pt') else model_dir
model = BERT_CTM()
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
# 加载分词器
tokenizer_path = os.path.join(os.path.dirname(model_dir), 'bert_model')
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
logger.info("BERT-CTM模型加载成功")
return {
'model': model,
'tokenizer': tokenizer,
'device': device
}
except Exception as e:
logger.error(f"加载BERT-CTM模型失败: {e}")
raise
def load_bcat_model(model_dir, device='cuda' if torch.cuda.is_available() else 'cpu'):
"""
加载BCAT模型
参数:
model_dir: 模型目录
device: 计算设备
返回:
包含模型和分词器的字典
"""
try:
logger.info(f"加载BCAT模型: {model_dir}")
sys.path.append('model_pro')
from BCAT import BCAT
from transformers import BertTokenizer
# 加载模型配置
config_path = os.path.join(model_dir, 'config.json')
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
# 初始化模型
model = BCAT(**config)
# 加载模型权重
model_path = os.path.join(model_dir, 'model.pt')
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
# 加载分词器
tokenizer_path = os.path.join(model_dir, 'tokenizer')
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
logger.info("BCAT模型加载成功")
return {
'model': model,
'tokenizer': tokenizer,
'device': device,
'config': config
}
except Exception as e:
logger.error(f"加载BCAT模型失败: {e}")
raise
def load_topic_classifier(model_dir, device='cuda' if torch.cuda.is_available() else 'cpu'):
"""
加载话题分类模型
参数:
model_dir: 模型目录
device: 计算设备
返回:
包含模型、分词器和标签映射的字典
"""
try:
logger.info(f"加载话题分类模型: {model_dir}")
# 尝试加载transformers模型
try:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# 加载模型
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
model.to(device)
model.eval()
# 加载分词器
tokenizer = AutoTokenizer.from_pretrained(model_dir)
# 加载标签映射
labels_path = os.path.join(model_dir, 'labels.json')
if os.path.exists(labels_path):
with open(labels_path, 'r', encoding='utf-8') as f:
labels_map = json.load(f)
else:
# 尝试从config中读取标签
if hasattr(model.config, 'id2label'):
labels_map = model.config.id2label
else:
labels_map = {}
logger.info("话题分类模型加载成功 (transformers)")
return {
'model': model,
'tokenizer': tokenizer,
'labels_map': labels_map,
'device': device
}
except Exception as e:
logger.warning(f"使用transformers加载失败,尝试其他方法: {e}")
# 尝试加载PyTorch模型
model_path = os.path.join(model_dir, 'model.pt')
if os.path.exists(model_path):
model = torch.load(model_path, map_location=device)
# 加载分词器
tokenizer_path = os.path.join(model_dir, 'tokenizer.pkl')
if os.path.exists(tokenizer_path):
with open(tokenizer_path, 'rb') as f:
tokenizer = pickle.load(f)
else:
tokenizer = None
# 加载标签映射
labels_path = os.path.join(model_dir, 'labels.json')
if os.path.exists(labels_path):
with open(labels_path, 'r', encoding='utf-8') as f:
labels_map = json.load(f)
else:
labels_map = {}
logger.info("话题分类模型加载成功 (PyTorch)")
return {
'model': model,
'tokenizer': tokenizer,
'labels_map': labels_map,
'device': device
}
raise ValueError(f"无法加载模型: {model_dir}")
except Exception as e:
logger.error(f"加载话题分类模型失败: {e}")
raise
def load_echarts_optimizer():
"""
加载ECharts优化器,用于提升大数据渲染性能
返回:
ECharts优化器对象
"""
try:
class EChartsOptimizer:
def __init__(self):
self.chunk_size = 1000 # 分块大小
logger.info("ECharts优化器初始化成功")
def optimize_option(self, option):
"""优化ECharts配置,提升大数据渲染性能"""
if not option:
return option
# 深拷贝以避免修改原始对象
import copy
option = copy.deepcopy(option)
# 添加渐进式渲染
if 'progressive' not in option:
option['progressive'] = 300 # 每帧渲染的数据点数量
if 'progressiveThreshold' not in option:
option['progressiveThreshold'] = 5000 # 启动渐进式渲染的阈值
if 'series' in option and isinstance(option['series'], list):
for series in option['series']:
# 对大数据系列应用优化
if 'data' in series and isinstance(series['data'], list) and len(series['data']) > 5000:
# 大数据采样
if series.get('type') in ['scatter', 'line']:
self._optimize_large_data_series(series)
return option
def _optimize_large_data_series(self, series):
"""优化大数据系列"""
# 添加大数据优化选项
series['large'] = True
series['largeThreshold'] = 2000
# 按需设置抽样
if len(series['data']) > 50000:
# 对非常大的数据集进行抽样
step = max(1, len(series['data']) // 50000)
series['data'] = series['data'][::step]
series['sampling'] = 'average'
return series
def chunk_process_data(self, data, process_func):
"""分块处理大数据"""
result = []
for i in range(0, len(data), self.chunk_size):
chunk = data[i:i + self.chunk_size]
result.extend(process_func(chunk))
return result
return EChartsOptimizer()
except Exception as e:
logger.error(f"加载ECharts优化器失败: {e}")
return None
# 导出所有加载函数
__all__ = [
'load_sentiment_model',
'load_bert_ctm_model',
'load_bcat_model',
'load_topic_classifier',
'load_echarts_optimizer'
]