import os import sys import platform import pandas as pd import pymysql from pymysql import cursors from pymysql.err import MySQLError from dbutils.pooled_db import PooledDB from typing import Union, List, Dict, Any, Optional, Tuple import threading from datetime import datetime import numpy as np from pathlib import Path # 导入日志系统 from utils.logger import log class MySQLAgent: """ 全平台兼容的MySQL数据库操作类 支持Windows/macOS/Linux系统 配置参数从外部传入 """ _instance = None _lock = threading.Lock() def __new__(cls, *args, **kwargs): if not cls._instance: with cls._lock: if not cls._instance: cls._instance = super().__new__(cls) return cls._instance def __init__(self, config: dict): """ 初始化MySQL数据库连接 Args: config (dict): 数据库配置字典,包含以下键: - host: 数据库主机 - port: 端口 - user: 用户名 - password: 密码 - database: 数据库名 - [可选] charset: 字符集(默认utf8mb4) - [可选] max_connections: 最大连接数(默认5) - [可选] connect_timeout: 连接超时(秒) - [可选] read_timeout: 读取超时(秒) - [可选] write_timeout: 写入超时(秒) - [可选] ssl: SSL配置 """ if hasattr(self, '_pool') and self._pool: return # 基础配置校验 required_keys = ['host', 'port', 'user', 'password', 'database'] if not all(key in config for key in required_keys): log.warning(f"数据库配置缺少必要参数,当前配置: {config}") raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}") self.config = { 'host': config['host'], 'port': config['port'], 'user': config['user'], 'password': config['password'], 'database': config['database'], 'charset': config.get('charset', 'utf8mb4'), 'cursorclass': cursors.DictCursor, 'autocommit': True, 'connect_timeout': config.get('connect_timeout', 10), 'read_timeout': config.get('read_timeout', 30), 'write_timeout': config.get('write_timeout', 30), 'ssl': config.get('ssl') } # 初始化日志 current_platform = platform.system() self.log = log.bind(module=f"MySQLAgent({current_platform})") # 创建连接池 self.pool_size = config.get('max_connections', 5) self._pool = self._create_pool() def _create_pool(self) -> PooledDB: """创建连接池""" try: # 线程安全的连接创建函数 def connect(): conn = pymysql.connect(**self.config) conn.threadsafety = 1 # 显式设置线程安全级别 return conn pool = PooledDB( creator=connect, mincached=1, maxcached=3, maxconnections=self.pool_size, blocking=True, ping=1 # 每次获取连接时ping数据库 ) self.log.info("连接池创建成功") return pool except Exception as e: self.log.critical("连接池创建失败", error=str(e), exc_info=True) raise def get_connection(self) -> pymysql.connections.Connection: """获取数据库连接(修复字符集方法缺失问题)""" try: conn = self._pool.connection() # 为连接添加字符集方法(兼容SQLAlchemy) if not hasattr(conn, 'character_set_name'): def _character_set_name(): return self.config.get('charset', 'utf8mb4') conn.character_set_name = _character_set_name # macOS平台SSL特殊处理 if platform.system() == 'Darwin' and self.config.get('ssl'): conn.ping(reconnect=True) self.log.trace("获取数据库连接成功") return conn except Exception as e: error_msg = str(e) # Windows平台连接超时重试 if platform.system() == 'Windows' and "timed out" in error_msg: self.log.warning("Windows连接超时,尝试重试...") return self._retry_connection() self.log.error("获取连接失败", error=error_msg, exc_info=True) raise def _retry_connection(self, max_retries: int = 3) -> pymysql.connections.Connection: """Windows平台连接重试机制""" for attempt in range(max_retries): try: conn = self._pool.connection() self.log.info(f"第{attempt + 1}次尝试连接成功") return conn except Exception: if attempt == max_retries - 1: raise import time time.sleep(1) # 重试间隔1秒 def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None, parse_dates: Union[List[str], bool] = True) -> pd.DataFrame: """执行SQL查询并返回DataFrame(优化连接管理)""" conn = None try: self.log.debug("执行SQL查询", sql=sql) conn = self.get_connection() # 创建SQLAlchemy引擎(使用静态池避免连接重复创建) from sqlalchemy import create_engine from sqlalchemy.pool import StaticPool engine = create_engine( "mysql+pymysql://", creator=lambda: conn, poolclass=StaticPool, connect_args={'charset': self.config.get('charset', 'utf8mb4')} ) # 执行查询 df = pd.read_sql(sql, engine, params=params, parse_dates=parse_dates) self.log.info(f"查询成功,返回{len(df)}行数据") return df except Exception as e: self.log.error(f"SQL查询失败{sql}", sql=sql, params=params, error=str(e), exc_info=True) raise finally: # 确保连接释放回池 if conn: try: conn.close() except Exception as e: self.log.warning("关闭连接失败", error=str(e)) def insert_from_df(self, table_name: str, df: pd.DataFrame, chunk_size: int = 1000, replace: bool = False) -> int: """将DataFrame数据插入到数据库表(优化批量处理)""" if df.empty: self.log.warning(f"尝试插入空DataFrame到表{table_name}") return 0 self.log.debug(f"准备插入DataFrame到表{table_name}", rows=len(df), chunk_size=chunk_size) # 根据平台自动调整批次大小 current_platform = platform.system() if current_platform == 'Windows' and chunk_size > 500: chunk_size = 500 self.log.debug(f"Windows平台自动调整批次大小为{chunk_size}") elif current_platform == 'Linux' and chunk_size < 1000: chunk_size = 1000 self.log.debug(f"Linux平台自动调整批次大小为{chunk_size}") try: method = 'replace' if replace else 'append' total_rows = 0 conn = self.get_connection() # 创建SQLAlchemy引擎 from sqlalchemy import create_engine from sqlalchemy.pool import StaticPool engine = create_engine( "mysql+pymysql://", creator=lambda: conn, poolclass=StaticPool, connect_args={ 'charset': self.config.get('charset', 'utf8mb4'), 'autocommit': True } ) try: for i in range(0, len(df), chunk_size): chunk = df.iloc[i:i + chunk_size].copy() # 使用copy避免SettingWithCopyWarning # macOS平台datetime特殊处理 if platform.system() == 'Darwin': for col in chunk.select_dtypes(include=['datetime64']): chunk[col] = chunk[col].dt.strftime('%Y-%m-%d %H:%M:%S') chunk.to_sql( table_name, engine, if_exists=method, index=False, method='multi' ) total_rows += len(chunk) method = 'append' # 首次后使用追加模式 self.log.trace(f"插入第{i // chunk_size + 1}批数据", rows=len(chunk), total=total_rows) self.log.info(f"数据插入成功,表{table_name}共插入{total_rows}行") return total_rows finally: engine.dispose() conn.close() except Exception as e: self.log.error(f"数据插入失败,表{table_name}", error=str(e), exc_info=True) raise def update_from_df(self, table_name: str, df: pd.DataFrame, key_columns: Union[str, List[str]]) -> int: """使用DataFrame数据更新数据库表(优化事务处理)""" if df.empty: self.log.warning(f"尝试用空DataFrame更新表{table_name}") return 0 self.log.debug(f"准备从DataFrame更新表{table_name}", key_columns=key_columns, rows=len(df)) try: if isinstance(key_columns, str): key_columns = [key_columns] # 验证关键列存在性 missing_keys = [key for key in key_columns if key not in df.columns] if missing_keys: raise ValueError(f"DataFrame中缺少关键列: {missing_keys}") total_updated = 0 conn = self.begin_transaction() try: cursor = conn.cursor() # 获取表结构信息 table_info = self._get_table_info(table_name) valid_columns = [col for col in df.columns if col in table_info] if not valid_columns: self.log.warning(f"DataFrame列与表{table_name}无匹配") return 0 # 构建UPDATE语句 set_clause = ', '.join([f"`{col}`=%s" for col in valid_columns if col not in key_columns]) where_clause = ' AND '.join([f"`{col}`=%s" for col in key_columns]) update_sql = f"UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}" self.log.trace("生成更新SQL", sql=update_sql) # 准备更新数据 update_data = [] for _, row in df.iterrows(): set_values = [row[col] for col in valid_columns if col not in key_columns] key_values = [row[col] for col in key_columns] update_data.append(tuple(set_values + key_values)) # 执行批量更新 cursor.executemany(update_sql, update_data) total_updated = cursor.rowcount self.commit_transaction(conn) self.log.info(f"数据更新成功,表{table_name}共更新{total_updated}行") return total_updated except Exception as e: self.rollback_transaction(conn) raise except Exception as e: self.log.error(f"数据更新失败,表{table_name}", error=str(e), exc_info=True) raise def _get_table_info(self, table_name: str) -> Dict[str, str]: """获取表结构信息(优化SQL安全性)""" sql = """ SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = %s \ AND table_name = %s \ """ try: with self.get_connection() as conn: with conn.cursor() as cursor: cursor.execute(sql, (self.config['database'], table_name)) result = cursor.fetchall() return {row['column_name']: row['data_type'] for row in result} except Exception as e: self.log.error(f"获取表{table_name}结构失败", error=str(e)) raise def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]: """推断DataFrame各列的SQL类型(扩展类型映射)""" type_mapping = { 'int64': 'BIGINT', 'int32': 'INT', 'int16': 'SMALLINT', 'int8': 'TINYINT', 'uint64': 'BIGINT UNSIGNED', 'float64': 'DOUBLE', 'float32': 'FLOAT', 'datetime64[ns]': 'DATETIME', 'datetime64[ns, UTC]': 'DATETIME', 'timedelta64[ns]': 'TIME', 'object': 'TEXT', 'string': 'VARCHAR(255)', 'bool': 'TINYINT(1)', 'category': 'VARCHAR(255)' } sql_types = {} for col, dtype in df.dtypes.items(): dtype_str = str(dtype) sql_types[col] = type_mapping.get(dtype_str, 'TEXT') self.log.debug("DataFrame类型映射为SQL类型", mappings=sql_types) return sql_types def create_table_from_df(self, table_name: str, df: pd.DataFrame, primary_key: Union[str, List[str], None] = None) -> bool: """根据DataFrame结构创建表(增强表结构定义)""" if self.table_exists(table_name): self.log.warning(f"表{table_name}已存在") return False self.log.debug(f"根据DataFrame结构创建表{table_name}", columns=list(df.columns)) try: sql_types = self.df_to_sql_type(df) columns_sql = [] for col, sql_type in sql_types.items(): # 特殊字段处理 if col.lower() in ['create_time', 'created_at'] and sql_type != 'DATETIME': col_def = f"`{col}` DATETIME DEFAULT CURRENT_TIMESTAMP" elif col.lower() in ['update_time', 'updated_at'] and sql_type != 'DATETIME': col_def = f"`{col}` DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP" else: col_def = f"`{col}` {sql_type}" columns_sql.append(col_def) # 处理主键 if primary_key: if isinstance(primary_key, str): primary_key = [primary_key] pk_columns = [f"`{col}`" for col in primary_key if col in sql_types] if pk_columns: columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})") self.log.trace(f"表{table_name}设置主键", primary_key=pk_columns) create_sql = f"CREATE TABLE `{table_name}` (\n {',\n '.join(columns_sql)}\n)" self.execute_sql(create_sql) self.log.info(f"表{table_name}创建成功") return True except Exception as e: self.log.error(f"表{table_name}创建失败", error=str(e), exc_info=True) return False def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None, fetch: bool = False) -> Union[int, List[Dict[str, Any]]]: """执行SQL语句(增强资源管理)""" conn = None cursor = None try: conn = self.get_connection() cursor = conn.cursor() # 非Windows平台延长执行超时 if platform.system() != 'Windows': cursor.execute("SET SESSION max_execution_time=600000") # 10分钟 cursor.execute(sql, params) if fetch: result = cursor.fetchall() self.log.debug(f"查询执行完成,返回{len(result)}行") return result else: affected_rows = cursor.rowcount self.log.debug(f"更新执行完成,影响{affected_rows}行") return affected_rows except Exception as e: self.log.error("SQL执行失败", sql=sql, params=params, error=str(e), exc_info=True) raise finally: if cursor: try: cursor.close() except Exception as e: self.log.warning("关闭游标失败", error=str(e)) if conn: try: conn.close() except Exception as e: self.log.warning("关闭连接失败", error=str(e)) def begin_transaction(self) -> pymysql.connections.Connection: """开始事务(增强隔离级别处理)""" try: conn = self.get_connection() conn.autocommit(False) # 平台特定事务配置 if platform.system() == 'Darwin': conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED") elif platform.system() == 'Linux': conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL REPEATABLE READ") self.log.debug("事务开始") return conn except Exception as e: self.log.error("事务开始失败", error=str(e)) raise def commit_transaction(self, conn: pymysql.connections.Connection) -> None: """提交事务""" try: conn.commit() self.log.debug("事务提交成功") except Exception as e: self.log.error("事务提交失败", error=str(e)) raise finally: try: conn.close() except Exception as e: self.log.warning("事务提交后关闭连接失败", error=str(e)) def rollback_transaction(self, conn: pymysql.connections.Connection) -> None: """回滚事务""" try: conn.rollback() self.log.warning("事务已回滚") except Exception as e: self.log.error("事务回滚失败", error=str(e)) finally: try: conn.close() except Exception as e: self.log.warning("事务回滚后关闭连接失败", error=str(e)) def table_exists(self, table_name: str) -> bool: """检查表是否存在(优化SQL安全性)""" sql = """ SELECT COUNT(*) as count FROM `information_schema`.`tables` WHERE `table_schema` = %s \ AND `table_name` = %s \ """ try: result = self.execute_sql(sql, (self.config['database'], table_name), fetch=True) exists = result[0]['count'] > 0 self.log.debug(f"表{table_name}存在性检查", exists=exists) return exists except Exception as e: self.log.warning(f"表{table_name}存在性检查失败", error=str(e)) return False def drop_table(self, table_name: str) -> bool: """删除表(增加二次确认日志)""" if not self.table_exists(table_name): self.log.warning(f"表{table_name}不存在,无法删除") return False try: self.execute_sql(f"DROP TABLE `{table_name}`") self.log.info(f"表{table_name}删除成功") return True except Exception as e: self.log.error(f"表{table_name}删除失败", error=str(e), exc_info=True) return False def get_pool_status(self) -> Dict[str, int]: """获取连接池状态""" status = { 'max_connections': self._pool._maxconnections, 'active_connections': len(self._pool._connections), 'idle_connections': len(self._pool._idle_cache), 'shared_connections': len(self._pool._shared_cache) } self.log.debug("连接池状态", **status) return status def validate_connection(self) -> bool: """验证连接是否有效(增强健康检查)""" try: with self.get_connection() as conn: with conn.cursor() as cursor: cursor.execute("SELECT 1 AS health_check") result = cursor.fetchone() return result['health_check'] == 1 except Exception as e: self.log.warning("连接健康检查失败", error=str(e)) return False def __del__(self): """析构函数(确保连接池关闭)""" if hasattr(self, '_pool') and self._pool: try: self._pool.close() self.log.info("连接池已关闭") except Exception as e: self.log.error("连接池关闭失败", error=str(e)) def get_default_config(): """获取各平台默认配置(优化默认参数)""" current_platform = platform.system() base_config = { 'host': 'localhost', 'port': 3306, 'user': 'root', 'password': '123123', 'database': 'intelligence', 'max_connections': 10, # 增加默认连接数 'charset': 'utf8mb4' } if current_platform == 'Windows': return { **base_config, 'connect_timeout': 10, 'read_timeout': 30, 'write_timeout': 30, 'ssl': None # Windows默认禁用SSL } elif current_platform == 'Darwin': # macOS return { **base_config, 'connect_timeout': 15, 'read_timeout': 60, 'write_timeout': 60, 'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'} # macOS默认SSL配置 } else: # Linux及其他平台 return { **base_config, 'connect_timeout': 15, 'read_timeout': 60, 'write_timeout': 60, 'ssl': None # Linux默认禁用SSL } if __name__ == "__main__": # 使用示例 try: db = MySQLAgent(get_default_config()) # 测试连接 if db.validate_connection(): print("数据库连接成功") # 获取数据库版本 version_df = db.query_to_df("SELECT VERSION() as version") print(f"数据库版本: {version_df['version'].iloc[0]}") # 查看连接池状态 print("连接池状态:", db.get_pool_status()) # 创建测试表 test_df = pd.DataFrame({ 'id': [1, 2, 3], 'name': ['测试1', '测试2', '测试3'], 'value': [10.5, 20.3, 30.8], 'created_at': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03']) }) db.create_table_from_df('test_table', test_df, primary_key='id') print("测试表创建成功") # 插入数据 rows_inserted = db.insert_from_df('test_table', test_df) print(f"插入了{rows_inserted}行数据") # 查询数据 result_df = db.query_to_df("SELECT * FROM test_table") print("查询结果:") print(result_df) # 清理测试表 db.drop_table('test_table') print("测试表已删除") else: print("数据库连接失败") except Exception as e: print(f"示例执行失败: {str(e)}")