Files
intelligence_system/storage/mysql_agent.py
T
2025-08-06 17:29:46 +08:00

684 lines
22 KiB
Python

import os
import sys
import platform
import pandas as pd
import pymysql
from pymysql import cursors
from pymysql.err import MySQLError
from dbutils.pooled_db import PooledDB
from typing import Union, List, Dict, Any, Optional, Tuple
import threading
from datetime import datetime
import numpy as np
from pathlib import Path
# 导入日志系统
from utils.logger import log
class MySQLAgent:
"""
全平台兼容的MySQL数据库操作类
支持Windows/macOS/Linux系统
配置参数从外部传入
"""
_instance = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, config: dict):
"""
初始化MySQL数据库连接
Args:
config (dict): 数据库配置字典,包含以下键:
- host: 数据库主机
- port: 端口
- user: 用户名
- password: 密码
- database: 数据库名
- [可选] charset: 字符集(默认utf8mb4)
- [可选] max_connections: 最大连接数(默认5)
- [可选] connect_timeout: 连接超时(秒)
- [可选] read_timeout: 读取超时(秒)
- [可选] write_timeout: 写入超时(秒)
- [可选] ssl: SSL配置
"""
if hasattr(self, '_pool') and self._pool:
return
# 基础配置
required_keys = ['host', 'port', 'user', 'password', 'database']
if not all(key in config for key in required_keys):
raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}")
self.config = {
'host': config['host'],
'port': config['port'],
'user': config['user'],
'password': config['password'],
'database': config['database'],
'charset': config.get('charset', 'utf8mb4'),
'cursorclass': cursors.DictCursor,
'autocommit': True,
'connect_timeout': config.get('connect_timeout', 10),
'read_timeout': config.get('read_timeout', 30),
'write_timeout': config.get('write_timeout', 30),
'ssl': config.get('ssl')
}
# 初始化log
current_platform = platform.system()
self.log = log.bind(module=f"MySQLAgent({current_platform})")
# 创建连接池
self.pool_size = config.get('max_connections', 5)
self._pool = self._create_pool()
def _create_pool(self) -> PooledDB:
"""创建连接池"""
try:
# 使用包装函数确保线程安全
def connect():
conn = pymysql.connect(**self.config)
conn.threadsafety = 1 # 显式设置线程安全级别
return conn
pool = PooledDB(
creator=connect,
mincached=1,
maxcached=3,
maxconnections=self.pool_size,
blocking=True,
ping=1
)
self.log.info("Connection pool created")
return pool
except Exception as e:
self.log.critical("Failed to create connection pool",
error=str(e),
exc_info=True)
raise
def get_connection(self) -> pymysql.connections.Connection:
"""
获取数据库连接
Returns:
pymysql.connections.Connection: 数据库连接对象
Raises:
MySQLError: 如果获取连接失败
"""
try:
conn = self._pool.connection()
# macOS需要特殊处理SSL
if platform.system() == 'Darwin' and self.config.get('ssl'):
conn.ping(reconnect=True)
self.log.trace("Database connection obtained")
return conn
except Exception as e:
error_msg = str(e)
# Windows特定错误处理
if platform.system() == 'Windows' and "timed out" in error_msg:
self.log.warning("Windows connection timeout, retrying...")
return self._retry_connection()
self.log.error("Connection failed",
error=error_msg,
exc_info=True)
raise
def _retry_connection(self, max_retries: int = 3) -> pymysql.connections.Connection:
"""Windows平台连接重试机制"""
for attempt in range(max_retries):
try:
conn = self._pool.connection()
self.log.info(f"Connection established after {attempt + 1} attempts")
return conn
except Exception:
if attempt == max_retries - 1:
raise
import time
time.sleep(1)
def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None,
parse_dates: Union[List[str], bool] = True) -> pd.DataFrame:
"""
执行SQL查询并返回DataFrame
Args:
sql (str): SQL查询语句
params (Union[tuple, dict, None]): 查询参数
parse_dates (Union[List[str], bool]): 自动解析日期字段
Returns:
pd.DataFrame: 查询结果
Raises:
MySQLError: 如果查询失败
"""
try:
self.log.debug("Executing SQL query", sql=sql)
with self.get_connection() as conn:
# Linux/macOS需要更长的查询超时
if platform.system() != 'Windows':
conn.cursor().execute("SET SESSION wait_timeout=600")
df = pd.read_sql(sql, conn, params=params, parse_dates=parse_dates)
# Windows平台需要手动关闭游标
if platform.system() == 'Windows':
conn.cursor().close()
self.log.info("Query executed successfully", rows=len(df))
return df
except Exception as e:
self.log.error("SQL query failed",
sql=sql,
params=params,
error=str(e),
exc_info=True)
raise
def insert_from_df(self, table_name: str, df: pd.DataFrame,
chunk_size: int = 1000, replace: bool = False) -> int:
"""
将DataFrame数据插入到数据库表(修复版)
Args:
table_name (str): 目标表名
df (pd.DataFrame): 要插入的数据
chunk_size (int): 分批插入大小
replace (bool): 是否替换现有数据
Returns:
int: 插入的总行数
Raises:
MySQLError: 如果插入失败
"""
if df.empty:
self.log.warning("Attempted to insert empty DataFrame", table=table_name)
return 0
self.log.debug("Preparing to insert DataFrame",
table=table_name,
rows=len(df),
chunk_size=chunk_size)
try:
method = 'replace' if replace else 'append'
total_rows = 0
# 创建临时SQLAlchemy引擎(不创建新连接池)
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
# 获取当前连接并包装
conn = self.get_connection()
# 修复连接对象缺少character_set_name的问题
if not hasattr(conn, 'character_set_name'):
conn.character_set_name = lambda: self.config.get('charset', 'utf8mb4')
engine = create_engine(
"mysql+pymysql://",
creator=lambda: conn,
poolclass=StaticPool, # 使用静态池避免创建新连接
connect_args={
'charset': self.config.get('charset', 'utf8mb4'),
'autocommit': True
}
)
try:
for i in range(0, len(df), chunk_size):
chunk = df.iloc[i:i + chunk_size]
# macOS需要特殊处理datetime
if platform.system() == 'Darwin':
for col in chunk.select_dtypes(include=['datetime64']):
chunk[col] = chunk[col].dt.strftime('%Y-%m-%d %H:%M:%S')
chunk.to_sql(
table_name,
engine,
if_exists=method,
index=False,
method='multi'
)
total_rows += len(chunk)
method = 'append' # 第一次之后都使用追加模式
self.log.trace(f"Inserted chunk {i // chunk_size + 1}",
rows=len(chunk),
total_inserted=total_rows)
self.log.info("Data inserted successfully",
table=table_name,
total_rows=total_rows)
return total_rows
finally:
# 确保连接正确关闭
engine.dispose()
conn.close()
except Exception as e:
self.log.error("Data insertion failed",
table=table_name,
error=str(e),
exc_info=True)
raise
def update_from_df(self, table_name: str, df: pd.DataFrame,
key_columns: Union[str, List[str]]) -> int:
"""
使用DataFrame数据更新数据库表
Args:
table_name (str): 目标表名
df (pd.DataFrame): 包含更新数据
key_columns (Union[str, List[str]]): 用于匹配记录的关键列
Returns:
int: 更新的总行数
Raises:
MySQLError: 如果更新失败
"""
if df.empty:
self.log.warning("Attempted to update with empty DataFrame", table=table_name)
return 0
self.log.debug("Preparing to update table from DataFrame",
table=table_name,
key_columns=key_columns,
rows=len(df))
try:
if isinstance(key_columns, str):
key_columns = [key_columns]
total_updated = 0
conn = self.begin_transaction()
try:
cursor = conn.cursor()
# 获取表结构信息
table_info = self._get_table_info(table_name)
columns = [col for col in df.columns if col in table_info]
# 构建UPDATE语句模板
set_clause = ', '.join([f"{col}=%s" for col in columns if col not in key_columns])
where_clause = ' AND '.join([f"{col}=%s" for col in key_columns])
update_sql = f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}"
self.log.trace("Generated update SQL", sql=update_sql)
# 准备数据
update_data = []
for _, row in df.iterrows():
# SET部分的值
set_values = [row[col] for col in columns if col not in key_columns]
# WHERE部分的值
key_values = [row[col] for col in key_columns]
update_data.append(tuple(set_values + key_values))
# 执行批量更新
cursor.executemany(update_sql, update_data)
total_updated = cursor.rowcount
self.commit_transaction(conn)
self.log.info("Data updated successfully",
table=table_name,
rows_updated=total_updated)
return total_updated
except Exception as e:
self.rollback_transaction(conn)
raise
except Exception as e:
self.log.error("Data update failed",
table=table_name,
error=str(e),
exc_info=True)
raise
def _get_table_info(self, table_name: str) -> Dict[str, str]:
"""
获取表结构信息
Args:
table_name (str): 表名
Returns:
Dict[str, str]: 列名到类型的映射
Raises:
MySQLError: 如果查询失败
"""
sql = f"""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
"""
params = (self.config['database'], table_name)
try:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(sql, params)
result = cursor.fetchall()
return {row['column_name']: row['data_type'] for row in result}
except Exception as e:
self.log.error("Failed to get table info",
table=table_name,
error=str(e))
raise
def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]:
"""
推断DataFrame各列的SQL类型
Args:
df (pd.DataFrame): 输入数据框
Returns:
Dict[str, str]: 列名到SQL类型的映射
"""
type_mapping = {
'int64': 'BIGINT',
'float64': 'DOUBLE',
'datetime64[ns]': 'DATETIME',
'object': 'TEXT',
'bool': 'TINYINT(1)',
'category': 'VARCHAR(255)'
}
sql_types = {}
for col, dtype in df.dtypes.items():
dtype_str = str(dtype)
sql_types[col] = type_mapping.get(dtype_str, 'TEXT')
self.log.debug("Mapped DataFrame types to SQL types",
mappings=sql_types)
return sql_types
def create_table_from_df(self, table_name: str, df: pd.DataFrame,
primary_key: Union[str, List[str], None] = None) -> bool:
"""
根据DataFrame结构创建表
Args:
table_name (str): 表名
df (pd.DataFrame): 参考数据框
primary_key (Union[str, List[str], None]): 主键列
Returns:
bool: 是否创建成功
"""
if self.table_exists(table_name):
self.log.warning("Table already exists", table=table_name)
return False
self.log.debug("Creating new table from DataFrame schema",
table=table_name,
columns=list(df.columns))
try:
sql_types = self.df_to_sql_type(df)
columns_sql = []
for col, sql_type in sql_types.items():
col_def = f"{col} {sql_type}"
columns_sql.append(col_def)
if primary_key:
if isinstance(primary_key, str):
primary_key = [primary_key]
pk_columns = [col for col in primary_key if col in sql_types]
if pk_columns:
columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})")
self.log.trace("Set primary key",
table=table_name,
primary_key=pk_columns)
create_sql = f"CREATE TABLE {table_name} (\n {',\n '.join(columns_sql)}\n)"
self.execute_sql(create_sql)
self.log.info("Table created successfully", table=table_name)
return True
except Exception as e:
self.log.error("Failed to create table",
table=table_name,
error=str(e),
exc_info=True)
return False
def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None,
fetch: bool = False) -> Union[int, List[Dict[str, Any]]]:
"""
执行SQL语句
Args:
sql (str): SQL语句
params (Union[tuple, dict, None]): 参数
fetch (bool): 是否获取结果
Returns:
Union[int, List[Dict[str, Any]]]:
- 如果是INSERT/UPDATE/DELETE,返回影响的行数
- 如果是SELECT且fetch=True,返回结果列表
"""
conn = None
cursor = None
try:
conn = self.get_connection()
cursor = conn.cursor()
# Linux/macOS需要更长的执行时间
if platform.system() != 'Windows':
cursor.execute("SET SESSION max_execution_time=600000")
cursor.execute(sql, params)
if fetch:
result = cursor.fetchall()
self.log.debug("Query executed", rows=len(result))
return result
else:
affected_rows = cursor.rowcount
self.log.debug("Update executed", affected_rows=affected_rows)
return affected_rows
except Exception as e:
self.log.error("SQL execution failed",
sql=sql,
params=params,
error=str(e),
exc_info=True)
raise
finally:
if cursor:
cursor.close()
if conn:
conn.close()
def begin_transaction(self) -> pymysql.connections.Connection:
"""开始事务"""
try:
conn = self.get_connection()
conn.autocommit(False)
# macOS需要特殊处理事务隔离级别
if platform.system() == 'Darwin':
conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED")
self.log.debug("Transaction started")
return conn
except Exception as e:
self.log.error("Begin transaction_failed", error=str(e))
raise
def commit_transaction(self, conn: pymysql.connections.Connection) -> None:
"""提交事务"""
try:
conn.commit()
self.log.debug("Transaction committed")
except Exception as e:
self.log.error("Commit failed", error=str(e))
raise
finally:
conn.close()
def rollback_transaction(self, conn: pymysql.connections.Connection) -> None:
"""回滚事务"""
try:
conn.rollback()
self.log.warning("Transaction rolled back")
except Exception as e:
self.log.error("Rollback failed", error=str(e))
finally:
conn.close()
def table_exists(self, table_name: str) -> bool:
"""检查表是否存在"""
sql = """
SELECT COUNT(*) as count
FROM `information_schema`.`tables`
WHERE `table_schema` = %s AND `table_name` = %s
"""
params = (self.config['database'], table_name)
try:
result = self.execute_sql(sql, params, fetch=True)
exists = result[0]['count'] > 0
self.log.debug("Checked table existence",
table=table_name,
exists=exists)
return exists
except Exception:
return False
def drop_table(self, table_name: str) -> bool:
"""删除表"""
if not self.table_exists(table_name):
self.log.warning("Table does not exist", table=table_name)
return False
try:
self.execute_sql(f"DROP TABLE {table_name}")
self.log.info("Table dropped successfully", table=table_name)
return True
except Exception as e:
self.log.error("Failed to drop table",
table=table_name,
error=str(e),
exc_info=True)
return False
def get_pool_status(self) -> Dict[str, int]:
"""获取连接池状态"""
return {
'max': self._pool._maxconnections,
'active': self._pool._connections,
'idle': len(self._pool._idle_cache),
'shared': len(self._pool._shared_cache)
}
def validate_connection(self) -> bool:
"""验证连接是否有效"""
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute("SELECT 1")
return cursor.fetchone()[0] == 1
except Exception:
return False
def __del__(self):
"""析构函数"""
if hasattr(self, '_pool'):
try:
self._pool.close()
self.log.info("Connection pool closed")
except Exception as e:
self.log.error("Failed to close pool", error=str(e))
# 平台特定的默认配置
def get_default_config():
"""获取各平台默认配置"""
current_platform = platform.system()
base_config = {
'host': 'localhost',
'port': 3306,
'user': 'root',
'password': '123123',
'database': 'intelligence',
'max_connections': 5
}
if current_platform == 'Windows':
return {
**base_config,
'connect_timeout': 10,
'read_timeout': 30,
'write_timeout': 30
}
elif current_platform == 'Darwin':
return {
**base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60,
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
}
else: # Linux和其他平台
return {
**base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60
}
if __name__ == "__main__":
# 使用示例
db = MySQLAgent(get_default_config())
# 测试连接
if db.validate_connection():
print("Database connection successful")
# 获取数据库版本
version = db.query_to_df("SELECT VERSION() as version")
print(f"Database version: {version['version'].iloc[0]}")
# 查看连接池状态
print("Connection pool status:", db.get_pool_status())
else:
print("Failed to connect to database")