pytorch_docker / main.py
Zarzamorati10's picture
Upload 5 files
ab63513 verified
from typing import List,Dict
from pydantic import BaseModel
import numpy as np
from fastapi import FastAPI
import torch
from model import BinaryClassificationWithLogits
import __main__
model_path="classification_gaussian_binary_model_0v.pt"
model=BinaryClassificationWithLogits(in_features=4,
out_features=1,
hidden_features=10)
model = torch.jit.load(model_path,map_location="cpu")
class ClassificationFeatures(BaseModel):
feature_1:float
feature_2:float
feature_3:float
feature_4:float
# Creando una instacnia de FastAPI
app=FastAPI()
# Definiendo la ruta raiz
@app.get("/")
def home_page():
return "Welcome the API with pytorch"
# Definiendo ruta para inferencias
@app.post("/predict")
def predict_sample(cls_features:ClassificationFeatures) -> Dict:
input_data=np.array([[
cls_features.feature_1,
cls_features.feature_2,
cls_features.feature_3,
cls_features.feature_4,
]])
X=torch.tensor(input_data,dtype=torch.float32)
model.eval()
with torch.inference_mode():
logit=model(X)
pred_prob=torch.sigmoid(logit)
pred_label=torch.round(pred_prob)
return {"prediction":pred_label.item()}