File size: 4,854 Bytes
95c01d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import numpy as np
import streamlit as st

import torch
from PIL import Image
from skimage.io import imread
import torch.nn.functional as F

from training.metrics import *
from training.seg_models import *
from training.image_preprocessing import ImagePadder
from training.logger_utils import load_dict_from_json
from training.dataset import get_dataloader_for_inference

def run_inference(image_array, file_weights, num_classes=5, file_stats_json="training/image_stats.json"):
    oil_spill_seg_model = ResNet50DeepLabV3Plus(
        num_classes=num_classes, pretrained=True
    )
    os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    oil_spill_seg_model.to(device)
    oil_spill_seg_model.load_state_dict(torch.load(file_weights, map_location=device))
    oil_spill_seg_model.eval()

    dict_label_to_color_mapping = {
        0: np.array([0, 0, 0]),
        1: np.array([0, 255, 255]),
        2: np.array([255, 0, 0]),
        3: np.array([153, 76, 0]),
        4: np.array([0, 153, 0]),
    }

    try:
        dict_stats = load_dict_from_json(file_stats_json)
    except:
        dir_json = os.path.dirname(os.path.realpath(__file__))
        dict_stats = load_dict_from_json(os.path.join(dir_json, file_stats_json))

    try:
        image_padder = ImagePadder("/data/images")
    except:
        image_padder = ImagePadder("./sample_padding_image_for_inference")

    # apply padding and preprocessing
    image_padded = image_padder.pad_image(image_array)
    image_preprocessed = image_padded / 255.
    image_preprocessed = image_preprocessed - dict_stats["mean"]
    image_preprocessed = image_preprocessed / dict_stats["std"]
    image_preprocessed = np.expand_dims(image_preprocessed, axis=0)
    # NCHW format
    image_preprocessed = np.transpose(image_preprocessed, (0, 3, 1, 2))


    image_tensor = torch.tensor(image_preprocessed).float()
    image_tensor = image_tensor.to(device, dtype=torch.float)
    pred_logits = oil_spill_seg_model(image_tensor)
    pred_probs = F.softmax(pred_logits, dim=1)
    pred_label = torch.argmax(pred_probs, dim=1)

    pred_label_arr = pred_label.detach().cpu().numpy()
    pred_label_arr = np.squeeze(pred_label_arr)
    pred_label_one_hot = np.eye(num_classes)[pred_label_arr]

    pred_mask_arr = np.zeros((pred_label_arr.shape[0], pred_label_arr.shape[1], 3))
    for sem_class in range(num_classes):
        curr_class_label = pred_label_one_hot[:, :, sem_class]
        curr_class_label = curr_class_label.reshape(pred_label_one_hot.shape[0], pred_label_one_hot.shape[1], 1)

        curr_class_color_mapping = dict_label_to_color_mapping[sem_class]
        curr_class_color_mapping = curr_class_color_mapping.reshape(1, curr_class_color_mapping.shape[0])

        pred_mask_arr +=  curr_class_label * curr_class_color_mapping

    pred_label_arr = pred_label_arr.astype(np.uint8)
    pred_mask_arr = pred_mask_arr.astype(np.uint8)
    padded_height, padded_width = pred_label_arr.shape
    pred_mask_arr = pred_mask_arr[11:padded_height-11, 15:padded_width-15]

    return pred_mask_arr

def infer():
    st.title("Oil spill detection app")

    # file_weights_default = "/home/abhishek/Desktop/RUG/htsm_masterwork/resnet_patch_padding_sgd/fold_5/resnet_50_deeplab_v3+/oil_spill_seg_resnet_50_deeplab_v3+_80.pt"
    file_weights_default = "/data/models/oil_spill_seg_resnet_50_deeplab_v3+_80.pt"
    file_weights = st.sidebar.text_input("File model weights", file_weights_default)

    if not os.path.isfile(file_weights):
        st.write("Wrong weights file path")
    else:
        st.write(f"Weights file: {file_weights}")

    # select an input SAR image file
    image_file_buffer = st.sidebar.file_uploader("Select input SAR image", type=["jpg", "jpeg"])
    # read the image
    if image_file_buffer is not None:
        image = Image.open(image_file_buffer)
        image_array = np.array(image)
        st.image(image_array, caption=f"Input image: {image_file_buffer.name}")
    else:
        st.write("Input image: not selected")

    # select a mask image file
    mask_file_buffer = st.sidebar.file_uploader("Select mask image", type=["png"])
    # read the mask
    if mask_file_buffer is not None:
        mask = Image.open(mask_file_buffer)
        mask_array = np.array(mask)
        st.image(mask_array, caption=f"Mask image: {mask_file_buffer.name}")
    else:
        st.write("Mask image: not selected")

    # run inference when the option is invoked by the user
    infer_button = st.sidebar.button("Run inference")
    if infer_button:
        mask_predicted = run_inference(image_array, file_weights)
        st.image(mask_predicted, caption=f"Predicted mask for the input: {image_file_buffer.name}")
    return

def main():
    infer()
    return

if __name__ == "__main__":
    main()