chojo12 commited on
Commit
d73e50a
1 Parent(s): 958a268

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -15
app.py CHANGED
@@ -83,31 +83,29 @@ def sepia(input_img):
83
  logits = outputs.logits
84
 
85
  logits = tf.transpose(logits, [0, 2, 3, 1])
86
- logits = tf.image.resize(logits, input_img.size[::-1])
 
 
87
  seg = tf.math.argmax(logits, axis=-1)[0]
88
 
89
- color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
 
 
90
  for label, color in enumerate(colormap):
91
  color_seg[seg.numpy() == label, :] = color
92
 
93
- # 이미지 + 마스크 표시
94
  pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
95
  pred_img = pred_img.astype(np.uint8)
96
 
97
  fig = draw_plot(pred_img, seg)
98
  return fig
99
 
100
- demo = gr.Interface(
101
- fn=sepia,
102
- inputs=gr.Image(shape=(400, 600), label="입력 이미지"),
103
- outputs=gr.Image(type="plot", label="Image Segmentation 결과"),
104
- examples=["cityscape_example1.jpeg"],
105
- theme="compact",
106
- title="Image Segmentation 결과",
107
- description="이미지를 업로드하고 Segmentation 결과를 확인하세요.",
108
- allow_flagging='never'
109
- )
110
-
111
- demo.launch()
112
 
113
 
 
 
83
  logits = outputs.logits
84
 
85
  logits = tf.transpose(logits, [0, 2, 3, 1])
86
+ logits = tf.image.resize(
87
+ logits, input_img.size[::-1]
88
+ ) # We reverse the shape of `image` because `image.size` returns width and height.
89
  seg = tf.math.argmax(logits, axis=-1)[0]
90
 
91
+ color_seg = np.zeros(
92
+ (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
93
+ ) # height, width, 3
94
  for label, color in enumerate(colormap):
95
  color_seg[seg.numpy() == label, :] = color
96
 
97
+ # Show image + mask
98
  pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
99
  pred_img = pred_img.astype(np.uint8)
100
 
101
  fig = draw_plot(pred_img, seg)
102
  return fig
103
 
104
+ demo = gr.Interface(fn=sepia,
105
+ inputs=gr.Image(shape=(400, 600)),
106
+ outputs=['plot'],
107
+ examples=["cityscape_example1.jpeg"],
108
+ allow_flagging='never')
 
 
 
 
 
 
 
109
 
110
 
111
+ demo.launch()