import functools import os from absl import flags from clu import checkpoint import gradio as gr import jax import jax.numpy as jnp import numpy as np from PIL import Image import tensorflow as tf from huggingface_hub import snapshot_download from invariant_slot_attention.configs.clevr_with_masks.equiv_transl_scale import get_config from invariant_slot_attention.lib import input_pipeline from invariant_slot_attention.lib import preprocessing from invariant_slot_attention.lib import utils def load_model(config, checkpoint_dir): rng = jax.random.PRNGKey(42) rng, data_rng = jax.random.split(rng) # Initialize model model = utils.build_model_from_config(config.model) def init_model(rng): rng, init_rng, model_rng, dropout_rng = jax.random.split(rng, num=4) init_conditioning = None init_inputs = jnp.ones([1, 1, 128, 128, 3], jnp.float32) initial_vars = model.init( {"params": model_rng, "state_init": init_rng, "dropout": dropout_rng}, video=init_inputs, conditioning=init_conditioning, padding_mask=jnp.ones(init_inputs.shape[:-1], jnp.int32)) # Split into state variables (e.g. for batchnorm stats) and model params. # Note that `pop()` on a FrozenDict performs a deep copy. state_vars, initial_params = initial_vars.pop("params") # pytype: disable=attribute-error # Filter out intermediates (we don't want to store these in the TrainState). state_vars = utils.filter_key_from_frozen_dict( state_vars, key="intermediates") return state_vars, initial_params state_vars, initial_params = init_model(rng) opt_state = None state = utils.TrainState( step=42, opt_state=opt_state, params=initial_params, rng=rng, variables=state_vars) ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir) state = ckpt.restore(state) return model, state, rng def load_image(name): img = Image.open(f"images/{name}.png") img = img.crop((64, 29, 64 + 192, 29 + 192)) img = img.resize((128, 128)) img_ = np.array(img) img = np.array(img)[:, :, :3] / 255. img = jnp.array(img, dtype=jnp.float32) return img, img_ download_path = snapshot_download(repo_id="ondrejbiza/isa") checkpoint_dir = os.path.join(download_path, "clevr_isa_ts", "checkpoints") model, state, rng = load_model(get_config(), checkpoint_dir) rng, init_rng = jax.random.split(rng, num=2) from flax import linen as nn from typing import Callable class DecoderWrapper(nn.Module): decoder: Callable[[], nn.Module] @nn.compact def __call__(self, slots, train=False): return self.decoder()(slots, train) decoder_model = DecoderWrapper(decoder=model.decoder) slots = np.zeros((11, 64), dtype=np.float32) pos = np.zeros((11, 2), dtype=np.float32) scale = np.zeros((11, 2), dtype=np.float32) probs = np.zeros((11, 128, 128), dtype=np.float32) with gr.Blocks() as demo: with gr.Row(): with gr.Column(): gr_choose_image = gr.Dropdown( [f"img{i}" for i in range(1, 9)], label="CLEVR Image", info="Start by a picking an image from the CLEVR dataset." ) gr_image_1 = gr.Image(type="numpy") gr_image_2 = gr.Image(type="numpy") with gr.Column(): gr_slot_slider = gr.Slider(1, 11, value=1, step=1, label="Slot") gr_y_slider = gr.Slider(-1, 1, value=0, step=0.01, label="x") gr_x_slider = gr.Slider(-1, 1, value=0, step=0.01, label="y") gr_sy_slider = gr.Slider(0.01, 1, value=0.1, step=0.01, label="width") gr_sx_slider = gr.Slider(0.01, 1, value=0.1, step=0.01, label="height") gr_button = gr.Button("Render") def update_image_and_segmentation(name, idx): idx = idx - 1 img_input, img = load_image(name) out = model.apply( {"params": state.params, **state.variables}, video=img_input[None, None], rngs={"state_init": init_rng}, train=False) probs[:] = nn.softmax(out["outputs"]["segmentation_logits"][0, 0, :, :, :, 0], axis=0) slots_ = out["states"] slots[:] = slots_[0, 0, :, :-4] pos[:] = slots_[0, 0, :, -4: -2] scale[:] = slots_[0, 0, :, -2:] return img, (probs[idx] * 255).astype(np.uint8), float(pos[idx, 0]), \ float(pos[idx, 1]), float(scale[idx, 0]), float(scale[idx, 1]) gr_choose_image.change( fn=update_image_and_segmentation, inputs=[gr_choose_image, gr_slot_slider], outputs=[gr_image_1, gr_image_2, gr_x_slider, gr_y_slider, gr_sx_slider, gr_sy_slider] ) def update_sliders(idx): idx = idx - 1 # 1-indexing to 0-indexing return (probs[idx] * 255).astype(np.uint8), float(pos[idx, 0]), \ float(pos[idx, 1]), float(scale[idx, 0]), float(scale[idx, 1]) gr_slot_slider.change( fn=update_sliders, inputs=gr_slot_slider, outputs=[gr_image_2, gr_x_slider, gr_y_slider, gr_sx_slider, gr_sy_slider] ) def update_pos_x(idx, val): pos[idx - 1, 0] = val def update_pos_y(idx, val): pos[idx - 1, 1] = val def update_scale_x(idx, val): scale[idx - 1, 0] = val def update_scale_y(idx, val): scale[idx - 1, 1] = val gr_x_slider.change( fn=update_pos_x, inputs=[gr_slot_slider, gr_x_slider] ) gr_y_slider.change( fn=update_pos_y, inputs=[gr_slot_slider, gr_y_slider] ) gr_sx_slider.change( fn=update_scale_x, inputs=[gr_slot_slider, gr_sx_slider] ) gr_sy_slider.change( fn=update_scale_y, inputs=[gr_slot_slider, gr_sy_slider] ) def render(idx): idx = idx - 1 slots_ = np.concatenate([slots, pos, scale], axis=-1) slots_ = jnp.array(slots_) out = decoder_model.apply( {"params": state.params, **state.variables}, slots=slots_[None, None], train=False ) probs[:] = nn.softmax(out["segmentation_logits"][0, 0, :, :, :, 0], axis=0) image = np.array(out["video"][0, 0]) image = np.clip(image, 0, 1) return (image * 255).astype(np.uint8), (probs[idx] * 255).astype(np.uint8) gr_button.click( fn=render, inputs=gr_slot_slider, outputs=[gr_image_1, gr_image_2] ) demo.launch()