Implement Chinese word segmentation
This commit is contained in:
@@ -3,6 +3,7 @@ from transformers.models.bert import BertTokenizer, BertModel
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import jieba
|
||||
|
||||
class BERT_CTM_Model:
|
||||
def __init__(self, bert_model_path):
|
||||
@@ -20,8 +21,12 @@ class BERT_CTM_Model:
|
||||
embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size]
|
||||
return np.vstack(embeddings)
|
||||
|
||||
def chinese_tokenize(self, text):
|
||||
"""使用jieba对中文文本进行分词"""
|
||||
return " ".join(jieba.cut(text))
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = BERT_CTM_Model('./bert_model')
|
||||
texts = ["这是第一个文本", "这是第二个文本"]
|
||||
embeddings = model.get_bert_embeddings(texts)
|
||||
print(embeddings.shape)
|
||||
text = "这是一个测试文本"
|
||||
tokenized_text = model.chinese_tokenize(text)
|
||||
print(tokenized_text)
|
||||
|
||||
Reference in New Issue
Block a user