|
from typing import List,Dict |
|
from pydantic import BaseModel |
|
import numpy as np |
|
from fastapi import FastAPI |
|
import torch |
|
from model import BinaryClassificationWithLogits |
|
import __main__ |
|
|
|
model_path="classification_gaussian_binary_model_0v.pt" |
|
model=BinaryClassificationWithLogits(in_features=4, |
|
out_features=1, |
|
hidden_features=10) |
|
|
|
model = torch.jit.load(model_path,map_location="cpu") |
|
|
|
|
|
class ClassificationFeatures(BaseModel): |
|
feature_1:float |
|
feature_2:float |
|
feature_3:float |
|
feature_4:float |
|
|
|
|
|
|
|
app=FastAPI() |
|
|
|
@app.get("/") |
|
def home_page(): |
|
return "Welcome the API with pytorch" |
|
|
|
@app.post("/predict") |
|
def predict_sample(cls_features:ClassificationFeatures) -> Dict: |
|
|
|
input_data=np.array([[ |
|
cls_features.feature_1, |
|
cls_features.feature_2, |
|
cls_features.feature_3, |
|
cls_features.feature_4, |
|
]]) |
|
X=torch.tensor(input_data,dtype=torch.float32) |
|
model.eval() |
|
with torch.inference_mode(): |
|
logit=model(X) |
|
pred_prob=torch.sigmoid(logit) |
|
pred_label=torch.round(pred_prob) |
|
|
|
return {"prediction":pred_label.item()} |