数据库操作
This commit is contained in:
@@ -1,56 +0,0 @@
|
|||||||
[loggers]
|
|
||||||
keys=root,data_collector,api_client,alert
|
|
||||||
|
|
||||||
[handlers]
|
|
||||||
keys=consoleHandler,fileHandler,errorFileHandler
|
|
||||||
|
|
||||||
[formatters]
|
|
||||||
keys=standardFormatter,detailedFormatter
|
|
||||||
|
|
||||||
[logger_root]
|
|
||||||
level=INFO
|
|
||||||
handlers=consoleHandler,fileHandler
|
|
||||||
|
|
||||||
[logger_data_collector]
|
|
||||||
level=DEBUG
|
|
||||||
handlers=fileHandler
|
|
||||||
qualname=data_collector
|
|
||||||
propagate=0
|
|
||||||
|
|
||||||
[logger_api_client]
|
|
||||||
level=INFO
|
|
||||||
handlers=fileHandler,errorFileHandler
|
|
||||||
qualname=api_client
|
|
||||||
propagate=0
|
|
||||||
|
|
||||||
[logger_alert]
|
|
||||||
level=WARNING
|
|
||||||
handlers=consoleHandler,errorFileHandler
|
|
||||||
qualname=alert
|
|
||||||
propagate=0
|
|
||||||
|
|
||||||
[handler_consoleHandler]
|
|
||||||
class=StreamHandler
|
|
||||||
level=INFO
|
|
||||||
formatter=standardFormatter
|
|
||||||
args=(sys.stdout,)
|
|
||||||
|
|
||||||
[handler_fileHandler]
|
|
||||||
class=logging.handlers.TimedRotatingFileHandler
|
|
||||||
level=DEBUG
|
|
||||||
formatter=detailedFormatter
|
|
||||||
args=('%(log_dir)s/application.log', 'midnight', 1, 30, 'utf-8')
|
|
||||||
|
|
||||||
[handler_errorFileHandler]
|
|
||||||
class=logging.handlers.TimedRotatingFileHandler
|
|
||||||
level=WARNING
|
|
||||||
formatter=detailedFormatter
|
|
||||||
args=('%(log_dir)s/error.log', 'midnight', 1, 90, 'utf-8')
|
|
||||||
|
|
||||||
[formatter_standardFormatter]
|
|
||||||
format=%(asctime)s [%(levelname)-5s] %(name)s - %(message)s
|
|
||||||
datefmt=%Y-%m-%d %H:%M:%S
|
|
||||||
|
|
||||||
[formatter_detailedFormatter]
|
|
||||||
format=%(asctime)s [%(levelname)-5s] %(name)s (%(filename)s:%(lineno)d) - %(message)s
|
|
||||||
datefmt=%Y-%m-%d %H:%M:%S
|
|
||||||
+371
-155
@@ -1,193 +1,409 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
系统配置模块
|
|
||||||
功能:
|
|
||||||
1. 支持多平台路径适配
|
|
||||||
2. 环境变量与配置文件优先级管理
|
|
||||||
3. 敏感信息加密存储
|
|
||||||
4. 配置热更新检测
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import json
|
import sys
|
||||||
import platform
|
import platform
|
||||||
|
import pandas as pd
|
||||||
|
import pymysql
|
||||||
|
from pymysql import cursors
|
||||||
|
from pymysql.err import MySQLError
|
||||||
|
from dbutils.pooled_db import PooledDB
|
||||||
|
from typing import Union, List, Dict, Any, Optional, Tuple
|
||||||
|
import threading
|
||||||
|
from datetime import datetime
|
||||||
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, Optional
|
|
||||||
import dotenv
|
|
||||||
from cryptography.fernet import Fernet
|
|
||||||
|
|
||||||
class ConfigManager:
|
# 导入您的日志系统
|
||||||
def __init__(self, app_name: str = "intelligence_system"):
|
from utils.logger import log as logger
|
||||||
"""
|
|
||||||
初始化配置管理器
|
|
||||||
|
|
||||||
参数:
|
class MySQLAgent:
|
||||||
app_name: 应用名称(用于生成配置目录)
|
"""
|
||||||
"""
|
全平台兼容的MySQL数据库操作类
|
||||||
self.system = platform.system().lower()
|
支持Windows/macOS/Linux系统
|
||||||
self.app_name = app_name
|
"""
|
||||||
self._config = {}
|
|
||||||
self._secret_key = None
|
|
||||||
|
|
||||||
# 初始化配置路径
|
_instance = None
|
||||||
self.config_dir = self._get_config_dir()
|
_lock = threading.Lock()
|
||||||
os.makedirs(self.config_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# 加载配置顺序
|
# 各平台特定的配置
|
||||||
self._load_defaults()
|
PLATFORM_CONFIG = {
|
||||||
self._load_env_file()
|
'Windows': {
|
||||||
self._load_user_config()
|
'socket_timeout': 30,
|
||||||
|
'connect_timeout': 10,
|
||||||
|
'ssl': None
|
||||||
|
},
|
||||||
|
'Darwin': { # macOS
|
||||||
|
'socket_timeout': 60,
|
||||||
|
'connect_timeout': 15,
|
||||||
|
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
|
||||||
|
},
|
||||||
|
'Linux': {
|
||||||
|
'socket_timeout': 60,
|
||||||
|
'connect_timeout': 15,
|
||||||
|
'ssl': None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
# 初始化加密模块
|
def __new__(cls, *args, **kwargs):
|
||||||
self._init_encryption()
|
if not cls._instance:
|
||||||
|
with cls._lock:
|
||||||
|
if not cls._instance:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
def _get_config_dir(self) -> str:
|
def __init__(self, config: dict = None):
|
||||||
"""获取适合当前平台的配置目录"""
|
if hasattr(self, '_pool') and self._pool:
|
||||||
if self.system == 'windows':
|
return
|
||||||
return os.path.join(os.environ['APPDATA'], self.app_name)
|
|
||||||
elif self.system == 'darwin': # macOS
|
|
||||||
return os.path.expanduser(f"~/Library/Application Support/{self.app_name}")
|
|
||||||
else: # Linux及其他Unix-like
|
|
||||||
return os.path.expanduser(f"~/.config/{self.app_name}")
|
|
||||||
|
|
||||||
def _load_defaults(self):
|
if not config:
|
||||||
"""加载默认配置"""
|
from config.settings import DATABASE_CONFIG
|
||||||
self._config = {
|
config = DATABASE_CONFIG
|
||||||
"system": {
|
|
||||||
"log_level": "INFO",
|
# 获取当前平台配置
|
||||||
"max_threads": os.cpu_count() or 4
|
current_platform = platform.system()
|
||||||
},
|
platform_config = self.PLATFORM_CONFIG.get(current_platform, {})
|
||||||
"api": {
|
|
||||||
"newsapi": {"endpoint": "https://newsapi.org/v2"},
|
# 基础配置
|
||||||
"weibo": {"version": "2"}
|
self.config = {
|
||||||
},
|
'host': config.get('host', 'localhost'),
|
||||||
"paths": {
|
'port': config.get('port', 3306),
|
||||||
"data_dir": os.path.join(self.config_dir, "data"),
|
'user': config.get('user', 'root'),
|
||||||
"cache_dir": os.path.join(self.config_dir, "cache")
|
'password': config.get('password', ''),
|
||||||
}
|
'database': config.get('database', 'intelligence_system'),
|
||||||
|
'charset': config.get('charset', 'utf8mb4'),
|
||||||
|
'cursorclass': cursors.DictCursor,
|
||||||
|
'autocommit': True,
|
||||||
|
**platform_config # 合并平台特定配置
|
||||||
}
|
}
|
||||||
|
|
||||||
def _load_env_file(self):
|
# 处理各平台路径差异
|
||||||
"""加载.env环境变量文件"""
|
if current_platform == 'Windows':
|
||||||
env_path = Path(self.config_dir) / ".env"
|
self.config['ssl'] = None # Windows通常不需要SSL配置
|
||||||
if env_path.exists():
|
|
||||||
dotenv.load_dotenv(env_path)
|
|
||||||
|
|
||||||
# 环境变量覆盖配置
|
# macOS特殊处理
|
||||||
if os.getenv("LOG_LEVEL"):
|
elif current_platform == 'Darwin':
|
||||||
self._config["system"]["log_level"] = os.getenv("LOG_LEVEL")
|
if not os.path.exists(self.config['ssl']['ca']):
|
||||||
|
self.config['ssl'] = None
|
||||||
|
logger.warning("macOS SSL certificate not found, disabling SSL")
|
||||||
|
|
||||||
def _load_user_config(self):
|
self.pool_size = config.get('max_connections', 5)
|
||||||
"""加载用户自定义配置"""
|
self._pool = self._create_pool()
|
||||||
config_file = Path(self.config_dir) / "config.json"
|
self.logger = logger.bind(module=f"MySQLAgent({current_platform})")
|
||||||
|
|
||||||
|
def _create_pool(self) -> PooledDB:
|
||||||
|
"""创建跨平台兼容的连接池"""
|
||||||
try:
|
try:
|
||||||
if config_file.exists():
|
# 各平台连接池参数调整
|
||||||
with open(config_file, 'r', encoding='utf-8') as f:
|
pool_config = {
|
||||||
user_config = json.load(f)
|
'creator': pymysql,
|
||||||
self._deep_update(self._config, user_config)
|
'maxconnections': self.pool_size,
|
||||||
|
'mincached': 1,
|
||||||
|
'maxcached': 3,
|
||||||
|
'blocking': True,
|
||||||
|
'ping': 1, # 定期检查连接有效性
|
||||||
|
**self.config
|
||||||
|
}
|
||||||
|
|
||||||
|
# Windows平台需要更短的超时时间
|
||||||
|
if platform.system() == 'Windows':
|
||||||
|
pool_config['ping'] = 0 # Windows上ping有时不稳定
|
||||||
|
|
||||||
|
pool = PooledDB(**pool_config)
|
||||||
|
self.logger.info(f"Connection pool created for {platform.system()}")
|
||||||
|
return pool
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"加载用户配置失败: {str(e)}")
|
self.logger.critical("Failed to create connection pool",
|
||||||
|
error=str(e),
|
||||||
|
exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
def _init_encryption(self):
|
def _handle_path(self, path: str) -> str:
|
||||||
"""初始化配置加密模块"""
|
"""处理跨平台路径问题"""
|
||||||
key_file = Path(self.config_dir) / ".secret.key"
|
if platform.system() == 'Windows':
|
||||||
if key_file.exists():
|
return path.replace('/', '\\')
|
||||||
with open(key_file, 'rb') as f:
|
return path
|
||||||
self._secret_key = f.read()
|
|
||||||
else:
|
|
||||||
self._secret_key = Fernet.generate_key()
|
|
||||||
with open(key_file, 'wb') as f:
|
|
||||||
f.write(self._secret_key)
|
|
||||||
key_file.chmod(0o600) # 设置密钥文件权限
|
|
||||||
|
|
||||||
def _deep_update(self, original: Dict, update: Dict) -> Dict:
|
def get_connection(self) -> pymysql.connections.Connection:
|
||||||
"""深度合并字典"""
|
|
||||||
for key, value in update.items():
|
|
||||||
if isinstance(value, dict) and key in original:
|
|
||||||
original[key] = self._deep_update(original.get(key, {}), value)
|
|
||||||
else:
|
|
||||||
original[key] = value
|
|
||||||
return original
|
|
||||||
|
|
||||||
def get(self, key: str, default: Any = None) -> Any:
|
|
||||||
"""
|
"""
|
||||||
获取配置项(支持点分路径)
|
获取数据库连接(跨平台兼容)
|
||||||
示例: get("api.newsapi.endpoint")
|
|
||||||
|
Returns:
|
||||||
|
pymysql.connections.Connection: 数据库连接
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MySQLError: 如果连接失败
|
||||||
"""
|
"""
|
||||||
keys = key.split('.')
|
|
||||||
value = self._config
|
|
||||||
try:
|
try:
|
||||||
for k in keys:
|
conn = self._pool.connection()
|
||||||
value = value[k]
|
|
||||||
return value
|
|
||||||
except (KeyError, TypeError):
|
|
||||||
return default
|
|
||||||
|
|
||||||
def set(self, key: str, value: Any, persist: bool = False):
|
# macOS需要特殊处理SSL
|
||||||
|
if platform.system() == 'Darwin' and self.config.get('ssl'):
|
||||||
|
conn.ping(reconnect=True)
|
||||||
|
|
||||||
|
self.logger.trace("Connection obtained")
|
||||||
|
return conn
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
|
||||||
|
# Windows特定错误处理
|
||||||
|
if platform.system() == 'Windows' and "timed out" in error_msg:
|
||||||
|
self.logger.warning("Windows connection timeout, retrying...")
|
||||||
|
return self._retry_connection()
|
||||||
|
|
||||||
|
self.logger.error("Connection failed",
|
||||||
|
error=error_msg,
|
||||||
|
exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _retry_connection(self, max_retries: int = 3) -> pymysql.connections.Connection:
|
||||||
|
"""Windows平台连接重试机制"""
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
conn = self._pool.connection()
|
||||||
|
self.logger.info(f"Connection established after {attempt+1} attempts")
|
||||||
|
return conn
|
||||||
|
except Exception:
|
||||||
|
if attempt == max_retries - 1:
|
||||||
|
raise
|
||||||
|
import time
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None,
|
||||||
|
parse_dates: Union[List[str], bool] = True) -> pd.DataFrame:
|
||||||
"""
|
"""
|
||||||
设置配置项
|
跨平台兼容的SQL查询
|
||||||
参数:
|
|
||||||
persist: 是否保存到用户配置文件
|
Args:
|
||||||
|
sql (str): SQL语句
|
||||||
|
params (Union[tuple, dict, None]): 参数
|
||||||
|
parse_dates (Union[List[str], bool]): 日期解析
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: 查询结果
|
||||||
"""
|
"""
|
||||||
keys = key.split('.')
|
try:
|
||||||
config_ref = self._config
|
with self.get_connection() as conn:
|
||||||
|
# Linux/macOS需要更长的查询超时
|
||||||
|
if platform.system() != 'Windows':
|
||||||
|
conn.cursor().execute("SET SESSION wait_timeout=600")
|
||||||
|
|
||||||
for k in keys[:-1]:
|
df = pd.read_sql(sql, conn, params=params, parse_dates=parse_dates)
|
||||||
if k not in config_ref:
|
|
||||||
config_ref[k] = {}
|
|
||||||
config_ref = config_ref[k]
|
|
||||||
|
|
||||||
config_ref[keys[-1]] = value
|
# Windows平台需要手动关闭游标
|
||||||
|
if platform.system() == 'Windows':
|
||||||
|
conn.cursor().close()
|
||||||
|
|
||||||
if persist:
|
self.logger.info("Query executed", rows=len(df))
|
||||||
self._save_user_config()
|
return df
|
||||||
|
|
||||||
def encrypt_value(self, plaintext: str) -> str:
|
except Exception as e:
|
||||||
"""加密敏感信息"""
|
self.logger.error("Query failed",
|
||||||
fernet = Fernet(self._secret_key)
|
sql=sql,
|
||||||
return fernet.encrypt(plaintext.encode()).decode()
|
params=params,
|
||||||
|
error=str(e),
|
||||||
|
exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
def decrypt_value(self, ciphertext: str) -> str:
|
def insert_from_df(self, table_name: str, df: pd.DataFrame,
|
||||||
"""解密敏感信息"""
|
chunk_size: int = 1000, replace: bool = False) -> int:
|
||||||
fernet = Fernet(self._secret_key)
|
"""
|
||||||
return fernet.decrypt(ciphertext.encode()).decode()
|
跨平台数据插入
|
||||||
|
|
||||||
def _save_user_config(self):
|
Args:
|
||||||
"""保存用户配置到文件"""
|
table_name (str): 表名
|
||||||
config_file = Path(self.config_dir) / "config.json"
|
df (pd.DataFrame): 数据
|
||||||
with open(config_file, 'w', encoding='utf-8') as f:
|
chunk_size (int): 分批大小
|
||||||
json.dump(self._config, f, indent=2, ensure_ascii=False)
|
replace (bool): 是否替换
|
||||||
|
|
||||||
def reload(self):
|
Returns:
|
||||||
"""重新加载所有配置"""
|
int: 插入行数
|
||||||
self._config = {}
|
"""
|
||||||
self._load_defaults()
|
if df.empty:
|
||||||
self._load_env_file()
|
self.logger.warning("Empty DataFrame", table=table_name)
|
||||||
self._load_user_config()
|
return 0
|
||||||
|
|
||||||
# 全局配置实例
|
try:
|
||||||
config = ConfigManager()
|
method = 'replace' if replace else 'append'
|
||||||
|
total_rows = 0
|
||||||
|
|
||||||
# 快捷访问方法(兼容旧代码)
|
with self.get_connection() as conn:
|
||||||
def get_config(key: str, default: Optional[Any] = None) -> Any:
|
# 各平台不同的分批策略
|
||||||
return config.get(key, default)
|
if platform.system() == 'Windows':
|
||||||
|
chunk_size = min(chunk_size, 500) # Windows上减小批次
|
||||||
|
|
||||||
def set_config(key: str, value: Any, persist: bool = False):
|
for i in range(0, len(df), chunk_size):
|
||||||
config.set(key, value, persist)
|
chunk = df.iloc[i:i + chunk_size]
|
||||||
|
|
||||||
# 测试代码
|
# macOS需要特殊处理datetime
|
||||||
|
if platform.system() == 'Darwin':
|
||||||
|
for col in chunk.select_dtypes(include=['datetime64']):
|
||||||
|
chunk[col] = chunk[col].dt.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
|
||||||
|
chunk.to_sql(
|
||||||
|
table_name,
|
||||||
|
conn,
|
||||||
|
if_exists=method,
|
||||||
|
index=False,
|
||||||
|
method='multi'
|
||||||
|
)
|
||||||
|
total_rows += len(chunk)
|
||||||
|
method = 'append'
|
||||||
|
|
||||||
|
self.logger.info("Data inserted", table=table_name, rows=total_rows)
|
||||||
|
return total_rows
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error("Insert failed",
|
||||||
|
table=table_name,
|
||||||
|
error=str(e),
|
||||||
|
exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None,
|
||||||
|
fetch: bool = False) -> Union[int, List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
跨平台SQL执行
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sql (str): SQL语句
|
||||||
|
params (Union[tuple, dict, None]): 参数
|
||||||
|
fetch (bool): 是否获取结果
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[int, List[Dict[str, Any]]]: 结果
|
||||||
|
"""
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
conn = self.get_connection()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Linux/macOS需要更长的执行时间
|
||||||
|
if platform.system() != 'Windows':
|
||||||
|
cursor.execute("SET SESSION max_execution_time=600000")
|
||||||
|
|
||||||
|
cursor.execute(sql, params)
|
||||||
|
|
||||||
|
if fetch:
|
||||||
|
result = cursor.fetchall()
|
||||||
|
self.logger.debug("Query executed", rows=len(result))
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
affected_rows = cursor.rowcount
|
||||||
|
self.logger.debug("Update executed", affected_rows=affected_rows)
|
||||||
|
return affected_rows
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error("SQL execution failed",
|
||||||
|
sql=sql,
|
||||||
|
params=params,
|
||||||
|
error=str(e),
|
||||||
|
exc_info=True)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if cursor:
|
||||||
|
cursor.close()
|
||||||
|
if conn:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def begin_transaction(self) -> pymysql.connections.Connection:
|
||||||
|
"""开始事务(跨平台兼容)"""
|
||||||
|
try:
|
||||||
|
conn = self.get_connection()
|
||||||
|
conn.autocommit(False)
|
||||||
|
|
||||||
|
# macOS需要特殊处理事务隔离级别
|
||||||
|
if platform.system() == 'Darwin':
|
||||||
|
conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED")
|
||||||
|
|
||||||
|
self.logger.debug("Transaction started")
|
||||||
|
return conn
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error("Begin transaction failed", error=str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
def commit_transaction(self, conn: pymysql.connections.Connection) -> None:
|
||||||
|
"""提交事务(跨平台兼容)"""
|
||||||
|
try:
|
||||||
|
conn.commit()
|
||||||
|
self.logger.debug("Transaction committed")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error("Commit failed", error=str(e))
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def rollback_transaction(self, conn: pymysql.connections.Connection) -> None:
|
||||||
|
"""回滚事务(跨平台兼容)"""
|
||||||
|
try:
|
||||||
|
conn.rollback()
|
||||||
|
self.logger.warning("Transaction rolled back")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error("Rollback failed", error=str(e))
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""析构函数(跨平台资源清理)"""
|
||||||
|
if hasattr(self, '_pool'):
|
||||||
|
try:
|
||||||
|
self._pool.close()
|
||||||
|
self.logger.info("Connection pool closed")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error("Failed to close pool", error=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
# 平台特定的默认配置
|
||||||
|
def get_default_config():
|
||||||
|
"""获取各平台默认配置"""
|
||||||
|
current_platform = platform.system()
|
||||||
|
|
||||||
|
base_config = {
|
||||||
|
'host': 'localhost',
|
||||||
|
'port': 3306,
|
||||||
|
'user': 'root',
|
||||||
|
'password': '',
|
||||||
|
'database': 'intelligence_system',
|
||||||
|
'max_connections': 5
|
||||||
|
}
|
||||||
|
|
||||||
|
if current_platform == 'Windows':
|
||||||
|
return {
|
||||||
|
**base_config,
|
||||||
|
'connect_timeout': 10,
|
||||||
|
'read_timeout': 30,
|
||||||
|
'write_timeout': 30
|
||||||
|
}
|
||||||
|
elif current_platform == 'Darwin':
|
||||||
|
return {
|
||||||
|
**base_config,
|
||||||
|
'connect_timeout': 15,
|
||||||
|
'read_timeout': 60,
|
||||||
|
'write_timeout': 60,
|
||||||
|
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
|
||||||
|
}
|
||||||
|
else: # Linux和其他平台
|
||||||
|
return {
|
||||||
|
**base_config,
|
||||||
|
'connect_timeout': 15,
|
||||||
|
'read_timeout': 60,
|
||||||
|
'write_timeout': 60
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 设置并保存API密钥(自动加密)
|
# 自动获取适合当前平台的配置
|
||||||
api_key = "your_newsapi_key_here"
|
config = get_default_config()
|
||||||
encrypted_key = config.encrypt_value(api_key)
|
|
||||||
config.set("api.newsapi.key", encrypted_key, persist=True)
|
|
||||||
|
|
||||||
# 获取配置示例
|
# 初始化数据库连接
|
||||||
print(f"日志级别: {config.get('system.log_level')}")
|
db = MySQLAgent(config)
|
||||||
print(f"NewsAPI端点: {config.get('api.newsapi.endpoint')}")
|
|
||||||
|
|
||||||
# 解密敏感信息
|
# 测试查询
|
||||||
stored_key = config.get("api.newsapi.key")
|
try:
|
||||||
if stored_key:
|
df = db.query_to_df("SELECT VERSION() as version")
|
||||||
print(f"解密后的API密钥: {config.decrypt_value(stored_key)}")
|
print(f"Database version: {df['version'].iloc[0]}")
|
||||||
|
print(f"Running on: {platform.system()} {platform.release()}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {str(e)}")
|
||||||
@@ -17,7 +17,7 @@ from typing import Dict, List, Any
|
|||||||
|
|
||||||
# 自定义模块
|
# 自定义模块
|
||||||
from processors.data_processor import DataProcessor
|
from processors.data_processor import DataProcessor
|
||||||
from storage.database import IntelligenceDB
|
from storage.mysql_agent import IntelligenceDB
|
||||||
from applications.reporter import ReportGenerator
|
from applications.reporter import ReportGenerator
|
||||||
from applications.alert import AlertService
|
from applications.alert import AlertService
|
||||||
from utils.logger import setup_logging
|
from utils.logger import setup_logging
|
||||||
|
|||||||
@@ -1,255 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
数据库存储模块
|
|
||||||
功能:
|
|
||||||
1. 统一数据库接口(SQLite/MySQL/PostgreSQL)
|
|
||||||
2. 自动处理多平台路径问题
|
|
||||||
3. 连接池管理
|
|
||||||
4. 数据加密存储
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import platform
|
|
||||||
import sqlite3
|
|
||||||
import threading
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Union, Dict, Any, List
|
|
||||||
from threading import Lock
|
|
||||||
import logging
|
|
||||||
from cryptography.fernet import Fernet
|
|
||||||
|
|
||||||
# 类型别名
|
|
||||||
QueryParams = Union[tuple, Dict[str, Any]]
|
|
||||||
|
|
||||||
|
|
||||||
class DatabaseManager:
|
|
||||||
"""数据库统一管理类"""
|
|
||||||
|
|
||||||
def __init__(self, db_config: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
初始化数据库连接
|
|
||||||
|
|
||||||
参数:
|
|
||||||
db_config: 配置字典,包含:
|
|
||||||
- type: 'sqlite'|'mysql'|'postgresql'
|
|
||||||
- database: 数据库名/路径
|
|
||||||
- [可选] host, port, user, password
|
|
||||||
"""
|
|
||||||
self.config = db_config
|
|
||||||
self._lock = Lock()
|
|
||||||
self._connection_pool = {}
|
|
||||||
self._setup_crypto()
|
|
||||||
|
|
||||||
# 自动创建SQLite目录
|
|
||||||
if db_config['type'] == 'sqlite':
|
|
||||||
self._ensure_sqlite_dir()
|
|
||||||
|
|
||||||
def _ensure_sqlite_dir(self):
|
|
||||||
"""确保SQLite数据库目录存在"""
|
|
||||||
db_path = Path(self.config['database'])
|
|
||||||
if not db_path.parent.exists():
|
|
||||||
try:
|
|
||||||
db_path.parent.mkdir(parents=True, mode=0o755)
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"创建数据库目录失败: {str(e)}")
|
|
||||||
|
|
||||||
def _setup_crypto(self):
|
|
||||||
"""初始化加密模块"""
|
|
||||||
key_file = Path(os.path.expanduser("~/.db_encryption.key"))
|
|
||||||
if key_file.exists():
|
|
||||||
with open(key_file, 'rb') as f:
|
|
||||||
self._fernet = Fernet(f.read())
|
|
||||||
else:
|
|
||||||
self._fernet = Fernet.generate_key()
|
|
||||||
with open(key_file, 'wb') as f:
|
|
||||||
f.write(self._fernet)
|
|
||||||
key_file.chmod(0o600) # 仅限当前用户读写
|
|
||||||
|
|
||||||
def get_connection(self, reuse=True):
|
|
||||||
"""
|
|
||||||
获取数据库连接(线程安全)
|
|
||||||
|
|
||||||
参数:
|
|
||||||
reuse: 是否复用现有连接(默认True)
|
|
||||||
"""
|
|
||||||
thread_id = threading.get_ident()
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
if reuse and thread_id in self._connection_pool:
|
|
||||||
conn = self._connection_pool[thread_id]
|
|
||||||
try:
|
|
||||||
# 检查连接是否有效
|
|
||||||
conn.execute("SELECT 1")
|
|
||||||
return conn
|
|
||||||
except:
|
|
||||||
del self._connection_pool[thread_id]
|
|
||||||
|
|
||||||
# 创建新连接
|
|
||||||
if self.config['type'] == 'sqlite':
|
|
||||||
conn = self._create_sqlite_connection()
|
|
||||||
elif self.config['type'] == 'mysql':
|
|
||||||
conn = self._create_mysql_connection()
|
|
||||||
elif self.config['type'] == 'postgresql':
|
|
||||||
conn = self._create_pg_connection()
|
|
||||||
else:
|
|
||||||
raise ValueError("不支持的数据库类型")
|
|
||||||
|
|
||||||
self._connection_pool[thread_id] = conn
|
|
||||||
return conn
|
|
||||||
|
|
||||||
def _create_sqlite_connection(self) -> sqlite3.Connection:
|
|
||||||
"""创建SQLite连接(兼容多平台路径)"""
|
|
||||||
db_path = self.config['database']
|
|
||||||
|
|
||||||
# Windows路径处理
|
|
||||||
if platform.system() == 'Windows' and not db_path.startswith(('\\\\', '/')):
|
|
||||||
db_path = os.path.abspath(db_path)
|
|
||||||
|
|
||||||
conn = sqlite3.connect(db_path, timeout=15)
|
|
||||||
conn.execute("PRAGMA journal_mode=WAL") # 写前日志提升并发
|
|
||||||
conn.execute("PRAGMA synchronous=NORMAL")
|
|
||||||
conn.row_factory = sqlite3.Row # 支持字典式访问
|
|
||||||
return conn
|
|
||||||
|
|
||||||
def _create_mysql_connection(self):
|
|
||||||
"""创建MySQL连接(需安装PyMySQL)"""
|
|
||||||
import pymysql
|
|
||||||
return pymysql.connect(
|
|
||||||
host=self.config.get('host', 'localhost'),
|
|
||||||
port=self.config.get('port', 3306),
|
|
||||||
user=self.config.get('user', 'root'),
|
|
||||||
password=self.config.get('password', ''),
|
|
||||||
database=self.config['database'],
|
|
||||||
charset='utf8mb4',
|
|
||||||
cursorclass=pymysql.cursors.DictCursor
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_pg_connection(self):
|
|
||||||
"""创建PostgreSQL连接(需安装psycopg2)"""
|
|
||||||
import psycopg2
|
|
||||||
return psycopg2.connect(
|
|
||||||
host=self.config.get('host', 'localhost'),
|
|
||||||
port=self.config.get('port', 5432),
|
|
||||||
user=self.config.get('user', 'postgres'),
|
|
||||||
password=self.config.get('password', ''),
|
|
||||||
dbname=self.config['database']
|
|
||||||
)
|
|
||||||
|
|
||||||
def execute(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
params: Optional[QueryParams] = None,
|
|
||||||
return_lastrowid: bool = False
|
|
||||||
) -> Union[int, None]:
|
|
||||||
conn = self.get_connection()
|
|
||||||
try:
|
|
||||||
with conn:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
if params is not None:
|
|
||||||
cursor.execute(query, params)
|
|
||||||
else:
|
|
||||||
cursor.execute(query) # ✅ 无参数时不要传 params
|
|
||||||
return cursor.lastrowid if return_lastrowid else None
|
|
||||||
finally:
|
|
||||||
if not return_lastrowid:
|
|
||||||
self._release_connection()
|
|
||||||
|
|
||||||
def query(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
params: Optional[QueryParams] = None,
|
|
||||||
fetchall: bool = True
|
|
||||||
) -> Union[List[Dict], Dict]:
|
|
||||||
conn = self.get_connection()
|
|
||||||
try:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
if params is not None:
|
|
||||||
cursor.execute(query, params)
|
|
||||||
else:
|
|
||||||
cursor.execute(query) # ✅ 无参数时不要传 None
|
|
||||||
|
|
||||||
if self.config['type'] == 'sqlite':
|
|
||||||
result = cursor.fetchall()
|
|
||||||
if not fetchall and result:
|
|
||||||
return dict(result[0])
|
|
||||||
return [dict(row) for row in result]
|
|
||||||
else:
|
|
||||||
return cursor.fetchall() if fetchall else cursor.fetchone()
|
|
||||||
finally:
|
|
||||||
self._release_connection()
|
|
||||||
|
|
||||||
def _release_connection(self):
|
|
||||||
"""释放当前线程的连接(SQLite除外)"""
|
|
||||||
if self.config['type'] != 'sqlite':
|
|
||||||
thread_id = threading.get_ident()
|
|
||||||
with self._lock:
|
|
||||||
if thread_id in self._connection_pool:
|
|
||||||
self._connection_pool[thread_id].close()
|
|
||||||
del self._connection_pool[thread_id]
|
|
||||||
|
|
||||||
def encrypt_data(self, plaintext: str) -> str:
|
|
||||||
"""加密敏感数据"""
|
|
||||||
return self._fernet.encrypt(plaintext.encode()).decode()
|
|
||||||
|
|
||||||
def decrypt_data(self, ciphertext: str) -> str:
|
|
||||||
"""解密数据"""
|
|
||||||
return self._fernet.decrypt(ciphertext.encode()).decode()
|
|
||||||
|
|
||||||
def close_all(self):
|
|
||||||
"""关闭所有数据库连接"""
|
|
||||||
with self._lock:
|
|
||||||
for conn in self._connection_pool.values():
|
|
||||||
try:
|
|
||||||
conn.close()
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
self._connection_pool.clear()
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
self.close_all()
|
|
||||||
|
|
||||||
|
|
||||||
# 全局SQLite实例(默认配置)
|
|
||||||
def get_default_db() -> DatabaseManager:
|
|
||||||
"""获取默认SQLite数据库(跨平台路径处理)"""
|
|
||||||
system = platform.system().lower()
|
|
||||||
if system == 'windows':
|
|
||||||
db_path = os.path.join(os.getenv('APPDATA'), 'app_name/data.db')
|
|
||||||
elif system == 'darwin':
|
|
||||||
db_path = os.path.expanduser('~/Library/Application Support/app_name/data.db')
|
|
||||||
else:
|
|
||||||
db_path = '/var/lib/app_name/data.db' if os.access('/var/lib', os.W_OK) \
|
|
||||||
else os.path.expanduser('~/.local/share/app_name/data.db')
|
|
||||||
|
|
||||||
return DatabaseManager({
|
|
||||||
'type': 'sqlite',
|
|
||||||
'database': db_path
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
# 测试代码
|
|
||||||
if __name__ == "__main__":
|
|
||||||
with get_default_db() as db:
|
|
||||||
# 创建测试表
|
|
||||||
db.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS test_table (
|
|
||||||
id INTEGER PRIMARY KEY,
|
|
||||||
name TEXT,
|
|
||||||
secret TEXT
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
|
|
||||||
# 插入加密数据
|
|
||||||
secret = db.encrypt_data("新敏感信息")
|
|
||||||
db.execute(
|
|
||||||
"INSERT INTO test_table (name, secret) VALUES (?, ?)",
|
|
||||||
("测试记录", secret)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 查询并解密
|
|
||||||
row = db.query("SELECT * FROM test_table", fetchall=False)
|
|
||||||
print(f"解密数据: {db.decrypt_data(row['name'])}")
|
|
||||||
@@ -0,0 +1,662 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import platform
|
||||||
|
import pandas as pd
|
||||||
|
import pymysql
|
||||||
|
from pymysql import cursors
|
||||||
|
from pymysql.err import MySQLError
|
||||||
|
from dbutils.pooled_db import PooledDB
|
||||||
|
from typing import Union, List, Dict, Any, Optional, Tuple
|
||||||
|
import threading
|
||||||
|
from datetime import datetime
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# 导入日志系统
|
||||||
|
from utils.logger import log
|
||||||
|
|
||||||
|
|
||||||
|
class MySQLAgent:
|
||||||
|
"""
|
||||||
|
全平台兼容的MySQL数据库操作类
|
||||||
|
支持Windows/macOS/Linux系统
|
||||||
|
配置参数从外部传入
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
_lock = threading.Lock()
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
if not cls._instance:
|
||||||
|
with cls._lock:
|
||||||
|
if not cls._instance:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self, config: dict):
|
||||||
|
"""
|
||||||
|
初始化MySQL数据库连接
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (dict): 数据库配置字典,包含以下键:
|
||||||
|
- host: 数据库主机
|
||||||
|
- port: 端口
|
||||||
|
- user: 用户名
|
||||||
|
- password: 密码
|
||||||
|
- database: 数据库名
|
||||||
|
- [可选] charset: 字符集(默认utf8mb4)
|
||||||
|
- [可选] max_connections: 最大连接数(默认5)
|
||||||
|
- [可选] connect_timeout: 连接超时(秒)
|
||||||
|
- [可选] read_timeout: 读取超时(秒)
|
||||||
|
- [可选] write_timeout: 写入超时(秒)
|
||||||
|
- [可选] ssl: SSL配置
|
||||||
|
"""
|
||||||
|
if hasattr(self, '_pool') and self._pool:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 基础配置
|
||||||
|
required_keys = ['host', 'port', 'user', 'password', 'database']
|
||||||
|
if not all(key in config for key in required_keys):
|
||||||
|
raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}")
|
||||||
|
|
||||||
|
self.config = {
|
||||||
|
'host': config['host'],
|
||||||
|
'port': config['port'],
|
||||||
|
'user': config['user'],
|
||||||
|
'password': config['password'],
|
||||||
|
'database': config['database'],
|
||||||
|
'charset': config.get('charset', 'utf8mb4'),
|
||||||
|
'cursorclass': cursors.DictCursor,
|
||||||
|
'autocommit': True,
|
||||||
|
'connect_timeout': config.get('connect_timeout', 10),
|
||||||
|
'read_timeout': config.get('read_timeout', 30),
|
||||||
|
'write_timeout': config.get('write_timeout', 30),
|
||||||
|
'ssl': config.get('ssl')
|
||||||
|
}
|
||||||
|
|
||||||
|
# 初始化log
|
||||||
|
current_platform = platform.system()
|
||||||
|
self.log = log.bind(module=f"MySQLAgent({current_platform})")
|
||||||
|
|
||||||
|
# 创建连接池
|
||||||
|
self.pool_size = config.get('max_connections', 5)
|
||||||
|
self._pool = self._create_pool()
|
||||||
|
|
||||||
|
def _create_pool(self) -> PooledDB:
|
||||||
|
"""创建连接池"""
|
||||||
|
try:
|
||||||
|
# 使用包装函数确保线程安全
|
||||||
|
def connect():
|
||||||
|
conn = pymysql.connect(**self.config)
|
||||||
|
conn.threadsafety = 1 # 显式设置线程安全级别
|
||||||
|
return conn
|
||||||
|
|
||||||
|
pool = PooledDB(
|
||||||
|
creator=connect,
|
||||||
|
mincached=1,
|
||||||
|
maxcached=3,
|
||||||
|
maxconnections=self.pool_size,
|
||||||
|
blocking=True,
|
||||||
|
ping=1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.log.info("Connection pool created")
|
||||||
|
return pool
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.log.critical("Failed to create connection pool",
|
||||||
|
error=str(e),
|
||||||
|
exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_connection(self) -> pymysql.connections.Connection:
|
||||||
|
"""
|
||||||
|
获取数据库连接
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pymysql.connections.Connection: 数据库连接对象
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MySQLError: 如果获取连接失败
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
conn = self._pool.connection()
|
||||||
|
|
||||||
|
# macOS需要特殊处理SSL
|
||||||
|
if platform.system() == 'Darwin' and self.config.get('ssl'):
|
||||||
|
conn.ping(reconnect=True)
|
||||||
|
|
||||||
|
self.log.trace("Database connection obtained")
|
||||||
|
return conn
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
|
||||||
|
# Windows特定错误处理
|
||||||
|
if platform.system() == 'Windows' and "timed out" in error_msg:
|
||||||
|
self.log.warning("Windows connection timeout, retrying...")
|
||||||
|
return self._retry_connection()
|
||||||
|
|
||||||
|
self.log.error("Connection failed",
|
||||||
|
error=error_msg,
|
||||||
|
exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _retry_connection(self, max_retries: int = 3) -> pymysql.connections.Connection:
|
||||||
|
"""Windows平台连接重试机制"""
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
conn = self._pool.connection()
|
||||||
|
self.log.info(f"Connection established after {attempt + 1} attempts")
|
||||||
|
return conn
|
||||||
|
except Exception:
|
||||||
|
if attempt == max_retries - 1:
|
||||||
|
raise
|
||||||
|
import time
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None,
|
||||||
|
parse_dates: Union[List[str], bool] = True) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
执行SQL查询并返回DataFrame
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sql (str): SQL查询语句
|
||||||
|
params (Union[tuple, dict, None]): 查询参数
|
||||||
|
parse_dates (Union[List[str], bool]): 自动解析日期字段
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: 查询结果
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MySQLError: 如果查询失败
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.log.debug("Executing SQL query", sql=sql)
|
||||||
|
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
# Linux/macOS需要更长的查询超时
|
||||||
|
if platform.system() != 'Windows':
|
||||||
|
conn.cursor().execute("SET SESSION wait_timeout=600")
|
||||||
|
|
||||||
|
df = pd.read_sql(sql, conn, params=params, parse_dates=parse_dates)
|
||||||
|
|
||||||
|
# Windows平台需要手动关闭游标
|
||||||
|
if platform.system() == 'Windows':
|
||||||
|
conn.cursor().close()
|
||||||
|
|
||||||
|
self.log.info("Query executed successfully", rows=len(df))
|
||||||
|
return df
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.log.error("SQL query failed",
|
||||||
|
sql=sql,
|
||||||
|
params=params,
|
||||||
|
error=str(e),
|
||||||
|
exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def insert_from_df(self, table_name: str, df: pd.DataFrame,
|
||||||
|
chunk_size: int = 1000, replace: bool = False) -> int:
|
||||||
|
"""
|
||||||
|
将DataFrame数据插入到数据库表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name (str): 目标表名
|
||||||
|
df (pd.DataFrame): 要插入的数据
|
||||||
|
chunk_size (int): 分批插入大小
|
||||||
|
replace (bool): 是否替换现有数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 插入的总行数
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MySQLError: 如果插入失败
|
||||||
|
"""
|
||||||
|
if df.empty:
|
||||||
|
self.log.warning("Attempted to insert empty DataFrame", table=table_name)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
self.log.debug("Preparing to insert DataFrame",
|
||||||
|
table=table_name,
|
||||||
|
rows=len(df),
|
||||||
|
chunk_size=chunk_size)
|
||||||
|
|
||||||
|
try:
|
||||||
|
method = 'replace' if replace else 'append'
|
||||||
|
total_rows = 0
|
||||||
|
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
# 各平台不同的分批策略
|
||||||
|
if platform.system() == 'Windows':
|
||||||
|
chunk_size = min(chunk_size, 500) # Windows上减小批次
|
||||||
|
|
||||||
|
for i in range(0, len(df), chunk_size):
|
||||||
|
chunk = df.iloc[i:i + chunk_size]
|
||||||
|
|
||||||
|
# macOS需要特殊处理datetime
|
||||||
|
if platform.system() == 'Darwin':
|
||||||
|
for col in chunk.select_dtypes(include=['datetime64']):
|
||||||
|
chunk[col] = chunk[col].dt.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
|
||||||
|
chunk.to_sql(
|
||||||
|
table_name,
|
||||||
|
conn,
|
||||||
|
if_exists=method,
|
||||||
|
index=False,
|
||||||
|
method='multi'
|
||||||
|
)
|
||||||
|
total_rows += len(chunk)
|
||||||
|
method = 'append' # 第一次之后都使用追加模式
|
||||||
|
self.log.trace(f"Inserted chunk {i // chunk_size + 1}",
|
||||||
|
rows=len(chunk),
|
||||||
|
total_inserted=total_rows)
|
||||||
|
|
||||||
|
self.log.info("Data inserted successfully",
|
||||||
|
table=table_name,
|
||||||
|
total_rows=total_rows)
|
||||||
|
return total_rows
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.log.error("Data insertion failed",
|
||||||
|
table=table_name,
|
||||||
|
error=str(e),
|
||||||
|
exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def update_from_df(self, table_name: str, df: pd.DataFrame,
|
||||||
|
key_columns: Union[str, List[str]]) -> int:
|
||||||
|
"""
|
||||||
|
使用DataFrame数据更新数据库表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name (str): 目标表名
|
||||||
|
df (pd.DataFrame): 包含更新数据
|
||||||
|
key_columns (Union[str, List[str]]): 用于匹配记录的关键列
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 更新的总行数
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MySQLError: 如果更新失败
|
||||||
|
"""
|
||||||
|
if df.empty:
|
||||||
|
self.log.warning("Attempted to update with empty DataFrame", table=table_name)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
self.log.debug("Preparing to update table from DataFrame",
|
||||||
|
table=table_name,
|
||||||
|
key_columns=key_columns,
|
||||||
|
rows=len(df))
|
||||||
|
|
||||||
|
try:
|
||||||
|
if isinstance(key_columns, str):
|
||||||
|
key_columns = [key_columns]
|
||||||
|
|
||||||
|
total_updated = 0
|
||||||
|
conn = self.begin_transaction()
|
||||||
|
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# 获取表结构信息
|
||||||
|
table_info = self._get_table_info(table_name)
|
||||||
|
columns = [col for col in df.columns if col in table_info]
|
||||||
|
|
||||||
|
# 构建UPDATE语句模板
|
||||||
|
set_clause = ', '.join([f"{col}=%s" for col in columns if col not in key_columns])
|
||||||
|
where_clause = ' AND '.join([f"{col}=%s" for col in key_columns])
|
||||||
|
|
||||||
|
update_sql = f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}"
|
||||||
|
self.log.trace("Generated update SQL", sql=update_sql)
|
||||||
|
|
||||||
|
# 准备数据
|
||||||
|
update_data = []
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
# SET部分的值
|
||||||
|
set_values = [row[col] for col in columns if col not in key_columns]
|
||||||
|
# WHERE部分的值
|
||||||
|
key_values = [row[col] for col in key_columns]
|
||||||
|
update_data.append(tuple(set_values + key_values))
|
||||||
|
|
||||||
|
# 执行批量更新
|
||||||
|
cursor.executemany(update_sql, update_data)
|
||||||
|
total_updated = cursor.rowcount
|
||||||
|
|
||||||
|
self.commit_transaction(conn)
|
||||||
|
self.log.info("Data updated successfully",
|
||||||
|
table=table_name,
|
||||||
|
rows_updated=total_updated)
|
||||||
|
return total_updated
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.rollback_transaction(conn)
|
||||||
|
raise
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.log.error("Data update failed",
|
||||||
|
table=table_name,
|
||||||
|
error=str(e),
|
||||||
|
exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _get_table_info(self, table_name: str) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
获取表结构信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name (str): 表名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, str]: 列名到类型的映射
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MySQLError: 如果查询失败
|
||||||
|
"""
|
||||||
|
sql = f"""
|
||||||
|
SELECT column_name, data_type
|
||||||
|
FROM information_schema.columns
|
||||||
|
WHERE table_schema = %s AND table_name = %s
|
||||||
|
"""
|
||||||
|
|
||||||
|
params = (self.config['database'], table_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(sql, params)
|
||||||
|
result = cursor.fetchall()
|
||||||
|
return {row['column_name']: row['data_type'] for row in result}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.log.error("Failed to get table info",
|
||||||
|
table=table_name,
|
||||||
|
error=str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
推断DataFrame各列的SQL类型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df (pd.DataFrame): 输入数据框
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, str]: 列名到SQL类型的映射
|
||||||
|
"""
|
||||||
|
type_mapping = {
|
||||||
|
'int64': 'BIGINT',
|
||||||
|
'float64': 'DOUBLE',
|
||||||
|
'datetime64[ns]': 'DATETIME',
|
||||||
|
'object': 'TEXT',
|
||||||
|
'bool': 'TINYINT(1)',
|
||||||
|
'category': 'VARCHAR(255)'
|
||||||
|
}
|
||||||
|
|
||||||
|
sql_types = {}
|
||||||
|
for col, dtype in df.dtypes.items():
|
||||||
|
dtype_str = str(dtype)
|
||||||
|
sql_types[col] = type_mapping.get(dtype_str, 'TEXT')
|
||||||
|
|
||||||
|
self.log.debug("Mapped DataFrame types to SQL types",
|
||||||
|
mappings=sql_types)
|
||||||
|
return sql_types
|
||||||
|
|
||||||
|
def create_table_from_df(self, table_name: str, df: pd.DataFrame,
|
||||||
|
primary_key: Union[str, List[str], None] = None) -> bool:
|
||||||
|
"""
|
||||||
|
根据DataFrame结构创建表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name (str): 表名
|
||||||
|
df (pd.DataFrame): 参考数据框
|
||||||
|
primary_key (Union[str, List[str], None]): 主键列
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否创建成功
|
||||||
|
"""
|
||||||
|
if self.table_exists(table_name):
|
||||||
|
self.log.warning("Table already exists", table=table_name)
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.log.debug("Creating new table from DataFrame schema",
|
||||||
|
table=table_name,
|
||||||
|
columns=list(df.columns))
|
||||||
|
|
||||||
|
try:
|
||||||
|
sql_types = self.df_to_sql_type(df)
|
||||||
|
columns_sql = []
|
||||||
|
|
||||||
|
for col, sql_type in sql_types.items():
|
||||||
|
col_def = f"{col} {sql_type}"
|
||||||
|
columns_sql.append(col_def)
|
||||||
|
|
||||||
|
if primary_key:
|
||||||
|
if isinstance(primary_key, str):
|
||||||
|
primary_key = [primary_key]
|
||||||
|
pk_columns = [col for col in primary_key if col in sql_types]
|
||||||
|
if pk_columns:
|
||||||
|
columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})")
|
||||||
|
self.log.trace("Set primary key",
|
||||||
|
table=table_name,
|
||||||
|
primary_key=pk_columns)
|
||||||
|
|
||||||
|
create_sql = f"CREATE TABLE {table_name} (\n {',\n '.join(columns_sql)}\n)"
|
||||||
|
|
||||||
|
self.execute_sql(create_sql)
|
||||||
|
self.log.info("Table created successfully", table=table_name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.log.error("Failed to create table",
|
||||||
|
table=table_name,
|
||||||
|
error=str(e),
|
||||||
|
exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None,
|
||||||
|
fetch: bool = False) -> Union[int, List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
执行SQL语句
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sql (str): SQL语句
|
||||||
|
params (Union[tuple, dict, None]): 参数
|
||||||
|
fetch (bool): 是否获取结果
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[int, List[Dict[str, Any]]]:
|
||||||
|
- 如果是INSERT/UPDATE/DELETE,返回影响的行数
|
||||||
|
- 如果是SELECT且fetch=True,返回结果列表
|
||||||
|
"""
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
conn = self.get_connection()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Linux/macOS需要更长的执行时间
|
||||||
|
if platform.system() != 'Windows':
|
||||||
|
cursor.execute("SET SESSION max_execution_time=600000")
|
||||||
|
|
||||||
|
cursor.execute(sql, params)
|
||||||
|
|
||||||
|
if fetch:
|
||||||
|
result = cursor.fetchall()
|
||||||
|
self.log.debug("Query executed", rows=len(result))
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
affected_rows = cursor.rowcount
|
||||||
|
self.log.debug("Update executed", affected_rows=affected_rows)
|
||||||
|
return affected_rows
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.log.error("SQL execution failed",
|
||||||
|
sql=sql,
|
||||||
|
params=params,
|
||||||
|
error=str(e),
|
||||||
|
exc_info=True)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if cursor:
|
||||||
|
cursor.close()
|
||||||
|
if conn:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def begin_transaction(self) -> pymysql.connections.Connection:
|
||||||
|
"""开始事务"""
|
||||||
|
try:
|
||||||
|
conn = self.get_connection()
|
||||||
|
conn.autocommit(False)
|
||||||
|
|
||||||
|
# macOS需要特殊处理事务隔离级别
|
||||||
|
if platform.system() == 'Darwin':
|
||||||
|
conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED")
|
||||||
|
|
||||||
|
self.log.debug("Transaction started")
|
||||||
|
return conn
|
||||||
|
except Exception as e:
|
||||||
|
self.log.error("Begin transaction failed", error=str(e), exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def commit_transaction(self, conn: pymysql.connections.Connection) -> None:
|
||||||
|
"""提交事务"""
|
||||||
|
try:
|
||||||
|
conn.commit()
|
||||||
|
self.log.debug("Transaction committed")
|
||||||
|
except Exception as e:
|
||||||
|
self.log.error("Commit failed", error=str(e))
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def rollback_transaction(self, conn: pymysql.connections.Connection) -> None:
|
||||||
|
"""回滚事务"""
|
||||||
|
try:
|
||||||
|
conn.rollback()
|
||||||
|
self.log.warning("Transaction rolled back")
|
||||||
|
except Exception as e:
|
||||||
|
self.log.error("Rollback failed", error=str(e))
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def table_exists(self, table_name: str) -> bool:
|
||||||
|
"""检查表是否存在"""
|
||||||
|
sql = """
|
||||||
|
SELECT COUNT(*) as count
|
||||||
|
FROM `information_schema`.`tables`
|
||||||
|
WHERE `table_schema` = %s AND `table_name` = %s
|
||||||
|
"""
|
||||||
|
|
||||||
|
params = (self.config['database'], table_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.execute_sql(sql, params, fetch=True)
|
||||||
|
exists = result[0]['count'] > 0
|
||||||
|
self.log.debug("Checked table existence",
|
||||||
|
table=table_name,
|
||||||
|
exists=exists)
|
||||||
|
return exists
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def drop_table(self, table_name: str) -> bool:
|
||||||
|
"""删除表"""
|
||||||
|
if not self.table_exists(table_name):
|
||||||
|
self.log.warning("Table does not exist", table=table_name)
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.execute_sql(f"DROP TABLE {table_name}")
|
||||||
|
self.log.info("Table dropped successfully", table=table_name)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
self.log.error("Failed to drop table",
|
||||||
|
table=table_name,
|
||||||
|
error=str(e),
|
||||||
|
exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_pool_status(self) -> Dict[str, int]:
|
||||||
|
"""获取连接池状态"""
|
||||||
|
return {
|
||||||
|
'max': self._pool._maxconnections,
|
||||||
|
'active': self._pool._connections,
|
||||||
|
'idle': len(self._pool._idle_cache),
|
||||||
|
'shared': len(self._pool._shared_cache)
|
||||||
|
}
|
||||||
|
|
||||||
|
def validate_connection(self) -> bool:
|
||||||
|
"""验证连接是否有效"""
|
||||||
|
try:
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
|
cursor.execute("SELECT 1")
|
||||||
|
return cursor.fetchone()[0] == 1
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""析构函数"""
|
||||||
|
if hasattr(self, '_pool'):
|
||||||
|
try:
|
||||||
|
self._pool.close()
|
||||||
|
self.log.info("Connection pool closed")
|
||||||
|
except Exception as e:
|
||||||
|
self.log.error("Failed to close pool", error=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
# 平台特定的默认配置
|
||||||
|
def get_default_config():
|
||||||
|
"""获取各平台默认配置"""
|
||||||
|
current_platform = platform.system()
|
||||||
|
|
||||||
|
base_config = {
|
||||||
|
'host': 'localhost',
|
||||||
|
'port': 3306,
|
||||||
|
'user': 'root',
|
||||||
|
'password': '123123',
|
||||||
|
'database': 'intelligence_system',
|
||||||
|
'max_connections': 5
|
||||||
|
}
|
||||||
|
|
||||||
|
if current_platform == 'Windows':
|
||||||
|
return {
|
||||||
|
**base_config,
|
||||||
|
'connect_timeout': 10,
|
||||||
|
'read_timeout': 30,
|
||||||
|
'write_timeout': 30
|
||||||
|
}
|
||||||
|
elif current_platform == 'Darwin':
|
||||||
|
return {
|
||||||
|
**base_config,
|
||||||
|
'connect_timeout': 15,
|
||||||
|
'read_timeout': 60,
|
||||||
|
'write_timeout': 60,
|
||||||
|
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
|
||||||
|
}
|
||||||
|
else: # Linux和其他平台
|
||||||
|
return {
|
||||||
|
**base_config,
|
||||||
|
'connect_timeout': 15,
|
||||||
|
'read_timeout': 60,
|
||||||
|
'write_timeout': 60
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 使用示例
|
||||||
|
db = MySQLAgent(get_default_config())
|
||||||
|
|
||||||
|
# 测试连接
|
||||||
|
if db.validate_connection():
|
||||||
|
print("Database connection successful")
|
||||||
|
|
||||||
|
# 获取数据库版本
|
||||||
|
version = db.query_to_df("SELECT VERSION() as version")
|
||||||
|
print(f"Database version: {version['version'].iloc[0]}")
|
||||||
|
|
||||||
|
# 查看连接池状态
|
||||||
|
print("Connection pool status:", db.get_pool_status())
|
||||||
|
else:
|
||||||
|
print("Failed to connect to database")
|
||||||
@@ -1 +0,0 @@
|
|||||||
{"a":{"0":1},"b":{"0":2}}
|
|
||||||
+291
@@ -0,0 +1,291 @@
|
|||||||
|
import unittest
|
||||||
|
import pandas as pd
|
||||||
|
from datetime import datetime
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
import pymysql
|
||||||
|
from storage.mysql_agent import MySQLAgent
|
||||||
|
import platform
|
||||||
|
|
||||||
|
class TestMySQLAgent(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
"""初始化测试环境和测试表"""
|
||||||
|
# 创建唯一的测试数据库名
|
||||||
|
cls.test_db_name = "test_db_" + datetime.now().strftime("%Y%m%d%H%M%S")
|
||||||
|
cls.test_table = "test_table_" + datetime.now().strftime("%Y%m%d%H%M%S")
|
||||||
|
|
||||||
|
# 基础配置
|
||||||
|
cls.base_config = {
|
||||||
|
'host': 'localhost',
|
||||||
|
'port': 3306,
|
||||||
|
'user': 'root',
|
||||||
|
'password': '123123',
|
||||||
|
'max_connections': 10
|
||||||
|
}
|
||||||
|
|
||||||
|
# 创建测试数据库
|
||||||
|
cls._create_test_database()
|
||||||
|
|
||||||
|
# 初始化数据库连接
|
||||||
|
cls.db = MySQLAgent({
|
||||||
|
**cls.base_config,
|
||||||
|
'database': cls.test_db_name
|
||||||
|
})
|
||||||
|
|
||||||
|
# 创建测试表
|
||||||
|
test_data = pd.DataFrame({
|
||||||
|
'id': [1, 2, 3],
|
||||||
|
'name': ['Test1', 'Test2', 'Test3'],
|
||||||
|
'value': [10.5, 20.3, 30.8],
|
||||||
|
'created_at': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03'])
|
||||||
|
})
|
||||||
|
|
||||||
|
cls.db.create_table_from_df(cls.test_table, test_data, primary_key='id')
|
||||||
|
cls.db.insert_from_df(cls.test_table, test_data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _create_test_database(cls):
|
||||||
|
"""创建测试数据库"""
|
||||||
|
# 使用临时连接创建数据库
|
||||||
|
temp_conn = pymysql.connect(
|
||||||
|
host=cls.base_config['host'],
|
||||||
|
port=cls.base_config['port'],
|
||||||
|
user=cls.base_config['user'],
|
||||||
|
password=cls.base_config['password'],
|
||||||
|
charset='utf8mb4'
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with temp_conn.cursor() as cursor:
|
||||||
|
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}")
|
||||||
|
cursor.execute(f"USE {cls.test_db_name}")
|
||||||
|
cursor.execute("SET GLOBAL max_connections = 100")
|
||||||
|
temp_conn.commit()
|
||||||
|
finally:
|
||||||
|
temp_conn.close()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
"""清理测试数据库"""
|
||||||
|
if hasattr(cls, 'db') and cls.db:
|
||||||
|
# 删除测试表
|
||||||
|
if cls.db.table_exists(cls.test_table):
|
||||||
|
cls.db.drop_table(cls.test_table)
|
||||||
|
|
||||||
|
# 删除测试数据库
|
||||||
|
temp_conn = pymysql.connect(
|
||||||
|
host=cls.base_config['host'],
|
||||||
|
port=cls.base_config['port'],
|
||||||
|
user=cls.base_config['user'],
|
||||||
|
password=cls.base_config['password'],
|
||||||
|
charset='utf8mb4'
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with temp_conn.cursor() as cursor:
|
||||||
|
cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}")
|
||||||
|
temp_conn.commit()
|
||||||
|
finally:
|
||||||
|
temp_conn.close()
|
||||||
|
|
||||||
|
def test_01_connection(self):
|
||||||
|
"""测试数据库连接"""
|
||||||
|
version = self.db.query_to_df("SELECT VERSION() as version")
|
||||||
|
self.assertIsNotNone(version)
|
||||||
|
print(f"\nDatabase version: {version['version'].iloc[0]}")
|
||||||
|
print(f"Running on: {platform.system()} {platform.release()}")
|
||||||
|
|
||||||
|
def test_02_query_to_df(self):
|
||||||
|
"""测试查询返回DataFrame"""
|
||||||
|
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id > %s", (1,))
|
||||||
|
self.assertEqual(len(df), 2)
|
||||||
|
self.assertIsInstance(df, pd.DataFrame)
|
||||||
|
print("\nQuery result sample:")
|
||||||
|
print(df.head())
|
||||||
|
|
||||||
|
def test_03_insert_from_df(self):
|
||||||
|
"""测试DataFrame插入"""
|
||||||
|
new_data = pd.DataFrame({
|
||||||
|
'id': [4, 5],
|
||||||
|
'name': ['Test4', 'Test5'],
|
||||||
|
'value': [40.1, 50.2],
|
||||||
|
'created_at': pd.to_datetime(['2023-01-04', '2023-01-05'])
|
||||||
|
})
|
||||||
|
|
||||||
|
rows = self.db.insert_from_df(self.test_table, new_data)
|
||||||
|
self.assertEqual(rows, 2)
|
||||||
|
|
||||||
|
# 验证数据
|
||||||
|
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id >= 4")
|
||||||
|
self.assertEqual(len(df), 2)
|
||||||
|
self.assertEqual(df['name'].tolist(), ['Test4', 'Test5'])
|
||||||
|
|
||||||
|
def test_04_update_from_df(self):
|
||||||
|
"""测试DataFrame更新"""
|
||||||
|
update_data = pd.DataFrame({
|
||||||
|
'id': [1, 2],
|
||||||
|
'name': ['Updated1', 'Updated2']
|
||||||
|
})
|
||||||
|
|
||||||
|
rows = self.db.update_from_df(self.test_table, update_data, 'id')
|
||||||
|
self.assertGreaterEqual(rows, 2)
|
||||||
|
|
||||||
|
# 验证更新
|
||||||
|
df = self.db.query_to_df(f"SELECT name FROM {self.test_table} WHERE id IN (1,2)")
|
||||||
|
self.assertIn('Updated1', df['name'].values)
|
||||||
|
self.assertIn('Updated2', df['name'].values)
|
||||||
|
|
||||||
|
def test_05_transaction(self):
|
||||||
|
"""测试事务处理"""
|
||||||
|
conn = self.db.begin_transaction()
|
||||||
|
try:
|
||||||
|
# 执行多个操作
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(f"UPDATE {self.test_table} SET value = 99.9 WHERE id = 1")
|
||||||
|
cursor.execute(f"UPDATE {self.test_table} SET value = 88.8 WHERE id = 2")
|
||||||
|
|
||||||
|
# 验证事务内修改
|
||||||
|
cursor.execute(f"SELECT value FROM {self.test_table} WHERE id = 1")
|
||||||
|
self.assertEqual(cursor.fetchone()['value'], 99.9)
|
||||||
|
|
||||||
|
self.db.commit_transaction(conn)
|
||||||
|
except Exception:
|
||||||
|
self.db.rollback_transaction(conn)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# 验证提交后的修改
|
||||||
|
df = self.db.query_to_df(f"SELECT value FROM {self.test_table} WHERE id IN (1,2)")
|
||||||
|
self.assertIn(99.9, df['value'].values)
|
||||||
|
self.assertIn(88.8, df['value'].values)
|
||||||
|
|
||||||
|
def test_06_large_data(self):
|
||||||
|
"""测试大数据量操作"""
|
||||||
|
# 生成测试数据
|
||||||
|
large_data = pd.DataFrame({
|
||||||
|
'id': range(1000, 2000),
|
||||||
|
'name': [f"Item_{i}" for i in range(1000, 2000)],
|
||||||
|
'value': [i * 0.1 for i in range(1000, 2000)],
|
||||||
|
'created_at': pd.date_range('2023-01-01', periods=1000)
|
||||||
|
})
|
||||||
|
|
||||||
|
# Windows平台使用更小的批次
|
||||||
|
chunk_size = 100 if platform.system() == 'Windows' else 500
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
rows = self.db.insert_from_df(self.test_table, large_data, chunk_size=chunk_size)
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
|
self.assertEqual(rows, 1000)
|
||||||
|
print(f"\nInserted 1000 rows in {elapsed:.2f}s (chunk_size={chunk_size})")
|
||||||
|
|
||||||
|
# 验证数据
|
||||||
|
df = self.db.query_to_df(f"SELECT COUNT(*) as cnt FROM {self.test_table} WHERE id >= 1000")
|
||||||
|
self.assertEqual(df['cnt'].iloc[0], 1000)
|
||||||
|
|
||||||
|
def test_07_concurrent_access(self):
|
||||||
|
"""测试并发访问"""
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
def worker(i):
|
||||||
|
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id = %s", (i % 5 + 1,))
|
||||||
|
return len(df)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
with ThreadPoolExecutor(max_workers=20) as executor:
|
||||||
|
results = list(executor.map(worker, range(100)))
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
self.assertEqual(sum(results), 100)
|
||||||
|
print(f"\nCompleted 100 concurrent queries in {elapsed:.2f}s")
|
||||||
|
|
||||||
|
|
||||||
|
class TestPlatformSpecific(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
"""创建临时测试数据库"""
|
||||||
|
cls.test_db_name = "test_db_platform_" + datetime.now().strftime("%Y%m%d%H%M%S")
|
||||||
|
cls.base_config = {
|
||||||
|
'host': 'localhost',
|
||||||
|
'port': 3306,
|
||||||
|
'user': 'root',
|
||||||
|
'password': '123123',
|
||||||
|
'max_connections': 10
|
||||||
|
}
|
||||||
|
|
||||||
|
# 创建数据库
|
||||||
|
temp_conn = pymysql.connect(
|
||||||
|
host=cls.base_config['host'],
|
||||||
|
port=cls.base_config['port'],
|
||||||
|
user=cls.base_config['user'],
|
||||||
|
password=cls.base_config['password'],
|
||||||
|
charset='utf8mb4'
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with temp_conn.cursor() as cursor:
|
||||||
|
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}")
|
||||||
|
temp_conn.commit()
|
||||||
|
finally:
|
||||||
|
temp_conn.close()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
"""删除临时测试数据库"""
|
||||||
|
temp_conn = pymysql.connect(
|
||||||
|
host=cls.base_config['host'],
|
||||||
|
port=cls.base_config['port'],
|
||||||
|
user=cls.base_config['user'],
|
||||||
|
password=cls.base_config['password'],
|
||||||
|
charset='utf8mb4'
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with temp_conn.cursor() as cursor:
|
||||||
|
cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}")
|
||||||
|
temp_conn.commit()
|
||||||
|
finally:
|
||||||
|
temp_conn.close()
|
||||||
|
|
||||||
|
def test_windows_timeout(self):
|
||||||
|
"""测试Windows平台超时处理"""
|
||||||
|
if platform.system() != 'Windows':
|
||||||
|
self.skipTest("Only runs on Windows")
|
||||||
|
|
||||||
|
config = {
|
||||||
|
**self.base_config,
|
||||||
|
'database': self.test_db_name,
|
||||||
|
'connect_timeout': 1,
|
||||||
|
'read_timeout': 1
|
||||||
|
}
|
||||||
|
|
||||||
|
db = MySQLAgent(config)
|
||||||
|
|
||||||
|
# 测试短超时查询
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
db.query_to_df("SELECT SLEEP(2)")
|
||||||
|
self.fail("Should have timed out")
|
||||||
|
except Exception as e:
|
||||||
|
self.assertIn("timed out", str(e))
|
||||||
|
print(f"\nWindows timeout test: {str(e)}")
|
||||||
|
|
||||||
|
def test_macos_ssl(self):
|
||||||
|
"""测试macOS SSL连接"""
|
||||||
|
if platform.system() != 'Darwin':
|
||||||
|
self.skipTest("Only runs on macOS")
|
||||||
|
|
||||||
|
config = {
|
||||||
|
**self.base_config,
|
||||||
|
'database': self.test_db_name,
|
||||||
|
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
|
||||||
|
}
|
||||||
|
|
||||||
|
db = MySQLAgent(config)
|
||||||
|
version = db.query_to_df("SELECT VERSION() as version")
|
||||||
|
self.assertIsNotNone(version)
|
||||||
|
print(f"\nmacOS SSL connection successful: {version['version'].iloc[0]}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
+13
-3
@@ -3,6 +3,8 @@ import pandas as pd
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from utils.file_handler import FileHandler
|
from utils.file_handler import FileHandler
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def temp_dir(tmp_path):
|
def temp_dir(tmp_path):
|
||||||
@@ -11,11 +13,13 @@ def temp_dir(tmp_path):
|
|||||||
test_dir.mkdir()
|
test_dir.mkdir()
|
||||||
return test_dir
|
return test_dir
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def file_handler(temp_dir):
|
def file_handler(temp_dir):
|
||||||
"""创建FileHandler实例"""
|
"""创建FileHandler实例"""
|
||||||
return FileHandler(temp_dir)
|
return FileHandler(temp_dir)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_dataframe():
|
def sample_dataframe():
|
||||||
"""创建测试用DataFrame"""
|
"""创建测试用DataFrame"""
|
||||||
@@ -25,6 +29,7 @@ def sample_dataframe():
|
|||||||
'value': [10.5, 20.3, 30.1]
|
'value': [10.5, 20.3, 30.1]
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_text_file(temp_dir):
|
def sample_text_file(temp_dir):
|
||||||
"""创建测试文本文件"""
|
"""创建测试文本文件"""
|
||||||
@@ -55,30 +60,33 @@ def test_read_write_csv(file_handler, temp_dir, sample_dataframe):
|
|||||||
assert df.shape == (3, 3)
|
assert df.shape == (3, 3)
|
||||||
assert list(df.columns) == ['id', 'name', 'value']
|
assert list(df.columns) == ['id', 'name', 'value']
|
||||||
|
|
||||||
|
|
||||||
def test_read_write_json(file_handler, temp_dir, sample_dataframe):
|
def test_read_write_json(file_handler, temp_dir, sample_dataframe):
|
||||||
"""测试JSON文件读写"""
|
"""测试JSON文件读写"""
|
||||||
test_file = temp_dir / "test.json"
|
test_file = temp_dir / "test.json"
|
||||||
|
|
||||||
# 测试写入
|
# 测试写入
|
||||||
write_result = file_handler.write_file(test_file, sample_dataframe)
|
write_result = file_handler.write_file(test_file, sample_dataframe)
|
||||||
assert write_result.iloc[0]['success'] == True
|
assert write_result.iloc[0]['success'] == True
|
||||||
|
|
||||||
# 测试读取
|
# 测试读取
|
||||||
df = file_handler.read_file(test_file)
|
df = file_handler.read_file(test_file)
|
||||||
assert df.shape == (3, 3)
|
assert df.shape == (3, 3)
|
||||||
|
|
||||||
|
|
||||||
def test_read_write_excel(file_handler, temp_dir, sample_dataframe):
|
def test_read_write_excel(file_handler, temp_dir, sample_dataframe):
|
||||||
"""测试Excel文件读写"""
|
"""测试Excel文件读写"""
|
||||||
test_file = temp_dir / "test.xlsx"
|
test_file = temp_dir / "test.xlsx"
|
||||||
|
|
||||||
# 测试写入
|
# 测试写入
|
||||||
write_result = file_handler.write_file(test_file, sample_dataframe)
|
write_result = file_handler.write_file(test_file, sample_dataframe)
|
||||||
assert write_result.iloc[0]['success'] == True
|
assert write_result.iloc[0]['success'] == True
|
||||||
|
|
||||||
# 测试读取
|
# 测试读取
|
||||||
df = file_handler.read_file(test_file)
|
df = file_handler.read_file(test_file)
|
||||||
assert df.shape == (3, 3)
|
assert df.shape == (3, 3)
|
||||||
|
|
||||||
|
|
||||||
def test_read_write_csv(file_handler, temp_dir, sample_dataframe):
|
def test_read_write_csv(file_handler, temp_dir, sample_dataframe):
|
||||||
"""测试CSV文件读写"""
|
"""测试CSV文件读写"""
|
||||||
test_file = temp_dir / "test.csv"
|
test_file = temp_dir / "test.csv"
|
||||||
@@ -119,6 +127,7 @@ def test_file_operations(file_handler, sample_text_file):
|
|||||||
assert delete_df.iloc[0]['deleted'] == True
|
assert delete_df.iloc[0]['deleted'] == True
|
||||||
assert not os.path.exists(sample_text_file)
|
assert not os.path.exists(sample_text_file)
|
||||||
|
|
||||||
|
|
||||||
def test_directory_operations(file_handler, temp_dir):
|
def test_directory_operations(file_handler, temp_dir):
|
||||||
"""测试目录操作"""
|
"""测试目录操作"""
|
||||||
test_dir = temp_dir / "subdir"
|
test_dir = temp_dir / "subdir"
|
||||||
@@ -160,6 +169,7 @@ def test_zip_operations(file_handler, temp_dir, sample_dataframe):
|
|||||||
assert os.path.exists(extract_dir / "file1.txt")
|
assert os.path.exists(extract_dir / "file1.txt")
|
||||||
assert os.path.exists(extract_dir / "file2.csv")
|
assert os.path.exists(extract_dir / "file2.csv")
|
||||||
|
|
||||||
|
|
||||||
def test_zip_directory(file_handler, temp_dir):
|
def test_zip_directory(file_handler, temp_dir):
|
||||||
"""测试目录压缩"""
|
"""测试目录压缩"""
|
||||||
# 创建测试目录结构
|
# 创建测试目录结构
|
||||||
@@ -174,4 +184,4 @@ def test_zip_directory(file_handler, temp_dir):
|
|||||||
zip_path = temp_dir / "dir.zip"
|
zip_path = temp_dir / "dir.zip"
|
||||||
zip_result = file_handler.zip_dir(test_dir, zip_path)
|
zip_result = file_handler.zip_dir(test_dir, zip_path)
|
||||||
assert zip_result.iloc[0]['zipped'] == True
|
assert zip_result.iloc[0]['zipped'] == True
|
||||||
assert zip_result.iloc[0]['file_count'] == 2
|
assert zip_result.iloc[0]['file_count'] == 2
|
||||||
|
|||||||
@@ -433,7 +433,9 @@ class FileHandler:
|
|||||||
# ---------------------------- 测试用例 ----------------------------
|
# ---------------------------- 测试用例 ----------------------------
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 初始化处理器(自动处理跨平台路径)
|
# 初始化处理器(自动处理跨平台路径)
|
||||||
handler = FileHandler("test_data")
|
project_root = next(p for p in Path(__file__).resolve().parents if
|
||||||
|
(p / '.git').exists() or (p / 'pyproject.toml').exists() or (p / 'requirements.txt').exists())
|
||||||
|
handler = FileHandler(project_root / "test")
|
||||||
|
|
||||||
# 测试路径标准化
|
# 测试路径标准化
|
||||||
test_paths = [
|
test_paths = [
|
||||||
|
|||||||
+3
-1
@@ -6,6 +6,7 @@ import platform
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
||||||
|
|
||||||
class CrossPlatformLog:
|
class CrossPlatformLog:
|
||||||
"""跨平台日志系统(支持Linux/Windows/Mac)"""
|
"""跨平台日志系统(支持Linux/Windows/Mac)"""
|
||||||
|
|
||||||
@@ -94,5 +95,6 @@ class CrossPlatformLog:
|
|||||||
"""获取模块专属日志器"""
|
"""获取模块专属日志器"""
|
||||||
return logger.bind(module=module_name or "__main__")
|
return logger.bind(module=module_name or "__main__")
|
||||||
|
|
||||||
|
|
||||||
# 初始化全局日志器
|
# 初始化全局日志器
|
||||||
log = CrossPlatformLog().get_logger()
|
log = CrossPlatformLog().get_logger()
|
||||||
|
|||||||
@@ -1,233 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
网络工具模块
|
|
||||||
功能:
|
|
||||||
1. 自动重试的HTTP请求
|
|
||||||
2. 多平台代理配置支持
|
|
||||||
3. DNS缓存优化
|
|
||||||
4. 连接超时与SSL验证
|
|
||||||
5. 用户代理轮换
|
|
||||||
"""
|
|
||||||
|
|
||||||
import socket
|
|
||||||
import time
|
|
||||||
import random
|
|
||||||
import platform
|
|
||||||
from typing import Optional, Dict, Any, Union
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
from functools import lru_cache
|
|
||||||
import requests
|
|
||||||
from requests.adapters import HTTPAdapter
|
|
||||||
from urllib3.util.retry import Retry
|
|
||||||
from urllib3.connection import HTTPConnection
|
|
||||||
import os
|
|
||||||
|
|
||||||
# 类型别名
|
|
||||||
TimeoutType = Union[float, tuple]
|
|
||||||
|
|
||||||
|
|
||||||
class NetworkUtils:
|
|
||||||
"""跨平台网络操作工具类"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.system = platform.system().lower()
|
|
||||||
self._dns_cache = {}
|
|
||||||
self._setup_platform_specific()
|
|
||||||
|
|
||||||
def _setup_platform_specific(self):
|
|
||||||
"""平台相关初始化"""
|
|
||||||
if self.system == 'windows':
|
|
||||||
# Windows默认关闭TCP_NODELAY
|
|
||||||
HTTPConnection.default_socket_options = (
|
|
||||||
HTTPConnection.default_socket_options + [
|
|
||||||
(socket.IPPROTO_TCP, socket.TCP_NODELAY, 0)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
elif self.system == 'linux':
|
|
||||||
# Linux启用TCP快速打开
|
|
||||||
HTTPConnection.default_socket_options = (
|
|
||||||
HTTPConnection.default_socket_options + [
|
|
||||||
(socket.IPPROTO_TCP, socket.TCP_FASTOPEN, 5)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@lru_cache(maxsize=512)
|
|
||||||
def _resolve_hostname(hostname: str) -> str:
|
|
||||||
"""DNS缓存解析(跨线程安全)"""
|
|
||||||
try:
|
|
||||||
return socket.gethostbyname(hostname)
|
|
||||||
except socket.gaierror:
|
|
||||||
return hostname # 失败时返回原始域名
|
|
||||||
|
|
||||||
def get_session(
|
|
||||||
self,
|
|
||||||
retries: int = 3,
|
|
||||||
backoff_factor: float = 0.5,
|
|
||||||
timeout: TimeoutType = (3.05, 30),
|
|
||||||
proxy: Optional[str] = None,
|
|
||||||
verify_ssl: bool = True
|
|
||||||
) -> requests.Session:
|
|
||||||
"""
|
|
||||||
获取配置好的请求会话
|
|
||||||
|
|
||||||
参数:
|
|
||||||
retries: 重试次数
|
|
||||||
backoff_factor: 重试间隔系数
|
|
||||||
timeout: (连接超时, 读取超时)秒数
|
|
||||||
proxy: 代理地址(如 'http://user:pass@proxy:port')
|
|
||||||
verify_ssl: 是否验证SSL证书
|
|
||||||
"""
|
|
||||||
session = requests.Session()
|
|
||||||
|
|
||||||
# 重试策略
|
|
||||||
retry = Retry(
|
|
||||||
total=retries,
|
|
||||||
backoff_factor=backoff_factor,
|
|
||||||
status_forcelist=[500, 502, 503, 504],
|
|
||||||
allowed_methods=frozenset(['GET', 'POST', 'PUT', 'DELETE'])
|
|
||||||
)
|
|
||||||
|
|
||||||
# 适配器配置
|
|
||||||
adapter = HTTPAdapter(
|
|
||||||
max_retries=retry,
|
|
||||||
pool_connections=20,
|
|
||||||
pool_maxsize=100
|
|
||||||
)
|
|
||||||
|
|
||||||
# 挂载适配器
|
|
||||||
session.mount('http://', adapter)
|
|
||||||
session.mount('https://', adapter)
|
|
||||||
|
|
||||||
# 代理配置
|
|
||||||
if proxy:
|
|
||||||
session.proxies = {
|
|
||||||
'http': proxy,
|
|
||||||
'https': proxy
|
|
||||||
}
|
|
||||||
|
|
||||||
# 请求默认配置
|
|
||||||
session.request = self._wrap_request(
|
|
||||||
session.request,
|
|
||||||
timeout=timeout,
|
|
||||||
verify=verify_ssl
|
|
||||||
)
|
|
||||||
|
|
||||||
return session
|
|
||||||
|
|
||||||
def _wrap_request(self, original_request, **defaults):
|
|
||||||
"""包装请求方法添加默认参数"""
|
|
||||||
|
|
||||||
def wrapped(method, url, **kwargs):
|
|
||||||
# 处理DNS缓存
|
|
||||||
parsed = urlparse(url)
|
|
||||||
if parsed.hostname:
|
|
||||||
kwargs['hooks'] = kwargs.get('hooks', {})
|
|
||||||
kwargs['hooks']['pre_request'] = lambda r: setattr(
|
|
||||||
r, '_orig_host', r.url
|
|
||||||
)
|
|
||||||
url = url.replace(
|
|
||||||
parsed.hostname,
|
|
||||||
self._resolve_hostname(parsed.hostname),
|
|
||||||
1
|
|
||||||
)
|
|
||||||
|
|
||||||
# 合并默认参数
|
|
||||||
for k, v in defaults.items():
|
|
||||||
kwargs.setdefault(k, v)
|
|
||||||
|
|
||||||
return original_request(method, url, **kwargs)
|
|
||||||
|
|
||||||
return wrapped
|
|
||||||
|
|
||||||
def get_user_agent(self) -> str:
|
|
||||||
"""获取随机用户代理(兼容各平台)"""
|
|
||||||
agents = [
|
|
||||||
# Windows
|
|
||||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
|
|
||||||
# macOS
|
|
||||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 12_4) AppleWebKit/605.1.15",
|
|
||||||
# Linux
|
|
||||||
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36",
|
|
||||||
# Mobile
|
|
||||||
"Mozilla/5.0 (iPhone; CPU iPhone OS 15_5 like Mac OS X) AppleWebKit/605.1.15"
|
|
||||||
]
|
|
||||||
return random.choice(agents)
|
|
||||||
|
|
||||||
def check_connection(
|
|
||||||
self,
|
|
||||||
url: str = "https://www.baidu.com",
|
|
||||||
timeout: float = 5.0
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
检查网络连接状态
|
|
||||||
|
|
||||||
参数:
|
|
||||||
url: 测试用的URL
|
|
||||||
timeout: 超时时间(秒)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
session = self.get_session(retries=0, timeout=timeout)
|
|
||||||
session.head(url)
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def download_file(
|
|
||||||
self,
|
|
||||||
url: str,
|
|
||||||
save_path: str,
|
|
||||||
chunk_size: int = 8192,
|
|
||||||
progress_callback: Optional[callable] = None
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
下载大文件支持断点续传
|
|
||||||
|
|
||||||
参数:
|
|
||||||
url: 文件URL
|
|
||||||
save_path: 本地保存路径
|
|
||||||
chunk_size: 分块大小(字节)
|
|
||||||
progress_callback: 进度回调函数(bytes_downloaded, total_size)
|
|
||||||
"""
|
|
||||||
session = self.get_session()
|
|
||||||
headers = {}
|
|
||||||
|
|
||||||
# 检查本地文件部分下载
|
|
||||||
if os.path.exists(save_path):
|
|
||||||
downloaded = os.path.getsize(save_path)
|
|
||||||
headers['Range'] = f'bytes={downloaded}-'
|
|
||||||
|
|
||||||
try:
|
|
||||||
with session.get(url, headers=headers, stream=True) as r:
|
|
||||||
r.raise_for_status()
|
|
||||||
|
|
||||||
# 获取文件总大小
|
|
||||||
total_size = int(r.headers.get('content-length', 0)) + downloaded
|
|
||||||
|
|
||||||
# 追加模式写入
|
|
||||||
mode = 'ab' if headers.get('Range') else 'wb'
|
|
||||||
with open(save_path, mode) as f:
|
|
||||||
for chunk in r.iter_content(chunk_size=chunk_size):
|
|
||||||
if chunk: # 过滤keep-alive chunks
|
|
||||||
f.write(chunk)
|
|
||||||
downloaded += len(chunk)
|
|
||||||
if progress_callback:
|
|
||||||
progress_callback(downloaded, total_size)
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
print(f"下载失败: {str(e)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# 全局实例(线程安全)
|
|
||||||
network_utils = NetworkUtils()
|
|
||||||
|
|
||||||
|
|
||||||
# 快捷方法(兼容旧代码)
|
|
||||||
def get_session(*args, **kwargs):
|
|
||||||
return network_utils.get_session(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def check_connection(*args, **kwargs):
|
|
||||||
return network_utils.check_connection(*args, **kwargs)
|
|
||||||
Reference in New Issue
Block a user