From 826de6184d134a7764007ff64dce0869243c44d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=88=92=E9=85=92=E7=9A=84=E6=9D=8E=E7=99=BD?= <670939375@qq.com> Date: Tue, 4 Feb 2025 21:03:45 +0800 Subject: [PATCH] The old emotion recognition model has been replaced with the new model_pro, and the results have been integrated into the project. --- BCAT_front/predict.py | 91 +++++++++++++++++++++++++---------------- utils/getEchartsData.py | 91 +++++++++++++++++++++++++++++++---------- views/page/page.py | 66 ++++++++++++++++++++++++++---- 3 files changed, 184 insertions(+), 64 deletions(-) diff --git a/BCAT_front/predict.py b/BCAT_front/predict.py index acb99c7..5e0e1ad 100644 --- a/BCAT_front/predict.py +++ b/BCAT_front/predict.py @@ -6,12 +6,12 @@ from tqdm import tqdm import os import sys import json -import chardet # 导入 chardet +import chardet -# 导入您定义的模型和模块 -from MHA import MultiHeadAttentionLayer -from classifier import FinalClassifier -from BERT_CTM import BERT_CTM_Model +# 导入改进版模型的组件 +from model_pro.MHA import MultiHeadAttentionLayer +from model_pro.classifier import FinalClassifier +from model_pro.BERT_CTM import BERT_CTM_Model # 设置设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -30,7 +30,7 @@ def detect_file_encoding(file_path, num_bytes=10000): result = chardet.detect(rawdata) encoding = result['encoding'] confidence = result['confidence'] - print(f"Detected encoding: {encoding} with confidence {confidence}") + print(f"检测到的编码: {encoding}, 置信度: {confidence}") return encoding @@ -42,8 +42,6 @@ def get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path, n_compon n_components=n_components, num_epochs=num_epochs ) - # 加载已保存的CTM模型 - bert_ctm_model.load_model() # 获取嵌入 embeddings = bert_ctm_model.get_bert_embeddings(texts) return embeddings @@ -60,15 +58,11 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_ num_classes=2): try: # 加载模型 - # 修改这里,设置 weights_only=True 以消除 FutureWarning - checkpoint = torch.load(model_save_path, map_location=device, weights_only=False) - classifier_model = FinalClassifier(input_dim=768, num_classes=num_classes) - classifier_model.load_state_dict(checkpoint['classifier_model_state_dict']) - classifier_model.to(device) + print("加载模型...") + classifier_model = torch.load(model_save_path, map_location=device) classifier_model.eval() attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8) - attention_model.load_state_dict(checkpoint['attention_model_state_dict']) attention_model.to(device) attention_model.eval() @@ -76,11 +70,12 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_ encoding = detect_file_encoding(input_data_path) # 读取输入数据 + print("读取输入数据...") data = pd.read_csv(input_data_path, encoding=encoding) texts = data['TEXT'].tolist() # 生成嵌入 - print("Generating embeddings...") + print("生成文本嵌入...") embeddings = get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path) # 准备DataLoader @@ -88,63 +83,89 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_ # 存储预测结果 all_predictions = [] + all_probabilities = [] + print("开始预测...") with torch.no_grad(): - for batch in tqdm(data_loader, desc="Predicting"): + for batch in tqdm(data_loader, desc="预测进度"): batch_x = batch[0].to(device) batch_x = torch.mean(batch_x, dim=1) + + # 使用注意力机制 attention_output = attention_model(batch_x, batch_x, batch_x) + + # 获取分类结果 outputs = classifier_model(attention_output) outputs = torch.mean(outputs, dim=1) + + # 获取预测概率 + probabilities = torch.softmax(outputs, dim=1) + + # 获取预测标签 _, predicted = torch.max(outputs, 1) + all_predictions.extend(predicted.cpu().numpy()) + all_probabilities.extend(probabilities.cpu().numpy()) + + # 添加预测结果和概率到数据框 + data['Predicted_Label'] = all_predictions + data['Confidence'] = [prob[pred] for prob, pred in zip(all_probabilities, all_predictions)] # 保存预测结果 - data['Predicted_Label'] = all_predictions data.to_csv(output_path, index=False, encoding='utf-8') - print(f"Predictions saved to {output_path}") + print(f"预测结果已保存到 {output_path}") # 统计标签的个数和占比 label_counts = data['Predicted_Label'].value_counts() total_count = len(data) - stats = {} + stats = { + '统计信息': { + '总样本数': total_count, + '各类别统计': {} + } + } + for label, count in label_counts.items(): label_name = "良好" if label == 0 else "不良" percentage = (count / total_count) * 100 - stats[label_name] = { - 'count': count, - 'percentage': f"{percentage:.2f}%" + confidence_mean = data[data['Predicted_Label'] == label]['Confidence'].mean() + + stats['统计信息']['各类别统计'][label_name] = { + '数量': int(count), + '占比': f"{percentage:.2f}%", + '平均置信度': f"{confidence_mean:.2f}" } - print(f"Label: {label_name}, Count: {count}, Percentage: {percentage:.2f}%") + print(f"标签: {label_name}, 数量: {count}, 占比: {percentage:.2f}%, 平均置信度: {confidence_mean:.2f}") # 将统计信息保存到 JSON 文件 with open(stats_output_path, 'w', encoding='utf-8') as f: - json.dump(stats, f, ensure_ascii=False) + json.dump(stats, f, ensure_ascii=False, indent=4) - return True # 成功执行 + return True except Exception as e: - print(f"Error during prediction: {e}") - return False # 执行失败 + print(f"预测过程中出现错误: {e}") + return False if __name__ == "__main__": if len(sys.argv) != 3: - print("Usage: python using_example.py ") + print("使用方法: python predict.py ") sys.exit(1) input_data_path = sys.argv[1] stats_output_path = sys.argv[2] + # 定义路径 - model_save_path = 'BCAT/final_model.pt' - output_path = 'BCAT/predictions.csv' # 保存预测结果的文件 - bert_model_path = 'BCAT/bert_model' - ctm_tokenizer_path = 'BCAT/sentence_bert_model' + model_save_path = 'model_pro/final_model.pt' + output_path = 'model_pro/predictions.csv' + bert_model_path = 'model_pro/bert_model' + ctm_tokenizer_path = 'model_pro/sentence_bert_model' # 执行预测 success = predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_tokenizer_path, - stats_output_path) + stats_output_path) if success: - sys.exit(0) # 成功 + sys.exit(0) else: - sys.exit(1) # 失败 + sys.exit(1) diff --git a/utils/getEchartsData.py b/utils/getEchartsData.py index e7a0f9c..6a89525 100644 --- a/utils/getEchartsData.py +++ b/utils/getEchartsData.py @@ -1,10 +1,59 @@ from utils.getPublicData import * # Import utility functions for data retrieval from utils.mynlp import SnowNLP # Import SnowNLP for sentiment analysis from collections import Counter # Import Counter for counting occurrences +import torch +from model_pro.MHA import MultiHeadAttentionLayer +from model_pro.classifier import FinalClassifier +from model_pro.BERT_CTM import BERT_CTM_Model articleList = getAllArticleData() # Retrieve all article data commentList = getAllCommentsData() # Retrieve all comment data +# 设置设备 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# 加载模型(全局变量,避免重复加载) +model_save_path = 'model_pro/final_model.pt' +bert_model_path = 'model_pro/bert_model' +ctm_tokenizer_path = 'model_pro/sentence_bert_model' + +try: + classifier_model = torch.load(model_save_path, map_location=device) + classifier_model.eval() + attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8) + attention_model.to(device) + attention_model.eval() + bert_ctm_model = BERT_CTM_Model( + bert_model_path=bert_model_path, + ctm_tokenizer_path=ctm_tokenizer_path + ) +except Exception as e: + print(f"模型加载失败: {e}") + +def predict_sentiment(texts): + """使用改进版模型预测情感""" + try: + # 获取文本嵌入 + embeddings = bert_ctm_model.get_bert_embeddings(texts) + + # 转换为tensor + batch_x = torch.tensor(embeddings, dtype=torch.float32).to(device) + batch_x = torch.mean(batch_x, dim=1) + + with torch.no_grad(): + # 使用注意力机制 + attention_output = attention_model(batch_x, batch_x, batch_x) + # 获取分类结果 + outputs = classifier_model(attention_output) + outputs = torch.mean(outputs, dim=1) + # 获取预测标签 + _, predicted = torch.max(outputs, 1) + + return predicted.cpu().numpy() + except Exception as e: + print(f"预测过程中出现错误: {e}") + return None + def getTypeList(): # Return a list of unique article types return list(set([x[8] for x in articleList])) @@ -119,32 +168,32 @@ def getYuQingCharDataOne(): return X, Y, biedata def getYuQingCharDataTwo(): - # Analyze sentiment of comments and articles - comment_sentiments = [] - for comment in commentList: - emotionValue = SnowNLP(comment[4]).sentiments - if emotionValue > 0.4: - comment_sentiments.append('正面') - elif emotionValue < 0.2: - comment_sentiments.append('负面') - else: - comment_sentiments.append('中性') - comment_counts = Counter(comment_sentiments) + # 分析评论和文章的情感 + comment_texts = [comment[4] for comment in commentList] + article_texts = [article[5] for article in articleList] - article_sentiments = [] - for article in articleList: - emotionValue = SnowNLP(article[5]).sentiments - if emotionValue > 0.4: - article_sentiments.append('正面') - elif emotionValue < 0.2: - article_sentiments.append('负面') - else: - article_sentiments.append('中性') + # 预测评论情感 + comment_predictions = predict_sentiment(comment_texts) + if comment_predictions is not None: + comment_sentiments = ['良好' if pred == 0 else '不良' for pred in comment_predictions] + else: + comment_sentiments = [] + + # 预测文章情感 + article_predictions = predict_sentiment(article_texts) + if article_predictions is not None: + article_sentiments = ['良好' if pred == 0 else '不良' for pred in article_predictions] + else: + article_sentiments = [] + + # 统计结果 + comment_counts = Counter(comment_sentiments) article_counts = Counter(article_sentiments) - X = ['正面', '中性', '负面'] + X = ['良好', '不良'] biedata1 = [{'name': x, 'value': comment_counts.get(x, 0)} for x in X] biedata2 = [{'name': x, 'value': article_counts.get(x, 0)} for x in X] + return biedata1, biedata2 def getYuQingCharDataThree(): diff --git a/views/page/page.py b/views/page/page.py index 0fb8d7b..658c601 100644 --- a/views/page/page.py +++ b/views/page/page.py @@ -8,12 +8,61 @@ from utils.getEchartsData import * from utils.getTopicPageData import * from utils.yuqingpredict import * from utils.logger import app_logger as logging +import torch +from model_pro.MHA import MultiHeadAttentionLayer +from model_pro.classifier import FinalClassifier +from model_pro.BERT_CTM import BERT_CTM_Model pb = Blueprint('page', __name__, url_prefix='/page', template_folder='templates') +# 设置设备 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# 加载模型(全局变量,避免重复加载) +model_save_path = 'model_pro/final_model.pt' +bert_model_path = 'model_pro/bert_model' +ctm_tokenizer_path = 'model_pro/sentence_bert_model' + +try: + classifier_model = torch.load(model_save_path, map_location=device) + classifier_model.eval() + attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8) + attention_model.to(device) + attention_model.eval() + bert_ctm_model = BERT_CTM_Model( + bert_model_path=bert_model_path, + ctm_tokenizer_path=ctm_tokenizer_path + ) +except Exception as e: + print(f"模型加载失败: {e}") + +def predict_sentiment(text): + """使用改进版模型预测单个文本的情感""" + try: + # 获取文本嵌入 + embeddings = bert_ctm_model.get_bert_embeddings([text]) + + # 转换为tensor + batch_x = torch.tensor(embeddings, dtype=torch.float32).to(device) + batch_x = torch.mean(batch_x, dim=1) + + with torch.no_grad(): + # 使用注意力机制 + attention_output = attention_model(batch_x, batch_x, batch_x) + # 获取分类结果 + outputs = classifier_model(attention_output) + outputs = torch.mean(outputs, dim=1) + # 获取预测标签和概率 + probabilities = torch.softmax(outputs, dim=1) + _, predicted = torch.max(outputs, 1) + + return predicted.item(), probabilities[0][predicted.item()].item() + except Exception as e: + print(f"预测过程中出现错误: {e}") + return None, None @pb.route('/home') def home(): @@ -172,14 +221,15 @@ def yuqingpredict(): defaultTopic = request.args.get('Topic') TopicLen = getTopicLen(defaultTopic) X, Y = getTopicCreatedAtandpredictData(defaultTopic) - sentences = '' - value = SnowNLP(defaultTopic).sentiments - if value == 0.5: - sentences = '中性' - elif value > 0.5: - sentences = '正面' - elif value < 0.5: - sentences = '负面' + + # 使用改进版模型进行情感预测 + predicted_label, confidence = predict_sentiment(defaultTopic) + if predicted_label is not None: + sentences = '良好' if predicted_label == 0 else '不良' + sentences = f"{sentences} (置信度: {confidence:.2f})" + else: + sentences = '预测失败' + comments = getCommentFilterDataTopic(defaultTopic) return render_template('yuqingpredict.html', username=username,