From 8c0479a978c62c98560df0689e8ec865a62930e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=88=92=E9=85=92=E7=9A=84=E6=9D=8E=E7=99=BD?= <670939375@qq.com> Date: Mon, 30 Sep 2024 00:14:40 +0800 Subject: [PATCH] Test the BERT model for Chinese simulation embedding --- model_pro/BERT_CTM.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 model_pro/BERT_CTM.py diff --git a/model_pro/BERT_CTM.py b/model_pro/BERT_CTM.py new file mode 100644 index 0000000..5aa12ed --- /dev/null +++ b/model_pro/BERT_CTM.py @@ -0,0 +1,22 @@ +import os +from transformers.models.bert import BertTokenizer, BertModel +import torch + +class BERT_CTM_Model: + def __init__(self, bert_model_path): + # 加载BERT模型和tokenizer + self.tokenizer = BertTokenizer.from_pretrained(bert_model_path) + self.model = BertModel.from_pretrained(bert_model_path) + + def get_bert_embeddings(self, text): + """使用BERT模型生成文本的嵌入向量""" + inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80) + with torch.no_grad(): + outputs = self.model(**inputs) + return outputs.last_hidden_state.cpu().numpy() # [batch_size, sequence_length, hidden_size] + +if __name__ == "__main__": + model = BERT_CTM_Model('./bert_model') + text = "这是一个测试文本" + embedding = model.get_bert_embeddings(text) + print(embedding.shape)