Optimize database connection management in spider module, fix WebSocket handling and parameter validation.

This commit is contained in:
戒酒的李白
2025-02-25 10:27:41 +08:00
parent 930046fd5c
commit e0719583fc
4 changed files with 104 additions and 16 deletions
+4
View File
@@ -9,6 +9,7 @@ from pytz import utc
from datetime import datetime, timedelta from datetime import datetime, timedelta
import time import time
from utils.logger import app_logger as logging from utils.logger import app_logger as logging
from utils.db_manager import DatabaseManager
def get_db_connection_interactive(): def get_db_connection_interactive():
""" """
@@ -232,6 +233,9 @@ DB_CONFIG = {
'charset': 'utf8mb4' 'charset': 'utf8mb4'
} }
# 初始化数据库管理器
DatabaseManager.initialize(DB_CONFIG)
# 主程序入口 # 主程序入口
if __name__ == '__main__': if __name__ == '__main__':
# 检测是否需要初始化数据库 # 检测是否需要初始化数据库
+38 -3
View File
@@ -10,6 +10,7 @@ import logging
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from datetime import datetime from datetime import datetime
from utils.logger import spider_logger as logging from utils.logger import spider_logger as logging
from utils.db_manager import DatabaseManager
def spiderData(): def spiderData():
if not os.path.exists(navAddr): if not os.path.exists(navAddr):
@@ -26,6 +27,7 @@ class SpiderData:
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
} }
self.base_url = 'https://s.weibo.com' self.base_url = 'https://s.weibo.com'
self.db = DatabaseManager()
def crawl_topic(self, topic, depth=3, interval=5, max_retries=3, timeout=30): def crawl_topic(self, topic, depth=3, interval=5, max_retries=3, timeout=30):
""" """
@@ -37,7 +39,17 @@ class SpiderData:
:param max_retries: 最大重试次数 :param max_retries: 最大重试次数
:param timeout: 请求超时时间(秒) :param timeout: 请求超时时间(秒)
""" """
logging.info(f"开始爬取话题: {topic}") # 参数验证
if not isinstance(depth, int) or depth < 1 or depth > 10:
raise ValueError("爬取深度必须在1-10页之间")
if not isinstance(interval, int) or interval < 3 or interval > 30:
raise ValueError("请求间隔必须在3-30秒之间")
if not isinstance(max_retries, int) or max_retries < 1 or max_retries > 5:
raise ValueError("最大重试次数必须在1-5次之间")
if not isinstance(timeout, int) or timeout < 10 or timeout > 60:
raise ValueError("请求超时时间必须在10-60秒之间")
logging.info(f"开始爬取话题: {topic}, 参数: depth={depth}, interval={interval}, max_retries={max_retries}, timeout={timeout}")
for page in range(1, depth + 1): for page in range(1, depth + 1):
retries = 0 retries = 0
@@ -140,11 +152,34 @@ class SpiderData:
:param data: 要保存的数据字典 :param data: 要保存的数据字典
""" """
connection = None
try: try:
# TODO: 实现数据库保存逻辑 connection = self.db.get_connection()
logging.info(f"保存数据: {data}")
with connection.cursor() as cursor:
# 插入文章数据
sql = """
INSERT INTO article (content, user_name, publish_time, forward_count,
comment_count, like_count, crawl_time)
VALUES (%s, %s, %s, %s, %s, %s, %s)
"""
cursor.execute(sql, (
data['content'],
data['user_name'],
data['publish_time'],
data['forward_count'],
data['comment_count'],
data['like_count'],
data['crawl_time']
))
connection.commit()
logging.info(f"成功保存微博数据: {data['content'][:30]}...")
except Exception as e: except Exception as e:
logging.error(f"保存数据时出错: {e}") logging.error(f"保存数据时出错: {e}")
if connection:
connection.rollback()
if __name__ == '__main__': if __name__ == '__main__':
spiderData() spiderData()
+36
View File
@@ -0,0 +1,36 @@
import pymysql
from pymysql.cursors import DictCursor
class DatabaseManager:
_instance = None
_connection = None
_config = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(DatabaseManager, cls).__new__(cls)
return cls._instance
@classmethod
def initialize(cls, config):
"""初始化数据库配置"""
cls._config = config
@classmethod
def get_connection(cls):
"""获取数据库连接"""
if cls._connection is None or not cls._connection.open:
if cls._config is None:
raise ValueError("数据库未初始化,请先调用initialize方法设置配置")
cls._connection = pymysql.connect(
**cls._config,
cursorclass=DictCursor
)
return cls._connection
@classmethod
def close(cls):
"""关闭数据库连接"""
if cls._connection and cls._connection.open:
cls._connection.close()
cls._connection = None
+26 -13
View File
@@ -73,6 +73,15 @@ def spider_worker(topics, parameters):
total_topics = len(topics) total_topics = len(topics)
completed_topics = 0 completed_topics = 0
async def send_message(message):
"""异步发送消息的包装函数"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
await broadcast_message(message)
finally:
loop.close()
try: try:
spider = SpiderData() spider = SpiderData()
@@ -80,13 +89,13 @@ def spider_worker(topics, parameters):
try: try:
# 更新进度 # 更新进度
progress = int((completed_topics / total_topics) * 100) progress = int((completed_topics / total_topics) * 100)
asyncio.run(broadcast_message({ asyncio.run(send_message({
'type': 'progress', 'type': 'progress',
'value': progress 'value': progress
})) }))
# 发送开始爬取的日志 # 发送开始爬取的日志
asyncio.run(broadcast_message({ asyncio.run(send_message({
'type': 'log', 'type': 'log',
'message': f'开始爬取话题: {topic}' 'message': f'开始爬取话题: {topic}'
})) }))
@@ -103,33 +112,33 @@ def spider_worker(topics, parameters):
completed_topics += 1 completed_topics += 1
# 发送完成爬取的日志 # 发送完成爬取的日志
asyncio.run(broadcast_message({ asyncio.run(send_message({
'type': 'log', 'type': 'log',
'message': f'话题 {topic} 爬取完成' 'message': f'话题 {topic} 爬取完成'
})) }))
except Exception as e: except Exception as e:
# 发送错误日志 # 发送错误日志
asyncio.run(broadcast_message({ asyncio.run(send_message({
'type': 'log', 'type': 'log',
'message': f'爬取话题 {topic} 时出错: {str(e)}' 'message': f'爬取话题 {topic} 时出错: {str(e)}'
})) }))
# 更新最终进度 # 更新最终进度
asyncio.run(broadcast_message({ asyncio.run(send_message({
'type': 'progress', 'type': 'progress',
'value': 100 'value': 100
})) }))
# 发送完成消息 # 发送完成消息
asyncio.run(broadcast_message({ asyncio.run(send_message({
'type': 'log', 'type': 'log',
'message': '所有话题爬取完成' 'message': '所有话题爬取完成'
})) }))
except Exception as e: except Exception as e:
# 发送错误日志 # 发送错误日志
asyncio.run(broadcast_message({ asyncio.run(send_message({
'type': 'log', 'type': 'log',
'message': f'爬虫任务执行出错: {str(e)}' 'message': f'爬虫任务执行出错: {str(e)}'
})) }))
@@ -196,23 +205,27 @@ def save_spider_config():
}) })
@spider_bp.websocket('/ws/spider-status') @spider_bp.websocket('/ws/spider-status')
async def spider_status_socket(): async def spider_status_socket(websocket):
"""WebSocket连接处理""" """WebSocket连接处理"""
try: try:
websocket = websockets.WebSocketServerProtocol()
websocket_connections.add(websocket) websocket_connections.add(websocket)
logging.info("新的WebSocket连接已建立")
try: try:
while True: while True:
# 保持连接活跃 # 等待消息,保持连接活跃
await websocket.ping() message = await websocket.receive()
await asyncio.sleep(30) if message is None:
break
except websockets.exceptions.ConnectionClosed: except websockets.exceptions.ConnectionClosed:
pass logging.info("WebSocket连接已关闭")
finally: finally:
websocket_connections.remove(websocket) websocket_connections.remove(websocket)
logging.info("WebSocket连接已移除")
except Exception as e: except Exception as e:
logger.error(f"WebSocket连接处理失败: {e}") logger.error(f"WebSocket连接处理失败: {e}")
if websocket in websocket_connections:
websocket_connections.remove(websocket)
def get_ai_client(): def get_ai_client():
"""获取可用的AI客户端""" """获取可用的AI客户端"""