diff --git a/model_pro/BERT_CTM.py b/model_pro/BERT_CTM.py index 3cd29cd..9be30cd 100644 --- a/model_pro/BERT_CTM.py +++ b/model_pro/BERT_CTM.py @@ -3,6 +3,7 @@ from transformers.models.bert import BertTokenizer, BertModel import torch from tqdm import tqdm import numpy as np +import jieba class BERT_CTM_Model: def __init__(self, bert_model_path): @@ -19,9 +20,13 @@ class BERT_CTM_Model: outputs = self.model(**inputs) embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size] return np.vstack(embeddings) + + def chinese_tokenize(self, text): + """使用jieba对中文文本进行分词""" + return " ".join(jieba.cut(text)) if __name__ == "__main__": model = BERT_CTM_Model('./bert_model') - texts = ["这是第一个文本", "这是第二个文本"] - embeddings = model.get_bert_embeddings(texts) - print(embeddings.shape) + text = "这是一个测试文本" + tokenized_text = model.chinese_tokenize(text) + print(tokenized_text)