File size: 21,233 Bytes
a560c26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
# coding=utf-8
# Copyright 2023 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Common utils."""

import functools
import importlib
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Type, Union

from absl import logging
from clu import metrics as base_metrics

import flax
from flax import linen as nn
from flax import traverse_util

import jax
import jax.numpy as jnp
import jax.ops

import matplotlib
import matplotlib.pyplot as plt
import ml_collections
import numpy as np
import optax

import skimage.transform
import tensorflow as tf

from invariant_slot_attention.lib import metrics


Array = Any  # Union[np.ndarray, jnp.ndarray]
ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]]  # pytype: disable=not-supported-yet
DictTree = Dict[str, Union[Array, "DictTree"]]  # pytype: disable=not-supported-yet
PRNGKey = Array
ConfigAttr = Any
MetricSpec = Dict[str, str]


@flax.struct.dataclass
class TrainState:
  """Data structure for checkpointing the model."""
  step: int
  opt_state: optax.OptState
  params: ArrayTree
  variables: flax.core.FrozenDict
  rng: PRNGKey


METRIC_TYPE_TO_CLS = {
    "loss": base_metrics.Average.from_output(name="loss"),
    "ari": metrics.Ari,
    "ari_nobg": metrics.AriNoBg,
}


def make_metrics_collection(
    class_name,
    metrics_spec):
  """Create class inhering from metrics.Collection based on spec."""
  metrics_dict = {}
  if metrics_spec:
    for m_name, m_type in metrics_spec.items():
      metrics_dict[m_name] = METRIC_TYPE_TO_CLS[m_type]

  return flax.struct.dataclass(
      type(class_name,
           (base_metrics.Collection,),
           {"__annotations__": metrics_dict}))


def flatten_named_dicttree(metrics_res, sep = "/"):
  """Flatten dictionary."""
  metrics_res_flat = {}
  for k, v in traverse_util.flatten_dict(metrics_res).items():
    metrics_res_flat[(sep.join(k)).strip(sep)] = v
  return metrics_res_flat


def spatial_broadcast(x, resolution):
  """Broadcast flat inputs to a 2D grid of a given resolution."""
  # x.shape = (batch_size, features).
  x = x[:, jnp.newaxis, jnp.newaxis, :]
  return jnp.tile(x, [1, resolution[0], resolution[1], 1])


def time_distributed(cls, in_axes=1, axis=1):
  """Wrapper for time-distributed (vmapped) application of a module."""
  return nn.vmap(
      cls, in_axes=in_axes, out_axes=axis, axis_name="time",
      # Stack debug vars along sequence dim and broadcast params.
      variable_axes={
          "params": None, "intermediates": axis, "batch_stats": None},
      split_rngs={"params": False, "dropout": True, "state_init": True})


def broadcast_across_batch(inputs, batch_size):
  """Broadcasts inputs across a batch of examples (creates new axis)."""
  return jnp.broadcast_to(
      array=jnp.expand_dims(inputs, axis=0),
      shape=(batch_size,) + inputs.shape)


def create_gradient_grid(
    samples_per_dim, value_range = (-1.0, 1.0)
    ):
  """Creates a tensor with equidistant entries from -1 to +1 in each dim.

  Args:
    samples_per_dim: Number of points to have along each dimension.
    value_range: In each dimension, points will go from range[0] to range[1]

  Returns:
    A tensor of shape [samples_per_dim] + [len(samples_per_dim)].
  """
  s = [jnp.linspace(value_range[0], value_range[1], n) for n in samples_per_dim]
  pe = jnp.stack(jnp.meshgrid(*s, sparse=False, indexing="ij"), axis=-1)
  return jnp.array(pe)


def convert_to_fourier_features(inputs, basis_degree):
  """Convert inputs to Fourier features, e.g. for positional encoding."""

  # inputs.shape = (..., n_dims).
  # inputs should be in range [-pi, pi] or [0, 2pi].
  n_dims = inputs.shape[-1]

  # Generate frequency basis.
  freq_basis = jnp.concatenate(  # shape = (n_dims, n_dims * basis_degree)
      [2**i * jnp.eye(n_dims) for i in range(basis_degree)], 1)

  # x.shape = (..., n_dims * basis_degree)
  x = inputs @ freq_basis  # Project inputs onto frequency basis.

  # Obtain Fourier features as [sin(x), cos(x)] = [sin(x), sin(x + 0.5 * pi)].
  return jnp.sin(jnp.concatenate([x, x + 0.5 * jnp.pi], axis=-1))


