ondrejbiza commited on
Commit
9fbcce5
1 Parent(s): f1a8131

Relative scales.

Browse files
Files changed (1) hide show
  1. app.py +17 -17
app.py CHANGED
@@ -80,8 +80,8 @@ with gr.Blocks() as demo:
80
 
81
  local_slots = gr.State(np.zeros((11, 64), dtype=np.float32))
82
 
83
- local_orig_pos = gr.State(np.zeros((11, 2), dtype=np.float32))
84
- local_orig_scale = gr.State(np.zeros((11, 2), dtype=np.float32))
85
 
86
  local_pos = gr.State(np.zeros((11, 2), dtype=np.float32))
87
  local_scale = gr.State(np.zeros((11, 2), dtype=np.float32))
@@ -96,7 +96,7 @@ with gr.Blocks() as demo:
96
 
97
  with gr.Row():
98
 
99
- with gr.Column(min_width=600):
100
 
101
  with gr.Row():
102
  with gr.Column():
@@ -108,10 +108,10 @@ with gr.Blocks() as demo:
108
  gr_slot_slider = gr.Slider(1, 11, value=1, step=1, label="Slot Index",
109
  info="Change slot index too see the segmentation mask, position and scale of each slot.")
110
 
111
- gr_y_slider = gr.Slider(-1, 1, value=0, step=0.01, label="x")
112
- gr_x_slider = gr.Slider(-1, 1, value=0, step=0.01, label="y")
113
- gr_sy_slider = gr.Slider(0.01, 1, value=0.1, step=0.01, label="width")
114
- gr_sx_slider = gr.Slider(0.01, 1, value=0.1, step=0.01, label="height")
115
 
116
  with gr.Row():
117
  with gr.Column():
@@ -139,13 +139,13 @@ with gr.Blocks() as demo:
139
  scale = slots_[0, 0, :, -2:]
140
 
141
  return (img * 255).astype(np.uint8), (probs[idx] * 255).astype(np.uint8), float(pos[idx, 0]), \
142
- float(pos[idx, 1]), float(scale[idx, 0]), float(scale[idx, 1]), probs, slots, pos, scale, pos, scale
143
 
144
  gr_choose_image.change(
145
  fn=update_image_and_segmentation,
146
  inputs=[gr_choose_image, gr_slot_slider],
147
- outputs=[gr_image_1, gr_image_2, gr_x_slider, gr_y_slider, gr_sx_slider, gr_sy_slider,
148
- local_probs, local_slots, local_pos, local_scale, local_orig_pos, local_orig_scale]
149
  )
150
 
151
  def update_sliders(idx, local_probs, local_pos, local_scale):
@@ -200,10 +200,10 @@ with gr.Blocks() as demo:
200
  outputs=local_scale
201
  )
202
 
203
- def render(idx, local_slots, local_pos, local_scale):
204
  idx = idx - 1
205
 
206
- slots = np.concatenate([local_slots, local_pos, local_scale], axis=-1)
207
  slots = jnp.array(slots)
208
 
209
  out = decoder_model.apply(
@@ -219,18 +219,18 @@ with gr.Blocks() as demo:
219
 
220
  gr_button_render.click(
221
  fn=render,
222
- inputs=[gr_slot_slider, local_slots, local_pos, local_scale],
223
  outputs=[gr_image_1, gr_image_2, local_probs]
224
  )
225
 
226
- def reset(idx, local_orig_pos, local_orig_scale):
227
  idx = idx - 1
228
- return np.copy(local_orig_pos), np.copy(local_orig_scale), float(local_orig_pos[idx, 0]), \
229
- float(local_orig_pos[idx, 1]), float(local_orig_scale[idx, 0]), float(local_orig_scale[idx, 1])
230
 
231
  gr_button_reset.click(
232
  fn=reset,
233
- inputs=[gr_slot_slider, local_orig_pos, local_orig_scale],
234
  outputs=[local_pos, local_scale, gr_x_slider, gr_y_slider, gr_sx_slider, gr_sy_slider]
235
  )
236
 
 
80
 
81
  local_slots = gr.State(np.zeros((11, 64), dtype=np.float32))
82
 
83
+ local_orig_pos = gr.State(np.ones((11, 2), dtype=np.float32))
84
+ local_orig_scale = gr.State(np.ones((11, 2), dtype=np.float32))
85
 
86
  local_pos = gr.State(np.zeros((11, 2), dtype=np.float32))
87
  local_scale = gr.State(np.zeros((11, 2), dtype=np.float32))
 
96
 
97
  with gr.Row():
98
 
99
+ with gr.Column():
100
 
101
  with gr.Row():
102
  with gr.Column():
 
108
  gr_slot_slider = gr.Slider(1, 11, value=1, step=1, label="Slot Index",
109
  info="Change slot index too see the segmentation mask, position and scale of each slot.")
110
 
111
+ gr_y_slider = gr.Slider(-1, 1, value=0, step=0.01, label="X")
112
+ gr_x_slider = gr.Slider(-1, 1, value=0, step=0.01, label="Y")
113
+ gr_sy_slider = gr.Slider(0.5, 1.5, value=1., step=0.1, label="Width Multiplier")
114
+ gr_sx_slider = gr.Slider(0.5, 1.5, value=1., step=0.1, label="Height Multiplier")
115
 
116
  with gr.Row():
117
  with gr.Column():
 
139
  scale = slots_[0, 0, :, -2:]
140
 
141
  return (img * 255).astype(np.uint8), (probs[idx] * 255).astype(np.uint8), float(pos[idx, 0]), \
142
+ float(pos[idx, 1]), probs, slots, pos, np.ones((11, 2), dtype=np.float32), pos, scale
143
 
144
  gr_choose_image.change(
145
  fn=update_image_and_segmentation,
146
  inputs=[gr_choose_image, gr_slot_slider],
147
+ outputs=[gr_image_1, gr_image_2, gr_x_slider, gr_y_slider, local_probs,
148
+ local_slots, local_pos, local_scale, local_orig_pos, local_orig_scale]
149
  )
150
 
151
  def update_sliders(idx, local_probs, local_pos, local_scale):
 
200
  outputs=local_scale
201
  )
202
 
203
+ def render(idx, local_slots, local_pos, local_scale, local_orig_scale):
204
  idx = idx - 1
205
 
206
+ slots = np.concatenate([local_slots, local_pos, local_scale * local_orig_scale], axis=-1)
207
  slots = jnp.array(slots)
208
 
209
  out = decoder_model.apply(
 
219
 
220
  gr_button_render.click(
221
  fn=render,
222
+ inputs=[gr_slot_slider, local_slots, local_pos, local_scale, local_orig_scale],
223
  outputs=[gr_image_1, gr_image_2, local_probs]
224
  )
225
 
226
+ def reset(idx, local_orig_pos):
227
  idx = idx - 1
228
+ return np.copy(local_orig_pos), np.ones((11, 2), dtype=np.float32), float(local_orig_pos[idx, 0]), \
229
+ float(local_orig_pos[idx, 1]), 1., 1.
230
 
231
  gr_button_reset.click(
232
  fn=reset,
233
+ inputs=[gr_slot_slider, local_orig_pos],
234
  outputs=[local_pos, local_scale, gr_x_slider, gr_y_slider, gr_sx_slider, gr_sy_slider]
235
  )
236