1. 同步MediaCrawler为最新版本
2. 修复数据库not null错误 3. 支持PG数据库 4. 规范环境变量及配置使用 5. 规范为uv安装 6. 使用loggru
This commit is contained in:
@@ -0,0 +1,87 @@
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from contextlib import asynccontextmanager
|
||||
from .models import Base
|
||||
import config
|
||||
from config.db_config import mysql_db_config, sqlite_db_config, postgresql_db_config
|
||||
|
||||
# Keep a cache of engines
|
||||
_engines = {}
|
||||
|
||||
|
||||
async def create_database_if_not_exists(db_type: str):
|
||||
if db_type == "mysql" or db_type == "db":
|
||||
# Connect to the server without a database
|
||||
server_url = f"mysql+asyncmy://{mysql_db_config['user']}:{mysql_db_config['password']}@{mysql_db_config['host']}:{mysql_db_config['port']}"
|
||||
engine = create_async_engine(server_url, echo=False)
|
||||
async with engine.connect() as conn:
|
||||
await conn.execute(text(f"CREATE DATABASE IF NOT EXISTS {mysql_db_config['db_name']}"))
|
||||
await engine.dispose()
|
||||
elif db_type == "postgresql":
|
||||
# Connect to PostgreSQL default database (postgres) to create target database
|
||||
server_url = f"postgresql+asyncpg://{postgresql_db_config['user']}:{postgresql_db_config['password']}@{postgresql_db_config['host']}:{postgresql_db_config['port']}/postgres"
|
||||
engine = create_async_engine(server_url, echo=False, isolation_level="AUTOCOMMIT")
|
||||
async with engine.connect() as conn:
|
||||
# PostgreSQL uses different syntax - check if database exists first
|
||||
result = await conn.execute(
|
||||
text(f"SELECT 1 FROM pg_database WHERE datname = '{postgresql_db_config['db_name']}'")
|
||||
)
|
||||
exists = result.scalar() is not None
|
||||
if not exists:
|
||||
# Set autocommit for CREATE DATABASE
|
||||
await conn.commit()
|
||||
await conn.execute(text(f"CREATE DATABASE {postgresql_db_config['db_name']}"))
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def get_async_engine(db_type: str = None):
|
||||
if db_type is None:
|
||||
db_type = config.SAVE_DATA_OPTION
|
||||
|
||||
if db_type in _engines:
|
||||
return _engines[db_type]
|
||||
|
||||
if db_type in ["json", "csv"]:
|
||||
return None
|
||||
|
||||
if db_type == "sqlite":
|
||||
db_url = f"sqlite+aiosqlite:///{sqlite_db_config['db_path']}"
|
||||
elif db_type == "mysql" or db_type == "db":
|
||||
db_url = f"mysql+asyncmy://{mysql_db_config['user']}:{mysql_db_config['password']}@{mysql_db_config['host']}:{mysql_db_config['port']}/{mysql_db_config['db_name']}"
|
||||
elif db_type == "postgresql":
|
||||
db_url = f"postgresql+asyncpg://{postgresql_db_config['user']}:{postgresql_db_config['password']}@{postgresql_db_config['host']}:{postgresql_db_config['port']}/{postgresql_db_config['db_name']}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {db_type}")
|
||||
|
||||
engine = create_async_engine(db_url, echo=False)
|
||||
_engines[db_type] = engine
|
||||
return engine
|
||||
|
||||
|
||||
async def create_tables(db_type: str = None):
|
||||
if db_type is None:
|
||||
db_type = config.SAVE_DATA_OPTION
|
||||
await create_database_if_not_exists(db_type)
|
||||
engine = get_async_engine(db_type)
|
||||
if engine:
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_session() -> AsyncSession:
|
||||
engine = get_async_engine(config.SAVE_DATA_OPTION)
|
||||
if not engine:
|
||||
yield None
|
||||
return
|
||||
AsyncSessionFactory = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
session = AsyncSessionFactory()
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
finally:
|
||||
await session.close()
|
||||
Reference in New Issue
Block a user