new refresh dep add refresh endpoint

This commit is contained in:
IluaAir
2025-09-21 14:59:26 +03:00
parent c3bfb9cb6a
commit 42b8c3a2c9
12 changed files with 125 additions and 66 deletions

View File

@@ -1,7 +1,7 @@
from typing import Annotated from typing import Annotated
from fastapi import Depends, HTTPException, Path from fastapi import Depends, HTTPException, Path
from fastapi.security import OAuth2PasswordBearer from fastapi.security import HTTPBearer, OAuth2PasswordBearer
from jwt import InvalidTokenError from jwt import InvalidTokenError
from src.api.dependacies.db_dep import sessionDep from src.api.dependacies.db_dep import sessionDep
@@ -11,37 +11,46 @@ from src.schemas.auth import TokenData
from src.services.tasks import TaskService from src.services.tasks import TaskService
from src.services.users import UserService from src.services.users import UserService
http_bearer = HTTPBearer(auto_error=False)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api.v1_login_url}/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api.v1_login_url}/login")
AccessTokenDep = Annotated[str, Depends(oauth2_scheme)]
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]): async def get_current_user(
token: AccessTokenDep, verify_exp: bool = True, check_active: bool = False
):
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=401, status_code=401,
detail="Could not validate credentials", detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
try: try:
payload = AuthManager.decode_access_token(token=token) payload = AuthManager.decode_access_token(token, verify_exp)
if payload is None: if payload is None:
raise credentials_exception raise credentials_exception
user = TokenData(**payload) user = TokenData(**payload)
if check_active and not user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
except InvalidTokenError: except InvalidTokenError:
raise credentials_exception raise credentials_exception
return user return user
CurrentUser = Annotated[TokenData, Depends(get_current_user)] async def get_current_user_basic(token: AccessTokenDep):
return await get_current_user(token, verify_exp=True, check_active=False)
def get_current_active_user( async def get_current_active_user(token: AccessTokenDep):
current_user: CurrentUser, return await get_current_user(token, verify_exp=True, check_active=True)
):
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
async def get_current_user_for_refresh(token: AccessTokenDep):
return await get_current_user(token, verify_exp=False, check_active=True)
CurrentUser = Annotated[TokenData, Depends(get_current_user_basic)]
ActiveUser = Annotated[TokenData, Depends(get_current_active_user)] ActiveUser = Annotated[TokenData, Depends(get_current_active_user)]
RefreshUser = Annotated[TokenData, Depends(get_current_user_for_refresh)]
async def get_admin_user(db: sessionDep, current_user: ActiveUser): async def get_admin_user(db: sessionDep, current_user: ActiveUser):

View File

@@ -1,16 +1,19 @@
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Depends, Response from fastapi import APIRouter, Cookie, Depends, HTTPException, Response
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from src.api.dependacies.db_dep import sessionDep from src.api.dependacies.db_dep import sessionDep
from src.api.dependacies.user_dep import ActiveUser from src.api.dependacies.user_dep import ActiveUser, RefreshUser, http_bearer
from src.core.settings import settings from src.core.settings import settings
from src.schemas.auth import Token from src.schemas.auth import Token
from src.schemas.users import UserRequestADD from src.schemas.users import UserRequestADD
from src.services.auth import AuthService from src.services.auth import AuthService
from src.services.users import UserService
router = APIRouter(prefix=settings.api.v1.auth, tags=["Auth"]) router = APIRouter(
prefix=settings.api.v1.auth, tags=["Auth"], dependencies=[Depends(http_bearer)]
)
@router.post(path="/signup") @router.post(path="/signup")
@@ -23,28 +26,55 @@ async def registration(session: sessionDep, credential: UserRequestADD):
async def login( async def login(
session: sessionDep, session: sessionDep,
credential: Annotated[OAuth2PasswordRequestForm, Depends()], credential: Annotated[OAuth2PasswordRequestForm, Depends()],
response: Response response: Response,
): ):
result = await AuthService(session).login( result = await AuthService(session).login(credential.username, credential.password)
credential.username, credential.password
)
response.set_cookie( response.set_cookie(
key="refresh_token", key="refresh_token",
value=result["refresh_token"], value=result["refresh_token"],
httponly=True, httponly=True,
samesite='lax', samesite="lax",
path=settings.api.v1.auth, path=settings.api.v1.auth,
max_age=60 * 60 * 24 * 7 max_age=60 * 60 * 24 * 7,
) )
return result return result
@router.post(path="/refresh") @router.post(path="/refresh", response_model=Token)
async def refresh(user: ActiveUser, response: Response): async def refresh(
print(response) session: sessionDep,
response: Response,
current_user: RefreshUser,
refresh_token: Annotated[str | None, Cookie(name="refresh_token")] = None,
):
if refresh_token is None:
raise HTTPException(status_code=401, detail="No refresh token")
result = await AuthService(session).refresh_tokens(refresh_token, current_user)
response.set_cookie(
key="refresh_token",
value=result["refresh_token"],
httponly=True,
samesite="lax",
path=settings.api.v1.auth,
max_age=60 * 60 * 24 * 7,
)
return result
@router.post(path='/logout') @router.get("/me")
async def logout(response: Response): async def get_me(session: sessionDep, user: ActiveUser):
cur_user = await UserService(session).get_user_by_filter_or_raise(id=user.id)
return cur_user
@router.post(path="/logout")
async def logout(
session: sessionDep,
response: Response,
refresh_token: Annotated[str | None, Cookie(name="refresh_token")] = None,
):
if refresh_token is None:
raise HTTPException(status_code=401, detail="No refresh token")
await AuthService(session).delete_token(token=refresh_token)
response.delete_cookie(key="refresh_token") response.delete_cookie(key="refresh_token")
return {'status': 'ok'} return {"status": "ok"}

