from fastapi import APIRouter from src.distilbert import Distilbert MODEL_PATH = "SebaK13/DistilBERT-finetuned-customer-queries-balanced" prediction = APIRouter(prefix="/predict", tags=["predict"]) distilbert_service = Distilbert(model_path=MODEL_PATH) @prediction.get("/predict") def predict(query: str): output = distilbert_service.predict_query_type(query) return {"label": output}