Spaces:
Running
Running
import functools | |
import os | |
from absl import flags | |
from clu import checkpoint | |
import gradio as gr | |
import jax | |
import jax.numpy as jnp | |
from invariant_slot_attention.configs.clevr_with_masks.equiv_transl_scale import get_config | |
from invariant_slot_attention.lib import input_pipeline | |
from invariant_slot_attention.lib import utils | |
def load_model(config): | |
rng = jax.random.PRNGKey(42) | |
rng, data_rng = jax.random.split(rng) | |
# Initialize model | |
model = utils.build_model_from_config(config.model) | |
def init_model(rng): | |
rng, init_rng, model_rng, dropout_rng = jax.random.split(rng, num=4) | |
init_conditioning = None | |
init_inputs = jnp.ones([1, 1, 128, 128, 3], jnp.float32) | |
initial_vars = model.init( | |
{"params": model_rng, "state_init": init_rng, "dropout": dropout_rng}, | |
video=init_inputs, conditioning=init_conditioning, | |
padding_mask=jnp.ones(init_inputs.shape[:-1], jnp.int32)) | |
# Split into state variables (e.g. for batchnorm stats) and model params. | |
# Note that `pop()` on a FrozenDict performs a deep copy. | |
state_vars, initial_params = initial_vars.pop("params") # pytype: disable=attribute-error | |
# Filter out intermediates (we don't want to store these in the TrainState). | |
state_vars = utils.filter_key_from_frozen_dict( | |
state_vars, key="intermediates") | |
return state_vars, initial_params | |
state_vars, initial_params = init_model(rng) | |
opt_state = None | |
state = utils.TrainState( | |
step=1, opt_state=opt_state, params=initial_params, rng=rng, | |
variables=state_vars) | |
checkpoint_dir = "clevr_isa_ts/checkpoints-0" | |
ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir) | |
state = ckpt.restore_or_initialize(state) | |
init_inputs = jnp.ones([1, 1, 128, 128, 3], jnp.float32) | |
rng, init_rng = jax.random.split(rng, num=2) | |
out = model.apply( | |
{"params": state.params, **state.variables}, | |
video=init_inputs, | |
rngs={"state_init": init_rng}, | |
train=False) | |
print(out.keys()) | |
def greet(name): | |
return "Hello " + name + "!" | |
load_model(get_config()) | |
demo = gr.Interface(fn=greet, inputs="text", outputs="text") | |
demo.launch() | |