初始版本
This commit is contained in:
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
数据库模块
|
||||
"""
|
||||
from .base import Base
|
||||
from .session import engine, SessionLocal, get_db
|
||||
from .init_db import init_db, reset_db
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"engine",
|
||||
"SessionLocal",
|
||||
"get_db",
|
||||
"init_db",
|
||||
"reset_db",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
数据库初始化脚本
|
||||
运行此脚本创建数据库表结构
|
||||
|
||||
用法:
|
||||
python -m app.db.init_db # 创建表(如果表已存在则跳过)
|
||||
python -m app.db.init_db --reset # 删除所有表后重新创建(⚠️ 会丢失数据)
|
||||
"""
|
||||
import sys
|
||||
import argparse
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from app.db.session import engine
|
||||
from app.db.base import Base # 这会导入所有模型
|
||||
|
||||
# 导入所有模型以确保表被注册
|
||||
from app.models import (
|
||||
User, Todo, Post, Transaction, Media, Tag, MediaTag,
|
||||
ChatMessage, Upload
|
||||
)
|
||||
|
||||
|
||||
def init_db() -> None:
|
||||
"""初始化数据库,创建所有表"""
|
||||
SQLModel.metadata.create_all(engine)
|
||||
print("✅ 数据库表创建完成")
|
||||
|
||||
|
||||
def reset_db() -> None:
|
||||
"""重置数据库:删除所有表后重新创建(⚠️ 会丢失所有数据)"""
|
||||
print("⚠️ 警告:将删除所有表和数据!")
|
||||
SQLModel.metadata.drop_all(engine)
|
||||
print("✅ 已删除所有表")
|
||||
SQLModel.metadata.create_all(engine)
|
||||
print("✅ 已重新创建所有表")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""初始化数据库"""
|
||||
parser = argparse.ArgumentParser(description="初始化数据库")
|
||||
parser.add_argument(
|
||||
"--reset",
|
||||
action="store_true",
|
||||
help="删除所有表后重新创建(⚠️ 会丢失所有数据)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.reset:
|
||||
print("🔄 重置模式:将删除所有表后重新创建...")
|
||||
reset_db()
|
||||
else:
|
||||
print("📦 创建模式:创建不存在的表...")
|
||||
init_db()
|
||||
|
||||
print(f"\n✅ 数据库表操作完成!")
|
||||
print(f"数据库文件位置: {engine.url}")
|
||||
|
||||
# 显示创建的表
|
||||
print("\n数据库中的表:")
|
||||
tables = [
|
||||
"users", "todos", "posts", "transactions",
|
||||
"media", "tags", "media_tags",
|
||||
"chat_messages", "uploads"
|
||||
]
|
||||
for table in tables:
|
||||
print(f" - {table}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
数据库会话管理
|
||||
"""
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy import create_engine
|
||||
from sqlmodel import Session
|
||||
from typing import Generator
|
||||
from app.core.config import settings
|
||||
|
||||
# 创建数据库引擎
|
||||
engine: Engine = create_engine(
|
||||
settings.DATABASE_URL,
|
||||
connect_args={"check_same_thread": False}, # SQLite 需要此参数
|
||||
echo=True # 开发环境显示SQL语句,生产环境设为False
|
||||
)
|
||||
|
||||
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
"""获取数据库会话(用于依赖注入)"""
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
# 为了向后兼容,保留 SessionLocal 别名
|
||||
SessionLocal = Session
|
||||
|
||||
Reference in New Issue
Block a user