jbrinkma commited on
Commit
1fe972f
1 Parent(s): 02cf8e9

update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -16
app.py CHANGED
@@ -2,6 +2,7 @@ import os
2
 
3
  import cv2
4
  import gradio as gr
 
5
  import matplotlib.pyplot as plt
6
  import numpy as np
7
 
@@ -9,11 +10,14 @@ from PIL import Image
9
 
10
  from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
11
 
 
 
12
 
13
  # setup model
14
  sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth")
15
  mask_generator = SamAutomaticMaskGenerator(sam)
16
 
 
17
  # copied from: https://github.com/facebookresearch/segment-anything
18
  def show_anns(anns):
19
  if len(anns) == 0:
@@ -54,21 +58,17 @@ def segment_image(input_image):
54
 
55
  with gr.Blocks() as demo:
56
 
57
- with gr.Row():
58
- input_image = gr.Image(label='Input Image')
59
- output_image = gr.Image(label='Output Image')
60
-
61
- button = gr.Button('Mask Image')
62
- button.click(segment_image, inputs=[input_image], outputs=output_image)
63
-
64
- gr.Examples(
65
- examples = [
66
- ['./imgs/cat.jpg']
67
- ],
68
- inputs=[input_image],
69
- outputs=[output_image],
70
- fn=segment_image,
71
- cache_examples=True
72
- )
73
 
74
  demo.launch()
 
2
 
3
  import cv2
4
  import gradio as gr
5
+ import matplotlib
6
  import matplotlib.pyplot as plt
7
  import numpy as np
8
 
 
10
 
11
  from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
12
 
13
+ # suppress server-side GUI windows
14
+ matplotlib.pyplot.switch_backend('Agg')
15
 
16
  # setup model
17
  sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth")
18
  mask_generator = SamAutomaticMaskGenerator(sam)
19
 
20
+
21
  # copied from: https://github.com/facebookresearch/segment-anything
22
  def show_anns(anns):
23
  if len(anns) == 0:
 
58
 
59
  with gr.Blocks() as demo:
60
 
61
+ gr.Markdown("The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a dataset of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.")
62
+
63
+ with gr.Tabs():
64
+
65
+ with gr.TabItem("Mask Generator"):
66
+
67
+ with gr.Row():
68
+ image_input = gr.Image()
69
+ image_output = gr.Image()
70
+ segment_image_button = gr.Button('Generate Mask')
71
+
72
+ segment_image_button.click(segment_image, inputs=[image_input], outputs=image_output)
 
 
 
 
73
 
74
  demo.launch()