Spaces:
Running
Running
# coding=utf-8 | |
# Copyright 2023 The Google Research Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Common utils.""" | |
import functools | |
import importlib | |
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Type, Union | |
from absl import logging | |
from clu import metrics as base_metrics | |
import flax | |
from flax import linen as nn | |
from flax import traverse_util | |
import jax | |
import jax.numpy as jnp | |
import jax.ops | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import ml_collections | |
import numpy as np | |
import optax | |
import skimage.transform | |
import tensorflow as tf | |
from invariant_slot_attention.lib import metrics | |
Array = Any # Union[np.ndarray, jnp.ndarray] | |
ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet | |
DictTree = Dict[str, Union[Array, "DictTree"]] # pytype: disable=not-supported-yet | |
PRNGKey = Array | |
ConfigAttr = Any | |
MetricSpec = Dict[str, str] | |
class TrainState: | |
"""Data structure for checkpointing the model.""" | |
step: int | |
opt_state: optax.OptState | |
params: ArrayTree | |
variables: flax.core.FrozenDict | |
rng: PRNGKey | |
METRIC_TYPE_TO_CLS = { | |
"loss": base_metrics.Average.from_output(name="loss"), | |
"ari": metrics.Ari, | |
"ari_nobg": metrics.AriNoBg, | |
} | |
def make_metrics_collection( | |
class_name, | |
metrics_spec): | |
"""Create class inhering from metrics.Collection based on spec.""" | |
metrics_dict = {} | |
if metrics_spec: | |
for m_name, m_type in metrics_spec.items(): | |
metrics_dict[m_name] = METRIC_TYPE_TO_CLS[m_type] | |
return flax.struct.dataclass( | |
type(class_name, | |
(base_metrics.Collection,), | |
{"__annotations__": metrics_dict})) | |
def flatten_named_dicttree(metrics_res, sep = "/"): | |
"""Flatten dictionary.""" | |
metrics_res_flat = {} | |
for k, v in traverse_util.flatten_dict(metrics_res).items(): | |
metrics_res_flat[(sep.join(k)).strip(sep)] = v | |
return metrics_res_flat | |
def spatial_broadcast(x, resolution): | |
"""Broadcast flat inputs to a 2D grid of a given resolution.""" | |
# x.shape = (batch_size, features). | |
x = x[:, jnp.newaxis, jnp.newaxis, :] | |
return jnp.tile(x, [1, resolution[0], resolution[1], 1]) | |
def time_distributed(cls, in_axes=1, axis=1): | |
"""Wrapper for time-distributed (vmapped) application of a module.""" | |
return nn.vmap( | |
cls, in_axes=in_axes, out_axes=axis, axis_name="time", | |
# Stack debug vars along sequence dim and broadcast params. | |
variable_axes={ | |
"params": None, "intermediates": axis, "batch_stats": None}, | |
split_rngs={"params": False, "dropout": True, "state_init": True}) | |
def broadcast_across_batch(inputs, batch_size): | |
"""Broadcasts inputs across a batch of examples (creates new axis).""" | |
return jnp.broadcast_to( | |
array=jnp.expand_dims(inputs, axis=0), | |
shape=(batch_size,) + inputs.shape) | |
def create_gradient_grid( | |
samples_per_dim, value_range = (-1.0, 1.0) | |
): | |
"""Creates a tensor with equidistant entries from -1 to +1 in each dim. | |
Args: | |
samples_per_dim: Number of points to have along each dimension. | |
value_range: In each dimension, points will go from range[0] to range[1] | |
Returns: | |
A tensor of shape [samples_per_dim] + [len(samples_per_dim)]. | |
""" | |
s = [jnp.linspace(value_range[0], value_range[1], n) for n in samples_per_dim] | |
pe = jnp.stack(jnp.meshgrid(*s, sparse=False, indexing="ij"), axis=-1) | |
return jnp.array(pe) | |
def convert_to_fourier_features(inputs, basis_degree): | |
"""Convert inputs to Fourier features, e.g. for positional encoding.""" | |
# inputs.shape = (..., n_dims). | |
# inputs should be in range [-pi, pi] or [0, 2pi]. | |
n_dims = inputs.shape[-1] | |
# Generate frequency basis. | |
freq_basis = jnp.concatenate( # shape = (n_dims, n_dims * basis_degree) | |
[2**i * jnp.eye(n_dims) for i in range(basis_degree)], 1) | |
# x.shape = (..., n_dims * basis_degree) | |
x = inputs @ freq_basis # Project inputs onto frequency basis. | |
# Obtain Fourier features as [sin(x), cos(x)] = [sin(x), sin(x + 0.5 * pi)]. | |
return jnp.sin(jnp.concatenate([x, x + 0.5 * jnp.pi], axis=-1)) | |
def prepare_images_for_logging( | |
config, | |
batch = None, | |
preds = None, | |
n_samples = 5, | |
n_frames = 5, | |
min_n_colors = 1, | |
epsilon = 1e-6, | |
first_replica_only = False): | |
"""Prepare images from batch and/or model predictions for logging.""" | |
images = dict() | |
# Converts all tensors to numpy arrays to run everything on CPU as JAX | |
# eager mode is inefficient and because memory usage from these ops may | |
# lead to OOM errors. | |
batch = jax.tree_map(np.array, batch) | |
preds = jax.tree_map(np.array, preds) | |
if n_samples <= 0: | |
return images | |
if not first_replica_only: | |
# Move the two leading batch dimensions into a single dimension. We do this | |
# to plot the same number of examples regardless of the data parallelism. | |
batch = jax.tree_map(lambda x: np.reshape(x, (-1,) + x.shape[2:]), batch) | |
preds = jax.tree_map(lambda x: np.reshape(x, (-1,) + x.shape[2:]), preds) | |
else: | |
batch = jax.tree_map(lambda x: x[0], batch) | |
preds = jax.tree_map(lambda x: x[0], preds) | |
# Limit the tensors to n_samples and n_frames. | |
batch = jax.tree_map( | |
lambda x: x[:n_samples, :n_frames] if x.ndim > 2 else x[:n_samples], | |
batch) | |
preds = jax.tree_map( | |
lambda x: x[:n_samples, :n_frames] if x.ndim > 2 else x[:n_samples], | |
preds) | |
# Log input data. | |
if batch is not None: | |
images["video"] = video_to_image_grid(batch["video"]) | |
if "segmentations" in batch: | |
images["mask"] = video_to_image_grid(convert_categories_to_color( | |
batch["segmentations"], min_n_colors=min_n_colors)) | |
if "flow" in batch: | |
images["flow"] = video_to_image_grid(batch["flow"]) | |
if "boxes" in batch: | |
images["boxes"] = draw_bounding_boxes( | |
batch["video"], | |
batch["boxes"], | |
min_n_colors=min_n_colors) | |
# Log model predictions. | |
if preds is not None and preds.get("outputs") is not None: | |
if "segmentations" in preds["outputs"]: # pytype: disable=attribute-error | |
images["segmentations"] = video_to_image_grid( | |
convert_categories_to_color( | |
preds["outputs"]["segmentations"], min_n_colors=min_n_colors)) | |
def shape_fn(x): | |
if isinstance(x, (np.ndarray, jnp.ndarray)): | |
return x.shape | |
# Log intermediate variables. | |
if preds is not None and "intermediates" in preds: | |
logging.info("intermediates: %s", | |
jax.tree_map(shape_fn, preds["intermediates"])) | |
for key, path in config.debug_var_video_paths.items(): | |
log_vars = retrieve_from_collection(preds["intermediates"], path) | |
if log_vars is not None: | |
if not isinstance(log_vars, Sequence): | |
log_vars = [log_vars] | |
for i, log_var in enumerate(log_vars): | |
log_var = np.array(log_var) # Moves log_var to CPU. | |
images[key + "_" + str(i)] = video_to_image_grid(log_var) | |
else: | |
logging.warning("%s not found in intermediates", path) | |
# Log attention weights. | |
for key, path in config.debug_var_attn_paths.items(): | |
log_vars = retrieve_from_collection(preds["intermediates"], path) | |
if log_vars is not None: | |
if not isinstance(log_vars, Sequence): | |
log_vars = [log_vars] | |
for i, log_var in enumerate(log_vars): | |
log_var = np.array(log_var) # Moves log_var to CPU. | |
images.update( | |
prepare_attention_maps_for_logging( | |
attn_maps=log_var, | |
key=key + "_" + str(i), | |
map_width=config.debug_var_attn_widths.get(key), | |
video=batch["video"], | |
epsilon=epsilon, | |
n_samples=n_samples, | |
n_frames=n_frames)) | |
else: | |
logging.warning("%s not found in intermediates", path) | |
# Crop each image to a maximum of 3 channels for RGB visualization. | |
for key, image in images.items(): | |
if image.shape[-1] > 3: | |
logging.warning("Truncating channels of %s for visualization.", key) | |
images[key] = image[Ellipsis, :3] | |
return images | |
def prepare_attention_maps_for_logging(attn_maps, key, | |
map_width, epsilon, | |
n_samples, n_frames, | |
video): | |
"""Visualize (overlayed) attention maps as an image grid.""" | |
images = {} # Results dictionary. | |
attn_maps = unflatten_image(attn_maps[Ellipsis, None], width=map_width) | |
num_heads = attn_maps.shape[2] | |
for head_idx in range(num_heads): | |
attn = attn_maps[:n_samples, :n_frames, head_idx] | |
attn /= attn.max() + epsilon # Standardizes scale for visualization. | |
# attn.shape: [bs, seq_len, 11, h', w', 1] | |
bs, seq_len, _, h_attn, w_attn, _ = attn.shape | |
images[f"{key}_head_{head_idx}"] = video_to_image_grid(attn) | |
# Attention maps are interpretable when they align with object boundaries. | |
# However, if they are overly smooth then the following visualization which | |
# overlays attention maps on video is helpful. | |
video = video[:n_samples, :n_frames] | |
# video.shape: [bs, seq_len, h, w, 3] | |
video_resized = [] | |
for i in range(n_samples): | |
for j in range(n_frames): | |
video_resized.append( | |
skimage.transform.resize(video[i, j], (h_attn, w_attn), order=1)) | |
video_resized = np.array(video_resized).reshape( | |
(bs, seq_len, h_attn, w_attn, 3)) | |
attn_overlayed = attn * np.expand_dims(video_resized, 2) | |
images[f"{key}_head_{head_idx}_overlayed"] = video_to_image_grid( | |
attn_overlayed) | |
return images | |
def convert_categories_to_color( | |
inputs, min_n_colors = 1, include_black = True): | |
"""Converts int-valued categories to color in last axis of input tensor. | |
Args: | |
inputs: `np.ndarray` of arbitrary shape with integer entries, encoding the | |
categories. | |
min_n_colors: Minimum number of colors (excl. black) to encode categories. | |
include_black: Include black as 0-th entry in the color palette. Increases | |
`min_n_colors` by 1 if True. | |
Returns: | |
`np.ndarray` with RGB colors in last axis. | |
""" | |
if inputs.shape[-1] == 1: # Strip category axis. | |
inputs = np.squeeze(inputs, axis=-1) | |
inputs = np.array(inputs, dtype=np.int32) # Convert to int. | |
# Infer number of colors from inputs. | |
n_colors = int(inputs.max()) + 1 # One color per category incl. 0. | |
if include_black: | |
n_colors -= 1 # If we include black, we need one color less. | |
if min_n_colors > n_colors: # Use more colors in color palette if requested. | |
n_colors = min_n_colors | |
rgb_colors = get_uniform_colors(n_colors) | |
if include_black: # Add black as color for zero-th index. | |
rgb_colors = np.concatenate((np.zeros((1, 3)), rgb_colors), axis=0) | |
return rgb_colors[inputs] | |
def get_uniform_colors(n_colors): | |
"""Get n_colors with uniformly spaced hues.""" | |
hues = np.linspace(0, 1, n_colors, endpoint=False) | |
hsv_colors = np.concatenate( | |
(np.expand_dims(hues, axis=1), np.ones((n_colors, 2))), axis=1) | |
rgb_colors = matplotlib.colors.hsv_to_rgb(hsv_colors) | |
return rgb_colors # rgb_colors.shape = (n_colors, 3) | |
def unflatten_image(image, width = None): | |
"""Unflatten image array of shape [batch_dims..., height*width, channels].""" | |
n_channels = image.shape[-1] | |
# If width is not provided, we assume that the image is square. | |
if width is None: | |
width = int(np.floor(np.sqrt(image.shape[-2]))) | |
height = width | |
assert width * height == image.shape[-2], "Image is not square." | |
else: | |
height = image.shape[-2] // width | |
return image.reshape(image.shape[:-2] + (height, width, n_channels)) | |
def video_to_image_grid(video): | |
"""Transform video to image grid by folding sequence dim along width.""" | |
if len(video.shape) == 5: | |
n_samples, n_frames, height, width, n_channels = video.shape | |
video = np.transpose(video, (0, 2, 1, 3, 4)) # Swap n_frames and height. | |
image_grid = np.reshape( | |
video, (n_samples, height, n_frames * width, n_channels)) | |
elif len(video.shape) == 6: | |
n_samples, n_frames, n_slots, height, width, n_channels = video.shape | |
# Put n_frames next to width. | |
video = np.transpose(video, (0, 2, 3, 1, 4, 5)) | |
image_grid = np.reshape( | |
video, (n_samples, n_slots * height, n_frames * width, n_channels)) | |
else: | |
raise ValueError("Unsupported video shape for visualization.") | |
return image_grid | |
def draw_bounding_boxes(video, | |
boxes, | |
min_n_colors = 1, | |
include_black = True): | |
"""Draw bounding boxes in videos.""" | |
colors = get_uniform_colors(min_n_colors - include_black) | |
b, t, h, w, c = video.shape | |
n = boxes.shape[2] | |
image_grid = tf.image.draw_bounding_boxes( | |
np.reshape(video, (b * t, h, w, c)), | |
np.reshape(boxes, (b * t, n, 4)), | |
colors).numpy() | |
image_grid = np.reshape( | |
np.transpose(np.reshape(image_grid, (b, t, h, w, c)), | |
(0, 2, 1, 3, 4)), | |
(b, h, t * w, c)) | |
return image_grid | |
def plot_image(ax, image): | |
"""Add an image visualization to a provided `plt.Axes` instance.""" | |
num_channels = image.shape[-1] | |
if num_channels == 1: | |
image = image.reshape(image.shape[:2]) | |
ax.imshow(image, cmap="viridis") | |
ax.grid(False) | |
plt.axis("off") | |
def visualize_image_dict(images, plot_scale = 10): | |
"""Visualize a dictionary of images in colab using maptlotlib.""" | |
for key in images.keys(): | |
logging.info("Visualizing key: %s", key) | |
n_images = len(images[key]) | |
fig = plt.figure(figsize=(n_images * plot_scale, plot_scale)) | |
for idx, image in enumerate(images[key]): | |
ax = fig.add_subplot(1, n_images, idx+1) | |
plot_image(ax, image) | |
plt.show() | |
def filter_key_from_frozen_dict( | |
frozen_dict, key): | |
"""Filters (removes) an item by key from a flax.core.FrozenDict.""" | |
if key in frozen_dict: | |
frozen_dict, _ = frozen_dict.pop(key) | |
return frozen_dict | |
def prepare_dict_for_logging(nested_dict, parent_key = "", | |
sep = "_"): | |
"""Prepare a nested dictionary for logging with `clu.metric_writers`. | |
Args: | |
nested_dict: A nested dictionary, e.g. obtained from a | |
`ml_collections.ConfigDict` via `.to_dict()`. | |
parent_key: String used in recursion. | |
sep: String used to separate parent and child keys. | |
Returns: | |
Flattened dict. | |
""" | |
items = [] | |
for k, v in nested_dict.items(): | |
# Flatten keys of nested elements. | |
new_key = parent_key + sep + k if parent_key else k | |
# Convert None values, lists and tuples to strings. | |
if v is None: | |
v = "None" | |
if isinstance(v, list) or isinstance(v, tuple): | |
v = str(v) | |
# Recursively flatten the dict. | |
if isinstance(v, dict): | |
items.extend(prepare_dict_for_logging(v, new_key, sep=sep).items()) | |
else: | |
items.append((new_key, v)) | |
return dict(items) | |
def retrieve_from_collection( | |
variable_collection, path): | |
"""Finds variables by their path by recursively searching the collection. | |
Args: | |
variable_collection: Nested dict containing the variables (or tuples/lists | |
of variables). | |
path: Path to variable in module tree, similar to Unix file names (e.g. | |
'/module/dense/0/bias'). | |
Returns: | |
The requested variable, variable collection or None (in case the variable | |
could not be found). | |
""" | |
key, _, rpath = path.strip("/").partition("/") | |
# In case the variable is not found, we return None. | |
if (key.isdigit() and not isinstance(variable_collection, Sequence)) or ( | |
key.isdigit() and int(key) >= len(variable_collection)) or ( | |
not key.isdigit() and key not in variable_collection): | |
return None | |
if key.isdigit(): | |
key = int(key) | |
if not rpath: | |
return variable_collection[key] | |
else: | |
return retrieve_from_collection(variable_collection[key], rpath) | |
def build_model_from_config(config): | |
"""Build a Flax model from a (nested) ConfigDict.""" | |
model_constructor = _parse_config(config) | |
if callable(model_constructor): | |
return model_constructor() | |
else: | |
raise ValueError("Provided config does not contain module constructors.") | |
def _parse_config(config | |
): | |
"""Recursively parses a nested ConfigDict and resolves module constructors.""" | |
if isinstance(config, list): | |
return [_parse_config(c) for c in config] | |
elif isinstance(config, tuple): | |
return tuple([_parse_config(c) for c in config]) | |
elif not isinstance(config, ml_collections.ConfigDict): | |
return config | |
elif "module" in config: | |
module_constructor = _resolve_module_constructor(config.module) | |
kwargs = {k: _parse_config(v) for k, v in config.items() if k != "module"} | |
return functools.partial(module_constructor, **kwargs) | |
else: | |
return {k: _parse_config(v) for k, v in config.items()} | |
def _resolve_module_constructor( | |
constructor_str): | |
import_str, _, module_name = constructor_str.rpartition(".") | |
py_module = importlib.import_module(import_str) | |
return getattr(py_module, module_name) | |
def get_slices_along_axis( | |
inputs, | |
slice_keys, | |
start_idx = 0, | |
end_idx = -1, | |
axis = 2, | |
pad_value = 0): | |
"""Extracts slices from a dictionary of tensors along the specified axis. | |
The slice operation is only applied to `slice_keys` dictionary keys. If | |
`end_idx` is larger than the actual size of the specified axis, padding is | |
added (with values provided in `pad_value`). | |
Args: | |
inputs: Dictionary of tensors. | |
slice_keys: Iterable of strings, the keys for the inputs dictionary for | |
which to apply the slice operation. | |
start_idx: Integer, defining the first index to be part of the slice. | |
end_idx: Integer, defining the end of the slice interval (exclusive). If set | |
to `-1`, the end index is set to the size of the axis. If a value is | |
provided that is larger than the size of the axis, zero-padding is added | |
for the remaining elements. | |
axis: Integer, the axis along which to slice. | |
pad_value: Integer, value to be used in padding. | |
Returns: | |
Dictionary of tensors where elements described in `slice_keys` are sliced, | |
and all other elements are returned as original. | |
""" | |
max_size = None | |
pad_size = 0 | |
# Check shapes and get maximum size of requested axis. | |
for key in slice_keys: | |
curr_size = inputs[key].shape[axis] | |
if max_size is None: | |
max_size = curr_size | |
elif max_size != curr_size: | |
raise ValueError( | |
"For specified tensors the requested axis needs to be of equal size.") | |
# Infer end index if not provided. | |
if end_idx == -1: | |
end_idx = max_size | |
# Set padding size if end index is larger than maximum size of requested axis. | |
elif end_idx > max_size: | |
pad_size = end_idx - max_size | |
end_idx = max_size | |
outputs = {} | |
for key in slice_keys: | |
outputs[key] = np.take( | |
inputs[key], indices=np.arange(start_idx, end_idx), axis=axis) | |
# Add padding if necessary. | |
if pad_size > 0: | |
pad_shape = np.array(outputs[key].shape) | |
np.put(pad_shape, axis, pad_size) # In-place op. | |
padding = pad_value * np.ones(pad_shape, dtype=outputs[key].dtype) | |
outputs[key] = np.concatenate((outputs[key], padding), axis=axis) | |
return outputs | |
def get_element_by_str( | |
dictionary, multilevel_key, separator = "/" | |
): | |
"""Gets element in a dictionary with multilevel key (e.g., "key1/key2").""" | |
keys = multilevel_key.split(separator) | |
if len(keys) == 1: | |
return dictionary[keys[0]] | |
return get_element_by_str( | |
dictionary[keys[0]], separator.join(keys[1:]), separator=separator) | |
def set_element_by_str( | |
dictionary, multilevel_key, new_value, | |
separator = "/"): | |
"""Sets element in a dictionary with multilevel key (e.g., "key1/key2").""" | |
keys = multilevel_key.split(separator) | |
if len(keys) == 1: | |
if keys[0] not in dictionary: | |
key_error = ( | |
"Pretrained {key} was not found in trained model. " | |
"Make sure you are loading the correct pretrained model " | |
"or consider adding {key} to exceptions.") | |
raise KeyError(key_error.format(type="parameter", key=keys[0])) | |
dictionary[keys[0]] = new_value | |
else: | |
set_element_by_str( | |
dictionary[keys[0]], | |
separator.join(keys[1:]), | |
new_value, | |
separator=separator) | |
def remove_singleton_dim(inputs): | |
"""Removes the final dimension if it is singleton (i.e. of size 1).""" | |
if inputs is None: | |
return None | |
if inputs.shape[-1] != 1: | |
logging.warning("Expected final dimension of inputs to be 1, " | |
"received inputs of shape %s: ", str(inputs.shape)) | |
return inputs | |
return inputs[Ellipsis, 0] | |