ai初期模板
This commit is contained in:
@@ -0,0 +1,255 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据库存储模块
|
||||
功能:
|
||||
1. 统一数据库接口(SQLite/MySQL/PostgreSQL)
|
||||
2. 自动处理多平台路径问题
|
||||
3. 连接池管理
|
||||
4. 数据加密存储
|
||||
"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
import sqlite3
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Dict, Any, List
|
||||
from threading import Lock
|
||||
import logging
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
# 类型别名
|
||||
QueryParams = Union[tuple, Dict[str, Any]]
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""数据库统一管理类"""
|
||||
|
||||
def __init__(self, db_config: Dict[str, Any]):
|
||||
"""
|
||||
初始化数据库连接
|
||||
|
||||
参数:
|
||||
db_config: 配置字典,包含:
|
||||
- type: 'sqlite'|'mysql'|'postgresql'
|
||||
- database: 数据库名/路径
|
||||
- [可选] host, port, user, password
|
||||
"""
|
||||
self.config = db_config
|
||||
self._lock = Lock()
|
||||
self._connection_pool = {}
|
||||
self._setup_crypto()
|
||||
|
||||
# 自动创建SQLite目录
|
||||
if db_config['type'] == 'sqlite':
|
||||
self._ensure_sqlite_dir()
|
||||
|
||||
def _ensure_sqlite_dir(self):
|
||||
"""确保SQLite数据库目录存在"""
|
||||
db_path = Path(self.config['database'])
|
||||
if not db_path.parent.exists():
|
||||
try:
|
||||
db_path.parent.mkdir(parents=True, mode=0o755)
|
||||
except Exception as e:
|
||||
logging.error(f"创建数据库目录失败: {str(e)}")
|
||||
|
||||
def _setup_crypto(self):
|
||||
"""初始化加密模块"""
|
||||
key_file = Path(os.path.expanduser("~/.db_encryption.key"))
|
||||
if key_file.exists():
|
||||
with open(key_file, 'rb') as f:
|
||||
self._fernet = Fernet(f.read())
|
||||
else:
|
||||
self._fernet = Fernet.generate_key()
|
||||
with open(key_file, 'wb') as f:
|
||||
f.write(self._fernet)
|
||||
key_file.chmod(0o600) # 仅限当前用户读写
|
||||
|
||||
def get_connection(self, reuse=True):
|
||||
"""
|
||||
获取数据库连接(线程安全)
|
||||
|
||||
参数:
|
||||
reuse: 是否复用现有连接(默认True)
|
||||
"""
|
||||
thread_id = threading.get_ident()
|
||||
|
||||
with self._lock:
|
||||
if reuse and thread_id in self._connection_pool:
|
||||
conn = self._connection_pool[thread_id]
|
||||
try:
|
||||
# 检查连接是否有效
|
||||
conn.execute("SELECT 1")
|
||||
return conn
|
||||
except:
|
||||
del self._connection_pool[thread_id]
|
||||
|
||||
# 创建新连接
|
||||
if self.config['type'] == 'sqlite':
|
||||
conn = self._create_sqlite_connection()
|
||||
elif self.config['type'] == 'mysql':
|
||||
conn = self._create_mysql_connection()
|
||||
elif self.config['type'] == 'postgresql':
|
||||
conn = self._create_pg_connection()
|
||||
else:
|
||||
raise ValueError("不支持的数据库类型")
|
||||
|
||||
self._connection_pool[thread_id] = conn
|
||||
return conn
|
||||
|
||||
def _create_sqlite_connection(self) -> sqlite3.Connection:
|
||||
"""创建SQLite连接(兼容多平台路径)"""
|
||||
db_path = self.config['database']
|
||||
|
||||
# Windows路径处理
|
||||
if platform.system() == 'Windows' and not db_path.startswith(('\\\\', '/')):
|
||||
db_path = os.path.abspath(db_path)
|
||||
|
||||
conn = sqlite3.connect(db_path, timeout=15)
|
||||
conn.execute("PRAGMA journal_mode=WAL") # 写前日志提升并发
|
||||
conn.execute("PRAGMA synchronous=NORMAL")
|
||||
conn.row_factory = sqlite3.Row # 支持字典式访问
|
||||
return conn
|
||||
|
||||
def _create_mysql_connection(self):
|
||||
"""创建MySQL连接(需安装PyMySQL)"""
|
||||
import pymysql
|
||||
return pymysql.connect(
|
||||
host=self.config.get('host', 'localhost'),
|
||||
port=self.config.get('port', 3306),
|
||||
user=self.config.get('user', 'root'),
|
||||
password=self.config.get('password', ''),
|
||||
database=self.config['database'],
|
||||
charset='utf8mb4',
|
||||
cursorclass=pymysql.cursors.DictCursor
|
||||
)
|
||||
|
||||
def _create_pg_connection(self):
|
||||
"""创建PostgreSQL连接(需安装psycopg2)"""
|
||||
import psycopg2
|
||||
return psycopg2.connect(
|
||||
host=self.config.get('host', 'localhost'),
|
||||
port=self.config.get('port', 5432),
|
||||
user=self.config.get('user', 'postgres'),
|
||||
password=self.config.get('password', ''),
|
||||
dbname=self.config['database']
|
||||
)
|
||||
|
||||
def execute(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[QueryParams] = None,
|
||||
return_lastrowid: bool = False
|
||||
) -> Union[int, None]:
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
with conn:
|
||||
cursor = conn.cursor()
|
||||
if params is not None:
|
||||
cursor.execute(query, params)
|
||||
else:
|
||||
cursor.execute(query) # ✅ 无参数时不要传 params
|
||||
return cursor.lastrowid if return_lastrowid else None
|
||||
finally:
|
||||
if not return_lastrowid:
|
||||
self._release_connection()
|
||||
|
||||
def query(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[QueryParams] = None,
|
||||
fetchall: bool = True
|
||||
) -> Union[List[Dict], Dict]:
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
if params is not None:
|
||||
cursor.execute(query, params)
|
||||
else:
|
||||
cursor.execute(query) # ✅ 无参数时不要传 None
|
||||
|
||||
if self.config['type'] == 'sqlite':
|
||||
result = cursor.fetchall()
|
||||
if not fetchall and result:
|
||||
return dict(result[0])
|
||||
return [dict(row) for row in result]
|
||||
else:
|
||||
return cursor.fetchall() if fetchall else cursor.fetchone()
|
||||
finally:
|
||||
self._release_connection()
|
||||
|
||||
def _release_connection(self):
|
||||
"""释放当前线程的连接(SQLite除外)"""
|
||||
if self.config['type'] != 'sqlite':
|
||||
thread_id = threading.get_ident()
|
||||
with self._lock:
|
||||
if thread_id in self._connection_pool:
|
||||
self._connection_pool[thread_id].close()
|
||||
del self._connection_pool[thread_id]
|
||||
|
||||
def encrypt_data(self, plaintext: str) -> str:
|
||||
"""加密敏感数据"""
|
||||
return self._fernet.encrypt(plaintext.encode()).decode()
|
||||
|
||||
def decrypt_data(self, ciphertext: str) -> str:
|
||||
"""解密数据"""
|
||||
return self._fernet.decrypt(ciphertext.encode()).decode()
|
||||
|
||||
def close_all(self):
|
||||
"""关闭所有数据库连接"""
|
||||
with self._lock:
|
||||
for conn in self._connection_pool.values():
|
||||
try:
|
||||
conn.close()
|
||||
except:
|
||||
pass
|
||||
self._connection_pool.clear()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close_all()
|
||||
|
||||
|
||||
# 全局SQLite实例(默认配置)
|
||||
def get_default_db() -> DatabaseManager:
|
||||
"""获取默认SQLite数据库(跨平台路径处理)"""
|
||||
system = platform.system().lower()
|
||||
if system == 'windows':
|
||||
db_path = os.path.join(os.getenv('APPDATA'), 'app_name/data.db')
|
||||
elif system == 'darwin':
|
||||
db_path = os.path.expanduser('~/Library/Application Support/app_name/data.db')
|
||||
else:
|
||||
db_path = '/var/lib/app_name/data.db' if os.access('/var/lib', os.W_OK) \
|
||||
else os.path.expanduser('~/.local/share/app_name/data.db')
|
||||
|
||||
return DatabaseManager({
|
||||
'type': 'sqlite',
|
||||
'database': db_path
|
||||
})
|
||||
|
||||
|
||||
# 测试代码
|
||||
if __name__ == "__main__":
|
||||
with get_default_db() as db:
|
||||
# 创建测试表
|
||||
db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS test_table (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT,
|
||||
secret TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
# 插入加密数据
|
||||
secret = db.encrypt_data("新敏感信息")
|
||||
db.execute(
|
||||
"INSERT INTO test_table (name, secret) VALUES (?, ?)",
|
||||
("测试记录", secret)
|
||||
)
|
||||
|
||||
# 查询并解密
|
||||
row = db.query("SELECT * FROM test_table", fetchall=False)
|
||||
print(f"解密数据: {db.decrypt_data(row['name'])}")
|
||||
Reference in New Issue
Block a user