View File

@@ -2,7 +2,6 @@ from fastapi import APIRouter, Body
from src.api.dependacies.db_dep import sessionDep from src.api.dependacies.db_dep import sessionDep
from src.api.dependacies.user_dep import ( from src.api.dependacies.user_dep import (
ActiveUser,
AdminUser, AdminUser,
OwnerDep, OwnerDep,
) )
@@ -13,12 +12,6 @@ from src.services.users import UserService
router = APIRouter(prefix=settings.api.v1.users, tags=["Users"]) router = APIRouter(prefix=settings.api.v1.users, tags=["Users"])
@router.get("/me")
async def get_me(session: sessionDep, user: ActiveUser):
cur_user = await UserService(session).get_user_by_filter_or_raise(id=user.id)
return cur_user
@router.get("/") @router.get("/")
async def get_all_users(session: sessionDep, _: AdminUser): async def get_all_users(session: sessionDep, _: AdminUser):
users = await UserService(session).get_all_users() users = await UserService(session).get_all_users()

View File

@@ -34,19 +34,17 @@ class AuthManager:
algorithm=settings.access_token.algorithm, algorithm=settings.access_token.algorithm,
) )
return encoded_jwt return encoded_jwt
@classmethod @classmethod
def create_refresh_token(cls) -> str: def create_refresh_token(cls) -> str:
# random_bytes = os.urandom(32)
# data = settings.refresh_token.secret_key.encode() + random_bytes
# token_hash = bcrypt.hashpw(data, bcrypt.gensalt(rounds=12)).decode()
token_hash = secrets.token_urlsafe(32) token_hash = secrets.token_urlsafe(32)
return token_hash return token_hash
@classmethod @classmethod
def decode_access_token(cls, token: str) -> dict: def decode_access_token(cls, token: str, verify_exp: bool = True) -> dict:
return jwt.decode( return jwt.decode(
token, token,
settings.access_token.secret_key, settings.access_token.secret_key,
algorithms=[settings.access_token.algorithm], algorithms=[settings.access_token.algorithm],
options={"verify_exp": verify_exp},
) )

View File

@@ -14,9 +14,9 @@ class HasId(Protocol):
class IUOWDB(Protocol): class IUOWDB(Protocol):
session: AsyncSession session: AsyncSession
user: 'UsersRepo' user: "UsersRepo"
task: 'TasksRepo' task: "TasksRepo"
auth: 'AuthRepo' auth: "AuthRepo"
async def __aenter__(self) -> "IUOWDB": ... async def __aenter__(self) -> "IUOWDB": ...

View File

