File size: 2,990 Bytes
b725c5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math

import torch
import torch.nn as nn

from modules.general.utils import Linear


class PositionEncoder(nn.Module):
    r"""Encoder of positional embedding, generates PE and then
    feed into 2 full-connected layers with ``SiLU``.

    Args:
        d_raw_emb: The dimension of raw embedding vectors.
        d_out: The dimension of output embedding vectors, default to ``d_raw_emb``.
        d_mlp: The dimension of hidden layer in MLP, default to ``d_raw_emb`` * 4.
        activation_function: The activation function used in MLP, default to ``SiLU``.
        n_layer: The number of layers in MLP, default to 2.
        max_period: controls the minimum frequency of the embeddings.
    """

    def __init__(
        self,
        d_raw_emb: int = 128,
        d_out: int = None,
        d_mlp: int = None,
        activation_function: str = "SiLU",
        n_layer: int = 2,
        max_period: int = 10000,
    ):
        super().__init__()

        self.d_raw_emb = d_raw_emb
        self.d_out = d_raw_emb if d_out is None else d_out
        self.d_mlp = d_raw_emb * 4 if d_mlp is None else d_mlp
        self.n_layer = n_layer
        self.max_period = max_period

        if activation_function.lower() == "silu":
            self.activation_function = "SiLU"
        elif activation_function.lower() == "relu":
            self.activation_function = "ReLU"
        elif activation_function.lower() == "gelu":
            self.activation_function = "GELU"
        else:
            raise ValueError("activation_function must be one of SiLU, ReLU, GELU")
        self.activation_function = activation_function

        tmp = [Linear(self.d_raw_emb, self.d_mlp), getattr(nn, activation_function)()]
        for _ in range(self.n_layer - 1):
            tmp.append(Linear(self.d_mlp, self.d_mlp))
            tmp.append(getattr(nn, activation_function)())
        tmp.append(Linear(self.d_mlp, self.d_out))

        self.out = nn.Sequential(*tmp)

    def forward(self, steps: torch.Tensor) -> torch.Tensor:
        r"""Create and return sinusoidal timestep embeddings directly.

        Args:
            steps: a 1D Tensor of N indices, one per batch element.
                These may be fractional.

        Returns:
            an [N x ``d_out``] Tensor of positional embeddings.
        """

        half = self.d_raw_emb // 2
        freqs = torch.exp(
            -math.log(self.max_period)
            / half
            * torch.arange(half, dtype=torch.float32, device=steps.device)
        )
        args = steps[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if self.d_raw_emb % 2:
            embedding = torch.cat(
                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
            )
        return self.out(embedding)