ai初期模板

This commit is contained in:
2025-08-05 15:00:46 +08:00
commit 71e9c7c5bc
21 changed files with 1446 additions and 0 deletions
+255
View File
@@ -0,0 +1,255 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
数据库存储模块
功能:
1. 统一数据库接口(SQLite/MySQL/PostgreSQL
2. 自动处理多平台路径问题
3. 连接池管理
4. 数据加密存储
"""
import os
import platform
import sqlite3
import threading
from pathlib import Path
from typing import Optional, Union, Dict, Any, List
from threading import Lock
import logging
from cryptography.fernet import Fernet
# 类型别名
QueryParams = Union[tuple, Dict[str, Any]]
class DatabaseManager:
"""数据库统一管理类"""
def __init__(self, db_config: Dict[str, Any]):
"""
初始化数据库连接
参数:
db_config: 配置字典,包含:
- type: 'sqlite'|'mysql'|'postgresql'
- database: 数据库名/路径
- [可选] host, port, user, password
"""
self.config = db_config
self._lock = Lock()
self._connection_pool = {}
self._setup_crypto()
# 自动创建SQLite目录
if db_config['type'] == 'sqlite':
self._ensure_sqlite_dir()
def _ensure_sqlite_dir(self):
"""确保SQLite数据库目录存在"""
db_path = Path(self.config['database'])
if not db_path.parent.exists():
try:
db_path.parent.mkdir(parents=True, mode=0o755)
except Exception as e:
logging.error(f"创建数据库目录失败: {str(e)}")
def _setup_crypto(self):
"""初始化加密模块"""
key_file = Path(os.path.expanduser("~/.db_encryption.key"))
if key_file.exists():
with open(key_file, 'rb') as f:
self._fernet = Fernet(f.read())
else:
self._fernet = Fernet.generate_key()
with open(key_file, 'wb') as f:
f.write(self._fernet)
key_file.chmod(0o600) # 仅限当前用户读写
def get_connection(self, reuse=True):
"""
获取数据库连接(线程安全)
参数:
reuse: 是否复用现有连接(默认True)
"""
thread_id = threading.get_ident()
with self._lock:
if reuse and thread_id in self._connection_pool:
conn = self._connection_pool[thread_id]
try:
# 检查连接是否有效
conn.execute("SELECT 1")
return conn
except:
del self._connection_pool[thread_id]
# 创建新连接
if self.config['type'] == 'sqlite':
conn = self._create_sqlite_connection()
elif self.config['type'] == 'mysql':
conn = self._create_mysql_connection()
elif self.config['type'] == 'postgresql':
conn = self._create_pg_connection()
else:
raise ValueError("不支持的数据库类型")
self._connection_pool[thread_id] = conn
return conn
def _create_sqlite_connection(self) -> sqlite3.Connection:
"""创建SQLite连接(兼容多平台路径)"""
db_path = self.config['database']
# Windows路径处理
if platform.system() == 'Windows' and not db_path.startswith(('\\\\', '/')):
db_path = os.path.abspath(db_path)
conn = sqlite3.connect(db_path, timeout=15)
conn.execute("PRAGMA journal_mode=WAL") # 写前日志提升并发
conn.execute("PRAGMA synchronous=NORMAL")
conn.row_factory = sqlite3.Row # 支持字典式访问
return conn
def _create_mysql_connection(self):
"""创建MySQL连接(需安装PyMySQL"""
import pymysql
return pymysql.connect(
host=self.config.get('host', 'localhost'),
port=self.config.get('port', 3306),
user=self.config.get('user', 'root'),
password=self.config.get('password', ''),
database=self.config['database'],
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor
)
def _create_pg_connection(self):
"""创建PostgreSQL连接(需安装psycopg2"""
import psycopg2
return psycopg2.connect(
host=self.config.get('host', 'localhost'),
port=self.config.get('port', 5432),
user=self.config.get('user', 'postgres'),
password=self.config.get('password', ''),
dbname=self.config['database']
)
def execute(
self,
query: str,
params: Optional[QueryParams] = None,
return_lastrowid: bool = False
) -> Union[int, None]:
conn = self.get_connection()
try:
with conn:
cursor = conn.cursor()
if params is not None:
cursor.execute(query, params)
else:
cursor.execute(query) # ✅ 无参数时不要传 params
return cursor.lastrowid if return_lastrowid else None
finally:
if not return_lastrowid:
self._release_connection()
def query(
self,
query: str,
params: Optional[QueryParams] = None,
fetchall: bool = True
) -> Union[List[Dict], Dict]:
conn = self.get_connection()
try:
cursor = conn.cursor()
if params is not None:
cursor.execute(query, params)
else:
cursor.execute(query) # ✅ 无参数时不要传 None
if self.config['type'] == 'sqlite':
result = cursor.fetchall()
if not fetchall and result:
return dict(result[0])
return [dict(row) for row in result]
else:
return cursor.fetchall() if fetchall else cursor.fetchone()
finally:
self._release_connection()
def _release_connection(self):
"""释放当前线程的连接(SQLite除外)"""
if self.config['type'] != 'sqlite':
thread_id = threading.get_ident()
with self._lock:
if thread_id in self._connection_pool:
self._connection_pool[thread_id].close()
del self._connection_pool[thread_id]
def encrypt_data(self, plaintext: str) -> str:
"""加密敏感数据"""
return self._fernet.encrypt(plaintext.encode()).decode()
def decrypt_data(self, ciphertext: str) -> str:
"""解密数据"""
return self._fernet.decrypt(ciphertext.encode()).decode()
def close_all(self):
"""关闭所有数据库连接"""
with self._lock:
for conn in self._connection_pool.values():
try:
conn.close()
except:
pass
self._connection_pool.clear()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close_all()
# 全局SQLite实例(默认配置)
def get_default_db() -> DatabaseManager:
"""获取默认SQLite数据库(跨平台路径处理)"""
system = platform.system().lower()
if system == 'windows':
db_path = os.path.join(os.getenv('APPDATA'), 'app_name/data.db')
elif system == 'darwin':
db_path = os.path.expanduser('~/Library/Application Support/app_name/data.db')
else:
db_path = '/var/lib/app_name/data.db' if os.access('/var/lib', os.W_OK) \
else os.path.expanduser('~/.local/share/app_name/data.db')
return DatabaseManager({
'type': 'sqlite',
'database': db_path
})
# 测试代码
if __name__ == "__main__":
with get_default_db() as db:
# 创建测试表
db.execute("""
CREATE TABLE IF NOT EXISTS test_table (
id INTEGER PRIMARY KEY,
name TEXT,
secret TEXT
)
""")
# 插入加密数据
secret = db.encrypt_data("新敏感信息")
db.execute(
"INSERT INTO test_table (name, secret) VALUES (?, ?)",
("测试记录", secret)
)
# 查询并解密
row = db.query("SELECT * FROM test_table", fetchall=False)
print(f"解密数据: {db.decrypt_data(row['name'])}")