Optimize model loading and prediction performance, implement the singleton pattern, and provide comprehensive error handling and error messages, along with confidence level display.
This commit is contained in:
+81
-11
@@ -13,9 +13,83 @@ from model_pro.MHA import MultiHeadAttentionLayer
|
|||||||
from model_pro.classifier import FinalClassifier
|
from model_pro.classifier import FinalClassifier
|
||||||
from model_pro.BERT_CTM import BERT_CTM_Model
|
from model_pro.BERT_CTM import BERT_CTM_Model
|
||||||
|
|
||||||
# 设置设备
|
class ModelManager:
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
_instance = None
|
||||||
|
_initialized = False
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super(ModelManager, cls).__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if not self._initialized:
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.classifier_model = None
|
||||||
|
self.attention_model = None
|
||||||
|
self.bert_ctm_model = None
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
def load_models(self, model_save_path, bert_model_path, ctm_tokenizer_path):
|
||||||
|
"""加载所有需要的模型"""
|
||||||
|
try:
|
||||||
|
if self.classifier_model is None:
|
||||||
|
self.classifier_model = torch.load(model_save_path, map_location=self.device)
|
||||||
|
self.classifier_model.eval()
|
||||||
|
|
||||||
|
if self.attention_model is None:
|
||||||
|
self.attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)
|
||||||
|
self.attention_model.to(self.device)
|
||||||
|
self.attention_model.eval()
|
||||||
|
|
||||||
|
if self.bert_ctm_model is None:
|
||||||
|
self.bert_ctm_model = BERT_CTM_Model(
|
||||||
|
bert_model_path=bert_model_path,
|
||||||
|
ctm_tokenizer_path=ctm_tokenizer_path
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"模型加载失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def predict_batch(self, texts, batch_size=32):
|
||||||
|
"""批量预测文本情感"""
|
||||||
|
try:
|
||||||
|
all_predictions = []
|
||||||
|
all_probabilities = []
|
||||||
|
|
||||||
|
# 分批处理文本
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
batch_texts = texts[i:i + batch_size]
|
||||||
|
|
||||||
|
# 获取文本嵌入
|
||||||
|
embeddings = self.bert_ctm_model.get_bert_embeddings(batch_texts)
|
||||||
|
|
||||||
|
# 转换为tensor
|
||||||
|
batch_x = torch.tensor(embeddings, dtype=torch.float32).to(self.device)
|
||||||
|
batch_x = torch.mean(batch_x, dim=1)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# 使用注意力机制
|
||||||
|
attention_output = self.attention_model(batch_x, batch_x, batch_x)
|
||||||
|
# 获取分类结果
|
||||||
|
outputs = self.classifier_model(attention_output)
|
||||||
|
outputs = torch.mean(outputs, dim=1)
|
||||||
|
# 获取预测概率
|
||||||
|
probabilities = torch.softmax(outputs, dim=1)
|
||||||
|
# 获取预测标签
|
||||||
|
_, predicted = torch.max(outputs, 1)
|
||||||
|
|
||||||
|
all_predictions.extend(predicted.cpu().numpy())
|
||||||
|
all_probabilities.extend(probabilities.cpu().numpy())
|
||||||
|
|
||||||
|
return all_predictions, all_probabilities
|
||||||
|
except Exception as e:
|
||||||
|
print(f"预测过程中出现错误: {e}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# 创建全局的模型管理器实例
|
||||||
|
model_manager = ModelManager()
|
||||||
|
|
||||||
def detect_file_encoding(file_path, num_bytes=10000):
|
def detect_file_encoding(file_path, num_bytes=10000):
|
||||||
"""
|
"""
|
||||||
@@ -59,12 +133,8 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_
|
|||||||
try:
|
try:
|
||||||
# 加载模型
|
# 加载模型
|
||||||
print("加载模型...")
|
print("加载模型...")
|
||||||
classifier_model = torch.load(model_save_path, map_location=device)
|
if not model_manager.load_models(model_save_path, bert_model_path, ctm_tokenizer_path):
|
||||||
classifier_model.eval()
|
return False
|
||||||
|
|
||||||
attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)
|
|
||||||
attention_model.to(device)
|
|
||||||
attention_model.eval()
|
|
||||||
|
|
||||||
# 检测文件编码
|
# 检测文件编码
|
||||||
encoding = detect_file_encoding(input_data_path)
|
encoding = detect_file_encoding(input_data_path)
|
||||||
@@ -88,14 +158,14 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_
|
|||||||
print("开始预测...")
|
print("开始预测...")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in tqdm(data_loader, desc="预测进度"):
|
for batch in tqdm(data_loader, desc="预测进度"):
|
||||||
batch_x = batch[0].to(device)
|
batch_x = batch[0].to(model_manager.device)
|
||||||
batch_x = torch.mean(batch_x, dim=1)
|
batch_x = torch.mean(batch_x, dim=1)
|
||||||
|
|
||||||
# 使用注意力机制
|
# 使用注意力机制
|
||||||
attention_output = attention_model(batch_x, batch_x, batch_x)
|
attention_output = model_manager.attention_model(batch_x, batch_x, batch_x)
|
||||||
|
|
||||||
# 获取分类结果
|
# 获取分类结果
|
||||||
outputs = classifier_model(attention_output)
|
outputs = model_manager.classifier_model(attention_output)
|
||||||
outputs = torch.mean(outputs, dim=1)
|
outputs = torch.mean(outputs, dim=1)
|
||||||
|
|
||||||
# 获取预测概率
|
# 获取预测概率
|
||||||
|
|||||||
+21
-35
@@ -2,9 +2,7 @@ from utils.getPublicData import * # Import utility functions for data retrieval
|
|||||||
from utils.mynlp import SnowNLP # Import SnowNLP for sentiment analysis
|
from utils.mynlp import SnowNLP # Import SnowNLP for sentiment analysis
|
||||||
from collections import Counter # Import Counter for counting occurrences
|
from collections import Counter # Import Counter for counting occurrences
|
||||||
import torch
|
import torch
|
||||||
from model_pro.MHA import MultiHeadAttentionLayer
|
from BCAT_front.predict import model_manager
|
||||||
from model_pro.classifier import FinalClassifier
|
|
||||||
from model_pro.BERT_CTM import BERT_CTM_Model
|
|
||||||
|
|
||||||
articleList = getAllArticleData() # Retrieve all article data
|
articleList = getAllArticleData() # Retrieve all article data
|
||||||
commentList = getAllCommentsData() # Retrieve all comment data
|
commentList = getAllCommentsData() # Retrieve all comment data
|
||||||
@@ -12,47 +10,27 @@ commentList = getAllCommentsData() # Retrieve all comment data
|
|||||||
# 设置设备
|
# 设置设备
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
# 加载模型(全局变量,避免重复加载)
|
# 设置模型路径
|
||||||
model_save_path = 'model_pro/final_model.pt'
|
model_save_path = 'model_pro/final_model.pt'
|
||||||
bert_model_path = 'model_pro/bert_model'
|
bert_model_path = 'model_pro/bert_model'
|
||||||
ctm_tokenizer_path = 'model_pro/sentence_bert_model'
|
ctm_tokenizer_path = 'model_pro/sentence_bert_model'
|
||||||
|
|
||||||
|
# 初始化模型
|
||||||
try:
|
try:
|
||||||
classifier_model = torch.load(model_save_path, map_location=device)
|
model_manager.load_models(model_save_path, bert_model_path, ctm_tokenizer_path)
|
||||||
classifier_model.eval()
|
|
||||||
attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)
|
|
||||||
attention_model.to(device)
|
|
||||||
attention_model.eval()
|
|
||||||
bert_ctm_model = BERT_CTM_Model(
|
|
||||||
bert_model_path=bert_model_path,
|
|
||||||
ctm_tokenizer_path=ctm_tokenizer_path
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"模型加载失败: {e}")
|
print(f"模型加载失败: {e}")
|
||||||
|
|
||||||
def predict_sentiment(texts):
|
def predict_sentiment(texts):
|
||||||
"""使用改进版模型预测情感"""
|
"""使用改进版模型预测情感"""
|
||||||
try:
|
try:
|
||||||
# 获取文本嵌入
|
predictions, probabilities = model_manager.predict_batch(texts)
|
||||||
embeddings = bert_ctm_model.get_bert_embeddings(texts)
|
if predictions is not None:
|
||||||
|
return predictions, probabilities
|
||||||
# 转换为tensor
|
return None, None
|
||||||
batch_x = torch.tensor(embeddings, dtype=torch.float32).to(device)
|
|
||||||
batch_x = torch.mean(batch_x, dim=1)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
# 使用注意力机制
|
|
||||||
attention_output = attention_model(batch_x, batch_x, batch_x)
|
|
||||||
# 获取分类结果
|
|
||||||
outputs = classifier_model(attention_output)
|
|
||||||
outputs = torch.mean(outputs, dim=1)
|
|
||||||
# 获取预测标签
|
|
||||||
_, predicted = torch.max(outputs, 1)
|
|
||||||
|
|
||||||
return predicted.cpu().numpy()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"预测过程中出现错误: {e}")
|
print(f"预测过程中出现错误: {e}")
|
||||||
return None
|
return None, None
|
||||||
|
|
||||||
def getTypeList():
|
def getTypeList():
|
||||||
# Return a list of unique article types
|
# Return a list of unique article types
|
||||||
@@ -194,15 +172,23 @@ def getYuQingCharDataTwo(model_type='pro'):
|
|||||||
article_sentiments.append('不良')
|
article_sentiments.append('不良')
|
||||||
else:
|
else:
|
||||||
# 使用改进模型
|
# 使用改进模型
|
||||||
comment_predictions = predict_sentiment(comment_texts)
|
comment_predictions, comment_probs = predict_sentiment(comment_texts)
|
||||||
if comment_predictions is not None:
|
if comment_predictions is not None:
|
||||||
comment_sentiments = ['良好' if pred == 0 else '不良' for pred in comment_predictions]
|
comment_sentiments = []
|
||||||
|
for pred, prob in zip(comment_predictions, comment_probs):
|
||||||
|
label = '良好' if pred == 0 else '不良'
|
||||||
|
confidence = prob[pred]
|
||||||
|
comment_sentiments.append(f"{label} ({confidence:.2%})")
|
||||||
else:
|
else:
|
||||||
comment_sentiments = []
|
comment_sentiments = []
|
||||||
|
|
||||||
article_predictions = predict_sentiment(article_texts)
|
article_predictions, article_probs = predict_sentiment(article_texts)
|
||||||
if article_predictions is not None:
|
if article_predictions is not None:
|
||||||
article_sentiments = ['良好' if pred == 0 else '不良' for pred in article_predictions]
|
article_sentiments = []
|
||||||
|
for pred, prob in zip(article_predictions, article_probs):
|
||||||
|
label = '良好' if pred == 0 else '不良'
|
||||||
|
confidence = prob[pred]
|
||||||
|
article_sentiments.append(f"{label} ({confidence:.2%})")
|
||||||
else:
|
else:
|
||||||
article_sentiments = []
|
article_sentiments = []
|
||||||
|
|
||||||
|
|||||||
+52
-70
@@ -1,4 +1,4 @@
|
|||||||
from flask import Flask, session, render_template, redirect, Blueprint, request
|
from flask import Flask, session, render_template, redirect, Blueprint, request, jsonify
|
||||||
from utils.mynlp import SnowNLP
|
from utils.mynlp import SnowNLP
|
||||||
from utils.getHomePageData import *
|
from utils.getHomePageData import *
|
||||||
from utils.getHotWordPageData import *
|
from utils.getHotWordPageData import *
|
||||||
@@ -9,9 +9,7 @@ from utils.getTopicPageData import *
|
|||||||
from utils.yuqingpredict import *
|
from utils.yuqingpredict import *
|
||||||
from utils.logger import app_logger as logging
|
from utils.logger import app_logger as logging
|
||||||
import torch
|
import torch
|
||||||
from model_pro.MHA import MultiHeadAttentionLayer
|
from BCAT_front.predict import model_manager
|
||||||
from model_pro.classifier import FinalClassifier
|
|
||||||
from model_pro.BERT_CTM import BERT_CTM_Model
|
|
||||||
|
|
||||||
pb = Blueprint('page',
|
pb = Blueprint('page',
|
||||||
__name__,
|
__name__,
|
||||||
@@ -21,47 +19,26 @@ pb = Blueprint('page',
|
|||||||
# 设置设备
|
# 设置设备
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
# 加载模型(全局变量,避免重复加载)
|
# 设置模型路径
|
||||||
model_save_path = 'model_pro/final_model.pt'
|
model_save_path = 'model_pro/final_model.pt'
|
||||||
bert_model_path = 'model_pro/bert_model'
|
bert_model_path = 'model_pro/bert_model'
|
||||||
ctm_tokenizer_path = 'model_pro/sentence_bert_model'
|
ctm_tokenizer_path = 'model_pro/sentence_bert_model'
|
||||||
|
|
||||||
|
# 初始化模型
|
||||||
try:
|
try:
|
||||||
classifier_model = torch.load(model_save_path, map_location=device)
|
model_manager.load_models(model_save_path, bert_model_path, ctm_tokenizer_path)
|
||||||
classifier_model.eval()
|
|
||||||
attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)
|
|
||||||
attention_model.to(device)
|
|
||||||
attention_model.eval()
|
|
||||||
bert_ctm_model = BERT_CTM_Model(
|
|
||||||
bert_model_path=bert_model_path,
|
|
||||||
ctm_tokenizer_path=ctm_tokenizer_path
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"模型加载失败: {e}")
|
logging.error(f"模型加载失败: {e}")
|
||||||
|
|
||||||
def predict_sentiment(text):
|
def predict_sentiment(text):
|
||||||
"""使用改进版模型预测单个文本的情感"""
|
"""使用改进版模型预测单个文本的情感"""
|
||||||
try:
|
try:
|
||||||
# 获取文本嵌入
|
predictions, probabilities = model_manager.predict_batch([text])
|
||||||
embeddings = bert_ctm_model.get_bert_embeddings([text])
|
if predictions is not None and len(predictions) > 0:
|
||||||
|
return predictions[0], probabilities[0][predictions[0]]
|
||||||
# 转换为tensor
|
return None, None
|
||||||
batch_x = torch.tensor(embeddings, dtype=torch.float32).to(device)
|
|
||||||
batch_x = torch.mean(batch_x, dim=1)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
# 使用注意力机制
|
|
||||||
attention_output = attention_model(batch_x, batch_x, batch_x)
|
|
||||||
# 获取分类结果
|
|
||||||
outputs = classifier_model(attention_output)
|
|
||||||
outputs = torch.mean(outputs, dim=1)
|
|
||||||
# 获取预测标签和概率
|
|
||||||
probabilities = torch.softmax(outputs, dim=1)
|
|
||||||
_, predicted = torch.max(outputs, 1)
|
|
||||||
|
|
||||||
return predicted.item(), probabilities[0][predicted.item()].item()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"预测过程中出现错误: {e}")
|
logging.error(f"预测过程中出现错误: {e}")
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
@pb.route('/home')
|
@pb.route('/home')
|
||||||
@@ -218,46 +195,51 @@ def yuqingChar():
|
|||||||
|
|
||||||
@pb.route('/yuqingpredict')
|
@pb.route('/yuqingpredict')
|
||||||
def yuqingpredict():
|
def yuqingpredict():
|
||||||
username = session.get('username')
|
try:
|
||||||
TopicList = getAllTopicData()
|
username = session.get('username')
|
||||||
defaultTopic = TopicList[0][0]
|
TopicList = getAllTopicData()
|
||||||
if request.args.get('Topic'):
|
defaultTopic = TopicList[0][0]
|
||||||
defaultTopic = request.args.get('Topic')
|
if request.args.get('Topic'):
|
||||||
TopicLen = getTopicLen(defaultTopic)
|
defaultTopic = request.args.get('Topic')
|
||||||
X, Y = getTopicCreatedAtandpredictData(defaultTopic)
|
TopicLen = getTopicLen(defaultTopic)
|
||||||
|
X, Y = getTopicCreatedAtandpredictData(defaultTopic)
|
||||||
|
|
||||||
# 获取模型选择参数
|
# 获取模型选择参数
|
||||||
model_type = request.args.get('model', 'pro') # 默认使用改进模型
|
model_type = request.args.get('model', 'pro') # 默认使用改进模型
|
||||||
|
|
||||||
if model_type == 'basic':
|
if model_type == 'basic':
|
||||||
# 使用基础模型(SnowNLP)
|
# 使用基础模型(SnowNLP)
|
||||||
value = SnowNLP(defaultTopic).sentiments
|
value = SnowNLP(defaultTopic).sentiments
|
||||||
if value == 0.5:
|
if value == 0.5:
|
||||||
sentences = '中性'
|
sentences = '中性'
|
||||||
elif value > 0.5:
|
elif value > 0.5:
|
||||||
sentences = '正面'
|
sentences = '正面'
|
||||||
elif value < 0.5:
|
elif value < 0.5:
|
||||||
sentences = '负面'
|
sentences = '负面'
|
||||||
else:
|
|
||||||
# 使用改进模型
|
|
||||||
predicted_label, confidence = predict_sentiment(defaultTopic)
|
|
||||||
if predicted_label is not None:
|
|
||||||
sentences = '良好' if predicted_label == 0 else '不良'
|
|
||||||
sentences = f"{sentences} (置信度: {confidence:.2f})"
|
|
||||||
else:
|
else:
|
||||||
sentences = '预测失败'
|
# 使用改进模型
|
||||||
|
predicted_label, confidence = predict_sentiment(defaultTopic)
|
||||||
|
if predicted_label is not None:
|
||||||
|
sentences = '良好' if predicted_label == 0 else '不良'
|
||||||
|
sentences = f"{sentences} (置信度: {confidence:.2%})"
|
||||||
|
else:
|
||||||
|
sentences = '预测失败,请稍后重试'
|
||||||
|
logging.error(f"预测失败,话题: {defaultTopic}")
|
||||||
|
|
||||||
comments = getCommentFilterDataTopic(defaultTopic)
|
comments = getCommentFilterDataTopic(defaultTopic)
|
||||||
return render_template('yuqingpredict.html',
|
return render_template('yuqingpredict.html',
|
||||||
username=username,
|
username=username,
|
||||||
hotWordList=TopicList,
|
hotWordList=TopicList,
|
||||||
defaultHotWord=defaultTopic,
|
defaultHotWord=defaultTopic,
|
||||||
hotWordLen=TopicLen,
|
hotWordLen=TopicLen,
|
||||||
sentences=sentences,
|
sentences=sentences,
|
||||||
xData=X,
|
xData=X,
|
||||||
yData=Y,
|
yData=Y,
|
||||||
comments=comments,
|
comments=comments,
|
||||||
model_type=model_type)
|
model_type=model_type)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"舆情预测页面渲染失败: {e}")
|
||||||
|
return render_template('error.html', error_message="加载舆情预测页面失败,请稍后重试")
|
||||||
|
|
||||||
|
|
||||||
@pb.route('/articleCloud')
|
@pb.route('/articleCloud')
|
||||||
|
|||||||
Reference in New Issue
Block a user