File size: 2,851 Bytes
af93c38
 
 
68a44cd
1fe972f
af93c38
 
f2f193c
af93c38
 
 
f2f193c
af93c38
1fe972f
 
af93c38
f2f193c
 
02cf8e9
f2f193c
af93c38
f2f193c
68a44cd
1fe972f
af93c38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348cb84
af93c38
a747dc5
af93c38
348cb84
 
 
 
af93c38
348cb84
 
 
 
 
 
 
 
af93c38
348cb84
 
 
 
a747dc5
 
 
 
348cb84
 
 
 
1fe972f
f2f193c
1fe972f
348cb84
 
 
 
 
 
 
f2f193c
a747dc5
348cb84
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os

import cv2
import gradio as gr
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch

from PIL import Image

from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry

# suppress server-side GUI windows
matplotlib.pyplot.switch_backend('Agg') 

# setup models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth")
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
predictor = SamPredictor(sam)


# copied from: https://github.com/facebookresearch/segment-anything
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)
    polygons = []
    color = []
    for ann in sorted_anns:
        m = ann['segmentation']
        img = np.ones((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        for i in range(3):
            img[:,:,i] = color_mask[i]
        ax.imshow(np.dstack((img, m*0.35)))


# demo function
def segment_image(input_image):

    if input_image is not None:

        # generate masks
        masks = mask_generator.generate(input_image)

        # add masks to image
        plt.clf()
        ppi = 100
        height, width, _ = input_image.shape
        plt.figure(figsize=(width / ppi, height / ppi))  # convert pixel to inches
        plt.imshow(input_image)
        show_anns(masks)
        plt.axis('off')

        # save and get figure
        plt.savefig('output_figure.png', bbox_inches='tight')
        output_image = cv2.imread('output_figure.png')
        return Image.fromarray(output_image)


with gr.Blocks() as demo:

    with gr.Row():
        gr.Markdown("## Segment Anything (by Meta AI Research)")
    with gr.Row():
        gr.Markdown("The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a dataset of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.")

    with gr.Row():

        with gr.Column():
            image_input = gr.Image()
            segment_image_button = gr.Button('Generate Mask')

        with gr.Column():
            image_output = gr.Image()

    segment_image_button.click(segment_image, inputs=[image_input], outputs=image_output)

    gr.Examples(
        examples=[
            ['./examples/dog.jpg'],
            ['./examples/groceries.jpg'],
            ['./examples/truck.jpg']

        ],
        inputs=[image_input],
        outputs=[image_output],
        fn=segment_image,
        #cache_examples=True
    )

demo.launch()