Implement Chinese word segmentation

This commit is contained in:
戒酒的李白
2024-09-30 01:06:03 +08:00
parent 48af69dace
commit 5108ae1254
+8 -3
View File
@@ -3,6 +3,7 @@ from transformers.models.bert import BertTokenizer, BertModel
import torch import torch
from tqdm import tqdm from tqdm import tqdm
import numpy as np import numpy as np
import jieba
class BERT_CTM_Model: class BERT_CTM_Model:
def __init__(self, bert_model_path): def __init__(self, bert_model_path):
@@ -19,9 +20,13 @@ class BERT_CTM_Model:
outputs = self.model(**inputs) outputs = self.model(**inputs)
embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size] embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size]
return np.vstack(embeddings) return np.vstack(embeddings)
def chinese_tokenize(self, text):
"""使用jieba对中文文本进行分词"""
return " ".join(jieba.cut(text))
if __name__ == "__main__": if __name__ == "__main__":
model = BERT_CTM_Model('./bert_model') model = BERT_CTM_Model('./bert_model')
texts = ["这是一个文本", "这是第二个文本"] text = "这是一个测试文本"
embeddings = model.get_bert_embeddings(texts) tokenized_text = model.chinese_tokenize(text)
print(embeddings.shape) print(tokenized_text)