数据库操作

This commit is contained in:
2025-08-06 16:24:17 +08:00
parent aa0b71a90b
commit c8d268647f
11 changed files with 1344 additions and 706 deletions
-255
View File
@@ -1,255 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
数据库存储模块
功能:
1. 统一数据库接口(SQLite/MySQL/PostgreSQL
2. 自动处理多平台路径问题
3. 连接池管理
4. 数据加密存储
"""
import os
import platform
import sqlite3
import threading
from pathlib import Path
from typing import Optional, Union, Dict, Any, List
from threading import Lock
import logging
from cryptography.fernet import Fernet
# 类型别名
QueryParams = Union[tuple, Dict[str, Any]]
class DatabaseManager:
"""数据库统一管理类"""
def __init__(self, db_config: Dict[str, Any]):
"""
初始化数据库连接
参数:
db_config: 配置字典,包含:
- type: 'sqlite'|'mysql'|'postgresql'
- database: 数据库名/路径
- [可选] host, port, user, password
"""
self.config = db_config
self._lock = Lock()
self._connection_pool = {}
self._setup_crypto()
# 自动创建SQLite目录
if db_config['type'] == 'sqlite':
self._ensure_sqlite_dir()
def _ensure_sqlite_dir(self):
"""确保SQLite数据库目录存在"""
db_path = Path(self.config['database'])
if not db_path.parent.exists():
try:
db_path.parent.mkdir(parents=True, mode=0o755)
except Exception as e:
logging.error(f"创建数据库目录失败: {str(e)}")
def _setup_crypto(self):
"""初始化加密模块"""
key_file = Path(os.path.expanduser("~/.db_encryption.key"))
if key_file.exists():
with open(key_file, 'rb') as f:
self._fernet = Fernet(f.read())
else:
self._fernet = Fernet.generate_key()
with open(key_file, 'wb') as f:
f.write(self._fernet)
key_file.chmod(0o600) # 仅限当前用户读写
def get_connection(self, reuse=True):
"""
获取数据库连接(线程安全)
参数:
reuse: 是否复用现有连接(默认True)
"""
thread_id = threading.get_ident()
with self._lock:
if reuse and thread_id in self._connection_pool:
conn = self._connection_pool[thread_id]
try:
# 检查连接是否有效
conn.execute("SELECT 1")
return conn
except:
del self._connection_pool[thread_id]
# 创建新连接
if self.config['type'] == 'sqlite':
conn = self._create_sqlite_connection()
elif self.config['type'] == 'mysql':
conn = self._create_mysql_connection()
elif self.config['type'] == 'postgresql':
conn = self._create_pg_connection()
else:
raise ValueError("不支持的数据库类型")
self._connection_pool[thread_id] = conn
return conn
def _create_sqlite_connection(self) -> sqlite3.Connection:
"""创建SQLite连接(兼容多平台路径)"""
db_path = self.config['database']
# Windows路径处理
if platform.system() == 'Windows' and not db_path.startswith(('\\\\', '/')):
db_path = os.path.abspath(db_path)
conn = sqlite3.connect(db_path, timeout=15)
conn.execute("PRAGMA journal_mode=WAL") # 写前日志提升并发
conn.execute("PRAGMA synchronous=NORMAL")
conn.row_factory = sqlite3.Row # 支持字典式访问
return conn
def _create_mysql_connection(self):
"""创建MySQL连接(需安装PyMySQL"""
import pymysql
return pymysql.connect(
host=self.config.get('host', 'localhost'),
port=self.config.get('port', 3306),
user=self.config.get('user', 'root'),
password=self.config.get('password', ''),
database=self.config['database'],
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor
)
def _create_pg_connection(self):
"""创建PostgreSQL连接(需安装psycopg2"""
import psycopg2
return psycopg2.connect(
host=self.config.get('host', 'localhost'),
port=self.config.get('port', 5432),
user=self.config.get('user', 'postgres'),
password=self.config.get('password', ''),
dbname=self.config['database']
)
def execute(
self,
query: str,
params: Optional[QueryParams] = None,
return_lastrowid: bool = False
) -> Union[int, None]:
conn = self.get_connection()
try:
with conn:
cursor = conn.cursor()
if params is not None:
cursor.execute(query, params)
else:
cursor.execute(query) # ✅ 无参数时不要传 params
return cursor.lastrowid if return_lastrowid else None
finally:
if not return_lastrowid:
self._release_connection()
def query(
self,
query: str,
params: Optional[QueryParams] = None,
fetchall: bool = True
) -> Union[List[Dict], Dict]:
conn = self.get_connection()
try:
cursor = conn.cursor()
if params is not None:
cursor.execute(query, params)
else:
cursor.execute(query) # ✅ 无参数时不要传 None
if self.config['type'] == 'sqlite':
result = cursor.fetchall()
if not fetchall and result:
return dict(result[0])
return [dict(row) for row in result]
else:
return cursor.fetchall() if fetchall else cursor.fetchone()
finally:
self._release_connection()
def _release_connection(self):
"""释放当前线程的连接(SQLite除外)"""
if self.config['type'] != 'sqlite':
thread_id = threading.get_ident()
with self._lock:
if thread_id in self._connection_pool:
self._connection_pool[thread_id].close()
del self._connection_pool[thread_id]
def encrypt_data(self, plaintext: str) -> str:
"""加密敏感数据"""
return self._fernet.encrypt(plaintext.encode()).decode()
def decrypt_data(self, ciphertext: str) -> str:
"""解密数据"""
return self._fernet.decrypt(ciphertext.encode()).decode()
def close_all(self):
"""关闭所有数据库连接"""
with self._lock:
for conn in self._connection_pool.values():
try:
conn.close()
except:
pass
self._connection_pool.clear()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close_all()
# 全局SQLite实例(默认配置)
def get_default_db() -> DatabaseManager:
"""获取默认SQLite数据库(跨平台路径处理)"""
system = platform.system().lower()
if system == 'windows':
db_path = os.path.join(os.getenv('APPDATA'), 'app_name/data.db')
elif system == 'darwin':
db_path = os.path.expanduser('~/Library/Application Support/app_name/data.db')
else:
db_path = '/var/lib/app_name/data.db' if os.access('/var/lib', os.W_OK) \
else os.path.expanduser('~/.local/share/app_name/data.db')
return DatabaseManager({
'type': 'sqlite',
'database': db_path
})
# 测试代码
if __name__ == "__main__":
with get_default_db() as db:
# 创建测试表
db.execute("""
CREATE TABLE IF NOT EXISTS test_table (
id INTEGER PRIMARY KEY,
name TEXT,
secret TEXT
)
""")
# 插入加密数据
secret = db.encrypt_data("新敏感信息")
db.execute(
"INSERT INTO test_table (name, secret) VALUES (?, ?)",
("测试记录", secret)
)
# 查询并解密
row = db.query("SELECT * FROM test_table", fetchall=False)
print(f"解密数据: {db.decrypt_data(row['name'])}")
+662
View File
@@ -0,0 +1,662 @@
import os
import sys
import platform
import pandas as pd
import pymysql
from pymysql import cursors
from pymysql.err import MySQLError
from dbutils.pooled_db import PooledDB
from typing import Union, List, Dict, Any, Optional, Tuple
import threading
from datetime import datetime
import numpy as np
from pathlib import Path
# 导入日志系统
from utils.logger import log
class MySQLAgent:
"""
全平台兼容的MySQL数据库操作类
支持Windows/macOS/Linux系统
配置参数从外部传入
"""
_instance = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, config: dict):
"""
初始化MySQL数据库连接
Args:
config (dict): 数据库配置字典,包含以下键:
- host: 数据库主机
- port: 端口
- user: 用户名
- password: 密码
- database: 数据库名
- [可选] charset: 字符集(默认utf8mb4)
- [可选] max_connections: 最大连接数(默认5)
- [可选] connect_timeout: 连接超时(秒)
- [可选] read_timeout: 读取超时(秒)
- [可选] write_timeout: 写入超时(秒)
- [可选] ssl: SSL配置
"""
if hasattr(self, '_pool') and self._pool:
return
# 基础配置
required_keys = ['host', 'port', 'user', 'password', 'database']
if not all(key in config for key in required_keys):
raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}")
self.config = {
'host': config['host'],
'port': config['port'],
'user': config['user'],
'password': config['password'],
'database': config['database'],
'charset': config.get('charset', 'utf8mb4'),
'cursorclass': cursors.DictCursor,
'autocommit': True,
'connect_timeout': config.get('connect_timeout', 10),
'read_timeout': config.get('read_timeout', 30),
'write_timeout': config.get('write_timeout', 30),
'ssl': config.get('ssl')
}
# 初始化log
current_platform = platform.system()
self.log = log.bind(module=f"MySQLAgent({current_platform})")
# 创建连接池
self.pool_size = config.get('max_connections', 5)
self._pool = self._create_pool()
def _create_pool(self) -> PooledDB:
"""创建连接池"""
try:
# 使用包装函数确保线程安全
def connect():
conn = pymysql.connect(**self.config)
conn.threadsafety = 1 # 显式设置线程安全级别
return conn
pool = PooledDB(
creator=connect,
mincached=1,
maxcached=3,
maxconnections=self.pool_size,
blocking=True,
ping=1
)
self.log.info("Connection pool created")
return pool
except Exception as e:
self.log.critical("Failed to create connection pool",
error=str(e),
exc_info=True)
raise
def get_connection(self) -> pymysql.connections.Connection:
"""
获取数据库连接
Returns:
pymysql.connections.Connection: 数据库连接对象
Raises:
MySQLError: 如果获取连接失败
"""
try:
conn = self._pool.connection()
# macOS需要特殊处理SSL
if platform.system() == 'Darwin' and self.config.get('ssl'):
conn.ping(reconnect=True)
self.log.trace("Database connection obtained")
return conn
except Exception as e:
error_msg = str(e)
# Windows特定错误处理
if platform.system() == 'Windows' and "timed out" in error_msg:
self.log.warning("Windows connection timeout, retrying...")
return self._retry_connection()
self.log.error("Connection failed",
error=error_msg,
exc_info=True)
raise
def _retry_connection(self, max_retries: int = 3) -> pymysql.connections.Connection:
"""Windows平台连接重试机制"""
for attempt in range(max_retries):
try:
conn = self._pool.connection()
self.log.info(f"Connection established after {attempt + 1} attempts")
return conn
except Exception:
if attempt == max_retries - 1:
raise
import time
time.sleep(1)
def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None,
parse_dates: Union[List[str], bool] = True) -> pd.DataFrame:
"""
执行SQL查询并返回DataFrame
Args:
sql (str): SQL查询语句
params (Union[tuple, dict, None]): 查询参数
parse_dates (Union[List[str], bool]): 自动解析日期字段
Returns:
pd.DataFrame: 查询结果
Raises:
MySQLError: 如果查询失败
"""
try:
self.log.debug("Executing SQL query", sql=sql)
with self.get_connection() as conn:
# Linux/macOS需要更长的查询超时
if platform.system() != 'Windows':
conn.cursor().execute("SET SESSION wait_timeout=600")
df = pd.read_sql(sql, conn, params=params, parse_dates=parse_dates)
# Windows平台需要手动关闭游标
if platform.system() == 'Windows':
conn.cursor().close()
self.log.info("Query executed successfully", rows=len(df))
return df
except Exception as e:
self.log.error("SQL query failed",
sql=sql,
params=params,
error=str(e),
exc_info=True)
raise
def insert_from_df(self, table_name: str, df: pd.DataFrame,
chunk_size: int = 1000, replace: bool = False) -> int:
"""
将DataFrame数据插入到数据库表
Args:
table_name (str): 目标表名
df (pd.DataFrame): 要插入的数据
chunk_size (int): 分批插入大小
replace (bool): 是否替换现有数据
Returns:
int: 插入的总行数
Raises:
MySQLError: 如果插入失败
"""
if df.empty:
self.log.warning("Attempted to insert empty DataFrame", table=table_name)
return 0
self.log.debug("Preparing to insert DataFrame",
table=table_name,
rows=len(df),
chunk_size=chunk_size)
try:
method = 'replace' if replace else 'append'
total_rows = 0
with self.get_connection() as conn:
# 各平台不同的分批策略
if platform.system() == 'Windows':
chunk_size = min(chunk_size, 500) # Windows上减小批次
for i in range(0, len(df), chunk_size):
chunk = df.iloc[i:i + chunk_size]
# macOS需要特殊处理datetime
if platform.system() == 'Darwin':
for col in chunk.select_dtypes(include=['datetime64']):
chunk[col] = chunk[col].dt.strftime('%Y-%m-%d %H:%M:%S')
chunk.to_sql(
table_name,
conn,
if_exists=method,
index=False,
method='multi'
)
total_rows += len(chunk)
method = 'append' # 第一次之后都使用追加模式
self.log.trace(f"Inserted chunk {i // chunk_size + 1}",
rows=len(chunk),
total_inserted=total_rows)
self.log.info("Data inserted successfully",
table=table_name,
total_rows=total_rows)
return total_rows
except Exception as e:
self.log.error("Data insertion failed",
table=table_name,
error=str(e),
exc_info=True)
raise
def update_from_df(self, table_name: str, df: pd.DataFrame,
key_columns: Union[str, List[str]]) -> int:
"""
使用DataFrame数据更新数据库表
Args:
table_name (str): 目标表名
df (pd.DataFrame): 包含更新数据
key_columns (Union[str, List[str]]): 用于匹配记录的关键列
Returns:
int: 更新的总行数
Raises:
MySQLError: 如果更新失败
"""
if df.empty:
self.log.warning("Attempted to update with empty DataFrame", table=table_name)
return 0
self.log.debug("Preparing to update table from DataFrame",
table=table_name,
key_columns=key_columns,
rows=len(df))
try:
if isinstance(key_columns, str):
key_columns = [key_columns]
total_updated = 0
conn = self.begin_transaction()
try:
cursor = conn.cursor()
# 获取表结构信息
table_info = self._get_table_info(table_name)
columns = [col for col in df.columns if col in table_info]
# 构建UPDATE语句模板
set_clause = ', '.join([f"{col}=%s" for col in columns if col not in key_columns])
where_clause = ' AND '.join([f"{col}=%s" for col in key_columns])
update_sql = f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}"
self.log.trace("Generated update SQL", sql=update_sql)
# 准备数据
update_data = []
for _, row in df.iterrows():
# SET部分的值
set_values = [row[col] for col in columns if col not in key_columns]
# WHERE部分的值
key_values = [row[col] for col in key_columns]
update_data.append(tuple(set_values + key_values))
# 执行批量更新
cursor.executemany(update_sql, update_data)
total_updated = cursor.rowcount
self.commit_transaction(conn)
self.log.info("Data updated successfully",
table=table_name,
rows_updated=total_updated)
return total_updated
except Exception as e:
self.rollback_transaction(conn)
raise
except Exception as e:
self.log.error("Data update failed",
table=table_name,
error=str(e),
exc_info=True)
raise
def _get_table_info(self, table_name: str) -> Dict[str, str]:
"""
获取表结构信息
Args:
table_name (str): 表名
Returns:
Dict[str, str]: 列名到类型的映射
Raises:
MySQLError: 如果查询失败
"""
sql = f"""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
"""
params = (self.config['database'], table_name)
try:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(sql, params)
result = cursor.fetchall()
return {row['column_name']: row['data_type'] for row in result}
except Exception as e:
self.log.error("Failed to get table info",
table=table_name,
error=str(e))
raise
def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]:
"""
推断DataFrame各列的SQL类型
Args:
df (pd.DataFrame): 输入数据框
Returns:
Dict[str, str]: 列名到SQL类型的映射
"""
type_mapping = {
'int64': 'BIGINT',
'float64': 'DOUBLE',
'datetime64[ns]': 'DATETIME',
'object': 'TEXT',
'bool': 'TINYINT(1)',
'category': 'VARCHAR(255)'
}
sql_types = {}
for col, dtype in df.dtypes.items():
dtype_str = str(dtype)
sql_types[col] = type_mapping.get(dtype_str, 'TEXT')
self.log.debug("Mapped DataFrame types to SQL types",
mappings=sql_types)
return sql_types
def create_table_from_df(self, table_name: str, df: pd.DataFrame,
primary_key: Union[str, List[str], None] = None) -> bool:
"""
根据DataFrame结构创建表
Args:
table_name (str): 表名
df (pd.DataFrame): 参考数据框
primary_key (Union[str, List[str], None]): 主键列
Returns:
bool: 是否创建成功
"""
if self.table_exists(table_name):
self.log.warning("Table already exists", table=table_name)
return False
self.log.debug("Creating new table from DataFrame schema",
table=table_name,
columns=list(df.columns))
try:
sql_types = self.df_to_sql_type(df)
columns_sql = []
for col, sql_type in sql_types.items():
col_def = f"{col} {sql_type}"
columns_sql.append(col_def)
if primary_key:
if isinstance(primary_key, str):
primary_key = [primary_key]
pk_columns = [col for col in primary_key if col in sql_types]
if pk_columns:
columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})")
self.log.trace("Set primary key",
table=table_name,
primary_key=pk_columns)
create_sql = f"CREATE TABLE {table_name} (\n {',\n '.join(columns_sql)}\n)"
self.execute_sql(create_sql)
self.log.info("Table created successfully", table=table_name)
return True
except Exception as e:
self.log.error("Failed to create table",
table=table_name,
error=str(e),
exc_info=True)
return False
def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None,
fetch: bool = False) -> Union[int, List[Dict[str, Any]]]:
"""
执行SQL语句
Args:
sql (str): SQL语句
params (Union[tuple, dict, None]): 参数
fetch (bool): 是否获取结果
Returns:
Union[int, List[Dict[str, Any]]]:
- 如果是INSERT/UPDATE/DELETE,返回影响的行数
- 如果是SELECT且fetch=True,返回结果列表
"""
conn = None
cursor = None
try:
conn = self.get_connection()
cursor = conn.cursor()
# Linux/macOS需要更长的执行时间
if platform.system() != 'Windows':
cursor.execute("SET SESSION max_execution_time=600000")
cursor.execute(sql, params)
if fetch:
result = cursor.fetchall()
self.log.debug("Query executed", rows=len(result))
return result
else:
affected_rows = cursor.rowcount
self.log.debug("Update executed", affected_rows=affected_rows)
return affected_rows
except Exception as e:
self.log.error("SQL execution failed",
sql=sql,
params=params,
error=str(e),
exc_info=True)
raise
finally:
if cursor:
cursor.close()
if conn:
conn.close()
def begin_transaction(self) -> pymysql.connections.Connection:
"""开始事务"""
try:
conn = self.get_connection()
conn.autocommit(False)
# macOS需要特殊处理事务隔离级别
if platform.system() == 'Darwin':
conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED")
self.log.debug("Transaction started")
return conn
except Exception as e:
self.log.error("Begin transaction failed", error=str(e), exc_info=True)
raise
def commit_transaction(self, conn: pymysql.connections.Connection) -> None:
"""提交事务"""
try:
conn.commit()
self.log.debug("Transaction committed")
except Exception as e:
self.log.error("Commit failed", error=str(e))
raise
finally:
conn.close()
def rollback_transaction(self, conn: pymysql.connections.Connection) -> None:
"""回滚事务"""
try:
conn.rollback()
self.log.warning("Transaction rolled back")
except Exception as e:
self.log.error("Rollback failed", error=str(e))
finally:
conn.close()
def table_exists(self, table_name: str) -> bool:
"""检查表是否存在"""
sql = """
SELECT COUNT(*) as count
FROM `information_schema`.`tables`
WHERE `table_schema` = %s AND `table_name` = %s
"""
params = (self.config['database'], table_name)
try:
result = self.execute_sql(sql, params, fetch=True)
exists = result[0]['count'] > 0
self.log.debug("Checked table existence",
table=table_name,
exists=exists)
return exists
except Exception:
return False
def drop_table(self, table_name: str) -> bool:
"""删除表"""
if not self.table_exists(table_name):
self.log.warning("Table does not exist", table=table_name)
return False
try:
self.execute_sql(f"DROP TABLE {table_name}")
self.log.info("Table dropped successfully", table=table_name)
return True
except Exception as e:
self.log.error("Failed to drop table",
table=table_name,
error=str(e),
exc_info=True)
return False
def get_pool_status(self) -> Dict[str, int]:
"""获取连接池状态"""
return {
'max': self._pool._maxconnections,
'active': self._pool._connections,
'idle': len(self._pool._idle_cache),
'shared': len(self._pool._shared_cache)
}
def validate_connection(self) -> bool:
"""验证连接是否有效"""
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute("SELECT 1")
return cursor.fetchone()[0] == 1
except Exception:
return False
def __del__(self):
"""析构函数"""
if hasattr(self, '_pool'):
try:
self._pool.close()
self.log.info("Connection pool closed")
except Exception as e:
self.log.error("Failed to close pool", error=str(e))
# 平台特定的默认配置
def get_default_config():
"""获取各平台默认配置"""
current_platform = platform.system()
base_config = {
'host': 'localhost',
'port': 3306,
'user': 'root',
'password': '123123',
'database': 'intelligence_system',
'max_connections': 5
}
if current_platform == 'Windows':
return {
**base_config,
'connect_timeout': 10,
'read_timeout': 30,
'write_timeout': 30
}
elif current_platform == 'Darwin':
return {
**base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60,
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
}
else: # Linux和其他平台
return {
**base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60
}
if __name__ == "__main__":
# 使用示例
db = MySQLAgent(get_default_config())
# 测试连接
if db.validate_connection():
print("Database connection successful")
# 获取数据库版本
version = db.query_to_df("SELECT VERSION() as version")
print(f"Database version: {version['version'].iloc[0]}")
# 查看连接池状态
print("Connection pool status:", db.get_pool_status())
else:
print("Failed to connect to database")