Optimize model loading and prediction performance, implement the singleton pattern, and provide comprehensive error handling and error messages, along with confidence level display.
This commit is contained in:
+21
-35
@@ -2,9 +2,7 @@ from utils.getPublicData import * # Import utility functions for data retrieval
|
||||
from utils.mynlp import SnowNLP # Import SnowNLP for sentiment analysis
|
||||
from collections import Counter # Import Counter for counting occurrences
|
||||
import torch
|
||||
from model_pro.MHA import MultiHeadAttentionLayer
|
||||
from model_pro.classifier import FinalClassifier
|
||||
from model_pro.BERT_CTM import BERT_CTM_Model
|
||||
from BCAT_front.predict import model_manager
|
||||
|
||||
articleList = getAllArticleData() # Retrieve all article data
|
||||
commentList = getAllCommentsData() # Retrieve all comment data
|
||||
@@ -12,47 +10,27 @@ commentList = getAllCommentsData() # Retrieve all comment data
|
||||
# 设置设备
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# 加载模型(全局变量,避免重复加载)
|
||||
# 设置模型路径
|
||||
model_save_path = 'model_pro/final_model.pt'
|
||||
bert_model_path = 'model_pro/bert_model'
|
||||
ctm_tokenizer_path = 'model_pro/sentence_bert_model'
|
||||
|
||||
# 初始化模型
|
||||
try:
|
||||
classifier_model = torch.load(model_save_path, map_location=device)
|
||||
classifier_model.eval()
|
||||
attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)
|
||||
attention_model.to(device)
|
||||
attention_model.eval()
|
||||
bert_ctm_model = BERT_CTM_Model(
|
||||
bert_model_path=bert_model_path,
|
||||
ctm_tokenizer_path=ctm_tokenizer_path
|
||||
)
|
||||
model_manager.load_models(model_save_path, bert_model_path, ctm_tokenizer_path)
|
||||
except Exception as e:
|
||||
print(f"模型加载失败: {e}")
|
||||
|
||||
def predict_sentiment(texts):
|
||||
"""使用改进版模型预测情感"""
|
||||
try:
|
||||
# 获取文本嵌入
|
||||
embeddings = bert_ctm_model.get_bert_embeddings(texts)
|
||||
|
||||
# 转换为tensor
|
||||
batch_x = torch.tensor(embeddings, dtype=torch.float32).to(device)
|
||||
batch_x = torch.mean(batch_x, dim=1)
|
||||
|
||||
with torch.no_grad():
|
||||
# 使用注意力机制
|
||||
attention_output = attention_model(batch_x, batch_x, batch_x)
|
||||
# 获取分类结果
|
||||
outputs = classifier_model(attention_output)
|
||||
outputs = torch.mean(outputs, dim=1)
|
||||
# 获取预测标签
|
||||
_, predicted = torch.max(outputs, 1)
|
||||
|
||||
return predicted.cpu().numpy()
|
||||
predictions, probabilities = model_manager.predict_batch(texts)
|
||||
if predictions is not None:
|
||||
return predictions, probabilities
|
||||
return None, None
|
||||
except Exception as e:
|
||||
print(f"预测过程中出现错误: {e}")
|
||||
return None
|
||||
return None, None
|
||||
|
||||
def getTypeList():
|
||||
# Return a list of unique article types
|
||||
@@ -194,15 +172,23 @@ def getYuQingCharDataTwo(model_type='pro'):
|
||||
article_sentiments.append('不良')
|
||||
else:
|
||||
# 使用改进模型
|
||||
comment_predictions = predict_sentiment(comment_texts)
|
||||
comment_predictions, comment_probs = predict_sentiment(comment_texts)
|
||||
if comment_predictions is not None:
|
||||
comment_sentiments = ['良好' if pred == 0 else '不良' for pred in comment_predictions]
|
||||
comment_sentiments = []
|
||||
for pred, prob in zip(comment_predictions, comment_probs):
|
||||
label = '良好' if pred == 0 else '不良'
|
||||
confidence = prob[pred]
|
||||
comment_sentiments.append(f"{label} ({confidence:.2%})")
|
||||
else:
|
||||
comment_sentiments = []
|
||||
|
||||
article_predictions = predict_sentiment(article_texts)
|
||||
article_predictions, article_probs = predict_sentiment(article_texts)
|
||||
if article_predictions is not None:
|
||||
article_sentiments = ['良好' if pred == 0 else '不良' for pred in article_predictions]
|
||||
article_sentiments = []
|
||||
for pred, prob in zip(article_predictions, article_probs):
|
||||
label = '良好' if pred == 0 else '不良'
|
||||
confidence = prob[pred]
|
||||
article_sentiments.append(f"{label} ({confidence:.2%})")
|
||||
else:
|
||||
article_sentiments = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user