234 lines
6.8 KiB
Python
234 lines
6.8 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
网络工具模块
|
||
功能:
|
||
1. 自动重试的HTTP请求
|
||
2. 多平台代理配置支持
|
||
3. DNS缓存优化
|
||
4. 连接超时与SSL验证
|
||
5. 用户代理轮换
|
||
"""
|
||
|
||
import socket
|
||
import time
|
||
import random
|
||
import platform
|
||
from typing import Optional, Dict, Any, Union
|
||
from urllib.parse import urlparse
|
||
from functools import lru_cache
|
||
import requests
|
||
from requests.adapters import HTTPAdapter
|
||
from urllib3.util.retry import Retry
|
||
from urllib3.connection import HTTPConnection
|
||
import os
|
||
|
||
# 类型别名
|
||
TimeoutType = Union[float, tuple]
|
||
|
||
|
||
class NetworkUtils:
|
||
"""跨平台网络操作工具类"""
|
||
|
||
def __init__(self):
|
||
self.system = platform.system().lower()
|
||
self._dns_cache = {}
|
||
self._setup_platform_specific()
|
||
|
||
def _setup_platform_specific(self):
|
||
"""平台相关初始化"""
|
||
if self.system == 'windows':
|
||
# Windows默认关闭TCP_NODELAY
|
||
HTTPConnection.default_socket_options = (
|
||
HTTPConnection.default_socket_options + [
|
||
(socket.IPPROTO_TCP, socket.TCP_NODELAY, 0)
|
||
]
|
||
)
|
||
elif self.system == 'linux':
|
||
# Linux启用TCP快速打开
|
||
HTTPConnection.default_socket_options = (
|
||
HTTPConnection.default_socket_options + [
|
||
(socket.IPPROTO_TCP, socket.TCP_FASTOPEN, 5)
|
||
]
|
||
)
|
||
|
||
@staticmethod
|
||
@lru_cache(maxsize=512)
|
||
def _resolve_hostname(hostname: str) -> str:
|
||
"""DNS缓存解析(跨线程安全)"""
|
||
try:
|
||
return socket.gethostbyname(hostname)
|
||
except socket.gaierror:
|
||
return hostname # 失败时返回原始域名
|
||
|
||
def get_session(
|
||
self,
|
||
retries: int = 3,
|
||
backoff_factor: float = 0.5,
|
||
timeout: TimeoutType = (3.05, 30),
|
||
proxy: Optional[str] = None,
|
||
verify_ssl: bool = True
|
||
) -> requests.Session:
|
||
"""
|
||
获取配置好的请求会话
|
||
|
||
参数:
|
||
retries: 重试次数
|
||
backoff_factor: 重试间隔系数
|
||
timeout: (连接超时, 读取超时)秒数
|
||
proxy: 代理地址(如 'http://user:pass@proxy:port')
|
||
verify_ssl: 是否验证SSL证书
|
||
"""
|
||
session = requests.Session()
|
||
|
||
# 重试策略
|
||
retry = Retry(
|
||
total=retries,
|
||
backoff_factor=backoff_factor,
|
||
status_forcelist=[500, 502, 503, 504],
|
||
allowed_methods=frozenset(['GET', 'POST', 'PUT', 'DELETE'])
|
||
)
|
||
|
||
# 适配器配置
|
||
adapter = HTTPAdapter(
|
||
max_retries=retry,
|
||
pool_connections=20,
|
||
pool_maxsize=100
|
||
)
|
||
|
||
# 挂载适配器
|
||
session.mount('http://', adapter)
|
||
session.mount('https://', adapter)
|
||
|
||
# 代理配置
|
||
if proxy:
|
||
session.proxies = {
|
||
'http': proxy,
|
||
'https': proxy
|
||
}
|
||
|
||
# 请求默认配置
|
||
session.request = self._wrap_request(
|
||
session.request,
|
||
timeout=timeout,
|
||
verify=verify_ssl
|
||
)
|
||
|
||
return session
|
||
|
||
def _wrap_request(self, original_request, **defaults):
|
||
"""包装请求方法添加默认参数"""
|
||
|
||
def wrapped(method, url, **kwargs):
|
||
# 处理DNS缓存
|
||
parsed = urlparse(url)
|
||
if parsed.hostname:
|
||
kwargs['hooks'] = kwargs.get('hooks', {})
|
||
kwargs['hooks']['pre_request'] = lambda r: setattr(
|
||
r, '_orig_host', r.url
|
||
)
|
||
url = url.replace(
|
||
parsed.hostname,
|
||
self._resolve_hostname(parsed.hostname),
|
||
1
|
||
)
|
||
|
||
# 合并默认参数
|
||
for k, v in defaults.items():
|
||
kwargs.setdefault(k, v)
|
||
|
||
return original_request(method, url, **kwargs)
|
||
|
||
return wrapped
|
||
|
||
def get_user_agent(self) -> str:
|
||
"""获取随机用户代理(兼容各平台)"""
|
||
agents = [
|
||
# Windows
|
||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
|
||
# macOS
|
||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 12_4) AppleWebKit/605.1.15",
|
||
# Linux
|
||
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36",
|
||
# Mobile
|
||
"Mozilla/5.0 (iPhone; CPU iPhone OS 15_5 like Mac OS X) AppleWebKit/605.1.15"
|
||
]
|
||
return random.choice(agents)
|
||
|
||
def check_connection(
|
||
self,
|
||
url: str = "https://www.baidu.com",
|
||
timeout: float = 5.0
|
||
) -> bool:
|
||
"""
|
||
检查网络连接状态
|
||
|
||
参数:
|
||
url: 测试用的URL
|
||
timeout: 超时时间(秒)
|
||
"""
|
||
try:
|
||
session = self.get_session(retries=0, timeout=timeout)
|
||
session.head(url)
|
||
return True
|
||
except Exception:
|
||
return False
|
||
|
||
def download_file(
|
||
self,
|
||
url: str,
|
||
save_path: str,
|
||
chunk_size: int = 8192,
|
||
progress_callback: Optional[callable] = None
|
||
) -> bool:
|
||
"""
|
||
下载大文件支持断点续传
|
||
|
||
参数:
|
||
url: 文件URL
|
||
save_path: 本地保存路径
|
||
chunk_size: 分块大小(字节)
|
||
progress_callback: 进度回调函数(bytes_downloaded, total_size)
|
||
"""
|
||
session = self.get_session()
|
||
headers = {}
|
||
|
||
# 检查本地文件部分下载
|
||
if os.path.exists(save_path):
|
||
downloaded = os.path.getsize(save_path)
|
||
headers['Range'] = f'bytes={downloaded}-'
|
||
|
||
try:
|
||
with session.get(url, headers=headers, stream=True) as r:
|
||
r.raise_for_status()
|
||
|
||
# 获取文件总大小
|
||
total_size = int(r.headers.get('content-length', 0)) + downloaded
|
||
|
||
# 追加模式写入
|
||
mode = 'ab' if headers.get('Range') else 'wb'
|
||
with open(save_path, mode) as f:
|
||
for chunk in r.iter_content(chunk_size=chunk_size):
|
||
if chunk: # 过滤keep-alive chunks
|
||
f.write(chunk)
|
||
downloaded += len(chunk)
|
||
if progress_callback:
|
||
progress_callback(downloaded, total_size)
|
||
return True
|
||
except Exception as e:
|
||
print(f"下载失败: {str(e)}")
|
||
return False
|
||
|
||
|
||
# 全局实例(线程安全)
|
||
network_utils = NetworkUtils()
|
||
|
||
|
||
# 快捷方法(兼容旧代码)
|
||
def get_session(*args, **kwargs):
|
||
return network_utils.get_session(*args, **kwargs)
|
||
|
||
|
||
def check_connection(*args, **kwargs):
|
||
return network_utils.check_connection(*args, **kwargs)
|