Batch processing text embedding tests
This commit is contained in:
+14
-9
@@ -1,6 +1,8 @@
|
||||
import os
|
||||
from transformers.models.bert import BertTokenizer, BertModel
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
class BERT_CTM_Model:
|
||||
def __init__(self, bert_model_path):
|
||||
@@ -8,15 +10,18 @@ class BERT_CTM_Model:
|
||||
self.tokenizer = BertTokenizer.from_pretrained(bert_model_path)
|
||||
self.model = BertModel.from_pretrained(bert_model_path)
|
||||
|
||||
def get_bert_embeddings(self, text):
|
||||
"""使用BERT模型生成文本的嵌入向量"""
|
||||
inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80)
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
return outputs.last_hidden_state.cpu().numpy() # [batch_size, sequence_length, hidden_size]
|
||||
def get_bert_embeddings(self, texts):
|
||||
"""使用BERT模型批量生成文本的嵌入向量"""
|
||||
embeddings = []
|
||||
for text in tqdm(texts, desc="Processing texts with BERT"):
|
||||
inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80)
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size]
|
||||
return np.vstack(embeddings)
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = BERT_CTM_Model('./bert_model')
|
||||
text = "这是一个测试文本"
|
||||
embedding = model.get_bert_embeddings(text)
|
||||
print(embedding.shape)
|
||||
texts = ["这是第一个文本", "这是第二个文本"]
|
||||
embeddings = model.get_bert_embeddings(texts)
|
||||
print(embeddings.shape)
|
||||
|
||||
Reference in New Issue
Block a user