type cheking

This commit is contained in:
IluaAir
2025-08-11 10:07:56 +03:00
parent 644d5614b9
commit ddc38dbd07
5 changed files with 27 additions and 20 deletions

View File

@@ -1,15 +1,15 @@
from typing import Annotated, AsyncGenerator from typing import Annotated, AsyncGenerator
from fastapi import Depends from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from src.core.database import async_session_maker from src.core.database import async_session_maker
from src.core.db_manager import DBManager from src.core.db_manager import DBManager
from src.core.interfaces import IUOWDB
async def get_db() -> AsyncGenerator[AsyncSession, None]: async def get_db() -> AsyncGenerator[IUOWDB, None]:
async with DBManager(async_session_maker) as db: async with DBManager(async_session_maker) as db:
yield db yield db
sessionDep = Annotated[AsyncSession, Depends(get_db)] sessionDep = Annotated[IUOWDB, Depends(get_db)]

View File

@@ -1,20 +1,24 @@
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from src.repository.tasks import TasksRepo from src.repository.tasks import TasksRepo
from src.repository.users import UsersRepo from src.repository.users import UsersRepo
class DBManager: class DBManager:
def __init__(self, session_factory): def __init__(self, session_factory: async_sessionmaker[AsyncSession]):
self.session_factory = session_factory self.session_factory = session_factory
async def __aenter__(self): async def __aenter__(self) -> "DBManager":
self.session = self.session_factory() self.session: AsyncSession = self.session_factory()
self.user = UsersRepo(self.session) self.user = UsersRepo(self.session)
self.task = TasksRepo(self.session) self.task = TasksRepo(self.session)
return self return self
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.session.rollback() await self.session.rollback()
await self.session.close() await self.session.close()
async def commit(self): async def commit(self) -> None:
await self.session.commit() await self.session.commit()

View File

@@ -1,15 +1,18 @@
from typing import Protocol from typing import Any, Protocol
from sqlalchemy.ext.asyncio import AsyncSession
from src.repository.tasks import TasksRepo from src.repository.tasks import TasksRepo
from src.repository.users import UsersRepo from src.repository.users import UsersRepo
class IUOWDB(Protocol): class IUOWDB(Protocol):
session: AsyncSession
user: UsersRepo user: UsersRepo
task: TasksRepo task: TasksRepo
async def __aenter__(self) -> "IUOWDB": ... async def __aenter__(self) -> "IUOWDB": ...
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: ... async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ...
async def commit(self) -> None: ... async def commit(self) -> None: ...

View File

@@ -7,31 +7,31 @@ from src.core.database import Base
class BaseRepo: class BaseRepo:
model: type[Base] = None model: type[Base]
def __init__(self, session): def __init__(self, session):
self.session = session self.session = session
async def get_filtered(self, *filter, **filter_by) -> list[BaseModel | Any]: async def get_filtered(self, *filter, **filter_by) -> list[Base]:
query = select(self.model).filter(*filter).filter_by(**filter_by) query = select(self.model).filter(*filter).filter_by(**filter_by)
result = await self.session.execute(query) result = await self.session.execute(query)
model = result.scalars().one_or_none() models = result.scalars().all()
return model return models
async def create_one(self, data: BaseModel): async def create_one(self, data: BaseModel) -> Base:
statement = insert(self.model).values(data.model_dump()).returning(self.model) statement = insert(self.model).values(data.model_dump()).returning(self.model)
result = await self.session.execute(statement) result = await self.session.execute(statement)
obj = result.scalar_one() obj = result.scalar_one()
return obj return obj
async def get_one_or_none(self, **filter_by): async def get_one_or_none(self, **filter_by: Any) -> Base | None:
query = select(self.model).filter_by(**filter_by) query = select(self.model).filter_by(**filter_by)
result = await self.session.execute(query) result = await self.session.execute(query)
model = result.scalars().one_or_none() model = result.scalars().one_or_none()
return model return model
async def get_all(self, *args, **kwargs) -> list[model]: async def get_all(self, *args, **kwargs) -> list[Base]:
result = await self.get_filtered() result = await self.get_filtered(*args, **kwargs)
return result return result
async def delete_one(self, **filter_by) -> None: async def delete_one(self, **filter_by) -> None:

View File

@@ -2,7 +2,7 @@ from src.core.interfaces import IUOWDB
class BaseService: class BaseService:
session: IUOWDB | None session: IUOWDB
def __init__(self, session: "IUOWDB"): def __init__(self, session: "IUOWDB"):
self.session = session self.session = session