The old emotion recognition model has been replaced with the new model_pro, and the results have been integrated into the project.

This commit is contained in:
戒酒的李白
2025-02-04 21:03:45 +08:00
parent a9108a909c
commit 826de6184d
3 changed files with 184 additions and 64 deletions
+70 -21
View File
@@ -1,10 +1,59 @@
from utils.getPublicData import * # Import utility functions for data retrieval
from utils.mynlp import SnowNLP # Import SnowNLP for sentiment analysis
from collections import Counter # Import Counter for counting occurrences
import torch
from model_pro.MHA import MultiHeadAttentionLayer
from model_pro.classifier import FinalClassifier
from model_pro.BERT_CTM import BERT_CTM_Model
articleList = getAllArticleData() # Retrieve all article data
commentList = getAllCommentsData() # Retrieve all comment data
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载模型(全局变量,避免重复加载)
model_save_path = 'model_pro/final_model.pt'
bert_model_path = 'model_pro/bert_model'
ctm_tokenizer_path = 'model_pro/sentence_bert_model'
try:
classifier_model = torch.load(model_save_path, map_location=device)
classifier_model.eval()
attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)
attention_model.to(device)
attention_model.eval()
bert_ctm_model = BERT_CTM_Model(
bert_model_path=bert_model_path,
ctm_tokenizer_path=ctm_tokenizer_path
)
except Exception as e:
print(f"模型加载失败: {e}")
def predict_sentiment(texts):
"""使用改进版模型预测情感"""
try:
# 获取文本嵌入
embeddings = bert_ctm_model.get_bert_embeddings(texts)
# 转换为tensor
batch_x = torch.tensor(embeddings, dtype=torch.float32).to(device)
batch_x = torch.mean(batch_x, dim=1)
with torch.no_grad():
# 使用注意力机制
attention_output = attention_model(batch_x, batch_x, batch_x)
# 获取分类结果
outputs = classifier_model(attention_output)
outputs = torch.mean(outputs, dim=1)
# 获取预测标签
_, predicted = torch.max(outputs, 1)
return predicted.cpu().numpy()
except Exception as e:
print(f"预测过程中出现错误: {e}")
return None
def getTypeList():
# Return a list of unique article types
return list(set([x[8] for x in articleList]))
@@ -119,32 +168,32 @@ def getYuQingCharDataOne():
return X, Y, biedata
def getYuQingCharDataTwo():
# Analyze sentiment of comments and articles
comment_sentiments = []
for comment in commentList:
emotionValue = SnowNLP(comment[4]).sentiments
if emotionValue > 0.4:
comment_sentiments.append('正面')
elif emotionValue < 0.2:
comment_sentiments.append('负面')
else:
comment_sentiments.append('中性')
comment_counts = Counter(comment_sentiments)
# 分析评论和文章的情感
comment_texts = [comment[4] for comment in commentList]
article_texts = [article[5] for article in articleList]
article_sentiments = []
for article in articleList:
emotionValue = SnowNLP(article[5]).sentiments
if emotionValue > 0.4:
article_sentiments.append('正面')
elif emotionValue < 0.2:
article_sentiments.append('负面')
else:
article_sentiments.append('中性')
# 预测评论情感
comment_predictions = predict_sentiment(comment_texts)
if comment_predictions is not None:
comment_sentiments = ['良好' if pred == 0 else '不良' for pred in comment_predictions]
else:
comment_sentiments = []
# 预测文章情感
article_predictions = predict_sentiment(article_texts)
if article_predictions is not None:
article_sentiments = ['良好' if pred == 0 else '不良' for pred in article_predictions]
else:
article_sentiments = []
# 统计结果
comment_counts = Counter(comment_sentiments)
article_counts = Counter(article_sentiments)
X = ['正面', '中性', '负面']
X = ['良好', '不良']
biedata1 = [{'name': x, 'value': comment_counts.get(x, 0)} for x in X]
biedata2 = [{'name': x, 'value': article_counts.get(x, 0)} for x in X]
return biedata1, biedata2
def getYuQingCharDataThree():