From b49d16ab071c483e3f2de817bd88296f0b71403c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=88=92=E9=85=92=E7=9A=84=E6=9D=8E=E7=99=BD?= <670939375@qq.com> Date: Fri, 4 Oct 2024 22:18:54 +0800 Subject: [PATCH] BCAT Preliminary --- model_pro/BCAT.py | 229 ++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 201 insertions(+), 28 deletions(-) diff --git a/model_pro/BCAT.py b/model_pro/BCAT.py index d3ac673..4f916cc 100644 --- a/model_pro/BCAT.py +++ b/model_pro/BCAT.py @@ -1,51 +1,224 @@ -import os +import torch +import pandas as pd import numpy as np -from BERT_CTM import BERT_CTM_Model # 假设BERT_CTM模型在这个文件中 +from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score +from torch.utils.data import DataLoader, TensorDataset +from CNN import extract_CNN_features +from MHA import MultiHeadAttentionLayer +from classifier import FinalClassifier +from BERT_CTM import BERT_CTM_Model +import os +from tqdm import tqdm +from sklearn.metrics import confusion_matrix -# BERT_CTM 嵌入生成和加载函数 + +# BERT_CTM embeddings generation and loading function def get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=20, save_path=None): - """ - 获取或生成 BERT+CTM 嵌入,并保存到文件。 - - :param texts: 需要嵌入的文本 - :param bert_model_path: BERT 模型的路径 - :param ctm_tokenizer_path: CTM tokenizer 的路径 - :param n_components: 生成的主题数量 - :param num_epochs: 训练的epoch数 - :param save_path: 嵌入保存路径 - :return: 生成或加载的嵌入 - """ - # 检查是否已经存在保存的嵌入文件 + # Check if saved embeddings already exist if save_path and os.path.exists(save_path): - print(f"从文件 {save_path} 加载嵌入...") + print(f"Loading embeddings from {save_path}...") embeddings = np.load(save_path) else: - print("生成 BERT+CTM 嵌入...") + print("Generating BERT+CTM embeddings...") bert_ctm_model = BERT_CTM_Model( bert_model_path=bert_model_path, ctm_tokenizer_path=ctm_tokenizer_path, n_components=n_components, num_epochs=num_epochs ) - embeddings = bert_ctm_model.train(texts) # 生成嵌入 + embeddings = bert_ctm_model.train(texts) # Generate embeddings - # 保存嵌入到文件 + # Save embeddings to file if save_path: - print(f"保存嵌入到文件 {save_path}...") + print(f"Saving embeddings to file {save_path}...") np.save(save_path, embeddings) return embeddings +# Data loading and preparation function +def prepare_dataloader(features, labels, batch_size): + """Create DataLoader for training, validation, and testing""" + tensor_x = torch.tensor(features, dtype=torch.float32) + tensor_y = torch.tensor(labels, dtype=torch.long) + dataset = TensorDataset(tensor_x, tensor_y) + return DataLoader(dataset, batch_size=batch_size, shuffle=True) + + +# Model training function +def train_model(train_data_path, valid_data_path, test_data_path, train_labels, valid_labels, test_labels, + bert_model_path, ctm_tokenizer_path, num_heads=8, num_classes=2, epochs=10, batch_size=128, + learning_rate=5e-3, model_save_path='./final_model.pt'): + # Step 1: Get BERT+CTM embeddings + print("Step 1: Getting BERT+CTM embeddings...") + valid_features = get_bert_ctm_embeddings(valid_data_path, bert_model_path, ctm_tokenizer_path, + save_path='valid_embeddings.npy') + test_features = get_bert_ctm_embeddings(test_data_path, bert_model_path, ctm_tokenizer_path, + save_path='test_embeddings.npy') + train_features = get_bert_ctm_embeddings(train_data_path, bert_model_path, ctm_tokenizer_path, + save_path='train_embeddings.npy') + + # Save labels to .npy file + print("Saving labels to labels.npy file...") + np.save('train_labels.npy', train_labels) + np.save('valid_labels.npy', valid_labels) + np.save('test_labels.npy', test_labels) + + # Step 2: Validate label correctness + print("Step 2: Validating label correctness...") + unique_labels_train = np.unique(train_labels) + unique_labels_valid = np.unique(valid_labels) + unique_labels_test = np.unique(test_labels) + print(f"Unique train labels: {unique_labels_train}") + print(f"Train set class distribution: {np.bincount(train_labels)}") + print(f"Unique validation labels: {unique_labels_valid}") + print(f"Validation set class distribution: {np.bincount(valid_labels)}") + print(f"Unique test labels: {unique_labels_test}") + print(f"Test set class distribution: {np.bincount(test_labels)}") + + if len(unique_labels_train) != num_classes or len(unique_labels_valid) != num_classes or len( + unique_labels_test) != num_classes: + raise ValueError(f"Number of classes in labels does not match expected: expected {num_classes}, " + f"but found different classes in training, validation, or test sets") + + # Step 3: Create DataLoader + print("Step 3: Creating DataLoader...") + train_loader = prepare_dataloader(train_features, train_labels, batch_size) + valid_loader = prepare_dataloader(valid_features, valid_labels, batch_size) + test_loader = prepare_dataloader(test_features, test_labels, batch_size) + + # Step 4: Initialize CNN + print("Step 4: Initializing CNN...") + num_filters = 256 # Use 256 convolutional output channels + kernel_sizes = [2, 3, 4] # Kernel sizes for convolution + k = 3 * len(kernel_sizes) + cnn_output_dim = num_filters * (k + 1) # Calculate the output feature dimension of CNN + + # Step 5: Initialize attention mechanism + print("Step 5: Initializing multi-head attention...") + attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8) + + # Step 6: Initialize classifier + print("Step 6: Initializing classifier...") + classifier_model = FinalClassifier(input_dim=768, num_classes=num_classes) + optimizer = torch.optim.Adam(classifier_model.parameters(), lr=learning_rate) + criterion = torch.nn.CrossEntropyLoss() + + # Step 7: Start training + print("Starting training...") + torch.autograd.set_detect_anomaly(True) + for epoch in range(epochs): + classifier_model.train() + epoch_loss = 0 + y_true = [] + y_pred = [] + + # Use tqdm to add progress bar for CNN feature extraction + for batch_x, batch_y in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} - Training"): + optimizer.zero_grad() + batch_x = torch.mean(batch_x, dim=1) + # Extract features from CNN + # cnn_output = extract_CNN_features(batch_x) + # batch_x = torch.mean(batch_x, dim=1) + # cnn_output = torch.cat((batch_x, cnn_output), dim=-1) + attention_output = attention_model(batch_x, batch_x, batch_x) + outputs = classifier_model(attention_output) + outputs = torch.mean(outputs, dim=1) + loss = criterion(outputs, batch_y) # Compute loss + loss.backward() # Backpropagation + optimizer.step() # Optimize + + epoch_loss += loss.item() + + _, predicted = torch.max(outputs, 1) # Get predicted class + y_true.extend(batch_y.tolist()) + y_pred.extend(predicted.tolist()) + + # Calculate training accuracy, precision, recall, and F1 score + accuracy = accuracy_score(y_true, y_pred) + precision = precision_score(y_true, y_pred, average='macro') + recall = recall_score(y_true, y_pred, average='macro') + f1 = f1_score(y_true, y_pred, average='macro') + + print( + f"Epoch [{epoch + 1}/{epochs}] Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}") + print(confusion_matrix(y_true, y_pred)) + + # Save model + torch.save(classifier_model, model_save_path) + print(f"Trained model has been saved to {model_save_path}") + + # Validation set evaluation + classifier_model.eval() + y_true = [] + y_pred = [] + + with torch.no_grad(): + for batch_x, batch_y in valid_loader: + batch_x = torch.mean(batch_x, dim=1) + # cnn_output = extract_CNN_features(batch_x) + # batch_x = torch.mean(batch_x, dim=1) + # cnn_output = torch.cat((batch_x, cnn_output), dim=-1) + attention_output = attention_model(batch_x, batch_x, batch_x) + outputs = classifier_model(attention_output) + outputs = torch.mean(outputs, dim=1) + _, predicted = torch.max(outputs, 1) + y_true.extend(batch_y.tolist()) + y_pred.extend(predicted.tolist()) + + # Validation accuracy, precision, recall, and F1 score + accuracy = accuracy_score(y_true, y_pred) + precision = precision_score(y_true, y_pred, average='macro') + recall = recall_score(y_true, y_pred, average='macro') + f1 = f1_score(y_true, y_pred, average='macro') + + print(f"\nValidation - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}") + print(confusion_matrix(y_true, y_pred)) + + # Test set evaluation + y_true = [] + y_pred = [] + + with torch.no_grad(): + for batch_x, batch_y in test_loader: + batch_x = torch.mean(batch_x, dim=1) + # cnn_output = extract_CNN_features(batch_x) + # batch_x = torch.mean(batch_x, dim=1) + # cnn_output = torch.cat((batch_x, cnn_output), dim=-1) + attention_output = attention_model(batch_x, batch_x, batch_x) + outputs = classifier_model(attention_output) + outputs = torch.mean(outputs, dim=1) + _, predicted = torch.max(outputs, 1) + y_true.extend(batch_y.tolist()) + y_pred.extend(predicted.tolist()) + # Test accuracy, precision, recall, and F1 score + accuracy = accuracy_score(y_true, y_pred) + precision = precision_score(y_true, y_pred, average='macro') + recall = recall_score(y_true, y_pred, average='macro') + f1 = f1_score(y_true, y_pred, average='macro') + + print(f"\nTest - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}") + print(confusion_matrix(y_true, y_pred)) + + if __name__ == "__main__": - # 示例调用 - sample_texts = ["This is a test text.", "Another example of text data."] + # Load and prepare data + train_data_path = './train.csv' + valid_data_path = './dev.csv' + test_data_path = './test.csv' + + train_data = pd.read_csv(train_data_path) + valid_data = pd.read_csv(valid_data_path) + test_data = pd.read_csv(test_data_path) + + train_labels = train_data['label'].values + valid_labels = valid_data['label'].values + test_labels = test_data['label'].values + + # Train model bert_model_path = './bert_model' ctm_tokenizer_path = './sentence_bert_model' - save_path = 'sample_embeddings.npy' - # 生成或加载 BERT+CTM 嵌入 - embeddings = get_bert_ctm_embeddings(sample_texts, bert_model_path, ctm_tokenizer_path, save_path=save_path) - - # 打印嵌入形状 - print(f"嵌入形状: {embeddings.shape}") + # Train model + train_model(train_data_path, valid_data_path, test_data_path, train_labels, valid_labels, test_labels, + bert_model_path, ctm_tokenizer_path, num_heads=12, num_classes=2, model_save_path='./final_model.pt')