# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn def normalization(channels: int, groups: int = 32): r"""Make a standard normalization layer, i.e. GroupNorm. Args: channels: number of input channels. groups: number of groups for group normalization. Returns: a ``nn.Module`` for normalization. """ assert groups > 0, f"invalid number of groups: {groups}" return nn.GroupNorm(groups, channels) def Linear(*args, **kwargs): r"""Wrapper of ``nn.Linear`` with kaiming_normal_ initialization.""" layer = nn.Linear(*args, **kwargs) nn.init.kaiming_normal_(layer.weight) return layer def Conv1d(*args, **kwargs): r"""Wrapper of ``nn.Conv1d`` with kaiming_normal_ initialization.""" layer = nn.Conv1d(*args, **kwargs) nn.init.kaiming_normal_(layer.weight) return layer def Conv2d(*args, **kwargs): r"""Wrapper of ``nn.Conv2d`` with kaiming_normal_ initialization.""" layer = nn.Conv2d(*args, **kwargs) nn.init.kaiming_normal_(layer.weight) return layer def ConvNd(dims: int = 1, *args, **kwargs): r"""Wrapper of N-dimension convolution with kaiming_normal_ initialization. Args: dims: number of dimensions of the convolution. """ if dims == 1: return Conv1d(*args, **kwargs) elif dims == 2: return Conv2d(*args, **kwargs) else: raise ValueError(f"invalid number of dimensions: {dims}") def zero_module(module: nn.Module): r"""Zero out the parameters of a module and return it.""" nn.init.zeros_(module.weight) nn.init.zeros_(module.bias) return module def scale_module(module: nn.Module, scale): r"""Scale the parameters of a module and return it.""" for p in module.parameters(): p.detach().mul_(scale) return module def mean_flat(tensor: torch.Tensor): r"""Take the mean over all non-batch dimensions.""" return tensor.mean(dim=tuple(range(1, tensor.dim()))) def append_dims(x, target_dims): r"""Appends dimensions to the end of a tensor until it has target_dims dimensions. """ dims_to_append = target_dims - x.dim() if dims_to_append < 0: raise ValueError( f"input has {x.dim()} dims but target_dims is {target_dims}, which is less" ) return x[(...,) + (None,) * dims_to_append] def append_zero(x, count=1): r"""Appends ``count`` zeros to the end of a tensor along the last dimension.""" assert count > 0, f"invalid count: {count}" return torch.cat([x, x.new_zeros((*x.size()[:-1], count))], dim=-1) class Transpose(nn.Identity): """(N, T, D) -> (N, D, T)""" def forward(self, input: torch.Tensor) -> torch.Tensor: return input.transpose(1, 2)