From 42b8c3a2c9cb25a6502de2533917db79a6e582c8 Mon Sep 17 00:00:00 2001 From: IluaAir Date: Sun, 21 Sep 2025 14:59:26 +0300 Subject: [PATCH] new refresh dep add refresh endpoint --- src/api/dependacies/user_dep.py | 29 +++++---- src/api/v1/auth.py | 60 ++++++++++++++----- src/api/v1/users.py | 7 --- src/core/auth_manager.py | 8 +-- src/core/interfaces.py | 6 +- ..._08_1456-b879d3502c37_add_refresh_token.py | 30 ++++++---- src/models/__init__.py | 6 +- src/schemas/auth.py | 6 +- src/schemas/validators.py | 4 +- src/services/auth.py | 32 ++++++++-- tests/conftest.py | 1 - tests/unit_tests/test_repo_db.py | 2 +- 12 files changed, 125 insertions(+), 66 deletions(-) diff --git a/src/api/dependacies/user_dep.py b/src/api/dependacies/user_dep.py index d5e8e5d..40dc167 100644 --- a/src/api/dependacies/user_dep.py +++ b/src/api/dependacies/user_dep.py @@ -1,7 +1,7 @@ from typing import Annotated from fastapi import Depends, HTTPException, Path -from fastapi.security import OAuth2PasswordBearer +from fastapi.security import HTTPBearer, OAuth2PasswordBearer from jwt import InvalidTokenError from src.api.dependacies.db_dep import sessionDep @@ -11,37 +11,46 @@ from src.schemas.auth import TokenData from src.services.tasks import TaskService from src.services.users import UserService +http_bearer = HTTPBearer(auto_error=False) oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api.v1_login_url}/login") +AccessTokenDep = Annotated[str, Depends(oauth2_scheme)] -async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]): +async def get_current_user( + token: AccessTokenDep, verify_exp: bool = True, check_active: bool = False +): credentials_exception = HTTPException( status_code=401, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: - payload = AuthManager.decode_access_token(token=token) + payload = AuthManager.decode_access_token(token, verify_exp) if payload is None: raise credentials_exception user = TokenData(**payload) + if check_active and not user.is_active: + raise HTTPException(status_code=400, detail="Inactive user") except InvalidTokenError: raise credentials_exception return user -CurrentUser = Annotated[TokenData, Depends(get_current_user)] +async def get_current_user_basic(token: AccessTokenDep): + return await get_current_user(token, verify_exp=True, check_active=False) -def get_current_active_user( - current_user: CurrentUser, -): - if not current_user.is_active: - raise HTTPException(status_code=400, detail="Inactive user") - return current_user +async def get_current_active_user(token: AccessTokenDep): + return await get_current_user(token, verify_exp=True, check_active=True) +async def get_current_user_for_refresh(token: AccessTokenDep): + return await get_current_user(token, verify_exp=False, check_active=True) + + +CurrentUser = Annotated[TokenData, Depends(get_current_user_basic)] ActiveUser = Annotated[TokenData, Depends(get_current_active_user)] +RefreshUser = Annotated[TokenData, Depends(get_current_user_for_refresh)] async def get_admin_user(db: sessionDep, current_user: ActiveUser): diff --git a/src/api/v1/auth.py b/src/api/v1/auth.py index 3ee1e00..26d98ce 100644 --- a/src/api/v1/auth.py +++ b/src/api/v1/auth.py @@ -1,16 +1,19 @@ from typing import Annotated -from fastapi import APIRouter, Depends, Response +from fastapi import APIRouter, Cookie, Depends, HTTPException, Response from fastapi.security import OAuth2PasswordRequestForm from src.api.dependacies.db_dep import sessionDep -from src.api.dependacies.user_dep import ActiveUser +from src.api.dependacies.user_dep import ActiveUser, RefreshUser, http_bearer from src.core.settings import settings from src.schemas.auth import Token from src.schemas.users import UserRequestADD from src.services.auth import AuthService +from src.services.users import UserService -router = APIRouter(prefix=settings.api.v1.auth, tags=["Auth"]) +router = APIRouter( + prefix=settings.api.v1.auth, tags=["Auth"], dependencies=[Depends(http_bearer)] +) @router.post(path="/signup") @@ -23,28 +26,55 @@ async def registration(session: sessionDep, credential: UserRequestADD): async def login( session: sessionDep, credential: Annotated[OAuth2PasswordRequestForm, Depends()], - response: Response + response: Response, ): - result = await AuthService(session).login( - credential.username, credential.password - ) + result = await AuthService(session).login(credential.username, credential.password) response.set_cookie( key="refresh_token", value=result["refresh_token"], httponly=True, - samesite='lax', + samesite="lax", path=settings.api.v1.auth, - max_age=60 * 60 * 24 * 7 + max_age=60 * 60 * 24 * 7, ) return result -@router.post(path="/refresh") -async def refresh(user: ActiveUser, response: Response): - print(response) +@router.post(path="/refresh", response_model=Token) +async def refresh( + session: sessionDep, + response: Response, + current_user: RefreshUser, + refresh_token: Annotated[str | None, Cookie(name="refresh_token")] = None, +): + if refresh_token is None: + raise HTTPException(status_code=401, detail="No refresh token") + result = await AuthService(session).refresh_tokens(refresh_token, current_user) + response.set_cookie( + key="refresh_token", + value=result["refresh_token"], + httponly=True, + samesite="lax", + path=settings.api.v1.auth, + max_age=60 * 60 * 24 * 7, + ) + return result -@router.post(path='/logout') -async def logout(response: Response): +@router.get("/me") +async def get_me(session: sessionDep, user: ActiveUser): + cur_user = await UserService(session).get_user_by_filter_or_raise(id=user.id) + return cur_user + + +@router.post(path="/logout") +async def logout( + session: sessionDep, + response: Response, + refresh_token: Annotated[str | None, Cookie(name="refresh_token")] = None, +): + if refresh_token is None: + raise HTTPException(status_code=401, detail="No refresh token") + await AuthService(session).delete_token(token=refresh_token) response.delete_cookie(key="refresh_token") - return {'status': 'ok'} \ No newline at end of file + return {"status": "ok"} diff --git a/src/api/v1/users.py b/src/api/v1/users.py index 55fad33..ea48990 100644 --- a/src/api/v1/users.py +++ b/src/api/v1/users.py @@ -2,7 +2,6 @@ from fastapi import APIRouter, Body from src.api.dependacies.db_dep import sessionDep from src.api.dependacies.user_dep import ( - ActiveUser, AdminUser, OwnerDep, ) @@ -13,12 +12,6 @@ from src.services.users import UserService router = APIRouter(prefix=settings.api.v1.users, tags=["Users"]) -@router.get("/me") -async def get_me(session: sessionDep, user: ActiveUser): - cur_user = await UserService(session).get_user_by_filter_or_raise(id=user.id) - return cur_user - - @router.get("/") async def get_all_users(session: sessionDep, _: AdminUser): users = await UserService(session).get_all_users() diff --git a/src/core/auth_manager.py b/src/core/auth_manager.py index 2a2dd2a..61cc9ac 100644 --- a/src/core/auth_manager.py +++ b/src/core/auth_manager.py @@ -34,19 +34,17 @@ class AuthManager: algorithm=settings.access_token.algorithm, ) return encoded_jwt - + @classmethod def create_refresh_token(cls) -> str: - # random_bytes = os.urandom(32) - # data = settings.refresh_token.secret_key.encode() + random_bytes - # token_hash = bcrypt.hashpw(data, bcrypt.gensalt(rounds=12)).decode() token_hash = secrets.token_urlsafe(32) return token_hash @classmethod - def decode_access_token(cls, token: str) -> dict: + def decode_access_token(cls, token: str, verify_exp: bool = True) -> dict: return jwt.decode( token, settings.access_token.secret_key, algorithms=[settings.access_token.algorithm], + options={"verify_exp": verify_exp}, ) diff --git a/src/core/interfaces.py b/src/core/interfaces.py index 9dcdff6..805e362 100644 --- a/src/core/interfaces.py +++ b/src/core/interfaces.py @@ -14,9 +14,9 @@ class HasId(Protocol): class IUOWDB(Protocol): session: AsyncSession - user: 'UsersRepo' - task: 'TasksRepo' - auth: 'AuthRepo' + user: "UsersRepo" + task: "TasksRepo" + auth: "AuthRepo" async def __aenter__(self) -> "IUOWDB": ... diff --git a/src/migrations/versions/2025_09_08_1456-b879d3502c37_add_refresh_token.py b/src/migrations/versions/2025_09_08_1456-b879d3502c37_add_refresh_token.py index 47f0fbd..35ec4f8 100644 --- a/src/migrations/versions/2025_09_08_1456-b879d3502c37_add_refresh_token.py +++ b/src/migrations/versions/2025_09_08_1456-b879d3502c37_add_refresh_token.py @@ -5,14 +5,15 @@ Revises: 4b0f3ea2fd26 Create Date: 2025-09-08 14:56:01.439089 """ + from typing import Sequence, Union import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision: str = 'b879d3502c37' -down_revision: Union[str, None] = '4b0f3ea2fd26' +revision: str = "b879d3502c37" +down_revision: Union[str, None] = "4b0f3ea2fd26" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -20,13 +21,22 @@ depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: """Upgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - op.create_table('refresh_tokens', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('token', sa.String(length=255), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.Column('created_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_table( + "refresh_tokens", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("token", sa.String(length=255), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("(CURRENT_TIMESTAMP)"), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.PrimaryKeyConstraint("id"), ) # ### end Alembic commands ### @@ -34,5 +44,5 @@ def upgrade() -> None: def downgrade() -> None: """Downgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('refresh_tokens') + op.drop_table("refresh_tokens") # ### end Alembic commands ### diff --git a/src/models/__init__.py b/src/models/__init__.py index bc35909..9018ce8 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -2,8 +2,4 @@ from src.models.tasks import TasksORM from src.models.tokens import RefreshTokensORM from src.models.users import UsersORM -__all__ = [ - "UsersORM", - "TasksORM", - "RefreshTokensORM" -] +__all__ = ["UsersORM", "TasksORM", "RefreshTokensORM"] diff --git a/src/schemas/auth.py b/src/schemas/auth.py index 46f0143..3d5faf1 100644 --- a/src/schemas/auth.py +++ b/src/schemas/auth.py @@ -4,13 +4,13 @@ from pydantic import BaseModel, ConfigDict, Field class Token(BaseModel): access_token: str token_type: str - model_config = ConfigDict(extra='ignore') + model_config = ConfigDict(extra="ignore") class TokenData(BaseModel): id: int - sub: str = Field(alias='username') + sub: str = Field(alias="username") is_superuser: bool is_active: bool - model_config = ConfigDict(populate_by_name=True) \ No newline at end of file + model_config = ConfigDict(populate_by_name=True) diff --git a/src/schemas/validators.py b/src/schemas/validators.py index 32db8f6..39cc1c8 100644 --- a/src/schemas/validators.py +++ b/src/schemas/validators.py @@ -16,6 +16,6 @@ def ensure_username(value: str) -> str: value = value.strip() if len(value) < 3: raise ValueError("Username must be at least 3 characters") - elif value.lower() in ['admin', 'moderator', 'админ', 'модератор']: + elif value.lower() in ["admin", "moderator", "админ", "модератор"]: raise ValueError("Login is already taken") - return value \ No newline at end of file + return value diff --git a/src/services/auth.py b/src/services/auth.py index 3f24ff6..a332f78 100644 --- a/src/services/auth.py +++ b/src/services/auth.py @@ -36,10 +36,34 @@ class AuthService(BaseService): status_code=401, detail="Incorrect username or password", ) - access_token = AuthManager.create_access_token( - user_token.model_dump() - ) + access_token = AuthManager.create_access_token(user_token.model_dump()) refresh_token = AuthManager.create_refresh_token() await self.session.auth.create_one({"token": refresh_token, "user_id": user.id}) await self.session.commit() - return {"access_token": access_token, "token_type": settings.access_token.token_type, "refresh_token": refresh_token} \ No newline at end of file + return { + "access_token": access_token, + "token_type": settings.access_token.token_type, + "refresh_token": refresh_token, + } + + async def delete_token(self, token: str) -> None: + await self.session.auth.delete_one(token=token) + await self.session.commit() + + async def refresh_tokens(self, refresh_token: str, user_data: TokenData): + token_record = await self.session.auth.get_one_or_none(token=refresh_token) + if not token_record or token_record.user_id != user_data.id: + raise HTTPException(status_code=401, detail="Invalid refresh token") + new_access_token = AuthManager.create_access_token(user_data.model_dump()) + new_refresh_token = AuthManager.create_refresh_token() + await self.session.auth.delete_one(token=refresh_token) + await self.session.auth.create_one({ + "token": new_refresh_token, + "user_id": user_data.id, + }) + await self.session.commit() + return { + "access_token": new_access_token, + "token_type": settings.access_token.token_type, + "refresh_token": new_refresh_token, + } diff --git a/tests/conftest.py b/tests/conftest.py index 9aaab30..521a923 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,3 @@ - import pytest from httpx import ASGITransport, AsyncClient from sqlalchemy import NullPool, insert diff --git a/tests/unit_tests/test_repo_db.py b/tests/unit_tests/test_repo_db.py index 1c3e4c8..d9bc89f 100644 --- a/tests/unit_tests/test_repo_db.py +++ b/tests/unit_tests/test_repo_db.py @@ -83,4 +83,4 @@ async def test_tasks_crud(db: "TestDBManager"): assert task.title == data["title"] await db.task.delete_one(id=task.id) task = await db.task.get_one_or_none(id=task.id) - assert not task \ No newline at end of file + assert not task