|
import os |
|
import numpy as np |
|
import streamlit as st |
|
|
|
import torch |
|
from PIL import Image |
|
from skimage.io import imread |
|
import torch.nn.functional as F |
|
|
|
from training.metrics import * |
|
from training.seg_models import * |
|
from training.image_preprocessing import ImagePadder |
|
from training.logger_utils import load_dict_from_json |
|
from training.dataset import get_dataloader_for_inference |
|
|
|
def run_inference(image_array, file_weights, num_classes=5, file_stats_json="training/image_stats.json"): |
|
oil_spill_seg_model = ResNet50DeepLabV3Plus( |
|
num_classes=num_classes, pretrained=True |
|
) |
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
oil_spill_seg_model.to(device) |
|
oil_spill_seg_model.load_state_dict(torch.load(file_weights, map_location=device)) |
|
oil_spill_seg_model.eval() |
|
|
|
dict_label_to_color_mapping = { |
|
0: np.array([0, 0, 0]), |
|
1: np.array([0, 255, 255]), |
|
2: np.array([255, 0, 0]), |
|
3: np.array([153, 76, 0]), |
|
4: np.array([0, 153, 0]), |
|
} |
|
|
|
try: |
|
dict_stats = load_dict_from_json(file_stats_json) |
|
except: |
|
dir_json = os.path.dirname(os.path.realpath(__file__)) |
|
dict_stats = load_dict_from_json(os.path.join(dir_json, file_stats_json)) |
|
|
|
try: |
|
image_padder = ImagePadder("/data/images") |
|
except: |
|
image_padder = ImagePadder("./sample_padding_image_for_inference") |
|
|
|
|
|
image_padded = image_padder.pad_image(image_array) |
|
image_preprocessed = image_padded / 255. |
|
image_preprocessed = image_preprocessed - dict_stats["mean"] |
|
image_preprocessed = image_preprocessed / dict_stats["std"] |
|
image_preprocessed = np.expand_dims(image_preprocessed, axis=0) |
|
|
|
image_preprocessed = np.transpose(image_preprocessed, (0, 3, 1, 2)) |
|
|
|
|
|
image_tensor = torch.tensor(image_preprocessed).float() |
|
image_tensor = image_tensor.to(device, dtype=torch.float) |
|
pred_logits = oil_spill_seg_model(image_tensor) |
|
pred_probs = F.softmax(pred_logits, dim=1) |
|
pred_label = torch.argmax(pred_probs, dim=1) |
|
|
|
pred_label_arr = pred_label.detach().cpu().numpy() |
|
pred_label_arr = np.squeeze(pred_label_arr) |
|
pred_label_one_hot = np.eye(num_classes)[pred_label_arr] |
|
|
|
pred_mask_arr = np.zeros((pred_label_arr.shape[0], pred_label_arr.shape[1], 3)) |
|
for sem_class in range(num_classes): |
|
curr_class_label = pred_label_one_hot[:, :, sem_class] |
|
curr_class_label = curr_class_label.reshape(pred_label_one_hot.shape[0], pred_label_one_hot.shape[1], 1) |
|
|
|
curr_class_color_mapping = dict_label_to_color_mapping[sem_class] |
|
curr_class_color_mapping = curr_class_color_mapping.reshape(1, curr_class_color_mapping.shape[0]) |
|
|
|
pred_mask_arr += curr_class_label * curr_class_color_mapping |
|
|
|
pred_label_arr = pred_label_arr.astype(np.uint8) |
|
pred_mask_arr = pred_mask_arr.astype(np.uint8) |
|
padded_height, padded_width = pred_label_arr.shape |
|
pred_mask_arr = pred_mask_arr[11:padded_height-11, 15:padded_width-15] |
|
|
|
return pred_mask_arr |
|
|
|
def infer(): |
|
st.title("Oil spill detection app") |
|
|
|
|
|
file_weights_default = "/data/models/oil_spill_seg_resnet_50_deeplab_v3+_80.pt" |
|
file_weights = st.sidebar.text_input("File model weights", file_weights_default) |
|
|
|
if not os.path.isfile(file_weights): |
|
st.write("Wrong weights file path") |
|
else: |
|
st.write(f"Weights file: {file_weights}") |
|
|
|
|
|
image_file_buffer = st.sidebar.file_uploader("Select input SAR image", type=["jpg", "jpeg"]) |
|
|
|
if image_file_buffer is not None: |
|
image = Image.open(image_file_buffer) |
|
image_array = np.array(image) |
|
st.image(image_array, caption=f"Input image: {image_file_buffer.name}") |
|
else: |
|
st.write("Input image: not selected") |
|
|
|
|
|
mask_file_buffer = st.sidebar.file_uploader("Select mask image", type=["png"]) |
|
|
|
if mask_file_buffer is not None: |
|
mask = Image.open(mask_file_buffer) |
|
mask_array = np.array(mask) |
|
st.image(mask_array, caption=f"Mask image: {mask_file_buffer.name}") |
|
else: |
|
st.write("Mask image: not selected") |
|
|
|
|
|
infer_button = st.sidebar.button("Run inference") |
|
if infer_button: |
|
mask_predicted = run_inference(image_array, file_weights) |
|
st.image(mask_predicted, caption=f"Predicted mask for the input: {image_file_buffer.name}") |
|
return |
|
|
|
def main(): |
|
infer() |
|
return |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|