delete protocol

This commit is contained in:
IluaAir
2025-09-28 22:34:14 +03:00
parent 64dcc77518
commit 91daaf9275
7 changed files with 8 additions and 17 deletions

View File

@@ -8,10 +8,6 @@ if TYPE_CHECKING:
from src.repository.users import UsersRepo from src.repository.users import UsersRepo
class HasId(Protocol):
id: Any
class IUOWDB(Protocol): class IUOWDB(Protocol):
session: AsyncSession session: AsyncSession
user: "UsersRepo" user: "UsersRepo"

View File

@@ -2,5 +2,5 @@ from src.models.tokens import RefreshTokensORM
from src.repository.base import BaseRepo from src.repository.base import BaseRepo
class AuthRepo(BaseRepo): class AuthRepo(BaseRepo[RefreshTokensORM]):
model: type[RefreshTokensORM] = RefreshTokensORM model: type[RefreshTokensORM] = RefreshTokensORM

View File

@@ -3,9 +3,7 @@ from typing import Any, Generic, Mapping, Sequence, Type, TypeVar
from sqlalchemy import delete, insert, select, update from sqlalchemy import delete, insert, select, update
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from src.core.interfaces import HasId ModelType = TypeVar("ModelType")
ModelType = TypeVar("ModelType", bound=HasId)
class BaseRepo(Generic[ModelType]): class BaseRepo(Generic[ModelType]):
@@ -45,12 +43,9 @@ class BaseRepo(Generic[ModelType]):
async def delete_one(self, **filter_by) -> None: async def delete_one(self, **filter_by) -> None:
await self.session.execute(delete(self.model).filter_by(**filter_by)) await self.session.execute(delete(self.model).filter_by(**filter_by))
async def update_one(self, id: int, data: dict[str, Any]) -> ModelType: async def update_one(self, data: dict[str, Any], **filter_by: Any) -> ModelType:
stmt = ( stmt = (
update(self.model) update(self.model).filter_by(**filter_by).values(data).returning(self.model)
.where(self.model.id == id)
.values(data)
.returning(self.model)
) )
result = await self.session.execute(stmt) result = await self.session.execute(stmt)
model = result.scalar_one() model = result.scalar_one()

View File

@@ -2,5 +2,5 @@ from src.models.tasks import TasksORM
from src.repository.base import BaseRepo from src.repository.base import BaseRepo
class TasksRepo(BaseRepo): class TasksRepo(BaseRepo[TasksORM]):
model: type[TasksORM] = TasksORM model: type[TasksORM] = TasksORM

View File

@@ -9,7 +9,7 @@ from src.models.tasks import TasksORM
from src.repository.base import BaseRepo from src.repository.base import BaseRepo
class UsersRepo(BaseRepo): class UsersRepo(BaseRepo[UsersORM]):
model: type[UsersORM] = UsersORM model: type[UsersORM] = UsersORM
async def get_one_with_load( async def get_one_with_load(

View File

@@ -32,7 +32,7 @@ class TaskService(BaseService):
exclude_unset: bool = True, exclude_unset: bool = True,
): ):
task = await self.session.task.update_one( task = await self.session.task.update_one(
id=task_id, data=task_data.model_dump(exclude_unset=exclude_unset) data=task_data.model_dump(exclude_unset=exclude_unset), id=task_id
) )
await self.session.commit() await self.session.commit()
return Task.model_validate(task) return Task.model_validate(task)

View File

@@ -36,7 +36,7 @@ class UserService(BaseService):
async def update_user(self, id: int, update_data: UserUpdate) -> User: async def update_user(self, id: int, update_data: UserUpdate) -> User:
await self.get_user_by_filter_or_raise(id=id) await self.get_user_by_filter_or_raise(id=id)
user = await self.session.user.update_one( user = await self.session.user.update_one(
id=id, data=update_data.model_dump(exclude_unset=True) data=update_data.model_dump(exclude_unset=True), id=id
) )
await self.session.commit() await self.session.commit()
return User.model_validate(user) return User.model_validate(user)