Spaces:
Runtime error
Runtime error
update app.py
Browse files
app.py
CHANGED
@@ -5,17 +5,21 @@ import gradio as gr
|
|
5 |
import matplotlib
|
6 |
import matplotlib.pyplot as plt
|
7 |
import numpy as np
|
|
|
8 |
|
9 |
from PIL import Image
|
10 |
|
11 |
-
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
|
12 |
|
13 |
# suppress server-side GUI windows
|
14 |
matplotlib.pyplot.switch_backend('Agg')
|
15 |
|
16 |
-
# setup
|
|
|
17 |
sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth")
|
|
|
18 |
mask_generator = SamAutomaticMaskGenerator(sam)
|
|
|
19 |
|
20 |
|
21 |
# copied from: https://github.com/facebookresearch/segment-anything
|
@@ -60,15 +64,11 @@ with gr.Blocks() as demo:
|
|
60 |
|
61 |
gr.Markdown("The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a dataset of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.")
|
62 |
|
63 |
-
with gr.
|
|
|
|
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
with gr.Row():
|
68 |
-
image_input = gr.Image()
|
69 |
-
image_output = gr.Image()
|
70 |
-
segment_image_button = gr.Button('Generate Mask')
|
71 |
-
|
72 |
-
segment_image_button.click(segment_image, inputs=[image_input], outputs=image_output)
|
73 |
|
74 |
demo.launch()
|
|
|
5 |
import matplotlib
|
6 |
import matplotlib.pyplot as plt
|
7 |
import numpy as np
|
8 |
+
import torch
|
9 |
|
10 |
from PIL import Image
|
11 |
|
12 |
+
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
|
13 |
|
14 |
# suppress server-side GUI windows
|
15 |
matplotlib.pyplot.switch_backend('Agg')
|
16 |
|
17 |
+
# setup models
|
18 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
19 |
sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth")
|
20 |
+
sam.to(device=device)
|
21 |
mask_generator = SamAutomaticMaskGenerator(sam)
|
22 |
+
predictor = SamPredictor(sam)
|
23 |
|
24 |
|
25 |
# copied from: https://github.com/facebookresearch/segment-anything
|
|
|
64 |
|
65 |
gr.Markdown("The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a dataset of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.")
|
66 |
|
67 |
+
with gr.Row():
|
68 |
+
image_input = gr.Image()
|
69 |
+
image_output = gr.Image()
|
70 |
|
71 |
+
segment_image_button = gr.Button('Generate Mask')
|
72 |
+
segment_image_button.click(segment_image, inputs=[image_input], outputs=image_output)
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
demo.launch()
|