Thiago Hersan commited on
Commit
a49b93f
1 Parent(s): 4c74ca5

simplifies query_image. uses non-deprecated function

Browse files
Files changed (1) hide show
  1. app.py +17 -8
app.py CHANGED
@@ -1,6 +1,4 @@
1
  import gradio as gr
2
- import torch
3
- import random
4
  import numpy as np
5
  from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation
6
 
@@ -10,23 +8,34 @@ from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmen
10
  preprocessor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-large-coco")
11
  model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-large-coco")
12
 
 
13
  def visualize_instance_seg_mask(mask):
14
  image = np.zeros((mask.shape[0], mask.shape[1], 3))
 
15
  labels = np.unique(mask)
16
- label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}
 
 
 
17
  for i in range(image.shape[0]):
18
  for j in range(image.shape[1]):
19
  image[i, j, :] = label2color[mask[i, j]]
 
 
20
  image = image / 255
 
 
 
 
21
  return image
22
 
 
23
  def query_image(img):
24
- target_size = (img.shape[0], img.shape[1])
25
  inputs = preprocessor(images=img, return_tensors="pt")
26
  outputs = model(**inputs)
27
- results = preprocessor.post_process_segmentation(outputs=outputs, target_size=target_size)[0]
28
- results = torch.argmax(results, dim=0).numpy()
29
- results = visualize_instance_seg_mask(results)
30
  return results
31
 
32
 
@@ -34,7 +43,7 @@ demo = gr.Interface(
34
  query_image,
35
  inputs=[gr.Image()],
36
  outputs="image",
37
- title="maskformer-swin-tiny-ade results",
38
  allow_flagging="never",
39
  analytics_enabled=None
40
  )
 
1
  import gradio as gr
 
 
2
  import numpy as np
3
  from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation
4
 
 
8
  preprocessor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-large-coco")
9
  model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-large-coco")
10
 
11
+
12
  def visualize_instance_seg_mask(mask):
13
  image = np.zeros((mask.shape[0], mask.shape[1], 3))
14
+ image_total_pixels = mask.shape[0] * mask.shape[1]
15
  labels = np.unique(mask)
16
+
17
+ label2color = {label: (np.random.randint(0, 2), np.random.randint(0, 256), np.random.randint(0, 256)) for label in labels}
18
+ label2count = {label: 0 for label in labels}
19
+
20
  for i in range(image.shape[0]):
21
  for j in range(image.shape[1]):
22
  image[i, j, :] = label2color[mask[i, j]]
23
+ label2count[mask[i, j]] = label2count[mask[i, j]] + 1
24
+
25
  image = image / 255
26
+
27
+ for k, v in label2count.items():
28
+ label2count[k] = v / image_total_pixels
29
+
30
  return image
31
 
32
+
33
  def query_image(img):
34
+ img_size = (img.shape[0], img.shape[1])
35
  inputs = preprocessor(images=img, return_tensors="pt")
36
  outputs = model(**inputs)
37
+ results = preprocessor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[img_size])[0]
38
+ results = visualize_instance_seg_mask(results.numpy())
 
39
  return results
40
 
41
 
 
43
  query_image,
44
  inputs=[gr.Image()],
45
  outputs="image",
46
+ title="maskformer-swin-large-coco results",
47
  allow_flagging="never",
48
  analytics_enabled=None
49
  )