diff --git a/src/api/dependacies/db_dep.py b/src/api/dependacies/db_dep.py index df46f22..12b0f51 100644 --- a/src/api/dependacies/db_dep.py +++ b/src/api/dependacies/db_dep.py @@ -1,15 +1,15 @@ from typing import Annotated, AsyncGenerator from fastapi import Depends -from sqlalchemy.ext.asyncio import AsyncSession from src.core.database import async_session_maker from src.core.db_manager import DBManager +from src.core.interfaces import IUOWDB -async def get_db() -> AsyncGenerator[AsyncSession, None]: +async def get_db() -> AsyncGenerator[IUOWDB, None]: async with DBManager(async_session_maker) as db: yield db -sessionDep = Annotated[AsyncSession, Depends(get_db)] +sessionDep = Annotated[IUOWDB, Depends(get_db)] diff --git a/src/core/db_manager.py b/src/core/db_manager.py index 85f6897..ee461a2 100644 --- a/src/core/db_manager.py +++ b/src/core/db_manager.py @@ -1,20 +1,24 @@ +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + from src.repository.tasks import TasksRepo from src.repository.users import UsersRepo class DBManager: - def __init__(self, session_factory): + def __init__(self, session_factory: async_sessionmaker[AsyncSession]): self.session_factory = session_factory - async def __aenter__(self): - self.session = self.session_factory() + async def __aenter__(self) -> "DBManager": + self.session: AsyncSession = self.session_factory() self.user = UsersRepo(self.session) self.task = TasksRepo(self.session) return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: await self.session.rollback() await self.session.close() - async def commit(self): + async def commit(self) -> None: await self.session.commit() diff --git a/src/core/interfaces.py b/src/core/interfaces.py index d863fe3..e63811a 100644 --- a/src/core/interfaces.py +++ b/src/core/interfaces.py @@ -1,15 +1,18 @@ -from typing import Protocol +from typing import Any, Protocol + +from sqlalchemy.ext.asyncio import AsyncSession from src.repository.tasks import TasksRepo from src.repository.users import UsersRepo class IUOWDB(Protocol): + session: AsyncSession user: UsersRepo task: TasksRepo async def __aenter__(self) -> "IUOWDB": ... - async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: ... + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ... async def commit(self) -> None: ... diff --git a/src/repository/base.py b/src/repository/base.py index 8a61a9f..56734b3 100644 --- a/src/repository/base.py +++ b/src/repository/base.py @@ -7,31 +7,31 @@ from src.core.database import Base class BaseRepo: - model: type[Base] = None + model: type[Base] def __init__(self, session): self.session = session - async def get_filtered(self, *filter, **filter_by) -> list[BaseModel | Any]: + async def get_filtered(self, *filter, **filter_by) -> list[Base]: query = select(self.model).filter(*filter).filter_by(**filter_by) result = await self.session.execute(query) - model = result.scalars().one_or_none() - return model + models = result.scalars().all() + return models - async def create_one(self, data: BaseModel): + async def create_one(self, data: BaseModel) -> Base: statement = insert(self.model).values(data.model_dump()).returning(self.model) result = await self.session.execute(statement) obj = result.scalar_one() return obj - async def get_one_or_none(self, **filter_by): + async def get_one_or_none(self, **filter_by: Any) -> Base | None: query = select(self.model).filter_by(**filter_by) result = await self.session.execute(query) model = result.scalars().one_or_none() return model - - async def get_all(self, *args, **kwargs) -> list[model]: - result = await self.get_filtered() + + async def get_all(self, *args, **kwargs) -> list[Base]: + result = await self.get_filtered(*args, **kwargs) return result async def delete_one(self, **filter_by) -> None: diff --git a/src/services/base.py b/src/services/base.py index ffea73c..0de2a50 100644 --- a/src/services/base.py +++ b/src/services/base.py @@ -2,7 +2,7 @@ from src.core.interfaces import IUOWDB class BaseService: - session: IUOWDB | None + session: IUOWDB def __init__(self, session: "IUOWDB"): self.session = session