def prepare_images_for_logging(
    config,
    batch = None,
    preds = None,
    n_samples = 5,
    n_frames = 5,
    min_n_colors = 1,
    epsilon = 1e-6,
    first_replica_only = False):
  """Prepare images from batch and/or model predictions for logging."""

  images = dict()
  # Converts all tensors to numpy arrays to run everything on CPU as JAX
  # eager mode is inefficient and because memory usage from these ops may
  # lead to OOM errors.
  batch = jax.tree_map(np.array, batch)
  preds = jax.tree_map(np.array, preds)

  if n_samples <= 0:
    return images

  if not first_replica_only:
    # Move the two leading batch dimensions into a single dimension. We do this
    # to plot the same number of examples regardless of the data parallelism.
    batch = jax.tree_map(lambda x: np.reshape(x, (-1,) + x.shape[2:]), batch)
    preds = jax.tree_map(lambda x: np.reshape(x, (-1,) + x.shape[2:]), preds)
  else:
    batch = jax.tree_map(lambda x: x[0], batch)
    preds = jax.tree_map(lambda x: x[0], preds)

  # Limit the tensors to n_samples and n_frames.
  batch = jax.tree_map(
      lambda x: x[:n_samples, :n_frames] if x.ndim > 2 else x[:n_samples],
      batch)
  preds = jax.tree_map(
      lambda x: x[:n_samples, :n_frames] if x.ndim > 2 else x[:n_samples],
      preds)

  # Log input data.
  if batch is not None:
    images["video"] = video_to_image_grid(batch["video"])
    if "segmentations" in batch:
      images["mask"] = video_to_image_grid(convert_categories_to_color(
          batch["segmentations"], min_n_colors=min_n_colors))
    if "flow" in batch:
      images["flow"] = video_to_image_grid(batch["flow"])
    if "boxes" in batch:
      images["boxes"] = draw_bounding_boxes(
          batch["video"],
          batch["boxes"],
          min_n_colors=min_n_colors)

  # Log model predictions.
  if preds is not None and preds.get("outputs") is not None:
    if "segmentations" in preds["outputs"]:  # pytype: disable=attribute-error
      images["segmentations"] = video_to_image_grid(
          convert_categories_to_color(
              preds["outputs"]["segmentations"], min_n_colors=min_n_colors))

  def shape_fn(x):
    if isinstance(x, (np.ndarray, jnp.ndarray)):
      return x.shape

  # Log intermediate variables.
  if preds is not None and "intermediates" in preds:

    logging.info("intermediates: %s",
                 jax.tree_map(shape_fn, preds["intermediates"]))

    for key, path in config.debug_var_video_paths.items():
      log_vars = retrieve_from_collection(preds["intermediates"], path)
      if log_vars is not None:
        if not isinstance(log_vars, Sequence):
          log_vars = [log_vars]
        for i, log_var in enumerate(log_vars):
          log_var = np.array(log_var)  # Moves log_var to CPU.
          images[key + "_" + str(i)] = video_to_image_grid(log_var)
      else:
        logging.warning("%s not found in intermediates", path)

    # Log attention weights.
    for key, path in config.debug_var_attn_paths.items():
      log_vars = retrieve_from_collection(preds["intermediates"], path)
      if log_vars is not None:
        if not isinstance(log_vars, Sequence):
          log_vars = [log_vars]
        for i, log_var in enumerate(log_vars):
          log_var = np.array(log_var)  # Moves log_var to CPU.
          images.update(
              prepare_attention_maps_for_logging(
                  attn_maps=log_var,
                  key=key + "_" + str(i),
                  map_width=config.debug_var_attn_widths.get(key),
                  video=batch["video"],
                  epsilon=epsilon,
                  n_samples=n_samples,
                  n_frames=n_frames))
      else:
        logging.warning("%s not found in intermediates", path)

  # Crop each image to a maximum of 3 channels for RGB visualization.
  for key, image in images.items():
    if image.shape[-1] > 3:
      logging.warning("Truncating channels of %s for visualization.", key)
      images[key] = image[Ellipsis, :3]

  return images


def prepare_attention_maps_for_logging(attn_maps, key,
                                       map_width, epsilon,
                                       n_samples, n_frames,
                                       video):
  """Visualize (overlayed) attention maps as an image grid."""
  images = {}  # Results dictionary.
  attn_maps = unflatten_image(attn_maps[Ellipsis, None], width=map_width)

  num_heads = attn_maps.shape[2]
  for head_idx in range(num_heads):
    attn = attn_maps[:n_samples, :n_frames, head_idx]
    attn /= attn.max() + epsilon  # Standardizes scale for visualization.
    # attn.shape: [bs, seq_len, 11, h', w', 1]

    bs, seq_len, _, h_attn, w_attn, _ = attn.shape
    images[f"{key}_head_{head_idx}"] = video_to_image_grid(attn)

    # Attention maps are interpretable when they align with object boundaries.
    # However, if they are overly smooth then the following visualization which
    # overlays attention maps on video is helpful.
    video = video[:n_samples, :n_frames]
    # video.shape: [bs, seq_len, h, w, 3]
    video_resized = []
    for i in range(n_samples):
      for j in range(n_frames):
        video_resized.append(
            skimage.transform.resize(video[i, j], (h_attn, w_attn), order=1))
    video_resized = np.array(video_resized).reshape(
        (bs, seq_len, h_attn, w_attn, 3))
    attn_overlayed = attn * np.expand_dims(video_resized, 2)
    images[f"{key}_head_{head_idx}_overlayed"] = video_to_image_grid(
        attn_overlayed)

  return images


def convert_categories_to_color(
    inputs, min_n_colors = 1, include_black = True):
  """Converts int-valued categories to color in last axis of input tensor.

  Args:
    inputs: `np.ndarray` of arbitrary shape with integer entries, encoding the
      categories.
    min_n_colors: Minimum number of colors (excl. black) to encode categories.
    include_black: Include black as 0-th entry in the color palette. Increases
      `min_n_colors` by 1 if True.

  Returns:
    `np.ndarray` with RGB colors in last axis.
  """
  if inputs.shape[-1] == 1:  # Strip category axis.
    inputs = np.squeeze(inputs, axis=-1)
  inputs = np.array(inputs, dtype=np.int32)  # Convert to int.

  # Infer number of colors from inputs.
  n_colors = int(inputs.max()) + 1  # One color per category incl. 0.
  if include_black:
    n_colors -= 1  # If we include black, we need one color less.

  if min_n_colors > n_colors:  # Use more colors in color palette if requested.
    n_colors = min_n_colors

  rgb_colors = get_uniform_colors(n_colors)

  if include_black:  # Add black as color for zero-th index.
    rgb_colors = np.concatenate((np.zeros((1, 3)), rgb_colors), axis=0)
  return rgb_colors[inputs]


def get_uniform_colors(n_colors):
  """Get n_colors with uniformly spaced hues."""
  hues = np.linspace(0, 1, n_colors, endpoint=False)
  hsv_colors = np.concatenate(
      (np.expand_dims(hues, axis=1), np.ones((n_colors, 2))), axis=1)
  rgb_colors = matplotlib.colors.hsv_to_rgb(hsv_colors)
  return rgb_colors  # rgb_colors.shape = (n_colors, 3)


def unflatten_image(image, width = None):
  """Unflatten image array of shape [batch_dims..., height*width, channels]."""
  n_channels = image.shape[-1]
  # If width is not provided, we assume that the image is square.
  if width is None:
    width = int(np.floor(np.sqrt(image.shape[-2])))
    height = width
    assert width * height == image.shape[-2], "Image is not square."
  else:
    height = image.shape[-2] // width
  return image.reshape(image.shape[:-2] + (height, width, n_channels))


def video_to_image_grid(video):
  """Transform video to image grid by folding sequence dim along width."""
  if len(video.shape) == 5:
    n_samples, n_frames, height, width, n_channels = video.shape
    video = np.transpose(video, (0, 2, 1, 3, 4))  # Swap n_frames and height.
    image_grid = np.reshape(
        video, (n_samples, height, n_frames * width, n_channels))
  elif len(video.shape) == 6:
    n_samples, n_frames, n_slots, height, width, n_channels = video.shape
    # Put n_frames next to width.
    video = np.transpose(video, (0, 2, 3, 1, 4, 5))
    image_grid = np.reshape(
        video, (n_samples, n_slots * height, n_frames * width, n_channels))
  else:
    raise ValueError("Unsupported video shape for visualization.")
  return image_grid


