jbrinkma commited on
Commit
af93c38
1 Parent(s): a747dc5

add segmentation function

Browse files
Files changed (1) hide show
  1. app.py +58 -2
app.py CHANGED
@@ -1,9 +1,65 @@
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
 
 
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  def segment_image(input_image):
5
- output_image = input_image
6
- return output_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
 
9
  with gr.Blocks() as demo:
 
1
+ import os
2
+ import urllib
3
+
4
+ import cv2
5
  import gradio as gr
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+
9
+ from PIL import Image
10
+
11
+ from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
12
+
13
+
14
+ # download model weights
15
+ ckpts_dir = os.path.join(os.getcwd() + "/ckpts")
16
+ if not os.path.exists(ckpts_dir):
17
+ os.makedirs(ckpts_dir)
18
+ ckpt_path = os.path.join(ckpts_dir + "/vit_b.pth")
19
+ if not os.path.exists(ckpt_path):
20
+ url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
21
+ urllib.request.urlretrieve(url, filename=ckpt_path)
22
 
23
+ # setup model
24
+ sam = sam_model_registry["vit_b"](checkpoint=ckpt_path)
25
+ mask_generator = SamAutomaticMaskGenerator(sam)
26
 
27
+ # copied from: https://github.com/facebookresearch/segment-anything
28
+ def show_anns(anns):
29
+ if len(anns) == 0:
30
+ return
31
+ sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
32
+ ax = plt.gca()
33
+ ax.set_autoscale_on(False)
34
+ polygons = []
35
+ color = []
36
+ for ann in sorted_anns:
37
+ m = ann['segmentation']
38
+ img = np.ones((m.shape[0], m.shape[1], 3))
39
+ color_mask = np.random.random((1, 3)).tolist()[0]
40
+ for i in range(3):
41
+ img[:,:,i] = color_mask[i]
42
+ ax.imshow(np.dstack((img, m*0.35)))
43
+
44
+ # demo function
45
  def segment_image(input_image):
46
+
47
+ # generate masks
48
+ masks = mask_generator.generate(input_image)
49
+
50
+ # add masks to image
51
+ plt.clf()
52
+ ppi = 100
53
+ height, width, _ = input_image.shape
54
+ plt.figure(figsize=(width / ppi, height / ppi)) # convert pixel to inches
55
+ plt.imshow(input_image)
56
+ show_anns(masks)
57
+ plt.axis('off')
58
+
59
+ # save and get figure
60
+ plt.savefig('output_figure.png', bbox_inches='tight')
61
+ output_image = cv2.imread('output_figure.png')
62
+ return Image.fromarray(output_image)
63
 
64
 
65
  with gr.Blocks() as demo: