diff --git a/src/core/interfaces.py b/src/core/interfaces.py index 805e362..ec36ab3 100644 --- a/src/core/interfaces.py +++ b/src/core/interfaces.py @@ -8,10 +8,6 @@ if TYPE_CHECKING: from src.repository.users import UsersRepo -class HasId(Protocol): - id: Any - - class IUOWDB(Protocol): session: AsyncSession user: "UsersRepo" diff --git a/src/repository/auth.py b/src/repository/auth.py index 3524cab..72f5df1 100644 --- a/src/repository/auth.py +++ b/src/repository/auth.py @@ -2,5 +2,5 @@ from src.models.tokens import RefreshTokensORM from src.repository.base import BaseRepo -class AuthRepo(BaseRepo): +class AuthRepo(BaseRepo[RefreshTokensORM]): model: type[RefreshTokensORM] = RefreshTokensORM diff --git a/src/repository/base.py b/src/repository/base.py index 60953e1..d960b8f 100644 --- a/src/repository/base.py +++ b/src/repository/base.py @@ -3,9 +3,7 @@ from typing import Any, Generic, Mapping, Sequence, Type, TypeVar from sqlalchemy import delete, insert, select, update from sqlalchemy.ext.asyncio import AsyncSession -from src.core.interfaces import HasId - -ModelType = TypeVar("ModelType", bound=HasId) +ModelType = TypeVar("ModelType") class BaseRepo(Generic[ModelType]): @@ -45,12 +43,9 @@ class BaseRepo(Generic[ModelType]): async def delete_one(self, **filter_by) -> None: await self.session.execute(delete(self.model).filter_by(**filter_by)) - async def update_one(self, id: int, data: dict[str, Any]) -> ModelType: + async def update_one(self, data: dict[str, Any], **filter_by: Any) -> ModelType: stmt = ( - update(self.model) - .where(self.model.id == id) - .values(data) - .returning(self.model) + update(self.model).filter_by(**filter_by).values(data).returning(self.model) ) result = await self.session.execute(stmt) model = result.scalar_one() diff --git a/src/repository/tasks.py b/src/repository/tasks.py index 83a2527..bb354cc 100644 --- a/src/repository/tasks.py +++ b/src/repository/tasks.py @@ -2,5 +2,5 @@ from src.models.tasks import TasksORM from src.repository.base import BaseRepo -class TasksRepo(BaseRepo): +class TasksRepo(BaseRepo[TasksORM]): model: type[TasksORM] = TasksORM diff --git a/src/repository/users.py b/src/repository/users.py index 021785c..6d87602 100644 --- a/src/repository/users.py +++ b/src/repository/users.py @@ -9,7 +9,7 @@ from src.models.tasks import TasksORM from src.repository.base import BaseRepo -class UsersRepo(BaseRepo): +class UsersRepo(BaseRepo[UsersORM]): model: type[UsersORM] = UsersORM async def get_one_with_load( diff --git a/src/services/tasks.py b/src/services/tasks.py index f093e05..85b88c4 100644 --- a/src/services/tasks.py +++ b/src/services/tasks.py @@ -32,7 +32,7 @@ class TaskService(BaseService): exclude_unset: bool = True, ): task = await self.session.task.update_one( - id=task_id, data=task_data.model_dump(exclude_unset=exclude_unset) + data=task_data.model_dump(exclude_unset=exclude_unset), id=task_id ) await self.session.commit() return Task.model_validate(task) diff --git a/src/services/users.py b/src/services/users.py index ce90ff2..cd425fd 100644 --- a/src/services/users.py +++ b/src/services/users.py @@ -36,7 +36,7 @@ class UserService(BaseService): async def update_user(self, id: int, update_data: UserUpdate) -> User: await self.get_user_by_filter_or_raise(id=id) user = await self.session.user.update_one( - id=id, data=update_data.model_dump(exclude_unset=True) + data=update_data.model_dump(exclude_unset=True), id=id ) await self.session.commit() return User.model_validate(user)