数据库操作
This commit is contained in:
@@ -1,255 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据库存储模块
|
||||
功能:
|
||||
1. 统一数据库接口(SQLite/MySQL/PostgreSQL)
|
||||
2. 自动处理多平台路径问题
|
||||
3. 连接池管理
|
||||
4. 数据加密存储
|
||||
"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
import sqlite3
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Dict, Any, List
|
||||
from threading import Lock
|
||||
import logging
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
# 类型别名
|
||||
QueryParams = Union[tuple, Dict[str, Any]]
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""数据库统一管理类"""
|
||||
|
||||
def __init__(self, db_config: Dict[str, Any]):
|
||||
"""
|
||||
初始化数据库连接
|
||||
|
||||
参数:
|
||||
db_config: 配置字典,包含:
|
||||
- type: 'sqlite'|'mysql'|'postgresql'
|
||||
- database: 数据库名/路径
|
||||
- [可选] host, port, user, password
|
||||
"""
|
||||
self.config = db_config
|
||||
self._lock = Lock()
|
||||
self._connection_pool = {}
|
||||
self._setup_crypto()
|
||||
|
||||
# 自动创建SQLite目录
|
||||
if db_config['type'] == 'sqlite':
|
||||
self._ensure_sqlite_dir()
|
||||
|
||||
def _ensure_sqlite_dir(self):
|
||||
"""确保SQLite数据库目录存在"""
|
||||
db_path = Path(self.config['database'])
|
||||
if not db_path.parent.exists():
|
||||
try:
|
||||
db_path.parent.mkdir(parents=True, mode=0o755)
|
||||
except Exception as e:
|
||||
logging.error(f"创建数据库目录失败: {str(e)}")
|
||||
|
||||
def _setup_crypto(self):
|
||||
"""初始化加密模块"""
|
||||
key_file = Path(os.path.expanduser("~/.db_encryption.key"))
|
||||
if key_file.exists():
|
||||
with open(key_file, 'rb') as f:
|
||||
self._fernet = Fernet(f.read())
|
||||
else:
|
||||
self._fernet = Fernet.generate_key()
|
||||
with open(key_file, 'wb') as f:
|
||||
f.write(self._fernet)
|
||||
key_file.chmod(0o600) # 仅限当前用户读写
|
||||
|
||||
def get_connection(self, reuse=True):
|
||||
"""
|
||||
获取数据库连接(线程安全)
|
||||
|
||||
参数:
|
||||
reuse: 是否复用现有连接(默认True)
|
||||
"""
|
||||
thread_id = threading.get_ident()
|
||||
|
||||
with self._lock:
|
||||
if reuse and thread_id in self._connection_pool:
|
||||
conn = self._connection_pool[thread_id]
|
||||
try:
|
||||
# 检查连接是否有效
|
||||
conn.execute("SELECT 1")
|
||||
return conn
|
||||
except:
|
||||
del self._connection_pool[thread_id]
|
||||
|
||||
# 创建新连接
|
||||
if self.config['type'] == 'sqlite':
|
||||
conn = self._create_sqlite_connection()
|
||||
elif self.config['type'] == 'mysql':
|
||||
conn = self._create_mysql_connection()
|
||||
elif self.config['type'] == 'postgresql':
|
||||
conn = self._create_pg_connection()
|
||||
else:
|
||||
raise ValueError("不支持的数据库类型")
|
||||
|
||||
self._connection_pool[thread_id] = conn
|
||||
return conn
|
||||
|
||||
def _create_sqlite_connection(self) -> sqlite3.Connection:
|
||||
"""创建SQLite连接(兼容多平台路径)"""
|
||||
db_path = self.config['database']
|
||||
|
||||
# Windows路径处理
|
||||
if platform.system() == 'Windows' and not db_path.startswith(('\\\\', '/')):
|
||||
db_path = os.path.abspath(db_path)
|
||||
|
||||
conn = sqlite3.connect(db_path, timeout=15)
|
||||
conn.execute("PRAGMA journal_mode=WAL") # 写前日志提升并发
|
||||
conn.execute("PRAGMA synchronous=NORMAL")
|
||||
conn.row_factory = sqlite3.Row # 支持字典式访问
|
||||
return conn
|
||||
|
||||
def _create_mysql_connection(self):
|
||||
"""创建MySQL连接(需安装PyMySQL)"""
|
||||
import pymysql
|
||||
return pymysql.connect(
|
||||
host=self.config.get('host', 'localhost'),
|
||||
port=self.config.get('port', 3306),
|
||||
user=self.config.get('user', 'root'),
|
||||
password=self.config.get('password', ''),
|
||||
database=self.config['database'],
|
||||
charset='utf8mb4',
|
||||
cursorclass=pymysql.cursors.DictCursor
|
||||
)
|
||||
|
||||
def _create_pg_connection(self):
|
||||
"""创建PostgreSQL连接(需安装psycopg2)"""
|
||||
import psycopg2
|
||||
return psycopg2.connect(
|
||||
host=self.config.get('host', 'localhost'),
|
||||
port=self.config.get('port', 5432),
|
||||
user=self.config.get('user', 'postgres'),
|
||||
password=self.config.get('password', ''),
|
||||
dbname=self.config['database']
|
||||
)
|
||||
|
||||
def execute(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[QueryParams] = None,
|
||||
return_lastrowid: bool = False
|
||||
) -> Union[int, None]:
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
with conn:
|
||||
cursor = conn.cursor()
|
||||
if params is not None:
|
||||
cursor.execute(query, params)
|
||||
else:
|
||||
cursor.execute(query) # ✅ 无参数时不要传 params
|
||||
return cursor.lastrowid if return_lastrowid else None
|
||||
finally:
|
||||
if not return_lastrowid:
|
||||
self._release_connection()
|
||||
|
||||
def query(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[QueryParams] = None,
|
||||
fetchall: bool = True
|
||||
) -> Union[List[Dict], Dict]:
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
if params is not None:
|
||||
cursor.execute(query, params)
|
||||
else:
|
||||
cursor.execute(query) # ✅ 无参数时不要传 None
|
||||
|
||||
if self.config['type'] == 'sqlite':
|
||||
result = cursor.fetchall()
|
||||
if not fetchall and result:
|
||||
return dict(result[0])
|
||||
return [dict(row) for row in result]
|
||||
else:
|
||||
return cursor.fetchall() if fetchall else cursor.fetchone()
|
||||
finally:
|
||||
self._release_connection()
|
||||
|
||||
def _release_connection(self):
|
||||
"""释放当前线程的连接(SQLite除外)"""
|
||||
if self.config['type'] != 'sqlite':
|
||||
thread_id = threading.get_ident()
|
||||
with self._lock:
|
||||
if thread_id in self._connection_pool:
|
||||
self._connection_pool[thread_id].close()
|
||||
del self._connection_pool[thread_id]
|
||||
|
||||
def encrypt_data(self, plaintext: str) -> str:
|
||||
"""加密敏感数据"""
|
||||
return self._fernet.encrypt(plaintext.encode()).decode()
|
||||
|
||||
def decrypt_data(self, ciphertext: str) -> str:
|
||||
"""解密数据"""
|
||||
return self._fernet.decrypt(ciphertext.encode()).decode()
|
||||
|
||||
def close_all(self):
|
||||
"""关闭所有数据库连接"""
|
||||
with self._lock:
|
||||
for conn in self._connection_pool.values():
|
||||
try:
|
||||
conn.close()
|
||||
except:
|
||||
pass
|
||||
self._connection_pool.clear()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close_all()
|
||||
|
||||
|
||||
# 全局SQLite实例(默认配置)
|
||||
def get_default_db() -> DatabaseManager:
|
||||
"""获取默认SQLite数据库(跨平台路径处理)"""
|
||||
system = platform.system().lower()
|
||||
if system == 'windows':
|
||||
db_path = os.path.join(os.getenv('APPDATA'), 'app_name/data.db')
|
||||
elif system == 'darwin':
|
||||
db_path = os.path.expanduser('~/Library/Application Support/app_name/data.db')
|
||||
else:
|
||||
db_path = '/var/lib/app_name/data.db' if os.access('/var/lib', os.W_OK) \
|
||||
else os.path.expanduser('~/.local/share/app_name/data.db')
|
||||
|
||||
return DatabaseManager({
|
||||
'type': 'sqlite',
|
||||
'database': db_path
|
||||
})
|
||||
|
||||
|
||||
# 测试代码
|
||||
if __name__ == "__main__":
|
||||
with get_default_db() as db:
|
||||
# 创建测试表
|
||||
db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS test_table (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT,
|
||||
secret TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
# 插入加密数据
|
||||
secret = db.encrypt_data("新敏感信息")
|
||||
db.execute(
|
||||
"INSERT INTO test_table (name, secret) VALUES (?, ?)",
|
||||
("测试记录", secret)
|
||||
)
|
||||
|
||||
# 查询并解密
|
||||
row = db.query("SELECT * FROM test_table", fetchall=False)
|
||||
print(f"解密数据: {db.decrypt_data(row['name'])}")
|
||||
@@ -0,0 +1,662 @@
|
||||
import os
|
||||
import sys
|
||||
import platform
|
||||
import pandas as pd
|
||||
import pymysql
|
||||
from pymysql import cursors
|
||||
from pymysql.err import MySQLError
|
||||
from dbutils.pooled_db import PooledDB
|
||||
from typing import Union, List, Dict, Any, Optional, Tuple
|
||||
import threading
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
# 导入日志系统
|
||||
from utils.logger import log
|
||||
|
||||
|
||||
class MySQLAgent:
|
||||
"""
|
||||
全平台兼容的MySQL数据库操作类
|
||||
支持Windows/macOS/Linux系统
|
||||
配置参数从外部传入
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if not cls._instance:
|
||||
with cls._lock:
|
||||
if not cls._instance:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: dict):
|
||||
"""
|
||||
初始化MySQL数据库连接
|
||||
|
||||
Args:
|
||||
config (dict): 数据库配置字典,包含以下键:
|
||||
- host: 数据库主机
|
||||
- port: 端口
|
||||
- user: 用户名
|
||||
- password: 密码
|
||||
- database: 数据库名
|
||||
- [可选] charset: 字符集(默认utf8mb4)
|
||||
- [可选] max_connections: 最大连接数(默认5)
|
||||
- [可选] connect_timeout: 连接超时(秒)
|
||||
- [可选] read_timeout: 读取超时(秒)
|
||||
- [可选] write_timeout: 写入超时(秒)
|
||||
- [可选] ssl: SSL配置
|
||||
"""
|
||||
if hasattr(self, '_pool') and self._pool:
|
||||
return
|
||||
|
||||
# 基础配置
|
||||
required_keys = ['host', 'port', 'user', 'password', 'database']
|
||||
if not all(key in config for key in required_keys):
|
||||
raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}")
|
||||
|
||||
self.config = {
|
||||
'host': config['host'],
|
||||
'port': config['port'],
|
||||
'user': config['user'],
|
||||
'password': config['password'],
|
||||
'database': config['database'],
|
||||
'charset': config.get('charset', 'utf8mb4'),
|
||||
'cursorclass': cursors.DictCursor,
|
||||
'autocommit': True,
|
||||
'connect_timeout': config.get('connect_timeout', 10),
|
||||
'read_timeout': config.get('read_timeout', 30),
|
||||
'write_timeout': config.get('write_timeout', 30),
|
||||
'ssl': config.get('ssl')
|
||||
}
|
||||
|
||||
# 初始化log
|
||||
current_platform = platform.system()
|
||||
self.log = log.bind(module=f"MySQLAgent({current_platform})")
|
||||
|
||||
# 创建连接池
|
||||
self.pool_size = config.get('max_connections', 5)
|
||||
self._pool = self._create_pool()
|
||||
|
||||
def _create_pool(self) -> PooledDB:
|
||||
"""创建连接池"""
|
||||
try:
|
||||
# 使用包装函数确保线程安全
|
||||
def connect():
|
||||
conn = pymysql.connect(**self.config)
|
||||
conn.threadsafety = 1 # 显式设置线程安全级别
|
||||
return conn
|
||||
|
||||
pool = PooledDB(
|
||||
creator=connect,
|
||||
mincached=1,
|
||||
maxcached=3,
|
||||
maxconnections=self.pool_size,
|
||||
blocking=True,
|
||||
ping=1
|
||||
)
|
||||
|
||||
self.log.info("Connection pool created")
|
||||
return pool
|
||||
|
||||
except Exception as e:
|
||||
self.log.critical("Failed to create connection pool",
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
raise
|
||||
|
||||
def get_connection(self) -> pymysql.connections.Connection:
|
||||
"""
|
||||
获取数据库连接
|
||||
|
||||
Returns:
|
||||
pymysql.connections.Connection: 数据库连接对象
|
||||
|
||||
Raises:
|
||||
MySQLError: 如果获取连接失败
|
||||
"""
|
||||
try:
|
||||
conn = self._pool.connection()
|
||||
|
||||
# macOS需要特殊处理SSL
|
||||
if platform.system() == 'Darwin' and self.config.get('ssl'):
|
||||
conn.ping(reconnect=True)
|
||||
|
||||
self.log.trace("Database connection obtained")
|
||||
return conn
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
||||
# Windows特定错误处理
|
||||
if platform.system() == 'Windows' and "timed out" in error_msg:
|
||||
self.log.warning("Windows connection timeout, retrying...")
|
||||
return self._retry_connection()
|
||||
|
||||
self.log.error("Connection failed",
|
||||
error=error_msg,
|
||||
exc_info=True)
|
||||
raise
|
||||
|
||||
def _retry_connection(self, max_retries: int = 3) -> pymysql.connections.Connection:
|
||||
"""Windows平台连接重试机制"""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
conn = self._pool.connection()
|
||||
self.log.info(f"Connection established after {attempt + 1} attempts")
|
||||
return conn
|
||||
except Exception:
|
||||
if attempt == max_retries - 1:
|
||||
raise
|
||||
import time
|
||||
time.sleep(1)
|
||||
|
||||
def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None,
|
||||
parse_dates: Union[List[str], bool] = True) -> pd.DataFrame:
|
||||
"""
|
||||
执行SQL查询并返回DataFrame
|
||||
|
||||
Args:
|
||||
sql (str): SQL查询语句
|
||||
params (Union[tuple, dict, None]): 查询参数
|
||||
parse_dates (Union[List[str], bool]): 自动解析日期字段
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: 查询结果
|
||||
|
||||
Raises:
|
||||
MySQLError: 如果查询失败
|
||||
"""
|
||||
try:
|
||||
self.log.debug("Executing SQL query", sql=sql)
|
||||
|
||||
with self.get_connection() as conn:
|
||||
# Linux/macOS需要更长的查询超时
|
||||
if platform.system() != 'Windows':
|
||||
conn.cursor().execute("SET SESSION wait_timeout=600")
|
||||
|
||||
df = pd.read_sql(sql, conn, params=params, parse_dates=parse_dates)
|
||||
|
||||
# Windows平台需要手动关闭游标
|
||||
if platform.system() == 'Windows':
|
||||
conn.cursor().close()
|
||||
|
||||
self.log.info("Query executed successfully", rows=len(df))
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("SQL query failed",
|
||||
sql=sql,
|
||||
params=params,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
raise
|
||||
|
||||
def insert_from_df(self, table_name: str, df: pd.DataFrame,
|
||||
chunk_size: int = 1000, replace: bool = False) -> int:
|
||||
"""
|
||||
将DataFrame数据插入到数据库表
|
||||
|
||||
Args:
|
||||
table_name (str): 目标表名
|
||||
df (pd.DataFrame): 要插入的数据
|
||||
chunk_size (int): 分批插入大小
|
||||
replace (bool): 是否替换现有数据
|
||||
|
||||
Returns:
|
||||
int: 插入的总行数
|
||||
|
||||
Raises:
|
||||
MySQLError: 如果插入失败
|
||||
"""
|
||||
if df.empty:
|
||||
self.log.warning("Attempted to insert empty DataFrame", table=table_name)
|
||||
return 0
|
||||
|
||||
self.log.debug("Preparing to insert DataFrame",
|
||||
table=table_name,
|
||||
rows=len(df),
|
||||
chunk_size=chunk_size)
|
||||
|
||||
try:
|
||||
method = 'replace' if replace else 'append'
|
||||
total_rows = 0
|
||||
|
||||
with self.get_connection() as conn:
|
||||
# 各平台不同的分批策略
|
||||
if platform.system() == 'Windows':
|
||||
chunk_size = min(chunk_size, 500) # Windows上减小批次
|
||||
|
||||
for i in range(0, len(df), chunk_size):
|
||||
chunk = df.iloc[i:i + chunk_size]
|
||||
|
||||
# macOS需要特殊处理datetime
|
||||
if platform.system() == 'Darwin':
|
||||
for col in chunk.select_dtypes(include=['datetime64']):
|
||||
chunk[col] = chunk[col].dt.strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
chunk.to_sql(
|
||||
table_name,
|
||||
conn,
|
||||
if_exists=method,
|
||||
index=False,
|
||||
method='multi'
|
||||
)
|
||||
total_rows += len(chunk)
|
||||
method = 'append' # 第一次之后都使用追加模式
|
||||
self.log.trace(f"Inserted chunk {i // chunk_size + 1}",
|
||||
rows=len(chunk),
|
||||
total_inserted=total_rows)
|
||||
|
||||
self.log.info("Data inserted successfully",
|
||||
table=table_name,
|
||||
total_rows=total_rows)
|
||||
return total_rows
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("Data insertion failed",
|
||||
table=table_name,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
raise
|
||||
|
||||
def update_from_df(self, table_name: str, df: pd.DataFrame,
|
||||
key_columns: Union[str, List[str]]) -> int:
|
||||
"""
|
||||
使用DataFrame数据更新数据库表
|
||||
|
||||
Args:
|
||||
table_name (str): 目标表名
|
||||
df (pd.DataFrame): 包含更新数据
|
||||
key_columns (Union[str, List[str]]): 用于匹配记录的关键列
|
||||
|
||||
Returns:
|
||||
int: 更新的总行数
|
||||
|
||||
Raises:
|
||||
MySQLError: 如果更新失败
|
||||
"""
|
||||
if df.empty:
|
||||
self.log.warning("Attempted to update with empty DataFrame", table=table_name)
|
||||
return 0
|
||||
|
||||
self.log.debug("Preparing to update table from DataFrame",
|
||||
table=table_name,
|
||||
key_columns=key_columns,
|
||||
rows=len(df))
|
||||
|
||||
try:
|
||||
if isinstance(key_columns, str):
|
||||
key_columns = [key_columns]
|
||||
|
||||
total_updated = 0
|
||||
conn = self.begin_transaction()
|
||||
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 获取表结构信息
|
||||
table_info = self._get_table_info(table_name)
|
||||
columns = [col for col in df.columns if col in table_info]
|
||||
|
||||
# 构建UPDATE语句模板
|
||||
set_clause = ', '.join([f"{col}=%s" for col in columns if col not in key_columns])
|
||||
where_clause = ' AND '.join([f"{col}=%s" for col in key_columns])
|
||||
|
||||
update_sql = f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}"
|
||||
self.log.trace("Generated update SQL", sql=update_sql)
|
||||
|
||||
# 准备数据
|
||||
update_data = []
|
||||
for _, row in df.iterrows():
|
||||
# SET部分的值
|
||||
set_values = [row[col] for col in columns if col not in key_columns]
|
||||
# WHERE部分的值
|
||||
key_values = [row[col] for col in key_columns]
|
||||
update_data.append(tuple(set_values + key_values))
|
||||
|
||||
# 执行批量更新
|
||||
cursor.executemany(update_sql, update_data)
|
||||
total_updated = cursor.rowcount
|
||||
|
||||
self.commit_transaction(conn)
|
||||
self.log.info("Data updated successfully",
|
||||
table=table_name,
|
||||
rows_updated=total_updated)
|
||||
return total_updated
|
||||
|
||||
except Exception as e:
|
||||
self.rollback_transaction(conn)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("Data update failed",
|
||||
table=table_name,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
raise
|
||||
|
||||
def _get_table_info(self, table_name: str) -> Dict[str, str]:
|
||||
"""
|
||||
获取表结构信息
|
||||
|
||||
Args:
|
||||
table_name (str): 表名
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: 列名到类型的映射
|
||||
|
||||
Raises:
|
||||
MySQLError: 如果查询失败
|
||||
"""
|
||||
sql = f"""
|
||||
SELECT column_name, data_type
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = %s AND table_name = %s
|
||||
"""
|
||||
|
||||
params = (self.config['database'], table_name)
|
||||
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(sql, params)
|
||||
result = cursor.fetchall()
|
||||
return {row['column_name']: row['data_type'] for row in result}
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("Failed to get table info",
|
||||
table=table_name,
|
||||
error=str(e))
|
||||
raise
|
||||
|
||||
def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]:
|
||||
"""
|
||||
推断DataFrame各列的SQL类型
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): 输入数据框
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: 列名到SQL类型的映射
|
||||
"""
|
||||
type_mapping = {
|
||||
'int64': 'BIGINT',
|
||||
'float64': 'DOUBLE',
|
||||
'datetime64[ns]': 'DATETIME',
|
||||
'object': 'TEXT',
|
||||
'bool': 'TINYINT(1)',
|
||||
'category': 'VARCHAR(255)'
|
||||
}
|
||||
|
||||
sql_types = {}
|
||||
for col, dtype in df.dtypes.items():
|
||||
dtype_str = str(dtype)
|
||||
sql_types[col] = type_mapping.get(dtype_str, 'TEXT')
|
||||
|
||||
self.log.debug("Mapped DataFrame types to SQL types",
|
||||
mappings=sql_types)
|
||||
return sql_types
|
||||
|
||||
def create_table_from_df(self, table_name: str, df: pd.DataFrame,
|
||||
primary_key: Union[str, List[str], None] = None) -> bool:
|
||||
"""
|
||||
根据DataFrame结构创建表
|
||||
|
||||
Args:
|
||||
table_name (str): 表名
|
||||
df (pd.DataFrame): 参考数据框
|
||||
primary_key (Union[str, List[str], None]): 主键列
|
||||
|
||||
Returns:
|
||||
bool: 是否创建成功
|
||||
"""
|
||||
if self.table_exists(table_name):
|
||||
self.log.warning("Table already exists", table=table_name)
|
||||
return False
|
||||
|
||||
self.log.debug("Creating new table from DataFrame schema",
|
||||
table=table_name,
|
||||
columns=list(df.columns))
|
||||
|
||||
try:
|
||||
sql_types = self.df_to_sql_type(df)
|
||||
columns_sql = []
|
||||
|
||||
for col, sql_type in sql_types.items():
|
||||
col_def = f"{col} {sql_type}"
|
||||
columns_sql.append(col_def)
|
||||
|
||||
if primary_key:
|
||||
if isinstance(primary_key, str):
|
||||
primary_key = [primary_key]
|
||||
pk_columns = [col for col in primary_key if col in sql_types]
|
||||
if pk_columns:
|
||||
columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})")
|
||||
self.log.trace("Set primary key",
|
||||
table=table_name,
|
||||
primary_key=pk_columns)
|
||||
|
||||
create_sql = f"CREATE TABLE {table_name} (\n {',\n '.join(columns_sql)}\n)"
|
||||
|
||||
self.execute_sql(create_sql)
|
||||
self.log.info("Table created successfully", table=table_name)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("Failed to create table",
|
||||
table=table_name,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
return False
|
||||
|
||||
def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None,
|
||||
fetch: bool = False) -> Union[int, List[Dict[str, Any]]]:
|
||||
"""
|
||||
执行SQL语句
|
||||
|
||||
Args:
|
||||
sql (str): SQL语句
|
||||
params (Union[tuple, dict, None]): 参数
|
||||
fetch (bool): 是否获取结果
|
||||
|
||||
Returns:
|
||||
Union[int, List[Dict[str, Any]]]:
|
||||
- 如果是INSERT/UPDATE/DELETE,返回影响的行数
|
||||
- 如果是SELECT且fetch=True,返回结果列表
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = self.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Linux/macOS需要更长的执行时间
|
||||
if platform.system() != 'Windows':
|
||||
cursor.execute("SET SESSION max_execution_time=600000")
|
||||
|
||||
cursor.execute(sql, params)
|
||||
|
||||
if fetch:
|
||||
result = cursor.fetchall()
|
||||
self.log.debug("Query executed", rows=len(result))
|
||||
return result
|
||||
else:
|
||||
affected_rows = cursor.rowcount
|
||||
self.log.debug("Update executed", affected_rows=affected_rows)
|
||||
return affected_rows
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("SQL execution failed",
|
||||
sql=sql,
|
||||
params=params,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
if cursor:
|
||||
cursor.close()
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
def begin_transaction(self) -> pymysql.connections.Connection:
|
||||
"""开始事务"""
|
||||
try:
|
||||
conn = self.get_connection()
|
||||
conn.autocommit(False)
|
||||
|
||||
# macOS需要特殊处理事务隔离级别
|
||||
if platform.system() == 'Darwin':
|
||||
conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED")
|
||||
|
||||
self.log.debug("Transaction started")
|
||||
return conn
|
||||
except Exception as e:
|
||||
self.log.error("Begin transaction failed", error=str(e), exc_info=True)
|
||||
raise
|
||||
|
||||
def commit_transaction(self, conn: pymysql.connections.Connection) -> None:
|
||||
"""提交事务"""
|
||||
try:
|
||||
conn.commit()
|
||||
self.log.debug("Transaction committed")
|
||||
except Exception as e:
|
||||
self.log.error("Commit failed", error=str(e))
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def rollback_transaction(self, conn: pymysql.connections.Connection) -> None:
|
||||
"""回滚事务"""
|
||||
try:
|
||||
conn.rollback()
|
||||
self.log.warning("Transaction rolled back")
|
||||
except Exception as e:
|
||||
self.log.error("Rollback failed", error=str(e))
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def table_exists(self, table_name: str) -> bool:
|
||||
"""检查表是否存在"""
|
||||
sql = """
|
||||
SELECT COUNT(*) as count
|
||||
FROM `information_schema`.`tables`
|
||||
WHERE `table_schema` = %s AND `table_name` = %s
|
||||
"""
|
||||
|
||||
params = (self.config['database'], table_name)
|
||||
|
||||
try:
|
||||
result = self.execute_sql(sql, params, fetch=True)
|
||||
exists = result[0]['count'] > 0
|
||||
self.log.debug("Checked table existence",
|
||||
table=table_name,
|
||||
exists=exists)
|
||||
return exists
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def drop_table(self, table_name: str) -> bool:
|
||||
"""删除表"""
|
||||
if not self.table_exists(table_name):
|
||||
self.log.warning("Table does not exist", table=table_name)
|
||||
return False
|
||||
|
||||
try:
|
||||
self.execute_sql(f"DROP TABLE {table_name}")
|
||||
self.log.info("Table dropped successfully", table=table_name)
|
||||
return True
|
||||
except Exception as e:
|
||||
self.log.error("Failed to drop table",
|
||||
table=table_name,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
return False
|
||||
|
||||
def get_pool_status(self) -> Dict[str, int]:
|
||||
"""获取连接池状态"""
|
||||
return {
|
||||
'max': self._pool._maxconnections,
|
||||
'active': self._pool._connections,
|
||||
'idle': len(self._pool._idle_cache),
|
||||
'shared': len(self._pool._shared_cache)
|
||||
}
|
||||
|
||||
def validate_connection(self) -> bool:
|
||||
"""验证连接是否有效"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("SELECT 1")
|
||||
return cursor.fetchone()[0] == 1
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def __del__(self):
|
||||
"""析构函数"""
|
||||
if hasattr(self, '_pool'):
|
||||
try:
|
||||
self._pool.close()
|
||||
self.log.info("Connection pool closed")
|
||||
except Exception as e:
|
||||
self.log.error("Failed to close pool", error=str(e))
|
||||
|
||||
|
||||
# 平台特定的默认配置
|
||||
def get_default_config():
|
||||
"""获取各平台默认配置"""
|
||||
current_platform = platform.system()
|
||||
|
||||
base_config = {
|
||||
'host': 'localhost',
|
||||
'port': 3306,
|
||||
'user': 'root',
|
||||
'password': '123123',
|
||||
'database': 'intelligence_system',
|
||||
'max_connections': 5
|
||||
}
|
||||
|
||||
if current_platform == 'Windows':
|
||||
return {
|
||||
**base_config,
|
||||
'connect_timeout': 10,
|
||||
'read_timeout': 30,
|
||||
'write_timeout': 30
|
||||
}
|
||||
elif current_platform == 'Darwin':
|
||||
return {
|
||||
**base_config,
|
||||
'connect_timeout': 15,
|
||||
'read_timeout': 60,
|
||||
'write_timeout': 60,
|
||||
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
|
||||
}
|
||||
else: # Linux和其他平台
|
||||
return {
|
||||
**base_config,
|
||||
'connect_timeout': 15,
|
||||
'read_timeout': 60,
|
||||
'write_timeout': 60
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 使用示例
|
||||
db = MySQLAgent(get_default_config())
|
||||
|
||||
# 测试连接
|
||||
if db.validate_connection():
|
||||
print("Database connection successful")
|
||||
|
||||
# 获取数据库版本
|
||||
version = db.query_to_df("SELECT VERSION() as version")
|
||||
print(f"Database version: {version['version'].iloc[0]}")
|
||||
|
||||
# 查看连接池状态
|
||||
print("Connection pool status:", db.get_pool_status())
|
||||
else:
|
||||
print("Failed to connect to database")
|
||||
Reference in New Issue
Block a user