File size: 8,126 Bytes
a560c26
1e7763d
a560c26
65d6890
1e7763d
a560c26
1e7763d
a560c26
 
9d5d768
 
a560c26
90e5776
a560c26
 
 
9d5d768
65d6890
a560c26
 
 
 
 
 
 
 
65d6890
a560c26
 
 
 
8530ae8
a560c26
 
 
 
 
 
 
 
 
 
 
65d6890
a560c26
8530ae8
a560c26
 
8530ae8
 
9d5d768
 
 
 
 
 
 
 
 
 
1e7763d
9d5d768
 
90e5776
 
9d5d768
 
 
 
 
1e7763d
9d5d768
 
 
 
 
 
 
 
f1a8131
1e7763d
f1a8131
1c5afc2
 
f1a8131
1e7763d
1c5afc2
f1a8131
1e7763d
9b6ff29
9d5d768
 
 
 
 
9b6ff29
 
9d5d768
9fbcce5
9b6ff29
c57a3a8
 
 
 
 
9b6ff29
 
 
 
9d5d768
9fbcce5
 
 
 
9d5d768
f1a8131
 
 
 
 
9d5d768
 
 
 
1e7763d
9d5d768
 
 
 
 
 
1e7763d
9b6ff29
 
 
1e7763d
 
 
 
9d5d768
9b6ff29
9fbcce5
9d5d768
 
 
 
9fbcce5
 
9d5d768
 
1e7763d
9d5d768
1e7763d
 
9d5d768
911983c
9d5d768
1e7763d
9d5d768
 
 
1e7763d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d5d768
911983c
9d5d768
1e7763d
 
9d5d768
911983c
9d5d768
1e7763d
 
9d5d768
911983c
9d5d768
1e7763d
 
9d5d768
911983c
9d5d768
1e7763d
 
9d5d768
 
9fbcce5
9d5d768
 
9fbcce5
1e7763d
9d5d768
 
 
1e7763d
9d5d768
 
 
1e7763d
9d5d768
 
1e7763d
9d5d768
f1a8131
9d5d768
9fbcce5
1e7763d
9d5d768
a560c26
9fbcce5
f1a8131
9fbcce5
 
f1a8131
 
 
9fbcce5
f1a8131
 
 
a560c26
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import os
from typing import Callable

from clu import checkpoint
from flax import linen as nn
import gradio as gr
from huggingface_hub import snapshot_download
import jax
import jax.numpy as jnp
import numpy as np
from PIL import Image

from invariant_slot_attention.configs.clevr_with_masks.equiv_transl_scale_v2 import get_config
from invariant_slot_attention.lib import utils


def load_model(config, checkpoint_dir):
  rng = jax.random.PRNGKey(42)

  # Initialize model
  model = utils.build_model_from_config(config.model)

  def init_model(rng):
    rng, init_rng, model_rng, dropout_rng = jax.random.split(rng, num=4)

    init_conditioning = None
    init_inputs = jnp.ones([1, 1, 128, 128, 3], jnp.float32)
    initial_vars = model.init(
        {"params": model_rng, "state_init": init_rng, "dropout": dropout_rng},
        video=init_inputs, conditioning=init_conditioning,
        padding_mask=jnp.ones(init_inputs.shape[:-1], jnp.int32))
 
    # Split into state variables (e.g. for batchnorm stats) and model params.
    # Note that `pop()` on a FrozenDict performs a deep copy.
    state_vars, initial_params = initial_vars.pop("params")  # pytype: disable=attribute-error

    # Filter out intermediates (we don't want to store these in the TrainState).
    state_vars = utils.filter_key_from_frozen_dict(
        state_vars, key="intermediates")
    return state_vars, initial_params

  state_vars, initial_params = init_model(rng)

  opt_state = None
  state = utils.TrainState(
      step=1, opt_state=opt_state, params=initial_params, rng=rng,
      variables=state_vars)

  ckpt = checkpoint.Checkpoint(checkpoint_dir)
  state = ckpt.restore(state, checkpoint=checkpoint_dir + "/ckpt-0")

  return model, state, rng


def load_image(name):
  img = Image.open(f"images/{name}.png")
  img = img.crop((64, 29, 64 + 192, 29 + 192))
  img = img.resize((128, 128))
  img = np.array(img)[:, :, :3] / 255.
  img = jnp.array(img, dtype=jnp.float32)
  return img


download_path = snapshot_download(repo_id="ondrejbiza/isa", allow_patterns="clevr_isa_ts_v2*")
checkpoint_dir = os.path.join(download_path, "clevr_isa_ts_v2")

model, state, rng = load_model(get_config(), checkpoint_dir)

rng, init_rng = jax.random.split(rng, num=2)


class DecoderWrapper(nn.Module):
    decoder: Callable[[], nn.Module]
    @nn.compact
    def __call__(self, slots, train=False):
        return self.decoder()(slots, train)
decoder_model = DecoderWrapper(decoder=model.decoder)

