Overhead_MNIST / app.py
abhishekrs4's picture
added try-catch block in the fastapi app for smooth execution if the file is sent via streamlit file_uploader
9abecf2
raw
history blame contribute delete
No virus
3.88 kB
import cv2
import json
import torch
import logging
import numpy as np
from fastapi import FastAPI, File, UploadFile, Form
from config import settings
from modeling.models import SimpleCNN, SimpleResNet, ComplexResNet, ComplexResNetV2
app = FastAPI()
logging.basicConfig(level=logging.INFO)
num_classes = settings.num_classes
model_type = settings.model_type
device = settings.device
file_json = "label_mapping.json"
file_desc_json = open(file_json)
label_mapping = json.load(file_desc_json)
logging.info(label_mapping)
file_model_local = f"./trained_models/{model_type}/{model_type}.pt"
file_model_cont = f"/data/models/{model_type}/{model_type}.pt"
logging.info(f"model_type: {model_type}")
if model_type == "simple_cnn":
overhead_mnist_model = SimpleCNN(num_classes=num_classes)
elif model_type == "simple_resnet":
overhead_mnist_model = SimpleResNet(num_classes=num_classes)
elif model_type == "medium_simple_resnet":
overhead_mnist_model = SimpleResNet(
list_num_res_units_per_block=[4, 4], num_classes=num_classes
)
elif model_type == "deep_simple_resnet":
overhead_mnist_model = SimpleResNet(
list_num_res_units_per_block=[6, 6], num_classes=num_classes
)
elif model_type == "complex_resnet":
overhead_mnist_model = ComplexResNet(
list_num_res_units_per_block=[4, 4, 4], num_classes=num_classes
)
elif model_type == "complex_resnet_v2":
overhead_mnist_model = ComplexResNetV2(
list_num_res_units_per_block=[4, 4, 4], num_classes=num_classes
)
try:
logging.info(f"loading model from {file_model_local}")
overhead_mnist_model.load_state_dict(
torch.load(file_model_local, map_location=device)
)
except:
logging.info(f"loading model from {file_model_cont}")
overhead_mnist_model.load_state_dict(
torch.load(file_model_cont, map_location=device)
)
overhead_mnist_model.to(device)
overhead_mnist_model.eval()
def get_prediction(img_arr: np.ndarray) -> str:
"""
---------
Arguments
---------
img_arr: ndarray
a numpy array of the image
-------
Returns
-------
pred_label_str : str
a string representing the label of the prediction
"""
img_arr = np.expand_dims(np.expand_dims(img_arr, 0), 0)
img_arr = img_arr.astype(np.float32) / 255.0
img_tensor = torch.tensor(img_arr).float()
img_tensor = img_tensor.to(device, dtype=torch.float)
pred_logits = overhead_mnist_model(img_tensor)
pred_label = torch.argmax(pred_logits, dim=1)
pred_label_arr = pred_label.detach().cpu().numpy()
pred_label_arr = np.squeeze(pred_label_arr)
pred_label_str = label_mapping[str(pred_label_arr)]
return pred_label_str
@app.get("/info")
def get_app_info() -> dict:
"""
-------
Returns
-------
dict_info : dict
a dictionary with info to be sent as a response to get request
"""
dict_info = {"app_name": settings.app_name, "version": settings.version}
return dict_info
@app.post("/predict")
def _file_upload(image_file: UploadFile = File(...)) -> dict:
"""
---------
Arguments
---------
image_file: object
an object of type UploadFile
-------
Returns
-------
response_json : dict
a dict as a response json for the post request
"""
logging.info(image_file)
try:
# if the file is sent via post request with open()
img_str = image_file.file.read()
img_decoded = cv2.imdecode(np.frombuffer(img_str, np.uint8), 0)
except:
# if the file is sent via post request from streamlit
img_decoded = cv2.imdecode(np.frombuffer(image_file.getvalue(), np.uint8), 0)
pred_label_str = get_prediction(img_decoded)
response_json = {"name": image_file.filename, "prediction": pred_label_str}
logging.info(response_json)
return response_json