#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 网络工具模块 功能: 1. 自动重试的HTTP请求 2. 多平台代理配置支持 3. DNS缓存优化 4. 连接超时与SSL验证 5. 用户代理轮换 """ import socket import time import random import platform from typing import Optional, Dict, Any, Union from urllib.parse import urlparse from functools import lru_cache import requests from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry from urllib3.connection import HTTPConnection import os # 类型别名 TimeoutType = Union[float, tuple] class NetworkUtils: """跨平台网络操作工具类""" def __init__(self): self.system = platform.system().lower() self._dns_cache = {} self._setup_platform_specific() def _setup_platform_specific(self): """平台相关初始化""" if self.system == 'windows': # Windows默认关闭TCP_NODELAY HTTPConnection.default_socket_options = ( HTTPConnection.default_socket_options + [ (socket.IPPROTO_TCP, socket.TCP_NODELAY, 0) ] ) elif self.system == 'linux': # Linux启用TCP快速打开 HTTPConnection.default_socket_options = ( HTTPConnection.default_socket_options + [ (socket.IPPROTO_TCP, socket.TCP_FASTOPEN, 5) ] ) @staticmethod @lru_cache(maxsize=512) def _resolve_hostname(hostname: str) -> str: """DNS缓存解析(跨线程安全)""" try: return socket.gethostbyname(hostname) except socket.gaierror: return hostname # 失败时返回原始域名 def get_session( self, retries: int = 3, backoff_factor: float = 0.5, timeout: TimeoutType = (3.05, 30), proxy: Optional[str] = None, verify_ssl: bool = True ) -> requests.Session: """ 获取配置好的请求会话 参数: retries: 重试次数 backoff_factor: 重试间隔系数 timeout: (连接超时, 读取超时)秒数 proxy: 代理地址(如 'http://user:pass@proxy:port') verify_ssl: 是否验证SSL证书 """ session = requests.Session() # 重试策略 retry = Retry( total=retries, backoff_factor=backoff_factor, status_forcelist=[500, 502, 503, 504], allowed_methods=frozenset(['GET', 'POST', 'PUT', 'DELETE']) ) # 适配器配置 adapter = HTTPAdapter( max_retries=retry, pool_connections=20, pool_maxsize=100 ) # 挂载适配器 session.mount('http://', adapter) session.mount('https://', adapter) # 代理配置 if proxy: session.proxies = { 'http': proxy, 'https': proxy } # 请求默认配置 session.request = self._wrap_request( session.request, timeout=timeout, verify=verify_ssl ) return session def _wrap_request(self, original_request, **defaults): """包装请求方法添加默认参数""" def wrapped(method, url, **kwargs): # 处理DNS缓存 parsed = urlparse(url) if parsed.hostname: kwargs['hooks'] = kwargs.get('hooks', {}) kwargs['hooks']['pre_request'] = lambda r: setattr( r, '_orig_host', r.url ) url = url.replace( parsed.hostname, self._resolve_hostname(parsed.hostname), 1 ) # 合并默认参数 for k, v in defaults.items(): kwargs.setdefault(k, v) return original_request(method, url, **kwargs) return wrapped def get_user_agent(self) -> str: """获取随机用户代理(兼容各平台)""" agents = [ # Windows "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", # macOS "Mozilla/5.0 (Macintosh; Intel Mac OS X 12_4) AppleWebKit/605.1.15", # Linux "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36", # Mobile "Mozilla/5.0 (iPhone; CPU iPhone OS 15_5 like Mac OS X) AppleWebKit/605.1.15" ] return random.choice(agents) def check_connection( self, url: str = "https://www.baidu.com", timeout: float = 5.0 ) -> bool: """ 检查网络连接状态 参数: url: 测试用的URL timeout: 超时时间(秒) """ try: session = self.get_session(retries=0, timeout=timeout) session.head(url) return True except Exception: return False def download_file( self, url: str, save_path: str, chunk_size: int = 8192, progress_callback: Optional[callable] = None ) -> bool: """ 下载大文件支持断点续传 参数: url: 文件URL save_path: 本地保存路径 chunk_size: 分块大小(字节) progress_callback: 进度回调函数(bytes_downloaded, total_size) """ session = self.get_session() headers = {} # 检查本地文件部分下载 if os.path.exists(save_path): downloaded = os.path.getsize(save_path) headers['Range'] = f'bytes={downloaded}-' try: with session.get(url, headers=headers, stream=True) as r: r.raise_for_status() # 获取文件总大小 total_size = int(r.headers.get('content-length', 0)) + downloaded # 追加模式写入 mode = 'ab' if headers.get('Range') else 'wb' with open(save_path, mode) as f: for chunk in r.iter_content(chunk_size=chunk_size): if chunk: # 过滤keep-alive chunks f.write(chunk) downloaded += len(chunk) if progress_callback: progress_callback(downloaded, total_size) return True except Exception as e: print(f"下载失败: {str(e)}") return False # 全局实例(线程安全) network_utils = NetworkUtils() # 快捷方法(兼容旧代码) def get_session(*args, **kwargs): return network_utils.get_session(*args, **kwargs) def check_connection(*args, **kwargs): return network_utils.check_connection(*args, **kwargs)