Karin0616 commited on
Commit
a45a6f9
โ€ข
1 Parent(s): 7cb998a

example radio

Browse files
Files changed (1) hide show
  1. app.py +58 -126
app.py CHANGED
@@ -1,129 +1,61 @@
1
  import gradio as gr
2
- import random
3
-
4
- from matplotlib import gridspec
5
- import matplotlib.pyplot as plt
6
- import numpy as np
7
- from PIL import Image
8
  import tensorflow as tf
9
- from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
10
-
11
- feature_extractor = SegformerFeatureExtractor.from_pretrained(
12
- "nvidia/segformer-b5-finetuned-cityscapes-1024-1024"
13
- )
14
- model = TFSegformerForSemanticSegmentation.from_pretrained(
15
- "nvidia/segformer-b5-finetuned-cityscapes-1024-1024"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  )
17
-
18
- def ade_palette():
19
-
20
- return [
21
- [204, 87, 92], # road (Reddish)
22
- [112, 185, 212], # sidewalk (Blue)
23
- [196, 160, 122], # building (Brown)
24
- [106, 135, 242], # wall (Light Blue)
25
- [91, 192, 222], # fence (Turquoise)
26
- [255, 192, 203], # pole (Pink)
27
- [176, 224, 230], # traffic light (Light Blue)
28
- [222, 49, 99], # traffic sign (Red)
29
- [139, 69, 19], # vegetation (Brown)
30
- [255, 0, 0], # terrain (Red)
31
- [0, 0, 255], # sky (Blue)
32
- [255, 228, 181], # person (Peach)
33
- [128, 0, 0], # rider (Maroon)
34
- [0, 128, 0], # car (Green)
35
- [255, 99, 71], # truck (Tomato)
36
- [0, 255, 0], # bus (Lime)
37
- [128, 0, 128], # train (Purple)
38
- [255, 255, 0], # motorcycle (Yellow)
39
- [128, 0, 128] # bicycle (Purple)
40
-
41
- ]
42
-
43
- labels_list = []
44
-
45
- with open(r'labels.txt', 'r') as fp:
46
- for line in fp:
47
- labels_list.append(line[:-1])
48
-
49
- colormap = np.asarray(ade_palette())
50
-
51
- def label_to_color_image(label):
52
- if label.ndim != 2:
53
- raise ValueError("Expect 2-D input label")
54
-
55
- if np.max(label) >= len(colormap):
56
- raise ValueError("label value too large.")
57
- return colormap[label]
58
-
59
- def draw_plot(pred_img, seg):
60
- fig = plt.figure(figsize=(20, 15))
61
-
62
- grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
63
-
64
- plt.subplot(grid_spec[0])
65
- plt.imshow(pred_img)
66
- plt.axis('off')
67
- LABEL_NAMES = np.asarray(labels_list)
68
- FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
69
- FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
70
-
71
- unique_labels = np.unique(seg.numpy().astype("uint8"))
72
-
73
- ax = plt.subplot(grid_spec[1])
74
- plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
75
-
76
- ax.yaxis.tick_left()
77
- plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
78
- plt.xticks([], [])
79
- ax.tick_params(width=0.0, labelsize=27)
80
- return fig
81
-
82
- def sepia(input_img):
83
- input_img = Image.fromarray(input_img)
84
-
85
- inputs = feature_extractor(images=input_img, return_tensors="tf")
86
- outputs = model(**inputs)
87
- logits = outputs.logits
88
-
89
- logits = tf.transpose(logits, [0, 2, 3, 1])
90
- logits = tf.image.resize(
91
- logits, input_img.size[::-1]
92
- ) # We reverse the shape of `image` because `image.size` returns width and height.
93
- seg = tf.math.argmax(logits, axis=-1)[0]
94
-
95
- color_seg = np.zeros(
96
- (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
97
- ) # height, width, 3
98
- for label, color in enumerate(colormap):
99
- color_seg[seg.numpy() == label, :] = color
100
-
101
- # Show image + mask
102
- pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
103
- pred_img = pred_img.astype(np.uint8)
104
-
105
- fig = draw_plot(pred_img, seg)
106
- return fig
107
-
108
-
109
- demo = gr.Interface(fn=sepia,
110
- inputs=gr.Image(shape=(564,846)),
111
- outputs=['plot'],
112
- live=True,
113
- examples=["city1.jpg","city2.jpg","city3.jpg"],
114
- allow_flagging='never',
115
- title="This is a machine learning activity project at Kyunggi University.",
116
- theme="darkpeach",
117
- css="""
118
- body {
119
- background-color: dark;
120
- color: white; /* ํฐํŠธ ์ƒ‰์ƒ ์ˆ˜์ • */
121
- font-family: Arial, sans-serif; /* ํฐํŠธ ํŒจ๋ฐ€๋ฆฌ ์ˆ˜์ • */
122
- }
123
- """
124
-
125
- )
126
-
127
-
128
- demo.launch()
129
-
 
1
  import gradio as gr
 
 
 
 
 
 
2
  import tensorflow as tf
3
+ from PIL import Image
4
+ import requests
5
+
6
+ # ๋ชจ๋ธ ๋กœ๋“œ
7
+ model = tf.saved_model.load("nvidia_segformer_b5_finetuned_cityscapes_1024")
8
+
9
+ # ๋ ˆ์ด๋ธ” ๋ฐ ์ƒ‰์ƒ ์ •์˜
10
+ label_colors = {
11
+ "road": [204, 87, 92],
12
+ "sidewalk": [112, 185, 212],
13
+ "building": [196, 160, 122],
14
+ "wall": [106, 135, 242],
15
+ "fence": [91, 192, 222],
16
+ "pole": [255, 192, 203],
17
+ "traffic_light": [176, 224, 230],
18
+ "traffic_sign": [222, 49, 99],
19
+ "vegetation": [139, 69, 19],
20
+ "terrain": [255, 0, 0],
21
+ "sky": [0, 0, 255],
22
+ "person": [255, 228, 181],
23
+ "rider": [128, 0, 0],
24
+ "car": [0, 128, 0],
25
+ "truck": [255, 99, 71],
26
+ "bus": [0, 255, 0],
27
+ "train": [128, 0, 128],
28
+ "motorcycle": [255, 255, 0],
29
+ "bicycle": [128, 0, 128]
30
+ }
31
+
32
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ •์˜
33
+ iface = gr.Interface(
34
+ fn=lambda image: predict_segmentation(image, model),
35
+ inputs="image",
36
+ outputs="image"
37
  )
38
+ iface.launch()
39
+
40
+ # ์ด๋ฏธ์ง€ ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ํ•จ์ˆ˜ ์ •์˜
41
+ def predict_segmentation(image, model):
42
+ # ์ด๋ฏธ์ง€ ๋ณ€ํ™˜
43
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
44
+ image = image.resize((1024, 1024)) # ๋ชจ๋ธ์˜ ์ž…๋ ฅ ํฌ๊ธฐ์— ๋งž๊ฒŒ ์กฐ์ ˆ
45
+ image_array = tf.keras.preprocessing.image.img_to_array(image)
46
+ image_array = tf.expand_dims(image_array, 0)
47
+
48
+ # ๋ชจ๋ธ ์ถ”๋ก 
49
+ predictions = model(image_array)["output_0"]
50
+
51
+ # ๋ ˆ์ด๋ธ”๋ณ„ ์ƒ‰์ƒ ๋งคํ•‘
52
+ segmented_image = tf.zeros_like(predictions)
53
+ for label, color in label_colors.items():
54
+ mask = tf.reduce_all(tf.equal(predictions, color), axis=-1, keepdims=True)
55
+ for i in range(3):
56
+ segmented_image += tf.cast(mask, tf.float32) * tf.constant(color[i], dtype=tf.float32)
57
+
58
+ # ์ด๋ฏธ์ง€ ๋ฆฌํ„ด
59
+ segmented_image = tf.cast(segmented_image, tf.uint8)
60
+ segmented_image = tf.image.resize(segmented_image, [image.height, image.width])
61
+ return segmented_image.numpy()