# This file is taken from signjoey repository import math import torch from torch import Tensor, nn def get_activation(activation_type): if activation_type == "relu": return nn.ReLU() elif activation_type == "relu6": return nn.ReLU6() elif activation_type == "prelu": return nn.PReLU() elif activation_type == "selu": return nn.SELU() elif activation_type == "celu": return nn.CELU() elif activation_type == "gelu": return nn.GELU() elif activation_type == "sigmoid": return nn.Sigmoid() elif activation_type == "softplus": return nn.Softplus() elif activation_type == "softshrink": return nn.Softshrink() elif activation_type == "softsign": return nn.Softsign() elif activation_type == "tanh": return nn.Tanh() elif activation_type == "tanhshrink": return nn.Tanhshrink() else: raise ValueError("Unknown activation type {}".format(activation_type)) class MaskedNorm(nn.Module): """ Original Code from: https://discuss.pytorch.org/t/batchnorm-for-different-sized-samples-in-batch/44251/8 """ def __init__(self, norm_type, num_groups, num_features): super().__init__() self.norm_type = norm_type if self.norm_type == "batch": self.norm = nn.BatchNorm1d(num_features=num_features) elif self.norm_type == "group": self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=num_features) elif self.norm_type == "layer": self.norm = nn.LayerNorm(normalized_shape=num_features) else: raise ValueError("Unsupported Normalization Layer") self.num_features = num_features def forward(self, x: Tensor, mask: Tensor): if self.training: reshaped = x.reshape([-1, self.num_features]) reshaped_mask = mask.reshape([-1, 1]) > 0 selected = torch.masked_select(reshaped, reshaped_mask).reshape( [-1, self.num_features] ) batch_normed = self.norm(selected) scattered = reshaped.masked_scatter(reshaped_mask, batch_normed) return scattered.reshape([x.shape[0], -1, self.num_features]) else: reshaped = x.reshape([-1, self.num_features]) batched_normed = self.norm(reshaped) return batched_normed.reshape([x.shape[0], -1, self.num_features]) # TODO (Cihan): Spatial and Word Embeddings are pretty much the same # We might as well convert them into a single module class. # Only difference is the lut vs linear layers. class Embeddings(nn.Module): """ Simple embeddings class """ # pylint: disable=unused-argument def __init__( self, embedding_dim: int = 64, num_heads: int = 8, scale: bool = False, scale_factor: float = None, norm_type: str = None, activation_type: str = None, vocab_size: int = 0, padding_idx: int = 1, freeze: bool = False, **kwargs ): """ Create new embeddings for the vocabulary. Use scaling for the Transformer. :param embedding_dim: :param scale: :param vocab_size: :param padding_idx: :param freeze: freeze the embeddings during training """ super().__init__() self.embedding_dim = embedding_dim self.vocab_size = vocab_size self.lut = nn.Embedding(vocab_size, self.embedding_dim, padding_idx=padding_idx) self.norm_type = norm_type if self.norm_type: self.norm = MaskedNorm( norm_type=norm_type, num_groups=num_heads, num_features=embedding_dim ) self.activation_type = activation_type if self.activation_type: self.activation = get_activation(activation_type) self.scale = scale if self.scale: if scale_factor: self.scale_factor = scale_factor else: self.scale_factor = math.sqrt(self.embedding_dim) if freeze: freeze_params(self) # pylint: disable=arguments-differ def forward(self, x: Tensor, mask: Tensor = None) -> Tensor: """ Perform lookup for input `x` in the embedding table. :param mask: token masks :param x: index in the vocabulary :return: embedded representation for `x` """ x = self.lut(x) if self.norm_type: x = self.norm(x, mask) if self.activation_type: x = self.activation(x) if self.scale: return x * self.scale_factor else: return x def __repr__(self): return "%s(embedding_dim=%d, vocab_size=%d)" % ( self.__class__.__name__, self.embedding_dim, self.vocab_size, ) class SpatialEmbeddings(nn.Module): """ Simple Linear Projection Layer (For encoder outputs to predict glosses) """ # pylint: disable=unused-argument def __init__( self, embedding_dim: int, input_size: int, num_heads: int, freeze: bool = False, norm_type: str = "batch", activation_type: str = "softsign", scale: bool = False, scale_factor: float = None, **kwargs ): """ Create new embeddings for the vocabulary. Use scaling for the Transformer. :param embedding_dim: :param input_size: :param freeze: freeze the embeddings during training """ super().__init__() self.embedding_dim = embedding_dim self.input_size = input_size self.ln = nn.Linear(self.input_size, self.embedding_dim) self.norm_type = norm_type if self.norm_type: self.norm = MaskedNorm( norm_type=norm_type, num_groups=num_heads, num_features=embedding_dim ) self.activation_type = activation_type if self.activation_type: self.activation = get_activation(activation_type) self.scale = scale if self.scale: if scale_factor: self.scale_factor = scale_factor else: self.scale_factor = math.sqrt(self.embedding_dim) if freeze: freeze_params(self) # pylint: disable=arguments-differ def forward(self, x: Tensor, mask: Tensor) -> Tensor: """ :param mask: frame masks :param x: input frame features :return: embedded representation for `x` """ x = self.ln(x) if self.norm_type: x = self.norm(x, mask) if self.activation_type: x = self.activation(x) if self.scale: return x * self.scale_factor else: return x def __repr__(self): return "%s(embedding_dim=%d, input_size=%d)" % ( self.__class__.__name__, self.embedding_dim, self.input_size, ) def get_timestep_embedding( timesteps: torch.Tensor, embedding_dim: int, flip_sin_to_cos: bool = False, downscale_freq_shift: float = 1, scale: float = 1, max_period: int = 10000, ): """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 exponent = -math.log(max_period) * torch.arange( start=0, end=half_dim, dtype=torch.float32, device=timesteps.device ) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) emb = timesteps[:, None].float() * emb[None, :] # scale embeddings emb = scale * emb # concat sine and cosine embeddings emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # flip sine and cosine embeddings if flip_sin_to_cos: emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) # zero pad if embedding_dim % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb class TimestepEmbedding(nn.Module): def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"): super().__init__() self.linear_1 = nn.Linear(channel, time_embed_dim) self.act = None if act_fn == "silu": self.act = nn.SiLU() self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) def forward(self, sample): sample = self.linear_1(sample) if self.act is not None: sample = self.act(sample) sample = self.linear_2(sample) return sample class Timesteps(nn.Module): def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): super().__init__() self.num_channels = num_channels self.flip_sin_to_cos = flip_sin_to_cos self.downscale_freq_shift = downscale_freq_shift def forward(self, timesteps): t_emb = get_timestep_embedding( timesteps, self.num_channels, flip_sin_to_cos=self.flip_sin_to_cos, downscale_freq_shift=self.downscale_freq_shift, ) return t_emb