hyo37009 commited on
Commit
8de291e
1 Parent(s): 3c17f91
Files changed (2) hide show
  1. app.py +109 -13
  2. labels.txt +18 -0
app.py CHANGED
@@ -1,30 +1,126 @@
1
  import gradio as gr
2
  #
3
- from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
 
 
4
  from PIL import Image
 
 
5
  import requests
 
6
  #
7
 
8
  feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-640-1280")
9
- model = TFSegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-640-1280")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
 
 
11
 
12
- url = "http://images.cocodataset.org/val2017/000000039769.jpg"
13
- image = Image.open(requests.get(url, stream=True).raw)
14
 
 
 
 
 
 
 
15
 
 
 
 
 
 
 
 
 
16
 
17
- inputs = feature_extractor(images=image, return_tensors="pt")
18
- outputs = model(**inputs)
19
- logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4)
20
 
 
 
 
21
 
 
 
 
22
 
23
- def greet():
24
- return outputs
25
 
26
  iface = gr.Interface(
27
- fn=sepia,
28
- inputs="text",
29
- outputs=["plot"])
30
- iface.launch(share=True)
 
 
 
1
  import gradio as gr
2
  #
3
+ from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
4
+ import matplotlib.pyplot as plt
5
+ from matplotlib import gridspec
6
  from PIL import Image
7
+ import numpy as np
8
+ import tensorflow as tf
9
  import requests
10
+
11
  #
12
 
13
  feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-640-1280")
14
+ model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-640-1280")
15
+
16
+ urls = ["http://farm3.staticflickr.com/2523/3705549787_79049b1b6d_z.jpg",
17
+ "http://farm8.staticflickr.com/7012/6476201279_52db36af64_z.jpg",
18
+ "http://farm8.staticflickr.com/7180/6967423255_a3d65d5f6b_z.jpg",
19
+ "http://farm4.staticflickr.com/3563/3470840644_3378804bea_z.jpg",
20
+ "http://farm9.staticflickr.com/8388/8516454091_0ebdc1130a_z.jpg"]
21
+ images = []
22
+ for i in urls:
23
+ images.append(Image.open(requests.get(i, stream=True).raw))
24
+
25
+
26
+
27
+ # inputs = feature_extractor(images=image, return_tensors="pt")
28
+ # outputs = model(**inputs)
29
+ # logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4)
30
+
31
+ def my_palette():
32
+ return [
33
+ [131, 162, 255],
34
+ [180, 189, 255],
35
+ [255, 227, 187],
36
+ [255, 210, 143],
37
+ [248, 117, 170],
38
+ [255, 223, 223],
39
+ [255, 246, 246],
40
+ [174, 222, 252],
41
+ [150, 194, 145],
42
+ [255, 219, 170],
43
+ [244, 238, 238],
44
+ [50, 38, 83],
45
+ [128, 98, 214],
46
+ [146, 136, 248],
47
+ [255, 210, 215],
48
+ [255, 152, 152],
49
+ [162, 103, 138],
50
+ [63, 29, 56]
51
+ ]
52
+
53
+
54
+ labels_list = []
55
+
56
+ with open(r"labels.txt", "r") as fp:
57
+ for line in fp:
58
+ labels_list.append(line[:-1])
59
+
60
+ colormap = np.asarray(my_palette())
61
+
62
+
63
+ def greet(input_img):
64
+ inputs = feature_extractor(images=input_img, return_tensors="pt")
65
+ outputs = model(**inputs)
66
+ logits = outputs.logits
67
+
68
+ logits = tf.transpose(logits, [0, 2, 3, 1])
69
+
70
+ logits = tf.image.resize(
71
+ logits, input_img.size[::-1]
72
+ ) # We reverse the shape of `image` because `image.size` returns width and height.
73
+ seg = tf.math.argmax(logits, axis=-1)[0]
74
+
75
+ color_seg = np.zeros(
76
+ (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
77
+ ) # height, width, 3
78
+ for label, color in enumerate(colormap):
79
+ color_seg[seg.numpy() == label, :] = color
80
+
81
+ # Show image + mask
82
+ pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
83
+ pred_img = pred_img.astype(np.uint8)
84
+
85
+ fig = draw_plot(pred_img, seg)
86
+ return fig
87
+
88
 
89
+ def draw_plot(pred_img, seg):
90
+ fig = plt.figure(figsize=(20, 15))
91
 
92
+ grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
 
93
 
94
+ plt.subplot(grid_spec[0])
95
+ plt.imshow(pred_img)
96
+ plt.axis("off")
97
+ LABEL_NAMES = np.asarray(labels_list)
98
+ FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
99
+ FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
100
 
101
+ unique_labels = np.unique(seg.numpy().astype("uint8"))
102
+ ax = plt.subplot(grid_spec[1])
103
+ plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
104
+ ax.yaxis.tick_right()
105
+ plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
106
+ plt.xticks([], [])
107
+ ax.tick_params(width=0.0, labelsize=25)
108
+ return fig
109
 
 
 
 
110
 
111
+ def label_to_color_image(label):
112
+ if label.ndim != 2:
113
+ raise ValueError("Expect 2-D input label")
114
 
115
+ if np.max(label) >= len(colormap):
116
+ raise ValueError("label value too large.")
117
+ return colormap[label]
118
 
 
 
119
 
120
  iface = gr.Interface(
121
+ fn=greet,
122
+ inputs=gr.Image(shape=(640, 1280)),
123
+ outputs=["plot"],
124
+ examples=[images],
125
+ allow_flagging="never")
126
+ iface.launch(share=True)
labels.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sidewalk
2
+ building
3
+ wall
4
+ fence
5
+ pole
6
+ traffic light
7
+ traffic sign
8
+ vegetation
9
+ terrain
10
+ sky
11
+ person
12
+ rider
13
+ car
14
+ truck
15
+ bus
16
+ train
17
+ motorcycle
18
+ bicycle