import os import io import numpy as np import streamlit as st import matplotlib as mpl import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap import torch from PIL import Image from skimage.io import imread import torch.nn.functional as F from training.metrics import * from training.seg_models import * from training.image_preprocessing import ImagePadder from training.logger_utils import load_dict_from_json from training.dataset import get_dataloader_for_inference def run_inference(image_array, file_weights, num_classes=5, file_stats_json="training/image_stats.json"): """ --------- Arguments --------- image_array : ndarray a numpy array of the image file_weights : str full path to weights file num_classes : int number of classes in the dataset file_stats_json : str full path to the json stats file for preprocessing ------- Returns ------- pred_mask_arr : ndarray a numpy array of the prediction mask """ oil_spill_seg_model = ResNet50DeepLabV3Plus( num_classes=num_classes, pretrained=True ) os.environ["CUDA_LAUNCH_BLOCKING"] = "1" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") oil_spill_seg_model.to(device) oil_spill_seg_model.load_state_dict(torch.load(file_weights, map_location=device)) oil_spill_seg_model.eval() dict_label_to_color_mapping = { 0: np.array([0, 0, 0]), 1: np.array([0, 255, 255]), 2: np.array([255, 0, 0]), 3: np.array([153, 76, 0]), 4: np.array([0, 153, 0]), } try: dict_stats = load_dict_from_json(file_stats_json) except: dir_json = os.path.dirname(os.path.realpath(__file__)) dict_stats = load_dict_from_json(os.path.join(dir_json, file_stats_json)) try: image_padder = ImagePadder("/data/images") except: image_padder = ImagePadder("./sample_padding_image_for_inference") # apply padding and preprocessing image_padded = image_padder.pad_image(image_array) image_preprocessed = image_padded / 255. image_preprocessed = image_preprocessed - dict_stats["mean"] image_preprocessed = image_preprocessed / dict_stats["std"] image_preprocessed = np.expand_dims(image_preprocessed, axis=0) # NCHW format image_preprocessed = np.transpose(image_preprocessed, (0, 3, 1, 2)) image_tensor = torch.tensor(image_preprocessed).float() image_tensor = image_tensor.to(device, dtype=torch.float) pred_logits = oil_spill_seg_model(image_tensor) pred_probs = F.softmax(pred_logits, dim=1) pred_label = torch.argmax(pred_probs, dim=1) pred_label_arr = pred_label.detach().cpu().numpy() pred_label_arr = np.squeeze(pred_label_arr) pred_label_one_hot = np.eye(num_classes)[pred_label_arr] pred_mask_arr = np.zeros((pred_label_arr.shape[0], pred_label_arr.shape[1], 3)) for sem_class in range(num_classes): curr_class_label = pred_label_one_hot[:, :, sem_class] curr_class_label = curr_class_label.reshape(pred_label_one_hot.shape[0], pred_label_one_hot.shape[1], 1) curr_class_color_mapping = dict_label_to_color_mapping[sem_class] curr_class_color_mapping = curr_class_color_mapping.reshape(1, curr_class_color_mapping.shape[0]) pred_mask_arr += curr_class_label * curr_class_color_mapping pred_label_arr = pred_label_arr.astype(np.uint8) pred_mask_arr = pred_mask_arr.astype(np.uint8) padded_height, padded_width = pred_label_arr.shape pred_mask_arr = pred_mask_arr[11:padded_height-11, 15:padded_width-15] return pred_mask_arr def show_mask_interpretation(): colors = ["#000000", "#00FFFF", "#FF0000", "#994C00", "#009900"] labels = ["sea_surface", "oil_spill", "oil_spill_look_alike", "ship", "land"] my_cmap = ListedColormap(colors, name="my_cmap") data = [[1, 2, 3, 4, 5]] fig = plt.figure(figsize=(20, 2)) plt.title("Oil Spill mask interpretation") plt.xticks(ticks=np.arange(len(labels)), labels=labels) plt.yticks([]) plt.imshow(data, cmap=my_cmap) st.pyplot(fig) return def infer(): st.title("Oil spill detection app") # file_weights_default = "/home/abhishek/Desktop/RUG/htsm_masterwork/resnet_patch_padding_sgd/fold_5/resnet_50_deeplab_v3+/oil_spill_seg_resnet_50_deeplab_v3+_80.pt" file_weights_default = "/data/models/oil_spill_seg_resnet_50_deeplab_v3+_80.pt" file_weights = st.sidebar.text_input("File model weights", file_weights_default) if not os.path.isfile(file_weights): st.write("Wrong weights file path") else: st.write(f"Weights file: {file_weights}") # select an input SAR image file image_file_buffer = st.sidebar.file_uploader("Select input SAR image", type=["jpg", "jpeg"]) # read the image if image_file_buffer is not None: image = Image.open(image_file_buffer) image_array = np.array(image) st.image(image_array, caption=f"Input image: {image_file_buffer.name}") else: st.write("Input image: not selected") # select a mask image file mask_file_buffer = st.sidebar.file_uploader("Select groundtruth mask image (optional, only for visual comparison with the prediction)", type=["png"]) # read the mask if mask_file_buffer is not None: mask = Image.open(mask_file_buffer) mask_array = np.array(mask) st.image(mask_array, caption=f"Mask image: {mask_file_buffer.name}") else: st.write("Groundtruth mask image (optional): not selected") # run inference when the option is invoked by the user infer_button = st.sidebar.button("Run inference") if infer_button: mask_predicted = run_inference(image_array, file_weights) st.image(mask_predicted, caption=f"Predicted mask for the input: {image_file_buffer.name}") # option to download predicted mask mask_pred_image = Image.fromarray(mask_predicted.astype("uint8"), "RGB") with io.BytesIO() as file_obj: mask_pred_image.save(file_obj, format="PNG") mask_for_download = file_obj.getvalue() st.download_button("Download predicted mask", data=mask_for_download, file_name="pred_mask.png", mime="image/png") # display a figure showing the interpretation of the mask labels show_mask_interpretation() return def app_info(): st.title("App info") st.markdown("_Task - Oil Spill segmentation_") st.markdown("_Project repo - [https://github.com/AbhishekRS4/HTSM_Oil_Spill_Segmentation](https://github.com/AbhishekRS4/HTSM_Oil_Spill_Segmentation)_") st.markdown("_Dataset - [Oil Spill detection dataset](https://m4d.iti.gr/oil-spill-detection-dataset/)_") st.header("Brief description of the project and the dataset") st.write("The Oil Spill detection dataset contains images extracted from satellite Synthetic Aperture Radar (SAR) data.") st.write("This dataset contains labels for 5 classes --- sea_surface, oil_spill, oil_spill_look_alike, ship, and land.") st.write("A custom encoder-decoder architecture is modeled for the segmentation task.") st.write("The best performing model has been used for the deployed application.") return app_modes = { "App Info" : app_info, "Oil Spill Inference App": infer, } def start_app(): selected_mode = st.sidebar.selectbox("Select mode", list(app_modes.keys())) app_modes[selected_mode]() return def main(): start_app() return if __name__ == "__main__": main()