jbrinkma commited on
Commit
f2f193c
1 Parent(s): 1fe972f

update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -5,17 +5,21 @@ import gradio as gr
5
  import matplotlib
6
  import matplotlib.pyplot as plt
7
  import numpy as np
 
8
 
9
  from PIL import Image
10
 
11
- from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
12
 
13
  # suppress server-side GUI windows
14
  matplotlib.pyplot.switch_backend('Agg')
15
 
16
- # setup model
 
17
  sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth")
 
18
  mask_generator = SamAutomaticMaskGenerator(sam)
 
19
 
20
 
21
  # copied from: https://github.com/facebookresearch/segment-anything
@@ -60,15 +64,11 @@ with gr.Blocks() as demo:
60
 
61
  gr.Markdown("The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a dataset of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.")
62
 
63
- with gr.Tabs():
 
 
64
 
65
- with gr.TabItem("Mask Generator"):
66
-
67
- with gr.Row():
68
- image_input = gr.Image()
69
- image_output = gr.Image()
70
- segment_image_button = gr.Button('Generate Mask')
71
-
72
- segment_image_button.click(segment_image, inputs=[image_input], outputs=image_output)
73
 
74
  demo.launch()
 
5
  import matplotlib
6
  import matplotlib.pyplot as plt
7
  import numpy as np
8
+ import torch
9
 
10
  from PIL import Image
11
 
12
+ from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
13
 
14
  # suppress server-side GUI windows
15
  matplotlib.pyplot.switch_backend('Agg')
16
 
17
+ # setup models
18
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
  sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth")
20
+ sam.to(device=device)
21
  mask_generator = SamAutomaticMaskGenerator(sam)
22
+ predictor = SamPredictor(sam)
23
 
24
 
25
  # copied from: https://github.com/facebookresearch/segment-anything
 
64
 
65
  gr.Markdown("The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a dataset of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.")
66
 
67
+ with gr.Row():
68
+ image_input = gr.Image()
69
+ image_output = gr.Image()
70
 
71
+ segment_image_button = gr.Button('Generate Mask')
72
+ segment_image_button.click(segment_image, inputs=[image_input], outputs=image_output)
 
 
 
 
 
 
73
 
74
  demo.launch()