ondrejbiza commited on
Commit
65d6890
1 Parent(s): 94c5fd1

Update app, add CLEVR images.

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
app.py CHANGED
@@ -2,6 +2,7 @@ import functools
2
  import os
3
 
4
  from absl import flags
 
5
  import gradio as gr
6
  import jax
7
  import jax.numpy as jnp
@@ -12,6 +13,7 @@ from invariant_slot_attention.lib import utils
12
 
13
 
14
  def load_model(config):
 
15
  rng, data_rng = jax.random.split(rng)
16
 
17
  # Initialize model
@@ -21,9 +23,7 @@ def load_model(config):
21
  rng, init_rng, model_rng, dropout_rng = jax.random.split(rng, num=4)
22
 
23
  init_conditioning = None
24
- init_inputs = jnp.ones(
25
- [1] + list(train_ds.element_spec["video"].shape)[2:],
26
- jnp.float32)
27
  initial_vars = model.init(
28
  {"params": model_rng, "state_init": init_rng, "dropout": dropout_rng},
29
  video=init_inputs, conditioning=init_conditioning,
@@ -40,27 +40,29 @@ def load_model(config):
40
 
41
  state_vars, initial_params = init_model(rng)
42
 
43
- learning_rate_fn = lr_schedules.get_learning_rate_fn(config)
44
- tx = optimizers.get_optimizer(
45
- config.optimizer_configs, learning_rate_fn, params=initial_params)
46
-
47
- opt_state = tx.init(initial_params)
48
-
49
  state = utils.TrainState(
50
  step=1, opt_state=opt_state, params=initial_params, rng=rng,
51
  variables=state_vars)
52
 
53
- loss_fn = functools.partial(
54
- losses.compute_full_loss, loss_config=config.losses)
55
-
56
- checkpoint_dir = os.path.join(workdir, "checkpoints")
57
  ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir)
58
  state = ckpt.restore_or_initialize(state)
59
 
 
 
 
 
 
 
 
 
 
60
 
61
  def greet(name):
62
  return "Hello " + name + "!"
63
 
64
 
 
65
  demo = gr.Interface(fn=greet, inputs="text", outputs="text")
66
  demo.launch()
2
  import os
3
 
4
  from absl import flags
5
+ from clu import checkpoint
6
  import gradio as gr
7
  import jax
8
  import jax.numpy as jnp
13
 
14
 
15
  def load_model(config):
16
+ rng = jax.random.PRNGKey(42)
17
  rng, data_rng = jax.random.split(rng)
18
 
19
  # Initialize model
23
  rng, init_rng, model_rng, dropout_rng = jax.random.split(rng, num=4)
24
 
25
  init_conditioning = None
26
+ init_inputs = jnp.ones([1, 1, 128, 128, 3], jnp.float32)
 
 
27
  initial_vars = model.init(
28
  {"params": model_rng, "state_init": init_rng, "dropout": dropout_rng},
29
  video=init_inputs, conditioning=init_conditioning,
40
 
41
  state_vars, initial_params = init_model(rng)
42
 
43
+ opt_state = None
 
 
 
 
 
44
  state = utils.TrainState(
45
  step=1, opt_state=opt_state, params=initial_params, rng=rng,
46
  variables=state_vars)
47
 
48
+ checkpoint_dir = "clevr_isa_ts/checkpoints-0"
 
 
 
49
  ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir)
50
  state = ckpt.restore_or_initialize(state)
51
 
52
+ init_inputs = jnp.ones([1, 1, 128, 128, 3], jnp.float32)
53
+ rng, init_rng = jax.random.split(rng, num=2)
54
+ out = model.apply(
55
+ {"params": state.params, **state.variables},
56
+ video=init_inputs,
57
+ rngs={"state_init": init_rng},
58
+ train=False)
59
+ print(out.keys())
60
+
61
 
62
  def greet(name):
63
  return "Hello " + name + "!"
64
 
65
 
66
+ load_model(get_config())
67
  demo = gr.Interface(fn=greet, inputs="text", outputs="text")
68
  demo.launch()
images/img1.png ADDED
images/img2.png ADDED
images/img3.png ADDED
images/img4.png ADDED
images/img5.png ADDED
images/img6.png ADDED
images/img7.png ADDED
images/img8.png ADDED