256 lines
8.2 KiB
Python
256 lines
8.2 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
数据库存储模块
|
||
功能:
|
||
1. 统一数据库接口(SQLite/MySQL/PostgreSQL)
|
||
2. 自动处理多平台路径问题
|
||
3. 连接池管理
|
||
4. 数据加密存储
|
||
"""
|
||
|
||
import os
|
||
import platform
|
||
import sqlite3
|
||
import threading
|
||
from pathlib import Path
|
||
from typing import Optional, Union, Dict, Any, List
|
||
from threading import Lock
|
||
import logging
|
||
from cryptography.fernet import Fernet
|
||
|
||
# 类型别名
|
||
QueryParams = Union[tuple, Dict[str, Any]]
|
||
|
||
|
||
class DatabaseManager:
|
||
"""数据库统一管理类"""
|
||
|
||
def __init__(self, db_config: Dict[str, Any]):
|
||
"""
|
||
初始化数据库连接
|
||
|
||
参数:
|
||
db_config: 配置字典,包含:
|
||
- type: 'sqlite'|'mysql'|'postgresql'
|
||
- database: 数据库名/路径
|
||
- [可选] host, port, user, password
|
||
"""
|
||
self.config = db_config
|
||
self._lock = Lock()
|
||
self._connection_pool = {}
|
||
self._setup_crypto()
|
||
|
||
# 自动创建SQLite目录
|
||
if db_config['type'] == 'sqlite':
|
||
self._ensure_sqlite_dir()
|
||
|
||
def _ensure_sqlite_dir(self):
|
||
"""确保SQLite数据库目录存在"""
|
||
db_path = Path(self.config['database'])
|
||
if not db_path.parent.exists():
|
||
try:
|
||
db_path.parent.mkdir(parents=True, mode=0o755)
|
||
except Exception as e:
|
||
logging.error(f"创建数据库目录失败: {str(e)}")
|
||
|
||
def _setup_crypto(self):
|
||
"""初始化加密模块"""
|
||
key_file = Path(os.path.expanduser("~/.db_encryption.key"))
|
||
if key_file.exists():
|
||
with open(key_file, 'rb') as f:
|
||
self._fernet = Fernet(f.read())
|
||
else:
|
||
self._fernet = Fernet.generate_key()
|
||
with open(key_file, 'wb') as f:
|
||
f.write(self._fernet)
|
||
key_file.chmod(0o600) # 仅限当前用户读写
|
||
|
||
def get_connection(self, reuse=True):
|
||
"""
|
||
获取数据库连接(线程安全)
|
||
|
||
参数:
|
||
reuse: 是否复用现有连接(默认True)
|
||
"""
|
||
thread_id = threading.get_ident()
|
||
|
||
with self._lock:
|
||
if reuse and thread_id in self._connection_pool:
|
||
conn = self._connection_pool[thread_id]
|
||
try:
|
||
# 检查连接是否有效
|
||
conn.execute("SELECT 1")
|
||
return conn
|
||
except:
|
||
del self._connection_pool[thread_id]
|
||
|
||
# 创建新连接
|
||
if self.config['type'] == 'sqlite':
|
||
conn = self._create_sqlite_connection()
|
||
elif self.config['type'] == 'mysql':
|
||
conn = self._create_mysql_connection()
|
||
elif self.config['type'] == 'postgresql':
|
||
conn = self._create_pg_connection()
|
||
else:
|
||
raise ValueError("不支持的数据库类型")
|
||
|
||
self._connection_pool[thread_id] = conn
|
||
return conn
|
||
|
||
def _create_sqlite_connection(self) -> sqlite3.Connection:
|
||
"""创建SQLite连接(兼容多平台路径)"""
|
||
db_path = self.config['database']
|
||
|
||
# Windows路径处理
|
||
if platform.system() == 'Windows' and not db_path.startswith(('\\\\', '/')):
|
||
db_path = os.path.abspath(db_path)
|
||
|
||
conn = sqlite3.connect(db_path, timeout=15)
|
||
conn.execute("PRAGMA journal_mode=WAL") # 写前日志提升并发
|
||
conn.execute("PRAGMA synchronous=NORMAL")
|
||
conn.row_factory = sqlite3.Row # 支持字典式访问
|
||
return conn
|
||
|
||
def _create_mysql_connection(self):
|
||
"""创建MySQL连接(需安装PyMySQL)"""
|
||
import pymysql
|
||
return pymysql.connect(
|
||
host=self.config.get('host', 'localhost'),
|
||
port=self.config.get('port', 3306),
|
||
user=self.config.get('user', 'root'),
|
||
password=self.config.get('password', ''),
|
||
database=self.config['database'],
|
||
charset='utf8mb4',
|
||
cursorclass=pymysql.cursors.DictCursor
|
||
)
|
||
|
||
def _create_pg_connection(self):
|
||
"""创建PostgreSQL连接(需安装psycopg2)"""
|
||
import psycopg2
|
||
return psycopg2.connect(
|
||
host=self.config.get('host', 'localhost'),
|
||
port=self.config.get('port', 5432),
|
||
user=self.config.get('user', 'postgres'),
|
||
password=self.config.get('password', ''),
|
||
dbname=self.config['database']
|
||
)
|
||
|
||
def execute(
|
||
self,
|
||
query: str,
|
||
params: Optional[QueryParams] = None,
|
||
return_lastrowid: bool = False
|
||
) -> Union[int, None]:
|
||
conn = self.get_connection()
|
||
try:
|
||
with conn:
|
||
cursor = conn.cursor()
|
||
if params is not None:
|
||
cursor.execute(query, params)
|
||
else:
|
||
cursor.execute(query) # ✅ 无参数时不要传 params
|
||
return cursor.lastrowid if return_lastrowid else None
|
||
finally:
|
||
if not return_lastrowid:
|
||
self._release_connection()
|
||
|
||
def query(
|
||
self,
|
||
query: str,
|
||
params: Optional[QueryParams] = None,
|
||
fetchall: bool = True
|
||
) -> Union[List[Dict], Dict]:
|
||
conn = self.get_connection()
|
||
try:
|
||
cursor = conn.cursor()
|
||
if params is not None:
|
||
cursor.execute(query, params)
|
||
else:
|
||
cursor.execute(query) # ✅ 无参数时不要传 None
|
||
|
||
if self.config['type'] == 'sqlite':
|
||
result = cursor.fetchall()
|
||
if not fetchall and result:
|
||
return dict(result[0])
|
||
return [dict(row) for row in result]
|
||
else:
|
||
return cursor.fetchall() if fetchall else cursor.fetchone()
|
||
finally:
|
||
self._release_connection()
|
||
|
||
def _release_connection(self):
|
||
"""释放当前线程的连接(SQLite除外)"""
|
||
if self.config['type'] != 'sqlite':
|
||
thread_id = threading.get_ident()
|
||
with self._lock:
|
||
if thread_id in self._connection_pool:
|
||
self._connection_pool[thread_id].close()
|
||
del self._connection_pool[thread_id]
|
||
|
||
def encrypt_data(self, plaintext: str) -> str:
|
||
"""加密敏感数据"""
|
||
return self._fernet.encrypt(plaintext.encode()).decode()
|
||
|
||
def decrypt_data(self, ciphertext: str) -> str:
|
||
"""解密数据"""
|
||
return self._fernet.decrypt(ciphertext.encode()).decode()
|
||
|
||
def close_all(self):
|
||
"""关闭所有数据库连接"""
|
||
with self._lock:
|
||
for conn in self._connection_pool.values():
|
||
try:
|
||
conn.close()
|
||
except:
|
||
pass
|
||
self._connection_pool.clear()
|
||
|
||
def __enter__(self):
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
self.close_all()
|
||
|
||
|
||
# 全局SQLite实例(默认配置)
|
||
def get_default_db() -> DatabaseManager:
|
||
"""获取默认SQLite数据库(跨平台路径处理)"""
|
||
system = platform.system().lower()
|
||
if system == 'windows':
|
||
db_path = os.path.join(os.getenv('APPDATA'), 'app_name/data.db')
|
||
elif system == 'darwin':
|
||
db_path = os.path.expanduser('~/Library/Application Support/app_name/data.db')
|
||
else:
|
||
db_path = '/var/lib/app_name/data.db' if os.access('/var/lib', os.W_OK) \
|
||
else os.path.expanduser('~/.local/share/app_name/data.db')
|
||
|
||
return DatabaseManager({
|
||
'type': 'sqlite',
|
||
'database': db_path
|
||
})
|
||
|
||
|
||
# 测试代码
|
||
if __name__ == "__main__":
|
||
with get_default_db() as db:
|
||
# 创建测试表
|
||
db.execute("""
|
||
CREATE TABLE IF NOT EXISTS test_table (
|
||
id INTEGER PRIMARY KEY,
|
||
name TEXT,
|
||
secret TEXT
|
||
)
|
||
""")
|
||
|
||
# 插入加密数据
|
||
secret = db.encrypt_data("新敏感信息")
|
||
db.execute(
|
||
"INSERT INTO test_table (name, secret) VALUES (?, ?)",
|
||
("测试记录", secret)
|
||
)
|
||
|
||
# 查询并解密
|
||
row = db.query("SELECT * FROM test_table", fetchall=False)
|
||
print(f"解密数据: {db.decrypt_data(row['name'])}")
|