ondrejbiza commited on
Commit
9d5d768
1 Parent(s): 65d6890

V1 works locally.

Browse files
Files changed (2) hide show
  1. .DS_Store +0 -0
  2. app.py +149 -21
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
app.py CHANGED
@@ -6,13 +6,18 @@ from clu import checkpoint
6
  import gradio as gr
7
  import jax
8
  import jax.numpy as jnp
 
 
 
 
9
 
10
  from invariant_slot_attention.configs.clevr_with_masks.equiv_transl_scale import get_config
11
  from invariant_slot_attention.lib import input_pipeline
 
12
  from invariant_slot_attention.lib import utils
13
 
14
 
15
- def load_model(config):
16
  rng = jax.random.PRNGKey(42)
17
  rng, data_rng = jax.random.split(rng)
18
 
@@ -42,27 +47,150 @@ def load_model(config):
42
 
43
  opt_state = None
44
  state = utils.TrainState(
45
- step=1, opt_state=opt_state, params=initial_params, rng=rng,
46
  variables=state_vars)
47
 
48
- checkpoint_dir = "clevr_isa_ts/checkpoints-0"
49
- ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir)
50
- state = ckpt.restore_or_initialize(state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- init_inputs = jnp.ones([1, 1, 128, 128, 3], jnp.float32)
53
- rng, init_rng = jax.random.split(rng, num=2)
54
- out = model.apply(
55
- {"params": state.params, **state.variables},
56
- video=init_inputs,
57
- rngs={"state_init": init_rng},
58
- train=False)
59
- print(out.keys())
60
-
61
-
62
- def greet(name):
63
- return "Hello " + name + "!"
64
-
65
-
66
- load_model(get_config())
67
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
68
  demo.launch()
6
  import gradio as gr
7
  import jax
8
  import jax.numpy as jnp
9
+ import numpy as np
10
+ from PIL import Image
11
+ import tensorflow as tf
12
+ from huggingface_hub import snapshot_download
13
 
14
  from invariant_slot_attention.configs.clevr_with_masks.equiv_transl_scale import get_config
15
  from invariant_slot_attention.lib import input_pipeline
16
+ from invariant_slot_attention.lib import preprocessing
17
  from invariant_slot_attention.lib import utils
18
 
19
 
20
+ def load_model(config, checkpoint_dir):
21
  rng = jax.random.PRNGKey(42)
22
  rng, data_rng = jax.random.split(rng)
23
 
47
 
48
  opt_state = None
49
  state = utils.TrainState(
50
+ step=42, opt_state=opt_state, params=initial_params, rng=rng,
51
  variables=state_vars)
52
 
53
+ ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir)
54
+ state = ckpt.restore(state)
55
+
56
+ return model, state, rng
57
+
58
+
59
+ def load_image(name):
60
+ img = Image.open(f"images/{name}.png")
61
+ img = img.crop((64, 29, 64 + 192, 29 + 192))
62
+ img = img.resize((128, 128))
63
+ img_ = np.array(img)
64
+ img = np.array(img)[:, :, :3] / 255.
65
+ img = jnp.array(img, dtype=jnp.float32)
66
+ return img, img_
67
+
68
+
69
+ download_path = snapshot_download(repo_id="ondrejbiza/isa")
70
+ checkpoint_dir = os.path.join(download_path, "clevr_isa_ts", "checkpoints")
71
+
72
+ model, state, rng = load_model(get_config(), checkpoint_dir)
73
+
74
+ rng, init_rng = jax.random.split(rng, num=2)
75
+
76
+ from flax import linen as nn
77
+ from typing import Callable
78
+ class DecoderWrapper(nn.Module):
79
+ decoder: Callable[[], nn.Module]
80
+ @nn.compact
81
+ def __call__(self, slots, train=False):
82
+ return self.decoder()(slots, train)
83
+ decoder_model = DecoderWrapper(decoder=model.decoder)
84
+
85
+ slots = np.zeros((11, 64), dtype=np.float32)
86
+ pos = np.zeros((11, 2), dtype=np.float32)
87
+ scale = np.zeros((11, 2), dtype=np.float32)
88
+ probs = np.zeros((11, 128, 128), dtype=np.float32)
89
+
90
+ with gr.Blocks() as demo:
91
+
92
+ with gr.Row():
93
+
94
+ with gr.Column():
95
+ gr_choose_image = gr.Dropdown(
96
+ [f"img{i}" for i in range(1, 9)], label="CLEVR Image", info="Start by a picking an image from the CLEVR dataset."
97
+ )
98
+ gr_image_1 = gr.Image(type="numpy")
99
+ gr_image_2 = gr.Image(type="numpy")
100
+
101
+ with gr.Column():
102
+ gr_slot_slider = gr.Slider(1, 11, value=1, step=1, label="Slot")
103
+
104
+ gr_y_slider = gr.Slider(-1, 1, value=0, step=0.01, label="x")
105
+ gr_x_slider = gr.Slider(-1, 1, value=0, step=0.01, label="y")
106
+ gr_sy_slider = gr.Slider(0.01, 1, value=0.1, step=0.01, label="width")
107
+ gr_sx_slider = gr.Slider(0.01, 1, value=0.1, step=0.01, label="height")
108
+
109
+ gr_button = gr.Button("Render")
110
+
111
+ def update_image_and_segmentation(name, idx):
112
+ idx = idx - 1
113
+
114
+ img_input, img = load_image(name)
115
+ out = model.apply(
116
+ {"params": state.params, **state.variables},
117
+ video=img_input[None, None],
118
+ rngs={"state_init": init_rng},
119
+ train=False)
120
+
121
+ probs[:] = nn.softmax(out["outputs"]["segmentation_logits"][0, 0, :, :, :, 0], axis=0)
122
+ slots_ = out["states"]
123
+ slots[:] = slots_[0, 0, :, :-4]
124
+ pos[:] = slots_[0, 0, :, -4: -2]
125
+ scale[:] = slots_[0, 0, :, -2:]
126
+
127
+ return img, (probs[idx] * 255).astype(np.uint8), float(pos[idx, 0]), \
128
+ float(pos[idx, 1]), float(scale[idx, 0]), float(scale[idx, 1])
129
+
130
+ gr_choose_image.change(
131
+ fn=update_image_and_segmentation,
132
+ inputs=[gr_choose_image, gr_slot_slider],
133
+ outputs=[gr_image_1, gr_image_2, gr_x_slider, gr_y_slider, gr_sx_slider, gr_sy_slider]
134
+ )
135
+
136
+ def update_sliders(idx):
137
+ idx = idx - 1 # 1-indexing to 0-indexing
138
+ return (probs[idx] * 255).astype(np.uint8), float(pos[idx, 0]), \
139
+ float(pos[idx, 1]), float(scale[idx, 0]), float(scale[idx, 1])
140
+
141
+ gr_slot_slider.change(
142
+ fn=update_sliders,
143
+ inputs=gr_slot_slider,
144
+ outputs=[gr_image_2, gr_x_slider, gr_y_slider, gr_sx_slider, gr_sy_slider]
145
+ )
146
+
147
+ def update_pos_x(idx, val):
148
+ pos[idx - 1, 0] = val
149
+ def update_pos_y(idx, val):
150
+ pos[idx - 1, 1] = val
151
+ def update_scale_x(idx, val):
152
+ scale[idx - 1, 0] = val
153
+ def update_scale_y(idx, val):
154
+ scale[idx - 1, 1] = val
155
+
156
+ gr_x_slider.change(
157
+ fn=update_pos_x,
158
+ inputs=[gr_slot_slider, gr_x_slider]
159
+ )
160
+ gr_y_slider.change(
161
+ fn=update_pos_y,
162
+ inputs=[gr_slot_slider, gr_y_slider]
163
+ )
164
+ gr_sx_slider.change(
165
+ fn=update_scale_x,
166
+ inputs=[gr_slot_slider, gr_sx_slider]
167
+ )
168
+ gr_sy_slider.change(
169
+ fn=update_scale_y,
170
+ inputs=[gr_slot_slider, gr_sy_slider]
171
+ )
172
+
173
+ def render(idx):
174
+ idx = idx - 1
175
+
176
+ slots_ = np.concatenate([slots, pos, scale], axis=-1)
177
+ slots_ = jnp.array(slots_)
178
+
179
+ out = decoder_model.apply(
180
+ {"params": state.params, **state.variables},
181
+ slots=slots_[None, None],
182
+ train=False
183
+ )
184
+
185
+ probs[:] = nn.softmax(out["segmentation_logits"][0, 0, :, :, :, 0], axis=0)
186
+ image = np.array(out["video"][0, 0])
187
+ image = np.clip(image, 0, 1)
188
+ return (image * 255).astype(np.uint8), (probs[idx] * 255).astype(np.uint8)
189
+
190
+ gr_button.click(
191
+ fn=render,
192
+ inputs=gr_slot_slider,
193
+ outputs=[gr_image_1, gr_image_2]
194
+ )
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  demo.launch()