import gradio as gr import onnxruntime as rt import numpy as np from transforms import ResizeLongestSide from torch.nn import functional as F import torch import onnxruntime IMAGE_SIZE = 1024 def preprocess_image(image): transform = ResizeLongestSide(IMAGE_SIZE) input_image = transform.apply_image(image) input_image_torch = torch.as_tensor(input_image, device="cpu") input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1) pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1) x = (input_image_torch - pixel_mean) / pixel_std h, w = x.shape[-2:] padh = IMAGE_SIZE - h padw = IMAGE_SIZE - w x = F.pad(x, (0, padw, 0, padh)) x = x.numpy() return x def prepare_inputs(image_embedding, input_point, image_shape): transform = ResizeLongestSide(IMAGE_SIZE) input_label = np.array([1]) onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :] onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32) onnx_coord = transform.apply_coords(onnx_coord, image_shape).astype(np.float32) onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) onnx_has_mask_input = np.zeros(1, dtype=np.float32) decoder_inputs = { "image_embeddings": image_embedding, "point_coords": onnx_coord, "point_labels": onnx_label, "mask_input": onnx_mask_input, "has_mask_input": onnx_has_mask_input, "orig_im_size": np.array(image_shape, dtype=np.float32) } return decoder_inputs enc_session = onnxruntime.InferenceSession("encoder-quant.onnx") dec_session = onnxruntime.InferenceSession("decoder-quant.onnx") def predict_image(img): x = preprocess_image(img) encoder_inputs = { "x": x, } output = enc_session.run(None, encoder_inputs) image_embedding = output[0] middle_of_photo = np.array([[img.shape[1] / 2, img.shape[0] / 2]]) decoder_inputs = prepare_inputs(image_embedding, middle_of_photo, img.shape[:2]) masks, _, low_res_logits = dec_session.run(None, decoder_inputs) # normalize the results between -1 and 1 masks = masks[0][0] masks[masks<0] = 0 masks = masks / np.max(masks) return masks, image_embedding, img.shape[:2] def segment_image(image_embedding, shape, evt: gr.SelectData): image_embedding = np.array(image_embedding) middle_of_photo = np.array([evt.index]) decoder_inputs = prepare_inputs(image_embedding, middle_of_photo, shape) masks, _, low_res_logits = dec_session.run(None, decoder_inputs) # normalize the results between -1 and 1 masks = masks[0][0] masks[masks<0] = 0 masks = masks / np.max(masks) return masks with gr.Blocks() as demo: gr.Markdown("# SAM quantized (Segment Anything Model)") markdown = """ This is a demo of the SAM model, which is a model for segmenting anything in an image. It returns segmentation mask of the image that's overlapping with the clicked point. The model is quantized using ONNX Runtime """ gr.Markdown(markdown) embedding = gr.State() shape = gr.State() with gr.Row(): with gr.Column(): inputs = gr.Image() start_segmentation = gr.Button("Segment") with gr.Column(): outputs = gr.Image(label="Segmentation Mask") start_segmentation.click( predict_image, inputs, [outputs, embedding, shape], ) outputs.select( segment_image, [embedding, shape], outputs, ) demo.launch()