The old emotion recognition model has been replaced with the new model_pro, and the results have been integrated into the project.

This commit is contained in:
戒酒的李白
2025-02-04 21:03:45 +08:00
parent a9108a909c
commit 826de6184d
3 changed files with 184 additions and 64 deletions
+56 -35
View File
@@ -6,12 +6,12 @@ from tqdm import tqdm
import os
import sys
import json
import chardet # 导入 chardet
import chardet
# 导入您定义的模型和模块
from MHA import MultiHeadAttentionLayer
from classifier import FinalClassifier
from BERT_CTM import BERT_CTM_Model
# 导入改进版模型的组件
from model_pro.MHA import MultiHeadAttentionLayer
from model_pro.classifier import FinalClassifier
from model_pro.BERT_CTM import BERT_CTM_Model
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -30,7 +30,7 @@ def detect_file_encoding(file_path, num_bytes=10000):
result = chardet.detect(rawdata)
encoding = result['encoding']
confidence = result['confidence']
print(f"Detected encoding: {encoding} with confidence {confidence}")
print(f"检测到的编码: {encoding}, 置信度: {confidence}")
return encoding
@@ -42,8 +42,6 @@ def get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path, n_compon
n_components=n_components,
num_epochs=num_epochs
)
# 加载已保存的CTM模型
bert_ctm_model.load_model()
# 获取嵌入
embeddings = bert_ctm_model.get_bert_embeddings(texts)
return embeddings
@@ -60,15 +58,11 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_
num_classes=2):
try:
# 加载模型
# 修改这里,设置 weights_only=True 以消除 FutureWarning
checkpoint = torch.load(model_save_path, map_location=device, weights_only=False)
classifier_model = FinalClassifier(input_dim=768, num_classes=num_classes)
classifier_model.load_state_dict(checkpoint['classifier_model_state_dict'])
classifier_model.to(device)
print("加载模型...")
classifier_model = torch.load(model_save_path, map_location=device)
classifier_model.eval()
attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)
attention_model.load_state_dict(checkpoint['attention_model_state_dict'])
attention_model.to(device)
attention_model.eval()
@@ -76,11 +70,12 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_
encoding = detect_file_encoding(input_data_path)
# 读取输入数据
print("读取输入数据...")
data = pd.read_csv(input_data_path, encoding=encoding)
texts = data['TEXT'].tolist()
# 生成嵌入
print("Generating embeddings...")
print("生成文本嵌入...")
embeddings = get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path)
# 准备DataLoader
@@ -88,63 +83,89 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_
# 存储预测结果
all_predictions = []
all_probabilities = []
print("开始预测...")
with torch.no_grad():
for batch in tqdm(data_loader, desc="Predicting"):
for batch in tqdm(data_loader, desc="预测进度"):
batch_x = batch[0].to(device)
batch_x = torch.mean(batch_x, dim=1)
# 使用注意力机制
attention_output = attention_model(batch_x, batch_x, batch_x)
# 获取分类结果
outputs = classifier_model(attention_output)
outputs = torch.mean(outputs, dim=1)
# 获取预测概率
probabilities = torch.softmax(outputs, dim=1)
# 获取预测标签
_, predicted = torch.max(outputs, 1)
all_predictions.extend(predicted.cpu().numpy())
all_probabilities.extend(probabilities.cpu().numpy())
# 添加预测结果和概率到数据框
data['Predicted_Label'] = all_predictions
data['Confidence'] = [prob[pred] for prob, pred in zip(all_probabilities, all_predictions)]
# 保存预测结果
data['Predicted_Label'] = all_predictions
data.to_csv(output_path, index=False, encoding='utf-8')
print(f"Predictions saved to {output_path}")
print(f"预测结果已保存到 {output_path}")
# 统计标签的个数和占比
label_counts = data['Predicted_Label'].value_counts()
total_count = len(data)
stats = {}
stats = {
'统计信息': {
'总样本数': total_count,
'各类别统计': {}
}
}
for label, count in label_counts.items():
label_name = "良好" if label == 0 else "不良"
percentage = (count / total_count) * 100
stats[label_name] = {
'count': count,
'percentage': f"{percentage:.2f}%"
confidence_mean = data[data['Predicted_Label'] == label]['Confidence'].mean()
stats['统计信息']['各类别统计'][label_name] = {
'数量': int(count),
'占比': f"{percentage:.2f}%",
'平均置信度': f"{confidence_mean:.2f}"
}
print(f"Label: {label_name}, Count: {count}, Percentage: {percentage:.2f}%")
print(f"标签: {label_name}, 数量: {count}, 占比: {percentage:.2f}%, 平均置信度: {confidence_mean:.2f}")
# 将统计信息保存到 JSON 文件
with open(stats_output_path, 'w', encoding='utf-8') as f:
json.dump(stats, f, ensure_ascii=False)
json.dump(stats, f, ensure_ascii=False, indent=4)
return True # 成功执行
return True
except Exception as e:
print(f"Error during prediction: {e}")
return False # 执行失败
print(f"预测过程中出现错误: {e}")
return False
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: python using_example.py <input_data_path> <stats_output_path>")
print("使用方法: python predict.py <input_data_path> <stats_output_path>")
sys.exit(1)
input_data_path = sys.argv[1]
stats_output_path = sys.argv[2]
# 定义路径
model_save_path = 'BCAT/final_model.pt'
output_path = 'BCAT/predictions.csv' # 保存预测结果的文件
bert_model_path = 'BCAT/bert_model'
ctm_tokenizer_path = 'BCAT/sentence_bert_model'
model_save_path = 'model_pro/final_model.pt'
output_path = 'model_pro/predictions.csv'
bert_model_path = 'model_pro/bert_model'
ctm_tokenizer_path = 'model_pro/sentence_bert_model'
# 执行预测
success = predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_tokenizer_path,
stats_output_path)
stats_output_path)
if success:
sys.exit(0) # 成功
sys.exit(0)
else:
sys.exit(1) # 失败
sys.exit(1)