初始版本

This commit is contained in:
z66
2025-12-26 13:42:22 +08:00
parent ddb90d6c20
commit b495bc1dca
43 changed files with 2179 additions and 20 deletions
+16
View File
@@ -0,0 +1,16 @@
"""
数据库模块
"""
from .base import Base
from .session import engine, SessionLocal, get_db
from .init_db import init_db, reset_db
__all__ = [
"Base",
"engine",
"SessionLocal",
"get_db",
"init_db",
"reset_db",
]
+71
View File
@@ -0,0 +1,71 @@
"""
数据库初始化脚本
运行此脚本创建数据库表结构
用法:
python -m app.db.init_db # 创建表(如果表已存在则跳过)
python -m app.db.init_db --reset # 删除所有表后重新创建(⚠️ 会丢失数据)
"""
import sys
import argparse
from sqlmodel import SQLModel
from app.db.session import engine
from app.db.base import Base # 这会导入所有模型
# 导入所有模型以确保表被注册
from app.models import (
User, Todo, Post, Transaction, Media, Tag, MediaTag,
ChatMessage, Upload
)
def init_db() -> None:
"""初始化数据库,创建所有表"""
SQLModel.metadata.create_all(engine)
print("✅ 数据库表创建完成")
def reset_db() -> None:
"""重置数据库:删除所有表后重新创建(⚠️ 会丢失所有数据)"""
print("⚠️ 警告:将删除所有表和数据!")
SQLModel.metadata.drop_all(engine)
print("✅ 已删除所有表")
SQLModel.metadata.create_all(engine)
print("✅ 已重新创建所有表")
def main() -> None:
"""初始化数据库"""
parser = argparse.ArgumentParser(description="初始化数据库")
parser.add_argument(
"--reset",
action="store_true",
help="删除所有表后重新创建(⚠️ 会丢失所有数据)"
)
args = parser.parse_args()
if args.reset:
print("🔄 重置模式:将删除所有表后重新创建...")
reset_db()
else:
print("📦 创建模式:创建不存在的表...")
init_db()
print(f"\n✅ 数据库表操作完成!")
print(f"数据库文件位置: {engine.url}")
# 显示创建的表
print("\n数据库中的表:")
tables = [
"users", "todos", "posts", "transactions",
"media", "tags", "media_tags",
"chat_messages", "uploads"
]
for table in tables:
print(f" - {table}")
if __name__ == "__main__":
main()
+26
View File
@@ -0,0 +1,26 @@
"""
数据库会话管理
"""
from sqlalchemy.engine import Engine
from sqlalchemy import create_engine
from sqlmodel import Session
from typing import Generator
from app.core.config import settings
# 创建数据库引擎
engine: Engine = create_engine(
settings.DATABASE_URL,
connect_args={"check_same_thread": False}, # SQLite 需要此参数
echo=True # 开发环境显示SQL语句,生产环境设为False
)
def get_db() -> Generator[Session, None, None]:
"""获取数据库会话(用于依赖注入)"""
with Session(engine) as session:
yield session
# 为了向后兼容,保留 SessionLocal 别名
SessionLocal = Session