import os import sys import platform import pandas as pd import pymysql from pymysql import cursors from pymysql.err import MySQLError from dbutils.pooled_db import PooledDB from typing import Union, List, Dict, Any, Optional, Tuple import threading from datetime import datetime import numpy as np from pathlib import Path # 导入日志系统 from utils.logger import log class MySQLAgent: """ 全平台兼容的MySQL数据库操作类 支持Windows/macOS/Linux系统 配置参数从外部传入 """ _instance = None _lock = threading.Lock() def __new__(cls, *args, **kwargs): if not cls._instance: with cls._lock: if not cls._instance: cls._instance = super().__new__(cls) return cls._instance def __init__(self, config: dict): """ 初始化MySQL数据库连接 Args: config (dict): 数据库配置字典,包含以下键: - host: 数据库主机 - port: 端口 - user: 用户名 - password: 密码 - database: 数据库名 - [可选] charset: 字符集(默认utf8mb4) - [可选] max_connections: 最大连接数(默认5) - [可选] connect_timeout: 连接超时(秒) - [可选] read_timeout: 读取超时(秒) - [可选] write_timeout: 写入超时(秒) - [可选] ssl: SSL配置 """ if hasattr(self, '_pool') and self._pool: return # 基础配置 required_keys = ['host', 'port', 'user', 'password', 'database'] if not all(key in config for key in required_keys): raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}") self.config = { 'host': config['host'], 'port': config['port'], 'user': config['user'], 'password': config['password'], 'database': config['database'], 'charset': config.get('charset', 'utf8mb4'), 'cursorclass': cursors.DictCursor, 'autocommit': True, 'connect_timeout': config.get('connect_timeout', 10), 'read_timeout': config.get('read_timeout', 30), 'write_timeout': config.get('write_timeout', 30), 'ssl': config.get('ssl') } # 初始化log current_platform = platform.system() self.log = log.bind(module=f"MySQLAgent({current_platform})") # 创建连接池 self.pool_size = config.get('max_connections', 5) self._pool = self._create_pool() def _create_pool(self) -> PooledDB: """创建连接池""" try: # 使用包装函数确保线程安全 def connect(): conn = pymysql.connect(**self.config) conn.threadsafety = 1 # 显式设置线程安全级别 return conn pool = PooledDB( creator=connect, mincached=1, maxcached=3, maxconnections=self.pool_size, blocking=True, ping=1 ) self.log.info("Connection pool created") return pool except Exception as e: self.log.critical("Failed to create connection pool", error=str(e), exc_info=True) raise def get_connection(self) -> pymysql.connections.Connection: """ 获取数据库连接 Returns: pymysql.connections.Connection: 数据库连接对象 Raises: MySQLError: 如果获取连接失败 """ try: conn = self._pool.connection() # macOS需要特殊处理SSL if platform.system() == 'Darwin' and self.config.get('ssl'): conn.ping(reconnect=True) self.log.trace("Database connection obtained") return conn except Exception as e: error_msg = str(e) # Windows特定错误处理 if platform.system() == 'Windows' and "timed out" in error_msg: self.log.warning("Windows connection timeout, retrying...") return self._retry_connection() self.log.error("Connection failed", error=error_msg, exc_info=True) raise def _retry_connection(self, max_retries: int = 3) -> pymysql.connections.Connection: """Windows平台连接重试机制""" for attempt in range(max_retries): try: conn = self._pool.connection() self.log.info(f"Connection established after {attempt + 1} attempts") return conn except Exception: if attempt == max_retries - 1: raise import time time.sleep(1) def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None, parse_dates: Union[List[str], bool] = True) -> pd.DataFrame: """ 执行SQL查询并返回DataFrame Args: sql (str): SQL查询语句 params (Union[tuple, dict, None]): 查询参数 parse_dates (Union[List[str], bool]): 自动解析日期字段 Returns: pd.DataFrame: 查询结果 Raises: MySQLError: 如果查询失败 """ try: self.log.debug("Executing SQL query", sql=sql) with self.get_connection() as conn: # Linux/macOS需要更长的查询超时 if platform.system() != 'Windows': conn.cursor().execute("SET SESSION wait_timeout=600") df = pd.read_sql(sql, conn, params=params, parse_dates=parse_dates) # Windows平台需要手动关闭游标 if platform.system() == 'Windows': conn.cursor().close() self.log.info("Query executed successfully", rows=len(df)) return df except Exception as e: self.log.error("SQL query failed", sql=sql, params=params, error=str(e), exc_info=True) raise def insert_from_df(self, table_name: str, df: pd.DataFrame, chunk_size: int = 1000, replace: bool = False) -> int: """ 将DataFrame数据插入到数据库表(修复版) Args: table_name (str): 目标表名 df (pd.DataFrame): 要插入的数据 chunk_size (int): 分批插入大小 replace (bool): 是否替换现有数据 Returns: int: 插入的总行数 Raises: MySQLError: 如果插入失败 """ if df.empty: self.log.warning("Attempted to insert empty DataFrame", table=table_name) return 0 self.log.debug("Preparing to insert DataFrame", table=table_name, rows=len(df), chunk_size=chunk_size) try: method = 'replace' if replace else 'append' total_rows = 0 # 创建临时SQLAlchemy引擎(不创建新连接池) from sqlalchemy import create_engine from sqlalchemy.pool import StaticPool # 获取当前连接并包装 conn = self.get_connection() # 修复连接对象缺少character_set_name的问题 if not hasattr(conn, 'character_set_name'): conn.character_set_name = lambda: self.config.get('charset', 'utf8mb4') engine = create_engine( "mysql+pymysql://", creator=lambda: conn, poolclass=StaticPool, # 使用静态池避免创建新连接 connect_args={ 'charset': self.config.get('charset', 'utf8mb4'), 'autocommit': True } ) try: for i in range(0, len(df), chunk_size): chunk = df.iloc[i:i + chunk_size] # macOS需要特殊处理datetime if platform.system() == 'Darwin': for col in chunk.select_dtypes(include=['datetime64']): chunk[col] = chunk[col].dt.strftime('%Y-%m-%d %H:%M:%S') chunk.to_sql( table_name, engine, if_exists=method, index=False, method='multi' ) total_rows += len(chunk) method = 'append' # 第一次之后都使用追加模式 self.log.trace(f"Inserted chunk {i // chunk_size + 1}", rows=len(chunk), total_inserted=total_rows) self.log.info("Data inserted successfully", table=table_name, total_rows=total_rows) return total_rows finally: # 确保连接正确关闭 engine.dispose() conn.close() except Exception as e: self.log.error("Data insertion failed", table=table_name, error=str(e), exc_info=True) raise def update_from_df(self, table_name: str, df: pd.DataFrame, key_columns: Union[str, List[str]]) -> int: """ 使用DataFrame数据更新数据库表 Args: table_name (str): 目标表名 df (pd.DataFrame): 包含更新数据 key_columns (Union[str, List[str]]): 用于匹配记录的关键列 Returns: int: 更新的总行数 Raises: MySQLError: 如果更新失败 """ if df.empty: self.log.warning("Attempted to update with empty DataFrame", table=table_name) return 0 self.log.debug("Preparing to update table from DataFrame", table=table_name, key_columns=key_columns, rows=len(df)) try: if isinstance(key_columns, str): key_columns = [key_columns] total_updated = 0 conn = self.begin_transaction() try: cursor = conn.cursor() # 获取表结构信息 table_info = self._get_table_info(table_name) columns = [col for col in df.columns if col in table_info] # 构建UPDATE语句模板 set_clause = ', '.join([f"{col}=%s" for col in columns if col not in key_columns]) where_clause = ' AND '.join([f"{col}=%s" for col in key_columns]) update_sql = f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}" self.log.trace("Generated update SQL", sql=update_sql) # 准备数据 update_data = [] for _, row in df.iterrows(): # SET部分的值 set_values = [row[col] for col in columns if col not in key_columns] # WHERE部分的值 key_values = [row[col] for col in key_columns] update_data.append(tuple(set_values + key_values)) # 执行批量更新 cursor.executemany(update_sql, update_data) total_updated = cursor.rowcount self.commit_transaction(conn) self.log.info("Data updated successfully", table=table_name, rows_updated=total_updated) return total_updated except Exception as e: self.rollback_transaction(conn) raise except Exception as e: self.log.error("Data update failed", table=table_name, error=str(e), exc_info=True) raise def _get_table_info(self, table_name: str) -> Dict[str, str]: """ 获取表结构信息 Args: table_name (str): 表名 Returns: Dict[str, str]: 列名到类型的映射 Raises: MySQLError: 如果查询失败 """ sql = f""" SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = %s AND table_name = %s """ params = (self.config['database'], table_name) try: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(sql, params) result = cursor.fetchall() return {row['column_name']: row['data_type'] for row in result} except Exception as e: self.log.error("Failed to get table info", table=table_name, error=str(e)) raise def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]: """ 推断DataFrame各列的SQL类型 Args: df (pd.DataFrame): 输入数据框 Returns: Dict[str, str]: 列名到SQL类型的映射 """ type_mapping = { 'int64': 'BIGINT', 'float64': 'DOUBLE', 'datetime64[ns]': 'DATETIME', 'object': 'TEXT', 'bool': 'TINYINT(1)', 'category': 'VARCHAR(255)' } sql_types = {} for col, dtype in df.dtypes.items(): dtype_str = str(dtype) sql_types[col] = type_mapping.get(dtype_str, 'TEXT') self.log.debug("Mapped DataFrame types to SQL types", mappings=sql_types) return sql_types def create_table_from_df(self, table_name: str, df: pd.DataFrame, primary_key: Union[str, List[str], None] = None) -> bool: """ 根据DataFrame结构创建表 Args: table_name (str): 表名 df (pd.DataFrame): 参考数据框 primary_key (Union[str, List[str], None]): 主键列 Returns: bool: 是否创建成功 """ if self.table_exists(table_name): self.log.warning("Table already exists", table=table_name) return False self.log.debug("Creating new table from DataFrame schema", table=table_name, columns=list(df.columns)) try: sql_types = self.df_to_sql_type(df) columns_sql = [] for col, sql_type in sql_types.items(): col_def = f"{col} {sql_type}" columns_sql.append(col_def) if primary_key: if isinstance(primary_key, str): primary_key = [primary_key] pk_columns = [col for col in primary_key if col in sql_types] if pk_columns: columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})") self.log.trace("Set primary key", table=table_name, primary_key=pk_columns) create_sql = f"CREATE TABLE {table_name} (\n {',\n '.join(columns_sql)}\n)" self.execute_sql(create_sql) self.log.info("Table created successfully", table=table_name) return True except Exception as e: self.log.error("Failed to create table", table=table_name, error=str(e), exc_info=True) return False def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None, fetch: bool = False) -> Union[int, List[Dict[str, Any]]]: """ 执行SQL语句 Args: sql (str): SQL语句 params (Union[tuple, dict, None]): 参数 fetch (bool): 是否获取结果 Returns: Union[int, List[Dict[str, Any]]]: - 如果是INSERT/UPDATE/DELETE,返回影响的行数 - 如果是SELECT且fetch=True,返回结果列表 """ conn = None cursor = None try: conn = self.get_connection() cursor = conn.cursor() # Linux/macOS需要更长的执行时间 if platform.system() != 'Windows': cursor.execute("SET SESSION max_execution_time=600000") cursor.execute(sql, params) if fetch: result = cursor.fetchall() self.log.debug("Query executed", rows=len(result)) return result else: affected_rows = cursor.rowcount self.log.debug("Update executed", affected_rows=affected_rows) return affected_rows except Exception as e: self.log.error("SQL execution failed", sql=sql, params=params, error=str(e), exc_info=True) raise finally: if cursor: cursor.close() if conn: conn.close() def begin_transaction(self) -> pymysql.connections.Connection: """开始事务""" try: conn = self.get_connection() conn.autocommit(False) # macOS需要特殊处理事务隔离级别 if platform.system() == 'Darwin': conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED") self.log.debug("Transaction started") return conn except Exception as e: self.log.error("Begin transaction_failed", error=str(e)) raise def commit_transaction(self, conn: pymysql.connections.Connection) -> None: """提交事务""" try: conn.commit() self.log.debug("Transaction committed") except Exception as e: self.log.error("Commit failed", error=str(e)) raise finally: conn.close() def rollback_transaction(self, conn: pymysql.connections.Connection) -> None: """回滚事务""" try: conn.rollback() self.log.warning("Transaction rolled back") except Exception as e: self.log.error("Rollback failed", error=str(e)) finally: conn.close() def table_exists(self, table_name: str) -> bool: """检查表是否存在""" sql = """ SELECT COUNT(*) as count FROM `information_schema`.`tables` WHERE `table_schema` = %s AND `table_name` = %s """ params = (self.config['database'], table_name) try: result = self.execute_sql(sql, params, fetch=True) exists = result[0]['count'] > 0 self.log.debug("Checked table existence", table=table_name, exists=exists) return exists except Exception: return False def drop_table(self, table_name: str) -> bool: """删除表""" if not self.table_exists(table_name): self.log.warning("Table does not exist", table=table_name) return False try: self.execute_sql(f"DROP TABLE {table_name}") self.log.info("Table dropped successfully", table=table_name) return True except Exception as e: self.log.error("Failed to drop table", table=table_name, error=str(e), exc_info=True) return False def get_pool_status(self) -> Dict[str, int]: """获取连接池状态""" return { 'max': self._pool._maxconnections, 'active': self._pool._connections, 'idle': len(self._pool._idle_cache), 'shared': len(self._pool._shared_cache) } def validate_connection(self) -> bool: """验证连接是否有效""" try: with self.get_connection() as conn: with conn.cursor() as cursor: cursor.execute("SELECT 1") return cursor.fetchone()[0] == 1 except Exception: return False def __del__(self): """析构函数""" if hasattr(self, '_pool'): try: self._pool.close() self.log.info("Connection pool closed") except Exception as e: self.log.error("Failed to close pool", error=str(e)) # 平台特定的默认配置 def get_default_config(): """获取各平台默认配置""" current_platform = platform.system() base_config = { 'host': 'localhost', 'port': 3306, 'user': 'root', 'password': '123123', 'database': 'intelligence', 'max_connections': 5 } if current_platform == 'Windows': return { **base_config, 'connect_timeout': 10, 'read_timeout': 30, 'write_timeout': 30 } elif current_platform == 'Darwin': return { **base_config, 'connect_timeout': 15, 'read_timeout': 60, 'write_timeout': 60, 'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'} } else: # Linux和其他平台 return { **base_config, 'connect_timeout': 15, 'read_timeout': 60, 'write_timeout': 60 } if __name__ == "__main__": # 使用示例 db = MySQLAgent(get_default_config()) # 测试连接 if db.validate_connection(): print("Database connection successful") # 获取数据库版本 version = db.query_to_df("SELECT VERSION() as version") print(f"Database version: {version['version'].iloc[0]}") # 查看连接池状态 print("Connection pool status:", db.get_pool_status()) else: print("Failed to connect to database")