@@ -5,14 +5,15 @@ Revises: 4b0f3ea2fd26
Create Date: 2025-09-08 14:56:01.439089 Create Date: 2025-09-08 14:56:01.439089
""" """
from typing import Sequence, Union from typing import Sequence, Union
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = 'b879d3502c37' revision: str = "b879d3502c37"
down_revision: Union[str, None] = '4b0f3ea2fd26' down_revision: Union[str, None] = "4b0f3ea2fd26"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
@@ -20,13 +21,22 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
"""Upgrade schema.""" """Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.create_table('refresh_tokens', op.create_table(
sa.Column('id', sa.Integer(), nullable=False), "refresh_tokens",
sa.Column('token', sa.String(length=255), nullable=False), sa.Column("id", sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False), sa.Column("token", sa.String(length=255), nullable=False),
sa.Column('created_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False), sa.Column("user_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), sa.Column(
sa.PrimaryKeyConstraint('id') "created_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("(CURRENT_TIMESTAMP)"),
nullable=False,
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
) )
# ### end Alembic commands ### # ### end Alembic commands ###
@@ -34,5 +44,5 @@ def upgrade() -> None:
def downgrade() -> None: def downgrade() -> None:
"""Downgrade schema.""" """Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.drop_table('refresh_tokens') op.drop_table("refresh_tokens")
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@@ -2,8 +2,4 @@ from src.models.tasks import TasksORM
from src.models.tokens import RefreshTokensORM from src.models.tokens import RefreshTokensORM
from src.models.users import UsersORM from src.models.users import UsersORM
__all__ = [ __all__ = ["UsersORM", "TasksORM", "RefreshTokensORM"]
"UsersORM",
"TasksORM",
"RefreshTokensORM"
]

View File

@@ -4,13 +4,13 @@ from pydantic import BaseModel, ConfigDict, Field
class Token(BaseModel): class Token(BaseModel):
access_token: str access_token: str
token_type: str token_type: str
model_config = ConfigDict(extra='ignore') model_config = ConfigDict(extra="ignore")
class TokenData(BaseModel): class TokenData(BaseModel):
id: int id: int
sub: str = Field(alias='username') sub: str = Field(alias="username")
is_superuser: bool is_superuser: bool
is_active: bool is_active: bool
model_config = ConfigDict(populate_by_name=True) model_config = ConfigDict(populate_by_name=True)

View File

@@ -16,6 +16,6 @@ def ensure_username(value: str) -> str:
value = value.strip() value = value.strip()
if len(value) < 3: if len(value) < 3:
raise ValueError("Username must be at least 3 characters") raise ValueError("Username must be at least 3 characters")
elif value.lower() in ['admin', 'moderator', 'админ', 'модератор']: elif value.lower() in ["admin", "moderator", "админ", "модератор"]:
raise ValueError("Login is already taken") raise ValueError("Login is already taken")
return value return value

View File

@@ -36,10 +36,34 @@ class AuthService(BaseService):
status_code=401, status_code=401,
detail="Incorrect username or password", detail="Incorrect username or password",
) )
access_token = AuthManager.create_access_token( access_token = AuthManager.create_access_token(user_token.model_dump())
user_token.model_dump()
)
refresh_token = AuthManager.create_refresh_token() refresh_token = AuthManager.create_refresh_token()
await self.session.auth.create_one({"token": refresh_token, "user_id": user.id}) await self.session.auth.create_one({"token": refresh_token, "user_id": user.id})
await self.session.commit() await self.session.commit()
return {"access_token": access_token, "token_type": settings.access_token.token_type, "refresh_token": refresh_token} return {
"access_token": access_token,
"token_type": settings.access_token.token_type,
"refresh_token": refresh_token,
}
async def delete_token(self, token: str) -> None:
await self.session.auth.delete_one(token=token)
await self.session.commit()
async def refresh_tokens(self, refresh_token: str, user_data: TokenData):
token_record = await self.session.auth.get_one_or_none(token=refresh_token)
if not token_record or token_record.user_id != user_data.id:
raise HTTPException(status_code=401, detail="Invalid refresh token")
new_access_token = AuthManager.create_access_token(user_data.model_dump())
new_refresh_token = AuthManager.create_refresh_token()
await self.session.auth.delete_one(token=refresh_token)
await self.session.auth.create_one({
"token": new_refresh_token,
"user_id": user_data.id,
})
await self.session.commit()
return {
"access_token": new_access_token,
"token_type": settings.access_token.token_type,
"refresh_token": new_refresh_token,
}

View File

@@ -1,4 +1,3 @@
import pytest import pytest
from httpx import ASGITransport, AsyncClient from httpx import ASGITransport, AsyncClient
from sqlalchemy import NullPool, insert from sqlalchemy import NullPool, insert

View File

@@ -83,4 +83,4 @@ async def test_tasks_crud(db: "TestDBManager"):
assert task.title == data["title"] assert task.title == data["title"]
await db.task.delete_one(id=task.id) await db.task.delete_one(id=task.id)
task = await db.task.get_one_or_none(id=task.id) task = await db.task.get_one_or_none(id=task.id)
assert not task assert not task