Spaces:
Running
Running
ondrejbiza
commited on
Commit
•
8530ae8
1
Parent(s):
9b6ff29
New model.
Browse files
app.py
CHANGED
@@ -1,19 +1,14 @@
|
|
1 |
-
import functools
|
2 |
import os
|
3 |
|
4 |
-
from absl import flags
|
5 |
from clu import checkpoint
|
6 |
import gradio as gr
|
7 |
import jax
|
8 |
import jax.numpy as jnp
|
9 |
import numpy as np
|
10 |
from PIL import Image
|
11 |
-
import tensorflow as tf
|
12 |
from huggingface_hub import snapshot_download
|
13 |
|
14 |
from invariant_slot_attention.configs.clevr_with_masks.equiv_transl_scale import get_config
|
15 |
-
from invariant_slot_attention.lib import input_pipeline
|
16 |
-
from invariant_slot_attention.lib import preprocessing
|
17 |
from invariant_slot_attention.lib import utils
|
18 |
|
19 |
|
@@ -33,7 +28,7 @@ def load_model(config, checkpoint_dir):
|
|
33 |
{"params": model_rng, "state_init": init_rng, "dropout": dropout_rng},
|
34 |
video=init_inputs, conditioning=init_conditioning,
|
35 |
padding_mask=jnp.ones(init_inputs.shape[:-1], jnp.int32))
|
36 |
-
|
37 |
# Split into state variables (e.g. for batchnorm stats) and model params.
|
38 |
# Note that `pop()` on a FrozenDict performs a deep copy.
|
39 |
state_vars, initial_params = initial_vars.pop("params") # pytype: disable=attribute-error
|
@@ -47,11 +42,11 @@ def load_model(config, checkpoint_dir):
|
|
47 |
|
48 |
opt_state = None
|
49 |
state = utils.TrainState(
|
50 |
-
step=
|
51 |
variables=state_vars)
|
52 |
|
53 |
-
ckpt = checkpoint.
|
54 |
-
state = ckpt.restore(state)
|
55 |
|
56 |
return model, state, rng
|
57 |
|
@@ -67,7 +62,7 @@ def load_image(name):
|
|
67 |
|
68 |
|
69 |
download_path = snapshot_download(repo_id="ondrejbiza/isa")
|
70 |
-
checkpoint_dir = os.path.join(download_path, "clevr_isa_ts"
|
71 |
|
72 |
model, state, rng = load_model(get_config(), checkpoint_dir)
|
73 |
|
|
|
|
|
1 |
import os
|
2 |
|
|
|
3 |
from clu import checkpoint
|
4 |
import gradio as gr
|
5 |
import jax
|
6 |
import jax.numpy as jnp
|
7 |
import numpy as np
|
8 |
from PIL import Image
|
|
|
9 |
from huggingface_hub import snapshot_download
|
10 |
|
11 |
from invariant_slot_attention.configs.clevr_with_masks.equiv_transl_scale import get_config
|
|
|
|
|
12 |
from invariant_slot_attention.lib import utils
|
13 |
|
14 |
|
|
|
28 |
{"params": model_rng, "state_init": init_rng, "dropout": dropout_rng},
|
29 |
video=init_inputs, conditioning=init_conditioning,
|
30 |
padding_mask=jnp.ones(init_inputs.shape[:-1], jnp.int32))
|
31 |
+
|
32 |
# Split into state variables (e.g. for batchnorm stats) and model params.
|
33 |
# Note that `pop()` on a FrozenDict performs a deep copy.
|
34 |
state_vars, initial_params = initial_vars.pop("params") # pytype: disable=attribute-error
|
|
|
42 |
|
43 |
opt_state = None
|
44 |
state = utils.TrainState(
|
45 |
+
step=1, opt_state=opt_state, params=initial_params, rng=rng,
|
46 |
variables=state_vars)
|
47 |
|
48 |
+
ckpt = checkpoint.Checkpoint(checkpoint_dir)
|
49 |
+
state = ckpt.restore(state, checkpoint=checkpoint_dir + "/ckpt-0")
|
50 |
|
51 |
return model, state, rng
|
52 |
|
|
|
62 |
|
63 |
|
64 |
download_path = snapshot_download(repo_id="ondrejbiza/isa")
|
65 |
+
checkpoint_dir = os.path.join(download_path, "clevr_isa_ts")
|
66 |
|
67 |
model, state, rng = load_model(get_config(), checkpoint_dir)
|
68 |
|