from sqlalchemy import text from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker from contextlib import asynccontextmanager from .models import Base import config from config.db_config import mysql_db_config, sqlite_db_config, postgresql_db_config # Keep a cache of engines _engines = {} async def create_database_if_not_exists(db_type: str): if db_type == "mysql" or db_type == "db": # Connect to the server without a database server_url = f"mysql+asyncmy://{mysql_db_config['user']}:{mysql_db_config['password']}@{mysql_db_config['host']}:{mysql_db_config['port']}?charset=utf8mb4" engine = create_async_engine(server_url, echo=False) async with engine.connect() as conn: # 确保数据库使用utf8mb4字符集 await conn.execute(text(f"CREATE DATABASE IF NOT EXISTS {mysql_db_config['db_name']} CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci")) await engine.dispose() elif db_type == "postgresql": # Connect to PostgreSQL default database (postgres) to create target database server_url = f"postgresql+asyncpg://{postgresql_db_config['user']}:{postgresql_db_config['password']}@{postgresql_db_config['host']}:{postgresql_db_config['port']}/postgres" engine = create_async_engine(server_url, echo=False, isolation_level="AUTOCOMMIT") async with engine.connect() as conn: # PostgreSQL uses different syntax - check if database exists first result = await conn.execute( text(f"SELECT 1 FROM pg_database WHERE datname = '{postgresql_db_config['db_name']}'") ) exists = result.scalar() is not None if not exists: # Set autocommit for CREATE DATABASE await conn.commit() await conn.execute(text(f"CREATE DATABASE {postgresql_db_config['db_name']}")) await engine.dispose() def get_async_engine(db_type: str = None): if db_type is None: db_type = config.SAVE_DATA_OPTION if db_type in _engines: return _engines[db_type] if db_type in ["json", "csv"]: return None if db_type == "sqlite": db_url = f"sqlite+aiosqlite:///{sqlite_db_config['db_path']}" elif db_type == "mysql" or db_type == "db": # 添加charset=utf8mb4以支持完整的UTF-8编码(包括emoji和中文) db_url = f"mysql+asyncmy://{mysql_db_config['user']}:{mysql_db_config['password']}@{mysql_db_config['host']}:{mysql_db_config['port']}/{mysql_db_config['db_name']}?charset=utf8mb4" elif db_type == "postgresql": db_url = f"postgresql+asyncpg://{postgresql_db_config['user']}:{postgresql_db_config['password']}@{postgresql_db_config['host']}:{postgresql_db_config['port']}/{postgresql_db_config['db_name']}" else: raise ValueError(f"Unsupported database type: {db_type}") engine = create_async_engine(db_url, echo=False) _engines[db_type] = engine return engine async def create_tables(db_type: str = None): if db_type is None: db_type = config.SAVE_DATA_OPTION await create_database_if_not_exists(db_type) engine = get_async_engine(db_type) if engine: async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) @asynccontextmanager async def get_session() -> AsyncSession: engine = get_async_engine(config.SAVE_DATA_OPTION) if not engine: yield None return AsyncSessionFactory = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) session = AsyncSessionFactory() try: yield session await session.commit() except Exception as e: await session.rollback() raise e finally: await session.close()