File size: 2,776 Bytes
67e67c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import cv2
import json
import torch
import logging
import numpy as np
from fastapi import FastAPI, File, UploadFile, Form

from config import settings
from modeling.models import SimpleCNN, SimpleResNet

app = FastAPI()
logging.basicConfig(level=logging.INFO)

num_classes = settings.num_classes
model_type = settings.model_type
device = settings.device

file_json = "label_mapping.json"
file_desc_json = open(file_json)
label_mapping = json.load(file_desc_json)
logging.info(label_mapping)

file_model_local = f"./trained_models/{model_type}/{model_type}.pt"
file_model_cont = f"/data/models/{model_type}/{model_type}.pt"

logging.info(f"model_type: {model_type}")
if model_type == "simple_cnn":
    overhead_mnist_model = SimpleCNN(num_classes=num_classes)
elif model_type == "simple_resnet":
    overhead_mnist_model = SimpleResNet(num_classes=num_classes)
elif model_type == "medium_resnet":
    overhead_mnist_model = SimpleResNet(
        list_num_res_units_per_block=[4, 4], num_classes=num_classes
    )
elif model_type == "deep_resnet":
    overhead_mnist_model = SimpleResNet(
        list_num_res_units_per_block=[6, 6], num_classes=num_classes
    )

try:
    logging.info(f"loading model from {file_model_local}")
    overhead_mnist_model.load_state_dict(
        torch.load(file_model_local, map_location=device)
    )
except:
    logging.info(f"loading model from {file_model_cont}")
    overhead_mnist_model.load_state_dict(
        torch.load(file_model_cont, map_location=device)
    )
overhead_mnist_model.to(device)
overhead_mnist_model.eval()


def get_prediction(img_arr):
    img_arr = np.expand_dims(np.expand_dims(img_arr, 0), 0)
    img_arr = img_arr.astype(np.float32) / 255.0
    img_tensor = torch.tensor(img_arr).float()
    img_tensor = img_tensor.to(device, dtype=torch.float)
    pred_logits = overhead_mnist_model(img_tensor)
    pred_label = torch.argmax(pred_logits, dim=1)

    pred_label_arr = pred_label.detach().cpu().numpy()
    pred_label_arr = np.squeeze(pred_label_arr)
    pred_label_str = label_mapping[str(pred_label_arr)]
    return pred_label_str


@app.get("/info")
def get_app_info():
    """
    -------
    Returns
    -------
    dict_info : dict
        a dictionary with info to be sent as a response to get request
    """
    dict_info = {"app_name": settings.app_name, "version": settings.version}
    return dict_info


@app.post("/predict")
def _file_upload(image_file: UploadFile = File(...)):
    logging.info(image_file)
    img_str = image_file.file.read()
    img_decoded = cv2.imdecode(np.frombuffer(img_str, np.uint8), 0)
    pred_label_str = get_prediction(img_decoded)
    response_json = {"name": image_file.filename, "prediction": pred_label_str}
    logging.info(response_json)
    return response_json