Spaces:
Build error
Build error
| import gradio as gr | |
| import onnxruntime as rt | |
| import numpy as np | |
| from transforms import ResizeLongestSide | |
| from torch.nn import functional as F | |
| import torch | |
| import onnxruntime | |
| IMAGE_SIZE = 1024 | |
| def preprocess_image(image): | |
| transform = ResizeLongestSide(IMAGE_SIZE) | |
| input_image = transform.apply_image(image) | |
| input_image_torch = torch.as_tensor(input_image, device="cpu") | |
| input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] | |
| pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1) | |
| pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1) | |
| x = (input_image_torch - pixel_mean) / pixel_std | |
| h, w = x.shape[-2:] | |
| padh = IMAGE_SIZE - h | |
| padw = IMAGE_SIZE - w | |
| x = F.pad(x, (0, padw, 0, padh)) | |
| x = x.numpy() | |
| return x | |
| def prepare_inputs(image_embedding, input_point, image_shape): | |
| transform = ResizeLongestSide(IMAGE_SIZE) | |
| input_label = np.array([1]) | |
| onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :] | |
| onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32) | |
| onnx_coord = transform.apply_coords(onnx_coord, image_shape).astype(np.float32) | |
| onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) | |
| onnx_has_mask_input = np.zeros(1, dtype=np.float32) | |
| decoder_inputs = { | |
| "image_embeddings": image_embedding, | |
| "point_coords": onnx_coord, | |
| "point_labels": onnx_label, | |
| "mask_input": onnx_mask_input, | |
| "has_mask_input": onnx_has_mask_input, | |
| "orig_im_size": np.array(image_shape, dtype=np.float32) | |
| } | |
| return decoder_inputs | |
| enc_session = onnxruntime.InferenceSession("encoder-quant.onnx") | |
| dec_session = onnxruntime.InferenceSession("decoder-quant.onnx") | |
| def predict_image(img): | |
| x = preprocess_image(img) | |
| encoder_inputs = { | |
| "x": x, | |
| } | |
| output = enc_session.run(None, encoder_inputs) | |
| image_embedding = output[0] | |
| middle_of_photo = np.array([[img.shape[1] / 2, img.shape[0] / 2]]) | |
| decoder_inputs = prepare_inputs(image_embedding, middle_of_photo, img.shape[:2]) | |
| masks, _, low_res_logits = dec_session.run(None, decoder_inputs) | |
| # normalize the results between -1 and 1 | |
| masks = masks[0][0] | |
| masks[masks<0] = 0 | |
| masks = masks / np.max(masks) | |
| return masks, image_embedding, img.shape[:2] | |
| def segment_image(image_embedding, shape, evt: gr.SelectData): | |
| image_embedding = np.array(image_embedding) | |
| middle_of_photo = np.array([evt.index]) | |
| decoder_inputs = prepare_inputs(image_embedding, middle_of_photo, shape) | |
| masks, _, low_res_logits = dec_session.run(None, decoder_inputs) | |
| # normalize the results between -1 and 1 | |
| masks = masks[0][0] | |
| masks[masks<0] = 0 | |
| masks = masks / np.max(masks) | |
| return masks | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# SAM quantized (Segment Anything Model)") | |
| markdown = """ | |
| This is a demo of the SAM model, which is a model for segmenting anything in an image. | |
| It returns segmentation mask of the image that's overlapping with the clicked point. | |
| The model is quantized using ONNX Runtime | |
| """ | |
| gr.Markdown(markdown) | |
| embedding = gr.State() | |
| shape = gr.State() | |
| with gr.Row(): | |
| with gr.Column(): | |
| inputs = gr.Image() | |
| start_segmentation = gr.Button("Segment") | |
| with gr.Column(): | |
| outputs = gr.Image(label="Segmentation Mask") | |
| start_segmentation.click( | |
| predict_image, | |
| inputs, | |
| [outputs, embedding, shape], | |
| ) | |
| outputs.select( | |
| segment_image, | |
| [embedding, shape], | |
| outputs, | |
| ) | |
| demo.launch() |