Skip to content

Commit

Permalink
add mock inference endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
hutanmihai committed Apr 22, 2024
1 parent e697ab1 commit af25faa
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 2 deletions.
3 changes: 2 additions & 1 deletion backend/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from alembic import context
from sqlalchemy import engine_from_config, pool
from sqlalchemy.ext.asyncio import AsyncEngine
from src import models
from src.settings import settings

from src import models


def get_database_url() -> str:
return settings.database_url.unicode_string()
Expand Down
2 changes: 2 additions & 0 deletions backend/src/apis/auth_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
status_code=status.HTTP_201_CREATED,
response_description="User created successfully",
responses=generate_error_responses(status.HTTP_404_NOT_FOUND, status.HTTP_403_FORBIDDEN),
response_model=TokenSchema,
)
async def register(register_schema: RegisterSchema, user_srv: UserSrv = Depends(UserSrv)):
username = register_schema.username
Expand All @@ -46,6 +47,7 @@ async def register(register_schema: RegisterSchema, user_srv: UserSrv = Depends(
status_code=status.HTTP_200_OK,
response_description="Login successful",
responses=generate_error_responses(status.HTTP_404_NOT_FOUND, status.HTTP_403_FORBIDDEN),
response_model=TokenSchema,
)
async def login(login_schema: LoginSchema, user_srv: UserSrv = Depends(UserSrv)):
email = login_schema.email
Expand Down
25 changes: 25 additions & 0 deletions backend/src/apis/inference_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from uuid import UUID

from fastapi import APIRouter, Depends, status
from src.apis.utils.utils import generate_error_responses
from src.auth.auth_bearer import auth_required
from src.schemas.inference_schema import InferenceResponseSchema, StructuredDataSchema
from src.services.inference_srv import InferenceSrv

router = APIRouter(tags=["inference"])


@router.post(
"/inference",
summary="Inference endpoint",
description="Inference endpoint",
status_code=status.HTTP_200_OK,
response_description="Inference successful",
responses=generate_error_responses(status.HTTP_403_FORBIDDEN),
response_model=InferenceResponseSchema,
)
async def inference(
structured_data_schema: StructuredDataSchema, inference_srv: InferenceSrv = Depends(InferenceSrv), user_id: UUID = Depends(auth_required)
):
prediction = await inference_srv.predict(structured_data_schema.dict())
return InferenceResponseSchema(prediction=prediction)
2 changes: 2 additions & 0 deletions backend/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from fastapi.middleware.cors import CORSMiddleware
from src.apis.auth_api import router as auth_router
from src.apis.health_api import router as health_router
from src.apis.inference_api import router as inference_router
from src.apis.payment_api import router as payment_router
from src.settings import settings
from uvicorn import run as uvicorn_run
Expand All @@ -12,6 +13,7 @@ def _register_api_handlers(app: FastAPI) -> FastAPI:
app.include_router(health_router)
app.include_router(auth_router)
app.include_router(payment_router)
app.include_router(inference_router)
return app


Expand Down
21 changes: 21 additions & 0 deletions backend/src/schemas/inference_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pydantic import BaseModel, Field


class StructuredDataSchema(BaseModel):
# categorical
manufacturer: str
model: str
fuel: str
sold_by: bool # True for dealer, False for private
gearbox: bool # True for automatic, False for manual
chassis: str

# numerical
km: int
power: int
engine: int
year: int


class InferenceResponseSchema(BaseModel):
prediction: float = Field(..., example=10000)
21 changes: 21 additions & 0 deletions backend/src/services/inference_srv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import random

from fastapi import Depends
from src.repositories.sqlalchemy_repo import SQLAlchemyRepository
from src.services.abstract_srv import AbstractService


class InferenceSrv(AbstractService):
_repository: SQLAlchemyRepository

def __init__(self, repo: SQLAlchemyRepository = Depends(SQLAlchemyRepository)):
super().__init__(repo)

async def preprocess_data(self, data):
# TODO
return data

async def predict(self, data):
# TODO
data = await self.preprocess_data(data)
return random.randint(1000, 50000)
3 changes: 2 additions & 1 deletion backend/src/tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import pytest
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from src import models
from src.database import async_engine, async_session
from src.main import app

from src import models


@pytest.fixture(scope="session")
def event_loop():
Expand Down

0 comments on commit af25faa

Please sign in to comment.