type cheking

This commit is contained in:
IluaAir
2025-08-11 10:07:56 +03:00
parent 644d5614b9
commit ddc38dbd07
5 changed files with 27 additions and 20 deletions

View File

@@ -1,15 +1,15 @@
from typing import Annotated, AsyncGenerator
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from src.core.database import async_session_maker
from src.core.db_manager import DBManager
from src.core.interfaces import IUOWDB
async def get_db() -> AsyncGenerator[AsyncSession, None]:
async def get_db() -> AsyncGenerator[IUOWDB, None]:
async with DBManager(async_session_maker) as db:
yield db
sessionDep = Annotated[AsyncSession, Depends(get_db)]
sessionDep = Annotated[IUOWDB, Depends(get_db)]

View File

@@ -1,20 +1,24 @@
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from src.repository.tasks import TasksRepo
from src.repository.users import UsersRepo
class DBManager:
def __init__(self, session_factory):
def __init__(self, session_factory: async_sessionmaker[AsyncSession]):
self.session_factory = session_factory
async def __aenter__(self):
self.session = self.session_factory()
async def __aenter__(self) -> "DBManager":
self.session: AsyncSession = self.session_factory()
self.user = UsersRepo(self.session)
self.task = TasksRepo(self.session)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.session.rollback()
await self.session.close()
async def commit(self):
async def commit(self) -> None:
await self.session.commit()

View File

@@ -1,15 +1,18 @@
from typing import Protocol
from typing import Any, Protocol
from sqlalchemy.ext.asyncio import AsyncSession
from src.repository.tasks import TasksRepo
from src.repository.users import UsersRepo
class IUOWDB(Protocol):
session: AsyncSession
user: UsersRepo
task: TasksRepo
async def __aenter__(self) -> "IUOWDB": ...
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: ...
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ...
async def commit(self) -> None: ...

View File

@@ -7,31 +7,31 @@ from src.core.database import Base
class BaseRepo:
model: type[Base] = None
model: type[Base]
def __init__(self, session):
self.session = session
async def get_filtered(self, *filter, **filter_by) -> list[BaseModel | Any]:
async def get_filtered(self, *filter, **filter_by) -> list[Base]:
query = select(self.model).filter(*filter).filter_by(**filter_by)
result = await self.session.execute(query)
model = result.scalars().one_or_none()
return model
models = result.scalars().all()
return models
async def create_one(self, data: BaseModel):
async def create_one(self, data: BaseModel) -> Base:
statement = insert(self.model).values(data.model_dump()).returning(self.model)
result = await self.session.execute(statement)
obj = result.scalar_one()
return obj
async def get_one_or_none(self, **filter_by):
async def get_one_or_none(self, **filter_by: Any) -> Base | None:
query = select(self.model).filter_by(**filter_by)
result = await self.session.execute(query)
model = result.scalars().one_or_none()
return model
async def get_all(self, *args, **kwargs) -> list[model]:
result = await self.get_filtered()
async def get_all(self, *args, **kwargs) -> list[Base]:
result = await self.get_filtered(*args, **kwargs)
return result
async def delete_one(self, **filter_by) -> None:

View File

@@ -2,7 +2,7 @@ from src.core.interfaces import IUOWDB
class BaseService:
session: IUOWDB | None
session: IUOWDB
def __init__(self, session: "IUOWDB"):
self.session = session