import os import urllib import cv2 import gradio as gr import matplotlib.pyplot as plt import numpy as np from PIL import Image from segment_anything import SamAutomaticMaskGenerator, sam_model_registry # download model weights ckpts_dir = os.path.join(os.getcwd() + "/ckpts") if not os.path.exists(ckpts_dir): os.makedirs(ckpts_dir) ckpt_path = os.path.join(ckpts_dir + "/sam_vit_b_01ec64.pth") if not os.path.exists(ckpt_path): url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" urllib.request.urlretrieve(url, filename=ckpt_path) # setup model sam = sam_model_registry["vit_b"](checkpoint=ckpt_path) mask_generator = SamAutomaticMaskGenerator(sam) # copied from: https://github.com/facebookresearch/segment-anything def show_anns(anns): if len(anns) == 0: return sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) ax = plt.gca() ax.set_autoscale_on(False) polygons = [] color = [] for ann in sorted_anns: m = ann['segmentation'] img = np.ones((m.shape[0], m.shape[1], 3)) color_mask = np.random.random((1, 3)).tolist()[0] for i in range(3): img[:,:,i] = color_mask[i] ax.imshow(np.dstack((img, m*0.35))) # demo function def segment_image(input_image): # generate masks masks = mask_generator.generate(input_image) # add masks to image plt.clf() ppi = 100 height, width, _ = input_image.shape plt.figure(figsize=(width / ppi, height / ppi)) # convert pixel to inches plt.imshow(input_image) show_anns(masks) plt.axis('off') # save and get figure plt.savefig('output_figure.png', bbox_inches='tight') output_image = cv2.imread('output_figure.png') return Image.fromarray(output_image) with gr.Blocks() as demo: with gr.Row(): input_image = gr.Image(label='Input Image') output_image = gr.Image(label='Output Image') button = gr.Button('Mask Image') button.click(segment_image, inputs=[input_image], outputs=output_image) gr.Examples( examples = [ ['./imgs/cat.jpg'] ], inputs=[input_image], outputs=[output_image], fn=segment_image, cache_examples=True ) demo.launch()