Spaces:
Build error
Build error
Commit
·
65d6890
1
Parent(s):
94c5fd1
Update app, add CLEVR images.
Browse files- .DS_Store +0 -0
- app.py +15 -13
- images/img1.png +0 -0
- images/img2.png +0 -0
- images/img3.png +0 -0
- images/img4.png +0 -0
- images/img5.png +0 -0
- images/img6.png +0 -0
- images/img7.png +0 -0
- images/img8.png +0 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
app.py
CHANGED
|
@@ -2,6 +2,7 @@ import functools
|
|
| 2 |
import os
|
| 3 |
|
| 4 |
from absl import flags
|
|
|
|
| 5 |
import gradio as gr
|
| 6 |
import jax
|
| 7 |
import jax.numpy as jnp
|
|
@@ -12,6 +13,7 @@ from invariant_slot_attention.lib import utils
|
|
| 12 |
|
| 13 |
|
| 14 |
def load_model(config):
|
|
|
|
| 15 |
rng, data_rng = jax.random.split(rng)
|
| 16 |
|
| 17 |
# Initialize model
|
|
@@ -21,9 +23,7 @@ def load_model(config):
|
|
| 21 |
rng, init_rng, model_rng, dropout_rng = jax.random.split(rng, num=4)
|
| 22 |
|
| 23 |
init_conditioning = None
|
| 24 |
-
init_inputs = jnp.ones(
|
| 25 |
-
[1] + list(train_ds.element_spec["video"].shape)[2:],
|
| 26 |
-
jnp.float32)
|
| 27 |
initial_vars = model.init(
|
| 28 |
{"params": model_rng, "state_init": init_rng, "dropout": dropout_rng},
|
| 29 |
video=init_inputs, conditioning=init_conditioning,
|
|
@@ -40,27 +40,29 @@ def load_model(config):
|
|
| 40 |
|
| 41 |
state_vars, initial_params = init_model(rng)
|
| 42 |
|
| 43 |
-
|
| 44 |
-
tx = optimizers.get_optimizer(
|
| 45 |
-
config.optimizer_configs, learning_rate_fn, params=initial_params)
|
| 46 |
-
|
| 47 |
-
opt_state = tx.init(initial_params)
|
| 48 |
-
|
| 49 |
state = utils.TrainState(
|
| 50 |
step=1, opt_state=opt_state, params=initial_params, rng=rng,
|
| 51 |
variables=state_vars)
|
| 52 |
|
| 53 |
-
|
| 54 |
-
losses.compute_full_loss, loss_config=config.losses)
|
| 55 |
-
|
| 56 |
-
checkpoint_dir = os.path.join(workdir, "checkpoints")
|
| 57 |
ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir)
|
| 58 |
state = ckpt.restore_or_initialize(state)
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
def greet(name):
|
| 62 |
return "Hello " + name + "!"
|
| 63 |
|
| 64 |
|
|
|
|
| 65 |
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
| 66 |
demo.launch()
|
|
|
|
| 2 |
import os
|
| 3 |
|
| 4 |
from absl import flags
|
| 5 |
+
from clu import checkpoint
|
| 6 |
import gradio as gr
|
| 7 |
import jax
|
| 8 |
import jax.numpy as jnp
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
def load_model(config):
|
| 16 |
+
rng = jax.random.PRNGKey(42)
|
| 17 |
rng, data_rng = jax.random.split(rng)
|
| 18 |
|
| 19 |
# Initialize model
|
|
|
|
| 23 |
rng, init_rng, model_rng, dropout_rng = jax.random.split(rng, num=4)
|
| 24 |
|
| 25 |
init_conditioning = None
|
| 26 |
+
init_inputs = jnp.ones([1, 1, 128, 128, 3], jnp.float32)
|
|
|
|
|
|
|
| 27 |
initial_vars = model.init(
|
| 28 |
{"params": model_rng, "state_init": init_rng, "dropout": dropout_rng},
|
| 29 |
video=init_inputs, conditioning=init_conditioning,
|
|
|
|
| 40 |
|
| 41 |
state_vars, initial_params = init_model(rng)
|
| 42 |
|
| 43 |
+
opt_state = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
state = utils.TrainState(
|
| 45 |
step=1, opt_state=opt_state, params=initial_params, rng=rng,
|
| 46 |
variables=state_vars)
|
| 47 |
|
| 48 |
+
checkpoint_dir = "clevr_isa_ts/checkpoints-0"
|
|
|
|
|
|
|
|
|
|
| 49 |
ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir)
|
| 50 |
state = ckpt.restore_or_initialize(state)
|
| 51 |
|
| 52 |
+
init_inputs = jnp.ones([1, 1, 128, 128, 3], jnp.float32)
|
| 53 |
+
rng, init_rng = jax.random.split(rng, num=2)
|
| 54 |
+
out = model.apply(
|
| 55 |
+
{"params": state.params, **state.variables},
|
| 56 |
+
video=init_inputs,
|
| 57 |
+
rngs={"state_init": init_rng},
|
| 58 |
+
train=False)
|
| 59 |
+
print(out.keys())
|
| 60 |
+
|
| 61 |
|
| 62 |
def greet(name):
|
| 63 |
return "Hello " + name + "!"
|
| 64 |
|
| 65 |
|
| 66 |
+
load_model(get_config())
|
| 67 |
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
| 68 |
demo.launch()
|
images/img1.png
ADDED
|
images/img2.png
ADDED
|
images/img3.png
ADDED
|
images/img4.png
ADDED
|
images/img5.png
ADDED
|
images/img6.png
ADDED
|
images/img7.png
ADDED
|
images/img8.png
ADDED
|