def draw_bounding_boxes(video,
                        boxes,
                        min_n_colors = 1,
                        include_black = True):
  """Draw bounding boxes in videos."""
  colors = get_uniform_colors(min_n_colors - include_black)

  b, t, h, w, c = video.shape
  n = boxes.shape[2]
  image_grid = tf.image.draw_bounding_boxes(
      np.reshape(video, (b * t, h, w, c)),
      np.reshape(boxes, (b * t, n, 4)),
      colors).numpy()
  image_grid = np.reshape(
      np.transpose(np.reshape(image_grid, (b, t, h, w, c)),
                   (0, 2, 1, 3, 4)),
      (b, h, t * w, c))
  return image_grid


def plot_image(ax, image):
  """Add an image visualization to a provided `plt.Axes` instance."""
  num_channels = image.shape[-1]
  if num_channels == 1:
    image = image.reshape(image.shape[:2])
  ax.imshow(image, cmap="viridis")
  ax.grid(False)
  plt.axis("off")


def visualize_image_dict(images, plot_scale = 10):
  """Visualize a dictionary of images in colab using maptlotlib."""

  for key in images.keys():
    logging.info("Visualizing key: %s", key)
    n_images = len(images[key])
    fig = plt.figure(figsize=(n_images * plot_scale, plot_scale))
    for idx, image in enumerate(images[key]):
      ax = fig.add_subplot(1, n_images, idx+1)
      plot_image(ax, image)
    plt.show()


def filter_key_from_frozen_dict(
    frozen_dict, key):
  """Filters (removes) an item by key from a flax.core.FrozenDict."""
  if key in frozen_dict:
    frozen_dict, _ = frozen_dict.pop(key)
  return frozen_dict


def prepare_dict_for_logging(nested_dict, parent_key = "",
                             sep = "_"):
  """Prepare a nested dictionary for logging with `clu.metric_writers`.

  Args:
    nested_dict: A nested dictionary, e.g. obtained from a
      `ml_collections.ConfigDict` via `.to_dict()`.
    parent_key: String used in recursion.
    sep: String used to separate parent and child keys.

  Returns:
    Flattened dict.
  """
  items = []
  for k, v in nested_dict.items():
    # Flatten keys of nested elements.
    new_key = parent_key + sep + k if parent_key else k

    # Convert None values, lists and tuples to strings.
    if v is None:
      v = "None"
    if isinstance(v, list) or isinstance(v, tuple):
      v = str(v)

    # Recursively flatten the dict.
    if isinstance(v, dict):
      items.extend(prepare_dict_for_logging(v, new_key, sep=sep).items())
    else:
      items.append((new_key, v))
  return dict(items)


def retrieve_from_collection(
    variable_collection, path):
  """Finds variables by their path by recursively searching the collection.

  Args:
    variable_collection: Nested dict containing the variables (or tuples/lists
      of variables).
    path: Path to variable in module tree, similar to Unix file names (e.g.
      '/module/dense/0/bias').

  Returns:
    The requested variable, variable collection or None (in case the variable
      could not be found).
  """
  key, _, rpath = path.strip("/").partition("/")

  # In case the variable is not found, we return None.
  if (key.isdigit() and not isinstance(variable_collection, Sequence)) or (
      key.isdigit() and int(key) >= len(variable_collection)) or (
          not key.isdigit() and key not in variable_collection):
    return None

  if key.isdigit():
    key = int(key)

  if not rpath:
    return variable_collection[key]
  else:
    return retrieve_from_collection(variable_collection[key], rpath)


def build_model_from_config(config):
  """Build a Flax model from a (nested) ConfigDict."""
  model_constructor = _parse_config(config)
  if callable(model_constructor):
    return model_constructor()
  else:
    raise ValueError("Provided config does not contain module constructors.")


