Spaces:
Runtime error
Runtime error
ondrejbiza
commited on
Commit
•
65d6890
1
Parent(s):
94c5fd1
Update app, add CLEVR images.
Browse files- .DS_Store +0 -0
- app.py +15 -13
- images/img1.png +0 -0
- images/img2.png +0 -0
- images/img3.png +0 -0
- images/img4.png +0 -0
- images/img5.png +0 -0
- images/img6.png +0 -0
- images/img7.png +0 -0
- images/img8.png +0 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
app.py
CHANGED
@@ -2,6 +2,7 @@ import functools
|
|
2 |
import os
|
3 |
|
4 |
from absl import flags
|
|
|
5 |
import gradio as gr
|
6 |
import jax
|
7 |
import jax.numpy as jnp
|
@@ -12,6 +13,7 @@ from invariant_slot_attention.lib import utils
|
|
12 |
|
13 |
|
14 |
def load_model(config):
|
|
|
15 |
rng, data_rng = jax.random.split(rng)
|
16 |
|
17 |
# Initialize model
|
@@ -21,9 +23,7 @@ def load_model(config):
|
|
21 |
rng, init_rng, model_rng, dropout_rng = jax.random.split(rng, num=4)
|
22 |
|
23 |
init_conditioning = None
|
24 |
-
init_inputs = jnp.ones(
|
25 |
-
[1] + list(train_ds.element_spec["video"].shape)[2:],
|
26 |
-
jnp.float32)
|
27 |
initial_vars = model.init(
|
28 |
{"params": model_rng, "state_init": init_rng, "dropout": dropout_rng},
|
29 |
video=init_inputs, conditioning=init_conditioning,
|
@@ -40,27 +40,29 @@ def load_model(config):
|
|
40 |
|
41 |
state_vars, initial_params = init_model(rng)
|
42 |
|
43 |
-
|
44 |
-
tx = optimizers.get_optimizer(
|
45 |
-
config.optimizer_configs, learning_rate_fn, params=initial_params)
|
46 |
-
|
47 |
-
opt_state = tx.init(initial_params)
|
48 |
-
|
49 |
state = utils.TrainState(
|
50 |
step=1, opt_state=opt_state, params=initial_params, rng=rng,
|
51 |
variables=state_vars)
|
52 |
|
53 |
-
|
54 |
-
losses.compute_full_loss, loss_config=config.losses)
|
55 |
-
|
56 |
-
checkpoint_dir = os.path.join(workdir, "checkpoints")
|
57 |
ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir)
|
58 |
state = ckpt.restore_or_initialize(state)
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
def greet(name):
|
62 |
return "Hello " + name + "!"
|
63 |
|
64 |
|
|
|
65 |
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
66 |
demo.launch()
|
|
|
2 |
import os
|
3 |
|
4 |
from absl import flags
|
5 |
+
from clu import checkpoint
|
6 |
import gradio as gr
|
7 |
import jax
|
8 |
import jax.numpy as jnp
|
|
|
13 |
|
14 |
|
15 |
def load_model(config):
|
16 |
+
rng = jax.random.PRNGKey(42)
|
17 |
rng, data_rng = jax.random.split(rng)
|
18 |
|
19 |
# Initialize model
|
|
|
23 |
rng, init_rng, model_rng, dropout_rng = jax.random.split(rng, num=4)
|
24 |
|
25 |
init_conditioning = None
|
26 |
+
init_inputs = jnp.ones([1, 1, 128, 128, 3], jnp.float32)
|
|
|
|
|
27 |
initial_vars = model.init(
|
28 |
{"params": model_rng, "state_init": init_rng, "dropout": dropout_rng},
|
29 |
video=init_inputs, conditioning=init_conditioning,
|
|
|
40 |
|
41 |
state_vars, initial_params = init_model(rng)
|
42 |
|
43 |
+
opt_state = None
|
|
|
|
|
|
|
|
|
|
|
44 |
state = utils.TrainState(
|
45 |
step=1, opt_state=opt_state, params=initial_params, rng=rng,
|
46 |
variables=state_vars)
|
47 |
|
48 |
+
checkpoint_dir = "clevr_isa_ts/checkpoints-0"
|
|
|
|
|
|
|
49 |
ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir)
|
50 |
state = ckpt.restore_or_initialize(state)
|
51 |
|
52 |
+
init_inputs = jnp.ones([1, 1, 128, 128, 3], jnp.float32)
|
53 |
+
rng, init_rng = jax.random.split(rng, num=2)
|
54 |
+
out = model.apply(
|
55 |
+
{"params": state.params, **state.variables},
|
56 |
+
video=init_inputs,
|
57 |
+
rngs={"state_init": init_rng},
|
58 |
+
train=False)
|
59 |
+
print(out.keys())
|
60 |
+
|
61 |
|
62 |
def greet(name):
|
63 |
return "Hello " + name + "!"
|
64 |
|
65 |
|
66 |
+
load_model(get_config())
|
67 |
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
68 |
demo.launch()
|
images/img1.png
ADDED
images/img2.png
ADDED
images/img3.png
ADDED
images/img4.png
ADDED
images/img5.png
ADDED
images/img6.png
ADDED
images/img7.png
ADDED
images/img8.png
ADDED