import os import torch import numpy as np import gradio as gr from segment_anything import build_sam, SamAutomaticMaskGenerator os.system(r'python -m wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth') hourglass_args = { "baseline": {}, "1.2x faster": { "use_hourglass": True, "hourglass_clustering_location": 14, "hourglass_num_cluster": 100, }, "1.5x faster": { "use_hourglass": True, "hourglass_clustering_location": 6, "hourglass_num_cluster": 81, }, } def predict(image, speed_mode): mask_generator = SamAutomaticMaskGenerator(build_sam(checkpoint="sam_vit_h_4b8939.pth", hourglass_kwargs=hourglass_args[speed_mode])) masks = mask_generator.generate(image) if len(masks) == 0: return image sorted_masks = sorted(masks, key=(lambda x: x['area']), reverse=True) img = np.ones(image.shape) for mask in sorted_masks: m = mask['segmentation'] color_mask = np.random.random((1, 1, 3)) img = img * (1 - m[..., None]) + color_mask * m[..., None] image = ((image + img * 255) / 2).astype(np.uint8) return image description = """ #
For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
' def main(): with gr.Blocks() as demo: gr.Markdown(description) with gr.Column(): with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input Image") speed_mode = gr.Dropdown( choices=list(hourglass_args.keys()), value="baseline", label="Speed Mode", multiselect=False, ) with gr.Row(): run_btn = gr.Button(label="Run", id="run", value="Run") clear_btn = gr.Button(label="Clear", id="clear", value="Clear") output_image = gr.Image(label="Output Image") gr.Examples( examples=[ ["./notebooks/images/dog.jpg"], ["notebooks/images/groceries.jpg"], ["notebooks/images/truck.jpg"], ], inputs=[input_image], outputs=[output_image], fn=predict, ) run_btn.click( fn=predict, inputs=[input_image, speed_mode], outputs=output_image ) clear_btn.click( fn=lambda: [None, None], inputs=None, outputs=[input_image, output_image], queue=False, ) demo.queue() demo.launch() if __name__ == "__main__": main()