The old emotion recognition model has been replaced with the new model_pro, and the results have been integrated into the project.
This commit is contained in:
+70
-21
@@ -1,10 +1,59 @@
|
||||
from utils.getPublicData import * # Import utility functions for data retrieval
|
||||
from utils.mynlp import SnowNLP # Import SnowNLP for sentiment analysis
|
||||
from collections import Counter # Import Counter for counting occurrences
|
||||
import torch
|
||||
from model_pro.MHA import MultiHeadAttentionLayer
|
||||
from model_pro.classifier import FinalClassifier
|
||||
from model_pro.BERT_CTM import BERT_CTM_Model
|
||||
|
||||
articleList = getAllArticleData() # Retrieve all article data
|
||||
commentList = getAllCommentsData() # Retrieve all comment data
|
||||
|
||||
# 设置设备
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# 加载模型(全局变量,避免重复加载)
|
||||
model_save_path = 'model_pro/final_model.pt'
|
||||
bert_model_path = 'model_pro/bert_model'
|
||||
ctm_tokenizer_path = 'model_pro/sentence_bert_model'
|
||||
|
||||
try:
|
||||
classifier_model = torch.load(model_save_path, map_location=device)
|
||||
classifier_model.eval()
|
||||
attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)
|
||||
attention_model.to(device)
|
||||
attention_model.eval()
|
||||
bert_ctm_model = BERT_CTM_Model(
|
||||
bert_model_path=bert_model_path,
|
||||
ctm_tokenizer_path=ctm_tokenizer_path
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"模型加载失败: {e}")
|
||||
|
||||
def predict_sentiment(texts):
|
||||
"""使用改进版模型预测情感"""
|
||||
try:
|
||||
# 获取文本嵌入
|
||||
embeddings = bert_ctm_model.get_bert_embeddings(texts)
|
||||
|
||||
# 转换为tensor
|
||||
batch_x = torch.tensor(embeddings, dtype=torch.float32).to(device)
|
||||
batch_x = torch.mean(batch_x, dim=1)
|
||||
|
||||
with torch.no_grad():
|
||||
# 使用注意力机制
|
||||
attention_output = attention_model(batch_x, batch_x, batch_x)
|
||||
# 获取分类结果
|
||||
outputs = classifier_model(attention_output)
|
||||
outputs = torch.mean(outputs, dim=1)
|
||||
# 获取预测标签
|
||||
_, predicted = torch.max(outputs, 1)
|
||||
|
||||
return predicted.cpu().numpy()
|
||||
except Exception as e:
|
||||
print(f"预测过程中出现错误: {e}")
|
||||
return None
|
||||
|
||||
def getTypeList():
|
||||
# Return a list of unique article types
|
||||
return list(set([x[8] for x in articleList]))
|
||||
@@ -119,32 +168,32 @@ def getYuQingCharDataOne():
|
||||
return X, Y, biedata
|
||||
|
||||
def getYuQingCharDataTwo():
|
||||
# Analyze sentiment of comments and articles
|
||||
comment_sentiments = []
|
||||
for comment in commentList:
|
||||
emotionValue = SnowNLP(comment[4]).sentiments
|
||||
if emotionValue > 0.4:
|
||||
comment_sentiments.append('正面')
|
||||
elif emotionValue < 0.2:
|
||||
comment_sentiments.append('负面')
|
||||
else:
|
||||
comment_sentiments.append('中性')
|
||||
comment_counts = Counter(comment_sentiments)
|
||||
# 分析评论和文章的情感
|
||||
comment_texts = [comment[4] for comment in commentList]
|
||||
article_texts = [article[5] for article in articleList]
|
||||
|
||||
article_sentiments = []
|
||||
for article in articleList:
|
||||
emotionValue = SnowNLP(article[5]).sentiments
|
||||
if emotionValue > 0.4:
|
||||
article_sentiments.append('正面')
|
||||
elif emotionValue < 0.2:
|
||||
article_sentiments.append('负面')
|
||||
else:
|
||||
article_sentiments.append('中性')
|
||||
# 预测评论情感
|
||||
comment_predictions = predict_sentiment(comment_texts)
|
||||
if comment_predictions is not None:
|
||||
comment_sentiments = ['良好' if pred == 0 else '不良' for pred in comment_predictions]
|
||||
else:
|
||||
comment_sentiments = []
|
||||
|
||||
# 预测文章情感
|
||||
article_predictions = predict_sentiment(article_texts)
|
||||
if article_predictions is not None:
|
||||
article_sentiments = ['良好' if pred == 0 else '不良' for pred in article_predictions]
|
||||
else:
|
||||
article_sentiments = []
|
||||
|
||||
# 统计结果
|
||||
comment_counts = Counter(comment_sentiments)
|
||||
article_counts = Counter(article_sentiments)
|
||||
|
||||
X = ['正面', '中性', '负面']
|
||||
X = ['良好', '不良']
|
||||
biedata1 = [{'name': x, 'value': comment_counts.get(x, 0)} for x in X]
|
||||
biedata2 = [{'name': x, 'value': article_counts.get(x, 0)} for x in X]
|
||||
|
||||
return biedata1, biedata2
|
||||
|
||||
def getYuQingCharDataThree():
|
||||
|
||||
Reference in New Issue
Block a user