数据库操作

This commit is contained in:
2025-08-06 16:24:17 +08:00
parent aa0b71a90b
commit c8d268647f
11 changed files with 1344 additions and 706 deletions
-56
View File
@@ -1,56 +0,0 @@
[loggers]
keys=root,data_collector,api_client,alert
[handlers]
keys=consoleHandler,fileHandler,errorFileHandler
[formatters]
keys=standardFormatter,detailedFormatter
[logger_root]
level=INFO
handlers=consoleHandler,fileHandler
[logger_data_collector]
level=DEBUG
handlers=fileHandler
qualname=data_collector
propagate=0
[logger_api_client]
level=INFO
handlers=fileHandler,errorFileHandler
qualname=api_client
propagate=0
[logger_alert]
level=WARNING
handlers=consoleHandler,errorFileHandler
qualname=alert
propagate=0
[handler_consoleHandler]
class=StreamHandler
level=INFO
formatter=standardFormatter
args=(sys.stdout,)
[handler_fileHandler]
class=logging.handlers.TimedRotatingFileHandler
level=DEBUG
formatter=detailedFormatter
args=('%(log_dir)s/application.log', 'midnight', 1, 30, 'utf-8')
[handler_errorFileHandler]
class=logging.handlers.TimedRotatingFileHandler
level=WARNING
formatter=detailedFormatter
args=('%(log_dir)s/error.log', 'midnight', 1, 90, 'utf-8')
[formatter_standardFormatter]
format=%(asctime)s [%(levelname)-5s] %(name)s - %(message)s
datefmt=%Y-%m-%d %H:%M:%S
[formatter_detailedFormatter]
format=%(asctime)s [%(levelname)-5s] %(name)s (%(filename)s:%(lineno)d) - %(message)s
datefmt=%Y-%m-%d %H:%M:%S
+371 -155
View File
@@ -1,193 +1,409 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
系统配置模块
功能:
1. 支持多平台路径适配
2. 环境变量与配置文件优先级管理
3. 敏感信息加密存储
4. 配置热更新检测
"""
import os
import json
import sys
import platform
import pandas as pd
import pymysql
from pymysql import cursors
from pymysql.err import MySQLError
from dbutils.pooled_db import PooledDB
from typing import Union, List, Dict, Any, Optional, Tuple
import threading
from datetime import datetime
import numpy as np
from pathlib import Path
from typing import Dict, Any, Optional
import dotenv
from cryptography.fernet import Fernet
class ConfigManager:
def __init__(self, app_name: str = "intelligence_system"):
"""
初始化配置管理器
# 导入您的日志系统
from utils.logger import log as logger
参数:
app_name: 应用名称(用于生成配置目录)
"""
self.system = platform.system().lower()
self.app_name = app_name
self._config = {}
self._secret_key = None
class MySQLAgent:
"""
全平台兼容的MySQL数据库操作类
支持Windows/macOS/Linux系统
"""
# 初始化配置路径
self.config_dir = self._get_config_dir()
os.makedirs(self.config_dir, exist_ok=True)
_instance = None
_lock = threading.Lock()
# 加载配置顺序
self._load_defaults()
self._load_env_file()
self._load_user_config()
# 各平台特定的配置
PLATFORM_CONFIG = {
'Windows': {
'socket_timeout': 30,
'connect_timeout': 10,
'ssl': None
},
'Darwin': { # macOS
'socket_timeout': 60,
'connect_timeout': 15,
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
},
'Linux': {
'socket_timeout': 60,
'connect_timeout': 15,
'ssl': None
}
}
# 初始化加密模块
self._init_encryption()
def __new__(cls, *args, **kwargs):
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
def _get_config_dir(self) -> str:
"""获取适合当前平台的配置目录"""
if self.system == 'windows':
return os.path.join(os.environ['APPDATA'], self.app_name)
elif self.system == 'darwin': # macOS
return os.path.expanduser(f"~/Library/Application Support/{self.app_name}")
else: # Linux及其他Unix-like
return os.path.expanduser(f"~/.config/{self.app_name}")
def __init__(self, config: dict = None):
if hasattr(self, '_pool') and self._pool:
return
def _load_defaults(self):
"""加载默认配置"""
self._config = {
"system": {
"log_level": "INFO",
"max_threads": os.cpu_count() or 4
},
"api": {
"newsapi": {"endpoint": "https://newsapi.org/v2"},
"weibo": {"version": "2"}
},
"paths": {
"data_dir": os.path.join(self.config_dir, "data"),
"cache_dir": os.path.join(self.config_dir, "cache")
}
if not config:
from config.settings import DATABASE_CONFIG
config = DATABASE_CONFIG
# 获取当前平台配置
current_platform = platform.system()
platform_config = self.PLATFORM_CONFIG.get(current_platform, {})
# 基础配置
self.config = {
'host': config.get('host', 'localhost'),
'port': config.get('port', 3306),
'user': config.get('user', 'root'),
'password': config.get('password', ''),
'database': config.get('database', 'intelligence_system'),
'charset': config.get('charset', 'utf8mb4'),
'cursorclass': cursors.DictCursor,
'autocommit': True,
**platform_config # 合并平台特定配置
}
def _load_env_file(self):
"""加载.env环境变量文件"""
env_path = Path(self.config_dir) / ".env"
if env_path.exists():
dotenv.load_dotenv(env_path)
# 处理各平台路径差异
if current_platform == 'Windows':
self.config['ssl'] = None # Windows通常不需要SSL配置
# 环境变量覆盖配置
if os.getenv("LOG_LEVEL"):
self._config["system"]["log_level"] = os.getenv("LOG_LEVEL")
# macOS特殊处理
elif current_platform == 'Darwin':
if not os.path.exists(self.config['ssl']['ca']):
self.config['ssl'] = None
logger.warning("macOS SSL certificate not found, disabling SSL")
def _load_user_config(self):
"""加载用户自定义配置"""
config_file = Path(self.config_dir) / "config.json"
self.pool_size = config.get('max_connections', 5)
self._pool = self._create_pool()
self.logger = logger.bind(module=f"MySQLAgent({current_platform})")
def _create_pool(self) -> PooledDB:
"""创建跨平台兼容的连接池"""
try:
if config_file.exists():
with open(config_file, 'r', encoding='utf-8') as f:
user_config = json.load(f)
self._deep_update(self._config, user_config)
# 各平台连接池参数调整
pool_config = {
'creator': pymysql,
'maxconnections': self.pool_size,
'mincached': 1,
'maxcached': 3,
'blocking': True,
'ping': 1, # 定期检查连接有效性
**self.config
}
# Windows平台需要更短的超时时间
if platform.system() == 'Windows':
pool_config['ping'] = 0 # Windows上ping有时不稳定
pool = PooledDB(**pool_config)
self.logger.info(f"Connection pool created for {platform.system()}")
return pool
except Exception as e:
print(f"加载用户配置失败: {str(e)}")
self.logger.critical("Failed to create connection pool",
error=str(e),
exc_info=True)
raise
def _init_encryption(self):
"""初始化配置加密模块"""
key_file = Path(self.config_dir) / ".secret.key"
if key_file.exists():
with open(key_file, 'rb') as f:
self._secret_key = f.read()
else:
self._secret_key = Fernet.generate_key()
with open(key_file, 'wb') as f:
f.write(self._secret_key)
key_file.chmod(0o600) # 设置密钥文件权限
def _handle_path(self, path: str) -> str:
"""处理跨平台路径问题"""
if platform.system() == 'Windows':
return path.replace('/', '\\')
return path
def _deep_update(self, original: Dict, update: Dict) -> Dict:
"""深度合并字典"""
for key, value in update.items():
if isinstance(value, dict) and key in original:
original[key] = self._deep_update(original.get(key, {}), value)
else:
original[key] = value
return original
def get(self, key: str, default: Any = None) -> Any:
def get_connection(self) -> pymysql.connections.Connection:
"""
获取配置项(支持点分路径
示例: get("api.newsapi.endpoint")
获取数据库连接(跨平台兼容
Returns:
pymysql.connections.Connection: 数据库连接
Raises:
MySQLError: 如果连接失败
"""
keys = key.split('.')
value = self._config
try:
for k in keys:
value = value[k]
return value
except (KeyError, TypeError):
return default
conn = self._pool.connection()
def set(self, key: str, value: Any, persist: bool = False):
# macOS需要特殊处理SSL
if platform.system() == 'Darwin' and self.config.get('ssl'):
conn.ping(reconnect=True)
self.logger.trace("Connection obtained")
return conn
except Exception as e:
error_msg = str(e)
# Windows特定错误处理
if platform.system() == 'Windows' and "timed out" in error_msg:
self.logger.warning("Windows connection timeout, retrying...")
return self._retry_connection()
self.logger.error("Connection failed",
error=error_msg,
exc_info=True)
raise
def _retry_connection(self, max_retries: int = 3) -> pymysql.connections.Connection:
"""Windows平台连接重试机制"""
for attempt in range(max_retries):
try:
conn = self._pool.connection()
self.logger.info(f"Connection established after {attempt+1} attempts")
return conn
except Exception:
if attempt == max_retries - 1:
raise
import time
time.sleep(1)
def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None,
parse_dates: Union[List[str], bool] = True) -> pd.DataFrame:
"""
设置配置项
参数:
persist: 是否保存到用户配置文件
跨平台兼容的SQL查询
Args:
sql (str): SQL语句
params (Union[tuple, dict, None]): 参数
parse_dates (Union[List[str], bool]): 日期解析
Returns:
pd.DataFrame: 查询结果
"""
keys = key.split('.')
config_ref = self._config
try:
with self.get_connection() as conn:
# Linux/macOS需要更长的查询超时
if platform.system() != 'Windows':
conn.cursor().execute("SET SESSION wait_timeout=600")
for k in keys[:-1]:
if k not in config_ref:
config_ref[k] = {}
config_ref = config_ref[k]
df = pd.read_sql(sql, conn, params=params, parse_dates=parse_dates)
config_ref[keys[-1]] = value
# Windows平台需要手动关闭游标
if platform.system() == 'Windows':
conn.cursor().close()
if persist:
self._save_user_config()
self.logger.info("Query executed", rows=len(df))
return df
def encrypt_value(self, plaintext: str) -> str:
"""加密敏感信息"""
fernet = Fernet(self._secret_key)
return fernet.encrypt(plaintext.encode()).decode()
except Exception as e:
self.logger.error("Query failed",
sql=sql,
params=params,
error=str(e),
exc_info=True)
raise
def decrypt_value(self, ciphertext: str) -> str:
"""解密敏感信息"""
fernet = Fernet(self._secret_key)
return fernet.decrypt(ciphertext.encode()).decode()
def insert_from_df(self, table_name: str, df: pd.DataFrame,
chunk_size: int = 1000, replace: bool = False) -> int:
"""
跨平台数据插入
def _save_user_config(self):
"""保存用户配置到文件"""
config_file = Path(self.config_dir) / "config.json"
with open(config_file, 'w', encoding='utf-8') as f:
json.dump(self._config, f, indent=2, ensure_ascii=False)
Args:
table_name (str): 表名
df (pd.DataFrame): 数据
chunk_size (int): 分批大小
replace (bool): 是否替换
def reload(self):
"""重新加载所有配置"""
self._config = {}
self._load_defaults()
self._load_env_file()
self._load_user_config()
Returns:
int: 插入行数
"""
if df.empty:
self.logger.warning("Empty DataFrame", table=table_name)
return 0
# 全局配置实例
config = ConfigManager()
try:
method = 'replace' if replace else 'append'
total_rows = 0
# 快捷访问方法(兼容旧代码)
def get_config(key: str, default: Optional[Any] = None) -> Any:
return config.get(key, default)
with self.get_connection() as conn:
# 各平台不同的分批策略
if platform.system() == 'Windows':
chunk_size = min(chunk_size, 500) # Windows上减小批次
def set_config(key: str, value: Any, persist: bool = False):
config.set(key, value, persist)
for i in range(0, len(df), chunk_size):
chunk = df.iloc[i:i + chunk_size]
# 测试代码
# macOS需要特殊处理datetime
if platform.system() == 'Darwin':
for col in chunk.select_dtypes(include=['datetime64']):
chunk[col] = chunk[col].dt.strftime('%Y-%m-%d %H:%M:%S')
chunk.to_sql(
table_name,
conn,
if_exists=method,
index=False,
method='multi'
)
total_rows += len(chunk)
method = 'append'
self.logger.info("Data inserted", table=table_name, rows=total_rows)
return total_rows
except Exception as e:
self.logger.error("Insert failed",
table=table_name,
error=str(e),
exc_info=True)
raise
def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None,
fetch: bool = False) -> Union[int, List[Dict[str, Any]]]:
"""
跨平台SQL执行
Args:
sql (str): SQL语句
params (Union[tuple, dict, None]): 参数
fetch (bool): 是否获取结果
Returns:
Union[int, List[Dict[str, Any]]]: 结果
"""
conn = None
cursor = None
try:
conn = self.get_connection()
cursor = conn.cursor()
# Linux/macOS需要更长的执行时间
if platform.system() != 'Windows':
cursor.execute("SET SESSION max_execution_time=600000")
cursor.execute(sql, params)
if fetch:
result = cursor.fetchall()
self.logger.debug("Query executed", rows=len(result))
return result
else:
affected_rows = cursor.rowcount
self.logger.debug("Update executed", affected_rows=affected_rows)
return affected_rows
except Exception as e:
self.logger.error("SQL execution failed",
sql=sql,
params=params,
error=str(e),
exc_info=True)
raise
finally:
if cursor:
cursor.close()
if conn:
conn.close()
def begin_transaction(self) -> pymysql.connections.Connection:
"""开始事务(跨平台兼容)"""
try:
conn = self.get_connection()
conn.autocommit(False)
# macOS需要特殊处理事务隔离级别
if platform.system() == 'Darwin':
conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED")
self.logger.debug("Transaction started")
return conn
except Exception as e:
self.logger.error("Begin transaction failed", error=str(e))
raise
def commit_transaction(self, conn: pymysql.connections.Connection) -> None:
"""提交事务(跨平台兼容)"""
try:
conn.commit()
self.logger.debug("Transaction committed")
except Exception as e:
self.logger.error("Commit failed", error=str(e))
raise
finally:
conn.close()
def rollback_transaction(self, conn: pymysql.connections.Connection) -> None:
"""回滚事务(跨平台兼容)"""
try:
conn.rollback()
self.logger.warning("Transaction rolled back")
except Exception as e:
self.logger.error("Rollback failed", error=str(e))
finally:
conn.close()
def __del__(self):
"""析构函数(跨平台资源清理)"""
if hasattr(self, '_pool'):
try:
self._pool.close()
self.logger.info("Connection pool closed")
except Exception as e:
self.logger.error("Failed to close pool", error=str(e))
# 平台特定的默认配置
def get_default_config():
"""获取各平台默认配置"""
current_platform = platform.system()
base_config = {
'host': 'localhost',
'port': 3306,
'user': 'root',
'password': '',
'database': 'intelligence_system',
'max_connections': 5
}
if current_platform == 'Windows':
return {
**base_config,
'connect_timeout': 10,
'read_timeout': 30,
'write_timeout': 30
}
elif current_platform == 'Darwin':
return {
**base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60,
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
}
else: # Linux和其他平台
return {
**base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60
}
# 使用示例
if __name__ == "__main__":
# 设置并保存API密钥(自动加密)
api_key = "your_newsapi_key_here"
encrypted_key = config.encrypt_value(api_key)
config.set("api.newsapi.key", encrypted_key, persist=True)
# 自动获取适合当前平台的配置
config = get_default_config()
# 获取配置示例
print(f"日志级别: {config.get('system.log_level')}")
print(f"NewsAPI端点: {config.get('api.newsapi.endpoint')}")
# 初始化数据库连接
db = MySQLAgent(config)
# 解密敏感信息
stored_key = config.get("api.newsapi.key")
if stored_key:
print(f"解密后的API密钥: {config.decrypt_value(stored_key)}")
# 测试查询
try:
df = db.query_to_df("SELECT VERSION() as version")
print(f"Database version: {df['version'].iloc[0]}")
print(f"Running on: {platform.system()} {platform.release()}")
except Exception as e:
print(f"Error: {str(e)}")
+1 -1
View File
@@ -17,7 +17,7 @@ from typing import Dict, List, Any
# 自定义模块
from processors.data_processor import DataProcessor
from storage.database import IntelligenceDB
from storage.mysql_agent import IntelligenceDB
from applications.reporter import ReportGenerator
from applications.alert import AlertService
from utils.logger import setup_logging
-255
View File
@@ -1,255 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
数据库存储模块
功能:
1. 统一数据库接口(SQLite/MySQL/PostgreSQL
2. 自动处理多平台路径问题
3. 连接池管理
4. 数据加密存储
"""
import os
import platform
import sqlite3
import threading
from pathlib import Path
from typing import Optional, Union, Dict, Any, List
from threading import Lock
import logging
from cryptography.fernet import Fernet
# 类型别名
QueryParams = Union[tuple, Dict[str, Any]]
class DatabaseManager:
"""数据库统一管理类"""
def __init__(self, db_config: Dict[str, Any]):
"""
初始化数据库连接
参数:
db_config: 配置字典,包含:
- type: 'sqlite'|'mysql'|'postgresql'
- database: 数据库名/路径
- [可选] host, port, user, password
"""
self.config = db_config
self._lock = Lock()
self._connection_pool = {}
self._setup_crypto()
# 自动创建SQLite目录
if db_config['type'] == 'sqlite':
self._ensure_sqlite_dir()
def _ensure_sqlite_dir(self):
"""确保SQLite数据库目录存在"""
db_path = Path(self.config['database'])
if not db_path.parent.exists():
try:
db_path.parent.mkdir(parents=True, mode=0o755)
except Exception as e:
logging.error(f"创建数据库目录失败: {str(e)}")
def _setup_crypto(self):
"""初始化加密模块"""
key_file = Path(os.path.expanduser("~/.db_encryption.key"))
if key_file.exists():
with open(key_file, 'rb') as f:
self._fernet = Fernet(f.read())
else:
self._fernet = Fernet.generate_key()
with open(key_file, 'wb') as f:
f.write(self._fernet)
key_file.chmod(0o600) # 仅限当前用户读写
def get_connection(self, reuse=True):
"""
获取数据库连接(线程安全)
参数:
reuse: 是否复用现有连接(默认True)
"""
thread_id = threading.get_ident()
with self._lock:
if reuse and thread_id in self._connection_pool:
conn = self._connection_pool[thread_id]
try:
# 检查连接是否有效
conn.execute("SELECT 1")
return conn
except:
del self._connection_pool[thread_id]
# 创建新连接
if self.config['type'] == 'sqlite':
conn = self._create_sqlite_connection()
elif self.config['type'] == 'mysql':
conn = self._create_mysql_connection()
elif self.config['type'] == 'postgresql':
conn = self._create_pg_connection()
else:
raise ValueError("不支持的数据库类型")
self._connection_pool[thread_id] = conn
return conn
def _create_sqlite_connection(self) -> sqlite3.Connection:
"""创建SQLite连接(兼容多平台路径)"""
db_path = self.config['database']
# Windows路径处理
if platform.system() == 'Windows' and not db_path.startswith(('\\\\', '/')):
db_path = os.path.abspath(db_path)
conn = sqlite3.connect(db_path, timeout=15)
conn.execute("PRAGMA journal_mode=WAL") # 写前日志提升并发
conn.execute("PRAGMA synchronous=NORMAL")
conn.row_factory = sqlite3.Row # 支持字典式访问
return conn
def _create_mysql_connection(self):
"""创建MySQL连接(需安装PyMySQL"""
import pymysql
return pymysql.connect(
host=self.config.get('host', 'localhost'),
port=self.config.get('port', 3306),
user=self.config.get('user', 'root'),
password=self.config.get('password', ''),
database=self.config['database'],
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor
)
def _create_pg_connection(self):
"""创建PostgreSQL连接(需安装psycopg2"""
import psycopg2
return psycopg2.connect(
host=self.config.get('host', 'localhost'),
port=self.config.get('port', 5432),
user=self.config.get('user', 'postgres'),
password=self.config.get('password', ''),
dbname=self.config['database']
)
def execute(
self,
query: str,
params: Optional[QueryParams] = None,
return_lastrowid: bool = False
) -> Union[int, None]:
conn = self.get_connection()
try:
with conn:
cursor = conn.cursor()
if params is not None:
cursor.execute(query, params)
else:
cursor.execute(query) # ✅ 无参数时不要传 params
return cursor.lastrowid if return_lastrowid else None
finally:
if not return_lastrowid:
self._release_connection()
def query(
self,
query: str,
params: Optional[QueryParams] = None,
fetchall: bool = True
) -> Union[List[Dict], Dict]:
conn = self.get_connection()
try:
cursor = conn.cursor()
if params is not None:
cursor.execute(query, params)
else:
cursor.execute(query) # ✅ 无参数时不要传 None
if self.config['type'] == 'sqlite':
result = cursor.fetchall()
if not fetchall and result:
return dict(result[0])
return [dict(row) for row in result]
else:
return cursor.fetchall() if fetchall else cursor.fetchone()
finally:
self._release_connection()
def _release_connection(self):
"""释放当前线程的连接(SQLite除外)"""
if self.config['type'] != 'sqlite':
thread_id = threading.get_ident()
with self._lock:
if thread_id in self._connection_pool:
self._connection_pool[thread_id].close()
del self._connection_pool[thread_id]
def encrypt_data(self, plaintext: str) -> str:
"""加密敏感数据"""
return self._fernet.encrypt(plaintext.encode()).decode()
def decrypt_data(self, ciphertext: str) -> str:
"""解密数据"""
return self._fernet.decrypt(ciphertext.encode()).decode()
def close_all(self):
"""关闭所有数据库连接"""
with self._lock:
for conn in self._connection_pool.values():
try:
conn.close()
except:
pass
self._connection_pool.clear()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close_all()
# 全局SQLite实例(默认配置)
def get_default_db() -> DatabaseManager:
"""获取默认SQLite数据库(跨平台路径处理)"""
system = platform.system().lower()
if system == 'windows':
db_path = os.path.join(os.getenv('APPDATA'), 'app_name/data.db')
elif system == 'darwin':
db_path = os.path.expanduser('~/Library/Application Support/app_name/data.db')
else:
db_path = '/var/lib/app_name/data.db' if os.access('/var/lib', os.W_OK) \
else os.path.expanduser('~/.local/share/app_name/data.db')
return DatabaseManager({
'type': 'sqlite',
'database': db_path
})
# 测试代码
if __name__ == "__main__":
with get_default_db() as db:
# 创建测试表
db.execute("""
CREATE TABLE IF NOT EXISTS test_table (
id INTEGER PRIMARY KEY,
name TEXT,
secret TEXT
)
""")
# 插入加密数据
secret = db.encrypt_data("新敏感信息")
db.execute(
"INSERT INTO test_table (name, secret) VALUES (?, ?)",
("测试记录", secret)
)
# 查询并解密
row = db.query("SELECT * FROM test_table", fetchall=False)
print(f"解密数据: {db.decrypt_data(row['name'])}")
+662
View File
@@ -0,0 +1,662 @@
import os
import sys
import platform
import pandas as pd
import pymysql
from pymysql import cursors
from pymysql.err import MySQLError
from dbutils.pooled_db import PooledDB
from typing import Union, List, Dict, Any, Optional, Tuple
import threading
from datetime import datetime
import numpy as np
from pathlib import Path
# 导入日志系统
from utils.logger import log
class MySQLAgent:
"""
全平台兼容的MySQL数据库操作类
支持Windows/macOS/Linux系统
配置参数从外部传入
"""
_instance = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, config: dict):
"""
初始化MySQL数据库连接
Args:
config (dict): 数据库配置字典,包含以下键:
- host: 数据库主机
- port: 端口
- user: 用户名
- password: 密码
- database: 数据库名
- [可选] charset: 字符集(默认utf8mb4)
- [可选] max_connections: 最大连接数(默认5)
- [可选] connect_timeout: 连接超时(秒)
- [可选] read_timeout: 读取超时(秒)
- [可选] write_timeout: 写入超时(秒)
- [可选] ssl: SSL配置
"""
if hasattr(self, '_pool') and self._pool:
return
# 基础配置
required_keys = ['host', 'port', 'user', 'password', 'database']
if not all(key in config for key in required_keys):
raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}")
self.config = {
'host': config['host'],
'port': config['port'],
'user': config['user'],
'password': config['password'],
'database': config['database'],
'charset': config.get('charset', 'utf8mb4'),
'cursorclass': cursors.DictCursor,
'autocommit': True,
'connect_timeout': config.get('connect_timeout', 10),
'read_timeout': config.get('read_timeout', 30),
'write_timeout': config.get('write_timeout', 30),
'ssl': config.get('ssl')
}
# 初始化log
current_platform = platform.system()
self.log = log.bind(module=f"MySQLAgent({current_platform})")
# 创建连接池
self.pool_size = config.get('max_connections', 5)
self._pool = self._create_pool()
def _create_pool(self) -> PooledDB:
"""创建连接池"""
try:
# 使用包装函数确保线程安全
def connect():
conn = pymysql.connect(**self.config)
conn.threadsafety = 1 # 显式设置线程安全级别
return conn
pool = PooledDB(
creator=connect,
mincached=1,
maxcached=3,
maxconnections=self.pool_size,
blocking=True,
ping=1
)
self.log.info("Connection pool created")
return pool
except Exception as e:
self.log.critical("Failed to create connection pool",
error=str(e),
exc_info=True)
raise
def get_connection(self) -> pymysql.connections.Connection:
"""
获取数据库连接
Returns:
pymysql.connections.Connection: 数据库连接对象
Raises:
MySQLError: 如果获取连接失败
"""
try:
conn = self._pool.connection()
# macOS需要特殊处理SSL
if platform.system() == 'Darwin' and self.config.get('ssl'):
conn.ping(reconnect=True)
self.log.trace("Database connection obtained")
return conn
except Exception as e:
error_msg = str(e)
# Windows特定错误处理
if platform.system() == 'Windows' and "timed out" in error_msg:
self.log.warning("Windows connection timeout, retrying...")
return self._retry_connection()
self.log.error("Connection failed",
error=error_msg,
exc_info=True)
raise
def _retry_connection(self, max_retries: int = 3) -> pymysql.connections.Connection:
"""Windows平台连接重试机制"""
for attempt in range(max_retries):
try:
conn = self._pool.connection()
self.log.info(f"Connection established after {attempt + 1} attempts")
return conn
except Exception:
if attempt == max_retries - 1:
raise
import time
time.sleep(1)
def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None,
parse_dates: Union[List[str], bool] = True) -> pd.DataFrame:
"""
执行SQL查询并返回DataFrame
Args:
sql (str): SQL查询语句
params (Union[tuple, dict, None]): 查询参数
parse_dates (Union[List[str], bool]): 自动解析日期字段
Returns:
pd.DataFrame: 查询结果
Raises:
MySQLError: 如果查询失败
"""
try:
self.log.debug("Executing SQL query", sql=sql)
with self.get_connection() as conn:
# Linux/macOS需要更长的查询超时
if platform.system() != 'Windows':
conn.cursor().execute("SET SESSION wait_timeout=600")
df = pd.read_sql(sql, conn, params=params, parse_dates=parse_dates)
# Windows平台需要手动关闭游标
if platform.system() == 'Windows':
conn.cursor().close()
self.log.info("Query executed successfully", rows=len(df))
return df
except Exception as e:
self.log.error("SQL query failed",
sql=sql,
params=params,
error=str(e),
exc_info=True)
raise
def insert_from_df(self, table_name: str, df: pd.DataFrame,
chunk_size: int = 1000, replace: bool = False) -> int:
"""
将DataFrame数据插入到数据库表
Args:
table_name (str): 目标表名
df (pd.DataFrame): 要插入的数据
chunk_size (int): 分批插入大小
replace (bool): 是否替换现有数据
Returns:
int: 插入的总行数
Raises:
MySQLError: 如果插入失败
"""
if df.empty:
self.log.warning("Attempted to insert empty DataFrame", table=table_name)
return 0
self.log.debug("Preparing to insert DataFrame",
table=table_name,
rows=len(df),
chunk_size=chunk_size)
try:
method = 'replace' if replace else 'append'
total_rows = 0
with self.get_connection() as conn:
# 各平台不同的分批策略
if platform.system() == 'Windows':
chunk_size = min(chunk_size, 500) # Windows上减小批次
for i in range(0, len(df), chunk_size):
chunk = df.iloc[i:i + chunk_size]
# macOS需要特殊处理datetime
if platform.system() == 'Darwin':
for col in chunk.select_dtypes(include=['datetime64']):
chunk[col] = chunk[col].dt.strftime('%Y-%m-%d %H:%M:%S')
chunk.to_sql(
table_name,
conn,
if_exists=method,
index=False,
method='multi'
)
total_rows += len(chunk)
method = 'append' # 第一次之后都使用追加模式
self.log.trace(f"Inserted chunk {i // chunk_size + 1}",
rows=len(chunk),
total_inserted=total_rows)
self.log.info("Data inserted successfully",
table=table_name,
total_rows=total_rows)
return total_rows
except Exception as e:
self.log.error("Data insertion failed",
table=table_name,
error=str(e),
exc_info=True)
raise
def update_from_df(self, table_name: str, df: pd.DataFrame,
key_columns: Union[str, List[str]]) -> int:
"""
使用DataFrame数据更新数据库表
Args:
table_name (str): 目标表名
df (pd.DataFrame): 包含更新数据
key_columns (Union[str, List[str]]): 用于匹配记录的关键列
Returns:
int: 更新的总行数
Raises:
MySQLError: 如果更新失败
"""
if df.empty:
self.log.warning("Attempted to update with empty DataFrame", table=table_name)
return 0
self.log.debug("Preparing to update table from DataFrame",
table=table_name,
key_columns=key_columns,
rows=len(df))
try:
if isinstance(key_columns, str):
key_columns = [key_columns]
total_updated = 0
conn = self.begin_transaction()
try:
cursor = conn.cursor()
# 获取表结构信息
table_info = self._get_table_info(table_name)
columns = [col for col in df.columns if col in table_info]
# 构建UPDATE语句模板
set_clause = ', '.join([f"{col}=%s" for col in columns if col not in key_columns])
where_clause = ' AND '.join([f"{col}=%s" for col in key_columns])
update_sql = f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}"
self.log.trace("Generated update SQL", sql=update_sql)
# 准备数据
update_data = []
for _, row in df.iterrows():
# SET部分的值
set_values = [row[col] for col in columns if col not in key_columns]
# WHERE部分的值
key_values = [row[col] for col in key_columns]
update_data.append(tuple(set_values + key_values))
# 执行批量更新
cursor.executemany(update_sql, update_data)
total_updated = cursor.rowcount
self.commit_transaction(conn)
self.log.info("Data updated successfully",
table=table_name,
rows_updated=total_updated)
return total_updated
except Exception as e:
self.rollback_transaction(conn)
raise
except Exception as e:
self.log.error("Data update failed",
table=table_name,
error=str(e),
exc_info=True)
raise
def _get_table_info(self, table_name: str) -> Dict[str, str]:
"""
获取表结构信息
Args:
table_name (str): 表名
Returns:
Dict[str, str]: 列名到类型的映射
Raises:
MySQLError: 如果查询失败
"""
sql = f"""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
"""
params = (self.config['database'], table_name)
try:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(sql, params)
result = cursor.fetchall()
return {row['column_name']: row['data_type'] for row in result}
except Exception as e:
self.log.error("Failed to get table info",
table=table_name,
error=str(e))
raise
def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]:
"""
推断DataFrame各列的SQL类型
Args:
df (pd.DataFrame): 输入数据框
Returns:
Dict[str, str]: 列名到SQL类型的映射
"""
type_mapping = {
'int64': 'BIGINT',
'float64': 'DOUBLE',
'datetime64[ns]': 'DATETIME',
'object': 'TEXT',
'bool': 'TINYINT(1)',
'category': 'VARCHAR(255)'
}
sql_types = {}
for col, dtype in df.dtypes.items():
dtype_str = str(dtype)
sql_types[col] = type_mapping.get(dtype_str, 'TEXT')
self.log.debug("Mapped DataFrame types to SQL types",
mappings=sql_types)
return sql_types
def create_table_from_df(self, table_name: str, df: pd.DataFrame,
primary_key: Union[str, List[str], None] = None) -> bool:
"""
根据DataFrame结构创建表
Args:
table_name (str): 表名
df (pd.DataFrame): 参考数据框
primary_key (Union[str, List[str], None]): 主键列
Returns:
bool: 是否创建成功
"""
if self.table_exists(table_name):
self.log.warning("Table already exists", table=table_name)
return False
self.log.debug("Creating new table from DataFrame schema",
table=table_name,
columns=list(df.columns))
try:
sql_types = self.df_to_sql_type(df)
columns_sql = []
for col, sql_type in sql_types.items():
col_def = f"{col} {sql_type}"
columns_sql.append(col_def)
if primary_key:
if isinstance(primary_key, str):
primary_key = [primary_key]
pk_columns = [col for col in primary_key if col in sql_types]
if pk_columns:
columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})")
self.log.trace("Set primary key",
table=table_name,
primary_key=pk_columns)
create_sql = f"CREATE TABLE {table_name} (\n {',\n '.join(columns_sql)}\n)"
self.execute_sql(create_sql)
self.log.info("Table created successfully", table=table_name)
return True
except Exception as e:
self.log.error("Failed to create table",
table=table_name,
error=str(e),
exc_info=True)
return False
def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None,
fetch: bool = False) -> Union[int, List[Dict[str, Any]]]:
"""
执行SQL语句
Args:
sql (str): SQL语句
params (Union[tuple, dict, None]): 参数
fetch (bool): 是否获取结果
Returns:
Union[int, List[Dict[str, Any]]]:
- 如果是INSERT/UPDATE/DELETE,返回影响的行数
- 如果是SELECT且fetch=True,返回结果列表
"""
conn = None
cursor = None
try:
conn = self.get_connection()
cursor = conn.cursor()
# Linux/macOS需要更长的执行时间
if platform.system() != 'Windows':
cursor.execute("SET SESSION max_execution_time=600000")
cursor.execute(sql, params)
if fetch:
result = cursor.fetchall()
self.log.debug("Query executed", rows=len(result))
return result
else:
affected_rows = cursor.rowcount
self.log.debug("Update executed", affected_rows=affected_rows)
return affected_rows
except Exception as e:
self.log.error("SQL execution failed",
sql=sql,
params=params,
error=str(e),
exc_info=True)
raise
finally:
if cursor:
cursor.close()
if conn:
conn.close()
def begin_transaction(self) -> pymysql.connections.Connection:
"""开始事务"""
try:
conn = self.get_connection()
conn.autocommit(False)
# macOS需要特殊处理事务隔离级别
if platform.system() == 'Darwin':
conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED")
self.log.debug("Transaction started")
return conn
except Exception as e:
self.log.error("Begin transaction failed", error=str(e), exc_info=True)
raise
def commit_transaction(self, conn: pymysql.connections.Connection) -> None:
"""提交事务"""
try:
conn.commit()
self.log.debug("Transaction committed")
except Exception as e:
self.log.error("Commit failed", error=str(e))
raise
finally:
conn.close()
def rollback_transaction(self, conn: pymysql.connections.Connection) -> None:
"""回滚事务"""
try:
conn.rollback()
self.log.warning("Transaction rolled back")
except Exception as e:
self.log.error("Rollback failed", error=str(e))
finally:
conn.close()
def table_exists(self, table_name: str) -> bool:
"""检查表是否存在"""
sql = """
SELECT COUNT(*) as count
FROM `information_schema`.`tables`
WHERE `table_schema` = %s AND `table_name` = %s
"""
params = (self.config['database'], table_name)
try:
result = self.execute_sql(sql, params, fetch=True)
exists = result[0]['count'] > 0
self.log.debug("Checked table existence",
table=table_name,
exists=exists)
return exists
except Exception:
return False
def drop_table(self, table_name: str) -> bool:
"""删除表"""
if not self.table_exists(table_name):
self.log.warning("Table does not exist", table=table_name)
return False
try:
self.execute_sql(f"DROP TABLE {table_name}")
self.log.info("Table dropped successfully", table=table_name)
return True
except Exception as e:
self.log.error("Failed to drop table",
table=table_name,
error=str(e),
exc_info=True)
return False
def get_pool_status(self) -> Dict[str, int]:
"""获取连接池状态"""
return {
'max': self._pool._maxconnections,
'active': self._pool._connections,
'idle': len(self._pool._idle_cache),
'shared': len(self._pool._shared_cache)
}
def validate_connection(self) -> bool:
"""验证连接是否有效"""
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute("SELECT 1")
return cursor.fetchone()[0] == 1
except Exception:
return False
def __del__(self):
"""析构函数"""
if hasattr(self, '_pool'):
try:
self._pool.close()
self.log.info("Connection pool closed")
except Exception as e:
self.log.error("Failed to close pool", error=str(e))
# 平台特定的默认配置
def get_default_config():
"""获取各平台默认配置"""
current_platform = platform.system()
base_config = {
'host': 'localhost',
'port': 3306,
'user': 'root',
'password': '123123',
'database': 'intelligence_system',
'max_connections': 5
}
if current_platform == 'Windows':
return {
**base_config,
'connect_timeout': 10,
'read_timeout': 30,
'write_timeout': 30
}
elif current_platform == 'Darwin':
return {
**base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60,
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
}
else: # Linux和其他平台
return {
**base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60
}
if __name__ == "__main__":
# 使用示例
db = MySQLAgent(get_default_config())
# 测试连接
if db.validate_connection():
print("Database connection successful")
# 获取数据库版本
version = db.query_to_df("SELECT VERSION() as version")
print(f"Database version: {version['version'].iloc[0]}")
# 查看连接池状态
print("Connection pool status:", db.get_pool_status())
else:
print("Failed to connect to database")
-1
View File
@@ -1 +0,0 @@
{"a":{"0":1},"b":{"0":2}}
+291
View File
@@ -0,0 +1,291 @@
import unittest
import pandas as pd
from datetime import datetime
import tempfile
import time
import pymysql
from storage.mysql_agent import MySQLAgent
import platform
class TestMySQLAgent(unittest.TestCase):
@classmethod
def setUpClass(cls):
"""初始化测试环境和测试表"""
# 创建唯一的测试数据库名
cls.test_db_name = "test_db_" + datetime.now().strftime("%Y%m%d%H%M%S")
cls.test_table = "test_table_" + datetime.now().strftime("%Y%m%d%H%M%S")
# 基础配置
cls.base_config = {
'host': 'localhost',
'port': 3306,
'user': 'root',
'password': '123123',
'max_connections': 10
}
# 创建测试数据库
cls._create_test_database()
# 初始化数据库连接
cls.db = MySQLAgent({
**cls.base_config,
'database': cls.test_db_name
})
# 创建测试表
test_data = pd.DataFrame({
'id': [1, 2, 3],
'name': ['Test1', 'Test2', 'Test3'],
'value': [10.5, 20.3, 30.8],
'created_at': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03'])
})
cls.db.create_table_from_df(cls.test_table, test_data, primary_key='id')
cls.db.insert_from_df(cls.test_table, test_data)
@classmethod
def _create_test_database(cls):
"""创建测试数据库"""
# 使用临时连接创建数据库
temp_conn = pymysql.connect(
host=cls.base_config['host'],
port=cls.base_config['port'],
user=cls.base_config['user'],
password=cls.base_config['password'],
charset='utf8mb4'
)
try:
with temp_conn.cursor() as cursor:
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}")
cursor.execute(f"USE {cls.test_db_name}")
cursor.execute("SET GLOBAL max_connections = 100")
temp_conn.commit()
finally:
temp_conn.close()
@classmethod
def tearDownClass(cls):
"""清理测试数据库"""
if hasattr(cls, 'db') and cls.db:
# 删除测试表
if cls.db.table_exists(cls.test_table):
cls.db.drop_table(cls.test_table)
# 删除测试数据库
temp_conn = pymysql.connect(
host=cls.base_config['host'],
port=cls.base_config['port'],
user=cls.base_config['user'],
password=cls.base_config['password'],
charset='utf8mb4'
)
try:
with temp_conn.cursor() as cursor:
cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}")
temp_conn.commit()
finally:
temp_conn.close()
def test_01_connection(self):
"""测试数据库连接"""
version = self.db.query_to_df("SELECT VERSION() as version")
self.assertIsNotNone(version)
print(f"\nDatabase version: {version['version'].iloc[0]}")
print(f"Running on: {platform.system()} {platform.release()}")
def test_02_query_to_df(self):
"""测试查询返回DataFrame"""
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id > %s", (1,))
self.assertEqual(len(df), 2)
self.assertIsInstance(df, pd.DataFrame)
print("\nQuery result sample:")
print(df.head())
def test_03_insert_from_df(self):
"""测试DataFrame插入"""
new_data = pd.DataFrame({
'id': [4, 5],
'name': ['Test4', 'Test5'],
'value': [40.1, 50.2],
'created_at': pd.to_datetime(['2023-01-04', '2023-01-05'])
})
rows = self.db.insert_from_df(self.test_table, new_data)
self.assertEqual(rows, 2)
# 验证数据
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id >= 4")
self.assertEqual(len(df), 2)
self.assertEqual(df['name'].tolist(), ['Test4', 'Test5'])
def test_04_update_from_df(self):
"""测试DataFrame更新"""
update_data = pd.DataFrame({
'id': [1, 2],
'name': ['Updated1', 'Updated2']
})
rows = self.db.update_from_df(self.test_table, update_data, 'id')
self.assertGreaterEqual(rows, 2)
# 验证更新
df = self.db.query_to_df(f"SELECT name FROM {self.test_table} WHERE id IN (1,2)")
self.assertIn('Updated1', df['name'].values)
self.assertIn('Updated2', df['name'].values)
def test_05_transaction(self):
"""测试事务处理"""
conn = self.db.begin_transaction()
try:
# 执行多个操作
cursor = conn.cursor()
cursor.execute(f"UPDATE {self.test_table} SET value = 99.9 WHERE id = 1")
cursor.execute(f"UPDATE {self.test_table} SET value = 88.8 WHERE id = 2")
# 验证事务内修改
cursor.execute(f"SELECT value FROM {self.test_table} WHERE id = 1")
self.assertEqual(cursor.fetchone()['value'], 99.9)
self.db.commit_transaction(conn)
except Exception:
self.db.rollback_transaction(conn)
raise
# 验证提交后的修改
df = self.db.query_to_df(f"SELECT value FROM {self.test_table} WHERE id IN (1,2)")
self.assertIn(99.9, df['value'].values)
self.assertIn(88.8, df['value'].values)
def test_06_large_data(self):
"""测试大数据量操作"""
# 生成测试数据
large_data = pd.DataFrame({
'id': range(1000, 2000),
'name': [f"Item_{i}" for i in range(1000, 2000)],
'value': [i * 0.1 for i in range(1000, 2000)],
'created_at': pd.date_range('2023-01-01', periods=1000)
})
# Windows平台使用更小的批次
chunk_size = 100 if platform.system() == 'Windows' else 500
start_time = time.time()
rows = self.db.insert_from_df(self.test_table, large_data, chunk_size=chunk_size)
elapsed = time.time() - start_time
self.assertEqual(rows, 1000)
print(f"\nInserted 1000 rows in {elapsed:.2f}s (chunk_size={chunk_size})")
# 验证数据
df = self.db.query_to_df(f"SELECT COUNT(*) as cnt FROM {self.test_table} WHERE id >= 1000")
self.assertEqual(df['cnt'].iloc[0], 1000)
def test_07_concurrent_access(self):
"""测试并发访问"""
from concurrent.futures import ThreadPoolExecutor
def worker(i):
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id = %s", (i % 5 + 1,))
return len(df)
start_time = time.time()
with ThreadPoolExecutor(max_workers=20) as executor:
results = list(executor.map(worker, range(100)))
elapsed = time.time() - start_time
self.assertEqual(sum(results), 100)
print(f"\nCompleted 100 concurrent queries in {elapsed:.2f}s")
class TestPlatformSpecific(unittest.TestCase):
@classmethod
def setUpClass(cls):
"""创建临时测试数据库"""
cls.test_db_name = "test_db_platform_" + datetime.now().strftime("%Y%m%d%H%M%S")
cls.base_config = {
'host': 'localhost',
'port': 3306,
'user': 'root',
'password': '123123',
'max_connections': 10
}
# 创建数据库
temp_conn = pymysql.connect(
host=cls.base_config['host'],
port=cls.base_config['port'],
user=cls.base_config['user'],
password=cls.base_config['password'],
charset='utf8mb4'
)
try:
with temp_conn.cursor() as cursor:
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}")
temp_conn.commit()
finally:
temp_conn.close()
@classmethod
def tearDownClass(cls):
"""删除临时测试数据库"""
temp_conn = pymysql.connect(
host=cls.base_config['host'],
port=cls.base_config['port'],
user=cls.base_config['user'],
password=cls.base_config['password'],
charset='utf8mb4'
)
try:
with temp_conn.cursor() as cursor:
cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}")
temp_conn.commit()
finally:
temp_conn.close()
def test_windows_timeout(self):
"""测试Windows平台超时处理"""
if platform.system() != 'Windows':
self.skipTest("Only runs on Windows")
config = {
**self.base_config,
'database': self.test_db_name,
'connect_timeout': 1,
'read_timeout': 1
}
db = MySQLAgent(config)
# 测试短超时查询
start_time = time.time()
try:
db.query_to_df("SELECT SLEEP(2)")
self.fail("Should have timed out")
except Exception as e:
self.assertIn("timed out", str(e))
print(f"\nWindows timeout test: {str(e)}")
def test_macos_ssl(self):
"""测试macOS SSL连接"""
if platform.system() != 'Darwin':
self.skipTest("Only runs on macOS")
config = {
**self.base_config,
'database': self.test_db_name,
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
}
db = MySQLAgent(config)
version = db.query_to_df("SELECT VERSION() as version")
self.assertIsNotNone(version)
print(f"\nmacOS SSL connection successful: {version['version'].iloc[0]}")
if __name__ == '__main__':
unittest.main()
+13 -3
View File
@@ -3,6 +3,8 @@ import pandas as pd
import os
from pathlib import Path
from utils.file_handler import FileHandler
from datetime import datetime
@pytest.fixture
def temp_dir(tmp_path):
@@ -11,11 +13,13 @@ def temp_dir(tmp_path):
test_dir.mkdir()
return test_dir
@pytest.fixture
def file_handler(temp_dir):
"""创建FileHandler实例"""
return FileHandler(temp_dir)
@pytest.fixture
def sample_dataframe():
"""创建测试用DataFrame"""
@@ -25,6 +29,7 @@ def sample_dataframe():
'value': [10.5, 20.3, 30.1]
})
@pytest.fixture
def sample_text_file(temp_dir):
"""创建测试文本文件"""
@@ -55,30 +60,33 @@ def test_read_write_csv(file_handler, temp_dir, sample_dataframe):
assert df.shape == (3, 3)
assert list(df.columns) == ['id', 'name', 'value']
def test_read_write_json(file_handler, temp_dir, sample_dataframe):
"""测试JSON文件读写"""
test_file = temp_dir / "test.json"
# 测试写入
write_result = file_handler.write_file(test_file, sample_dataframe)
assert write_result.iloc[0]['success'] == True
assert write_result.iloc[0]['success'] == True
# 测试读取
df = file_handler.read_file(test_file)
assert df.shape == (3, 3)
def test_read_write_excel(file_handler, temp_dir, sample_dataframe):
"""测试Excel文件读写"""
test_file = temp_dir / "test.xlsx"
# 测试写入
write_result = file_handler.write_file(test_file, sample_dataframe)
assert write_result.iloc[0]['success'] == True
assert write_result.iloc[0]['success'] == True
# 测试读取
df = file_handler.read_file(test_file)
assert df.shape == (3, 3)
def test_read_write_csv(file_handler, temp_dir, sample_dataframe):
"""测试CSV文件读写"""
test_file = temp_dir / "test.csv"
@@ -119,6 +127,7 @@ def test_file_operations(file_handler, sample_text_file):
assert delete_df.iloc[0]['deleted'] == True
assert not os.path.exists(sample_text_file)
def test_directory_operations(file_handler, temp_dir):
"""测试目录操作"""
test_dir = temp_dir / "subdir"
@@ -160,6 +169,7 @@ def test_zip_operations(file_handler, temp_dir, sample_dataframe):
assert os.path.exists(extract_dir / "file1.txt")
assert os.path.exists(extract_dir / "file2.csv")
def test_zip_directory(file_handler, temp_dir):
"""测试目录压缩"""
# 创建测试目录结构
@@ -174,4 +184,4 @@ def test_zip_directory(file_handler, temp_dir):
zip_path = temp_dir / "dir.zip"
zip_result = file_handler.zip_dir(test_dir, zip_path)
assert zip_result.iloc[0]['zipped'] == True
assert zip_result.iloc[0]['file_count'] == 2
assert zip_result.iloc[0]['file_count'] == 2
+3 -1
View File
@@ -433,7 +433,9 @@ class FileHandler:
# ---------------------------- 测试用例 ----------------------------
if __name__ == "__main__":
# 初始化处理器(自动处理跨平台路径)
handler = FileHandler("test_data")
project_root = next(p for p in Path(__file__).resolve().parents if
(p / '.git').exists() or (p / 'pyproject.toml').exists() or (p / 'requirements.txt').exists())
handler = FileHandler(project_root / "test")
# 测试路径标准化
test_paths = [
+3 -1
View File
@@ -6,6 +6,7 @@ import platform
from datetime import datetime
import zipfile
class CrossPlatformLog:
"""跨平台日志系统(支持Linux/Windows/Mac"""
@@ -94,5 +95,6 @@ class CrossPlatformLog:
"""获取模块专属日志器"""
return logger.bind(module=module_name or "__main__")
# 初始化全局日志器
log = CrossPlatformLog().get_logger()
log = CrossPlatformLog().get_logger()
-233
View File
@@ -1,233 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
网络工具模块
功能:
1. 自动重试的HTTP请求
2. 多平台代理配置支持
3. DNS缓存优化
4. 连接超时与SSL验证
5. 用户代理轮换
"""
import socket
import time
import random
import platform
from typing import Optional, Dict, Any, Union
from urllib.parse import urlparse
from functools import lru_cache
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from urllib3.connection import HTTPConnection
import os
# 类型别名
TimeoutType = Union[float, tuple]
class NetworkUtils:
"""跨平台网络操作工具类"""
def __init__(self):
self.system = platform.system().lower()
self._dns_cache = {}
self._setup_platform_specific()
def _setup_platform_specific(self):
"""平台相关初始化"""
if self.system == 'windows':
# Windows默认关闭TCP_NODELAY
HTTPConnection.default_socket_options = (
HTTPConnection.default_socket_options + [
(socket.IPPROTO_TCP, socket.TCP_NODELAY, 0)
]
)
elif self.system == 'linux':
# Linux启用TCP快速打开
HTTPConnection.default_socket_options = (
HTTPConnection.default_socket_options + [
(socket.IPPROTO_TCP, socket.TCP_FASTOPEN, 5)
]
)
@staticmethod
@lru_cache(maxsize=512)
def _resolve_hostname(hostname: str) -> str:
"""DNS缓存解析(跨线程安全)"""
try:
return socket.gethostbyname(hostname)
except socket.gaierror:
return hostname # 失败时返回原始域名
def get_session(
self,
retries: int = 3,
backoff_factor: float = 0.5,
timeout: TimeoutType = (3.05, 30),
proxy: Optional[str] = None,
verify_ssl: bool = True
) -> requests.Session:
"""
获取配置好的请求会话
参数:
retries: 重试次数
backoff_factor: 重试间隔系数
timeout: (连接超时, 读取超时)秒数
proxy: 代理地址(如 'http://user:pass@proxy:port'
verify_ssl: 是否验证SSL证书
"""
session = requests.Session()
# 重试策略
retry = Retry(
total=retries,
backoff_factor=backoff_factor,
status_forcelist=[500, 502, 503, 504],
allowed_methods=frozenset(['GET', 'POST', 'PUT', 'DELETE'])
)
# 适配器配置
adapter = HTTPAdapter(
max_retries=retry,
pool_connections=20,
pool_maxsize=100
)
# 挂载适配器
session.mount('http://', adapter)
session.mount('https://', adapter)
# 代理配置
if proxy:
session.proxies = {
'http': proxy,
'https': proxy
}
# 请求默认配置
session.request = self._wrap_request(
session.request,
timeout=timeout,
verify=verify_ssl
)
return session
def _wrap_request(self, original_request, **defaults):
"""包装请求方法添加默认参数"""
def wrapped(method, url, **kwargs):
# 处理DNS缓存
parsed = urlparse(url)
if parsed.hostname:
kwargs['hooks'] = kwargs.get('hooks', {})
kwargs['hooks']['pre_request'] = lambda r: setattr(
r, '_orig_host', r.url
)
url = url.replace(
parsed.hostname,
self._resolve_hostname(parsed.hostname),
1
)
# 合并默认参数
for k, v in defaults.items():
kwargs.setdefault(k, v)
return original_request(method, url, **kwargs)
return wrapped
def get_user_agent(self) -> str:
"""获取随机用户代理(兼容各平台)"""
agents = [
# Windows
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
# macOS
"Mozilla/5.0 (Macintosh; Intel Mac OS X 12_4) AppleWebKit/605.1.15",
# Linux
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36",
# Mobile
"Mozilla/5.0 (iPhone; CPU iPhone OS 15_5 like Mac OS X) AppleWebKit/605.1.15"
]
return random.choice(agents)
def check_connection(
self,
url: str = "https://www.baidu.com",
timeout: float = 5.0
) -> bool:
"""
检查网络连接状态
参数:
url: 测试用的URL
timeout: 超时时间(秒)
"""
try:
session = self.get_session(retries=0, timeout=timeout)
session.head(url)
return True
except Exception:
return False
def download_file(
self,
url: str,
save_path: str,
chunk_size: int = 8192,
progress_callback: Optional[callable] = None
) -> bool:
"""
下载大文件支持断点续传
参数:
url: 文件URL
save_path: 本地保存路径
chunk_size: 分块大小(字节)
progress_callback: 进度回调函数(bytes_downloaded, total_size)
"""
session = self.get_session()
headers = {}
# 检查本地文件部分下载
if os.path.exists(save_path):
downloaded = os.path.getsize(save_path)
headers['Range'] = f'bytes={downloaded}-'
try:
with session.get(url, headers=headers, stream=True) as r:
r.raise_for_status()
# 获取文件总大小
total_size = int(r.headers.get('content-length', 0)) + downloaded
# 追加模式写入
mode = 'ab' if headers.get('Range') else 'wb'
with open(save_path, mode) as f:
for chunk in r.iter_content(chunk_size=chunk_size):
if chunk: # 过滤keep-alive chunks
f.write(chunk)
downloaded += len(chunk)
if progress_callback:
progress_callback(downloaded, total_size)
return True
except Exception as e:
print(f"下载失败: {str(e)}")
return False
# 全局实例(线程安全)
network_utils = NetworkUtils()
# 快捷方法(兼容旧代码)
def get_session(*args, **kwargs):
return network_utils.get_session(*args, **kwargs)
def check_connection(*args, **kwargs):
return network_utils.check_connection(*args, **kwargs)