diff --git a/config/logging.conf b/config/logging.conf deleted file mode 100644 index 5adc1ba..0000000 --- a/config/logging.conf +++ /dev/null @@ -1,56 +0,0 @@ -[loggers] -keys=root,data_collector,api_client,alert - -[handlers] -keys=consoleHandler,fileHandler,errorFileHandler - -[formatters] -keys=standardFormatter,detailedFormatter - -[logger_root] -level=INFO -handlers=consoleHandler,fileHandler - -[logger_data_collector] -level=DEBUG -handlers=fileHandler -qualname=data_collector -propagate=0 - -[logger_api_client] -level=INFO -handlers=fileHandler,errorFileHandler -qualname=api_client -propagate=0 - -[logger_alert] -level=WARNING -handlers=consoleHandler,errorFileHandler -qualname=alert -propagate=0 - -[handler_consoleHandler] -class=StreamHandler -level=INFO -formatter=standardFormatter -args=(sys.stdout,) - -[handler_fileHandler] -class=logging.handlers.TimedRotatingFileHandler -level=DEBUG -formatter=detailedFormatter -args=('%(log_dir)s/application.log', 'midnight', 1, 30, 'utf-8') - -[handler_errorFileHandler] -class=logging.handlers.TimedRotatingFileHandler -level=WARNING -formatter=detailedFormatter -args=('%(log_dir)s/error.log', 'midnight', 1, 90, 'utf-8') - -[formatter_standardFormatter] -format=%(asctime)s [%(levelname)-5s] %(name)s - %(message)s -datefmt=%Y-%m-%d %H:%M:%S - -[formatter_detailedFormatter] -format=%(asctime)s [%(levelname)-5s] %(name)s (%(filename)s:%(lineno)d) - %(message)s -datefmt=%Y-%m-%d %H:%M:%S \ No newline at end of file diff --git a/config/settings.py b/config/settings.py index b37aef8..f119cfb 100644 --- a/config/settings.py +++ b/config/settings.py @@ -1,193 +1,409 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -系统配置模块 -功能: -1. 支持多平台路径适配 -2. 环境变量与配置文件优先级管理 -3. 敏感信息加密存储 -4. 配置热更新检测 -""" - import os -import json +import sys import platform +import pandas as pd +import pymysql +from pymysql import cursors +from pymysql.err import MySQLError +from dbutils.pooled_db import PooledDB +from typing import Union, List, Dict, Any, Optional, Tuple +import threading +from datetime import datetime +import numpy as np from pathlib import Path -from typing import Dict, Any, Optional -import dotenv -from cryptography.fernet import Fernet -class ConfigManager: - def __init__(self, app_name: str = "intelligence_system"): - """ - 初始化配置管理器 +# 导入您的日志系统 +from utils.logger import log as logger - 参数: - app_name: 应用名称(用于生成配置目录) - """ - self.system = platform.system().lower() - self.app_name = app_name - self._config = {} - self._secret_key = None +class MySQLAgent: + """ + 全平台兼容的MySQL数据库操作类 + 支持Windows/macOS/Linux系统 + """ - # 初始化配置路径 - self.config_dir = self._get_config_dir() - os.makedirs(self.config_dir, exist_ok=True) + _instance = None + _lock = threading.Lock() - # 加载配置顺序 - self._load_defaults() - self._load_env_file() - self._load_user_config() + # 各平台特定的配置 + PLATFORM_CONFIG = { + 'Windows': { + 'socket_timeout': 30, + 'connect_timeout': 10, + 'ssl': None + }, + 'Darwin': { # macOS + 'socket_timeout': 60, + 'connect_timeout': 15, + 'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'} + }, + 'Linux': { + 'socket_timeout': 60, + 'connect_timeout': 15, + 'ssl': None + } + } - # 初始化加密模块 - self._init_encryption() + def __new__(cls, *args, **kwargs): + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super().__new__(cls) + return cls._instance - def _get_config_dir(self) -> str: - """获取适合当前平台的配置目录""" - if self.system == 'windows': - return os.path.join(os.environ['APPDATA'], self.app_name) - elif self.system == 'darwin': # macOS - return os.path.expanduser(f"~/Library/Application Support/{self.app_name}") - else: # Linux及其他Unix-like - return os.path.expanduser(f"~/.config/{self.app_name}") + def __init__(self, config: dict = None): + if hasattr(self, '_pool') and self._pool: + return - def _load_defaults(self): - """加载默认配置""" - self._config = { - "system": { - "log_level": "INFO", - "max_threads": os.cpu_count() or 4 - }, - "api": { - "newsapi": {"endpoint": "https://newsapi.org/v2"}, - "weibo": {"version": "2"} - }, - "paths": { - "data_dir": os.path.join(self.config_dir, "data"), - "cache_dir": os.path.join(self.config_dir, "cache") - } + if not config: + from config.settings import DATABASE_CONFIG + config = DATABASE_CONFIG + + # 获取当前平台配置 + current_platform = platform.system() + platform_config = self.PLATFORM_CONFIG.get(current_platform, {}) + + # 基础配置 + self.config = { + 'host': config.get('host', 'localhost'), + 'port': config.get('port', 3306), + 'user': config.get('user', 'root'), + 'password': config.get('password', ''), + 'database': config.get('database', 'intelligence_system'), + 'charset': config.get('charset', 'utf8mb4'), + 'cursorclass': cursors.DictCursor, + 'autocommit': True, + **platform_config # 合并平台特定配置 } - def _load_env_file(self): - """加载.env环境变量文件""" - env_path = Path(self.config_dir) / ".env" - if env_path.exists(): - dotenv.load_dotenv(env_path) + # 处理各平台路径差异 + if current_platform == 'Windows': + self.config['ssl'] = None # Windows通常不需要SSL配置 - # 环境变量覆盖配置 - if os.getenv("LOG_LEVEL"): - self._config["system"]["log_level"] = os.getenv("LOG_LEVEL") + # macOS特殊处理 + elif current_platform == 'Darwin': + if not os.path.exists(self.config['ssl']['ca']): + self.config['ssl'] = None + logger.warning("macOS SSL certificate not found, disabling SSL") - def _load_user_config(self): - """加载用户自定义配置""" - config_file = Path(self.config_dir) / "config.json" + self.pool_size = config.get('max_connections', 5) + self._pool = self._create_pool() + self.logger = logger.bind(module=f"MySQLAgent({current_platform})") + + def _create_pool(self) -> PooledDB: + """创建跨平台兼容的连接池""" try: - if config_file.exists(): - with open(config_file, 'r', encoding='utf-8') as f: - user_config = json.load(f) - self._deep_update(self._config, user_config) + # 各平台连接池参数调整 + pool_config = { + 'creator': pymysql, + 'maxconnections': self.pool_size, + 'mincached': 1, + 'maxcached': 3, + 'blocking': True, + 'ping': 1, # 定期检查连接有效性 + **self.config + } + + # Windows平台需要更短的超时时间 + if platform.system() == 'Windows': + pool_config['ping'] = 0 # Windows上ping有时不稳定 + + pool = PooledDB(**pool_config) + self.logger.info(f"Connection pool created for {platform.system()}") + return pool + except Exception as e: - print(f"加载用户配置失败: {str(e)}") + self.logger.critical("Failed to create connection pool", + error=str(e), + exc_info=True) + raise - def _init_encryption(self): - """初始化配置加密模块""" - key_file = Path(self.config_dir) / ".secret.key" - if key_file.exists(): - with open(key_file, 'rb') as f: - self._secret_key = f.read() - else: - self._secret_key = Fernet.generate_key() - with open(key_file, 'wb') as f: - f.write(self._secret_key) - key_file.chmod(0o600) # 设置密钥文件权限 + def _handle_path(self, path: str) -> str: + """处理跨平台路径问题""" + if platform.system() == 'Windows': + return path.replace('/', '\\') + return path - def _deep_update(self, original: Dict, update: Dict) -> Dict: - """深度合并字典""" - for key, value in update.items(): - if isinstance(value, dict) and key in original: - original[key] = self._deep_update(original.get(key, {}), value) - else: - original[key] = value - return original - - def get(self, key: str, default: Any = None) -> Any: + def get_connection(self) -> pymysql.connections.Connection: """ - 获取配置项(支持点分路径) - 示例: get("api.newsapi.endpoint") + 获取数据库连接(跨平台兼容) + + Returns: + pymysql.connections.Connection: 数据库连接 + + Raises: + MySQLError: 如果连接失败 """ - keys = key.split('.') - value = self._config try: - for k in keys: - value = value[k] - return value - except (KeyError, TypeError): - return default + conn = self._pool.connection() - def set(self, key: str, value: Any, persist: bool = False): + # macOS需要特殊处理SSL + if platform.system() == 'Darwin' and self.config.get('ssl'): + conn.ping(reconnect=True) + + self.logger.trace("Connection obtained") + return conn + + except Exception as e: + error_msg = str(e) + + # Windows特定错误处理 + if platform.system() == 'Windows' and "timed out" in error_msg: + self.logger.warning("Windows connection timeout, retrying...") + return self._retry_connection() + + self.logger.error("Connection failed", + error=error_msg, + exc_info=True) + raise + + def _retry_connection(self, max_retries: int = 3) -> pymysql.connections.Connection: + """Windows平台连接重试机制""" + for attempt in range(max_retries): + try: + conn = self._pool.connection() + self.logger.info(f"Connection established after {attempt+1} attempts") + return conn + except Exception: + if attempt == max_retries - 1: + raise + import time + time.sleep(1) + + def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None, + parse_dates: Union[List[str], bool] = True) -> pd.DataFrame: """ - 设置配置项 - 参数: - persist: 是否保存到用户配置文件 + 跨平台兼容的SQL查询 + + Args: + sql (str): SQL语句 + params (Union[tuple, dict, None]): 参数 + parse_dates (Union[List[str], bool]): 日期解析 + + Returns: + pd.DataFrame: 查询结果 """ - keys = key.split('.') - config_ref = self._config + try: + with self.get_connection() as conn: + # Linux/macOS需要更长的查询超时 + if platform.system() != 'Windows': + conn.cursor().execute("SET SESSION wait_timeout=600") - for k in keys[:-1]: - if k not in config_ref: - config_ref[k] = {} - config_ref = config_ref[k] + df = pd.read_sql(sql, conn, params=params, parse_dates=parse_dates) - config_ref[keys[-1]] = value + # Windows平台需要手动关闭游标 + if platform.system() == 'Windows': + conn.cursor().close() - if persist: - self._save_user_config() + self.logger.info("Query executed", rows=len(df)) + return df - def encrypt_value(self, plaintext: str) -> str: - """加密敏感信息""" - fernet = Fernet(self._secret_key) - return fernet.encrypt(plaintext.encode()).decode() + except Exception as e: + self.logger.error("Query failed", + sql=sql, + params=params, + error=str(e), + exc_info=True) + raise - def decrypt_value(self, ciphertext: str) -> str: - """解密敏感信息""" - fernet = Fernet(self._secret_key) - return fernet.decrypt(ciphertext.encode()).decode() + def insert_from_df(self, table_name: str, df: pd.DataFrame, + chunk_size: int = 1000, replace: bool = False) -> int: + """ + 跨平台数据插入 - def _save_user_config(self): - """保存用户配置到文件""" - config_file = Path(self.config_dir) / "config.json" - with open(config_file, 'w', encoding='utf-8') as f: - json.dump(self._config, f, indent=2, ensure_ascii=False) + Args: + table_name (str): 表名 + df (pd.DataFrame): 数据 + chunk_size (int): 分批大小 + replace (bool): 是否替换 - def reload(self): - """重新加载所有配置""" - self._config = {} - self._load_defaults() - self._load_env_file() - self._load_user_config() + Returns: + int: 插入行数 + """ + if df.empty: + self.logger.warning("Empty DataFrame", table=table_name) + return 0 -# 全局配置实例 -config = ConfigManager() + try: + method = 'replace' if replace else 'append' + total_rows = 0 -# 快捷访问方法(兼容旧代码) -def get_config(key: str, default: Optional[Any] = None) -> Any: - return config.get(key, default) + with self.get_connection() as conn: + # 各平台不同的分批策略 + if platform.system() == 'Windows': + chunk_size = min(chunk_size, 500) # Windows上减小批次 -def set_config(key: str, value: Any, persist: bool = False): - config.set(key, value, persist) + for i in range(0, len(df), chunk_size): + chunk = df.iloc[i:i + chunk_size] -# 测试代码 + # macOS需要特殊处理datetime + if platform.system() == 'Darwin': + for col in chunk.select_dtypes(include=['datetime64']): + chunk[col] = chunk[col].dt.strftime('%Y-%m-%d %H:%M:%S') + + chunk.to_sql( + table_name, + conn, + if_exists=method, + index=False, + method='multi' + ) + total_rows += len(chunk) + method = 'append' + + self.logger.info("Data inserted", table=table_name, rows=total_rows) + return total_rows + + except Exception as e: + self.logger.error("Insert failed", + table=table_name, + error=str(e), + exc_info=True) + raise + + def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None, + fetch: bool = False) -> Union[int, List[Dict[str, Any]]]: + """ + 跨平台SQL执行 + + Args: + sql (str): SQL语句 + params (Union[tuple, dict, None]): 参数 + fetch (bool): 是否获取结果 + + Returns: + Union[int, List[Dict[str, Any]]]: 结果 + """ + conn = None + cursor = None + try: + conn = self.get_connection() + cursor = conn.cursor() + + # Linux/macOS需要更长的执行时间 + if platform.system() != 'Windows': + cursor.execute("SET SESSION max_execution_time=600000") + + cursor.execute(sql, params) + + if fetch: + result = cursor.fetchall() + self.logger.debug("Query executed", rows=len(result)) + return result + else: + affected_rows = cursor.rowcount + self.logger.debug("Update executed", affected_rows=affected_rows) + return affected_rows + + except Exception as e: + self.logger.error("SQL execution failed", + sql=sql, + params=params, + error=str(e), + exc_info=True) + raise + finally: + if cursor: + cursor.close() + if conn: + conn.close() + + def begin_transaction(self) -> pymysql.connections.Connection: + """开始事务(跨平台兼容)""" + try: + conn = self.get_connection() + conn.autocommit(False) + + # macOS需要特殊处理事务隔离级别 + if platform.system() == 'Darwin': + conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED") + + self.logger.debug("Transaction started") + return conn + except Exception as e: + self.logger.error("Begin transaction failed", error=str(e)) + raise + + def commit_transaction(self, conn: pymysql.connections.Connection) -> None: + """提交事务(跨平台兼容)""" + try: + conn.commit() + self.logger.debug("Transaction committed") + except Exception as e: + self.logger.error("Commit failed", error=str(e)) + raise + finally: + conn.close() + + def rollback_transaction(self, conn: pymysql.connections.Connection) -> None: + """回滚事务(跨平台兼容)""" + try: + conn.rollback() + self.logger.warning("Transaction rolled back") + except Exception as e: + self.logger.error("Rollback failed", error=str(e)) + finally: + conn.close() + + def __del__(self): + """析构函数(跨平台资源清理)""" + if hasattr(self, '_pool'): + try: + self._pool.close() + self.logger.info("Connection pool closed") + except Exception as e: + self.logger.error("Failed to close pool", error=str(e)) + + +# 平台特定的默认配置 +def get_default_config(): + """获取各平台默认配置""" + current_platform = platform.system() + + base_config = { + 'host': 'localhost', + 'port': 3306, + 'user': 'root', + 'password': '', + 'database': 'intelligence_system', + 'max_connections': 5 + } + + if current_platform == 'Windows': + return { + **base_config, + 'connect_timeout': 10, + 'read_timeout': 30, + 'write_timeout': 30 + } + elif current_platform == 'Darwin': + return { + **base_config, + 'connect_timeout': 15, + 'read_timeout': 60, + 'write_timeout': 60, + 'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'} + } + else: # Linux和其他平台 + return { + **base_config, + 'connect_timeout': 15, + 'read_timeout': 60, + 'write_timeout': 60 + } + + +# 使用示例 if __name__ == "__main__": - # 设置并保存API密钥(自动加密) - api_key = "your_newsapi_key_here" - encrypted_key = config.encrypt_value(api_key) - config.set("api.newsapi.key", encrypted_key, persist=True) + # 自动获取适合当前平台的配置 + config = get_default_config() - # 获取配置示例 - print(f"日志级别: {config.get('system.log_level')}") - print(f"NewsAPI端点: {config.get('api.newsapi.endpoint')}") + # 初始化数据库连接 + db = MySQLAgent(config) - # 解密敏感信息 - stored_key = config.get("api.newsapi.key") - if stored_key: - print(f"解密后的API密钥: {config.decrypt_value(stored_key)}") \ No newline at end of file + # 测试查询 + try: + df = db.query_to_df("SELECT VERSION() as version") + print(f"Database version: {df['version'].iloc[0]}") + print(f"Running on: {platform.system()} {platform.release()}") + except Exception as e: + print(f"Error: {str(e)}") \ No newline at end of file diff --git a/main.py b/main.py index f2da81c..c78525e 100644 --- a/main.py +++ b/main.py @@ -17,7 +17,7 @@ from typing import Dict, List, Any # 自定义模块 from processors.data_processor import DataProcessor -from storage.database import IntelligenceDB +from storage.mysql_agent import IntelligenceDB from applications.reporter import ReportGenerator from applications.alert import AlertService from utils.logger import setup_logging diff --git a/storage/database.py b/storage/database.py deleted file mode 100644 index db151ff..0000000 --- a/storage/database.py +++ /dev/null @@ -1,255 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -数据库存储模块 -功能: -1. 统一数据库接口(SQLite/MySQL/PostgreSQL) -2. 自动处理多平台路径问题 -3. 连接池管理 -4. 数据加密存储 -""" - -import os -import platform -import sqlite3 -import threading -from pathlib import Path -from typing import Optional, Union, Dict, Any, List -from threading import Lock -import logging -from cryptography.fernet import Fernet - -# 类型别名 -QueryParams = Union[tuple, Dict[str, Any]] - - -class DatabaseManager: - """数据库统一管理类""" - - def __init__(self, db_config: Dict[str, Any]): - """ - 初始化数据库连接 - - 参数: - db_config: 配置字典,包含: - - type: 'sqlite'|'mysql'|'postgresql' - - database: 数据库名/路径 - - [可选] host, port, user, password - """ - self.config = db_config - self._lock = Lock() - self._connection_pool = {} - self._setup_crypto() - - # 自动创建SQLite目录 - if db_config['type'] == 'sqlite': - self._ensure_sqlite_dir() - - def _ensure_sqlite_dir(self): - """确保SQLite数据库目录存在""" - db_path = Path(self.config['database']) - if not db_path.parent.exists(): - try: - db_path.parent.mkdir(parents=True, mode=0o755) - except Exception as e: - logging.error(f"创建数据库目录失败: {str(e)}") - - def _setup_crypto(self): - """初始化加密模块""" - key_file = Path(os.path.expanduser("~/.db_encryption.key")) - if key_file.exists(): - with open(key_file, 'rb') as f: - self._fernet = Fernet(f.read()) - else: - self._fernet = Fernet.generate_key() - with open(key_file, 'wb') as f: - f.write(self._fernet) - key_file.chmod(0o600) # 仅限当前用户读写 - - def get_connection(self, reuse=True): - """ - 获取数据库连接(线程安全) - - 参数: - reuse: 是否复用现有连接(默认True) - """ - thread_id = threading.get_ident() - - with self._lock: - if reuse and thread_id in self._connection_pool: - conn = self._connection_pool[thread_id] - try: - # 检查连接是否有效 - conn.execute("SELECT 1") - return conn - except: - del self._connection_pool[thread_id] - - # 创建新连接 - if self.config['type'] == 'sqlite': - conn = self._create_sqlite_connection() - elif self.config['type'] == 'mysql': - conn = self._create_mysql_connection() - elif self.config['type'] == 'postgresql': - conn = self._create_pg_connection() - else: - raise ValueError("不支持的数据库类型") - - self._connection_pool[thread_id] = conn - return conn - - def _create_sqlite_connection(self) -> sqlite3.Connection: - """创建SQLite连接(兼容多平台路径)""" - db_path = self.config['database'] - - # Windows路径处理 - if platform.system() == 'Windows' and not db_path.startswith(('\\\\', '/')): - db_path = os.path.abspath(db_path) - - conn = sqlite3.connect(db_path, timeout=15) - conn.execute("PRAGMA journal_mode=WAL") # 写前日志提升并发 - conn.execute("PRAGMA synchronous=NORMAL") - conn.row_factory = sqlite3.Row # 支持字典式访问 - return conn - - def _create_mysql_connection(self): - """创建MySQL连接(需安装PyMySQL)""" - import pymysql - return pymysql.connect( - host=self.config.get('host', 'localhost'), - port=self.config.get('port', 3306), - user=self.config.get('user', 'root'), - password=self.config.get('password', ''), - database=self.config['database'], - charset='utf8mb4', - cursorclass=pymysql.cursors.DictCursor - ) - - def _create_pg_connection(self): - """创建PostgreSQL连接(需安装psycopg2)""" - import psycopg2 - return psycopg2.connect( - host=self.config.get('host', 'localhost'), - port=self.config.get('port', 5432), - user=self.config.get('user', 'postgres'), - password=self.config.get('password', ''), - dbname=self.config['database'] - ) - - def execute( - self, - query: str, - params: Optional[QueryParams] = None, - return_lastrowid: bool = False - ) -> Union[int, None]: - conn = self.get_connection() - try: - with conn: - cursor = conn.cursor() - if params is not None: - cursor.execute(query, params) - else: - cursor.execute(query) # ✅ 无参数时不要传 params - return cursor.lastrowid if return_lastrowid else None - finally: - if not return_lastrowid: - self._release_connection() - - def query( - self, - query: str, - params: Optional[QueryParams] = None, - fetchall: bool = True - ) -> Union[List[Dict], Dict]: - conn = self.get_connection() - try: - cursor = conn.cursor() - if params is not None: - cursor.execute(query, params) - else: - cursor.execute(query) # ✅ 无参数时不要传 None - - if self.config['type'] == 'sqlite': - result = cursor.fetchall() - if not fetchall and result: - return dict(result[0]) - return [dict(row) for row in result] - else: - return cursor.fetchall() if fetchall else cursor.fetchone() - finally: - self._release_connection() - - def _release_connection(self): - """释放当前线程的连接(SQLite除外)""" - if self.config['type'] != 'sqlite': - thread_id = threading.get_ident() - with self._lock: - if thread_id in self._connection_pool: - self._connection_pool[thread_id].close() - del self._connection_pool[thread_id] - - def encrypt_data(self, plaintext: str) -> str: - """加密敏感数据""" - return self._fernet.encrypt(plaintext.encode()).decode() - - def decrypt_data(self, ciphertext: str) -> str: - """解密数据""" - return self._fernet.decrypt(ciphertext.encode()).decode() - - def close_all(self): - """关闭所有数据库连接""" - with self._lock: - for conn in self._connection_pool.values(): - try: - conn.close() - except: - pass - self._connection_pool.clear() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close_all() - - -# 全局SQLite实例(默认配置) -def get_default_db() -> DatabaseManager: - """获取默认SQLite数据库(跨平台路径处理)""" - system = platform.system().lower() - if system == 'windows': - db_path = os.path.join(os.getenv('APPDATA'), 'app_name/data.db') - elif system == 'darwin': - db_path = os.path.expanduser('~/Library/Application Support/app_name/data.db') - else: - db_path = '/var/lib/app_name/data.db' if os.access('/var/lib', os.W_OK) \ - else os.path.expanduser('~/.local/share/app_name/data.db') - - return DatabaseManager({ - 'type': 'sqlite', - 'database': db_path - }) - - -# 测试代码 -if __name__ == "__main__": - with get_default_db() as db: - # 创建测试表 - db.execute(""" - CREATE TABLE IF NOT EXISTS test_table ( - id INTEGER PRIMARY KEY, - name TEXT, - secret TEXT - ) - """) - - # 插入加密数据 - secret = db.encrypt_data("新敏感信息") - db.execute( - "INSERT INTO test_table (name, secret) VALUES (?, ?)", - ("测试记录", secret) - ) - - # 查询并解密 - row = db.query("SELECT * FROM test_table", fetchall=False) - print(f"解密数据: {db.decrypt_data(row['name'])}") diff --git a/storage/mysql_agent.py b/storage/mysql_agent.py new file mode 100644 index 0000000..65c4119 --- /dev/null +++ b/storage/mysql_agent.py @@ -0,0 +1,662 @@ +import os +import sys +import platform +import pandas as pd +import pymysql +from pymysql import cursors +from pymysql.err import MySQLError +from dbutils.pooled_db import PooledDB +from typing import Union, List, Dict, Any, Optional, Tuple +import threading +from datetime import datetime +import numpy as np +from pathlib import Path + +# 导入日志系统 +from utils.logger import log + + +class MySQLAgent: + """ + 全平台兼容的MySQL数据库操作类 + 支持Windows/macOS/Linux系统 + 配置参数从外部传入 + """ + + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, config: dict): + """ + 初始化MySQL数据库连接 + + Args: + config (dict): 数据库配置字典,包含以下键: + - host: 数据库主机 + - port: 端口 + - user: 用户名 + - password: 密码 + - database: 数据库名 + - [可选] charset: 字符集(默认utf8mb4) + - [可选] max_connections: 最大连接数(默认5) + - [可选] connect_timeout: 连接超时(秒) + - [可选] read_timeout: 读取超时(秒) + - [可选] write_timeout: 写入超时(秒) + - [可选] ssl: SSL配置 + """ + if hasattr(self, '_pool') and self._pool: + return + + # 基础配置 + required_keys = ['host', 'port', 'user', 'password', 'database'] + if not all(key in config for key in required_keys): + raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}") + + self.config = { + 'host': config['host'], + 'port': config['port'], + 'user': config['user'], + 'password': config['password'], + 'database': config['database'], + 'charset': config.get('charset', 'utf8mb4'), + 'cursorclass': cursors.DictCursor, + 'autocommit': True, + 'connect_timeout': config.get('connect_timeout', 10), + 'read_timeout': config.get('read_timeout', 30), + 'write_timeout': config.get('write_timeout', 30), + 'ssl': config.get('ssl') + } + + # 初始化log + current_platform = platform.system() + self.log = log.bind(module=f"MySQLAgent({current_platform})") + + # 创建连接池 + self.pool_size = config.get('max_connections', 5) + self._pool = self._create_pool() + + def _create_pool(self) -> PooledDB: + """创建连接池""" + try: + # 使用包装函数确保线程安全 + def connect(): + conn = pymysql.connect(**self.config) + conn.threadsafety = 1 # 显式设置线程安全级别 + return conn + + pool = PooledDB( + creator=connect, + mincached=1, + maxcached=3, + maxconnections=self.pool_size, + blocking=True, + ping=1 + ) + + self.log.info("Connection pool created") + return pool + + except Exception as e: + self.log.critical("Failed to create connection pool", + error=str(e), + exc_info=True) + raise + + def get_connection(self) -> pymysql.connections.Connection: + """ + 获取数据库连接 + + Returns: + pymysql.connections.Connection: 数据库连接对象 + + Raises: + MySQLError: 如果获取连接失败 + """ + try: + conn = self._pool.connection() + + # macOS需要特殊处理SSL + if platform.system() == 'Darwin' and self.config.get('ssl'): + conn.ping(reconnect=True) + + self.log.trace("Database connection obtained") + return conn + + except Exception as e: + error_msg = str(e) + + # Windows特定错误处理 + if platform.system() == 'Windows' and "timed out" in error_msg: + self.log.warning("Windows connection timeout, retrying...") + return self._retry_connection() + + self.log.error("Connection failed", + error=error_msg, + exc_info=True) + raise + + def _retry_connection(self, max_retries: int = 3) -> pymysql.connections.Connection: + """Windows平台连接重试机制""" + for attempt in range(max_retries): + try: + conn = self._pool.connection() + self.log.info(f"Connection established after {attempt + 1} attempts") + return conn + except Exception: + if attempt == max_retries - 1: + raise + import time + time.sleep(1) + + def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None, + parse_dates: Union[List[str], bool] = True) -> pd.DataFrame: + """ + 执行SQL查询并返回DataFrame + + Args: + sql (str): SQL查询语句 + params (Union[tuple, dict, None]): 查询参数 + parse_dates (Union[List[str], bool]): 自动解析日期字段 + + Returns: + pd.DataFrame: 查询结果 + + Raises: + MySQLError: 如果查询失败 + """ + try: + self.log.debug("Executing SQL query", sql=sql) + + with self.get_connection() as conn: + # Linux/macOS需要更长的查询超时 + if platform.system() != 'Windows': + conn.cursor().execute("SET SESSION wait_timeout=600") + + df = pd.read_sql(sql, conn, params=params, parse_dates=parse_dates) + + # Windows平台需要手动关闭游标 + if platform.system() == 'Windows': + conn.cursor().close() + + self.log.info("Query executed successfully", rows=len(df)) + return df + + except Exception as e: + self.log.error("SQL query failed", + sql=sql, + params=params, + error=str(e), + exc_info=True) + raise + + def insert_from_df(self, table_name: str, df: pd.DataFrame, + chunk_size: int = 1000, replace: bool = False) -> int: + """ + 将DataFrame数据插入到数据库表 + + Args: + table_name (str): 目标表名 + df (pd.DataFrame): 要插入的数据 + chunk_size (int): 分批插入大小 + replace (bool): 是否替换现有数据 + + Returns: + int: 插入的总行数 + + Raises: + MySQLError: 如果插入失败 + """ + if df.empty: + self.log.warning("Attempted to insert empty DataFrame", table=table_name) + return 0 + + self.log.debug("Preparing to insert DataFrame", + table=table_name, + rows=len(df), + chunk_size=chunk_size) + + try: + method = 'replace' if replace else 'append' + total_rows = 0 + + with self.get_connection() as conn: + # 各平台不同的分批策略 + if platform.system() == 'Windows': + chunk_size = min(chunk_size, 500) # Windows上减小批次 + + for i in range(0, len(df), chunk_size): + chunk = df.iloc[i:i + chunk_size] + + # macOS需要特殊处理datetime + if platform.system() == 'Darwin': + for col in chunk.select_dtypes(include=['datetime64']): + chunk[col] = chunk[col].dt.strftime('%Y-%m-%d %H:%M:%S') + + chunk.to_sql( + table_name, + conn, + if_exists=method, + index=False, + method='multi' + ) + total_rows += len(chunk) + method = 'append' # 第一次之后都使用追加模式 + self.log.trace(f"Inserted chunk {i // chunk_size + 1}", + rows=len(chunk), + total_inserted=total_rows) + + self.log.info("Data inserted successfully", + table=table_name, + total_rows=total_rows) + return total_rows + + except Exception as e: + self.log.error("Data insertion failed", + table=table_name, + error=str(e), + exc_info=True) + raise + + def update_from_df(self, table_name: str, df: pd.DataFrame, + key_columns: Union[str, List[str]]) -> int: + """ + 使用DataFrame数据更新数据库表 + + Args: + table_name (str): 目标表名 + df (pd.DataFrame): 包含更新数据 + key_columns (Union[str, List[str]]): 用于匹配记录的关键列 + + Returns: + int: 更新的总行数 + + Raises: + MySQLError: 如果更新失败 + """ + if df.empty: + self.log.warning("Attempted to update with empty DataFrame", table=table_name) + return 0 + + self.log.debug("Preparing to update table from DataFrame", + table=table_name, + key_columns=key_columns, + rows=len(df)) + + try: + if isinstance(key_columns, str): + key_columns = [key_columns] + + total_updated = 0 + conn = self.begin_transaction() + + try: + cursor = conn.cursor() + + # 获取表结构信息 + table_info = self._get_table_info(table_name) + columns = [col for col in df.columns if col in table_info] + + # 构建UPDATE语句模板 + set_clause = ', '.join([f"{col}=%s" for col in columns if col not in key_columns]) + where_clause = ' AND '.join([f"{col}=%s" for col in key_columns]) + + update_sql = f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}" + self.log.trace("Generated update SQL", sql=update_sql) + + # 准备数据 + update_data = [] + for _, row in df.iterrows(): + # SET部分的值 + set_values = [row[col] for col in columns if col not in key_columns] + # WHERE部分的值 + key_values = [row[col] for col in key_columns] + update_data.append(tuple(set_values + key_values)) + + # 执行批量更新 + cursor.executemany(update_sql, update_data) + total_updated = cursor.rowcount + + self.commit_transaction(conn) + self.log.info("Data updated successfully", + table=table_name, + rows_updated=total_updated) + return total_updated + + except Exception as e: + self.rollback_transaction(conn) + raise + + except Exception as e: + self.log.error("Data update failed", + table=table_name, + error=str(e), + exc_info=True) + raise + + def _get_table_info(self, table_name: str) -> Dict[str, str]: + """ + 获取表结构信息 + + Args: + table_name (str): 表名 + + Returns: + Dict[str, str]: 列名到类型的映射 + + Raises: + MySQLError: 如果查询失败 + """ + sql = f""" + SELECT column_name, data_type + FROM information_schema.columns + WHERE table_schema = %s AND table_name = %s + """ + + params = (self.config['database'], table_name) + + try: + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql, params) + result = cursor.fetchall() + return {row['column_name']: row['data_type'] for row in result} + + except Exception as e: + self.log.error("Failed to get table info", + table=table_name, + error=str(e)) + raise + + def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]: + """ + 推断DataFrame各列的SQL类型 + + Args: + df (pd.DataFrame): 输入数据框 + + Returns: + Dict[str, str]: 列名到SQL类型的映射 + """ + type_mapping = { + 'int64': 'BIGINT', + 'float64': 'DOUBLE', + 'datetime64[ns]': 'DATETIME', + 'object': 'TEXT', + 'bool': 'TINYINT(1)', + 'category': 'VARCHAR(255)' + } + + sql_types = {} + for col, dtype in df.dtypes.items(): + dtype_str = str(dtype) + sql_types[col] = type_mapping.get(dtype_str, 'TEXT') + + self.log.debug("Mapped DataFrame types to SQL types", + mappings=sql_types) + return sql_types + + def create_table_from_df(self, table_name: str, df: pd.DataFrame, + primary_key: Union[str, List[str], None] = None) -> bool: + """ + 根据DataFrame结构创建表 + + Args: + table_name (str): 表名 + df (pd.DataFrame): 参考数据框 + primary_key (Union[str, List[str], None]): 主键列 + + Returns: + bool: 是否创建成功 + """ + if self.table_exists(table_name): + self.log.warning("Table already exists", table=table_name) + return False + + self.log.debug("Creating new table from DataFrame schema", + table=table_name, + columns=list(df.columns)) + + try: + sql_types = self.df_to_sql_type(df) + columns_sql = [] + + for col, sql_type in sql_types.items(): + col_def = f"{col} {sql_type}" + columns_sql.append(col_def) + + if primary_key: + if isinstance(primary_key, str): + primary_key = [primary_key] + pk_columns = [col for col in primary_key if col in sql_types] + if pk_columns: + columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})") + self.log.trace("Set primary key", + table=table_name, + primary_key=pk_columns) + + create_sql = f"CREATE TABLE {table_name} (\n {',\n '.join(columns_sql)}\n)" + + self.execute_sql(create_sql) + self.log.info("Table created successfully", table=table_name) + return True + + except Exception as e: + self.log.error("Failed to create table", + table=table_name, + error=str(e), + exc_info=True) + return False + + def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None, + fetch: bool = False) -> Union[int, List[Dict[str, Any]]]: + """ + 执行SQL语句 + + Args: + sql (str): SQL语句 + params (Union[tuple, dict, None]): 参数 + fetch (bool): 是否获取结果 + + Returns: + Union[int, List[Dict[str, Any]]]: + - 如果是INSERT/UPDATE/DELETE,返回影响的行数 + - 如果是SELECT且fetch=True,返回结果列表 + """ + conn = None + cursor = None + try: + conn = self.get_connection() + cursor = conn.cursor() + + # Linux/macOS需要更长的执行时间 + if platform.system() != 'Windows': + cursor.execute("SET SESSION max_execution_time=600000") + + cursor.execute(sql, params) + + if fetch: + result = cursor.fetchall() + self.log.debug("Query executed", rows=len(result)) + return result + else: + affected_rows = cursor.rowcount + self.log.debug("Update executed", affected_rows=affected_rows) + return affected_rows + + except Exception as e: + self.log.error("SQL execution failed", + sql=sql, + params=params, + error=str(e), + exc_info=True) + raise + finally: + if cursor: + cursor.close() + if conn: + conn.close() + + def begin_transaction(self) -> pymysql.connections.Connection: + """开始事务""" + try: + conn = self.get_connection() + conn.autocommit(False) + + # macOS需要特殊处理事务隔离级别 + if platform.system() == 'Darwin': + conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED") + + self.log.debug("Transaction started") + return conn + except Exception as e: + self.log.error("Begin transaction failed", error=str(e), exc_info=True) + raise + + def commit_transaction(self, conn: pymysql.connections.Connection) -> None: + """提交事务""" + try: + conn.commit() + self.log.debug("Transaction committed") + except Exception as e: + self.log.error("Commit failed", error=str(e)) + raise + finally: + conn.close() + + def rollback_transaction(self, conn: pymysql.connections.Connection) -> None: + """回滚事务""" + try: + conn.rollback() + self.log.warning("Transaction rolled back") + except Exception as e: + self.log.error("Rollback failed", error=str(e)) + finally: + conn.close() + + def table_exists(self, table_name: str) -> bool: + """检查表是否存在""" + sql = """ + SELECT COUNT(*) as count + FROM `information_schema`.`tables` + WHERE `table_schema` = %s AND `table_name` = %s + """ + + params = (self.config['database'], table_name) + + try: + result = self.execute_sql(sql, params, fetch=True) + exists = result[0]['count'] > 0 + self.log.debug("Checked table existence", + table=table_name, + exists=exists) + return exists + except Exception: + return False + + def drop_table(self, table_name: str) -> bool: + """删除表""" + if not self.table_exists(table_name): + self.log.warning("Table does not exist", table=table_name) + return False + + try: + self.execute_sql(f"DROP TABLE {table_name}") + self.log.info("Table dropped successfully", table=table_name) + return True + except Exception as e: + self.log.error("Failed to drop table", + table=table_name, + error=str(e), + exc_info=True) + return False + + def get_pool_status(self) -> Dict[str, int]: + """获取连接池状态""" + return { + 'max': self._pool._maxconnections, + 'active': self._pool._connections, + 'idle': len(self._pool._idle_cache), + 'shared': len(self._pool._shared_cache) + } + + def validate_connection(self) -> bool: + """验证连接是否有效""" + try: + with self.get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute("SELECT 1") + return cursor.fetchone()[0] == 1 + except Exception: + return False + + def __del__(self): + """析构函数""" + if hasattr(self, '_pool'): + try: + self._pool.close() + self.log.info("Connection pool closed") + except Exception as e: + self.log.error("Failed to close pool", error=str(e)) + + +# 平台特定的默认配置 +def get_default_config(): + """获取各平台默认配置""" + current_platform = platform.system() + + base_config = { + 'host': 'localhost', + 'port': 3306, + 'user': 'root', + 'password': '123123', + 'database': 'intelligence_system', + 'max_connections': 5 + } + + if current_platform == 'Windows': + return { + **base_config, + 'connect_timeout': 10, + 'read_timeout': 30, + 'write_timeout': 30 + } + elif current_platform == 'Darwin': + return { + **base_config, + 'connect_timeout': 15, + 'read_timeout': 60, + 'write_timeout': 60, + 'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'} + } + else: # Linux和其他平台 + return { + **base_config, + 'connect_timeout': 15, + 'read_timeout': 60, + 'write_timeout': 60 + } + + +if __name__ == "__main__": + # 使用示例 + db = MySQLAgent(get_default_config()) + + # 测试连接 + if db.validate_connection(): + print("Database connection successful") + + # 获取数据库版本 + version = db.query_to_df("SELECT VERSION() as version") + print(f"Database version: {version['version'].iloc[0]}") + + # 查看连接池状态 + print("Connection pool status:", db.get_pool_status()) + else: + print("Failed to connect to database") diff --git a/test/subdir/test.json b/test/subdir/test.json deleted file mode 100644 index 449d36b..0000000 --- a/test/subdir/test.json +++ /dev/null @@ -1 +0,0 @@ -{"a":{"0":1},"b":{"0":2}} \ No newline at end of file diff --git a/test/数据库链接测试.py b/test/数据库链接测试.py new file mode 100644 index 0000000..a4f8991 --- /dev/null +++ b/test/数据库链接测试.py @@ -0,0 +1,291 @@ +import unittest +import pandas as pd +from datetime import datetime +import tempfile +import time +import pymysql +from storage.mysql_agent import MySQLAgent +import platform + +class TestMySQLAgent(unittest.TestCase): + @classmethod + def setUpClass(cls): + """初始化测试环境和测试表""" + # 创建唯一的测试数据库名 + cls.test_db_name = "test_db_" + datetime.now().strftime("%Y%m%d%H%M%S") + cls.test_table = "test_table_" + datetime.now().strftime("%Y%m%d%H%M%S") + + # 基础配置 + cls.base_config = { + 'host': 'localhost', + 'port': 3306, + 'user': 'root', + 'password': '123123', + 'max_connections': 10 + } + + # 创建测试数据库 + cls._create_test_database() + + # 初始化数据库连接 + cls.db = MySQLAgent({ + **cls.base_config, + 'database': cls.test_db_name + }) + + # 创建测试表 + test_data = pd.DataFrame({ + 'id': [1, 2, 3], + 'name': ['Test1', 'Test2', 'Test3'], + 'value': [10.5, 20.3, 30.8], + 'created_at': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03']) + }) + + cls.db.create_table_from_df(cls.test_table, test_data, primary_key='id') + cls.db.insert_from_df(cls.test_table, test_data) + + @classmethod + def _create_test_database(cls): + """创建测试数据库""" + # 使用临时连接创建数据库 + temp_conn = pymysql.connect( + host=cls.base_config['host'], + port=cls.base_config['port'], + user=cls.base_config['user'], + password=cls.base_config['password'], + charset='utf8mb4' + ) + + try: + with temp_conn.cursor() as cursor: + cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}") + cursor.execute(f"USE {cls.test_db_name}") + cursor.execute("SET GLOBAL max_connections = 100") + temp_conn.commit() + finally: + temp_conn.close() + + @classmethod + def tearDownClass(cls): + """清理测试数据库""" + if hasattr(cls, 'db') and cls.db: + # 删除测试表 + if cls.db.table_exists(cls.test_table): + cls.db.drop_table(cls.test_table) + + # 删除测试数据库 + temp_conn = pymysql.connect( + host=cls.base_config['host'], + port=cls.base_config['port'], + user=cls.base_config['user'], + password=cls.base_config['password'], + charset='utf8mb4' + ) + + try: + with temp_conn.cursor() as cursor: + cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}") + temp_conn.commit() + finally: + temp_conn.close() + + def test_01_connection(self): + """测试数据库连接""" + version = self.db.query_to_df("SELECT VERSION() as version") + self.assertIsNotNone(version) + print(f"\nDatabase version: {version['version'].iloc[0]}") + print(f"Running on: {platform.system()} {platform.release()}") + + def test_02_query_to_df(self): + """测试查询返回DataFrame""" + df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id > %s", (1,)) + self.assertEqual(len(df), 2) + self.assertIsInstance(df, pd.DataFrame) + print("\nQuery result sample:") + print(df.head()) + + def test_03_insert_from_df(self): + """测试DataFrame插入""" + new_data = pd.DataFrame({ + 'id': [4, 5], + 'name': ['Test4', 'Test5'], + 'value': [40.1, 50.2], + 'created_at': pd.to_datetime(['2023-01-04', '2023-01-05']) + }) + + rows = self.db.insert_from_df(self.test_table, new_data) + self.assertEqual(rows, 2) + + # 验证数据 + df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id >= 4") + self.assertEqual(len(df), 2) + self.assertEqual(df['name'].tolist(), ['Test4', 'Test5']) + + def test_04_update_from_df(self): + """测试DataFrame更新""" + update_data = pd.DataFrame({ + 'id': [1, 2], + 'name': ['Updated1', 'Updated2'] + }) + + rows = self.db.update_from_df(self.test_table, update_data, 'id') + self.assertGreaterEqual(rows, 2) + + # 验证更新 + df = self.db.query_to_df(f"SELECT name FROM {self.test_table} WHERE id IN (1,2)") + self.assertIn('Updated1', df['name'].values) + self.assertIn('Updated2', df['name'].values) + + def test_05_transaction(self): + """测试事务处理""" + conn = self.db.begin_transaction() + try: + # 执行多个操作 + cursor = conn.cursor() + cursor.execute(f"UPDATE {self.test_table} SET value = 99.9 WHERE id = 1") + cursor.execute(f"UPDATE {self.test_table} SET value = 88.8 WHERE id = 2") + + # 验证事务内修改 + cursor.execute(f"SELECT value FROM {self.test_table} WHERE id = 1") + self.assertEqual(cursor.fetchone()['value'], 99.9) + + self.db.commit_transaction(conn) + except Exception: + self.db.rollback_transaction(conn) + raise + + # 验证提交后的修改 + df = self.db.query_to_df(f"SELECT value FROM {self.test_table} WHERE id IN (1,2)") + self.assertIn(99.9, df['value'].values) + self.assertIn(88.8, df['value'].values) + + def test_06_large_data(self): + """测试大数据量操作""" + # 生成测试数据 + large_data = pd.DataFrame({ + 'id': range(1000, 2000), + 'name': [f"Item_{i}" for i in range(1000, 2000)], + 'value': [i * 0.1 for i in range(1000, 2000)], + 'created_at': pd.date_range('2023-01-01', periods=1000) + }) + + # Windows平台使用更小的批次 + chunk_size = 100 if platform.system() == 'Windows' else 500 + + start_time = time.time() + rows = self.db.insert_from_df(self.test_table, large_data, chunk_size=chunk_size) + elapsed = time.time() - start_time + + self.assertEqual(rows, 1000) + print(f"\nInserted 1000 rows in {elapsed:.2f}s (chunk_size={chunk_size})") + + # 验证数据 + df = self.db.query_to_df(f"SELECT COUNT(*) as cnt FROM {self.test_table} WHERE id >= 1000") + self.assertEqual(df['cnt'].iloc[0], 1000) + + def test_07_concurrent_access(self): + """测试并发访问""" + from concurrent.futures import ThreadPoolExecutor + + def worker(i): + df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id = %s", (i % 5 + 1,)) + return len(df) + + start_time = time.time() + with ThreadPoolExecutor(max_workers=20) as executor: + results = list(executor.map(worker, range(100))) + + elapsed = time.time() - start_time + self.assertEqual(sum(results), 100) + print(f"\nCompleted 100 concurrent queries in {elapsed:.2f}s") + + +class TestPlatformSpecific(unittest.TestCase): + @classmethod + def setUpClass(cls): + """创建临时测试数据库""" + cls.test_db_name = "test_db_platform_" + datetime.now().strftime("%Y%m%d%H%M%S") + cls.base_config = { + 'host': 'localhost', + 'port': 3306, + 'user': 'root', + 'password': '123123', + 'max_connections': 10 + } + + # 创建数据库 + temp_conn = pymysql.connect( + host=cls.base_config['host'], + port=cls.base_config['port'], + user=cls.base_config['user'], + password=cls.base_config['password'], + charset='utf8mb4' + ) + + try: + with temp_conn.cursor() as cursor: + cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}") + temp_conn.commit() + finally: + temp_conn.close() + + @classmethod + def tearDownClass(cls): + """删除临时测试数据库""" + temp_conn = pymysql.connect( + host=cls.base_config['host'], + port=cls.base_config['port'], + user=cls.base_config['user'], + password=cls.base_config['password'], + charset='utf8mb4' + ) + + try: + with temp_conn.cursor() as cursor: + cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}") + temp_conn.commit() + finally: + temp_conn.close() + + def test_windows_timeout(self): + """测试Windows平台超时处理""" + if platform.system() != 'Windows': + self.skipTest("Only runs on Windows") + + config = { + **self.base_config, + 'database': self.test_db_name, + 'connect_timeout': 1, + 'read_timeout': 1 + } + + db = MySQLAgent(config) + + # 测试短超时查询 + start_time = time.time() + try: + db.query_to_df("SELECT SLEEP(2)") + self.fail("Should have timed out") + except Exception as e: + self.assertIn("timed out", str(e)) + print(f"\nWindows timeout test: {str(e)}") + + def test_macos_ssl(self): + """测试macOS SSL连接""" + if platform.system() != 'Darwin': + self.skipTest("Only runs on macOS") + + config = { + **self.base_config, + 'database': self.test_db_name, + 'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'} + } + + db = MySQLAgent(config) + version = db.query_to_df("SELECT VERSION() as version") + self.assertIsNotNone(version) + print(f"\nmacOS SSL connection successful: {version['version'].iloc[0]}") + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/test/通用文件读取测试.py b/test/通用文件读取测试.py index 0302bac..6167711 100644 --- a/test/通用文件读取测试.py +++ b/test/通用文件读取测试.py @@ -3,6 +3,8 @@ import pandas as pd import os from pathlib import Path from utils.file_handler import FileHandler +from datetime import datetime + @pytest.fixture def temp_dir(tmp_path): @@ -11,11 +13,13 @@ def temp_dir(tmp_path): test_dir.mkdir() return test_dir + @pytest.fixture def file_handler(temp_dir): """创建FileHandler实例""" return FileHandler(temp_dir) + @pytest.fixture def sample_dataframe(): """创建测试用DataFrame""" @@ -25,6 +29,7 @@ def sample_dataframe(): 'value': [10.5, 20.3, 30.1] }) + @pytest.fixture def sample_text_file(temp_dir): """创建测试文本文件""" @@ -55,30 +60,33 @@ def test_read_write_csv(file_handler, temp_dir, sample_dataframe): assert df.shape == (3, 3) assert list(df.columns) == ['id', 'name', 'value'] + def test_read_write_json(file_handler, temp_dir, sample_dataframe): """测试JSON文件读写""" test_file = temp_dir / "test.json" # 测试写入 write_result = file_handler.write_file(test_file, sample_dataframe) - assert write_result.iloc[0]['success'] == True + assert write_result.iloc[0]['success'] == True # 测试读取 df = file_handler.read_file(test_file) assert df.shape == (3, 3) + def test_read_write_excel(file_handler, temp_dir, sample_dataframe): """测试Excel文件读写""" test_file = temp_dir / "test.xlsx" # 测试写入 write_result = file_handler.write_file(test_file, sample_dataframe) - assert write_result.iloc[0]['success'] == True + assert write_result.iloc[0]['success'] == True # 测试读取 df = file_handler.read_file(test_file) assert df.shape == (3, 3) + def test_read_write_csv(file_handler, temp_dir, sample_dataframe): """测试CSV文件读写""" test_file = temp_dir / "test.csv" @@ -119,6 +127,7 @@ def test_file_operations(file_handler, sample_text_file): assert delete_df.iloc[0]['deleted'] == True assert not os.path.exists(sample_text_file) + def test_directory_operations(file_handler, temp_dir): """测试目录操作""" test_dir = temp_dir / "subdir" @@ -160,6 +169,7 @@ def test_zip_operations(file_handler, temp_dir, sample_dataframe): assert os.path.exists(extract_dir / "file1.txt") assert os.path.exists(extract_dir / "file2.csv") + def test_zip_directory(file_handler, temp_dir): """测试目录压缩""" # 创建测试目录结构 @@ -174,4 +184,4 @@ def test_zip_directory(file_handler, temp_dir): zip_path = temp_dir / "dir.zip" zip_result = file_handler.zip_dir(test_dir, zip_path) assert zip_result.iloc[0]['zipped'] == True - assert zip_result.iloc[0]['file_count'] == 2 \ No newline at end of file + assert zip_result.iloc[0]['file_count'] == 2 diff --git a/utils/file_handler.py b/utils/file_handler.py index e668749..71accdb 100644 --- a/utils/file_handler.py +++ b/utils/file_handler.py @@ -433,7 +433,9 @@ class FileHandler: # ---------------------------- 测试用例 ---------------------------- if __name__ == "__main__": # 初始化处理器(自动处理跨平台路径) - handler = FileHandler("test_data") + project_root = next(p for p in Path(__file__).resolve().parents if + (p / '.git').exists() or (p / 'pyproject.toml').exists() or (p / 'requirements.txt').exists()) + handler = FileHandler(project_root / "test") # 测试路径标准化 test_paths = [ diff --git a/utils/logger.py b/utils/logger.py index e148554..284cc15 100644 --- a/utils/logger.py +++ b/utils/logger.py @@ -6,6 +6,7 @@ import platform from datetime import datetime import zipfile + class CrossPlatformLog: """跨平台日志系统(支持Linux/Windows/Mac)""" @@ -94,5 +95,6 @@ class CrossPlatformLog: """获取模块专属日志器""" return logger.bind(module=module_name or "__main__") + # 初始化全局日志器 -log = CrossPlatformLog().get_logger() \ No newline at end of file +log = CrossPlatformLog().get_logger() diff --git a/utils/network.py b/utils/network.py deleted file mode 100644 index 66d4373..0000000 --- a/utils/network.py +++ /dev/null @@ -1,233 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -网络工具模块 -功能: -1. 自动重试的HTTP请求 -2. 多平台代理配置支持 -3. DNS缓存优化 -4. 连接超时与SSL验证 -5. 用户代理轮换 -""" - -import socket -import time -import random -import platform -from typing import Optional, Dict, Any, Union -from urllib.parse import urlparse -from functools import lru_cache -import requests -from requests.adapters import HTTPAdapter -from urllib3.util.retry import Retry -from urllib3.connection import HTTPConnection -import os - -# 类型别名 -TimeoutType = Union[float, tuple] - - -class NetworkUtils: - """跨平台网络操作工具类""" - - def __init__(self): - self.system = platform.system().lower() - self._dns_cache = {} - self._setup_platform_specific() - - def _setup_platform_specific(self): - """平台相关初始化""" - if self.system == 'windows': - # Windows默认关闭TCP_NODELAY - HTTPConnection.default_socket_options = ( - HTTPConnection.default_socket_options + [ - (socket.IPPROTO_TCP, socket.TCP_NODELAY, 0) - ] - ) - elif self.system == 'linux': - # Linux启用TCP快速打开 - HTTPConnection.default_socket_options = ( - HTTPConnection.default_socket_options + [ - (socket.IPPROTO_TCP, socket.TCP_FASTOPEN, 5) - ] - ) - - @staticmethod - @lru_cache(maxsize=512) - def _resolve_hostname(hostname: str) -> str: - """DNS缓存解析(跨线程安全)""" - try: - return socket.gethostbyname(hostname) - except socket.gaierror: - return hostname # 失败时返回原始域名 - - def get_session( - self, - retries: int = 3, - backoff_factor: float = 0.5, - timeout: TimeoutType = (3.05, 30), - proxy: Optional[str] = None, - verify_ssl: bool = True - ) -> requests.Session: - """ - 获取配置好的请求会话 - - 参数: - retries: 重试次数 - backoff_factor: 重试间隔系数 - timeout: (连接超时, 读取超时)秒数 - proxy: 代理地址(如 'http://user:pass@proxy:port') - verify_ssl: 是否验证SSL证书 - """ - session = requests.Session() - - # 重试策略 - retry = Retry( - total=retries, - backoff_factor=backoff_factor, - status_forcelist=[500, 502, 503, 504], - allowed_methods=frozenset(['GET', 'POST', 'PUT', 'DELETE']) - ) - - # 适配器配置 - adapter = HTTPAdapter( - max_retries=retry, - pool_connections=20, - pool_maxsize=100 - ) - - # 挂载适配器 - session.mount('http://', adapter) - session.mount('https://', adapter) - - # 代理配置 - if proxy: - session.proxies = { - 'http': proxy, - 'https': proxy - } - - # 请求默认配置 - session.request = self._wrap_request( - session.request, - timeout=timeout, - verify=verify_ssl - ) - - return session - - def _wrap_request(self, original_request, **defaults): - """包装请求方法添加默认参数""" - - def wrapped(method, url, **kwargs): - # 处理DNS缓存 - parsed = urlparse(url) - if parsed.hostname: - kwargs['hooks'] = kwargs.get('hooks', {}) - kwargs['hooks']['pre_request'] = lambda r: setattr( - r, '_orig_host', r.url - ) - url = url.replace( - parsed.hostname, - self._resolve_hostname(parsed.hostname), - 1 - ) - - # 合并默认参数 - for k, v in defaults.items(): - kwargs.setdefault(k, v) - - return original_request(method, url, **kwargs) - - return wrapped - - def get_user_agent(self) -> str: - """获取随机用户代理(兼容各平台)""" - agents = [ - # Windows - "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", - # macOS - "Mozilla/5.0 (Macintosh; Intel Mac OS X 12_4) AppleWebKit/605.1.15", - # Linux - "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36", - # Mobile - "Mozilla/5.0 (iPhone; CPU iPhone OS 15_5 like Mac OS X) AppleWebKit/605.1.15" - ] - return random.choice(agents) - - def check_connection( - self, - url: str = "https://www.baidu.com", - timeout: float = 5.0 - ) -> bool: - """ - 检查网络连接状态 - - 参数: - url: 测试用的URL - timeout: 超时时间(秒) - """ - try: - session = self.get_session(retries=0, timeout=timeout) - session.head(url) - return True - except Exception: - return False - - def download_file( - self, - url: str, - save_path: str, - chunk_size: int = 8192, - progress_callback: Optional[callable] = None - ) -> bool: - """ - 下载大文件支持断点续传 - - 参数: - url: 文件URL - save_path: 本地保存路径 - chunk_size: 分块大小(字节) - progress_callback: 进度回调函数(bytes_downloaded, total_size) - """ - session = self.get_session() - headers = {} - - # 检查本地文件部分下载 - if os.path.exists(save_path): - downloaded = os.path.getsize(save_path) - headers['Range'] = f'bytes={downloaded}-' - - try: - with session.get(url, headers=headers, stream=True) as r: - r.raise_for_status() - - # 获取文件总大小 - total_size = int(r.headers.get('content-length', 0)) + downloaded - - # 追加模式写入 - mode = 'ab' if headers.get('Range') else 'wb' - with open(save_path, mode) as f: - for chunk in r.iter_content(chunk_size=chunk_size): - if chunk: # 过滤keep-alive chunks - f.write(chunk) - downloaded += len(chunk) - if progress_callback: - progress_callback(downloaded, total_size) - return True - except Exception as e: - print(f"下载失败: {str(e)}") - return False - - -# 全局实例(线程安全) -network_utils = NetworkUtils() - - -# 快捷方法(兼容旧代码) -def get_session(*args, **kwargs): - return network_utils.get_session(*args, **kwargs) - - -def check_connection(*args, **kwargs): - return network_utils.check_connection(*args, **kwargs)