Overhead_MNIST / app.py
abhishekrs4's picture
added scripts for fastapi application
67e67c1
raw
history blame
No virus
2.78 kB
import cv2
import json
import torch
import logging
import numpy as np
from fastapi import FastAPI, File, UploadFile, Form
from config import settings
from modeling.models import SimpleCNN, SimpleResNet
app = FastAPI()
logging.basicConfig(level=logging.INFO)
num_classes = settings.num_classes
model_type = settings.model_type
device = settings.device
file_json = "label_mapping.json"
file_desc_json = open(file_json)
label_mapping = json.load(file_desc_json)
logging.info(label_mapping)
file_model_local = f"./trained_models/{model_type}/{model_type}.pt"
file_model_cont = f"/data/models/{model_type}/{model_type}.pt"
logging.info(f"model_type: {model_type}")
if model_type == "simple_cnn":
overhead_mnist_model = SimpleCNN(num_classes=num_classes)
elif model_type == "simple_resnet":
overhead_mnist_model = SimpleResNet(num_classes=num_classes)
elif model_type == "medium_resnet":
overhead_mnist_model = SimpleResNet(
list_num_res_units_per_block=[4, 4], num_classes=num_classes
)
elif model_type == "deep_resnet":
overhead_mnist_model = SimpleResNet(
list_num_res_units_per_block=[6, 6], num_classes=num_classes
)
try:
logging.info(f"loading model from {file_model_local}")
overhead_mnist_model.load_state_dict(
torch.load(file_model_local, map_location=device)
)
except:
logging.info(f"loading model from {file_model_cont}")
overhead_mnist_model.load_state_dict(
torch.load(file_model_cont, map_location=device)
)
overhead_mnist_model.to(device)
overhead_mnist_model.eval()
def get_prediction(img_arr):
img_arr = np.expand_dims(np.expand_dims(img_arr, 0), 0)
img_arr = img_arr.astype(np.float32) / 255.0
img_tensor = torch.tensor(img_arr).float()
img_tensor = img_tensor.to(device, dtype=torch.float)
pred_logits = overhead_mnist_model(img_tensor)
pred_label = torch.argmax(pred_logits, dim=1)
pred_label_arr = pred_label.detach().cpu().numpy()
pred_label_arr = np.squeeze(pred_label_arr)
pred_label_str = label_mapping[str(pred_label_arr)]
return pred_label_str
@app.get("/info")
def get_app_info():
"""
-------
Returns
-------
dict_info : dict
a dictionary with info to be sent as a response to get request
"""
dict_info = {"app_name": settings.app_name, "version": settings.version}
return dict_info
@app.post("/predict")
def _file_upload(image_file: UploadFile = File(...)):
logging.info(image_file)
img_str = image_file.file.read()
img_decoded = cv2.imdecode(np.frombuffer(img_str, np.uint8), 0)
pred_label_str = get_prediction(img_decoded)
response_json = {"name": image_file.filename, "prediction": pred_label_str}
logging.info(response_json)
return response_json