From f0f43c8e985236887bfdd3a5cda4f2abd5bf6aa6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=88=92=E9=85=92=E7=9A=84=E6=9D=8E=E7=99=BD?= <670939375@qq.com> Date: Mon, 30 Sep 2024 09:24:08 +0800 Subject: [PATCH] Integrated training function of CTM model --- model_pro/BERT_CTM.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/model_pro/BERT_CTM.py b/model_pro/BERT_CTM.py index 9be30cd..8bf1c98 100644 --- a/model_pro/BERT_CTM.py +++ b/model_pro/BERT_CTM.py @@ -4,13 +4,20 @@ import torch from tqdm import tqdm import numpy as np import jieba +from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation +from contextualized_topic_models.models.ctm import CombinedTM class BERT_CTM_Model: - def __init__(self, bert_model_path): + def __init__(self, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=50): # 加载BERT模型和tokenizer self.tokenizer = BertTokenizer.from_pretrained(bert_model_path) self.model = BertModel.from_pretrained(bert_model_path) + # 创建CTM数据预处理对象 + self.tp = TopicModelDataPreparation(ctm_tokenizer_path) + self.n_components = n_components + self.num_epochs = num_epochs + def get_bert_embeddings(self, texts): """使用BERT模型批量生成文本的嵌入向量""" embeddings = [] @@ -25,8 +32,17 @@ class BERT_CTM_Model: """使用jieba对中文文本进行分词""" return " ".join(jieba.cut(text)) + def train_ctm(self, texts): + """训练CTM模型""" + bow_texts = [self.chinese_tokenize(text) for text in texts] + training_dataset = self.tp.fit(text_for_contextual=texts, text_for_bow=bow_texts) + + # 训练CTM + ctm = CombinedTM(bow_size=len(self.tp.vocab), contextual_size=768, n_components=self.n_components, num_epochs=self.num_epochs) + ctm.fit(training_dataset) + print("CTM模型训练完成") + if __name__ == "__main__": - model = BERT_CTM_Model('./bert_model') - text = "这是一个测试文本" - tokenized_text = model.chinese_tokenize(text) - print(tokenized_text) + model = BERT_CTM_Model('./bert_model', './sentence_bert_model') + texts = ["这是第一个文本", "这是第二个文本"] + model.train_ctm(texts)