Applio / rvc /lib /algorithm /normalization.py
Vidalnt
fix: :art: Update Applio
39f384e
raw
history blame contribute delete
953 Bytes
import torch
class LayerNorm(torch.nn.Module):
"""Layer normalization module.
Args:
channels (int): Number of channels.
eps (float, optional): Epsilon value for numerical stability. Defaults to 1e-5.
"""
def __init__(self, channels: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.gamma = torch.nn.Parameter(torch.ones(channels))
self.beta = torch.nn.Parameter(torch.zeros(channels))
def forward(self, x):
"""Forward pass.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, channels, time_steps).
"""
# Transpose to (batch_size, time_steps, channels) for layer_norm
x = x.transpose(1, -1)
x = torch.nn.functional.layer_norm(
x, (x.size(-1),), self.gamma, self.beta, self.eps
)
# Transpose back to (batch_size, channels, time_steps)
return x.transpose(1, -1)