def _parse_config(config
                  ):
  """Recursively parses a nested ConfigDict and resolves module constructors."""

  if isinstance(config, list):
    return [_parse_config(c) for c in config]
  elif isinstance(config, tuple):
    return tuple([_parse_config(c) for c in config])
  elif not isinstance(config, ml_collections.ConfigDict):
    return config
  elif "module" in config:
    module_constructor = _resolve_module_constructor(config.module)
    kwargs = {k: _parse_config(v) for k, v in config.items() if k != "module"}
    return functools.partial(module_constructor, **kwargs)
  else:
    return {k: _parse_config(v) for k, v in config.items()}


def _resolve_module_constructor(
    constructor_str):
  import_str, _, module_name = constructor_str.rpartition(".")
  py_module = importlib.import_module(import_str)
  return getattr(py_module, module_name)


def get_slices_along_axis(
    inputs,
    slice_keys,
    start_idx = 0,
    end_idx = -1,
    axis = 2,
    pad_value = 0):
  """Extracts slices from a dictionary of tensors along the specified axis.

  The slice operation is only applied to `slice_keys` dictionary keys. If
  `end_idx` is larger than the actual size of the specified axis, padding is
  added (with values provided in `pad_value`).

  Args:
    inputs: Dictionary of tensors.
    slice_keys: Iterable of strings, the keys for the inputs dictionary for
      which to apply the slice operation.
    start_idx: Integer, defining the first index to be part of the slice.
    end_idx: Integer, defining the end of the slice interval (exclusive). If set
      to `-1`, the end index is set to the size of the axis. If a value is
      provided that is larger than the size of the axis, zero-padding is added
      for the remaining elements.
    axis: Integer, the axis along which to slice.
    pad_value: Integer, value to be used in padding.

  Returns:
    Dictionary of tensors where elements described in `slice_keys` are sliced,
      and all other elements are returned as original.
  """

  max_size = None
  pad_size = 0

  # Check shapes and get maximum size of requested axis.
  for key in slice_keys:
    curr_size = inputs[key].shape[axis]
    if max_size is None:
      max_size = curr_size
    elif max_size != curr_size:
      raise ValueError(
          "For specified tensors the requested axis needs to be of equal size.")

  # Infer end index if not provided.
  if end_idx == -1:
    end_idx = max_size

  # Set padding size if end index is larger than maximum size of requested axis.
  elif end_idx > max_size:
    pad_size = end_idx - max_size
    end_idx = max_size

  outputs = {}
  for key in slice_keys:
    outputs[key] = np.take(
        inputs[key], indices=np.arange(start_idx, end_idx), axis=axis)

    # Add padding if necessary.
    if pad_size > 0:
      pad_shape = np.array(outputs[key].shape)
      np.put(pad_shape, axis, pad_size)  # In-place op.
      padding = pad_value * np.ones(pad_shape, dtype=outputs[key].dtype)
      outputs[key] = np.concatenate((outputs[key], padding), axis=axis)

  return outputs


def get_element_by_str(
    dictionary, multilevel_key, separator = "/"
    ):
  """Gets element in a dictionary with multilevel key (e.g., "key1/key2")."""
  keys = multilevel_key.split(separator)
  if len(keys) == 1:
    return dictionary[keys[0]]
  return get_element_by_str(
      dictionary[keys[0]], separator.join(keys[1:]), separator=separator)


def set_element_by_str(
    dictionary, multilevel_key, new_value,
    separator = "/"):
  """Sets element in a dictionary with multilevel key (e.g., "key1/key2")."""
  keys = multilevel_key.split(separator)
  if len(keys) == 1:
    if keys[0] not in dictionary:
      key_error = (
          "Pretrained {key} was not found in trained model. "
          "Make sure you are loading the correct pretrained model "
          "or consider adding {key} to exceptions.")
      raise KeyError(key_error.format(type="parameter", key=keys[0]))
    dictionary[keys[0]] = new_value
  else:
    set_element_by_str(
        dictionary[keys[0]],
        separator.join(keys[1:]),
        new_value,
        separator=separator)


def remove_singleton_dim(inputs):
  """Removes the final dimension if it is singleton (i.e. of size 1)."""
  if inputs is None:
    return None
  if inputs.shape[-1] != 1:
    logging.warning("Expected final dimension of inputs to be 1, "
                    "received inputs of shape %s: ", str(inputs.shape))
    return inputs
  return inputs[Ellipsis, 0]