Spaces:
Build error
Build error
| # coding=utf-8 | |
| # Copyright 2023 The Google Research Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Common utils.""" | |
| import functools | |
| import importlib | |
| from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Type, Union | |
| from absl import logging | |
| from clu import metrics as base_metrics | |
| import flax | |
| from flax import linen as nn | |
| from flax import traverse_util | |
| import jax | |
| import jax.numpy as jnp | |
| import jax.ops | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| import ml_collections | |
| import numpy as np | |
| import optax | |
| import skimage.transform | |
| import tensorflow as tf | |
| from invariant_slot_attention.lib import metrics | |
| Array = Any # Union[np.ndarray, jnp.ndarray] | |
| ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet | |
| DictTree = Dict[str, Union[Array, "DictTree"]] # pytype: disable=not-supported-yet | |
| PRNGKey = Array | |
| ConfigAttr = Any | |
| MetricSpec = Dict[str, str] | |
| class TrainState: | |
| """Data structure for checkpointing the model.""" | |
| step: int | |
| opt_state: optax.OptState | |
| params: ArrayTree | |
| variables: flax.core.FrozenDict | |
| rng: PRNGKey | |
| METRIC_TYPE_TO_CLS = { | |
| "loss": base_metrics.Average.from_output(name="loss"), | |
| "ari": metrics.Ari, | |
| "ari_nobg": metrics.AriNoBg, | |
| } | |
| def make_metrics_collection( | |
| class_name, | |
| metrics_spec): | |
| """Create class inhering from metrics.Collection based on spec.""" | |
| metrics_dict = {} | |
| if metrics_spec: | |
| for m_name, m_type in metrics_spec.items(): | |
| metrics_dict[m_name] = METRIC_TYPE_TO_CLS[m_type] | |
| return flax.struct.dataclass( | |
| type(class_name, | |
| (base_metrics.Collection,), | |
| {"__annotations__": metrics_dict})) | |
| def flatten_named_dicttree(metrics_res, sep = "/"): | |
| """Flatten dictionary.""" | |
| metrics_res_flat = {} | |
| for k, v in traverse_util.flatten_dict(metrics_res).items(): | |
| metrics_res_flat[(sep.join(k)).strip(sep)] = v | |
| return metrics_res_flat | |
| def spatial_broadcast(x, resolution): | |
| """Broadcast flat inputs to a 2D grid of a given resolution.""" | |
| # x.shape = (batch_size, features). | |
| x = x[:, jnp.newaxis, jnp.newaxis, :] | |
| return jnp.tile(x, [1, resolution[0], resolution[1], 1]) | |
| def time_distributed(cls, in_axes=1, axis=1): | |
| """Wrapper for time-distributed (vmapped) application of a module.""" | |
| return nn.vmap( | |
| cls, in_axes=in_axes, out_axes=axis, axis_name="time", | |
| # Stack debug vars along sequence dim and broadcast params. | |
| variable_axes={ | |
| "params": None, "intermediates": axis, "batch_stats": None}, | |
| split_rngs={"params": False, "dropout": True, "state_init": True}) | |
| def broadcast_across_batch(inputs, batch_size): | |
| """Broadcasts inputs across a batch of examples (creates new axis).""" | |
| return jnp.broadcast_to( | |
| array=jnp.expand_dims(inputs, axis=0), | |
| shape=(batch_size,) + inputs.shape) | |
| def create_gradient_grid( | |
| samples_per_dim, value_range = (-1.0, 1.0) | |
| ): | |
| """Creates a tensor with equidistant entries from -1 to +1 in each dim. | |
| Args: | |
| samples_per_dim: Number of points to have along each dimension. | |
| value_range: In each dimension, points will go from range[0] to range[1] | |
| Returns: | |
| A tensor of shape [samples_per_dim] + [len(samples_per_dim)]. | |
| """ | |
| s = [jnp.linspace(value_range[0], value_range[1], n) for n in samples_per_dim] | |
| pe = jnp.stack(jnp.meshgrid(*s, sparse=False, indexing="ij"), axis=-1) | |
| return jnp.array(pe) | |
| def convert_to_fourier_features(inputs, basis_degree): | |
| """Convert inputs to Fourier features, e.g. for positional encoding.""" | |
| # inputs.shape = (..., n_dims). | |
| # inputs should be in range [-pi, pi] or [0, 2pi]. | |
| n_dims = inputs.shape[-1] | |
| # Generate frequency basis. | |
| freq_basis = jnp.concatenate( # shape = (n_dims, n_dims * basis_degree) | |
| [2**i * jnp.eye(n_dims) for i in range(basis_degree)], 1) | |
| # x.shape = (..., n_dims * basis_degree) | |
| x = inputs @ freq_basis # Project inputs onto frequency basis. | |
| # Obtain Fourier features as [sin(x), cos(x)] = [sin(x), sin(x + 0.5 * pi)]. | |
| return jnp.sin(jnp.concatenate([x, x + 0.5 * jnp.pi], axis=-1)) | |
| def prepare_images_for_logging( | |
| config, | |
| batch = None, | |
| preds = None, | |
| n_samples = 5, | |
| n_frames = 5, | |
| min_n_colors = 1, | |
| epsilon = 1e-6, | |
| first_replica_only = False): | |
| """Prepare images from batch and/or model predictions for logging.""" | |
| images = dict() | |
| # Converts all tensors to numpy arrays to run everything on CPU as JAX | |
| # eager mode is inefficient and because memory usage from these ops may | |
| # lead to OOM errors. | |
| batch = jax.tree_map(np.array, batch) | |
| preds = jax.tree_map(np.array, preds) | |
| if n_samples <= 0: | |
| return images | |
| if not first_replica_only: | |
| # Move the two leading batch dimensions into a single dimension. We do this | |
| # to plot the same number of examples regardless of the data parallelism. | |
| batch = jax.tree_map(lambda x: np.reshape(x, (-1,) + x.shape[2:]), batch) | |
| preds = jax.tree_map(lambda x: np.reshape(x, (-1,) + x.shape[2:]), preds) | |
| else: | |
| batch = jax.tree_map(lambda x: x[0], batch) | |
| preds = jax.tree_map(lambda x: x[0], preds) | |
| # Limit the tensors to n_samples and n_frames. | |
| batch = jax.tree_map( | |
| lambda x: x[:n_samples, :n_frames] if x.ndim > 2 else x[:n_samples], | |
| batch) | |
| preds = jax.tree_map( | |
| lambda x: x[:n_samples, :n_frames] if x.ndim > 2 else x[:n_samples], | |
| preds) | |
| # Log input data. | |
| if batch is not None: | |
| images["video"] = video_to_image_grid(batch["video"]) | |
| if "segmentations" in batch: | |
| images["mask"] = video_to_image_grid(convert_categories_to_color( | |
| batch["segmentations"], min_n_colors=min_n_colors)) | |
| if "flow" in batch: | |
| images["flow"] = video_to_image_grid(batch["flow"]) | |
| if "boxes" in batch: | |
| images["boxes"] = draw_bounding_boxes( | |
| batch["video"], | |
| batch["boxes"], | |
| min_n_colors=min_n_colors) | |
| # Log model predictions. | |
| if preds is not None and preds.get("outputs") is not None: | |
| if "segmentations" in preds["outputs"]: # pytype: disable=attribute-error | |
| images["segmentations"] = video_to_image_grid( | |
| convert_categories_to_color( | |
| preds["outputs"]["segmentations"], min_n_colors=min_n_colors)) | |
| def shape_fn(x): | |
| if isinstance(x, (np.ndarray, jnp.ndarray)): | |
| return x.shape | |
| # Log intermediate variables. | |
| if preds is not None and "intermediates" in preds: | |
| logging.info("intermediates: %s", | |
| jax.tree_map(shape_fn, preds["intermediates"])) | |
| for key, path in config.debug_var_video_paths.items(): | |
| log_vars = retrieve_from_collection(preds["intermediates"], path) | |
| if log_vars is not None: | |
| if not isinstance(log_vars, Sequence): | |
| log_vars = [log_vars] | |
| for i, log_var in enumerate(log_vars): | |
| log_var = np.array(log_var) # Moves log_var to CPU. | |
| images[key + "_" + str(i)] = video_to_image_grid(log_var) | |
| else: | |
| logging.warning("%s not found in intermediates", path) | |
| # Log attention weights. | |
| for key, path in config.debug_var_attn_paths.items(): | |
| log_vars = retrieve_from_collection(preds["intermediates"], path) | |
| if log_vars is not None: | |
| if not isinstance(log_vars, Sequence): | |
| log_vars = [log_vars] | |
| for i, log_var in enumerate(log_vars): | |
| log_var = np.array(log_var) # Moves log_var to CPU. | |
| images.update( | |
| prepare_attention_maps_for_logging( | |
| attn_maps=log_var, | |
| key=key + "_" + str(i), | |
| map_width=config.debug_var_attn_widths.get(key), | |
| video=batch["video"], | |
| epsilon=epsilon, | |
| n_samples=n_samples, | |
| n_frames=n_frames)) | |
| else: | |
| logging.warning("%s not found in intermediates", path) | |
| # Crop each image to a maximum of 3 channels for RGB visualization. | |
| for key, image in images.items(): | |
| if image.shape[-1] > 3: | |
| logging.warning("Truncating channels of %s for visualization.", key) | |
| images[key] = image[Ellipsis, :3] | |
| return images | |
| def prepare_attention_maps_for_logging(attn_maps, key, | |
| map_width, epsilon, | |
| n_samples, n_frames, | |
| video): | |
| """Visualize (overlayed) attention maps as an image grid.""" | |
| images = {} # Results dictionary. | |
| attn_maps = unflatten_image(attn_maps[Ellipsis, None], width=map_width) | |
| num_heads = attn_maps.shape[2] | |
| for head_idx in range(num_heads): | |
| attn = attn_maps[:n_samples, :n_frames, head_idx] | |
| attn /= attn.max() + epsilon # Standardizes scale for visualization. | |
| # attn.shape: [bs, seq_len, 11, h', w', 1] | |
| bs, seq_len, _, h_attn, w_attn, _ = attn.shape | |
| images[f"{key}_head_{head_idx}"] = video_to_image_grid(attn) | |
| # Attention maps are interpretable when they align with object boundaries. | |
| # However, if they are overly smooth then the following visualization which | |
| # overlays attention maps on video is helpful. | |
| video = video[:n_samples, :n_frames] | |
| # video.shape: [bs, seq_len, h, w, 3] | |
| video_resized = [] | |
| for i in range(n_samples): | |
| for j in range(n_frames): | |
| video_resized.append( | |
| skimage.transform.resize(video[i, j], (h_attn, w_attn), order=1)) | |
| video_resized = np.array(video_resized).reshape( | |
| (bs, seq_len, h_attn, w_attn, 3)) | |
| attn_overlayed = attn * np.expand_dims(video_resized, 2) | |
| images[f"{key}_head_{head_idx}_overlayed"] = video_to_image_grid( | |
| attn_overlayed) | |
| return images | |
| def convert_categories_to_color( | |
| inputs, min_n_colors = 1, include_black = True): | |
| """Converts int-valued categories to color in last axis of input tensor. | |
| Args: | |
| inputs: `np.ndarray` of arbitrary shape with integer entries, encoding the | |
| categories. | |
| min_n_colors: Minimum number of colors (excl. black) to encode categories. | |
| include_black: Include black as 0-th entry in the color palette. Increases | |
| `min_n_colors` by 1 if True. | |
| Returns: | |
| `np.ndarray` with RGB colors in last axis. | |
| """ | |
| if inputs.shape[-1] == 1: # Strip category axis. | |
| inputs = np.squeeze(inputs, axis=-1) | |
| inputs = np.array(inputs, dtype=np.int32) # Convert to int. | |
| # Infer number of colors from inputs. | |
| n_colors = int(inputs.max()) + 1 # One color per category incl. 0. | |
| if include_black: | |
| n_colors -= 1 # If we include black, we need one color less. | |
| if min_n_colors > n_colors: # Use more colors in color palette if requested. | |
| n_colors = min_n_colors | |
| rgb_colors = get_uniform_colors(n_colors) | |
| if include_black: # Add black as color for zero-th index. | |
| rgb_colors = np.concatenate((np.zeros((1, 3)), rgb_colors), axis=0) | |
| return rgb_colors[inputs] | |
| def get_uniform_colors(n_colors): | |
| """Get n_colors with uniformly spaced hues.""" | |
| hues = np.linspace(0, 1, n_colors, endpoint=False) | |
| hsv_colors = np.concatenate( | |
| (np.expand_dims(hues, axis=1), np.ones((n_colors, 2))), axis=1) | |
| rgb_colors = matplotlib.colors.hsv_to_rgb(hsv_colors) | |
| return rgb_colors # rgb_colors.shape = (n_colors, 3) | |
| def unflatten_image(image, width = None): | |
| """Unflatten image array of shape [batch_dims..., height*width, channels].""" | |
| n_channels = image.shape[-1] | |
| # If width is not provided, we assume that the image is square. | |
| if width is None: | |
| width = int(np.floor(np.sqrt(image.shape[-2]))) | |
| height = width | |
| assert width * height == image.shape[-2], "Image is not square." | |
| else: | |
| height = image.shape[-2] // width | |
| return image.reshape(image.shape[:-2] + (height, width, n_channels)) | |
| def video_to_image_grid(video): | |
| """Transform video to image grid by folding sequence dim along width.""" | |
| if len(video.shape) == 5: | |
| n_samples, n_frames, height, width, n_channels = video.shape | |
| video = np.transpose(video, (0, 2, 1, 3, 4)) # Swap n_frames and height. | |
| image_grid = np.reshape( | |
| video, (n_samples, height, n_frames * width, n_channels)) | |
| elif len(video.shape) == 6: | |
| n_samples, n_frames, n_slots, height, width, n_channels = video.shape | |
| # Put n_frames next to width. | |
| video = np.transpose(video, (0, 2, 3, 1, 4, 5)) | |
| image_grid = np.reshape( | |
| video, (n_samples, n_slots * height, n_frames * width, n_channels)) | |
| else: | |
| raise ValueError("Unsupported video shape for visualization.") | |
| return image_grid | |
| def draw_bounding_boxes(video, | |
| boxes, | |
| min_n_colors = 1, | |
| include_black = True): | |
| """Draw bounding boxes in videos.""" | |
| colors = get_uniform_colors(min_n_colors - include_black) | |
| b, t, h, w, c = video.shape | |
| n = boxes.shape[2] | |
| image_grid = tf.image.draw_bounding_boxes( | |
| np.reshape(video, (b * t, h, w, c)), | |
| np.reshape(boxes, (b * t, n, 4)), | |
| colors).numpy() | |
| image_grid = np.reshape( | |
| np.transpose(np.reshape(image_grid, (b, t, h, w, c)), | |
| (0, 2, 1, 3, 4)), | |
| (b, h, t * w, c)) | |
| return image_grid | |
| def plot_image(ax, image): | |
| """Add an image visualization to a provided `plt.Axes` instance.""" | |
| num_channels = image.shape[-1] | |
| if num_channels == 1: | |
| image = image.reshape(image.shape[:2]) | |
| ax.imshow(image, cmap="viridis") | |
| ax.grid(False) | |
| plt.axis("off") | |
| def visualize_image_dict(images, plot_scale = 10): | |
| """Visualize a dictionary of images in colab using maptlotlib.""" | |
| for key in images.keys(): | |
| logging.info("Visualizing key: %s", key) | |
| n_images = len(images[key]) | |
| fig = plt.figure(figsize=(n_images * plot_scale, plot_scale)) | |
| for idx, image in enumerate(images[key]): | |
| ax = fig.add_subplot(1, n_images, idx+1) | |
| plot_image(ax, image) | |
| plt.show() | |
| def filter_key_from_frozen_dict( | |
| frozen_dict, key): | |
| """Filters (removes) an item by key from a flax.core.FrozenDict.""" | |
| if key in frozen_dict: | |
| frozen_dict, _ = frozen_dict.pop(key) | |
| return frozen_dict | |
| def prepare_dict_for_logging(nested_dict, parent_key = "", | |
| sep = "_"): | |
| """Prepare a nested dictionary for logging with `clu.metric_writers`. | |
| Args: | |
| nested_dict: A nested dictionary, e.g. obtained from a | |
| `ml_collections.ConfigDict` via `.to_dict()`. | |
| parent_key: String used in recursion. | |
| sep: String used to separate parent and child keys. | |
| Returns: | |
| Flattened dict. | |
| """ | |
| items = [] | |
| for k, v in nested_dict.items(): | |
| # Flatten keys of nested elements. | |
| new_key = parent_key + sep + k if parent_key else k | |
| # Convert None values, lists and tuples to strings. | |
| if v is None: | |
| v = "None" | |
| if isinstance(v, list) or isinstance(v, tuple): | |
| v = str(v) | |
| # Recursively flatten the dict. | |
| if isinstance(v, dict): | |
| items.extend(prepare_dict_for_logging(v, new_key, sep=sep).items()) | |
| else: | |
| items.append((new_key, v)) | |
| return dict(items) | |
| def retrieve_from_collection( | |
| variable_collection, path): | |
| """Finds variables by their path by recursively searching the collection. | |
| Args: | |
| variable_collection: Nested dict containing the variables (or tuples/lists | |
| of variables). | |
| path: Path to variable in module tree, similar to Unix file names (e.g. | |
| '/module/dense/0/bias'). | |
| Returns: | |
| The requested variable, variable collection or None (in case the variable | |
| could not be found). | |
| """ | |
| key, _, rpath = path.strip("/").partition("/") | |
| # In case the variable is not found, we return None. | |
| if (key.isdigit() and not isinstance(variable_collection, Sequence)) or ( | |
| key.isdigit() and int(key) >= len(variable_collection)) or ( | |
| not key.isdigit() and key not in variable_collection): | |
| return None | |
| if key.isdigit(): | |
| key = int(key) | |
| if not rpath: | |
| return variable_collection[key] | |
| else: | |
| return retrieve_from_collection(variable_collection[key], rpath) | |
| def build_model_from_config(config): | |
| """Build a Flax model from a (nested) ConfigDict.""" | |
| model_constructor = _parse_config(config) | |
| if callable(model_constructor): | |
| return model_constructor() | |
| else: | |
| raise ValueError("Provided config does not contain module constructors.") | |
| def _parse_config(config | |
| ): | |
| """Recursively parses a nested ConfigDict and resolves module constructors.""" | |
| if isinstance(config, list): | |
| return [_parse_config(c) for c in config] | |
| elif isinstance(config, tuple): | |
| return tuple([_parse_config(c) for c in config]) | |
| elif not isinstance(config, ml_collections.ConfigDict): | |
| return config | |
| elif "module" in config: | |
| module_constructor = _resolve_module_constructor(config.module) | |
| kwargs = {k: _parse_config(v) for k, v in config.items() if k != "module"} | |
| return functools.partial(module_constructor, **kwargs) | |
| else: | |
| return {k: _parse_config(v) for k, v in config.items()} | |
| def _resolve_module_constructor( | |
| constructor_str): | |
| import_str, _, module_name = constructor_str.rpartition(".") | |
| py_module = importlib.import_module(import_str) | |
| return getattr(py_module, module_name) | |
| def get_slices_along_axis( | |
| inputs, | |
| slice_keys, | |
| start_idx = 0, | |
| end_idx = -1, | |
| axis = 2, | |
| pad_value = 0): | |
| """Extracts slices from a dictionary of tensors along the specified axis. | |
| The slice operation is only applied to `slice_keys` dictionary keys. If | |
| `end_idx` is larger than the actual size of the specified axis, padding is | |
| added (with values provided in `pad_value`). | |
| Args: | |
| inputs: Dictionary of tensors. | |
| slice_keys: Iterable of strings, the keys for the inputs dictionary for | |
| which to apply the slice operation. | |
| start_idx: Integer, defining the first index to be part of the slice. | |
| end_idx: Integer, defining the end of the slice interval (exclusive). If set | |
| to `-1`, the end index is set to the size of the axis. If a value is | |
| provided that is larger than the size of the axis, zero-padding is added | |
| for the remaining elements. | |
| axis: Integer, the axis along which to slice. | |
| pad_value: Integer, value to be used in padding. | |
| Returns: | |
| Dictionary of tensors where elements described in `slice_keys` are sliced, | |
| and all other elements are returned as original. | |
| """ | |
| max_size = None | |
| pad_size = 0 | |
| # Check shapes and get maximum size of requested axis. | |
| for key in slice_keys: | |
| curr_size = inputs[key].shape[axis] | |
| if max_size is None: | |
| max_size = curr_size | |
| elif max_size != curr_size: | |
| raise ValueError( | |
| "For specified tensors the requested axis needs to be of equal size.") | |
| # Infer end index if not provided. | |
| if end_idx == -1: | |
| end_idx = max_size | |
| # Set padding size if end index is larger than maximum size of requested axis. | |
| elif end_idx > max_size: | |
| pad_size = end_idx - max_size | |
| end_idx = max_size | |
| outputs = {} | |
| for key in slice_keys: | |
| outputs[key] = np.take( | |
| inputs[key], indices=np.arange(start_idx, end_idx), axis=axis) | |
| # Add padding if necessary. | |
| if pad_size > 0: | |
| pad_shape = np.array(outputs[key].shape) | |
| np.put(pad_shape, axis, pad_size) # In-place op. | |
| padding = pad_value * np.ones(pad_shape, dtype=outputs[key].dtype) | |
| outputs[key] = np.concatenate((outputs[key], padding), axis=axis) | |
| return outputs | |
| def get_element_by_str( | |
| dictionary, multilevel_key, separator = "/" | |
| ): | |
| """Gets element in a dictionary with multilevel key (e.g., "key1/key2").""" | |
| keys = multilevel_key.split(separator) | |
| if len(keys) == 1: | |
| return dictionary[keys[0]] | |
| return get_element_by_str( | |
| dictionary[keys[0]], separator.join(keys[1:]), separator=separator) | |
| def set_element_by_str( | |
| dictionary, multilevel_key, new_value, | |
| separator = "/"): | |
| """Sets element in a dictionary with multilevel key (e.g., "key1/key2").""" | |
| keys = multilevel_key.split(separator) | |
| if len(keys) == 1: | |
| if keys[0] not in dictionary: | |
| key_error = ( | |
| "Pretrained {key} was not found in trained model. " | |
| "Make sure you are loading the correct pretrained model " | |
| "or consider adding {key} to exceptions.") | |
| raise KeyError(key_error.format(type="parameter", key=keys[0])) | |
| dictionary[keys[0]] = new_value | |
| else: | |
| set_element_by_str( | |
| dictionary[keys[0]], | |
| separator.join(keys[1:]), | |
| new_value, | |
| separator=separator) | |
| def remove_singleton_dim(inputs): | |
| """Removes the final dimension if it is singleton (i.e. of size 1).""" | |
| if inputs is None: | |
| return None | |
| if inputs.shape[-1] != 1: | |
| logging.warning("Expected final dimension of inputs to be 1, " | |
| "received inputs of shape %s: ", str(inputs.shape)) | |
| return inputs | |
| return inputs[Ellipsis, 0] | |