# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math import torch import torch.nn as nn from modules.general.utils import Linear class PositionEncoder(nn.Module): r"""Encoder of positional embedding, generates PE and then feed into 2 full-connected layers with ``SiLU``. Args: d_raw_emb: The dimension of raw embedding vectors. d_out: The dimension of output embedding vectors, default to ``d_raw_emb``. d_mlp: The dimension of hidden layer in MLP, default to ``d_raw_emb`` * 4. activation_function: The activation function used in MLP, default to ``SiLU``. n_layer: The number of layers in MLP, default to 2. max_period: controls the minimum frequency of the embeddings. """ def __init__( self, d_raw_emb: int = 128, d_out: int = None, d_mlp: int = None, activation_function: str = "SiLU", n_layer: int = 2, max_period: int = 10000, ): super().__init__() self.d_raw_emb = d_raw_emb self.d_out = d_raw_emb if d_out is None else d_out self.d_mlp = d_raw_emb * 4 if d_mlp is None else d_mlp self.n_layer = n_layer self.max_period = max_period if activation_function.lower() == "silu": self.activation_function = "SiLU" elif activation_function.lower() == "relu": self.activation_function = "ReLU" elif activation_function.lower() == "gelu": self.activation_function = "GELU" else: raise ValueError("activation_function must be one of SiLU, ReLU, GELU") self.activation_function = activation_function tmp = [Linear(self.d_raw_emb, self.d_mlp), getattr(nn, activation_function)()] for _ in range(self.n_layer - 1): tmp.append(Linear(self.d_mlp, self.d_mlp)) tmp.append(getattr(nn, activation_function)()) tmp.append(Linear(self.d_mlp, self.d_out)) self.out = nn.Sequential(*tmp) def forward(self, steps: torch.Tensor) -> torch.Tensor: r"""Create and return sinusoidal timestep embeddings directly. Args: steps: a 1D Tensor of N indices, one per batch element. These may be fractional. Returns: an [N x ``d_out``] Tensor of positional embeddings. """ half = self.d_raw_emb // 2 freqs = torch.exp( -math.log(self.max_period) / half * torch.arange(half, dtype=torch.float32, device=steps.device) ) args = steps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if self.d_raw_emb % 2: embedding = torch.cat( [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 ) return self.out(embedding)