nlp-rgr / api /prediction.py
Max Bushuev
f
cf509fc
raw
history blame
No virus
682 Bytes
from transformers import pipeline
import os
from fastapi import APIRouter
from models.prediction import Prediction
from services.answers_service import AnswersService
router = APIRouter(
prefix='/prediction',
tags=['prediction'],
)
nlp_task = os.getenv('NLP_TASK')
model_name = os.getenv('MODEL')
model = pipeline(nlp_task, model=model_name)
@router.post("/get_prediction", name='Получение предсказания модели')
def get_prediction(message_from_user: str):
output = model(message_from_user)
prediction = Prediction.from_output(dictionary=output[0])
answer = AnswersService.get_answer_by_id(prediction=prediction)
return answer