diff --git a/model_pro/BERT_CTM.py b/model_pro/BERT_CTM.py index 8bf1c98..61128b0 100644 --- a/model_pro/BERT_CTM.py +++ b/model_pro/BERT_CTM.py @@ -8,41 +8,87 @@ from contextualized_topic_models.utils.data_preparation import TopicModelDataPre from contextualized_topic_models.models.ctm import CombinedTM class BERT_CTM_Model: - def __init__(self, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=50): + def __init__(self, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=50, device=None): + # 确定设备 (CPU/GPU) + self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") + + # 检查模型路径是否存在 + if not os.path.exists(bert_model_path): + raise ValueError(f"BERT模型路径不存在: {bert_model_path}") + if not os.path.exists(ctm_tokenizer_path): + raise ValueError(f"CTM分词器路径不存在: {ctm_tokenizer_path}") + # 加载BERT模型和tokenizer self.tokenizer = BertTokenizer.from_pretrained(bert_model_path) - self.model = BertModel.from_pretrained(bert_model_path) + self.model = BertModel.from_pretrained(bert_model_path).to(self.device) # 创建CTM数据预处理对象 self.tp = TopicModelDataPreparation(ctm_tokenizer_path) self.n_components = n_components self.num_epochs = num_epochs + self.ctm_model = None def get_bert_embeddings(self, texts): """使用BERT模型批量生成文本的嵌入向量""" embeddings = [] for text in tqdm(texts, desc="Processing texts with BERT"): - inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80) + inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size] return np.vstack(embeddings) - + def chinese_tokenize(self, text): """使用jieba对中文文本进行分词""" return " ".join(jieba.cut(text)) def train_ctm(self, texts): """训练CTM模型""" - bow_texts = [self.chinese_tokenize(text) for text in texts] - training_dataset = self.tp.fit(text_for_contextual=texts, text_for_bow=bow_texts) + try: + # 分词并准备BOW文本 + bow_texts = [self.chinese_tokenize(text) for text in texts] + training_dataset = self.tp.fit(text_for_contextual=texts, text_for_bow=bow_texts) - # 训练CTM - ctm = CombinedTM(bow_size=len(self.tp.vocab), contextual_size=768, n_components=self.n_components, num_epochs=self.num_epochs) - ctm.fit(training_dataset) - print("CTM模型训练完成") + # 训练CTM + self.ctm_model = CombinedTM(bow_size=len(self.tp.vocab), contextual_size=768, + n_components=self.n_components, num_epochs=self.num_epochs) + self.ctm_model.fit(training_dataset) + print("CTM模型训练完成") + except Exception as e: + print(f"训练CTM模型时发生错误: {e}") + + def save_model(self, path): + """保存训练后的CTM模型""" + if self.ctm_model: + self.ctm_model.save(path) + print(f"CTM模型已保存至: {path}") + else: + print("未找到已训练的CTM模型,无法保存") + + def load_model(self, path): + """加载已保存的CTM模型""" + if os.path.exists(path): + self.ctm_model = CombinedTM.load(path) + print(f"CTM模型已加载自: {path}") + else: + print(f"无法加载模型,路径不存在: {path}") if __name__ == "__main__": - model = BERT_CTM_Model('./bert_model', './sentence_bert_model') + # 设定BERT和CTM模型的路径 + bert_model_path = './bert_model' + ctm_tokenizer_path = './sentence_bert_model' + + # 初始化模型 + model = BERT_CTM_Model(bert_model_path, ctm_tokenizer_path) + + # 示例文本 texts = ["这是第一个文本", "这是第二个文本"] + + # 训练CTM模型 model.train_ctm(texts) + + # 保存CTM模型 + model.save_model('./trained_ctm_model') + + # 加载CTM模型 + model.load_model('./trained_ctm_model')