Spaces:
Sleeping
Sleeping
| import onnx | |
| import onnxruntime as rt | |
| from PIL import Image | |
| import numpy as np | |
| from patchify import patchify, unpatchify | |
| import torch | |
| import gradio as gr | |
| mapping = {(155, 155, 155): [0, 'Unlabeled'], | |
| (60, 16, 152): [1, 'Building'], | |
| (132, 41, 246): [2, 'Land'], | |
| (110, 193, 228): [3, 'Road'], | |
| (254, 221, 58): [4, 'Vegetation'], | |
| (226, 169, 41): [5, 'Water']} | |
| def get_nearest_multiple(target, divisor): | |
| return round(target/divisor) | |
| def resize_img(img:Image.Image, patch_size: int): | |
| """Transform to resize the image dimensions according to patch size""" | |
| # getting the suitable resize H and W | |
| max_w = get_nearest_multiple(img.size[0], patch_size) * patch_size | |
| max_h = get_nearest_multiple(img.size[1], patch_size) * patch_size | |
| img = img.resize((max_w, max_h), Image.Resampling.BILINEAR) | |
| return img, max_h, max_w | |
| def preprocess_patches(patches, patch_size): | |
| """Transform for preproccesing the image patches to be suitable for model inference""" | |
| patches = patches.reshape(-1, 3, patch_size, patch_size) | |
| patches = torch.tensor(patches)/255. | |
| return patches | |
| def seg_maps_to_rgb(maps): | |
| rgb_maps = np.zeros(maps.shape+(3,), np.uint8) # shape: n_patches x patch_size x patch_size x 3 | |
| for i, m in enumerate(maps): | |
| for k, v in mapping.items(): | |
| rgb_maps[i, m==v[0]] = k | |
| return rgb_maps | |
| def make_prediction(img: Image.Image, patch_size = 224): | |
| w, h = img.size | |
| resized_img, max_h, max_w = resize_img(img, patch_size) | |
| # patchifying and inference on each patch | |
| patches = patchify(np.array(resized_img), (patch_size, patch_size, 3), step=patch_size) | |
| patches_shape = patches.shape | |
| patches = preprocess_patches(patches, patch_size) # shape: n_patches x n_channels x patch_size x patch_size | |
| maps = [model(patch[None]) for patch in patches] | |
| maps = np.concatenate(maps) | |
| maps = np.argmax(maps, axis=1) # shape: n_patches x patch_size x patch_size | |
| # convert 2d class maps to rgb | |
| rgb_maps = seg_maps_to_rgb(maps) | |
| # reshaping back the maps to the shape of parches | |
| rgb_maps = rgb_maps.reshape(patches_shape) | |
| final_mask = unpatchify(rgb_maps, (max_h, max_w, 3)) | |
| # resizing back the mask to its original size | |
| final_mask = Image.fromarray(final_mask).resize((w, h), Image.Resampling.NEAREST) | |
| return final_mask | |
| session = rt.InferenceSession('./satellite_seg_model_resnet50.onnx') | |
| model = lambda patches: session.run(['seg_maps'], {'image_1_3_224_224': patches.numpy()})[0] | |
| demo = gr.Interface( | |
| make_prediction, | |
| inputs=[gr.Image(type='pil')], | |
| outputs="image", | |
| title="Ariel Imaging Segmentation", | |
| description = "Please upload an image to see segmentation capabilities of this model", | |
| examples=[["img.jpg"]] | |
| ) | |
| demo.launch() |