ai-challenge / routers /distilbert.py
Sebastian Kułaga
Fix!
dec29b0
raw
history blame
No virus
398 Bytes
from fastapi import APIRouter
from src.distilbert import Distilbert
MODEL_PATH = "SebaK13/DistilBERT-finetuned-customer-queries-balanced"
prediction = APIRouter(prefix="/predict", tags=["predict"])
distilbert_service = Distilbert(model_path=MODEL_PATH)
@prediction.get("/predict")
def predict(query: str):
output = distilbert_service.predict_query_type(query)
return {"label": output}