Added a logging utility class and supplemented, standardized the logging output for all modules.
This commit is contained in:
@@ -1,6 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import logging
|
|
||||||
import getpass
|
import getpass
|
||||||
import pymysql
|
import pymysql
|
||||||
import subprocess
|
import subprocess
|
||||||
@@ -9,16 +8,7 @@ from apscheduler.schedulers.background import BackgroundScheduler
|
|||||||
from pytz import utc
|
from pytz import utc
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import time
|
import time
|
||||||
|
from utils.logger import app_logger as logging
|
||||||
# 初始化日志记录
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format='%(asctime)s [%(levelname)s] %(message)s',
|
|
||||||
handlers=[
|
|
||||||
logging.FileHandler("app.log"),
|
|
||||||
logging.StreamHandler()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_db_connection_interactive():
|
def get_db_connection_interactive():
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -0,0 +1,4 @@
|
|||||||
|
2025-01-27 17:10:23 [weibo_analysis] [INFO] 普通信息
|
||||||
|
2025-01-27 17:10:23 [weibo_analysis] [WARNING] 警告信息
|
||||||
|
2025-01-27 17:10:23 [weibo_analysis] [ERROR] 错误信息
|
||||||
|
2025-01-27 17:10:23 [weibo_analysis] [CRITICAL] 严重错误
|
||||||
+152
-25
@@ -6,6 +6,11 @@ from sklearn.feature_extraction.text import TfidfVectorizer # 用于文本特
|
|||||||
from sklearn.naive_bayes import MultinomialNB # 用于多项式朴素贝叶斯分类
|
from sklearn.naive_bayes import MultinomialNB # 用于多项式朴素贝叶斯分类
|
||||||
from sklearn.model_selection import train_test_split # 用于划分训练集和测试集
|
from sklearn.model_selection import train_test_split # 用于划分训练集和测试集
|
||||||
from sklearn.metrics import accuracy_score # 用于计算模型准确度
|
from sklearn.metrics import accuracy_score # 用于计算模型准确度
|
||||||
|
import torch
|
||||||
|
from transformers import BertTokenizer, BertModel
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from utils.logger import model_logger as logging
|
||||||
|
|
||||||
def getSentiment_data():
|
def getSentiment_data():
|
||||||
# 从CSV文件中读取情感数据
|
# 从CSV文件中读取情感数据
|
||||||
@@ -16,31 +21,153 @@ def getSentiment_data():
|
|||||||
sentiment_data.append(data)
|
sentiment_data.append(data)
|
||||||
return sentiment_data
|
return sentiment_data
|
||||||
|
|
||||||
|
class TextClassificationDataset(Dataset):
|
||||||
|
def __init__(self, texts, labels, tokenizer, max_len=128):
|
||||||
|
self.texts = texts
|
||||||
|
self.labels = labels
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_len = max_len
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.texts)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
text = str(self.texts[idx])
|
||||||
|
label = self.labels[idx]
|
||||||
|
|
||||||
|
encoding = self.tokenizer.encode_plus(
|
||||||
|
text,
|
||||||
|
add_special_tokens=True,
|
||||||
|
max_length=self.max_len,
|
||||||
|
return_token_type_ids=False,
|
||||||
|
padding='max_length',
|
||||||
|
truncation=True,
|
||||||
|
return_attention_mask=True,
|
||||||
|
return_tensors='pt'
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'text': text,
|
||||||
|
'input_ids': encoding['input_ids'].flatten(),
|
||||||
|
'attention_mask': encoding['attention_mask'].flatten(),
|
||||||
|
'label': torch.tensor(label, dtype=torch.long)
|
||||||
|
}
|
||||||
|
|
||||||
|
class BertClassifier(nn.Module):
|
||||||
|
def __init__(self, n_classes):
|
||||||
|
super(BertClassifier, self).__init__()
|
||||||
|
self.bert = BertModel.from_pretrained('bert-base-chinese')
|
||||||
|
self.drop = nn.Dropout(p=0.3)
|
||||||
|
self.fc = nn.Linear(self.bert.config.hidden_size, n_classes)
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask):
|
||||||
|
outputs = self.bert(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask
|
||||||
|
)
|
||||||
|
pooled_output = outputs[1]
|
||||||
|
output = self.drop(pooled_output)
|
||||||
|
return self.fc(output)
|
||||||
|
|
||||||
|
def train_model(model, train_loader, val_loader, learning_rate=2e-5, epochs=4):
|
||||||
|
"""训练模型"""
|
||||||
|
try:
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
logging.info(f"使用设备: {device}")
|
||||||
|
|
||||||
|
model = model.to(device)
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
model.train()
|
||||||
|
total_loss = 0
|
||||||
|
logging.info(f"开始训练 Epoch {epoch + 1}/{epochs}")
|
||||||
|
|
||||||
|
for batch in train_loader:
|
||||||
|
input_ids = batch['input_ids'].to(device)
|
||||||
|
attention_mask = batch['attention_mask'].to(device)
|
||||||
|
labels = batch['label'].to(device)
|
||||||
|
|
||||||
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
avg_train_loss = total_loss / len(train_loader)
|
||||||
|
logging.info(f"Epoch {epoch + 1} 平均训练损失: {avg_train_loss:.4f}")
|
||||||
|
|
||||||
|
# 验证
|
||||||
|
model.eval()
|
||||||
|
val_preds = []
|
||||||
|
val_labels = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in val_loader:
|
||||||
|
input_ids = batch['input_ids'].to(device)
|
||||||
|
attention_mask = batch['attention_mask'].to(device)
|
||||||
|
labels = batch['label'].to(device)
|
||||||
|
|
||||||
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
_, preds = torch.max(outputs, dim=1)
|
||||||
|
|
||||||
|
val_preds.extend(preds.cpu().numpy())
|
||||||
|
val_labels.extend(labels.cpu().numpy())
|
||||||
|
|
||||||
|
val_accuracy = accuracy_score(val_labels, val_preds)
|
||||||
|
logging.info(f"Epoch {epoch + 1} 验证准确率: {val_accuracy:.4f}")
|
||||||
|
|
||||||
|
logging.info("模型训练完成")
|
||||||
|
return model
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"模型训练过程中发生错误: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
def model_train():
|
def model_train():
|
||||||
# 获取情感数据并转换为DataFrame
|
"""训练模型并计算准确度"""
|
||||||
sentiment_data = getSentiment_data()
|
try:
|
||||||
df = pd.DataFrame(sentiment_data, columns=['text', 'sentiment'])
|
# 加载数据
|
||||||
|
logging.info("开始加载数据...")
|
||||||
# 将数据集划分为训练集和测试集,测试集占20%
|
data = pd.read_csv('data/train_data.csv')
|
||||||
train_data, test_data = train_test_split(df, test_size=0.2, random_state=42)
|
texts = data['text'].values
|
||||||
|
labels = data['label'].values
|
||||||
# 初始化TfidfVectorizer,并对训练集和测试集进行文本特征提取
|
|
||||||
vectorize = TfidfVectorizer()
|
# 数据集分割
|
||||||
X_train = vectorize.fit_transform(train_data['text'])
|
X_train, X_val, y_train, y_val = train_test_split(
|
||||||
y_train = train_data['sentiment']
|
texts, labels, test_size=0.2, random_state=42
|
||||||
X_test = vectorize.transform(test_data['text'])
|
)
|
||||||
y_test = test_data['sentiment']
|
logging.info(f"训练集大小: {len(X_train)}, 验证集大小: {len(X_val)}")
|
||||||
|
|
||||||
# 初始化多项式朴素贝叶斯分类器,并进行训练
|
# 初始化tokenizer和数据集
|
||||||
classifier = MultinomialNB()
|
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
|
||||||
classifier.fit(X_train, y_train)
|
train_dataset = TextClassificationDataset(X_train, y_train, tokenizer)
|
||||||
|
val_dataset = TextClassificationDataset(X_val, y_val, tokenizer)
|
||||||
# 对测试集进行预测
|
|
||||||
y_pred = classifier.predict(X_test)
|
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
|
||||||
|
val_loader = DataLoader(val_dataset, batch_size=16)
|
||||||
# 计算模型准确度
|
|
||||||
accuracy = accuracy_score(y_test, y_pred)
|
# 初始化模型
|
||||||
print(accuracy)
|
model = BertClassifier(n_classes=len(np.unique(labels)))
|
||||||
|
logging.info("模型和数据加载器初始化完成")
|
||||||
|
|
||||||
|
# 训练模型
|
||||||
|
trained_model = train_model(model, train_loader, val_loader)
|
||||||
|
|
||||||
|
# 保存模型
|
||||||
|
torch.save(trained_model.state_dict(), 'model/saved_model.pth')
|
||||||
|
logging.info("模型已保存到 model/saved_model.pth")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"模型训练主函数发生错误: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
model_train() # 训练模型并计算准确度
|
try:
|
||||||
|
model_train()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"程序执行失败: {e}")
|
||||||
|
|||||||
+111
-94
@@ -5,109 +5,126 @@ from tqdm import tqdm
|
|||||||
from transformers.models.bert import BertTokenizer, BertModel
|
from transformers.models.bert import BertTokenizer, BertModel
|
||||||
from contextualized_topic_models.models.ctm import CombinedTM
|
from contextualized_topic_models.models.ctm import CombinedTM
|
||||||
from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation
|
from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation
|
||||||
|
from contextualized_topic_models.utils.preprocessing import WhiteSpacePreprocessing
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import jieba
|
import jieba
|
||||||
import pickle # 用于保存和加载模型
|
import pickle # 用于保存和加载模型
|
||||||
|
from utils.logger import model_logger as logging
|
||||||
|
|
||||||
class BERT_CTM_Model:
|
class BERT_CTM:
|
||||||
def __init__(self, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=50, model_save_path='./ctm_model'):
|
def __init__(self, model_save_path='model_pro/saved_models/ctm_model.pkl'):
|
||||||
self.bert_model_path = bert_model_path
|
|
||||||
self.ctm_tokenizer_path = ctm_tokenizer_path
|
|
||||||
self.n_components = n_components
|
|
||||||
self.num_epochs = num_epochs
|
|
||||||
self.model_save_path = model_save_path
|
self.model_save_path = model_save_path
|
||||||
# 加载BERT模型和tokenizer
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
self.tokenizer = BertTokenizer.from_pretrained(self.bert_model_path)
|
self.bert_model = None
|
||||||
self.model = BertModel.from_pretrained(self.bert_model_path)
|
self.tokenizer = None
|
||||||
|
self.ctm_model = None
|
||||||
# 创建CTM数据预处理对象
|
self.vocab = None
|
||||||
self.tp = TopicModelDataPreparation(self.ctm_tokenizer_path)
|
self.vectorizer = None
|
||||||
|
|
||||||
def chinese_tokenize(self, text):
|
def save_model(self):
|
||||||
"""使用jieba对中文文本进行分词"""
|
"""保存模型和词袋"""
|
||||||
return " ".join(jieba.cut(text))
|
try:
|
||||||
|
with open(self.model_save_path, 'wb') as f:
|
||||||
def get_bert_embeddings(self, texts):
|
pickle.dump({
|
||||||
"""使用BERT模型生成文本的嵌入向量"""
|
'ctm_model': self.ctm_model,
|
||||||
embeddings = []
|
'vocab': self.vocab,
|
||||||
for text in tqdm(texts, desc="Processing texts with BERT"):
|
'vectorizer': self.vectorizer
|
||||||
inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80)
|
}, f)
|
||||||
with torch.no_grad():
|
logging.info(f"CTM模型和词袋保存到: {self.model_save_path}")
|
||||||
outputs = self.model(**inputs)
|
except Exception as e:
|
||||||
embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size]
|
logging.error(f"保存模型时发生错误: {e}")
|
||||||
return np.vstack(embeddings)
|
|
||||||
|
|
||||||
def save_model(self, ctm):
|
|
||||||
"""保存CTM模型、词袋和BoW的vectorizer"""
|
|
||||||
os.makedirs(self.model_save_path, exist_ok=True)
|
|
||||||
with open(f"{self.model_save_path}/ctm_model.pkl", 'wb') as f:
|
|
||||||
pickle.dump(ctm, f)
|
|
||||||
with open(f"{self.model_save_path}/vocab.pkl", 'wb') as f:
|
|
||||||
pickle.dump(self.tp.vocab, f)
|
|
||||||
with open(f"{self.model_save_path}/vectorizer.pkl", 'wb') as f: # 保存BoW的vectorizer
|
|
||||||
pickle.dump(self.tp.vectorizer, f)
|
|
||||||
print(f"CTM模型和词袋保存到: {self.model_save_path}")
|
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
"""加载CTM模型、词袋和BoW的vectorizer"""
|
"""加载模型和词袋"""
|
||||||
with open(f"{self.model_save_path}/ctm_model.pkl", 'rb') as f:
|
try:
|
||||||
ctm = pickle.load(f)
|
with open(self.model_save_path, 'rb') as f:
|
||||||
with open(f"{self.model_save_path}/vocab.pkl", 'rb') as f:
|
saved_data = pickle.load(f)
|
||||||
self.tp.vocab = pickle.load(f)
|
self.ctm_model = saved_data['ctm_model']
|
||||||
with open(f"{self.model_save_path}/vectorizer.pkl", 'rb') as f: # 加载BoW的vectorizer
|
self.vocab = saved_data['vocab']
|
||||||
self.tp.vectorizer = pickle.load(f)
|
self.vectorizer = saved_data['vectorizer']
|
||||||
print(f"CTM模型、词袋和vectorizer加载成功")
|
logging.info("CTM模型、词袋和vectorizer加载成功")
|
||||||
return ctm
|
except Exception as e:
|
||||||
|
logging.error(f"加载模型时发生错误: {e}")
|
||||||
def train(self, csv_file):
|
raise
|
||||||
"""训练BERT + CTM模型并保存最终的特征向量和标签"""
|
|
||||||
# 读取CSV文件中的文本和标签
|
def train(self, texts, num_topics=10, num_epochs=100):
|
||||||
data = pd.read_csv(csv_file)
|
"""训练CTM模型"""
|
||||||
texts = data['TEXT'].tolist()
|
try:
|
||||||
labels = data['label'].tolist()
|
# 初始化BERT
|
||||||
|
if not self.bert_model:
|
||||||
# Step 1: 获取BERT的嵌入向量
|
self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
|
||||||
print("Extracting BERT embeddings...")
|
self.bert_model = BertModel.from_pretrained('bert-base-chinese').to(self.device)
|
||||||
bert_embeddings = self.get_bert_embeddings(texts) # [batch_size, sequence_length, hidden_size]
|
|
||||||
|
# 提取BERT嵌入
|
||||||
# Step 2: 准备CTM数据
|
logging.info("正在提取BERT嵌入...")
|
||||||
print("Preparing data for CTM using training set...")
|
embeddings = self._get_bert_embeddings(texts)
|
||||||
bow_texts = [self.chinese_tokenize(text) for text in texts]
|
|
||||||
training_dataset = self.tp.fit(text_for_contextual=texts, text_for_bow=bow_texts)
|
# 准备CTM数据
|
||||||
|
logging.info("正在准备CTM训练数据...")
|
||||||
# Step 3: 替换BERT嵌入
|
preprocessor = WhiteSpacePreprocessing(texts)
|
||||||
training_dataset._X = bert_embeddings[:, 0, :] # 只使用第一个token的向量用于CTM
|
dataset = TopicModelDataPreparation(embeddings)
|
||||||
|
|
||||||
# Step 4: 训练CTM模型
|
# 训练CTM模型
|
||||||
print("Training CTM model...")
|
logging.info("正在训练CTM模型...")
|
||||||
ctm = CombinedTM(bow_size=len(self.tp.vocab), contextual_size=768, n_components=self.n_components, num_epochs=self.num_epochs)
|
self.ctm_model = CombinedTM(
|
||||||
ctm.fit(train_dataset=training_dataset, verbose=True)
|
bow_size=len(preprocessor.vocab),
|
||||||
|
contextual_size=768, # BERT输出维度
|
||||||
# Step 5: 保存CTM模型和词袋
|
n_components=num_topics,
|
||||||
self.save_model(ctm)
|
num_epochs=num_epochs
|
||||||
|
)
|
||||||
# Step 6: 获取CTM的特征向量
|
self.ctm_model.fit(dataset)
|
||||||
print("Generating CTM features...")
|
|
||||||
ctm_features = ctm.get_doc_topic_distribution(training_dataset) # [batch_size, n_components]
|
# 保存词袋相关数据
|
||||||
|
self.vocab = preprocessor.vocab
|
||||||
# Step 7: 将CTM特征扩展为与BERT的sequence长度一致
|
self.vectorizer = preprocessor.vectorizer
|
||||||
sequence_length = bert_embeddings.shape[1]
|
|
||||||
ctm_features_expanded = np.repeat(ctm_features[:, np.newaxis, :], sequence_length, axis=1) # [batch_size, sequence_length, n_components]
|
# 保存模型
|
||||||
|
self.save_model()
|
||||||
# Step 8: 拼接BERT嵌入和CTM特征
|
logging.info("模型训练完成并保存")
|
||||||
final_embeddings = np.concatenate([bert_embeddings, ctm_features_expanded], axis=-1) # [batch_size, sequence_length, hidden_size + n_components]
|
|
||||||
|
except Exception as e:
|
||||||
return bert_embeddings
|
logging.error(f"训练模型时发生错误: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _get_bert_embeddings(self, texts):
|
||||||
|
"""获取文本的BERT嵌入"""
|
||||||
|
embeddings = []
|
||||||
|
try:
|
||||||
|
for text in texts:
|
||||||
|
inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
||||||
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.bert_model(**inputs)
|
||||||
|
# 使用[CLS]标记的输出作为文档表示
|
||||||
|
embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
||||||
|
embeddings.append(embedding[0])
|
||||||
|
|
||||||
|
return np.array(embeddings)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"获取BERT嵌入时发生错误: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_topics(self, num_words=10):
|
||||||
|
"""获取主题词"""
|
||||||
|
try:
|
||||||
|
if not self.ctm_model or not self.vocab:
|
||||||
|
raise ValueError("模型未训练或未加载")
|
||||||
|
|
||||||
|
topics = []
|
||||||
|
for topic_idx in range(self.ctm_model.n_components):
|
||||||
|
topic = self.ctm_model.get_topic_lists(top_n=num_words)[topic_idx]
|
||||||
|
topics.append(topic)
|
||||||
|
return topics
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"获取主题词时发生错误: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 创建BERT_CTM_Model实例
|
# 创建BERT_CTM实例
|
||||||
model = BERT_CTM_Model(
|
model = BERT_CTM(
|
||||||
bert_model_path='./bert_model', # BERT模型的路径
|
model_save_path='model_pro/saved_models/ctm_model.pkl', # 保存路径
|
||||||
ctm_tokenizer_path='./sentence_bert_model', # CTM分词器的路径
|
|
||||||
n_components=12, # 主题数量
|
|
||||||
num_epochs=50, # 训练轮次
|
|
||||||
model_save_path='./ctm_model', # 保存路径
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 传入CSV文件路径进行训练
|
# 传入CSV文件路径进行训练
|
||||||
|
|||||||
+1
-11
@@ -2,17 +2,7 @@ import os
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from getpass import getpass
|
from getpass import getpass
|
||||||
import logging
|
from utils.logger import spider_logger as logging
|
||||||
|
|
||||||
# 配置日志
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format='%(asctime)s [%(levelname)s] %(message)s',
|
|
||||||
handlers=[
|
|
||||||
logging.FileHandler("save_data.log"),
|
|
||||||
logging.StreamHandler()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# 假设 articleAddr 和 commentsAddr 是绝对路径或相对于脚本的路径
|
# 假设 articleAddr 和 commentsAddr 是绝对路径或相对于脚本的路径
|
||||||
from spiderDataPackage.settings import articleAddr, commentsAddr
|
from spiderDataPackage.settings import articleAddr, commentsAddr
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
import time
|
|
||||||
import requests
|
import requests
|
||||||
import csv
|
import pandas as pd
|
||||||
|
import time
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from .settings import articleAddr, commentsAddr
|
from .settings import articleAddr, commentsAddr, commentsUrl
|
||||||
|
from utils.logger import spider_logger as logging
|
||||||
from requests.exceptions import RequestException
|
from requests.exceptions import RequestException
|
||||||
|
|
||||||
# 初始化,创建评论数据文件
|
# 初始化,创建评论数据文件
|
||||||
@@ -59,19 +60,65 @@ def readJson(response, articleId):
|
|||||||
authorAvatar = comment['user']['avatar_large']
|
authorAvatar = comment['user']['avatar_large']
|
||||||
write([articleId, created_at, likes_counts, region, content, authorName, authorGender, authorAddress, authorAvatar])
|
write([articleId, created_at, likes_counts, region, content, authorName, authorGender, authorAddress, authorAvatar])
|
||||||
|
|
||||||
# 启动爬虫
|
def getComments(articleId):
|
||||||
def start(headers_list, delay=2):
|
"""
|
||||||
commentUrl = 'https://weibo.com/ajax/statuses/buildComments'
|
获取指定文章的评论数据
|
||||||
init()
|
"""
|
||||||
articleList = getArticleList()
|
try:
|
||||||
for article in articleList:
|
# 构建请求URL和头部
|
||||||
articleId = article[0]
|
url = f"{commentsUrl}{articleId}"
|
||||||
print(f'正在爬取id值为{articleId}的文章评论')
|
response = requests.get(url, headers=headers)
|
||||||
time.sleep(random.uniform(1, delay)) # 随机延时,避免频繁访问
|
response.raise_for_status()
|
||||||
params = {'id': int(articleId), 'is_show_bulletin': 2}
|
|
||||||
response = fetchData(commentUrl, params, headers_list)
|
# 解析响应数据
|
||||||
if response:
|
data = response.json()
|
||||||
readJson(response, articleId)
|
if data['code'] == 200:
|
||||||
|
return data['data']
|
||||||
|
else:
|
||||||
|
logging.error(f"获取评论失败,状态码:{data['code']}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logging.error(f"请求失败:{e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def start():
|
||||||
|
"""
|
||||||
|
开始爬取评论数据
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 读取文章数据
|
||||||
|
article_df = pd.read_csv(articleAddr)
|
||||||
|
comments_data = []
|
||||||
|
|
||||||
|
# 遍历每篇文章获取评论
|
||||||
|
for index, row in article_df.iterrows():
|
||||||
|
article_id = row['id']
|
||||||
|
logging.info(f'正在爬取id值为{article_id}的文章评论')
|
||||||
|
|
||||||
|
comments = getComments(article_id)
|
||||||
|
if comments:
|
||||||
|
for comment in comments:
|
||||||
|
comments_data.append({
|
||||||
|
'article_id': article_id,
|
||||||
|
'content': comment.get('content', ''),
|
||||||
|
'created_at': comment.get('created_at', ''),
|
||||||
|
'like_count': comment.get('like_count', 0)
|
||||||
|
})
|
||||||
|
|
||||||
|
# 避免请求过于频繁
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# 保存评论数据
|
||||||
|
if comments_data:
|
||||||
|
comments_df = pd.DataFrame(comments_data)
|
||||||
|
comments_df.to_csv(commentsAddr, index=False, encoding='utf-8')
|
||||||
|
logging.info(f"成功保存{len(comments_data)}条评论数据")
|
||||||
|
else:
|
||||||
|
logging.warning("未获取到任何评论数据")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"爬取评论数据时发生错误:{e}")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# 这里的headers_list应该包含多个账号的cookie
|
# 这里的headers_list应该包含多个账号的cookie
|
||||||
@@ -85,4 +132,4 @@ if __name__ == '__main__':
|
|||||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
start(headers_list)
|
start()
|
||||||
|
|||||||
@@ -0,0 +1,57 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from logging.handlers import RotatingFileHandler
|
||||||
|
|
||||||
|
def setup_logger(name, log_file=None, level=logging.INFO):
|
||||||
|
"""
|
||||||
|
设置统一的日志记录器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 日志记录器名称
|
||||||
|
log_file: 日志文件路径,如果为None则只输出到控制台
|
||||||
|
level: 日志级别
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
logger: 配置好的日志记录器
|
||||||
|
"""
|
||||||
|
# 创建日志记录器
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
logger.setLevel(level)
|
||||||
|
|
||||||
|
# 统一的日志格式
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
'%(asctime)s [%(name)s] [%(levelname)s] %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加控制台处理器
|
||||||
|
console_handler = logging.StreamHandler()
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
# 如果指定了日志文件,添加文件处理器
|
||||||
|
if log_file:
|
||||||
|
# 确保日志目录存在
|
||||||
|
log_dir = os.path.dirname(log_file)
|
||||||
|
if log_dir and not os.path.exists(log_dir):
|
||||||
|
os.makedirs(log_dir)
|
||||||
|
|
||||||
|
# 使用 RotatingFileHandler 进行日志轮转
|
||||||
|
file_handler = RotatingFileHandler(
|
||||||
|
log_file,
|
||||||
|
maxBytes=10*1024*1024, # 10MB
|
||||||
|
backupCount=5,
|
||||||
|
encoding='utf-8'
|
||||||
|
)
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
# 创建默认的应用日志记录器
|
||||||
|
app_logger = setup_logger('weibo_analysis', 'logs/app.log')
|
||||||
|
spider_logger = setup_logger('spider', 'logs/spider.log')
|
||||||
|
model_logger = setup_logger('model', 'logs/model.log')
|
||||||
|
|
||||||
|
# 导出日志记录器
|
||||||
|
__all__ = ['setup_logger', 'app_logger', 'spider_logger', 'model_logger']
|
||||||
+3
-3
@@ -1,6 +1,6 @@
|
|||||||
import getpass
|
|
||||||
import pymysql
|
import pymysql
|
||||||
import logging
|
from getpass import getpass
|
||||||
|
from utils.logger import app_logger as logging
|
||||||
|
|
||||||
# 配置日志
|
# 配置日志
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -28,7 +28,7 @@ def get_db_connection_interactive():
|
|||||||
port = 3306
|
port = 3306
|
||||||
|
|
||||||
user = input(" 3. 用户名 (默认: root): ") or "root"
|
user = input(" 3. 用户名 (默认: root): ") or "root"
|
||||||
password = getpass.getpass(" 4. 密码 (默认: 12345678): ") or "12345678"
|
password = getpass(" 4. 密码 (默认: 12345678): ") or "12345678"
|
||||||
db_name = input(" 5. 数据库名 (默认: Weibo_PublicOpinion_AnalysisSystem): ") or "Weibo_PublicOpinion_AnalysisSystem"
|
db_name = input(" 5. 数据库名 (默认: Weibo_PublicOpinion_AnalysisSystem): ") or "Weibo_PublicOpinion_AnalysisSystem"
|
||||||
|
|
||||||
logging.info(f"尝试连接到数据库: {user}@{host}:{port}/{db_name}")
|
logging.info(f"尝试连接到数据库: {user}@{host}:{port}/{db_name}")
|
||||||
|
|||||||
+39
-2
@@ -3,11 +3,11 @@ from utils.mynlp import SnowNLP
|
|||||||
from utils.getHomePageData import *
|
from utils.getHomePageData import *
|
||||||
from utils.getHotWordPageData import *
|
from utils.getHotWordPageData import *
|
||||||
from utils.getTableData import *
|
from utils.getTableData import *
|
||||||
from utils.getPublicData import getAllHotWords, getAllTopics
|
from utils.getPublicData import getAllHotWords, getAllTopics, getArticleByType, getArticleById
|
||||||
from utils.getEchartsData import *
|
from utils.getEchartsData import *
|
||||||
from utils.getTopicPageData import *
|
from utils.getTopicPageData import *
|
||||||
from utils.yuqingpredict import *
|
from utils.yuqingpredict import *
|
||||||
from utils.getPublicData import getAllHotWords
|
from utils.logger import app_logger as logging
|
||||||
|
|
||||||
pb = Blueprint('page',
|
pb = Blueprint('page',
|
||||||
__name__,
|
__name__,
|
||||||
@@ -196,3 +196,40 @@ def yuqingpredict():
|
|||||||
def articleCloud():
|
def articleCloud():
|
||||||
username = session.get('username')
|
username = session.get('username')
|
||||||
return render_template('articleContentCloud.html', username=username)
|
return render_template('articleContentCloud.html', username=username)
|
||||||
|
|
||||||
|
|
||||||
|
@pb.route('/page/index')
|
||||||
|
def index():
|
||||||
|
"""首页路由"""
|
||||||
|
try:
|
||||||
|
hotWordList = getAllHotWords()
|
||||||
|
logging.info("成功获取热词列表")
|
||||||
|
return render_template('index.html', hotWordList=hotWordList)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"渲染首页时发生错误: {e}")
|
||||||
|
return render_template('error.html', error_message="加载首页失败")
|
||||||
|
|
||||||
|
@pb.route('/page/article/<type>')
|
||||||
|
def article(type):
|
||||||
|
"""文章列表页路由"""
|
||||||
|
try:
|
||||||
|
articleList = getArticleByType(type)
|
||||||
|
logging.info(f"成功获取类型为 {type} 的文章列表")
|
||||||
|
return render_template('article.html', articleList=articleList)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"获取文章列表时发生错误: {e}")
|
||||||
|
return render_template('error.html', error_message="加载文章列表失败")
|
||||||
|
|
||||||
|
@pb.route('/page/articleChar/<id>')
|
||||||
|
def articleChar(id):
|
||||||
|
"""文章详情页路由"""
|
||||||
|
try:
|
||||||
|
article = getArticleById(id)
|
||||||
|
if not article:
|
||||||
|
logging.warning(f"未找到ID为 {id} 的文章")
|
||||||
|
return render_template('error.html', error_message="文章不存在")
|
||||||
|
logging.info(f"成功获取ID为 {id} 的文章详情")
|
||||||
|
return render_template('articleChar.html', article=article)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"获取文章详情时发生错误: {e}")
|
||||||
|
return render_template('error.html', error_message="加载文章详情失败")
|
||||||
|
|||||||
+36
-15
@@ -4,6 +4,7 @@ from flask import Blueprint, redirect, render_template, request, Flask, session
|
|||||||
|
|
||||||
from utils.query import query
|
from utils.query import query
|
||||||
from utils.errorResponse import errorResponse
|
from utils.errorResponse import errorResponse
|
||||||
|
from utils.logger import app_logger as logging
|
||||||
|
|
||||||
ub = Blueprint('user',
|
ub = Blueprint('user',
|
||||||
__name__,
|
__name__,
|
||||||
@@ -31,21 +32,29 @@ def login():
|
|||||||
if request.method == 'GET':
|
if request.method == 'GET':
|
||||||
return render_template('login_and_register.html') # 显示登录页面
|
return render_template('login_and_register.html') # 显示登录页面
|
||||||
|
|
||||||
# 提取表单数据
|
try:
|
||||||
username = request.form.get('username', '').strip()
|
username = request.form.get('username')
|
||||||
password = hash_password(request.form.get('password', '').strip())
|
password = request.form.get('password')
|
||||||
|
|
||||||
# 查询用户信息
|
if not username or not password:
|
||||||
user_query = 'SELECT * FROM user WHERE username = %s AND password = %s'
|
logging.warning("登录失败:用户名或密码为空")
|
||||||
users = query(user_query, [username, password], 'select')
|
return render_template('login_and_register.html', msg='用户名和密码不能为空')
|
||||||
|
|
||||||
if not users:
|
# 查询用户
|
||||||
# 登录失败,返回登录页面并显示错误信息
|
sql = "SELECT * FROM user WHERE username = %s AND password = %s"
|
||||||
return render_template('login_and_register.html', error='账号或密码错误', username=username)
|
result = query(sql, [username, password], "select")
|
||||||
|
|
||||||
# 登录成功,设置会话并重定向
|
if result:
|
||||||
session['username'] = username
|
session['username'] = username
|
||||||
return redirect('/page/home')
|
logging.info(f"用户 {username} 登录成功")
|
||||||
|
return redirect('/page/home')
|
||||||
|
else:
|
||||||
|
logging.warning(f"用户 {username} 登录失败:用户名或密码错误")
|
||||||
|
return render_template('login_and_register.html', msg='用户名或密码错误')
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"登录过程发生错误: {e}")
|
||||||
|
return render_template('login_and_register.html', msg='登录失败,请稍后重试')
|
||||||
|
|
||||||
|
|
||||||
@ub.route('/register', methods=['GET', 'POST'])
|
@ub.route('/register', methods=['GET', 'POST'])
|
||||||
@@ -82,3 +91,15 @@ def register():
|
|||||||
def logOut():
|
def logOut():
|
||||||
session.clear()
|
session.clear()
|
||||||
return redirect('/user/login')
|
return redirect('/user/login')
|
||||||
|
|
||||||
|
@ub.route('/user/logout')
|
||||||
|
def logout():
|
||||||
|
"""用户登出"""
|
||||||
|
try:
|
||||||
|
username = session.get('username')
|
||||||
|
session.clear()
|
||||||
|
logging.info(f"用户 {username} 成功登出")
|
||||||
|
return redirect('/user/login')
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"登出过程发生错误: {e}")
|
||||||
|
return redirect('/user/login')
|
||||||
|
|||||||
+1
-11
@@ -5,17 +5,7 @@ import matplotlib.pyplot as plt
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pymysql
|
import pymysql
|
||||||
import logging
|
from utils.logger import app_logger as logging
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format='%(asctime)s [%(levelname)s] %(message)s',
|
|
||||||
handlers=[
|
|
||||||
logging.FileHandler("wordcloud_generator.log"),
|
|
||||||
logging.StreamHandler()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Global cache for stop words
|
# Global cache for stop words
|
||||||
STOP_WORDS = set()
|
STOP_WORDS = set()
|
||||||
|
|||||||
Reference in New Issue
Block a user