|
from transformers import pipeline |
|
import os |
|
from fastapi import APIRouter |
|
from models.prediction import Prediction |
|
|
|
from services.answers_service import AnswersService |
|
router = APIRouter( |
|
prefix='/prediction', |
|
tags=['prediction'], |
|
) |
|
|
|
|
|
nlp_task = os.getenv('NLP_TASK') |
|
model_name = os.getenv('MODEL') |
|
model = pipeline(nlp_task, model=model_name) |
|
|
|
|
|
@router.post("/get_prediction", name='Получение предсказания модели') |
|
def get_prediction(message_from_user: str): |
|
output = model(message_from_user) |
|
prediction = Prediction.from_output(dictionary=output[0]) |
|
answer = AnswersService.get_answer_by_id(prediction=prediction) |
|
return answer |
|
|