Spaces:
Runtime error
Runtime error
File size: 2,851 Bytes
af93c38 68a44cd 1fe972f af93c38 f2f193c af93c38 f2f193c af93c38 1fe972f af93c38 f2f193c 02cf8e9 f2f193c af93c38 f2f193c 68a44cd 1fe972f af93c38 348cb84 af93c38 a747dc5 af93c38 348cb84 af93c38 348cb84 af93c38 348cb84 a747dc5 348cb84 1fe972f f2f193c 1fe972f 348cb84 f2f193c a747dc5 348cb84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import os
import cv2
import gradio as gr
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
# suppress server-side GUI windows
matplotlib.pyplot.switch_backend('Agg')
# setup models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth")
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
predictor = SamPredictor(sam)
# copied from: https://github.com/facebookresearch/segment-anything
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []
for ann in sorted_anns:
m = ann['segmentation']
img = np.ones((m.shape[0], m.shape[1], 3))
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack((img, m*0.35)))
# demo function
def segment_image(input_image):
if input_image is not None:
# generate masks
masks = mask_generator.generate(input_image)
# add masks to image
plt.clf()
ppi = 100
height, width, _ = input_image.shape
plt.figure(figsize=(width / ppi, height / ppi)) # convert pixel to inches
plt.imshow(input_image)
show_anns(masks)
plt.axis('off')
# save and get figure
plt.savefig('output_figure.png', bbox_inches='tight')
output_image = cv2.imread('output_figure.png')
return Image.fromarray(output_image)
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown("## Segment Anything (by Meta AI Research)")
with gr.Row():
gr.Markdown("The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a dataset of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.")
with gr.Row():
with gr.Column():
image_input = gr.Image()
segment_image_button = gr.Button('Generate Mask')
with gr.Column():
image_output = gr.Image()
segment_image_button.click(segment_image, inputs=[image_input], outputs=image_output)
gr.Examples(
examples=[
['./examples/dog.jpg'],
['./examples/groceries.jpg'],
['./examples/truck.jpg']
],
inputs=[image_input],
outputs=[image_output],
fn=segment_image,
#cache_examples=True
)
demo.launch()
|