ariel-imaging / app.py
Hazem-Ahmed-Abdelraouf
added app files
1a0f92e
import onnx
import onnxruntime as rt
from PIL import Image
import numpy as np
from patchify import patchify, unpatchify
import torch
import gradio as gr
mapping = {(155, 155, 155): [0, 'Unlabeled'],
(60, 16, 152): [1, 'Building'],
(132, 41, 246): [2, 'Land'],
(110, 193, 228): [3, 'Road'],
(254, 221, 58): [4, 'Vegetation'],
(226, 169, 41): [5, 'Water']}
def get_nearest_multiple(target, divisor):
return round(target/divisor)
def resize_img(img:Image.Image, patch_size: int):
"""Transform to resize the image dimensions according to patch size"""
# getting the suitable resize H and W
max_w = get_nearest_multiple(img.size[0], patch_size) * patch_size
max_h = get_nearest_multiple(img.size[1], patch_size) * patch_size
img = img.resize((max_w, max_h), Image.Resampling.BILINEAR)
return img, max_h, max_w
def preprocess_patches(patches, patch_size):
"""Transform for preproccesing the image patches to be suitable for model inference"""
patches = patches.reshape(-1, 3, patch_size, patch_size)
patches = torch.tensor(patches)/255.
return patches
def seg_maps_to_rgb(maps):
rgb_maps = np.zeros(maps.shape+(3,), np.uint8) # shape: n_patches x patch_size x patch_size x 3
for i, m in enumerate(maps):
for k, v in mapping.items():
rgb_maps[i, m==v[0]] = k
return rgb_maps
def make_prediction(img: Image.Image, patch_size = 224):
w, h = img.size
resized_img, max_h, max_w = resize_img(img, patch_size)
# patchifying and inference on each patch
patches = patchify(np.array(resized_img), (patch_size, patch_size, 3), step=patch_size)
patches_shape = patches.shape
patches = preprocess_patches(patches, patch_size) # shape: n_patches x n_channels x patch_size x patch_size
maps = [model(patch[None]) for patch in patches]
maps = np.concatenate(maps)
maps = np.argmax(maps, axis=1) # shape: n_patches x patch_size x patch_size
# convert 2d class maps to rgb
rgb_maps = seg_maps_to_rgb(maps)
# reshaping back the maps to the shape of parches
rgb_maps = rgb_maps.reshape(patches_shape)
final_mask = unpatchify(rgb_maps, (max_h, max_w, 3))
# resizing back the mask to its original size
final_mask = Image.fromarray(final_mask).resize((w, h), Image.Resampling.NEAREST)
return final_mask
session = rt.InferenceSession('./satellite_seg_model_resnet50.onnx')
model = lambda patches: session.run(['seg_maps'], {'image_1_3_224_224': patches.numpy()})[0]
demo = gr.Interface(
make_prediction,
inputs=[gr.Image(type='pil')],
outputs="image",
title="Ariel Imaging Segmentation",
description = "Please upload an image to see segmentation capabilities of this model",
examples=[["img.jpg"]]
)
demo.launch()