Optimize model loading and prediction performance, implement the singleton pattern, and provide comprehensive error handling and error messages, along with confidence level display.

This commit is contained in:
戒酒的李白
2025-02-08 23:00:11 +08:00
parent 1707c2c3de
commit 607db7317e
3 changed files with 157 additions and 119 deletions
+81 -11
View File
@@ -13,9 +13,83 @@ from model_pro.MHA import MultiHeadAttentionLayer
from model_pro.classifier import FinalClassifier from model_pro.classifier import FinalClassifier
from model_pro.BERT_CTM import BERT_CTM_Model from model_pro.BERT_CTM import BERT_CTM_Model
# 设置设备 class ModelManager:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") _instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super(ModelManager, cls).__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.classifier_model = None
self.attention_model = None
self.bert_ctm_model = None
self._initialized = True
def load_models(self, model_save_path, bert_model_path, ctm_tokenizer_path):
"""加载所有需要的模型"""
try:
if self.classifier_model is None:
self.classifier_model = torch.load(model_save_path, map_location=self.device)
self.classifier_model.eval()
if self.attention_model is None:
self.attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)
self.attention_model.to(self.device)
self.attention_model.eval()
if self.bert_ctm_model is None:
self.bert_ctm_model = BERT_CTM_Model(
bert_model_path=bert_model_path,
ctm_tokenizer_path=ctm_tokenizer_path
)
return True
except Exception as e:
print(f"模型加载失败: {e}")
return False
def predict_batch(self, texts, batch_size=32):
"""批量预测文本情感"""
try:
all_predictions = []
all_probabilities = []
# 分批处理文本
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
# 获取文本嵌入
embeddings = self.bert_ctm_model.get_bert_embeddings(batch_texts)
# 转换为tensor
batch_x = torch.tensor(embeddings, dtype=torch.float32).to(self.device)
batch_x = torch.mean(batch_x, dim=1)
with torch.no_grad():
# 使用注意力机制
attention_output = self.attention_model(batch_x, batch_x, batch_x)
# 获取分类结果
outputs = self.classifier_model(attention_output)
outputs = torch.mean(outputs, dim=1)
# 获取预测概率
probabilities = torch.softmax(outputs, dim=1)
# 获取预测标签
_, predicted = torch.max(outputs, 1)
all_predictions.extend(predicted.cpu().numpy())
all_probabilities.extend(probabilities.cpu().numpy())
return all_predictions, all_probabilities
except Exception as e:
print(f"预测过程中出现错误: {e}")
return None, None
# 创建全局的模型管理器实例
model_manager = ModelManager()
def detect_file_encoding(file_path, num_bytes=10000): def detect_file_encoding(file_path, num_bytes=10000):
""" """
@@ -59,12 +133,8 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_
try: try:
# 加载模型 # 加载模型
print("加载模型...") print("加载模型...")
classifier_model = torch.load(model_save_path, map_location=device) if not model_manager.load_models(model_save_path, bert_model_path, ctm_tokenizer_path):
classifier_model.eval() return False
attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)
attention_model.to(device)
attention_model.eval()
# 检测文件编码 # 检测文件编码
encoding = detect_file_encoding(input_data_path) encoding = detect_file_encoding(input_data_path)
@@ -88,14 +158,14 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_
print("开始预测...") print("开始预测...")
with torch.no_grad(): with torch.no_grad():
for batch in tqdm(data_loader, desc="预测进度"): for batch in tqdm(data_loader, desc="预测进度"):
batch_x = batch[0].to(device) batch_x = batch[0].to(model_manager.device)
batch_x = torch.mean(batch_x, dim=1) batch_x = torch.mean(batch_x, dim=1)
# 使用注意力机制 # 使用注意力机制
attention_output = attention_model(batch_x, batch_x, batch_x) attention_output = model_manager.attention_model(batch_x, batch_x, batch_x)
# 获取分类结果 # 获取分类结果
outputs = classifier_model(attention_output) outputs = model_manager.classifier_model(attention_output)
outputs = torch.mean(outputs, dim=1) outputs = torch.mean(outputs, dim=1)
# 获取预测概率 # 获取预测概率
+21 -35
View File
@@ -2,9 +2,7 @@ from utils.getPublicData import * # Import utility functions for data retrieval
from utils.mynlp import SnowNLP # Import SnowNLP for sentiment analysis from utils.mynlp import SnowNLP # Import SnowNLP for sentiment analysis
from collections import Counter # Import Counter for counting occurrences from collections import Counter # Import Counter for counting occurrences
import torch import torch
from model_pro.MHA import MultiHeadAttentionLayer from BCAT_front.predict import model_manager
from model_pro.classifier import FinalClassifier
from model_pro.BERT_CTM import BERT_CTM_Model
articleList = getAllArticleData() # Retrieve all article data articleList = getAllArticleData() # Retrieve all article data
commentList = getAllCommentsData() # Retrieve all comment data commentList = getAllCommentsData() # Retrieve all comment data
@@ -12,47 +10,27 @@ commentList = getAllCommentsData() # Retrieve all comment data
# 设置设备 # 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载模型(全局变量,避免重复加载) # 设置模型路径
model_save_path = 'model_pro/final_model.pt' model_save_path = 'model_pro/final_model.pt'
bert_model_path = 'model_pro/bert_model' bert_model_path = 'model_pro/bert_model'
ctm_tokenizer_path = 'model_pro/sentence_bert_model' ctm_tokenizer_path = 'model_pro/sentence_bert_model'
# 初始化模型
try: try:
classifier_model = torch.load(model_save_path, map_location=device) model_manager.load_models(model_save_path, bert_model_path, ctm_tokenizer_path)
classifier_model.eval()
attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)
attention_model.to(device)
attention_model.eval()
bert_ctm_model = BERT_CTM_Model(
bert_model_path=bert_model_path,
ctm_tokenizer_path=ctm_tokenizer_path
)
except Exception as e: except Exception as e:
print(f"模型加载失败: {e}") print(f"模型加载失败: {e}")
def predict_sentiment(texts): def predict_sentiment(texts):
"""使用改进版模型预测情感""" """使用改进版模型预测情感"""
try: try:
# 获取文本嵌入 predictions, probabilities = model_manager.predict_batch(texts)
embeddings = bert_ctm_model.get_bert_embeddings(texts) if predictions is not None:
return predictions, probabilities
# 转换为tensor return None, None
batch_x = torch.tensor(embeddings, dtype=torch.float32).to(device)
batch_x = torch.mean(batch_x, dim=1)
with torch.no_grad():
# 使用注意力机制
attention_output = attention_model(batch_x, batch_x, batch_x)
# 获取分类结果
outputs = classifier_model(attention_output)
outputs = torch.mean(outputs, dim=1)
# 获取预测标签
_, predicted = torch.max(outputs, 1)
return predicted.cpu().numpy()
except Exception as e: except Exception as e:
print(f"预测过程中出现错误: {e}") print(f"预测过程中出现错误: {e}")
return None return None, None
def getTypeList(): def getTypeList():
# Return a list of unique article types # Return a list of unique article types
@@ -194,15 +172,23 @@ def getYuQingCharDataTwo(model_type='pro'):
article_sentiments.append('不良') article_sentiments.append('不良')
else: else:
# 使用改进模型 # 使用改进模型
comment_predictions = predict_sentiment(comment_texts) comment_predictions, comment_probs = predict_sentiment(comment_texts)
if comment_predictions is not None: if comment_predictions is not None:
comment_sentiments = ['良好' if pred == 0 else '不良' for pred in comment_predictions] comment_sentiments = []
for pred, prob in zip(comment_predictions, comment_probs):
label = '良好' if pred == 0 else '不良'
confidence = prob[pred]
comment_sentiments.append(f"{label} ({confidence:.2%})")
else: else:
comment_sentiments = [] comment_sentiments = []
article_predictions = predict_sentiment(article_texts) article_predictions, article_probs = predict_sentiment(article_texts)
if article_predictions is not None: if article_predictions is not None:
article_sentiments = ['良好' if pred == 0 else '不良' for pred in article_predictions] article_sentiments = []
for pred, prob in zip(article_predictions, article_probs):
label = '良好' if pred == 0 else '不良'
confidence = prob[pred]
article_sentiments.append(f"{label} ({confidence:.2%})")
else: else:
article_sentiments = [] article_sentiments = []
+55 -73
View File
@@ -1,4 +1,4 @@
from flask import Flask, session, render_template, redirect, Blueprint, request from flask import Flask, session, render_template, redirect, Blueprint, request, jsonify
from utils.mynlp import SnowNLP from utils.mynlp import SnowNLP
from utils.getHomePageData import * from utils.getHomePageData import *
from utils.getHotWordPageData import * from utils.getHotWordPageData import *
@@ -9,9 +9,7 @@ from utils.getTopicPageData import *
from utils.yuqingpredict import * from utils.yuqingpredict import *
from utils.logger import app_logger as logging from utils.logger import app_logger as logging
import torch import torch
from model_pro.MHA import MultiHeadAttentionLayer from BCAT_front.predict import model_manager
from model_pro.classifier import FinalClassifier
from model_pro.BERT_CTM import BERT_CTM_Model
pb = Blueprint('page', pb = Blueprint('page',
__name__, __name__,
@@ -21,47 +19,26 @@ pb = Blueprint('page',
# 设置设备 # 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载模型(全局变量,避免重复加载) # 设置模型路径
model_save_path = 'model_pro/final_model.pt' model_save_path = 'model_pro/final_model.pt'
bert_model_path = 'model_pro/bert_model' bert_model_path = 'model_pro/bert_model'
ctm_tokenizer_path = 'model_pro/sentence_bert_model' ctm_tokenizer_path = 'model_pro/sentence_bert_model'
# 初始化模型
try: try:
classifier_model = torch.load(model_save_path, map_location=device) model_manager.load_models(model_save_path, bert_model_path, ctm_tokenizer_path)
classifier_model.eval()
attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)
attention_model.to(device)
attention_model.eval()
bert_ctm_model = BERT_CTM_Model(
bert_model_path=bert_model_path,
ctm_tokenizer_path=ctm_tokenizer_path
)
except Exception as e: except Exception as e:
print(f"模型加载失败: {e}") logging.error(f"模型加载失败: {e}")
def predict_sentiment(text): def predict_sentiment(text):
"""使用改进版模型预测单个文本的情感""" """使用改进版模型预测单个文本的情感"""
try: try:
# 获取文本嵌入 predictions, probabilities = model_manager.predict_batch([text])
embeddings = bert_ctm_model.get_bert_embeddings([text]) if predictions is not None and len(predictions) > 0:
return predictions[0], probabilities[0][predictions[0]]
# 转换为tensor return None, None
batch_x = torch.tensor(embeddings, dtype=torch.float32).to(device)
batch_x = torch.mean(batch_x, dim=1)
with torch.no_grad():
# 使用注意力机制
attention_output = attention_model(batch_x, batch_x, batch_x)
# 获取分类结果
outputs = classifier_model(attention_output)
outputs = torch.mean(outputs, dim=1)
# 获取预测标签和概率
probabilities = torch.softmax(outputs, dim=1)
_, predicted = torch.max(outputs, 1)
return predicted.item(), probabilities[0][predicted.item()].item()
except Exception as e: except Exception as e:
print(f"预测过程中出现错误: {e}") logging.error(f"预测过程中出现错误: {e}")
return None, None return None, None
@pb.route('/home') @pb.route('/home')
@@ -218,46 +195,51 @@ def yuqingChar():
@pb.route('/yuqingpredict') @pb.route('/yuqingpredict')
def yuqingpredict(): def yuqingpredict():
username = session.get('username') try:
TopicList = getAllTopicData() username = session.get('username')
defaultTopic = TopicList[0][0] TopicList = getAllTopicData()
if request.args.get('Topic'): defaultTopic = TopicList[0][0]
defaultTopic = request.args.get('Topic') if request.args.get('Topic'):
TopicLen = getTopicLen(defaultTopic) defaultTopic = request.args.get('Topic')
X, Y = getTopicCreatedAtandpredictData(defaultTopic) TopicLen = getTopicLen(defaultTopic)
X, Y = getTopicCreatedAtandpredictData(defaultTopic)
# 获取模型选择参数
model_type = request.args.get('model', 'pro') # 默认使用改进模型 # 获取模型选择参数
model_type = request.args.get('model', 'pro') # 默认使用改进模型
if model_type == 'basic':
# 使用基础模型(SnowNLP if model_type == 'basic':
value = SnowNLP(defaultTopic).sentiments # 使用基础模型(SnowNLP
if value == 0.5: value = SnowNLP(defaultTopic).sentiments
sentences = '中性' if value == 0.5:
elif value > 0.5: sentences = '中性'
sentences = '正面' elif value > 0.5:
elif value < 0.5: sentences = '正面'
sentences = '负面' elif value < 0.5:
else: sentences = '负面'
# 使用改进模型
predicted_label, confidence = predict_sentiment(defaultTopic)
if predicted_label is not None:
sentences = '良好' if predicted_label == 0 else '不良'
sentences = f"{sentences} (置信度: {confidence:.2f})"
else: else:
sentences = '预测失败' # 使用改进模型
predicted_label, confidence = predict_sentiment(defaultTopic)
comments = getCommentFilterDataTopic(defaultTopic) if predicted_label is not None:
return render_template('yuqingpredict.html', sentences = '良好' if predicted_label == 0 else '不良'
username=username, sentences = f"{sentences} (置信度: {confidence:.2%})"
hotWordList=TopicList, else:
defaultHotWord=defaultTopic, sentences = '预测失败,请稍后重试'
hotWordLen=TopicLen, logging.error(f"预测失败,话题: {defaultTopic}")
sentences=sentences,
xData=X, comments = getCommentFilterDataTopic(defaultTopic)
yData=Y, return render_template('yuqingpredict.html',
comments=comments, username=username,
model_type=model_type) hotWordList=TopicList,
defaultHotWord=defaultTopic,
hotWordLen=TopicLen,
sentences=sentences,
xData=X,
yData=Y,
comments=comments,
model_type=model_type)
except Exception as e:
logging.error(f"舆情预测页面渲染失败: {e}")
return render_template('error.html', error_message="加载舆情预测页面失败,请稍后重试")
@pb.route('/articleCloud') @pb.route('/articleCloud')