From 607db7317ebd2ecc29181c161465cc24e221a497 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=88=92=E9=85=92=E7=9A=84=E6=9D=8E=E7=99=BD?= <670939375@qq.com> Date: Sat, 8 Feb 2025 23:00:11 +0800 Subject: [PATCH] Optimize model loading and prediction performance, implement the singleton pattern, and provide comprehensive error handling and error messages, along with confidence level display. --- BCAT_front/predict.py | 92 +++++++++++++++++++++++++---- utils/getEchartsData.py | 56 +++++++----------- views/page/page.py | 128 +++++++++++++++++----------------------- 3 files changed, 157 insertions(+), 119 deletions(-) diff --git a/BCAT_front/predict.py b/BCAT_front/predict.py index 5e0e1ad..4562d94 100644 --- a/BCAT_front/predict.py +++ b/BCAT_front/predict.py @@ -13,9 +13,83 @@ from model_pro.MHA import MultiHeadAttentionLayer from model_pro.classifier import FinalClassifier from model_pro.BERT_CTM import BERT_CTM_Model -# 设置设备 -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +class ModelManager: + _instance = None + _initialized = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super(ModelManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + if not self._initialized: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.classifier_model = None + self.attention_model = None + self.bert_ctm_model = None + self._initialized = True + + def load_models(self, model_save_path, bert_model_path, ctm_tokenizer_path): + """加载所有需要的模型""" + try: + if self.classifier_model is None: + self.classifier_model = torch.load(model_save_path, map_location=self.device) + self.classifier_model.eval() + + if self.attention_model is None: + self.attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8) + self.attention_model.to(self.device) + self.attention_model.eval() + + if self.bert_ctm_model is None: + self.bert_ctm_model = BERT_CTM_Model( + bert_model_path=bert_model_path, + ctm_tokenizer_path=ctm_tokenizer_path + ) + return True + except Exception as e: + print(f"模型加载失败: {e}") + return False + + def predict_batch(self, texts, batch_size=32): + """批量预测文本情感""" + try: + all_predictions = [] + all_probabilities = [] + + # 分批处理文本 + for i in range(0, len(texts), batch_size): + batch_texts = texts[i:i + batch_size] + + # 获取文本嵌入 + embeddings = self.bert_ctm_model.get_bert_embeddings(batch_texts) + + # 转换为tensor + batch_x = torch.tensor(embeddings, dtype=torch.float32).to(self.device) + batch_x = torch.mean(batch_x, dim=1) + + with torch.no_grad(): + # 使用注意力机制 + attention_output = self.attention_model(batch_x, batch_x, batch_x) + # 获取分类结果 + outputs = self.classifier_model(attention_output) + outputs = torch.mean(outputs, dim=1) + # 获取预测概率 + probabilities = torch.softmax(outputs, dim=1) + # 获取预测标签 + _, predicted = torch.max(outputs, 1) + + all_predictions.extend(predicted.cpu().numpy()) + all_probabilities.extend(probabilities.cpu().numpy()) + + return all_predictions, all_probabilities + except Exception as e: + print(f"预测过程中出现错误: {e}") + return None, None +# 创建全局的模型管理器实例 +model_manager = ModelManager() def detect_file_encoding(file_path, num_bytes=10000): """ @@ -59,12 +133,8 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_ try: # 加载模型 print("加载模型...") - classifier_model = torch.load(model_save_path, map_location=device) - classifier_model.eval() - - attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8) - attention_model.to(device) - attention_model.eval() + if not model_manager.load_models(model_save_path, bert_model_path, ctm_tokenizer_path): + return False # 检测文件编码 encoding = detect_file_encoding(input_data_path) @@ -88,14 +158,14 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_ print("开始预测...") with torch.no_grad(): for batch in tqdm(data_loader, desc="预测进度"): - batch_x = batch[0].to(device) + batch_x = batch[0].to(model_manager.device) batch_x = torch.mean(batch_x, dim=1) # 使用注意力机制 - attention_output = attention_model(batch_x, batch_x, batch_x) + attention_output = model_manager.attention_model(batch_x, batch_x, batch_x) # 获取分类结果 - outputs = classifier_model(attention_output) + outputs = model_manager.classifier_model(attention_output) outputs = torch.mean(outputs, dim=1) # 获取预测概率 diff --git a/utils/getEchartsData.py b/utils/getEchartsData.py index 2e633b0..72352a1 100644 --- a/utils/getEchartsData.py +++ b/utils/getEchartsData.py @@ -2,9 +2,7 @@ from utils.getPublicData import * # Import utility functions for data retrieval from utils.mynlp import SnowNLP # Import SnowNLP for sentiment analysis from collections import Counter # Import Counter for counting occurrences import torch -from model_pro.MHA import MultiHeadAttentionLayer -from model_pro.classifier import FinalClassifier -from model_pro.BERT_CTM import BERT_CTM_Model +from BCAT_front.predict import model_manager articleList = getAllArticleData() # Retrieve all article data commentList = getAllCommentsData() # Retrieve all comment data @@ -12,47 +10,27 @@ commentList = getAllCommentsData() # Retrieve all comment data # 设置设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -# 加载模型(全局变量,避免重复加载) +# 设置模型路径 model_save_path = 'model_pro/final_model.pt' bert_model_path = 'model_pro/bert_model' ctm_tokenizer_path = 'model_pro/sentence_bert_model' +# 初始化模型 try: - classifier_model = torch.load(model_save_path, map_location=device) - classifier_model.eval() - attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8) - attention_model.to(device) - attention_model.eval() - bert_ctm_model = BERT_CTM_Model( - bert_model_path=bert_model_path, - ctm_tokenizer_path=ctm_tokenizer_path - ) + model_manager.load_models(model_save_path, bert_model_path, ctm_tokenizer_path) except Exception as e: print(f"模型加载失败: {e}") def predict_sentiment(texts): """使用改进版模型预测情感""" try: - # 获取文本嵌入 - embeddings = bert_ctm_model.get_bert_embeddings(texts) - - # 转换为tensor - batch_x = torch.tensor(embeddings, dtype=torch.float32).to(device) - batch_x = torch.mean(batch_x, dim=1) - - with torch.no_grad(): - # 使用注意力机制 - attention_output = attention_model(batch_x, batch_x, batch_x) - # 获取分类结果 - outputs = classifier_model(attention_output) - outputs = torch.mean(outputs, dim=1) - # 获取预测标签 - _, predicted = torch.max(outputs, 1) - - return predicted.cpu().numpy() + predictions, probabilities = model_manager.predict_batch(texts) + if predictions is not None: + return predictions, probabilities + return None, None except Exception as e: print(f"预测过程中出现错误: {e}") - return None + return None, None def getTypeList(): # Return a list of unique article types @@ -194,15 +172,23 @@ def getYuQingCharDataTwo(model_type='pro'): article_sentiments.append('不良') else: # 使用改进模型 - comment_predictions = predict_sentiment(comment_texts) + comment_predictions, comment_probs = predict_sentiment(comment_texts) if comment_predictions is not None: - comment_sentiments = ['良好' if pred == 0 else '不良' for pred in comment_predictions] + comment_sentiments = [] + for pred, prob in zip(comment_predictions, comment_probs): + label = '良好' if pred == 0 else '不良' + confidence = prob[pred] + comment_sentiments.append(f"{label} ({confidence:.2%})") else: comment_sentiments = [] - article_predictions = predict_sentiment(article_texts) + article_predictions, article_probs = predict_sentiment(article_texts) if article_predictions is not None: - article_sentiments = ['良好' if pred == 0 else '不良' for pred in article_predictions] + article_sentiments = [] + for pred, prob in zip(article_predictions, article_probs): + label = '良好' if pred == 0 else '不良' + confidence = prob[pred] + article_sentiments.append(f"{label} ({confidence:.2%})") else: article_sentiments = [] diff --git a/views/page/page.py b/views/page/page.py index 0fa3cc5..c788e87 100644 --- a/views/page/page.py +++ b/views/page/page.py @@ -1,4 +1,4 @@ -from flask import Flask, session, render_template, redirect, Blueprint, request +from flask import Flask, session, render_template, redirect, Blueprint, request, jsonify from utils.mynlp import SnowNLP from utils.getHomePageData import * from utils.getHotWordPageData import * @@ -9,9 +9,7 @@ from utils.getTopicPageData import * from utils.yuqingpredict import * from utils.logger import app_logger as logging import torch -from model_pro.MHA import MultiHeadAttentionLayer -from model_pro.classifier import FinalClassifier -from model_pro.BERT_CTM import BERT_CTM_Model +from BCAT_front.predict import model_manager pb = Blueprint('page', __name__, @@ -21,47 +19,26 @@ pb = Blueprint('page', # 设置设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -# 加载模型(全局变量,避免重复加载) +# 设置模型路径 model_save_path = 'model_pro/final_model.pt' bert_model_path = 'model_pro/bert_model' ctm_tokenizer_path = 'model_pro/sentence_bert_model' +# 初始化模型 try: - classifier_model = torch.load(model_save_path, map_location=device) - classifier_model.eval() - attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8) - attention_model.to(device) - attention_model.eval() - bert_ctm_model = BERT_CTM_Model( - bert_model_path=bert_model_path, - ctm_tokenizer_path=ctm_tokenizer_path - ) + model_manager.load_models(model_save_path, bert_model_path, ctm_tokenizer_path) except Exception as e: - print(f"模型加载失败: {e}") + logging.error(f"模型加载失败: {e}") def predict_sentiment(text): """使用改进版模型预测单个文本的情感""" try: - # 获取文本嵌入 - embeddings = bert_ctm_model.get_bert_embeddings([text]) - - # 转换为tensor - batch_x = torch.tensor(embeddings, dtype=torch.float32).to(device) - batch_x = torch.mean(batch_x, dim=1) - - with torch.no_grad(): - # 使用注意力机制 - attention_output = attention_model(batch_x, batch_x, batch_x) - # 获取分类结果 - outputs = classifier_model(attention_output) - outputs = torch.mean(outputs, dim=1) - # 获取预测标签和概率 - probabilities = torch.softmax(outputs, dim=1) - _, predicted = torch.max(outputs, 1) - - return predicted.item(), probabilities[0][predicted.item()].item() + predictions, probabilities = model_manager.predict_batch([text]) + if predictions is not None and len(predictions) > 0: + return predictions[0], probabilities[0][predictions[0]] + return None, None except Exception as e: - print(f"预测过程中出现错误: {e}") + logging.error(f"预测过程中出现错误: {e}") return None, None @pb.route('/home') @@ -218,46 +195,51 @@ def yuqingChar(): @pb.route('/yuqingpredict') def yuqingpredict(): - username = session.get('username') - TopicList = getAllTopicData() - defaultTopic = TopicList[0][0] - if request.args.get('Topic'): - defaultTopic = request.args.get('Topic') - TopicLen = getTopicLen(defaultTopic) - X, Y = getTopicCreatedAtandpredictData(defaultTopic) - - # 获取模型选择参数 - model_type = request.args.get('model', 'pro') # 默认使用改进模型 - - if model_type == 'basic': - # 使用基础模型(SnowNLP) - value = SnowNLP(defaultTopic).sentiments - if value == 0.5: - sentences = '中性' - elif value > 0.5: - sentences = '正面' - elif value < 0.5: - sentences = '负面' - else: - # 使用改进模型 - predicted_label, confidence = predict_sentiment(defaultTopic) - if predicted_label is not None: - sentences = '良好' if predicted_label == 0 else '不良' - sentences = f"{sentences} (置信度: {confidence:.2f})" + try: + username = session.get('username') + TopicList = getAllTopicData() + defaultTopic = TopicList[0][0] + if request.args.get('Topic'): + defaultTopic = request.args.get('Topic') + TopicLen = getTopicLen(defaultTopic) + X, Y = getTopicCreatedAtandpredictData(defaultTopic) + + # 获取模型选择参数 + model_type = request.args.get('model', 'pro') # 默认使用改进模型 + + if model_type == 'basic': + # 使用基础模型(SnowNLP) + value = SnowNLP(defaultTopic).sentiments + if value == 0.5: + sentences = '中性' + elif value > 0.5: + sentences = '正面' + elif value < 0.5: + sentences = '负面' else: - sentences = '预测失败' - - comments = getCommentFilterDataTopic(defaultTopic) - return render_template('yuqingpredict.html', - username=username, - hotWordList=TopicList, - defaultHotWord=defaultTopic, - hotWordLen=TopicLen, - sentences=sentences, - xData=X, - yData=Y, - comments=comments, - model_type=model_type) + # 使用改进模型 + predicted_label, confidence = predict_sentiment(defaultTopic) + if predicted_label is not None: + sentences = '良好' if predicted_label == 0 else '不良' + sentences = f"{sentences} (置信度: {confidence:.2%})" + else: + sentences = '预测失败,请稍后重试' + logging.error(f"预测失败,话题: {defaultTopic}") + + comments = getCommentFilterDataTopic(defaultTopic) + return render_template('yuqingpredict.html', + username=username, + hotWordList=TopicList, + defaultHotWord=defaultTopic, + hotWordLen=TopicLen, + sentences=sentences, + xData=X, + yData=Y, + comments=comments, + model_type=model_type) + except Exception as e: + logging.error(f"舆情预测页面渲染失败: {e}") + return render_template('error.html', error_message="加载舆情预测页面失败,请稍后重试") @pb.route('/articleCloud')