File size: 2,885 Bytes
b725c5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn


def normalization(channels: int, groups: int = 32):
    r"""Make a standard normalization layer, i.e. GroupNorm.

    Args:
        channels: number of input channels.
        groups: number of groups for group normalization.

    Returns:
        a ``nn.Module`` for normalization.
    """
    assert groups > 0, f"invalid number of groups: {groups}"
    return nn.GroupNorm(groups, channels)


def Linear(*args, **kwargs):
    r"""Wrapper of ``nn.Linear`` with kaiming_normal_ initialization."""
    layer = nn.Linear(*args, **kwargs)
    nn.init.kaiming_normal_(layer.weight)
    return layer


def Conv1d(*args, **kwargs):
    r"""Wrapper of ``nn.Conv1d`` with kaiming_normal_ initialization."""
    layer = nn.Conv1d(*args, **kwargs)
    nn.init.kaiming_normal_(layer.weight)
    return layer


def Conv2d(*args, **kwargs):
    r"""Wrapper of ``nn.Conv2d`` with kaiming_normal_ initialization."""
    layer = nn.Conv2d(*args, **kwargs)
    nn.init.kaiming_normal_(layer.weight)
    return layer


def ConvNd(dims: int = 1, *args, **kwargs):
    r"""Wrapper of N-dimension convolution with kaiming_normal_ initialization.

    Args:
        dims: number of dimensions of the convolution.
    """
    if dims == 1:
        return Conv1d(*args, **kwargs)
    elif dims == 2:
        return Conv2d(*args, **kwargs)
    else:
        raise ValueError(f"invalid number of dimensions: {dims}")


def zero_module(module: nn.Module):
    r"""Zero out the parameters of a module and return it."""
    nn.init.zeros_(module.weight)
    nn.init.zeros_(module.bias)
    return module


def scale_module(module: nn.Module, scale):
    r"""Scale the parameters of a module and return it."""
    for p in module.parameters():
        p.detach().mul_(scale)
    return module


def mean_flat(tensor: torch.Tensor):
    r"""Take the mean over all non-batch dimensions."""
    return tensor.mean(dim=tuple(range(1, tensor.dim())))


def append_dims(x, target_dims):
    r"""Appends dimensions to the end of a tensor until
    it has target_dims dimensions.
    """
    dims_to_append = target_dims - x.dim()
    if dims_to_append < 0:
        raise ValueError(
            f"input has {x.dim()} dims but target_dims is {target_dims}, which is less"
        )
    return x[(...,) + (None,) * dims_to_append]


def append_zero(x, count=1):
    r"""Appends ``count`` zeros to the end of a tensor along the last dimension."""
    assert count > 0, f"invalid count: {count}"
    return torch.cat([x, x.new_zeros((*x.size()[:-1], count))], dim=-1)


class Transpose(nn.Identity):
    """(N, T, D) -> (N, D, T)"""

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return input.transpose(1, 2)