import os import sys import platform import pandas as pd import pymysql from pymysql import cursors from pymysql.err import MySQLError from dbutils.pooled_db import PooledDB from typing import Union, List, Dict, Any, Optional, Tuple import threading from datetime import datetime import numpy as np from pathlib import Path # 导入您的日志系统 from utils.logger import log as logger class MySQLAgent: """ 全平台兼容的MySQL数据库操作类 支持Windows/macOS/Linux系统 """ _instance = None _lock = threading.Lock() # 各平台特定的配置 PLATFORM_CONFIG = { 'Windows': { 'socket_timeout': 30, 'connect_timeout': 10, 'ssl': None }, 'Darwin': { # macOS 'socket_timeout': 60, 'connect_timeout': 15, 'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'} }, 'Linux': { 'socket_timeout': 60, 'connect_timeout': 15, 'ssl': None } } def __new__(cls, *args, **kwargs): if not cls._instance: with cls._lock: if not cls._instance: cls._instance = super().__new__(cls) return cls._instance def __init__(self, config: dict = None): if hasattr(self, '_pool') and self._pool: return if not config: from config.settings import DATABASE_CONFIG config = DATABASE_CONFIG # 获取当前平台配置 current_platform = platform.system() platform_config = self.PLATFORM_CONFIG.get(current_platform, {}) # 基础配置 self.config = { 'host': config.get('host', 'localhost'), 'port': config.get('port', 3306), 'user': config.get('user', 'root'), 'password': config.get('password', ''), 'database': config.get('database', 'intelligence_system'), 'charset': config.get('charset', 'utf8mb4'), 'cursorclass': cursors.DictCursor, 'autocommit': True, **platform_config # 合并平台特定配置 } # 处理各平台路径差异 if current_platform == 'Windows': self.config['ssl'] = None # Windows通常不需要SSL配置 # macOS特殊处理 elif current_platform == 'Darwin': if not os.path.exists(self.config['ssl']['ca']): self.config['ssl'] = None logger.warning("macOS SSL certificate not found, disabling SSL") self.pool_size = config.get('max_connections', 5) self._pool = self._create_pool() self.logger = logger.bind(module=f"MySQLAgent({current_platform})") def _create_pool(self) -> PooledDB: """创建跨平台兼容的连接池""" try: # 各平台连接池参数调整 pool_config = { 'creator': pymysql, 'maxconnections': self.pool_size, 'mincached': 1, 'maxcached': 3, 'blocking': True, 'ping': 1, # 定期检查连接有效性 **self.config } # Windows平台需要更短的超时时间 if platform.system() == 'Windows': pool_config['ping'] = 0 # Windows上ping有时不稳定 pool = PooledDB(**pool_config) self.logger.info(f"Connection pool created for {platform.system()}") return pool except Exception as e: self.logger.critical("Failed to create connection pool", error=str(e), exc_info=True) raise def _handle_path(self, path: str) -> str: """处理跨平台路径问题""" if platform.system() == 'Windows': return path.replace('/', '\\') return path def get_connection(self) -> pymysql.connections.Connection: """ 获取数据库连接(跨平台兼容) Returns: pymysql.connections.Connection: 数据库连接 Raises: MySQLError: 如果连接失败 """ try: conn = self._pool.connection() # macOS需要特殊处理SSL if platform.system() == 'Darwin' and self.config.get('ssl'): conn.ping(reconnect=True) self.logger.trace("Connection obtained") return conn except Exception as e: error_msg = str(e) # Windows特定错误处理 if platform.system() == 'Windows' and "timed out" in error_msg: self.logger.warning("Windows connection timeout, retrying...") return self._retry_connection() self.logger.error("Connection failed", error=error_msg, exc_info=True) raise def _retry_connection(self, max_retries: int = 3) -> pymysql.connections.Connection: """Windows平台连接重试机制""" for attempt in range(max_retries): try: conn = self._pool.connection() self.logger.info(f"Connection established after {attempt+1} attempts") return conn except Exception: if attempt == max_retries - 1: raise import time time.sleep(1) def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None, parse_dates: Union[List[str], bool] = True) -> pd.DataFrame: """ 跨平台兼容的SQL查询 Args: sql (str): SQL语句 params (Union[tuple, dict, None]): 参数 parse_dates (Union[List[str], bool]): 日期解析 Returns: pd.DataFrame: 查询结果 """ try: with self.get_connection() as conn: # Linux/macOS需要更长的查询超时 if platform.system() != 'Windows': conn.cursor().execute("SET SESSION wait_timeout=600") df = pd.read_sql(sql, conn, params=params, parse_dates=parse_dates) # Windows平台需要手动关闭游标 if platform.system() == 'Windows': conn.cursor().close() self.logger.info("Query executed", rows=len(df)) return df except Exception as e: self.logger.error("Query failed", sql=sql, params=params, error=str(e), exc_info=True) raise def insert_from_df(self, table_name: str, df: pd.DataFrame, chunk_size: int = 1000, replace: bool = False) -> int: """ 跨平台数据插入 Args: table_name (str): 表名 df (pd.DataFrame): 数据 chunk_size (int): 分批大小 replace (bool): 是否替换 Returns: int: 插入行数 """ if df.empty: self.logger.warning("Empty DataFrame", table=table_name) return 0 try: method = 'replace' if replace else 'append' total_rows = 0 with self.get_connection() as conn: # 各平台不同的分批策略 if platform.system() == 'Windows': chunk_size = min(chunk_size, 500) # Windows上减小批次 for i in range(0, len(df), chunk_size): chunk = df.iloc[i:i + chunk_size] # macOS需要特殊处理datetime if platform.system() == 'Darwin': for col in chunk.select_dtypes(include=['datetime64']): chunk[col] = chunk[col].dt.strftime('%Y-%m-%d %H:%M:%S') chunk.to_sql( table_name, conn, if_exists=method, index=False, method='multi' ) total_rows += len(chunk) method = 'append' self.logger.info("Data inserted", table=table_name, rows=total_rows) return total_rows except Exception as e: self.logger.error("Insert failed", table=table_name, error=str(e), exc_info=True) raise def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None, fetch: bool = False) -> Union[int, List[Dict[str, Any]]]: """ 跨平台SQL执行 Args: sql (str): SQL语句 params (Union[tuple, dict, None]): 参数 fetch (bool): 是否获取结果 Returns: Union[int, List[Dict[str, Any]]]: 结果 """ conn = None cursor = None try: conn = self.get_connection() cursor = conn.cursor() # Linux/macOS需要更长的执行时间 if platform.system() != 'Windows': cursor.execute("SET SESSION max_execution_time=600000") cursor.execute(sql, params) if fetch: result = cursor.fetchall() self.logger.debug("Query executed", rows=len(result)) return result else: affected_rows = cursor.rowcount self.logger.debug("Update executed", affected_rows=affected_rows) return affected_rows except Exception as e: self.logger.error("SQL execution failed", sql=sql, params=params, error=str(e), exc_info=True) raise finally: if cursor: cursor.close() if conn: conn.close() def begin_transaction(self) -> pymysql.connections.Connection: """开始事务(跨平台兼容)""" try: conn = self.get_connection() conn.autocommit(False) # macOS需要特殊处理事务隔离级别 if platform.system() == 'Darwin': conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED") self.logger.debug("Transaction started") return conn except Exception as e: self.logger.error("Begin transaction failed", error=str(e)) raise def commit_transaction(self, conn: pymysql.connections.Connection) -> None: """提交事务(跨平台兼容)""" try: conn.commit() self.logger.debug("Transaction committed") except Exception as e: self.logger.error("Commit failed", error=str(e)) raise finally: conn.close() def rollback_transaction(self, conn: pymysql.connections.Connection) -> None: """回滚事务(跨平台兼容)""" try: conn.rollback() self.logger.warning("Transaction rolled back") except Exception as e: self.logger.error("Rollback failed", error=str(e)) finally: conn.close() def __del__(self): """析构函数(跨平台资源清理)""" if hasattr(self, '_pool'): try: self._pool.close() self.logger.info("Connection pool closed") except Exception as e: self.logger.error("Failed to close pool", error=str(e)) # 平台特定的默认配置 def get_default_config(): """获取各平台默认配置""" current_platform = platform.system() base_config = { 'host': 'localhost', 'port': 3306, 'user': 'root', 'password': '', 'database': 'intelligence_system', 'max_connections': 5 } if current_platform == 'Windows': return { **base_config, 'connect_timeout': 10, 'read_timeout': 30, 'write_timeout': 30 } elif current_platform == 'Darwin': return { **base_config, 'connect_timeout': 15, 'read_timeout': 60, 'write_timeout': 60, 'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'} } else: # Linux和其他平台 return { **base_config, 'connect_timeout': 15, 'read_timeout': 60, 'write_timeout': 60 } # 使用示例 if __name__ == "__main__": # 自动获取适合当前平台的配置 config = get_default_config() # 初始化数据库连接 db = MySQLAgent(config) # 测试查询 try: df = db.query_to_df("SELECT VERSION() as version") print(f"Database version: {df['version'].iloc[0]}") print(f"Running on: {platform.system()} {platform.release()}") except Exception as e: print(f"Error: {str(e)}")