ondrejbiza commited on
Commit
f1a8131
1 Parent(s): c57a3a8

Added a reset button.

Browse files
Files changed (1) hide show
  1. app.py +30 -6
app.py CHANGED
@@ -77,10 +77,15 @@ class DecoderWrapper(nn.Module):
77
  decoder_model = DecoderWrapper(decoder=model.decoder)
78
 
79
  with gr.Blocks() as demo:
80
-
81
  local_slots = gr.State(np.zeros((11, 64), dtype=np.float32))
 
 
 
 
82
  local_pos = gr.State(np.zeros((11, 2), dtype=np.float32))
83
  local_scale = gr.State(np.zeros((11, 2), dtype=np.float32))
 
84
  local_probs = gr.State(np.zeros((11, 128, 128), dtype=np.float32))
85
 
86
  with gr.Row():
@@ -91,7 +96,7 @@ with gr.Blocks() as demo:
91
 
92
  with gr.Row():
93
 
94
- with gr.Column():
95
 
96
  with gr.Row():
97
  with gr.Column():
@@ -108,7 +113,11 @@ with gr.Blocks() as demo:
108
  gr_sy_slider = gr.Slider(0.01, 1, value=0.1, step=0.01, label="width")
109
  gr_sx_slider = gr.Slider(0.01, 1, value=0.1, step=0.01, label="height")
110
 
111
- gr_button = gr.Button("Render", info="Render a new image with altered positions and scales.")
 
 
 
 
112
 
113
  def update_image_and_segmentation(name, idx):
114
  idx = idx - 1
@@ -130,13 +139,13 @@ with gr.Blocks() as demo:
130
  scale = slots_[0, 0, :, -2:]
131
 
132
  return (img * 255).astype(np.uint8), (probs[idx] * 255).astype(np.uint8), float(pos[idx, 0]), \
133
- float(pos[idx, 1]), float(scale[idx, 0]), float(scale[idx, 1]), probs, slots, pos, scale
134
 
135
  gr_choose_image.change(
136
  fn=update_image_and_segmentation,
137
  inputs=[gr_choose_image, gr_slot_slider],
138
  outputs=[gr_image_1, gr_image_2, gr_x_slider, gr_y_slider, gr_sx_slider, gr_sy_slider,
139
- local_probs, local_slots, local_pos, local_scale]
140
  )
141
 
142
  def update_sliders(idx, local_probs, local_pos, local_scale):
@@ -151,18 +160,22 @@ with gr.Blocks() as demo:
151
  )
152
 
153
  def update_pos_x(idx, val, local_pos):
 
154
  local_pos[idx - 1, 0] = val
155
  return local_pos
156
 
157
  def update_pos_y(idx, val, local_pos):
 
158
  local_pos[idx - 1, 1] = val
159
  return local_pos
160
 
161
  def update_scale_x(idx, val, local_scale):
 
162
  local_scale[idx - 1, 0] = val
163
  return local_scale
164
 
165
  def update_scale_y(idx, val, local_scale):
 
166
  local_scale[idx - 1, 1] = val
167
  return local_scale
168
 
@@ -204,10 +217,21 @@ with gr.Blocks() as demo:
204
  image = np.clip(image, 0, 1)
205
  return (image * 255).astype(np.uint8), (probs[idx] * 255).astype(np.uint8), probs
206
 
207
- gr_button.click(
208
  fn=render,
209
  inputs=[gr_slot_slider, local_slots, local_pos, local_scale],
210
  outputs=[gr_image_1, gr_image_2, local_probs]
211
  )
212
 
 
 
 
 
 
 
 
 
 
 
 
213
  demo.launch()
 
77
  decoder_model = DecoderWrapper(decoder=model.decoder)
78
 
79
  with gr.Blocks() as demo:
80
+
81
  local_slots = gr.State(np.zeros((11, 64), dtype=np.float32))
