ondrejbiza commited on
Commit
8530ae8
1 Parent(s): 9b6ff29

New model.

Browse files
Files changed (1) hide show
  1. app.py +5 -10
app.py CHANGED
@@ -1,19 +1,14 @@
1
- import functools
2
  import os
3
 
4
- from absl import flags
5
  from clu import checkpoint
6
  import gradio as gr
7
  import jax
8
  import jax.numpy as jnp
9
  import numpy as np
10
  from PIL import Image
11
- import tensorflow as tf
12
  from huggingface_hub import snapshot_download
13
 
14
  from invariant_slot_attention.configs.clevr_with_masks.equiv_transl_scale import get_config
15
- from invariant_slot_attention.lib import input_pipeline
16
- from invariant_slot_attention.lib import preprocessing
17
  from invariant_slot_attention.lib import utils
18
 
19
 
@@ -33,7 +28,7 @@ def load_model(config, checkpoint_dir):
33
  {"params": model_rng, "state_init": init_rng, "dropout": dropout_rng},
34
  video=init_inputs, conditioning=init_conditioning,
35
  padding_mask=jnp.ones(init_inputs.shape[:-1], jnp.int32))
36
-
37
  # Split into state variables (e.g. for batchnorm stats) and model params.
38
  # Note that `pop()` on a FrozenDict performs a deep copy.
39
  state_vars, initial_params = initial_vars.pop("params") # pytype: disable=attribute-error
@@ -47,11 +42,11 @@ def load_model(config, checkpoint_dir):
47
 
48
  opt_state = None
49
  state = utils.TrainState(
50
- step=42, opt_state=opt_state, params=initial_params, rng=rng,
51
  variables=state_vars)
52
 
53
- ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir)
54
- state = ckpt.restore(state)
55
 
56
  return model, state, rng
57
 
@@ -67,7 +62,7 @@ def load_image(name):
67
 
68
 
69
  download_path = snapshot_download(repo_id="ondrejbiza/isa")
70
- checkpoint_dir = os.path.join(download_path, "clevr_isa_ts", "checkpoints")
71
 
72
  model, state, rng = load_model(get_config(), checkpoint_dir)
73
 
 
 
1
  import os
2
 
 
3
  from clu import checkpoint
4
  import gradio as gr
5
  import jax
6
  import jax.numpy as jnp
7
  import numpy as np
8
  from PIL import Image
 
9
  from huggingface_hub import snapshot_download
10
 
11
  from invariant_slot_attention.configs.clevr_with_masks.equiv_transl_scale import get_config
 
 
12
  from invariant_slot_attention.lib import utils
13
 
14
 
 
28
  {"params": model_rng, "state_init": init_rng, "dropout": dropout_rng},
29
  video=init_inputs, conditioning=init_conditioning,
30
  padding_mask=jnp.ones(init_inputs.shape[:-1], jnp.int32))
31
+
32
  # Split into state variables (e.g. for batchnorm stats) and model params.
33
  # Note that `pop()` on a FrozenDict performs a deep copy.
34
  state_vars, initial_params = initial_vars.pop("params") # pytype: disable=attribute-error
 
42
 
43
  opt_state = None
44
  state = utils.TrainState(
45
+ step=1, opt_state=opt_state, params=initial_params, rng=rng,
46
  variables=state_vars)
47
 
48
+ ckpt = checkpoint.Checkpoint(checkpoint_dir)
49
+ state = ckpt.restore(state, checkpoint=checkpoint_dir + "/ckpt-0")
50
 
51
  return model, state, rng
52
 
 
62
 
63
 
64
  download_path = snapshot_download(repo_id="ondrejbiza/isa")
65
+ checkpoint_dir = os.path.join(download_path, "clevr_isa_ts")
66
 
67
  model, state, rng = load_model(get_config(), checkpoint_dir)
68