import os os.environ["TOKENIZERS_PARALLELISM"] = "false" import pandas as pd from tqdm import tqdm from transformers.models.bert import BertTokenizer, BertModel from contextualized_topic_models.models.ctm import CombinedTM from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation from contextualized_topic_models.utils.preprocessing import WhiteSpacePreprocessing import numpy as np import torch import jieba import pickle # 用于保存和加载模型 from utils.logger import model_logger as logging class BERT_CTM: def __init__(self, model_save_path='model_pro/saved_models/ctm_model.pkl'): self.model_save_path = model_save_path self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.bert_model = None self.tokenizer = None self.ctm_model = None self.vocab = None self.vectorizer = None def save_model(self): """保存模型和词袋""" try: with open(self.model_save_path, 'wb') as f: pickle.dump({ 'ctm_model': self.ctm_model, 'vocab': self.vocab, 'vectorizer': self.vectorizer }, f) logging.info(f"CTM模型和词袋保存到: {self.model_save_path}") except Exception as e: logging.error(f"保存模型时发生错误: {e}") def load_model(self): """加载模型和词袋""" try: with open(self.model_save_path, 'rb') as f: saved_data = pickle.load(f) self.ctm_model = saved_data['ctm_model'] self.vocab = saved_data['vocab'] self.vectorizer = saved_data['vectorizer'] logging.info("CTM模型、词袋和vectorizer加载成功") except Exception as e: logging.error(f"加载模型时发生错误: {e}") raise def train(self, texts, num_topics=10, num_epochs=100): """训练CTM模型""" try: # 初始化BERT if not self.bert_model: self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') self.bert_model = BertModel.from_pretrained('bert-base-chinese').to(self.device) # 提取BERT嵌入 logging.info("正在提取BERT嵌入...") embeddings = self._get_bert_embeddings(texts) # 准备CTM数据 logging.info("正在准备CTM训练数据...") preprocessor = WhiteSpacePreprocessing(texts) dataset = TopicModelDataPreparation(embeddings) # 训练CTM模型 logging.info("正在训练CTM模型...") self.ctm_model = CombinedTM( bow_size=len(preprocessor.vocab), contextual_size=768, # BERT输出维度 n_components=num_topics, num_epochs=num_epochs ) self.ctm_model.fit(dataset) # 保存词袋相关数据 self.vocab = preprocessor.vocab self.vectorizer = preprocessor.vectorizer # 保存模型 self.save_model() logging.info("模型训练完成并保存") except Exception as e: logging.error(f"训练模型时发生错误: {e}") raise def _get_bert_embeddings(self, texts): """获取文本的BERT嵌入""" embeddings = [] try: for text in texts: inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.bert_model(**inputs) # 使用[CLS]标记的输出作为文档表示 embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy() embeddings.append(embedding[0]) return np.array(embeddings) except Exception as e: logging.error(f"获取BERT嵌入时发生错误: {e}") raise def get_topics(self, num_words=10): """获取主题词""" try: if not self.ctm_model or not self.vocab: raise ValueError("模型未训练或未加载") topics = [] for topic_idx in range(self.ctm_model.n_components): topic = self.ctm_model.get_topic_lists(top_n=num_words)[topic_idx] topics.append(topic) return topics except Exception as e: logging.error(f"获取主题词时发生错误: {e}") raise if __name__ == "__main__": # 创建BERT_CTM实例 model = BERT_CTM( model_save_path='model_pro/saved_models/ctm_model.pkl', # 保存路径 ) # 传入CSV文件路径进行训练 model.train("./train.csv")