# coding=utf-8 # Copyright 2023 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Common utils.""" import functools import importlib from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Type, Union from absl import logging from clu import metrics as base_metrics import flax from flax import linen as nn from flax import traverse_util import jax import jax.numpy as jnp import jax.ops import matplotlib import matplotlib.pyplot as plt import ml_collections import numpy as np import optax import skimage.transform import tensorflow as tf from invariant_slot_attention.lib import metrics Array = Any # Union[np.ndarray, jnp.ndarray] ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet DictTree = Dict[str, Union[Array, "DictTree"]] # pytype: disable=not-supported-yet PRNGKey = Array ConfigAttr = Any MetricSpec = Dict[str, str] @flax.struct.dataclass class TrainState: """Data structure for checkpointing the model.""" step: int opt_state: optax.OptState params: ArrayTree variables: flax.core.FrozenDict rng: PRNGKey METRIC_TYPE_TO_CLS = { "loss": base_metrics.Average.from_output(name="loss"), "ari": metrics.Ari, "ari_nobg": metrics.AriNoBg, } def make_metrics_collection( class_name, metrics_spec): """Create class inhering from metrics.Collection based on spec.""" metrics_dict = {} if metrics_spec: for m_name, m_type in metrics_spec.items(): metrics_dict[m_name] = METRIC_TYPE_TO_CLS[m_type] return flax.struct.dataclass( type(class_name, (base_metrics.Collection,), {"__annotations__": metrics_dict})) def flatten_named_dicttree(metrics_res, sep = "/"): """Flatten dictionary.""" metrics_res_flat = {} for k, v in traverse_util.flatten_dict(metrics_res).items(): metrics_res_flat[(sep.join(k)).strip(sep)] = v return metrics_res_flat def spatial_broadcast(x, resolution): """Broadcast flat inputs to a 2D grid of a given resolution.""" # x.shape = (batch_size, features). x = x[:, jnp.newaxis, jnp.newaxis, :] return jnp.tile(x, [1, resolution[0], resolution[1], 1]) def time_distributed(cls, in_axes=1, axis=1): """Wrapper for time-distributed (vmapped) application of a module.""" return nn.vmap( cls, in_axes=in_axes, out_axes=axis, axis_name="time", # Stack debug vars along sequence dim and broadcast params. variable_axes={ "params": None, "intermediates": axis, "batch_stats": None}, split_rngs={"params": False, "dropout": True, "state_init": True}) def broadcast_across_batch(inputs, batch_size): """Broadcasts inputs across a batch of examples (creates new axis).""" return jnp.broadcast_to( array=jnp.expand_dims(inputs, axis=0), shape=(batch_size,) + inputs.shape) def create_gradient_grid( samples_per_dim, value_range = (-1.0, 1.0) ): """Creates a tensor with equidistant entries from -1 to +1 in each dim. Args: samples_per_dim: Number of points to have along each dimension. value_range: In each dimension, points will go from range[0] to range[1] Returns: A tensor of shape [samples_per_dim] + [len(samples_per_dim)]. """ s = [jnp.linspace(value_range[0], value_range[1], n) for n in samples_per_dim] pe = jnp.stack(jnp.meshgrid(*s, sparse=False, indexing="ij"), axis=-1) return jnp.array(pe) def convert_to_fourier_features(inputs, basis_degree): """Convert inputs to Fourier features, e.g. for positional encoding.""" # inputs.shape = (..., n_dims). # inputs should be in range [-pi, pi] or [0, 2pi]. n_dims = inputs.shape[-1] # Generate frequency basis. freq_basis = jnp.concatenate( # shape = (n_dims, n_dims * basis_degree) [2**i * jnp.eye(n_dims) for i in range(basis_degree)], 1) # x.shape = (..., n_dims * basis_degree) x = inputs @ freq_basis # Project inputs onto frequency basis. # Obtain Fourier features as [sin(x), cos(x)] = [sin(x), sin(x + 0.5 * pi)]. return jnp.sin(jnp.concatenate([x, x + 0.5 * jnp.pi], axis=-1)) def prepare_images_for_logging( config, batch = None, preds = None, n_samples = 5, n_frames = 5, min_n_colors = 1, epsilon = 1e-6, first_replica_only = False): """Prepare images from batch and/or model predictions for logging.""" images = dict() # Converts all tensors to numpy arrays to run everything on CPU as JAX # eager mode is inefficient and because memory usage from these ops may # lead to OOM errors. batch = jax.tree_map(np.array, batch) preds = jax.tree_map(np.array, preds) if n_samples <= 0: return images if not first_replica_only: # Move the two leading batch dimensions into a single dimension. We do this # to plot the same number of examples regardless of the data parallelism. batch = jax.tree_map(lambda x: np.reshape(x, (-1,) + x.shape[2:]), batch) preds = jax.tree_map(lambda x: np.reshape(x, (-1,) + x.shape[2:]), preds) else: batch = jax.tree_map(lambda x: x[0], batch) preds = jax.tree_map(lambda x: x[0], preds) # Limit the tensors to n_samples and n_frames. batch = jax.tree_map( lambda x: x[:n_samples, :n_frames] if x.ndim > 2 else x[:n_samples], batch) preds = jax.tree_map( lambda x: x[:n_samples, :n_frames] if x.ndim > 2 else x[:n_samples], preds) # Log input data. if batch is not None: images["video"] = video_to_image_grid(batch["video"]) if "segmentations" in batch: images["mask"] = video_to_image_grid(convert_categories_to_color( batch["segmentations"], min_n_colors=min_n_colors)) if "flow" in batch: images["flow"] = video_to_image_grid(batch["flow"]) if "boxes" in batch: images["boxes"] = draw_bounding_boxes( batch["video"], batch["boxes"], min_n_colors=min_n_colors) # Log model predictions. if preds is not None and preds.get("outputs") is not None: if "segmentations" in preds["outputs"]: # pytype: disable=attribute-error images["segmentations"] = video_to_image_grid( convert_categories_to_color( preds["outputs"]["segmentations"], min_n_colors=min_n_colors)) def shape_fn(x): if isinstance(x, (np.ndarray, jnp.ndarray)): return x.shape # Log intermediate variables. if preds is not None and "intermediates" in preds: logging.info("intermediates: %s", jax.tree_map(shape_fn, preds["intermediates"])) for key, path in config.debug_var_video_paths.items(): log_vars = retrieve_from_collection(preds["intermediates"], path) if log_vars is not None: if not isinstance(log_vars, Sequence): log_vars = [log_vars] for i, log_var in enumerate(log_vars): log_var = np.array(log_var) # Moves log_var to CPU. images[key + "_" + str(i)] = video_to_image_grid(log_var) else: logging.warning("%s not found in intermediates", path) # Log attention weights. for key, path in config.debug_var_attn_paths.items(): log_vars = retrieve_from_collection(preds["intermediates"], path) if log_vars is not None: if not isinstance(log_vars, Sequence): log_vars = [log_vars] for i, log_var in enumerate(log_vars): log_var = np.array(log_var) # Moves log_var to CPU. images.update( prepare_attention_maps_for_logging( attn_maps=log_var, key=key + "_" + str(i), map_width=config.debug_var_attn_widths.get(key), video=batch["video"], epsilon=epsilon, n_samples=n_samples, n_frames=n_frames)) else: logging.warning("%s not found in intermediates", path) # Crop each image to a maximum of 3 channels for RGB visualization. for key, image in images.items(): if image.shape[-1] > 3: logging.warning("Truncating channels of %s for visualization.", key) images[key] = image[Ellipsis, :3] return images def prepare_attention_maps_for_logging(attn_maps, key, map_width, epsilon, n_samples, n_frames, video): """Visualize (overlayed) attention maps as an image grid.""" images = {} # Results dictionary. attn_maps = unflatten_image(attn_maps[Ellipsis, None], width=map_width) num_heads = attn_maps.shape[2] for head_idx in range(num_heads): attn = attn_maps[:n_samples, :n_frames, head_idx] attn /= attn.max() + epsilon # Standardizes scale for visualization. # attn.shape: [bs, seq_len, 11, h', w', 1] bs, seq_len, _, h_attn, w_attn, _ = attn.shape images[f"{key}_head_{head_idx}"] = video_to_image_grid(attn) # Attention maps are interpretable when they align with object boundaries. # However, if they are overly smooth then the following visualization which # overlays attention maps on video is helpful. video = video[:n_samples, :n_frames] # video.shape: [bs, seq_len, h, w, 3] video_resized = [] for i in range(n_samples): for j in range(n_frames): video_resized.append( skimage.transform.resize(video[i, j], (h_attn, w_attn), order=1)) video_resized = np.array(video_resized).reshape( (bs, seq_len, h_attn, w_attn, 3)) attn_overlayed = attn * np.expand_dims(video_resized, 2) images[f"{key}_head_{head_idx}_overlayed"] = video_to_image_grid( attn_overlayed) return images def convert_categories_to_color( inputs, min_n_colors = 1, include_black = True): """Converts int-valued categories to color in last axis of input tensor. Args: inputs: `np.ndarray` of arbitrary shape with integer entries, encoding the categories. min_n_colors: Minimum number of colors (excl. black) to encode categories. include_black: Include black as 0-th entry in the color palette. Increases `min_n_colors` by 1 if True. Returns: `np.ndarray` with RGB colors in last axis. """ if inputs.shape[-1] == 1: # Strip category axis. inputs = np.squeeze(inputs, axis=-1) inputs = np.array(inputs, dtype=np.int32) # Convert to int. # Infer number of colors from inputs. n_colors = int(inputs.max()) + 1 # One color per category incl. 0. if include_black: n_colors -= 1 # If we include black, we need one color less. if min_n_colors > n_colors: # Use more colors in color palette if requested. n_colors = min_n_colors rgb_colors = get_uniform_colors(n_colors) if include_black: # Add black as color for zero-th index. rgb_colors = np.concatenate((np.zeros((1, 3)), rgb_colors), axis=0) return rgb_colors[inputs] def get_uniform_colors(n_colors): """Get n_colors with uniformly spaced hues.""" hues = np.linspace(0, 1, n_colors, endpoint=False) hsv_colors = np.concatenate( (np.expand_dims(hues, axis=1), np.ones((n_colors, 2))), axis=1) rgb_colors = matplotlib.colors.hsv_to_rgb(hsv_colors) return rgb_colors # rgb_colors.shape = (n_colors, 3) def unflatten_image(image, width = None): """Unflatten image array of shape [batch_dims..., height*width, channels].""" n_channels = image.shape[-1] # If width is not provided, we assume that the image is square. if width is None: width = int(np.floor(np.sqrt(image.shape[-2]))) height = width assert width * height == image.shape[-2], "Image is not square." else: height = image.shape[-2] // width return image.reshape(image.shape[:-2] + (height, width, n_channels)) def video_to_image_grid(video): """Transform video to image grid by folding sequence dim along width.""" if len(video.shape) == 5: n_samples, n_frames, height, width, n_channels = video.shape video = np.transpose(video, (0, 2, 1, 3, 4)) # Swap n_frames and height. image_grid = np.reshape( video, (n_samples, height, n_frames * width, n_channels)) elif len(video.shape) == 6: n_samples, n_frames, n_slots, height, width, n_channels = video.shape # Put n_frames next to width. video = np.transpose(video, (0, 2, 3, 1, 4, 5)) image_grid = np.reshape( video, (n_samples, n_slots * height, n_frames * width, n_channels)) else: raise ValueError("Unsupported video shape for visualization.") return image_grid def draw_bounding_boxes(video, boxes, min_n_colors = 1, include_black = True): """Draw bounding boxes in videos.""" colors = get_uniform_colors(min_n_colors - include_black) b, t, h, w, c = video.shape n = boxes.shape[2] image_grid = tf.image.draw_bounding_boxes( np.reshape(video, (b * t, h, w, c)), np.reshape(boxes, (b * t, n, 4)), colors).numpy() image_grid = np.reshape( np.transpose(np.reshape(image_grid, (b, t, h, w, c)), (0, 2, 1, 3, 4)), (b, h, t * w, c)) return image_grid def plot_image(ax, image): """Add an image visualization to a provided `plt.Axes` instance.""" num_channels = image.shape[-1] if num_channels == 1: image = image.reshape(image.shape[:2]) ax.imshow(image, cmap="viridis") ax.grid(False) plt.axis("off") def visualize_image_dict(images, plot_scale = 10): """Visualize a dictionary of images in colab using maptlotlib.""" for key in images.keys(): logging.info("Visualizing key: %s", key) n_images = len(images[key]) fig = plt.figure(figsize=(n_images * plot_scale, plot_scale)) for idx, image in enumerate(images[key]): ax = fig.add_subplot(1, n_images, idx+1) plot_image(ax, image) plt.show() def filter_key_from_frozen_dict( frozen_dict, key): """Filters (removes) an item by key from a flax.core.FrozenDict.""" if key in frozen_dict: frozen_dict, _ = frozen_dict.pop(key) return frozen_dict def prepare_dict_for_logging(nested_dict, parent_key = "", sep = "_"): """Prepare a nested dictionary for logging with `clu.metric_writers`. Args: nested_dict: A nested dictionary, e.g. obtained from a `ml_collections.ConfigDict` via `.to_dict()`. parent_key: String used in recursion. sep: String used to separate parent and child keys. Returns: Flattened dict. """ items = [] for k, v in nested_dict.items(): # Flatten keys of nested elements. new_key = parent_key + sep + k if parent_key else k # Convert None values, lists and tuples to strings. if v is None: v = "None" if isinstance(v, list) or isinstance(v, tuple): v = str(v) # Recursively flatten the dict. if isinstance(v, dict): items.extend(prepare_dict_for_logging(v, new_key, sep=sep).items()) else: items.append((new_key, v)) return dict(items) def retrieve_from_collection( variable_collection, path): """Finds variables by their path by recursively searching the collection. Args: variable_collection: Nested dict containing the variables (or tuples/lists of variables). path: Path to variable in module tree, similar to Unix file names (e.g. '/module/dense/0/bias'). Returns: The requested variable, variable collection or None (in case the variable could not be found). """ key, _, rpath = path.strip("/").partition("/") # In case the variable is not found, we return None. if (key.isdigit() and not isinstance(variable_collection, Sequence)) or ( key.isdigit() and int(key) >= len(variable_collection)) or ( not key.isdigit() and key not in variable_collection): return None if key.isdigit(): key = int(key) if not rpath: return variable_collection[key] else: return retrieve_from_collection(variable_collection[key], rpath) def build_model_from_config(config): """Build a Flax model from a (nested) ConfigDict.""" model_constructor = _parse_config(config) if callable(model_constructor): return model_constructor() else: raise ValueError("Provided config does not contain module constructors.") def _parse_config(config ): """Recursively parses a nested ConfigDict and resolves module constructors.""" if isinstance(config, list): return [_parse_config(c) for c in config] elif isinstance(config, tuple): return tuple([_parse_config(c) for c in config]) elif not isinstance(config, ml_collections.ConfigDict): return config elif "module" in config: module_constructor = _resolve_module_constructor(config.module) kwargs = {k: _parse_config(v) for k, v in config.items() if k != "module"} return functools.partial(module_constructor, **kwargs) else: return {k: _parse_config(v) for k, v in config.items()} def _resolve_module_constructor( constructor_str): import_str, _, module_name = constructor_str.rpartition(".") py_module = importlib.import_module(import_str) return getattr(py_module, module_name) def get_slices_along_axis( inputs, slice_keys, start_idx = 0, end_idx = -1, axis = 2, pad_value = 0): """Extracts slices from a dictionary of tensors along the specified axis. The slice operation is only applied to `slice_keys` dictionary keys. If `end_idx` is larger than the actual size of the specified axis, padding is added (with values provided in `pad_value`). Args: inputs: Dictionary of tensors. slice_keys: Iterable of strings, the keys for the inputs dictionary for which to apply the slice operation. start_idx: Integer, defining the first index to be part of the slice. end_idx: Integer, defining the end of the slice interval (exclusive). If set to `-1`, the end index is set to the size of the axis. If a value is provided that is larger than the size of the axis, zero-padding is added for the remaining elements. axis: Integer, the axis along which to slice. pad_value: Integer, value to be used in padding. Returns: Dictionary of tensors where elements described in `slice_keys` are sliced, and all other elements are returned as original. """ max_size = None pad_size = 0 # Check shapes and get maximum size of requested axis. for key in slice_keys: curr_size = inputs[key].shape[axis] if max_size is None: max_size = curr_size elif max_size != curr_size: raise ValueError( "For specified tensors the requested axis needs to be of equal size.") # Infer end index if not provided. if end_idx == -1: end_idx = max_size # Set padding size if end index is larger than maximum size of requested axis. elif end_idx > max_size: pad_size = end_idx - max_size end_idx = max_size outputs = {} for key in slice_keys: outputs[key] = np.take( inputs[key], indices=np.arange(start_idx, end_idx), axis=axis) # Add padding if necessary. if pad_size > 0: pad_shape = np.array(outputs[key].shape) np.put(pad_shape, axis, pad_size) # In-place op. padding = pad_value * np.ones(pad_shape, dtype=outputs[key].dtype) outputs[key] = np.concatenate((outputs[key], padding), axis=axis) return outputs def get_element_by_str( dictionary, multilevel_key, separator = "/" ): """Gets element in a dictionary with multilevel key (e.g., "key1/key2").""" keys = multilevel_key.split(separator) if len(keys) == 1: return dictionary[keys[0]] return get_element_by_str( dictionary[keys[0]], separator.join(keys[1:]), separator=separator) def set_element_by_str( dictionary, multilevel_key, new_value, separator = "/"): """Sets element in a dictionary with multilevel key (e.g., "key1/key2").""" keys = multilevel_key.split(separator) if len(keys) == 1: if keys[0] not in dictionary: key_error = ( "Pretrained {key} was not found in trained model. " "Make sure you are loading the correct pretrained model " "or consider adding {key} to exceptions.") raise KeyError(key_error.format(type="parameter", key=keys[0])) dictionary[keys[0]] = new_value else: set_element_by_str( dictionary[keys[0]], separator.join(keys[1:]), new_value, separator=separator) def remove_singleton_dim(inputs): """Removes the final dimension if it is singleton (i.e. of size 1).""" if inputs is None: return None if inputs.shape[-1] != 1: logging.warning("Expected final dimension of inputs to be 1, " "received inputs of shape %s: ", str(inputs.shape)) return inputs return inputs[Ellipsis, 0]