diff --git a/src/core/interfaces.py b/src/core/interfaces.py index e63811a..0468ffc 100644 --- a/src/core/interfaces.py +++ b/src/core/interfaces.py @@ -1,15 +1,20 @@ -from typing import Any, Protocol +from typing import TYPE_CHECKING, Any, Protocol from sqlalchemy.ext.asyncio import AsyncSession -from src.repository.tasks import TasksRepo -from src.repository.users import UsersRepo +if TYPE_CHECKING: + from src.repository.tasks import TasksRepo + from src.repository.users import UsersRepo + + +class HasId(Protocol): + id: Any class IUOWDB(Protocol): session: AsyncSession - user: UsersRepo - task: TasksRepo + user: 'UsersRepo' + task: 'TasksRepo' async def __aenter__(self) -> "IUOWDB": ... diff --git a/src/repository/base.py b/src/repository/base.py index 8115822..cdfd891 100644 --- a/src/repository/base.py +++ b/src/repository/base.py @@ -1,11 +1,11 @@ from typing import Any, Generic, Mapping, Sequence, Type, TypeVar -from sqlalchemy import delete, insert, select +from sqlalchemy import delete, insert, select, update from sqlalchemy.ext.asyncio import AsyncSession -from src.core.database import Base +from src.core.interfaces import HasId -ModelType = TypeVar("ModelType", bound=Base) +ModelType = TypeVar("ModelType", bound=HasId) class BaseRepo(Generic[ModelType]): @@ -44,3 +44,14 @@ class BaseRepo(Generic[ModelType]): async def delete_one(self, **filter_by) -> None: await self.session.execute(delete(self.model).filter_by(**filter_by)) + + async def update_one(self, id: int, data: Mapping[str, Any]) -> ModelType: + stmt = ( + update(self.model) + .where(self.model.id == id) + .values(data) + .returning(self.model) + ) + result = await self.session.execute(stmt) + model = result.scalar_one() + return model diff --git a/src/services/tasks.py b/src/services/tasks.py index 19bc499..8d43720 100644 --- a/src/services/tasks.py +++ b/src/services/tasks.py @@ -28,7 +28,11 @@ class TaskService(BaseService): await self.session.task.delete_one(id=task_id) await self.session.commit() - async def update_task(self, task_id: int, task_data: TaskPATCHRequest, exclude_unset: bool = True): - task = await self.session.task.update_one(id=task_id, data=task_data.model_dump(exclude_unset=exclude_unset)) + async def update_task( + self, task_id: int, task_data: TaskPATCHRequest, exclude_unset: bool = True + ): + task = await self.session.task.update_one( + id=task_id, data=task_data.model_dump(exclude_unset=exclude_unset) + ) await self.session.commit() return Task.model_validate(task) diff --git a/tests/conftest.py b/tests/conftest.py index 6cfaa03..9aaab30 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -import json + import pytest from httpx import ASGITransport, AsyncClient from sqlalchemy import NullPool, insert