with gr.Blocks() as demo:

    local_slots = gr.State(np.zeros((11, 64), dtype=np.float32))

    local_orig_pos = gr.State(np.zeros((11, 2), dtype=np.float32))
    local_orig_scale = gr.State(np.zeros((11, 2), dtype=np.float32))

    local_pos = gr.State(np.zeros((11, 2), dtype=np.float32))
    local_scale = gr.State(np.ones((11, 2), dtype=np.float32))

    local_probs = gr.State(np.zeros((11, 128, 128), dtype=np.float32))

    with gr.Row():

        gr_choose_image = gr.Dropdown(
            [f"img{i}" for i in range(1, 9)], label="CLEVR Image", info="Start by a picking an image from the CLEVR dataset."
        )

    with gr.Row():

      with gr.Column():

        with gr.Row():
          with gr.Column():
            gr_image_1 = gr.Image(type="numpy", shape=(112, 112), source="canvas", label="Decoding")
          with gr.Column():
            gr_image_2 = gr.Image(type="numpy", shape=(112, 112), source="canvas", label="Segmentation")

      with gr.Column():
        gr_slot_slider = gr.Slider(1, 11, value=1, step=1, label="Slot Index",
                                   info="Change slot index too see the segmentation mask, position and scale of each slot.")

        gr_y_slider = gr.Slider(-1, 1, value=0, step=0.01, label="X")
        gr_x_slider = gr.Slider(-1, 1, value=0, step=0.01, label="Y")
        gr_sy_slider = gr.Slider(0.5, 1.5, value=1., step=0.1, label="Width Multiplier")
        gr_sx_slider = gr.Slider(0.5, 1.5, value=1., step=0.1, label="Height Multiplier")

        with gr.Row():
          with gr.Column():
            gr_button_render = gr.Button("Render", variant="primary", info="Render a new image with altered positions and scales.")
          with gr.Column():
            gr_button_reset = gr.Button("Reset", info="Reset slot statistics.")

    def update_image_and_segmentation(name, idx):
      idx = idx - 1

      img_input = load_image(name)
      out = model.apply(
        {"params": state.params, **state.variables},
        video=img_input[None, None],
        rngs={"state_init": init_rng},
        train=False)
      
      probs = np.array(nn.softmax(out["outputs"]["segmentation_logits"][0, 0, :, :, :, 0], axis=0))
      img = np.array(out["outputs"]["video"][0, 0])
      img = np.clip(img, 0, 1)

      slots_ = np.array(out["states"])
      slots = slots_[0, 0, :, :-4]
      pos = slots_[0, 0, :, -4: -2]
      scale = slots_[0, 0, :, -2:]

      return (img * 255).astype(np.uint8), (probs[idx] * 255).astype(np.uint8), float(pos[idx, 0]), \
             float(pos[idx, 1]), probs, slots, pos, np.ones((11, 2), dtype=np.float32), pos, scale

    gr_choose_image.change(
       fn=update_image_and_segmentation,
       inputs=[gr_choose_image, gr_slot_slider],
       outputs=[gr_image_1, gr_image_2, gr_x_slider, gr_y_slider, local_probs,
                local_slots, local_pos, local_scale, local_orig_pos, local_orig_scale]
    )

    def update_sliders(idx, local_probs, local_pos, local_scale):
      idx = idx - 1  # 1-indexing to 0-indexing
      return (local_probs[idx] * 255).astype(np.uint8), float(local_pos[idx, 0]), \
             float(local_pos[idx, 1]), float(local_scale[idx, 0]), float(local_scale[idx, 1])

    gr_slot_slider.release(
      fn=update_sliders,
      inputs=[gr_slot_slider, local_probs, local_pos, local_scale],
      outputs=[gr_image_2, gr_x_slider, gr_y_slider, gr_sx_slider, gr_sy_slider]
    )

    def update_pos_x(idx, val, local_pos):
       local_pos[idx - 1, 0] = val
       return local_pos

    def update_pos_y(idx, val, local_pos):
       local_pos[idx - 1, 1] = val
       return local_pos

    def update_scale_x(idx, val, local_scale):
       local_scale[idx - 1, 0] = val
       return local_scale

    def update_scale_y(idx, val, local_scale):
       local_scale[idx - 1, 1] = val
       return local_scale

    gr_x_slider.release(
       fn=update_pos_x,
       inputs=[gr_slot_slider, gr_x_slider, local_pos],
       outputs=local_pos
    )
    gr_y_slider.release(
       fn=update_pos_y,
       inputs=[gr_slot_slider, gr_y_slider, local_pos],
       outputs=local_pos
    )
    gr_sx_slider.release(
       fn=update_scale_x,
       inputs=[gr_slot_slider, gr_sx_slider, local_scale],
       outputs=local_scale
    )
    gr_sy_slider.release(
       fn=update_scale_y,
       inputs=[gr_slot_slider, gr_sy_slider, local_scale],
       outputs=local_scale
    )

    def render(idx, local_slots, local_pos, local_scale, local_orig_scale):
      idx = idx - 1

      slots = np.concatenate([local_slots, local_pos, local_scale * local_orig_scale], axis=-1)
      slots = jnp.array(slots)

      out = decoder_model.apply(
        {"params": state.params, **state.variables},
        slots=slots[None, None],
        train=False
      )

      probs = np.array(nn.softmax(out["segmentation_logits"][0, 0, :, :, :, 0], axis=0))
      image = np.array(out["video"][0, 0])
      image = np.clip(image, 0, 1)
      return (image * 255).astype(np.uint8), (probs[idx] * 255).astype(np.uint8), probs

    gr_button_render.click(
        fn=render,
        inputs=[gr_slot_slider, local_slots, local_pos, local_scale, local_orig_scale],
        outputs=[gr_image_1, gr_image_2, local_probs]
    )

    def reset(idx, local_orig_pos):
       idx = idx - 1
       return np.copy(local_orig_pos), np.ones((11, 2), dtype=np.float32), float(local_orig_pos[idx, 0]), \
              float(local_orig_pos[idx, 1]), 1., 1.

    gr_button_reset.click(
        fn=reset,
        inputs=[gr_slot_slider, local_orig_pos],
        outputs=[local_pos, local_scale, gr_x_slider, gr_y_slider, gr_sx_slider, gr_sy_slider]
    )

demo.launch()