sujeongim0402@gmail.com commited on
Commit
acd3317
โ€ข
1 Parent(s): cbc2699

edit codes

Browse files
Files changed (1) hide show
  1. app.py +20 -4
app.py CHANGED
@@ -7,6 +7,7 @@ from PIL import Image
7
  import tensorflow as tf
8
  from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
9
 
 
10
  feature_extractor = SegformerFeatureExtractor.from_pretrained(
11
  "nvidia/segformer-b3-finetuned-cityscapes-1024-1024"
12
 
@@ -15,6 +16,7 @@ model = TFSegformerForSemanticSegmentation.from_pretrained(
15
  "nvidia/segformer-b3-finetuned-cityscapes-1024-1024"
16
  )
17
 
 
18
  def ade_palette():
19
  return [
20
  [204, 87, 92],
@@ -38,14 +40,17 @@ def ade_palette():
38
  [180, 32, 10],
39
  ]
40
 
 
41
  labels_list = []
42
 
43
  with open(r'labels.txt', 'r') as fp:
44
  for line in fp:
45
  labels_list.append(line[:-1])
46
 
 
47
  colormap = np.asarray(ade_palette())
48
 
 
49
  def label_to_color_image(label):
50
  if label.ndim != 2:
51
  raise ValueError("Expect 2-D input label")
@@ -54,14 +59,17 @@ def label_to_color_image(label):
54
  raise ValueError("label value too large.")
55
  return colormap[label]
56
 
 
57
  def draw_plot(pred_img, seg):
 
58
  fig = plt.figure(figsize=(20, 15))
59
-
60
  grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
61
 
62
  plt.subplot(grid_spec[0])
63
  plt.imshow(pred_img)
64
  plt.axis('off')
 
 
65
  LABEL_NAMES = np.asarray(labels_list)
66
  FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
67
  FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
@@ -75,36 +83,44 @@ def draw_plot(pred_img, seg):
75
  ax.tick_params(width=0.0, labelsize=25)
76
  return fig
77
 
 
78
  def sepia(input_img):
79
  input_img = Image.fromarray(input_img)
80
 
 
81
  inputs = feature_extractor(images=input_img, return_tensors="tf")
82
  outputs = model(**inputs)
83
  logits = outputs.logits
84
 
 
85
  logits = tf.transpose(logits, [0, 2, 3, 1])
86
  logits = tf.image.resize(
87
  logits, input_img.size[::-1]
88
- ) # We reverse the shape of `image` because `image.size` returns width and height.
 
 
89
  seg = tf.math.argmax(logits, axis=-1)[0]
90
 
91
  color_seg = np.zeros(
92
  (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
93
- ) # height, width, 3
94
  for label, color in enumerate(colormap):
95
  color_seg[seg.numpy() == label, :] = color
96
 
97
- # Show image + mask
98
  pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
99
  pred_img = pred_img.astype(np.uint8)
100
 
 
101
  fig = draw_plot(pred_img, seg)
102
  return fig
103
 
 
104
  demo = gr.Interface(fn=sepia,
105
  inputs=gr.Image(shape=(400, 600)),
106
  outputs=['plot'],
107
  examples=["city-1.jpg", "city-2.jpg", "city-3.jpg"],
108
  allow_flagging='never')
109
 
 
110
  demo.launch()
 
7
  import tensorflow as tf
8
  from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
9
 
10
+ # ์‚ฌ์ „ ํ›ˆ๋ จ๋œ Segformer ํŠน์„ฑ ์ถ”์ถœ๊ธฐ์™€ ์‹œ๋งจํ‹ฑ ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ๋ชจ๋ธ์„ ๋กœ๋“œ
11
  feature_extractor = SegformerFeatureExtractor.from_pretrained(
12
  "nvidia/segformer-b3-finetuned-cityscapes-1024-1024"
13
 
 
16
  "nvidia/segformer-b3-finetuned-cityscapes-1024-1024"
17
  )
18
 
19
+ # ADE20K ๋ฐ์ดํ„ฐ์…‹์„ ์œ„ํ•œ RBG ์ƒ‰์ƒ๊ฐ’ ์ •์˜
20
  def ade_palette():
21
  return [
22
  [204, 87, 92],
 
40
  [180, 32, 10],
41
  ]
42
 
43
+ # 'labels.txt'์—์„œ ๋กœ๋“œํ•œ ๋ผ๋ฒจ ๋ชฉ๋ก ์ •์˜
44
  labels_list = []
45
 
46
  with open(r'labels.txt', 'r') as fp:
47
  for line in fp:
48
  labels_list.append(line[:-1])
49
 
50
+ # ์ •์˜ํ•œ ์ƒ‰์ƒ ๋ฐฐ์—ด์„ NumPy ๋ฐฐ์—ด๋กœ ๋ณ€ํ™˜
51
  colormap = np.asarray(ade_palette())
52
 
53
+ # ๋ผ๋ฒจ์„ ์ƒ‰ ์ด๋ฏธ์ง€๋กœ ๋งคํ•‘ํ•˜๋Š” ํ•จ์ˆ˜
54
  def label_to_color_image(label):
55
  if label.ndim != 2:
56
  raise ValueError("Expect 2-D input label")
 
59
  raise ValueError("label value too large.")
60
  return colormap[label]
61
 
62
+ # ์˜ˆ์ธก๋œ ์ด๋ฏธ์ง€์™€ ์ƒ‰์ƒ ๋งต์„ ํฌํ•จํ•œ ํ”Œ๋กฏ์„ ๊ทธ๋ฆฌ๋Š” ํ•จ์ˆ˜
63
  def draw_plot(pred_img, seg):
64
+ # ์˜ˆ์ธก๋œ ์ด๋ฏธ์ง€ ๋ฐ ์ƒ‰์ƒ ๋งต ํ”Œ๋กฏ ๋งŒ๋“ค๊ธฐ
65
  fig = plt.figure(figsize=(20, 15))
 
66
  grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
67
 
68
  plt.subplot(grid_spec[0])
69
  plt.imshow(pred_img)
70
  plt.axis('off')
71
+
72
+ # ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ๋ผ๋ฒจ์„ ์œ„ํ•œ ์ƒ‰์ƒ ๋งต ์„ค์ •
73
  LABEL_NAMES = np.asarray(labels_list)
74
  FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
75
  FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
 
83
  ax.tick_params(width=0.0, labelsize=25)
84
  return fig
85
 
86
+ # Input ์ด๋ฏธ์ง€์— Segformer ๋ชจ๋ธ์„ ์ ์šฉํ•˜๊ณ  ํ”Œ๋กฏ์„ ๋งŒ๋“œ๋Š” ํ•จ์ˆ˜
87
  def sepia(input_img):
88
  input_img = Image.fromarray(input_img)
89
 
90
+ # feature ์ถ”์ถœ ํ›„ Segformer ๋ชจ๋ธ๋กœ ์˜ˆ์ธก
91
  inputs = feature_extractor(images=input_img, return_tensors="tf")
92
  outputs = model(**inputs)
93
  logits = outputs.logits
94
 
95
+ # ์ž…๋ ฅ ์ด๋ฏธ์ง€ ํฌ๊ธฐ์™€ ์ผ์น˜ํ•˜๋„๋ก ํฌ๊ธฐ ์กฐ์ •
96
  logits = tf.transpose(logits, [0, 2, 3, 1])
97
  logits = tf.image.resize(
98
  logits, input_img.size[::-1]
99
+ )
100
+
101
+ # ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜์„ ์ถ”์ถœํ•˜๊ณ  ๋ผ๋ฒจ์„ ์ƒ‰์ƒ์œผ๋กœ ๋งคํ•‘
102
  seg = tf.math.argmax(logits, axis=-1)[0]
103
 
104
  color_seg = np.zeros(
105
  (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
106
+ )
107
  for label, color in enumerate(colormap):
108
  color_seg[seg.numpy() == label, :] = color
109
 
110
+ # ์›๋ณธ๊ณผ ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜์ด ํ˜ผํ•ฉ๋œ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑ
111
  pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
112
  pred_img = pred_img.astype(np.uint8)
113
 
114
+ # ์˜ˆ์ธก๋œ ์ด๋ฏธ์ง€์™€ ์ƒ‰์ƒ ๋งต์„ ํฌํ•จํ•œ ํ”Œ๋กฏ ๊ทธ๋ฆฌ๊ธฐ
115
  fig = draw_plot(pred_img, seg)
116
  return fig
117
 
118
+ # sepia ํ•จ์ˆ˜์— ๋Œ€ํ•œ Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
119
  demo = gr.Interface(fn=sepia,
120
  inputs=gr.Image(shape=(400, 600)),
121
  outputs=['plot'],
122
  examples=["city-1.jpg", "city-2.jpg", "city-3.jpg"],
123
  allow_flagging='never')
124
 
125
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์‹คํ–‰
126
  demo.launch()