Spaces:
Sleeping
Sleeping
import onnx | |
import onnxruntime as rt | |
from PIL import Image | |
import numpy as np | |
from patchify import patchify, unpatchify | |
import torch | |
import gradio as gr | |
mapping = {(155, 155, 155): [0, 'Unlabeled'], | |
(60, 16, 152): [1, 'Building'], | |
(132, 41, 246): [2, 'Land'], | |
(110, 193, 228): [3, 'Road'], | |
(254, 221, 58): [4, 'Vegetation'], | |
(226, 169, 41): [5, 'Water']} | |
def get_nearest_multiple(target, divisor): | |
return round(target/divisor) | |
def resize_img(img:Image.Image, patch_size: int): | |
"""Transform to resize the image dimensions according to patch size""" | |
# getting the suitable resize H and W | |
max_w = get_nearest_multiple(img.size[0], patch_size) * patch_size | |
max_h = get_nearest_multiple(img.size[1], patch_size) * patch_size | |
img = img.resize((max_w, max_h), Image.Resampling.BILINEAR) | |
return img, max_h, max_w | |
def preprocess_patches(patches, patch_size): | |
"""Transform for preproccesing the image patches to be suitable for model inference""" | |
patches = patches.reshape(-1, 3, patch_size, patch_size) | |
patches = torch.tensor(patches)/255. | |
return patches | |
def seg_maps_to_rgb(maps): | |
rgb_maps = np.zeros(maps.shape+(3,), np.uint8) # shape: n_patches x patch_size x patch_size x 3 | |
for i, m in enumerate(maps): | |
for k, v in mapping.items(): | |
rgb_maps[i, m==v[0]] = k | |
return rgb_maps | |
def make_prediction(img: Image.Image, patch_size = 224): | |
w, h = img.size | |
resized_img, max_h, max_w = resize_img(img, patch_size) | |
# patchifying and inference on each patch | |
patches = patchify(np.array(resized_img), (patch_size, patch_size, 3), step=patch_size) | |
patches_shape = patches.shape | |
patches = preprocess_patches(patches, patch_size) # shape: n_patches x n_channels x patch_size x patch_size | |
maps = [model(patch[None]) for patch in patches] | |
maps = np.concatenate(maps) | |
maps = np.argmax(maps, axis=1) # shape: n_patches x patch_size x patch_size | |
# convert 2d class maps to rgb | |
rgb_maps = seg_maps_to_rgb(maps) | |
# reshaping back the maps to the shape of parches | |
rgb_maps = rgb_maps.reshape(patches_shape) | |
final_mask = unpatchify(rgb_maps, (max_h, max_w, 3)) | |
# resizing back the mask to its original size | |
final_mask = Image.fromarray(final_mask).resize((w, h), Image.Resampling.NEAREST) | |
return final_mask | |
session = rt.InferenceSession('./satellite_seg_model_resnet50.onnx') | |
model = lambda patches: session.run(['seg_maps'], {'image_1_3_224_224': patches.numpy()})[0] | |
demo = gr.Interface( | |
make_prediction, | |
inputs=[gr.Image(type='pil')], | |
outputs="image", | |
title="Ariel Imaging Segmentation", | |
description = "Please upload an image to see segmentation capabilities of this model", | |
examples=[["img.jpg"]] | |
) | |
demo.launch() |