import functools import os from absl import flags from clu import checkpoint import gradio as gr import jax import jax.numpy as jnp from invariant_slot_attention.configs.clevr_with_masks.equiv_transl_scale import get_config from invariant_slot_attention.lib import input_pipeline from invariant_slot_attention.lib import utils def load_model(config): rng = jax.random.PRNGKey(42) rng, data_rng = jax.random.split(rng) # Initialize model model = utils.build_model_from_config(config.model) def init_model(rng): rng, init_rng, model_rng, dropout_rng = jax.random.split(rng, num=4) init_conditioning = None init_inputs = jnp.ones([1, 1, 128, 128, 3], jnp.float32) initial_vars = model.init( {"params": model_rng, "state_init": init_rng, "dropout": dropout_rng}, video=init_inputs, conditioning=init_conditioning, padding_mask=jnp.ones(init_inputs.shape[:-1], jnp.int32)) # Split into state variables (e.g. for batchnorm stats) and model params. # Note that `pop()` on a FrozenDict performs a deep copy. state_vars, initial_params = initial_vars.pop("params") # pytype: disable=attribute-error # Filter out intermediates (we don't want to store these in the TrainState). state_vars = utils.filter_key_from_frozen_dict( state_vars, key="intermediates") return state_vars, initial_params state_vars, initial_params = init_model(rng) opt_state = None state = utils.TrainState( step=1, opt_state=opt_state, params=initial_params, rng=rng, variables=state_vars) checkpoint_dir = "clevr_isa_ts/checkpoints-0" ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir) state = ckpt.restore_or_initialize(state) init_inputs = jnp.ones([1, 1, 128, 128, 3], jnp.float32) rng, init_rng = jax.random.split(rng, num=2) out = model.apply( {"params": state.params, **state.variables}, video=init_inputs, rngs={"state_init": init_rng}, train=False) print(out.keys()) def greet(name): return "Hello " + name + "!" load_model(get_config()) demo = gr.Interface(fn=greet, inputs="text", outputs="text") demo.launch()