diff --git a/model_pro/BERT_CTM.py b/model_pro/BERT_CTM.py index 5aa12ed..3cd29cd 100644 --- a/model_pro/BERT_CTM.py +++ b/model_pro/BERT_CTM.py @@ -1,6 +1,8 @@ import os from transformers.models.bert import BertTokenizer, BertModel import torch +from tqdm import tqdm +import numpy as np class BERT_CTM_Model: def __init__(self, bert_model_path): @@ -8,15 +10,18 @@ class BERT_CTM_Model: self.tokenizer = BertTokenizer.from_pretrained(bert_model_path) self.model = BertModel.from_pretrained(bert_model_path) - def get_bert_embeddings(self, text): - """使用BERT模型生成文本的嵌入向量""" - inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80) - with torch.no_grad(): - outputs = self.model(**inputs) - return outputs.last_hidden_state.cpu().numpy() # [batch_size, sequence_length, hidden_size] + def get_bert_embeddings(self, texts): + """使用BERT模型批量生成文本的嵌入向量""" + embeddings = [] + for text in tqdm(texts, desc="Processing texts with BERT"): + inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80) + with torch.no_grad(): + outputs = self.model(**inputs) + embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size] + return np.vstack(embeddings) if __name__ == "__main__": model = BERT_CTM_Model('./bert_model') - text = "这是一个测试文本" - embedding = model.get_bert_embeddings(text) - print(embedding.shape) + texts = ["这是第一个文本", "这是第二个文本"] + embeddings = model.get_bert_embeddings(texts) + print(embeddings.shape)