new refresh dep add refresh endpoint
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, Path
|
from fastapi import Depends, HTTPException, Path
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import HTTPBearer, OAuth2PasswordBearer
|
||||||
from jwt import InvalidTokenError
|
from jwt import InvalidTokenError
|
||||||
|
|
||||||
from src.api.dependacies.db_dep import sessionDep
|
from src.api.dependacies.db_dep import sessionDep
|
||||||
@@ -11,37 +11,46 @@ from src.schemas.auth import TokenData
|
|||||||
from src.services.tasks import TaskService
|
from src.services.tasks import TaskService
|
||||||
from src.services.users import UserService
|
from src.services.users import UserService
|
||||||
|
|
||||||
|
http_bearer = HTTPBearer(auto_error=False)
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api.v1_login_url}/login")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api.v1_login_url}/login")
|
||||||
|
AccessTokenDep = Annotated[str, Depends(oauth2_scheme)]
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
|
async def get_current_user(
|
||||||
|
token: AccessTokenDep, verify_exp: bool = True, check_active: bool = False
|
||||||
|
):
|
||||||
credentials_exception = HTTPException(
|
credentials_exception = HTTPException(
|
||||||
status_code=401,
|
status_code=401,
|
||||||
detail="Could not validate credentials",
|
detail="Could not validate credentials",
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
payload = AuthManager.decode_access_token(token=token)
|
payload = AuthManager.decode_access_token(token, verify_exp)
|
||||||
if payload is None:
|
if payload is None:
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
user = TokenData(**payload)
|
user = TokenData(**payload)
|
||||||
|
if check_active and not user.is_active:
|
||||||
|
raise HTTPException(status_code=400, detail="Inactive user")
|
||||||
except InvalidTokenError:
|
except InvalidTokenError:
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
CurrentUser = Annotated[TokenData, Depends(get_current_user)]
|
async def get_current_user_basic(token: AccessTokenDep):
|
||||||
|
return await get_current_user(token, verify_exp=True, check_active=False)
|
||||||
|
|
||||||
|
|
||||||
def get_current_active_user(
|
async def get_current_active_user(token: AccessTokenDep):
|
||||||
current_user: CurrentUser,
|
return await get_current_user(token, verify_exp=True, check_active=True)
|
||||||
):
|
|
||||||
if not current_user.is_active:
|
|
||||||
raise HTTPException(status_code=400, detail="Inactive user")
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user_for_refresh(token: AccessTokenDep):
|
||||||
|
return await get_current_user(token, verify_exp=False, check_active=True)
|
||||||
|
|
||||||
|
|
||||||
|
CurrentUser = Annotated[TokenData, Depends(get_current_user_basic)]
|
||||||
ActiveUser = Annotated[TokenData, Depends(get_current_active_user)]
|
ActiveUser = Annotated[TokenData, Depends(get_current_active_user)]
|
||||||
|
RefreshUser = Annotated[TokenData, Depends(get_current_user_for_refresh)]
|
||||||
|
|
||||||
|
|
||||||
async def get_admin_user(db: sessionDep, current_user: ActiveUser):
|
async def get_admin_user(db: sessionDep, current_user: ActiveUser):
|
||||||
|
|||||||
@@ -1,16 +1,19 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Response
|
from fastapi import APIRouter, Cookie, Depends, HTTPException, Response
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
|
||||||
from src.api.dependacies.db_dep import sessionDep
|
from src.api.dependacies.db_dep import sessionDep
|
||||||
from src.api.dependacies.user_dep import ActiveUser
|
from src.api.dependacies.user_dep import ActiveUser, RefreshUser, http_bearer
|
||||||
from src.core.settings import settings
|
from src.core.settings import settings
|
||||||
from src.schemas.auth import Token
|
from src.schemas.auth import Token
|
||||||
from src.schemas.users import UserRequestADD
|
from src.schemas.users import UserRequestADD
|
||||||
from src.services.auth import AuthService
|
from src.services.auth import AuthService
|
||||||
|
from src.services.users import UserService
|
||||||
|
|
||||||
router = APIRouter(prefix=settings.api.v1.auth, tags=["Auth"])
|
router = APIRouter(
|
||||||
|
prefix=settings.api.v1.auth, tags=["Auth"], dependencies=[Depends(http_bearer)]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post(path="/signup")
|
@router.post(path="/signup")
|
||||||
@@ -23,28 +26,55 @@ async def registration(session: sessionDep, credential: UserRequestADD):
|
|||||||
async def login(
|
async def login(
|
||||||
session: sessionDep,
|
session: sessionDep,
|
||||||
credential: Annotated[OAuth2PasswordRequestForm, Depends()],
|
credential: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||||
response: Response
|
response: Response,
|
||||||
):
|
):
|
||||||
result = await AuthService(session).login(
|
result = await AuthService(session).login(credential.username, credential.password)
|
||||||
credential.username, credential.password
|
|
||||||
)
|
|
||||||
response.set_cookie(
|
response.set_cookie(
|
||||||
key="refresh_token",
|
key="refresh_token",
|
||||||
value=result["refresh_token"],
|
value=result["refresh_token"],
|
||||||
httponly=True,
|
httponly=True,
|
||||||
samesite='lax',
|
samesite="lax",
|
||||||
path=settings.api.v1.auth,
|
path=settings.api.v1.auth,
|
||||||
max_age=60 * 60 * 24 * 7
|
max_age=60 * 60 * 24 * 7,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@router.post(path="/refresh")
|
@router.post(path="/refresh", response_model=Token)
|
||||||
async def refresh(user: ActiveUser, response: Response):
|
async def refresh(
|
||||||
print(response)
|
session: sessionDep,
|
||||||
|
response: Response,
|
||||||
|
current_user: RefreshUser,
|
||||||
|
refresh_token: Annotated[str | None, Cookie(name="refresh_token")] = None,
|
||||||
|
):
|
||||||
|
if refresh_token is None:
|
||||||
|
raise HTTPException(status_code=401, detail="No refresh token")
|
||||||
|
result = await AuthService(session).refresh_tokens(refresh_token, current_user)
|
||||||
|
response.set_cookie(
|
||||||
|
key="refresh_token",
|
||||||
|
value=result["refresh_token"],
|
||||||
|
httponly=True,
|
||||||
|
samesite="lax",
|
||||||
|
path=settings.api.v1.auth,
|
||||||
|
max_age=60 * 60 * 24 * 7,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@router.post(path='/logout')
|
@router.get("/me")
|
||||||
async def logout(response: Response):
|
async def get_me(session: sessionDep, user: ActiveUser):
|
||||||
|
cur_user = await UserService(session).get_user_by_filter_or_raise(id=user.id)
|
||||||
|
return cur_user
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(path="/logout")
|
||||||
|
async def logout(
|
||||||
|
session: sessionDep,
|
||||||
|
response: Response,
|
||||||
|
refresh_token: Annotated[str | None, Cookie(name="refresh_token")] = None,
|
||||||
|
):
|
||||||
|
if refresh_token is None:
|
||||||
|
raise HTTPException(status_code=401, detail="No refresh token")
|
||||||
|
await AuthService(session).delete_token(token=refresh_token)
|
||||||
response.delete_cookie(key="refresh_token")
|
response.delete_cookie(key="refresh_token")
|
||||||
return {'status': 'ok'}
|
return {"status": "ok"}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ from fastapi import APIRouter, Body
|
|||||||
|
|
||||||
from src.api.dependacies.db_dep import sessionDep
|
from src.api.dependacies.db_dep import sessionDep
|
||||||
from src.api.dependacies.user_dep import (
|
from src.api.dependacies.user_dep import (
|
||||||
ActiveUser,
|
|
||||||
AdminUser,
|
AdminUser,
|
||||||
OwnerDep,
|
OwnerDep,
|
||||||
)
|
)
|
||||||
@@ -13,12 +12,6 @@ from src.services.users import UserService
|
|||||||
router = APIRouter(prefix=settings.api.v1.users, tags=["Users"])
|
router = APIRouter(prefix=settings.api.v1.users, tags=["Users"])
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me")
|
|
||||||
async def get_me(session: sessionDep, user: ActiveUser):
|
|
||||||
cur_user = await UserService(session).get_user_by_filter_or_raise(id=user.id)
|
|
||||||
return cur_user
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/")
|
@router.get("/")
|
||||||
async def get_all_users(session: sessionDep, _: AdminUser):
|
async def get_all_users(session: sessionDep, _: AdminUser):
|
||||||
users = await UserService(session).get_all_users()
|
users = await UserService(session).get_all_users()
|
||||||
|
|||||||
@@ -37,16 +37,14 @@ class AuthManager:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_refresh_token(cls) -> str:
|
def create_refresh_token(cls) -> str:
|
||||||
# random_bytes = os.urandom(32)
|
|
||||||
# data = settings.refresh_token.secret_key.encode() + random_bytes
|
|
||||||
# token_hash = bcrypt.hashpw(data, bcrypt.gensalt(rounds=12)).decode()
|
|
||||||
token_hash = secrets.token_urlsafe(32)
|
token_hash = secrets.token_urlsafe(32)
|
||||||
return token_hash
|
return token_hash
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def decode_access_token(cls, token: str) -> dict:
|
def decode_access_token(cls, token: str, verify_exp: bool = True) -> dict:
|
||||||
return jwt.decode(
|
return jwt.decode(
|
||||||
token,
|
token,
|
||||||
settings.access_token.secret_key,
|
settings.access_token.secret_key,
|
||||||
algorithms=[settings.access_token.algorithm],
|
algorithms=[settings.access_token.algorithm],
|
||||||
|
options={"verify_exp": verify_exp},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -14,9 +14,9 @@ class HasId(Protocol):
|
|||||||
|
|
||||||
class IUOWDB(Protocol):
|
class IUOWDB(Protocol):
|
||||||
session: AsyncSession
|
session: AsyncSession
|
||||||
user: 'UsersRepo'
|
user: "UsersRepo"
|
||||||
task: 'TasksRepo'
|
task: "TasksRepo"
|
||||||
auth: 'AuthRepo'
|
auth: "AuthRepo"
|
||||||
|
|
||||||
async def __aenter__(self) -> "IUOWDB": ...
|
async def __aenter__(self) -> "IUOWDB": ...
|
||||||
|
|
||||||
|
|||||||
@@ -5,14 +5,15 @@ Revises: 4b0f3ea2fd26
|
|||||||
Create Date: 2025-09-08 14:56:01.439089
|
Create Date: 2025-09-08 14:56:01.439089
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = 'b879d3502c37'
|
revision: str = "b879d3502c37"
|
||||||
down_revision: Union[str, None] = '4b0f3ea2fd26'
|
down_revision: Union[str, None] = "4b0f3ea2fd26"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
@@ -20,13 +21,22 @@ depends_on: Union[str, Sequence[str], None] = None
|
|||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
"""Upgrade schema."""
|
"""Upgrade schema."""
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.create_table('refresh_tokens',
|
op.create_table(
|
||||||
sa.Column('id', sa.Integer(), nullable=False),
|
"refresh_tokens",
|
||||||
sa.Column('token', sa.String(length=255), nullable=False),
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
sa.Column('user_id', sa.Integer(), nullable=False),
|
sa.Column("token", sa.String(length=255), nullable=False),
|
||||||
sa.Column('created_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
|
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
sa.Column(
|
||||||
sa.PrimaryKeyConstraint('id')
|
"created_at",
|
||||||
|
sa.TIMESTAMP(timezone=True),
|
||||||
|
server_default=sa.text("(CURRENT_TIMESTAMP)"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["user_id"],
|
||||||
|
["users.id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
)
|
)
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
@@ -34,5 +44,5 @@ def upgrade() -> None:
|
|||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
"""Downgrade schema."""
|
"""Downgrade schema."""
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.drop_table('refresh_tokens')
|
op.drop_table("refresh_tokens")
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|||||||
@@ -2,8 +2,4 @@ from src.models.tasks import TasksORM
|
|||||||
from src.models.tokens import RefreshTokensORM
|
from src.models.tokens import RefreshTokensORM
|
||||||
from src.models.users import UsersORM
|
from src.models.users import UsersORM
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["UsersORM", "TasksORM", "RefreshTokensORM"]
|
||||||
"UsersORM",
|
|
||||||
"TasksORM",
|
|
||||||
"RefreshTokensORM"
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -4,12 +4,12 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||||||
class Token(BaseModel):
|
class Token(BaseModel):
|
||||||
access_token: str
|
access_token: str
|
||||||
token_type: str
|
token_type: str
|
||||||
model_config = ConfigDict(extra='ignore')
|
model_config = ConfigDict(extra="ignore")
|
||||||
|
|
||||||
|
|
||||||
class TokenData(BaseModel):
|
class TokenData(BaseModel):
|
||||||
id: int
|
id: int
|
||||||
sub: str = Field(alias='username')
|
sub: str = Field(alias="username")
|
||||||
is_superuser: bool
|
is_superuser: bool
|
||||||
is_active: bool
|
is_active: bool
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,6 @@ def ensure_username(value: str) -> str:
|
|||||||
value = value.strip()
|
value = value.strip()
|
||||||
if len(value) < 3:
|
if len(value) < 3:
|
||||||
raise ValueError("Username must be at least 3 characters")
|
raise ValueError("Username must be at least 3 characters")
|
||||||
elif value.lower() in ['admin', 'moderator', 'админ', 'модератор']:
|
elif value.lower() in ["admin", "moderator", "админ", "модератор"]:
|
||||||
raise ValueError("Login is already taken")
|
raise ValueError("Login is already taken")
|
||||||
return value
|
return value
|
||||||
@@ -36,10 +36,34 @@ class AuthService(BaseService):
|
|||||||
status_code=401,
|
status_code=401,
|
||||||
detail="Incorrect username or password",
|
detail="Incorrect username or password",
|
||||||
)
|
)
|
||||||
access_token = AuthManager.create_access_token(
|
access_token = AuthManager.create_access_token(user_token.model_dump())
|
||||||
user_token.model_dump()
|
|
||||||
)
|
|
||||||
refresh_token = AuthManager.create_refresh_token()
|
refresh_token = AuthManager.create_refresh_token()
|
||||||
await self.session.auth.create_one({"token": refresh_token, "user_id": user.id})
|
await self.session.auth.create_one({"token": refresh_token, "user_id": user.id})
|
||||||
await self.session.commit()
|
await self.session.commit()
|
||||||
return {"access_token": access_token, "token_type": settings.access_token.token_type, "refresh_token": refresh_token}
|
return {
|
||||||
|
"access_token": access_token,
|
||||||
|
"token_type": settings.access_token.token_type,
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def delete_token(self, token: str) -> None:
|
||||||
|
await self.session.auth.delete_one(token=token)
|
||||||
|
await self.session.commit()
|
||||||
|
|
||||||
|
async def refresh_tokens(self, refresh_token: str, user_data: TokenData):
|
||||||
|
token_record = await self.session.auth.get_one_or_none(token=refresh_token)
|
||||||
|
if not token_record or token_record.user_id != user_data.id:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
||||||
|
new_access_token = AuthManager.create_access_token(user_data.model_dump())
|
||||||
|
new_refresh_token = AuthManager.create_refresh_token()
|
||||||
|
await self.session.auth.delete_one(token=refresh_token)
|
||||||
|
await self.session.auth.create_one({
|
||||||
|
"token": new_refresh_token,
|
||||||
|
"user_id": user_data.id,
|
||||||
|
})
|
||||||
|
await self.session.commit()
|
||||||
|
return {
|
||||||
|
"access_token": new_access_token,
|
||||||
|
"token_type": settings.access_token.token_type,
|
||||||
|
"refresh_token": new_refresh_token,
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from httpx import ASGITransport, AsyncClient
|
from httpx import ASGITransport, AsyncClient
|
||||||
from sqlalchemy import NullPool, insert
|
from sqlalchemy import NullPool, insert
|
||||||
|
|||||||
Reference in New Issue
Block a user