Files
intelligence_system/config/settings.py
T
2025-08-06 16:24:17 +08:00

409 lines
13 KiB
Python

import os
import sys
import platform
import pandas as pd
import pymysql
from pymysql import cursors
from pymysql.err import MySQLError
from dbutils.pooled_db import PooledDB
from typing import Union, List, Dict, Any, Optional, Tuple
import threading
from datetime import datetime
import numpy as np
from pathlib import Path
# 导入您的日志系统
from utils.logger import log as logger
class MySQLAgent:
"""
全平台兼容的MySQL数据库操作类
支持Windows/macOS/Linux系统
"""
_instance = None
_lock = threading.Lock()
# 各平台特定的配置
PLATFORM_CONFIG = {
'Windows': {
'socket_timeout': 30,
'connect_timeout': 10,
'ssl': None
},
'Darwin': { # macOS
'socket_timeout': 60,
'connect_timeout': 15,
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
},
'Linux': {
'socket_timeout': 60,
'connect_timeout': 15,
'ssl': None
}
}
def __new__(cls, *args, **kwargs):
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, config: dict = None):
if hasattr(self, '_pool') and self._pool:
return
if not config:
from config.settings import DATABASE_CONFIG
config = DATABASE_CONFIG
# 获取当前平台配置
current_platform = platform.system()
platform_config = self.PLATFORM_CONFIG.get(current_platform, {})
# 基础配置
self.config = {
'host': config.get('host', 'localhost'),
'port': config.get('port', 3306),
'user': config.get('user', 'root'),
'password': config.get('password', ''),
'database': config.get('database', 'intelligence_system'),
'charset': config.get('charset', 'utf8mb4'),
'cursorclass': cursors.DictCursor,
'autocommit': True,
**platform_config # 合并平台特定配置
}
# 处理各平台路径差异
if current_platform == 'Windows':
self.config['ssl'] = None # Windows通常不需要SSL配置
# macOS特殊处理
elif current_platform == 'Darwin':
if not os.path.exists(self.config['ssl']['ca']):
self.config['ssl'] = None
logger.warning("macOS SSL certificate not found, disabling SSL")
self.pool_size = config.get('max_connections', 5)
self._pool = self._create_pool()
self.logger = logger.bind(module=f"MySQLAgent({current_platform})")
def _create_pool(self) -> PooledDB:
"""创建跨平台兼容的连接池"""
try:
# 各平台连接池参数调整
pool_config = {
'creator': pymysql,
'maxconnections': self.pool_size,
'mincached': 1,
'maxcached': 3,
'blocking': True,
'ping': 1, # 定期检查连接有效性
**self.config
}
# Windows平台需要更短的超时时间
if platform.system() == 'Windows':
pool_config['ping'] = 0 # Windows上ping有时不稳定
pool = PooledDB(**pool_config)
self.logger.info(f"Connection pool created for {platform.system()}")
return pool
except Exception as e:
self.logger.critical("Failed to create connection pool",
error=str(e),
exc_info=True)
raise
def _handle_path(self, path: str) -> str:
"""处理跨平台路径问题"""
if platform.system() == 'Windows':
return path.replace('/', '\\')
return path
def get_connection(self) -> pymysql.connections.Connection:
"""
获取数据库连接(跨平台兼容)
Returns:
pymysql.connections.Connection: 数据库连接
Raises:
MySQLError: 如果连接失败
"""
try:
conn = self._pool.connection()
# macOS需要特殊处理SSL
if platform.system() == 'Darwin' and self.config.get('ssl'):
conn.ping(reconnect=True)
self.logger.trace("Connection obtained")
return conn
except Exception as e:
error_msg = str(e)
# Windows特定错误处理
if platform.system() == 'Windows' and "timed out" in error_msg:
self.logger.warning("Windows connection timeout, retrying...")
return self._retry_connection()
self.logger.error("Connection failed",
error=error_msg,
exc_info=True)
raise
def _retry_connection(self, max_retries: int = 3) -> pymysql.connections.Connection:
"""Windows平台连接重试机制"""
for attempt in range(max_retries):
try:
conn = self._pool.connection()
self.logger.info(f"Connection established after {attempt+1} attempts")
return conn
except Exception:
if attempt == max_retries - 1:
raise
import time
time.sleep(1)
def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None,
parse_dates: Union[List[str], bool] = True) -> pd.DataFrame:
"""
跨平台兼容的SQL查询
Args:
sql (str): SQL语句
params (Union[tuple, dict, None]): 参数
parse_dates (Union[List[str], bool]): 日期解析
Returns:
pd.DataFrame: 查询结果
"""
try:
with self.get_connection() as conn:
# Linux/macOS需要更长的查询超时
if platform.system() != 'Windows':
conn.cursor().execute("SET SESSION wait_timeout=600")
df = pd.read_sql(sql, conn, params=params, parse_dates=parse_dates)
# Windows平台需要手动关闭游标
if platform.system() == 'Windows':
conn.cursor().close()
self.logger.info("Query executed", rows=len(df))
return df
except Exception as e:
self.logger.error("Query failed",
sql=sql,
params=params,
error=str(e),
exc_info=True)
raise
def insert_from_df(self, table_name: str, df: pd.DataFrame,
chunk_size: int = 1000, replace: bool = False) -> int:
"""
跨平台数据插入
Args:
table_name (str): 表名
df (pd.DataFrame): 数据
chunk_size (int): 分批大小
replace (bool): 是否替换
Returns:
int: 插入行数
"""
if df.empty:
self.logger.warning("Empty DataFrame", table=table_name)
return 0
try:
method = 'replace' if replace else 'append'
total_rows = 0
with self.get_connection() as conn:
# 各平台不同的分批策略
if platform.system() == 'Windows':
chunk_size = min(chunk_size, 500) # Windows上减小批次
for i in range(0, len(df), chunk_size):
chunk = df.iloc[i:i + chunk_size]
# macOS需要特殊处理datetime
if platform.system() == 'Darwin':
for col in chunk.select_dtypes(include=['datetime64']):
chunk[col] = chunk[col].dt.strftime('%Y-%m-%d %H:%M:%S')
chunk.to_sql(
table_name,
conn,
if_exists=method,
index=False,
method='multi'
)
total_rows += len(chunk)
method = 'append'
self.logger.info("Data inserted", table=table_name, rows=total_rows)
return total_rows
except Exception as e:
self.logger.error("Insert failed",
table=table_name,
error=str(e),
exc_info=True)
raise
def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None,
fetch: bool = False) -> Union[int, List[Dict[str, Any]]]:
"""
跨平台SQL执行
Args:
sql (str): SQL语句
params (Union[tuple, dict, None]): 参数
fetch (bool): 是否获取结果
Returns:
Union[int, List[Dict[str, Any]]]: 结果
"""
conn = None
cursor = None
try:
conn = self.get_connection()
cursor = conn.cursor()
# Linux/macOS需要更长的执行时间
if platform.system() != 'Windows':
cursor.execute("SET SESSION max_execution_time=600000")
cursor.execute(sql, params)
if fetch:
result = cursor.fetchall()
self.logger.debug("Query executed", rows=len(result))
return result
else:
affected_rows = cursor.rowcount
self.logger.debug("Update executed", affected_rows=affected_rows)
return affected_rows
except Exception as e:
self.logger.error("SQL execution failed",
sql=sql,
params=params,
error=str(e),
exc_info=True)
raise
finally:
if cursor:
cursor.close()
if conn:
conn.close()
def begin_transaction(self) -> pymysql.connections.Connection:
"""开始事务(跨平台兼容)"""
try:
conn = self.get_connection()
conn.autocommit(False)
# macOS需要特殊处理事务隔离级别
if platform.system() == 'Darwin':
conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED")
self.logger.debug("Transaction started")
return conn
except Exception as e:
self.logger.error("Begin transaction failed", error=str(e))
raise
def commit_transaction(self, conn: pymysql.connections.Connection) -> None:
"""提交事务(跨平台兼容)"""
try:
conn.commit()
self.logger.debug("Transaction committed")
except Exception as e:
self.logger.error("Commit failed", error=str(e))
raise
finally:
conn.close()
def rollback_transaction(self, conn: pymysql.connections.Connection) -> None:
"""回滚事务(跨平台兼容)"""
try:
conn.rollback()
self.logger.warning("Transaction rolled back")
except Exception as e:
self.logger.error("Rollback failed", error=str(e))
finally:
conn.close()
def __del__(self):
"""析构函数(跨平台资源清理)"""
if hasattr(self, '_pool'):
try:
self._pool.close()
self.logger.info("Connection pool closed")
except Exception as e:
self.logger.error("Failed to close pool", error=str(e))
# 平台特定的默认配置
def get_default_config():
"""获取各平台默认配置"""
current_platform = platform.system()
base_config = {
'host': 'localhost',
'port': 3306,
'user': 'root',
'password': '',
'database': 'intelligence_system',
'max_connections': 5
}
if current_platform == 'Windows':
return {
**base_config,
'connect_timeout': 10,
'read_timeout': 30,
'write_timeout': 30
}
elif current_platform == 'Darwin':
return {
**base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60,
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
}
else: # Linux和其他平台
return {
**base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60
}
# 使用示例
if __name__ == "__main__":
# 自动获取适合当前平台的配置
config = get_default_config()
# 初始化数据库连接
db = MySQLAgent(config)
# 测试查询
try:
df = db.query_to_df("SELECT VERSION() as version")
print(f"Database version: {df['version'].iloc[0]}")
print(f"Running on: {platform.system()} {platform.release()}")
except Exception as e:
print(f"Error: {str(e)}")