Optimize model loading and prediction performance, implement the singleton pattern, and provide comprehensive error handling and error messages, along with confidence level display.
This commit is contained in:
+81
-11
@@ -13,9 +13,83 @@ from model_pro.MHA import MultiHeadAttentionLayer
|
||||
from model_pro.classifier import FinalClassifier
|
||||
from model_pro.BERT_CTM import BERT_CTM_Model
|
||||
|
||||
# 设置设备
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
class ModelManager:
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(ModelManager, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.classifier_model = None
|
||||
self.attention_model = None
|
||||
self.bert_ctm_model = None
|
||||
self._initialized = True
|
||||
|
||||
def load_models(self, model_save_path, bert_model_path, ctm_tokenizer_path):
|
||||
"""加载所有需要的模型"""
|
||||
try:
|
||||
if self.classifier_model is None:
|
||||
self.classifier_model = torch.load(model_save_path, map_location=self.device)
|
||||
self.classifier_model.eval()
|
||||
|
||||
if self.attention_model is None:
|
||||
self.attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)
|
||||
self.attention_model.to(self.device)
|
||||
self.attention_model.eval()
|
||||
|
||||
if self.bert_ctm_model is None:
|
||||
self.bert_ctm_model = BERT_CTM_Model(
|
||||
bert_model_path=bert_model_path,
|
||||
ctm_tokenizer_path=ctm_tokenizer_path
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"模型加载失败: {e}")
|
||||
return False
|
||||
|
||||
def predict_batch(self, texts, batch_size=32):
|
||||
"""批量预测文本情感"""
|
||||
try:
|
||||
all_predictions = []
|
||||
all_probabilities = []
|
||||
|
||||
# 分批处理文本
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[i:i + batch_size]
|
||||
|
||||
# 获取文本嵌入
|
||||
embeddings = self.bert_ctm_model.get_bert_embeddings(batch_texts)
|
||||
|
||||
# 转换为tensor
|
||||
batch_x = torch.tensor(embeddings, dtype=torch.float32).to(self.device)
|
||||
batch_x = torch.mean(batch_x, dim=1)
|
||||
|
||||
with torch.no_grad():
|
||||
# 使用注意力机制
|
||||
attention_output = self.attention_model(batch_x, batch_x, batch_x)
|
||||
# 获取分类结果
|
||||
outputs = self.classifier_model(attention_output)
|
||||
outputs = torch.mean(outputs, dim=1)
|
||||
# 获取预测概率
|
||||
probabilities = torch.softmax(outputs, dim=1)
|
||||
# 获取预测标签
|
||||
_, predicted = torch.max(outputs, 1)
|
||||
|
||||
all_predictions.extend(predicted.cpu().numpy())
|
||||
all_probabilities.extend(probabilities.cpu().numpy())
|
||||
|
||||
return all_predictions, all_probabilities
|
||||
except Exception as e:
|
||||
print(f"预测过程中出现错误: {e}")
|
||||
return None, None
|
||||
|
||||
# 创建全局的模型管理器实例
|
||||
model_manager = ModelManager()
|
||||
|
||||
def detect_file_encoding(file_path, num_bytes=10000):
|
||||
"""
|
||||
@@ -59,12 +133,8 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_
|
||||
try:
|
||||
# 加载模型
|
||||
print("加载模型...")
|
||||
classifier_model = torch.load(model_save_path, map_location=device)
|
||||
classifier_model.eval()
|
||||
|
||||
attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)
|
||||
attention_model.to(device)
|
||||
attention_model.eval()
|
||||
if not model_manager.load_models(model_save_path, bert_model_path, ctm_tokenizer_path):
|
||||
return False
|
||||
|
||||
# 检测文件编码
|
||||
encoding = detect_file_encoding(input_data_path)
|
||||
@@ -88,14 +158,14 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_
|
||||
print("开始预测...")
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(data_loader, desc="预测进度"):
|
||||
batch_x = batch[0].to(device)
|
||||
batch_x = batch[0].to(model_manager.device)
|
||||
batch_x = torch.mean(batch_x, dim=1)
|
||||
|
||||
# 使用注意力机制
|
||||
attention_output = attention_model(batch_x, batch_x, batch_x)
|
||||
attention_output = model_manager.attention_model(batch_x, batch_x, batch_x)
|
||||
|
||||
# 获取分类结果
|
||||
outputs = classifier_model(attention_output)
|
||||
outputs = model_manager.classifier_model(attention_output)
|
||||
outputs = torch.mean(outputs, dim=1)
|
||||
|
||||
# 获取预测概率
|
||||
|
||||
Reference in New Issue
Block a user