82
+
83
+ local_orig_pos = gr.State(np.zeros((11, 2), dtype=np.float32))
84
+ local_orig_scale = gr.State(np.zeros((11, 2), dtype=np.float32))
85
+
86
  local_pos = gr.State(np.zeros((11, 2), dtype=np.float32))
87
  local_scale = gr.State(np.zeros((11, 2), dtype=np.float32))
88
+
89
  local_probs = gr.State(np.zeros((11, 128, 128), dtype=np.float32))
90
 
91
  with gr.Row():
 
96
 
97
  with gr.Row():
98
 
99
+ with gr.Column(min_width=600):
100
 
101
  with gr.Row():
102
  with gr.Column():
 
113
  gr_sy_slider = gr.Slider(0.01, 1, value=0.1, step=0.01, label="width")
114
  gr_sx_slider = gr.Slider(0.01, 1, value=0.1, step=0.01, label="height")
115
 
116
+ with gr.Row():
117
+ with gr.Column():
118
+ gr_button_render = gr.Button("Render", variant="primary", info="Render a new image with altered positions and scales.")
119
+ with gr.Column():
120
+ gr_button_reset = gr.Button("Reset", info="Reset slot statistics.")
121
 
122
  def update_image_and_segmentation(name, idx):
123
  idx = idx - 1
 
139
  scale = slots_[0, 0, :, -2:]
140
 
141
  return (img * 255).astype(np.uint8), (probs[idx] * 255).astype(np.uint8), float(pos[idx, 0]), \
142
+ float(pos[idx, 1]), float(scale[idx, 0]), float(scale[idx, 1]), probs, slots, pos, scale, pos, scale
143
 
144
  gr_choose_image.change(
145
  fn=update_image_and_segmentation,
146
  inputs=[gr_choose_image, gr_slot_slider],
147
  outputs=[gr_image_1, gr_image_2, gr_x_slider, gr_y_slider, gr_sx_slider, gr_sy_slider,
148
+ local_probs, local_slots, local_pos, local_scale, local_orig_pos, local_orig_scale]
149
  )
150
 
151
  def update_sliders(idx, local_probs, local_pos, local_scale):
 
160
  )
161
 
162
  def update_pos_x(idx, val, local_pos):
163
+ local_pos = np.copy(local_pos)
164
  local_pos[idx - 1, 0] = val
165
  return local_pos
166
 
167
  def update_pos_y(idx, val, local_pos):
168
+ local_pos = np.copy(local_pos)
169
  local_pos[idx - 1, 1] = val
170
  return local_pos
171
 
172
  def update_scale_x(idx, val, local_scale):
173
+ local_scale = np.copy(local_scale)
174
  local_scale[idx - 1, 0] = val
175
  return local_scale
176
 
177
  def update_scale_y(idx, val, local_scale):
178
+ local_scale = np.copy(local_scale)
179
  local_scale[idx - 1, 1] = val
180
  return local_scale
181
 
 
217
  image = np.clip(image, 0, 1)
218
  return (image * 255).astype(np.uint8), (probs[idx] * 255).astype(np.uint8), probs
219
 
220
+ gr_button_render.click(
221
  fn=render,
222
  inputs=[gr_slot_slider, local_slots, local_pos, local_scale],
223
  outputs=[gr_image_1, gr_image_2, local_probs]
224
  )
225
 
226
+ def reset(idx, local_orig_pos, local_orig_scale):
227
+ idx = idx - 1
228
+ return np.copy(local_orig_pos), np.copy(local_orig_scale), float(local_orig_pos[idx, 0]), \
229
+ float(local_orig_pos[idx, 1]), float(local_orig_scale[idx, 0]), float(local_orig_scale[idx, 1])
230
+
231
+ gr_button_reset.click(
232
+ fn=reset,
233
+ inputs=[gr_slot_slider, local_orig_pos, local_orig_scale],
234
+ outputs=[local_pos, local_scale, gr_x_slider, gr_y_slider, gr_sx_slider, gr_sy_slider]
235
+ )
236
+
237
  demo.launch()