From a4fba83c4a27c37b3660e69a8fa0f4bb2088e120 Mon Sep 17 00:00:00 2001 From: juanboy <2980526980@qq.com> Date: Fri, 18 Oct 2024 22:15:21 +0800 Subject: [PATCH] predict.demo built --- BCAT_front/predict.py | 150 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 BCAT_front/predict.py diff --git a/BCAT_front/predict.py b/BCAT_front/predict.py new file mode 100644 index 0000000..acb99c7 --- /dev/null +++ b/BCAT_front/predict.py @@ -0,0 +1,150 @@ +import torch +import pandas as pd +import numpy as np +from torch.utils.data import DataLoader, TensorDataset +from tqdm import tqdm +import os +import sys +import json +import chardet # 导入 chardet + +# 导入您定义的模型和模块 +from MHA import MultiHeadAttentionLayer +from classifier import FinalClassifier +from BERT_CTM import BERT_CTM_Model + +# 设置设备 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def detect_file_encoding(file_path, num_bytes=10000): + """ + 使用 chardet 检测文件的编码。 + + :param file_path: 文件路径 + :param num_bytes: 用于检测的字节数 + :return: 检测到的编码 + """ + with open(file_path, 'rb') as f: + rawdata = f.read(num_bytes) + result = chardet.detect(rawdata) + encoding = result['encoding'] + confidence = result['confidence'] + print(f"Detected encoding: {encoding} with confidence {confidence}") + return encoding + + +def get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=20): + # 创建BERT_CTM_Model实例 + bert_ctm_model = BERT_CTM_Model( + bert_model_path=bert_model_path, + ctm_tokenizer_path=ctm_tokenizer_path, + n_components=n_components, + num_epochs=num_epochs + ) + # 加载已保存的CTM模型 + bert_ctm_model.load_model() + # 获取嵌入 + embeddings = bert_ctm_model.get_bert_embeddings(texts) + return embeddings + + +def prepare_dataloader(features, batch_size): + tensor_x = torch.tensor(features, dtype=torch.float32) + dataset = TensorDataset(tensor_x) + return DataLoader(dataset, batch_size=batch_size, shuffle=False) + + +def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_tokenizer_path, stats_output_path, + batch_size=128, + num_classes=2): + try: + # 加载模型 + # 修改这里,设置 weights_only=True 以消除 FutureWarning + checkpoint = torch.load(model_save_path, map_location=device, weights_only=False) + classifier_model = FinalClassifier(input_dim=768, num_classes=num_classes) + classifier_model.load_state_dict(checkpoint['classifier_model_state_dict']) + classifier_model.to(device) + classifier_model.eval() + + attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8) + attention_model.load_state_dict(checkpoint['attention_model_state_dict']) + attention_model.to(device) + attention_model.eval() + + # 检测文件编码 + encoding = detect_file_encoding(input_data_path) + + # 读取输入数据 + data = pd.read_csv(input_data_path, encoding=encoding) + texts = data['TEXT'].tolist() + + # 生成嵌入 + print("Generating embeddings...") + embeddings = get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path) + + # 准备DataLoader + data_loader = prepare_dataloader(embeddings, batch_size) + + # 存储预测结果 + all_predictions = [] + + with torch.no_grad(): + for batch in tqdm(data_loader, desc="Predicting"): + batch_x = batch[0].to(device) + batch_x = torch.mean(batch_x, dim=1) + attention_output = attention_model(batch_x, batch_x, batch_x) + outputs = classifier_model(attention_output) + outputs = torch.mean(outputs, dim=1) + _, predicted = torch.max(outputs, 1) + all_predictions.extend(predicted.cpu().numpy()) + + # 保存预测结果 + data['Predicted_Label'] = all_predictions + data.to_csv(output_path, index=False, encoding='utf-8') + print(f"Predictions saved to {output_path}") + + # 统计标签的个数和占比 + label_counts = data['Predicted_Label'].value_counts() + total_count = len(data) + stats = {} + for label, count in label_counts.items(): + label_name = "良好" if label == 0 else "不良" + percentage = (count / total_count) * 100 + stats[label_name] = { + 'count': count, + 'percentage': f"{percentage:.2f}%" + } + print(f"Label: {label_name}, Count: {count}, Percentage: {percentage:.2f}%") + + # 将统计信息保存到 JSON 文件 + with open(stats_output_path, 'w', encoding='utf-8') as f: + json.dump(stats, f, ensure_ascii=False) + + return True # 成功执行 + except Exception as e: + print(f"Error during prediction: {e}") + return False # 执行失败 + + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: python using_example.py ") + sys.exit(1) + + input_data_path = sys.argv[1] + stats_output_path = sys.argv[2] + # 定义路径 + model_save_path = 'BCAT/final_model.pt' + output_path = 'BCAT/predictions.csv' # 保存预测结果的文件 + bert_model_path = 'BCAT/bert_model' + ctm_tokenizer_path = 'BCAT/sentence_bert_model' + + # 执行预测 + success = predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_tokenizer_path, + stats_output_path) + + if success: + sys.exit(0) # 成功 + else: + sys.exit(1) # 失败