Implement Chinese word segmentation
This commit is contained in:
@@ -3,6 +3,7 @@ from transformers.models.bert import BertTokenizer, BertModel
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import jieba
|
||||||
|
|
||||||
class BERT_CTM_Model:
|
class BERT_CTM_Model:
|
||||||
def __init__(self, bert_model_path):
|
def __init__(self, bert_model_path):
|
||||||
@@ -19,9 +20,13 @@ class BERT_CTM_Model:
|
|||||||
outputs = self.model(**inputs)
|
outputs = self.model(**inputs)
|
||||||
embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size]
|
embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size]
|
||||||
return np.vstack(embeddings)
|
return np.vstack(embeddings)
|
||||||
|
|
||||||
|
def chinese_tokenize(self, text):
|
||||||
|
"""使用jieba对中文文本进行分词"""
|
||||||
|
return " ".join(jieba.cut(text))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
model = BERT_CTM_Model('./bert_model')
|
model = BERT_CTM_Model('./bert_model')
|
||||||
texts = ["这是第一个文本", "这是第二个文本"]
|
text = "这是一个测试文本"
|
||||||
embeddings = model.get_bert_embeddings(texts)
|
tokenized_text = model.chinese_tokenize(text)
|
||||||
print(embeddings.shape)
|
print(tokenized_text)
|
||||||
|
|||||||
Reference in New Issue
Block a user