Files
2025-12-26 17:29:22 +08:00

78 lines
2.6 KiB
Python

"""
认证相关 API 端点
"""
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from sqlmodel import Session, select
from datetime import timedelta
from app.core.config import settings
from app.core.security import verify_password, create_access_token, get_password_hash
from app.db.session import get_db
from app.models.user import User
from app.schemas.user import Token, User as UserSchema, UserCreate
router = APIRouter()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/auth/login")
@router.post("/register", response_model=UserSchema, status_code=status.HTTP_201_CREATED)
def register(user_in: UserCreate, db: Session = Depends(get_db)):
"""用户注册"""
# 检查用户名是否已存在
statement = select(User).where(User.username == user_in.username)
existing_user = db.exec(statement).first()
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="用户名已存在"
)
# 检查邮箱是否已存在
if user_in.email:
statement = select(User).where(User.email == user_in.email)
existing_email = db.exec(statement).first()
if existing_email:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="邮箱已被注册"
)
# 创建新用户
hashed_password = get_password_hash(user_in.password)
db_user = User(
username=user_in.username,
email=user_in.email,
hashed_password=hashed_password
)
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user
@router.post("/login", response_model=Token)
def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
"""用户登录"""
# 查找用户
statement = select(User).where(User.username == form_data.username)
user = db.exec(statement).first()
if not user or not verify_password(form_data.password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
headers={"WWW-Authenticate": "Bearer"},
)
# 创建访问令牌
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}