132 lines
5.0 KiB
Python
132 lines
5.0 KiB
Python
import os
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
import pandas as pd
|
|
from tqdm import tqdm
|
|
from transformers.models.bert import BertTokenizer, BertModel
|
|
from contextualized_topic_models.models.ctm import CombinedTM
|
|
from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation
|
|
from contextualized_topic_models.utils.preprocessing import WhiteSpacePreprocessing
|
|
import numpy as np
|
|
import torch
|
|
import jieba
|
|
import pickle # 用于保存和加载模型
|
|
from utils.logger import model_logger as logging
|
|
|
|
class BERT_CTM:
|
|
def __init__(self, model_save_path='model_pro/saved_models/ctm_model.pkl'):
|
|
self.model_save_path = model_save_path
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
self.bert_model = None
|
|
self.tokenizer = None
|
|
self.ctm_model = None
|
|
self.vocab = None
|
|
self.vectorizer = None
|
|
|
|
def save_model(self):
|
|
"""保存模型和词袋"""
|
|
try:
|
|
with open(self.model_save_path, 'wb') as f:
|
|
pickle.dump({
|
|
'ctm_model': self.ctm_model,
|
|
'vocab': self.vocab,
|
|
'vectorizer': self.vectorizer
|
|
}, f)
|
|
logging.info(f"CTM模型和词袋保存到: {self.model_save_path}")
|
|
except Exception as e:
|
|
logging.error(f"保存模型时发生错误: {e}")
|
|
|
|
def load_model(self):
|
|
"""加载模型和词袋"""
|
|
try:
|
|
with open(self.model_save_path, 'rb') as f:
|
|
saved_data = pickle.load(f)
|
|
self.ctm_model = saved_data['ctm_model']
|
|
self.vocab = saved_data['vocab']
|
|
self.vectorizer = saved_data['vectorizer']
|
|
logging.info("CTM模型、词袋和vectorizer加载成功")
|
|
except Exception as e:
|
|
logging.error(f"加载模型时发生错误: {e}")
|
|
raise
|
|
|
|
def train(self, texts, num_topics=10, num_epochs=100):
|
|
"""训练CTM模型"""
|
|
try:
|
|
# 初始化BERT
|
|
if not self.bert_model:
|
|
self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
|
|
self.bert_model = BertModel.from_pretrained('bert-base-chinese').to(self.device)
|
|
|
|
# 提取BERT嵌入
|
|
logging.info("正在提取BERT嵌入...")
|
|
embeddings = self._get_bert_embeddings(texts)
|
|
|
|
# 准备CTM数据
|
|
logging.info("正在准备CTM训练数据...")
|
|
preprocessor = WhiteSpacePreprocessing(texts)
|
|
dataset = TopicModelDataPreparation(embeddings)
|
|
|
|
# 训练CTM模型
|
|
logging.info("正在训练CTM模型...")
|
|
self.ctm_model = CombinedTM(
|
|
bow_size=len(preprocessor.vocab),
|
|
contextual_size=768, # BERT输出维度
|
|
n_components=num_topics,
|
|
num_epochs=num_epochs
|
|
)
|
|
self.ctm_model.fit(dataset)
|
|
|
|
# 保存词袋相关数据
|
|
self.vocab = preprocessor.vocab
|
|
self.vectorizer = preprocessor.vectorizer
|
|
|
|
# 保存模型
|
|
self.save_model()
|
|
logging.info("模型训练完成并保存")
|
|
|
|
except Exception as e:
|
|
logging.error(f"训练模型时发生错误: {e}")
|
|
raise
|
|
|
|
def _get_bert_embeddings(self, texts):
|
|
"""获取文本的BERT嵌入"""
|
|
embeddings = []
|
|
try:
|
|
for text in texts:
|
|
inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
|
|
with torch.no_grad():
|
|
outputs = self.bert_model(**inputs)
|
|
# 使用[CLS]标记的输出作为文档表示
|
|
embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
|
embeddings.append(embedding[0])
|
|
|
|
return np.array(embeddings)
|
|
except Exception as e:
|
|
logging.error(f"获取BERT嵌入时发生错误: {e}")
|
|
raise
|
|
|
|
def get_topics(self, num_words=10):
|
|
"""获取主题词"""
|
|
try:
|
|
if not self.ctm_model or not self.vocab:
|
|
raise ValueError("模型未训练或未加载")
|
|
|
|
topics = []
|
|
for topic_idx in range(self.ctm_model.n_components):
|
|
topic = self.ctm_model.get_topic_lists(top_n=num_words)[topic_idx]
|
|
topics.append(topic)
|
|
return topics
|
|
except Exception as e:
|
|
logging.error(f"获取主题词时发生错误: {e}")
|
|
raise
|
|
|
|
if __name__ == "__main__":
|
|
# 创建BERT_CTM实例
|
|
model = BERT_CTM(
|
|
model_save_path='model_pro/saved_models/ctm_model.pkl', # 保存路径
|
|
)
|
|
|
|
# 传入CSV文件路径进行训练
|
|
model.train("./train.csv")
|