Spaces:
Runtime error
Runtime error
File size: 4,854 Bytes
95c01d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import os
import numpy as np
import streamlit as st
import torch
from PIL import Image
from skimage.io import imread
import torch.nn.functional as F
from training.metrics import *
from training.seg_models import *
from training.image_preprocessing import ImagePadder
from training.logger_utils import load_dict_from_json
from training.dataset import get_dataloader_for_inference
def run_inference(image_array, file_weights, num_classes=5, file_stats_json="training/image_stats.json"):
oil_spill_seg_model = ResNet50DeepLabV3Plus(
num_classes=num_classes, pretrained=True
)
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
oil_spill_seg_model.to(device)
oil_spill_seg_model.load_state_dict(torch.load(file_weights, map_location=device))
oil_spill_seg_model.eval()
dict_label_to_color_mapping = {
0: np.array([0, 0, 0]),
1: np.array([0, 255, 255]),
2: np.array([255, 0, 0]),
3: np.array([153, 76, 0]),
4: np.array([0, 153, 0]),
}
try:
dict_stats = load_dict_from_json(file_stats_json)
except:
dir_json = os.path.dirname(os.path.realpath(__file__))
dict_stats = load_dict_from_json(os.path.join(dir_json, file_stats_json))
try:
image_padder = ImagePadder("/data/images")
except:
image_padder = ImagePadder("./sample_padding_image_for_inference")
# apply padding and preprocessing
image_padded = image_padder.pad_image(image_array)
image_preprocessed = image_padded / 255.
image_preprocessed = image_preprocessed - dict_stats["mean"]
image_preprocessed = image_preprocessed / dict_stats["std"]
image_preprocessed = np.expand_dims(image_preprocessed, axis=0)
# NCHW format
image_preprocessed = np.transpose(image_preprocessed, (0, 3, 1, 2))
image_tensor = torch.tensor(image_preprocessed).float()
image_tensor = image_tensor.to(device, dtype=torch.float)
pred_logits = oil_spill_seg_model(image_tensor)
pred_probs = F.softmax(pred_logits, dim=1)
pred_label = torch.argmax(pred_probs, dim=1)
pred_label_arr = pred_label.detach().cpu().numpy()
pred_label_arr = np.squeeze(pred_label_arr)
pred_label_one_hot = np.eye(num_classes)[pred_label_arr]
pred_mask_arr = np.zeros((pred_label_arr.shape[0], pred_label_arr.shape[1], 3))
for sem_class in range(num_classes):
curr_class_label = pred_label_one_hot[:, :, sem_class]
curr_class_label = curr_class_label.reshape(pred_label_one_hot.shape[0], pred_label_one_hot.shape[1], 1)
curr_class_color_mapping = dict_label_to_color_mapping[sem_class]
curr_class_color_mapping = curr_class_color_mapping.reshape(1, curr_class_color_mapping.shape[0])
pred_mask_arr += curr_class_label * curr_class_color_mapping
pred_label_arr = pred_label_arr.astype(np.uint8)
pred_mask_arr = pred_mask_arr.astype(np.uint8)
padded_height, padded_width = pred_label_arr.shape
pred_mask_arr = pred_mask_arr[11:padded_height-11, 15:padded_width-15]
return pred_mask_arr
def infer():
st.title("Oil spill detection app")
# file_weights_default = "/home/abhishek/Desktop/RUG/htsm_masterwork/resnet_patch_padding_sgd/fold_5/resnet_50_deeplab_v3+/oil_spill_seg_resnet_50_deeplab_v3+_80.pt"
file_weights_default = "/data/models/oil_spill_seg_resnet_50_deeplab_v3+_80.pt"
file_weights = st.sidebar.text_input("File model weights", file_weights_default)
if not os.path.isfile(file_weights):
st.write("Wrong weights file path")
else:
st.write(f"Weights file: {file_weights}")
# select an input SAR image file
image_file_buffer = st.sidebar.file_uploader("Select input SAR image", type=["jpg", "jpeg"])
# read the image
if image_file_buffer is not None:
image = Image.open(image_file_buffer)
image_array = np.array(image)
st.image(image_array, caption=f"Input image: {image_file_buffer.name}")
else:
st.write("Input image: not selected")
# select a mask image file
mask_file_buffer = st.sidebar.file_uploader("Select mask image", type=["png"])
# read the mask
if mask_file_buffer is not None:
mask = Image.open(mask_file_buffer)
mask_array = np.array(mask)
st.image(mask_array, caption=f"Mask image: {mask_file_buffer.name}")
else:
st.write("Mask image: not selected")
# run inference when the option is invoked by the user
infer_button = st.sidebar.button("Run inference")
if infer_button:
mask_predicted = run_inference(image_array, file_weights)
st.image(mask_predicted, caption=f"Predicted mask for the input: {image_file_buffer.name}")
return
def main():
infer()
return
if __name__ == "__main__":
main()
|