|
import os |
|
import io |
|
import numpy as np |
|
import streamlit as st |
|
|
|
import matplotlib as mpl |
|
import matplotlib.pyplot as plt |
|
from matplotlib.colors import ListedColormap |
|
|
|
import torch |
|
from PIL import Image |
|
from skimage.io import imread |
|
import torch.nn.functional as F |
|
|
|
from training.metrics import * |
|
from training.seg_models import * |
|
from training.image_preprocessing import ImagePadder |
|
from training.logger_utils import load_dict_from_json |
|
from training.dataset import get_dataloader_for_inference |
|
|
|
def run_inference(image_array, file_weights, num_classes=5, file_stats_json="training/image_stats.json"): |
|
""" |
|
--------- |
|
Arguments |
|
--------- |
|
image_array : ndarray |
|
a numpy array of the image |
|
file_weights : str |
|
full path to weights file |
|
num_classes : int |
|
number of classes in the dataset |
|
file_stats_json : str |
|
full path to the json stats file for preprocessing |
|
|
|
------- |
|
Returns |
|
------- |
|
pred_mask_arr : ndarray |
|
a numpy array of the prediction mask |
|
""" |
|
oil_spill_seg_model = ResNet50DeepLabV3Plus( |
|
num_classes=num_classes, pretrained=True |
|
) |
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
oil_spill_seg_model.to(device) |
|
oil_spill_seg_model.load_state_dict(torch.load(file_weights, map_location=device)) |
|
oil_spill_seg_model.eval() |
|
|
|
dict_label_to_color_mapping = { |
|
0: np.array([0, 0, 0]), |
|
1: np.array([0, 255, 255]), |
|
2: np.array([255, 0, 0]), |
|
3: np.array([153, 76, 0]), |
|
4: np.array([0, 153, 0]), |
|
} |
|
|
|
try: |
|
dict_stats = load_dict_from_json(file_stats_json) |
|
except: |
|
dir_json = os.path.dirname(os.path.realpath(__file__)) |
|
dict_stats = load_dict_from_json(os.path.join(dir_json, file_stats_json)) |
|
|
|
try: |
|
image_padder = ImagePadder("/data/images") |
|
except: |
|
image_padder = ImagePadder("./sample_padding_image_for_inference") |
|
|
|
|
|
image_padded = image_padder.pad_image(image_array) |
|
image_preprocessed = image_padded / 255. |
|
image_preprocessed = image_preprocessed - dict_stats["mean"] |
|
image_preprocessed = image_preprocessed / dict_stats["std"] |
|
image_preprocessed = np.expand_dims(image_preprocessed, axis=0) |
|
|
|
image_preprocessed = np.transpose(image_preprocessed, (0, 3, 1, 2)) |
|
|
|
|
|
image_tensor = torch.tensor(image_preprocessed).float() |
|
image_tensor = image_tensor.to(device, dtype=torch.float) |
|
pred_logits = oil_spill_seg_model(image_tensor) |
|
pred_probs = F.softmax(pred_logits, dim=1) |
|
pred_label = torch.argmax(pred_probs, dim=1) |
|
|
|
pred_label_arr = pred_label.detach().cpu().numpy() |
|
pred_label_arr = np.squeeze(pred_label_arr) |
|
pred_label_one_hot = np.eye(num_classes)[pred_label_arr] |
|
|
|
pred_mask_arr = np.zeros((pred_label_arr.shape[0], pred_label_arr.shape[1], 3)) |
|
for sem_class in range(num_classes): |
|
curr_class_label = pred_label_one_hot[:, :, sem_class] |
|
curr_class_label = curr_class_label.reshape(pred_label_one_hot.shape[0], pred_label_one_hot.shape[1], 1) |
|
|
|
curr_class_color_mapping = dict_label_to_color_mapping[sem_class] |
|
curr_class_color_mapping = curr_class_color_mapping.reshape(1, curr_class_color_mapping.shape[0]) |
|
|
|
pred_mask_arr += curr_class_label * curr_class_color_mapping |
|
|
|
pred_label_arr = pred_label_arr.astype(np.uint8) |
|
pred_mask_arr = pred_mask_arr.astype(np.uint8) |
|
padded_height, padded_width = pred_label_arr.shape |
|
pred_mask_arr = pred_mask_arr[11:padded_height-11, 15:padded_width-15] |
|
|
|
return pred_mask_arr |
|
|
|
def show_mask_interpretation(): |
|
colors = ["#000000", "#00FFFF", "#FF0000", "#994C00", "#009900"] |
|
labels = ["sea_surface", "oil_spill", "oil_spill_look_alike", "ship", "land"] |
|
my_cmap = ListedColormap(colors, name="my_cmap") |
|
data = [[1, 2, 3, 4, 5]] |
|
fig = plt.figure(figsize=(20, 2)) |
|
plt.title("Oil Spill mask interpretation") |
|
plt.xticks(ticks=np.arange(len(labels)), labels=labels) |
|
plt.yticks([]) |
|
plt.imshow(data, cmap=my_cmap) |
|
st.pyplot(fig) |
|
return |
|
|
|
def infer(): |
|
st.title("Oil spill detection app") |
|
|
|
|
|
file_weights_default = "/data/models/oil_spill_seg_resnet_50_deeplab_v3+_80.pt" |
|
file_weights = st.sidebar.text_input("File model weights", file_weights_default) |
|
|
|
if not os.path.isfile(file_weights): |
|
st.write("Wrong weights file path") |
|
else: |
|
st.write(f"Weights file: {file_weights}") |
|
|
|
|
|
image_file_buffer = st.sidebar.file_uploader("Select input SAR image", type=["jpg", "jpeg"]) |
|
|
|
if image_file_buffer is not None: |
|
image = Image.open(image_file_buffer) |
|
image_array = np.array(image) |
|
st.image(image_array, caption=f"Input image: {image_file_buffer.name}") |
|
else: |
|
st.write("Input image: not selected") |
|
|
|
|
|
mask_file_buffer = st.sidebar.file_uploader("Select groundtruth mask image (optional, only for visual comparison with the prediction)", type=["png"]) |
|
|
|
if mask_file_buffer is not None: |
|
mask = Image.open(mask_file_buffer) |
|
mask_array = np.array(mask) |
|
st.image(mask_array, caption=f"Mask image: {mask_file_buffer.name}") |
|
else: |
|
st.write("Groundtruth mask image (optional): not selected") |
|
|
|
|
|
infer_button = st.sidebar.button("Run inference") |
|
if infer_button: |
|
mask_predicted = run_inference(image_array, file_weights) |
|
st.image(mask_predicted, caption=f"Predicted mask for the input: {image_file_buffer.name}") |
|
|
|
|
|
mask_pred_image = Image.fromarray(mask_predicted.astype("uint8"), "RGB") |
|
with io.BytesIO() as file_obj: |
|
mask_pred_image.save(file_obj, format="PNG") |
|
mask_for_download = file_obj.getvalue() |
|
st.download_button("Download predicted mask", data=mask_for_download, file_name="pred_mask.png", mime="image/png") |
|
|
|
|
|
show_mask_interpretation() |
|
return |
|
|
|
def app_info(): |
|
st.title("App info") |
|
st.markdown("_Task - Oil Spill segmentation_") |
|
st.markdown("_Project repo - [https://github.com/AbhishekRS4/HTSM_Oil_Spill_Segmentation](https://github.com/AbhishekRS4/HTSM_Oil_Spill_Segmentation)_") |
|
st.markdown("_Dataset - [Oil Spill detection dataset](https://m4d.iti.gr/oil-spill-detection-dataset/)_") |
|
st.header("Brief description of the project and the dataset") |
|
st.write("The Oil Spill detection dataset contains images extracted from satellite Synthetic Aperture Radar (SAR) data.") |
|
st.write("This dataset contains labels for 5 classes --- sea_surface, oil_spill, oil_spill_look_alike, ship, and land.") |
|
st.write("A custom encoder-decoder architecture is modeled for the segmentation task.") |
|
st.write("The best performing model has been used for the deployed application.") |
|
return |
|
|
|
app_modes = { |
|
"App Info" : app_info, |
|
"Oil Spill Inference App": infer, |
|
} |
|
|
|
def start_app(): |
|
selected_mode = st.sidebar.selectbox("Select mode", list(app_modes.keys())) |
|
app_modes[selected_mode]() |
|
return |
|
|
|
def main(): |
|
start_app() |
|
return |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|