isa / app.py
ondrejbiza's picture
Update app, add CLEVR images.
65d6890
import functools
import os
from absl import flags
from clu import checkpoint
import gradio as gr
import jax
import jax.numpy as jnp
from invariant_slot_attention.configs.clevr_with_masks.equiv_transl_scale import get_config
from invariant_slot_attention.lib import input_pipeline
from invariant_slot_attention.lib import utils
def load_model(config):
rng = jax.random.PRNGKey(42)
rng, data_rng = jax.random.split(rng)
# Initialize model
model = utils.build_model_from_config(config.model)
def init_model(rng):
rng, init_rng, model_rng, dropout_rng = jax.random.split(rng, num=4)
init_conditioning = None
init_inputs = jnp.ones([1, 1, 128, 128, 3], jnp.float32)
initial_vars = model.init(
{"params": model_rng, "state_init": init_rng, "dropout": dropout_rng},
video=init_inputs, conditioning=init_conditioning,
padding_mask=jnp.ones(init_inputs.shape[:-1], jnp.int32))
# Split into state variables (e.g. for batchnorm stats) and model params.
# Note that `pop()` on a FrozenDict performs a deep copy.
state_vars, initial_params = initial_vars.pop("params") # pytype: disable=attribute-error
# Filter out intermediates (we don't want to store these in the TrainState).
state_vars = utils.filter_key_from_frozen_dict(
state_vars, key="intermediates")
return state_vars, initial_params
state_vars, initial_params = init_model(rng)
opt_state = None
state = utils.TrainState(
step=1, opt_state=opt_state, params=initial_params, rng=rng,
variables=state_vars)
checkpoint_dir = "clevr_isa_ts/checkpoints-0"
ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir)
state = ckpt.restore_or_initialize(state)
init_inputs = jnp.ones([1, 1, 128, 128, 3], jnp.float32)
rng, init_rng = jax.random.split(rng, num=2)
out = model.apply(
{"params": state.params, **state.variables},
video=init_inputs,
rngs={"state_init": init_rng},
train=False)
print(out.keys())
def greet(name):
return "Hello " + name + "!"
load_model(get_config())
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
demo.launch()