Spaces:
Build error
Build error
Commit
·
90e5776
1
Parent(s):
1ccf223
V2 config, revert requirements.
Browse files- .gitignore +2 -1
- app.py +3 -3
- invariant_slot_attention/configs/clevr_with_masks/equiv_transl_scale_v2.py +202 -0
- requirements.txt +1 -1
.gitignore
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
/venv
|
| 2 |
/flagged
|
| 3 |
/clevr_isa_ts
|
| 4 |
-
*.pyc
|
|
|
|
|
|
| 1 |
/venv
|
| 2 |
/flagged
|
| 3 |
/clevr_isa_ts
|
| 4 |
+
*.pyc
|
| 5 |
+
*.DS_Store
|
app.py
CHANGED
|
@@ -10,7 +10,7 @@ import jax.numpy as jnp
|
|
| 10 |
import numpy as np
|
| 11 |
from PIL import Image
|
| 12 |
|
| 13 |
-
from invariant_slot_attention.configs.clevr_with_masks.
|
| 14 |
from invariant_slot_attention.lib import utils
|
| 15 |
|
| 16 |
|
|
@@ -61,8 +61,8 @@ def load_image(name):
|
|
| 61 |
return img
|
| 62 |
|
| 63 |
|
| 64 |
-
download_path = snapshot_download(repo_id="ondrejbiza/isa", allow_patterns="
|
| 65 |
-
checkpoint_dir = os.path.join(download_path, "
|
| 66 |
|
| 67 |
model, state, rng = load_model(get_config(), checkpoint_dir)
|
| 68 |
|
|
|
|
| 10 |
import numpy as np
|
| 11 |
from PIL import Image
|
| 12 |
|
| 13 |
+
from invariant_slot_attention.configs.clevr_with_masks.equiv_transl_scale_v2 import get_config
|
| 14 |
from invariant_slot_attention.lib import utils
|
| 15 |
|
| 16 |
|
|
|
|
| 61 |
return img
|
| 62 |
|
| 63 |
|
| 64 |
+
download_path = snapshot_download(repo_id="ondrejbiza/isa", allow_patterns="clevr_isa_ts_v2*")
|
| 65 |
+
checkpoint_dir = os.path.join(download_path, "clevr_isa_ts_v2")
|
| 66 |
|
| 67 |
model, state, rng = load_model(get_config(), checkpoint_dir)
|
| 68 |
|
invariant_slot_attention/configs/clevr_with_masks/equiv_transl_scale_v2.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The Google Research Authors.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
r"""Config for unsupervised training on CLEVR."""
|
| 17 |
+
|
| 18 |
+
import ml_collections
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_config():
|
| 22 |
+
"""Get the default hyperparameter configuration."""
|
| 23 |
+
config = ml_collections.ConfigDict()
|
| 24 |
+
|
| 25 |
+
config.seed = 42
|
| 26 |
+
config.seed_data = True
|
| 27 |
+
|
| 28 |
+
config.batch_size = 64
|
| 29 |
+
config.num_train_steps = 500000 # from the original Slot Attention
|
| 30 |
+
config.init_checkpoint = ml_collections.ConfigDict()
|
| 31 |
+
config.init_checkpoint.xid = 0 # Disabled by default.
|
| 32 |
+
config.init_checkpoint.wid = 1
|
| 33 |
+
|
| 34 |
+
config.optimizer_configs = ml_collections.ConfigDict()
|
| 35 |
+
config.optimizer_configs.optimizer = "adam"
|
| 36 |
+
|
| 37 |
+
config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
|
| 38 |
+
config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
|
| 39 |
+
config.optimizer_configs.grad_clip.clip_value = 0.05
|
| 40 |
+
|
| 41 |
+
config.lr_configs = ml_collections.ConfigDict()
|
| 42 |
+
config.lr_configs.learning_rate_schedule = "compound"
|
| 43 |
+
config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
|
| 44 |
+
config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
|
| 45 |
+
config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
|
| 46 |
+
# from the original Slot Attention
|
| 47 |
+
config.lr_configs.base_learning_rate = 4e-4
|
| 48 |
+
|
| 49 |
+
config.eval_pad_last_batch = False # True
|
| 50 |
+
config.log_loss_every_steps = 50
|
| 51 |
+
config.eval_every_steps = 5000
|
| 52 |
+
config.checkpoint_every_steps = 5000
|
| 53 |
+
|
| 54 |
+
config.train_metrics_spec = {
|
| 55 |
+
"loss": "loss",
|
| 56 |
+
"ari": "ari",
|
| 57 |
+
"ari_nobg": "ari_nobg",
|
| 58 |
+
}
|
| 59 |
+
config.eval_metrics_spec = {
|
| 60 |
+
"eval_loss": "loss",
|
| 61 |
+
"eval_ari": "ari",
|
| 62 |
+
"eval_ari_nobg": "ari_nobg",
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
config.data = ml_collections.ConfigDict({
|
| 66 |
+
"dataset_name": "clevr_with_masks",
|
| 67 |
+
"shuffle_buffer_size": config.batch_size * 8,
|
| 68 |
+
"resolution": (128, 128)
|
| 69 |
+
})
|
| 70 |
+
|
| 71 |
+
config.max_instances = 11
|
| 72 |
+
config.num_slots = config.max_instances # Only used for metrics.
|
| 73 |
+
config.logging_min_n_colors = config.max_instances
|
| 74 |
+
|
| 75 |
+
config.preproc_train = [
|
| 76 |
+
"tfds_image_to_tfds_video",
|
| 77 |
+
"video_from_tfds",
|
| 78 |
+
"top_left_crop(top=29, left=64, height=192)",
|
| 79 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
| 80 |
+
]
|
| 81 |
+
|
| 82 |
+
config.preproc_eval = [
|
| 83 |
+
"tfds_image_to_tfds_video",
|
| 84 |
+
"video_from_tfds",
|
| 85 |
+
"top_left_crop(top=29, left=64, height=192)",
|
| 86 |
+
"resize_small({size})".format(size=min(*config.data.resolution))
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
config.eval_slice_size = 1
|
| 90 |
+
config.eval_slice_keys = ["video", "segmentations_video"]
|
| 91 |
+
|
| 92 |
+
# Dictionary of targets and corresponding channels. Losses need to match.
|
| 93 |
+
targets = {"video": 3}
|
| 94 |
+
config.losses = {"recon": {"targets": list(targets)}}
|
| 95 |
+
config.losses = ml_collections.ConfigDict({
|
| 96 |
+
f"recon_{target}": {"loss_type": "recon", "key": target}
|
| 97 |
+
for target in targets})
|
| 98 |
+
|
| 99 |
+
config.model = ml_collections.ConfigDict({
|
| 100 |
+
"module": "invariant_slot_attention.modules.SAVi",
|
| 101 |
+
|
| 102 |
+
# Encoder.
|
| 103 |
+
"encoder": ml_collections.ConfigDict({
|
| 104 |
+
"module": "invariant_slot_attention.modules.FrameEncoder",
|
| 105 |
+
"reduction": "spatial_flatten",
|
| 106 |
+
"backbone": ml_collections.ConfigDict({
|
| 107 |
+
"module": "invariant_slot_attention.modules.SimpleCNN",
|
| 108 |
+
"features": [64, 64, 64, 64],
|
| 109 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
|
| 110 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
|
| 111 |
+
}),
|
| 112 |
+
"pos_emb": ml_collections.ConfigDict({
|
| 113 |
+
"module": "invariant_slot_attention.modules.PositionEmbedding",
|
| 114 |
+
"embedding_type": "linear",
|
| 115 |
+
"update_type": "concat"
|
| 116 |
+
}),
|
| 117 |
+
}),
|
| 118 |
+
|
| 119 |
+
# Corrector.
|
| 120 |
+
"corrector": ml_collections.ConfigDict({
|
| 121 |
+
"module": "invariant_slot_attention.modules.SlotAttentionTranslScaleEquiv", # pylint: disable=line-too-long
|
| 122 |
+
"num_iterations": 3,
|
| 123 |
+
"qkv_size": 64,
|
| 124 |
+
"mlp_size": 128,
|
| 125 |
+
"grid_encoder": ml_collections.ConfigDict({
|
| 126 |
+
"module": "invariant_slot_attention.modules.MLP",
|
| 127 |
+
"hidden_size": 128,
|
| 128 |
+
"layernorm": "pre"
|
| 129 |
+
}),
|
| 130 |
+
"add_rel_pos_to_values": False, # V2
|
| 131 |
+
"zero_position_init": False, # Random positions.
|
| 132 |
+
"init_with_fixed_scale": None, # Random scales.
|
| 133 |
+
"scales_factor": 5.0,
|
| 134 |
+
}),
|
| 135 |
+
|
| 136 |
+
# Predictor.
|
| 137 |
+
# Removed since we are running a single frame.
|
| 138 |
+
"predictor": ml_collections.ConfigDict({
|
| 139 |
+
"module": "invariant_slot_attention.modules.Identity"
|
| 140 |
+
}),
|
| 141 |
+
|
| 142 |
+
# Initializer.
|
| 143 |
+
"initializer": ml_collections.ConfigDict({
|
| 144 |
+
"module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsScales", # pylint: disable=line-too-long
|
| 145 |
+
"shape": (11, 64), # (num_slots, slot_size)
|
| 146 |
+
}),
|
| 147 |
+
|
| 148 |
+
# Decoder.
|
| 149 |
+
"decoder": ml_collections.ConfigDict({
|
| 150 |
+
"module":
|
| 151 |
+
"invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
|
| 152 |
+
"resolution": (16, 16), # Update if data resolution or strides change
|
| 153 |
+
"backbone": ml_collections.ConfigDict({
|
| 154 |
+
"module": "invariant_slot_attention.modules.CNN",
|
| 155 |
+
"features": [64, 64, 64, 64, 64],
|
| 156 |
+
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
|
| 157 |
+
"strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
|
| 158 |
+
"max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
|
| 159 |
+
"layer_transpose": [True, True, True, False, False]
|
| 160 |
+
}),
|
| 161 |
+
"target_readout": ml_collections.ConfigDict({
|
| 162 |
+
"module": "invariant_slot_attention.modules.Readout",
|
| 163 |
+
"keys": list(targets),
|
| 164 |
+
"readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
|
| 165 |
+
"module": "invariant_slot_attention.modules.MLP",
|
| 166 |
+
"num_hidden_layers": 0,
|
| 167 |
+
"hidden_size": 0,
|
| 168 |
+
"output_size": targets[k]}) for k in targets],
|
| 169 |
+
}),
|
| 170 |
+
"relative_positions_and_scales": True,
|
| 171 |
+
"pos_emb": ml_collections.ConfigDict({
|
| 172 |
+
"module":
|
| 173 |
+
"invariant_slot_attention.modules.RelativePositionEmbedding",
|
| 174 |
+
"embedding_type":
|
| 175 |
+
"linear",
|
| 176 |
+
"update_type":
|
| 177 |
+
"project_add",
|
| 178 |
+
"scales_factor":
|
| 179 |
+
5.0,
|
| 180 |
+
}),
|
| 181 |
+
}),
|
| 182 |
+
"decode_corrected": True,
|
| 183 |
+
"decode_predicted": False,
|
| 184 |
+
})
|
| 185 |
+
|
| 186 |
+
# Which video-shaped variables to visualize.
|
| 187 |
+
config.debug_var_video_paths = {
|
| 188 |
+
"recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
# Define which attention matrices to log/visualize.
|
| 192 |
+
config.debug_var_attn_paths = {
|
| 193 |
+
"corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
# Widths of attention matrices (for reshaping to image grid).
|
| 197 |
+
config.debug_var_attn_widths = {
|
| 198 |
+
"corrector_attn": 16,
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
return config
|
| 202 |
+
|
requirements.txt
CHANGED
|
@@ -4,7 +4,7 @@ tensorflow-cpu>=2.12.0
|
|
| 4 |
tensorflow-datasets>=4.4.0
|
| 5 |
matplotlib>=3.5.0
|
| 6 |
clu>=0.0.3
|
| 7 |
-
flax
|
| 8 |
chex>=0.0.7
|
| 9 |
optax>=0.1.0
|
| 10 |
ml-collections>=0.1.0
|
|
|
|
| 4 |
tensorflow-datasets>=4.4.0
|
| 5 |
matplotlib>=3.5.0
|
| 6 |
clu>=0.0.3
|
| 7 |
+
flax==0.3.5
|
| 8 |
chex>=0.0.7
|
| 9 |
optax>=0.1.0
|
| 10 |
ml-collections>=0.1.0
|