nlp-rgr / api /prediction.py
Max Bushuev
dfdf
9a4ae09
raw
history blame contribute delete
720 Bytes
from transformers import pipeline
import os
from fastapi import APIRouter
from models.prediction import Prediction
from services.answers_service import AnswersService
router = APIRouter(
prefix='/prediction',
tags=['prediction'],
)
# nlp_task = os.getenv('NLP_TASK')
# model_name = os.getenv('MODEL')
model = pipeline("text-classification", model="Maxim01/intent-classification")
@router.post("/get_prediction", name='Получение предсказания модели')
def get_prediction(message_from_user: str):
output = model(message_from_user)
prediction = Prediction.from_output(dictionary=output[0])
answer = AnswersService.get_answer_by_id(prediction=prediction)
return answer