#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 数据库存储模块 功能: 1. 统一数据库接口(SQLite/MySQL/PostgreSQL) 2. 自动处理多平台路径问题 3. 连接池管理 4. 数据加密存储 """ import os import platform import sqlite3 import threading from pathlib import Path from typing import Optional, Union, Dict, Any, List from threading import Lock import logging from cryptography.fernet import Fernet # 类型别名 QueryParams = Union[tuple, Dict[str, Any]] class DatabaseManager: """数据库统一管理类""" def __init__(self, db_config: Dict[str, Any]): """ 初始化数据库连接 参数: db_config: 配置字典,包含: - type: 'sqlite'|'mysql'|'postgresql' - database: 数据库名/路径 - [可选] host, port, user, password """ self.config = db_config self._lock = Lock() self._connection_pool = {} self._setup_crypto() # 自动创建SQLite目录 if db_config['type'] == 'sqlite': self._ensure_sqlite_dir() def _ensure_sqlite_dir(self): """确保SQLite数据库目录存在""" db_path = Path(self.config['database']) if not db_path.parent.exists(): try: db_path.parent.mkdir(parents=True, mode=0o755) except Exception as e: logging.error(f"创建数据库目录失败: {str(e)}") def _setup_crypto(self): """初始化加密模块""" key_file = Path(os.path.expanduser("~/.db_encryption.key")) if key_file.exists(): with open(key_file, 'rb') as f: self._fernet = Fernet(f.read()) else: self._fernet = Fernet.generate_key() with open(key_file, 'wb') as f: f.write(self._fernet) key_file.chmod(0o600) # 仅限当前用户读写 def get_connection(self, reuse=True): """ 获取数据库连接(线程安全) 参数: reuse: 是否复用现有连接(默认True) """ thread_id = threading.get_ident() with self._lock: if reuse and thread_id in self._connection_pool: conn = self._connection_pool[thread_id] try: # 检查连接是否有效 conn.execute("SELECT 1") return conn except: del self._connection_pool[thread_id] # 创建新连接 if self.config['type'] == 'sqlite': conn = self._create_sqlite_connection() elif self.config['type'] == 'mysql': conn = self._create_mysql_connection() elif self.config['type'] == 'postgresql': conn = self._create_pg_connection() else: raise ValueError("不支持的数据库类型") self._connection_pool[thread_id] = conn return conn def _create_sqlite_connection(self) -> sqlite3.Connection: """创建SQLite连接(兼容多平台路径)""" db_path = self.config['database'] # Windows路径处理 if platform.system() == 'Windows' and not db_path.startswith(('\\\\', '/')): db_path = os.path.abspath(db_path) conn = sqlite3.connect(db_path, timeout=15) conn.execute("PRAGMA journal_mode=WAL") # 写前日志提升并发 conn.execute("PRAGMA synchronous=NORMAL") conn.row_factory = sqlite3.Row # 支持字典式访问 return conn def _create_mysql_connection(self): """创建MySQL连接(需安装PyMySQL)""" import pymysql return pymysql.connect( host=self.config.get('host', 'localhost'), port=self.config.get('port', 3306), user=self.config.get('user', 'root'), password=self.config.get('password', ''), database=self.config['database'], charset='utf8mb4', cursorclass=pymysql.cursors.DictCursor ) def _create_pg_connection(self): """创建PostgreSQL连接(需安装psycopg2)""" import psycopg2 return psycopg2.connect( host=self.config.get('host', 'localhost'), port=self.config.get('port', 5432), user=self.config.get('user', 'postgres'), password=self.config.get('password', ''), dbname=self.config['database'] ) def execute( self, query: str, params: Optional[QueryParams] = None, return_lastrowid: bool = False ) -> Union[int, None]: conn = self.get_connection() try: with conn: cursor = conn.cursor() if params is not None: cursor.execute(query, params) else: cursor.execute(query) # ✅ 无参数时不要传 params return cursor.lastrowid if return_lastrowid else None finally: if not return_lastrowid: self._release_connection() def query( self, query: str, params: Optional[QueryParams] = None, fetchall: bool = True ) -> Union[List[Dict], Dict]: conn = self.get_connection() try: cursor = conn.cursor() if params is not None: cursor.execute(query, params) else: cursor.execute(query) # ✅ 无参数时不要传 None if self.config['type'] == 'sqlite': result = cursor.fetchall() if not fetchall and result: return dict(result[0]) return [dict(row) for row in result] else: return cursor.fetchall() if fetchall else cursor.fetchone() finally: self._release_connection() def _release_connection(self): """释放当前线程的连接(SQLite除外)""" if self.config['type'] != 'sqlite': thread_id = threading.get_ident() with self._lock: if thread_id in self._connection_pool: self._connection_pool[thread_id].close() del self._connection_pool[thread_id] def encrypt_data(self, plaintext: str) -> str: """加密敏感数据""" return self._fernet.encrypt(plaintext.encode()).decode() def decrypt_data(self, ciphertext: str) -> str: """解密数据""" return self._fernet.decrypt(ciphertext.encode()).decode() def close_all(self): """关闭所有数据库连接""" with self._lock: for conn in self._connection_pool.values(): try: conn.close() except: pass self._connection_pool.clear() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close_all() # 全局SQLite实例(默认配置) def get_default_db() -> DatabaseManager: """获取默认SQLite数据库(跨平台路径处理)""" system = platform.system().lower() if system == 'windows': db_path = os.path.join(os.getenv('APPDATA'), 'app_name/data.db') elif system == 'darwin': db_path = os.path.expanduser('~/Library/Application Support/app_name/data.db') else: db_path = '/var/lib/app_name/data.db' if os.access('/var/lib', os.W_OK) \ else os.path.expanduser('~/.local/share/app_name/data.db') return DatabaseManager({ 'type': 'sqlite', 'database': db_path }) # 测试代码 if __name__ == "__main__": with get_default_db() as db: # 创建测试表 db.execute(""" CREATE TABLE IF NOT EXISTS test_table ( id INTEGER PRIMARY KEY, name TEXT, secret TEXT ) """) # 插入加密数据 secret = db.encrypt_data("新敏感信息") db.execute( "INSERT INTO test_table (name, secret) VALUES (?, ?)", ("测试记录", secret) ) # 查询并解密 row = db.query("SELECT * FROM test_table", fetchall=False) print(f"解密数据: {db.decrypt_data(row['name'])}")