# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import copy import torch import torch.nn as nn import torch.nn.functional as F from typing import Union from functools import partial from itertools import repeat import collections.abc # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): return tuple(x) return tuple(repeat(x, n)) return parse to_1tuple = _ntuple(1) to_2tuple = _ntuple(2) to_3tuple = _ntuple(3) to_4tuple = _ntuple(4) to_ntuple = _ntuple def make_divisible(v, divisor=8, min_value=None, round_limit=.9): min_value = min_value or divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) # Make sure that round down does not go down by more than 10%. if new_v < round_limit * v: new_v += divisor return new_v def extend_tuple(x, n): # pads a tuple to specified n by padding with last value if not isinstance(x, (tuple, list)): x = (x,) else: x = tuple(x) pad_n = n - len(x) if pad_n <= 0: return x[:n] return x + (x[-1],) * pad_n def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): """ Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` that are temporally closest to the current frame at `frame_idx`. Here, we take - a) the closest conditioning frame before `frame_idx` (if any); - b) the closest conditioning frame after `frame_idx` (if any); - c) any other temporally closest conditioning frames until reaching a total of `max_cond_frame_num` conditioning frames. Outputs: - selected_outputs: selected items (keys & values) from `cond_frame_outputs`. - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. """ if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: selected_outputs = cond_frame_outputs unselected_outputs = {} else: assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" selected_outputs = {} # the closest conditioning frame before `frame_idx` (if any) idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) if idx_before is not None: selected_outputs[idx_before] = cond_frame_outputs[idx_before] # the closest conditioning frame after `frame_idx` (if any) idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) if idx_after is not None: selected_outputs[idx_after] = cond_frame_outputs[idx_after] # add other temporally closest conditioning frames until reaching a total # of `max_cond_frame_num` conditioning frames. num_remain = max_cond_frame_num - len(selected_outputs) inds_remain = sorted( (t for t in cond_frame_outputs if t not in selected_outputs), key=lambda x: abs(x - frame_idx), )[:num_remain] selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) unselected_outputs = { t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs } return selected_outputs, unselected_outputs def get_1d_sine_pe(pos_inds, dim, temperature=10000): """ Get 1D sine positional embedding as in the original Transformer paper. """ pe_dim = dim // 2 dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) pos_embed = pos_inds.unsqueeze(-1) / dim_t pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) return pos_embed def get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu raise RuntimeError(f"activation should be relu/gelu, not {activation}.") def get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) class DropPath(nn.Module): # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py def __init__(self, drop_prob=0.0, scale_by_keep=True): super(DropPath, self).__init__() self.drop_prob = drop_prob self.scale_by_keep = scale_by_keep def forward(self, x): if self.drop_prob == 0.0 or not self.training: return x keep_prob = 1 - self.drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) random_tensor = x.new_empty(shape).bernoulli_(keep_prob) if keep_prob > 0.0 and self.scale_by_keep: random_tensor.div_(keep_prob) return x * random_tensor # Lightly adapted from # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa class MLP(nn.Module): def __init__( self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, activation: nn.Module = nn.ReLU, sigmoid_output: bool = False, ) -> None: super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList( nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) ) self.sigmoid_output = sigmoid_output self.act = activation() def forward(self, x): for i, layer in enumerate(self.layers): x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) if self.sigmoid_output: x = F.sigmoid(x) return x # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa class LayerNorm2d(nn.Module): def __init__(self, num_channels: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(num_channels)) self.bias = nn.Parameter(torch.zeros(num_channels)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x class LayerScale(nn.Module): def __init__( self, dim: int, init_values: Union[float, torch.Tensor] = 1e-5, inplace: bool = False, ) -> None: super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma