Optimize model loading and prediction performance, implement the singleton pattern, and provide comprehensive error handling and error messages, along with confidence level display.
This commit is contained in:
+55
-73
@@ -1,4 +1,4 @@
|
||||
from flask import Flask, session, render_template, redirect, Blueprint, request
|
||||
from flask import Flask, session, render_template, redirect, Blueprint, request, jsonify
|
||||
from utils.mynlp import SnowNLP
|
||||
from utils.getHomePageData import *
|
||||
from utils.getHotWordPageData import *
|
||||
@@ -9,9 +9,7 @@ from utils.getTopicPageData import *
|
||||
from utils.yuqingpredict import *
|
||||
from utils.logger import app_logger as logging
|
||||
import torch
|
||||
from model_pro.MHA import MultiHeadAttentionLayer
|
||||
from model_pro.classifier import FinalClassifier
|
||||
from model_pro.BERT_CTM import BERT_CTM_Model
|
||||
from BCAT_front.predict import model_manager
|
||||
|
||||
pb = Blueprint('page',
|
||||
__name__,
|
||||
@@ -21,47 +19,26 @@ pb = Blueprint('page',
|
||||
# 设置设备
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# 加载模型(全局变量,避免重复加载)
|
||||
# 设置模型路径
|
||||
model_save_path = 'model_pro/final_model.pt'
|
||||
bert_model_path = 'model_pro/bert_model'
|
||||
ctm_tokenizer_path = 'model_pro/sentence_bert_model'
|
||||
|
||||
# 初始化模型
|
||||
try:
|
||||
classifier_model = torch.load(model_save_path, map_location=device)
|
||||
classifier_model.eval()
|
||||
attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)
|
||||
attention_model.to(device)
|
||||
attention_model.eval()
|
||||
bert_ctm_model = BERT_CTM_Model(
|
||||
bert_model_path=bert_model_path,
|
||||
ctm_tokenizer_path=ctm_tokenizer_path
|
||||
)
|
||||
model_manager.load_models(model_save_path, bert_model_path, ctm_tokenizer_path)
|
||||
except Exception as e:
|
||||
print(f"模型加载失败: {e}")
|
||||
logging.error(f"模型加载失败: {e}")
|
||||
|
||||
def predict_sentiment(text):
|
||||
"""使用改进版模型预测单个文本的情感"""
|
||||
try:
|
||||
# 获取文本嵌入
|
||||
embeddings = bert_ctm_model.get_bert_embeddings([text])
|
||||
|
||||
# 转换为tensor
|
||||
batch_x = torch.tensor(embeddings, dtype=torch.float32).to(device)
|
||||
batch_x = torch.mean(batch_x, dim=1)
|
||||
|
||||
with torch.no_grad():
|
||||
# 使用注意力机制
|
||||
attention_output = attention_model(batch_x, batch_x, batch_x)
|
||||
# 获取分类结果
|
||||
outputs = classifier_model(attention_output)
|
||||
outputs = torch.mean(outputs, dim=1)
|
||||
# 获取预测标签和概率
|
||||
probabilities = torch.softmax(outputs, dim=1)
|
||||
_, predicted = torch.max(outputs, 1)
|
||||
|
||||
return predicted.item(), probabilities[0][predicted.item()].item()
|
||||
predictions, probabilities = model_manager.predict_batch([text])
|
||||
if predictions is not None and len(predictions) > 0:
|
||||
return predictions[0], probabilities[0][predictions[0]]
|
||||
return None, None
|
||||
except Exception as e:
|
||||
print(f"预测过程中出现错误: {e}")
|
||||
logging.error(f"预测过程中出现错误: {e}")
|
||||
return None, None
|
||||
|
||||
@pb.route('/home')
|
||||
@@ -218,46 +195,51 @@ def yuqingChar():
|
||||
|
||||
@pb.route('/yuqingpredict')
|
||||
def yuqingpredict():
|
||||
username = session.get('username')
|
||||
TopicList = getAllTopicData()
|
||||
defaultTopic = TopicList[0][0]
|
||||
if request.args.get('Topic'):
|
||||
defaultTopic = request.args.get('Topic')
|
||||
TopicLen = getTopicLen(defaultTopic)
|
||||
X, Y = getTopicCreatedAtandpredictData(defaultTopic)
|
||||
|
||||
# 获取模型选择参数
|
||||
model_type = request.args.get('model', 'pro') # 默认使用改进模型
|
||||
|
||||
if model_type == 'basic':
|
||||
# 使用基础模型(SnowNLP)
|
||||
value = SnowNLP(defaultTopic).sentiments
|
||||
if value == 0.5:
|
||||
sentences = '中性'
|
||||
elif value > 0.5:
|
||||
sentences = '正面'
|
||||
elif value < 0.5:
|
||||
sentences = '负面'
|
||||
else:
|
||||
# 使用改进模型
|
||||
predicted_label, confidence = predict_sentiment(defaultTopic)
|
||||
if predicted_label is not None:
|
||||
sentences = '良好' if predicted_label == 0 else '不良'
|
||||
sentences = f"{sentences} (置信度: {confidence:.2f})"
|
||||
try:
|
||||
username = session.get('username')
|
||||
TopicList = getAllTopicData()
|
||||
defaultTopic = TopicList[0][0]
|
||||
if request.args.get('Topic'):
|
||||
defaultTopic = request.args.get('Topic')
|
||||
TopicLen = getTopicLen(defaultTopic)
|
||||
X, Y = getTopicCreatedAtandpredictData(defaultTopic)
|
||||
|
||||
# 获取模型选择参数
|
||||
model_type = request.args.get('model', 'pro') # 默认使用改进模型
|
||||
|
||||
if model_type == 'basic':
|
||||
# 使用基础模型(SnowNLP)
|
||||
value = SnowNLP(defaultTopic).sentiments
|
||||
if value == 0.5:
|
||||
sentences = '中性'
|
||||
elif value > 0.5:
|
||||
sentences = '正面'
|
||||
elif value < 0.5:
|
||||
sentences = '负面'
|
||||
else:
|
||||
sentences = '预测失败'
|
||||
|
||||
comments = getCommentFilterDataTopic(defaultTopic)
|
||||
return render_template('yuqingpredict.html',
|
||||
username=username,
|
||||
hotWordList=TopicList,
|
||||
defaultHotWord=defaultTopic,
|
||||
hotWordLen=TopicLen,
|
||||
sentences=sentences,
|
||||
xData=X,
|
||||
yData=Y,
|
||||
comments=comments,
|
||||
model_type=model_type)
|
||||
# 使用改进模型
|
||||
predicted_label, confidence = predict_sentiment(defaultTopic)
|
||||
if predicted_label is not None:
|
||||
sentences = '良好' if predicted_label == 0 else '不良'
|
||||
sentences = f"{sentences} (置信度: {confidence:.2%})"
|
||||
else:
|
||||
sentences = '预测失败,请稍后重试'
|
||||
logging.error(f"预测失败,话题: {defaultTopic}")
|
||||
|
||||
comments = getCommentFilterDataTopic(defaultTopic)
|
||||
return render_template('yuqingpredict.html',
|
||||
username=username,
|
||||
hotWordList=TopicList,
|
||||
defaultHotWord=defaultTopic,
|
||||
hotWordLen=TopicLen,
|
||||
sentences=sentences,
|
||||
xData=X,
|
||||
yData=Y,
|
||||
comments=comments,
|
||||
model_type=model_type)
|
||||
except Exception as e:
|
||||
logging.error(f"舆情预测页面渲染失败: {e}")
|
||||
return render_template('error.html', error_message="加载舆情预测页面失败,请稍后重试")
|
||||
|
||||
|
||||
@pb.route('/articleCloud')
|
||||
|
||||
Reference in New Issue
Block a user