""" 认证相关 API 端点 """ from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from sqlmodel import Session, select from datetime import timedelta from app.core.config import settings from app.core.security import verify_password, create_access_token, get_password_hash from app.db.session import get_db from app.models.user import User from app.schemas.user import Token, User as UserSchema, UserCreate router = APIRouter() oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/auth/login") @router.post("/register", response_model=UserSchema, status_code=status.HTTP_201_CREATED) def register(user_in: UserCreate, db: Session = Depends(get_db)): """用户注册""" # 检查用户名是否已存在 statement = select(User).where(User.username == user_in.username) existing_user = db.exec(statement).first() if existing_user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已存在" ) # 检查邮箱是否已存在 if user_in.email: statement = select(User).where(User.email == user_in.email) existing_email = db.exec(statement).first() if existing_email: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="邮箱已被注册" ) # 创建新用户 hashed_password = get_password_hash(user_in.password) db_user = User( username=user_in.username, email=user_in.email, hashed_password=hashed_password ) db.add(db_user) db.commit() db.refresh(db_user) return db_user @router.post("/login", response_model=Token) def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): """用户登录""" # 查找用户 statement = select(User).where(User.username == form_data.username) user = db.exec(statement).first() if not user or not verify_password(form_data.password, user.hashed_password): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名或密码错误", headers={"WWW-Authenticate": "Bearer"}, ) # 创建访问令牌 access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) access_token = create_access_token( data={"sub": user.username}, expires_delta=access_token_expires ) return {"access_token": access_token, "token_type": "bearer"}