Spaces:
Runtime error
Runtime error
ondrejbiza
commited on
Commit
•
f1a8131
1
Parent(s):
c57a3a8
Added a reset button.
Browse files
app.py
CHANGED
@@ -77,10 +77,15 @@ class DecoderWrapper(nn.Module):
|
|
77 |
decoder_model = DecoderWrapper(decoder=model.decoder)
|
78 |
|
79 |
with gr.Blocks() as demo:
|
80 |
-
|
81 |
local_slots = gr.State(np.zeros((11, 64), dtype=np.float32))
|
|
|
|
|
|
|
|
|
82 |
local_pos = gr.State(np.zeros((11, 2), dtype=np.float32))
|
83 |
local_scale = gr.State(np.zeros((11, 2), dtype=np.float32))
|
|
|
84 |
local_probs = gr.State(np.zeros((11, 128, 128), dtype=np.float32))
|
85 |
|
86 |
with gr.Row():
|
@@ -91,7 +96,7 @@ with gr.Blocks() as demo:
|
|
91 |
|
92 |
with gr.Row():
|
93 |
|
94 |
-
with gr.Column():
|
95 |
|
96 |
with gr.Row():
|
97 |
with gr.Column():
|
@@ -108,7 +113,11 @@ with gr.Blocks() as demo:
|
|
108 |
gr_sy_slider = gr.Slider(0.01, 1, value=0.1, step=0.01, label="width")
|
109 |
gr_sx_slider = gr.Slider(0.01, 1, value=0.1, step=0.01, label="height")
|
110 |
|
111 |
-
|
|
|
|
|
|
|
|
|
112 |
|
113 |
def update_image_and_segmentation(name, idx):
|
114 |
idx = idx - 1
|
@@ -130,13 +139,13 @@ with gr.Blocks() as demo:
|
|
130 |
scale = slots_[0, 0, :, -2:]
|
131 |
|
132 |
return (img * 255).astype(np.uint8), (probs[idx] * 255).astype(np.uint8), float(pos[idx, 0]), \
|
133 |
-
float(pos[idx, 1]), float(scale[idx, 0]), float(scale[idx, 1]), probs, slots, pos, scale
|
134 |
|
135 |
gr_choose_image.change(
|
136 |
fn=update_image_and_segmentation,
|
137 |
inputs=[gr_choose_image, gr_slot_slider],
|
138 |
outputs=[gr_image_1, gr_image_2, gr_x_slider, gr_y_slider, gr_sx_slider, gr_sy_slider,
|
139 |
-
local_probs, local_slots, local_pos, local_scale]
|
140 |
)
|
141 |
|
142 |
def update_sliders(idx, local_probs, local_pos, local_scale):
|
@@ -151,18 +160,22 @@ with gr.Blocks() as demo:
|
|
151 |
)
|
152 |
|
153 |
def update_pos_x(idx, val, local_pos):
|
|
|
154 |
local_pos[idx - 1, 0] = val
|
155 |
return local_pos
|
156 |
|
157 |
def update_pos_y(idx, val, local_pos):
|
|
|
158 |
local_pos[idx - 1, 1] = val
|
159 |
return local_pos
|
160 |
|
161 |
def update_scale_x(idx, val, local_scale):
|
|
|
162 |
local_scale[idx - 1, 0] = val
|
163 |
return local_scale
|
164 |
|
165 |
def update_scale_y(idx, val, local_scale):
|
|
|
166 |
local_scale[idx - 1, 1] = val
|
167 |
return local_scale
|
168 |
|
@@ -204,10 +217,21 @@ with gr.Blocks() as demo:
|
|
204 |
image = np.clip(image, 0, 1)
|
205 |
return (image * 255).astype(np.uint8), (probs[idx] * 255).astype(np.uint8), probs
|
206 |
|
207 |
-
|
208 |
fn=render,
|
209 |
inputs=[gr_slot_slider, local_slots, local_pos, local_scale],
|
210 |
outputs=[gr_image_1, gr_image_2, local_probs]
|
211 |
)
|
212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
demo.launch()
|
|
|
77 |
decoder_model = DecoderWrapper(decoder=model.decoder)
|
78 |
|
79 |
with gr.Blocks() as demo:
|
80 |
+
|
81 |
local_slots = gr.State(np.zeros((11, 64), dtype=np.float32))
|
82 |
+
|
83 |
+
local_orig_pos = gr.State(np.zeros((11, 2), dtype=np.float32))
|
84 |
+
local_orig_scale = gr.State(np.zeros((11, 2), dtype=np.float32))
|
85 |
+
|
86 |
local_pos = gr.State(np.zeros((11, 2), dtype=np.float32))
|
87 |
local_scale = gr.State(np.zeros((11, 2), dtype=np.float32))
|
88 |
+
|
89 |
local_probs = gr.State(np.zeros((11, 128, 128), dtype=np.float32))
|
90 |
|
91 |
with gr.Row():
|
|
|
96 |
|
97 |
with gr.Row():
|
98 |
|
99 |
+
with gr.Column(min_width=600):
|
100 |
|
101 |
with gr.Row():
|
102 |
with gr.Column():
|
|
|
113 |
gr_sy_slider = gr.Slider(0.01, 1, value=0.1, step=0.01, label="width")
|
114 |
gr_sx_slider = gr.Slider(0.01, 1, value=0.1, step=0.01, label="height")
|
115 |
|
116 |
+
with gr.Row():
|
117 |
+
with gr.Column():
|
118 |
+
gr_button_render = gr.Button("Render", variant="primary", info="Render a new image with altered positions and scales.")
|
119 |
+
with gr.Column():
|
120 |
+
gr_button_reset = gr.Button("Reset", info="Reset slot statistics.")
|
121 |
|
122 |
def update_image_and_segmentation(name, idx):
|
123 |
idx = idx - 1
|
|
|
139 |
scale = slots_[0, 0, :, -2:]
|
140 |
|
141 |
return (img * 255).astype(np.uint8), (probs[idx] * 255).astype(np.uint8), float(pos[idx, 0]), \
|
142 |
+
float(pos[idx, 1]), float(scale[idx, 0]), float(scale[idx, 1]), probs, slots, pos, scale, pos, scale
|
143 |
|
144 |
gr_choose_image.change(
|
145 |
fn=update_image_and_segmentation,
|
146 |
inputs=[gr_choose_image, gr_slot_slider],
|
147 |
outputs=[gr_image_1, gr_image_2, gr_x_slider, gr_y_slider, gr_sx_slider, gr_sy_slider,
|
148 |
+
local_probs, local_slots, local_pos, local_scale, local_orig_pos, local_orig_scale]
|
149 |
)
|
150 |
|
151 |
def update_sliders(idx, local_probs, local_pos, local_scale):
|
|
|
160 |
)
|
161 |
|
162 |
def update_pos_x(idx, val, local_pos):
|
163 |
+
local_pos = np.copy(local_pos)
|
164 |
local_pos[idx - 1, 0] = val
|
165 |
return local_pos
|
166 |
|
167 |
def update_pos_y(idx, val, local_pos):
|
168 |
+
local_pos = np.copy(local_pos)
|
169 |
local_pos[idx - 1, 1] = val
|
170 |
return local_pos
|
171 |
|
172 |
def update_scale_x(idx, val, local_scale):
|
173 |
+
local_scale = np.copy(local_scale)
|
174 |
local_scale[idx - 1, 0] = val
|
175 |
return local_scale
|
176 |
|
177 |
def update_scale_y(idx, val, local_scale):
|
178 |
+
local_scale = np.copy(local_scale)
|
179 |
local_scale[idx - 1, 1] = val
|
180 |
return local_scale
|
181 |
|
|
|
217 |
image = np.clip(image, 0, 1)
|
218 |
return (image * 255).astype(np.uint8), (probs[idx] * 255).astype(np.uint8), probs
|
219 |
|
220 |
+
gr_button_render.click(
|
221 |
fn=render,
|
222 |
inputs=[gr_slot_slider, local_slots, local_pos, local_scale],
|
223 |
outputs=[gr_image_1, gr_image_2, local_probs]
|
224 |
)
|
225 |
|
226 |
+
def reset(idx, local_orig_pos, local_orig_scale):
|
227 |
+
idx = idx - 1
|
228 |
+
return np.copy(local_orig_pos), np.copy(local_orig_scale), float(local_orig_pos[idx, 0]), \
|
229 |
+
float(local_orig_pos[idx, 1]), float(local_orig_scale[idx, 0]), float(local_orig_scale[idx, 1])
|
230 |
+
|
231 |
+
gr_button_reset.click(
|
232 |
+
fn=reset,
|
233 |
+
inputs=[gr_slot_slider, local_orig_pos, local_orig_scale],
|
234 |
+
outputs=[local_pos, local_scale, gr_x_slider, gr_y_slider, gr_sx_slider, gr_sy_slider]
|
235 |
+
)
|
236 |
+
|
237 |
demo.launch()
|