akhaliq3
spaces demo
5019931
from typing import List, NoReturn
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def init_embedding(layer: nn.Module) -> NoReturn:
r"""Initialize a Linear or Convolutional layer."""
nn.init.uniform_(layer.weight, -1.0, 1.0)
if hasattr(layer, 'bias'):
if layer.bias is not None:
layer.bias.data.fill_(0.0)
def init_layer(layer: nn.Module) -> NoReturn:
r"""Initialize a Linear or Convolutional layer."""
nn.init.xavier_uniform_(layer.weight)
if hasattr(layer, "bias"):
if layer.bias is not None:
layer.bias.data.fill_(0.0)
def init_bn(bn: nn.Module) -> NoReturn:
r"""Initialize a Batchnorm layer."""
bn.bias.data.fill_(0.0)
bn.weight.data.fill_(1.0)
bn.running_mean.data.fill_(0.0)
bn.running_var.data.fill_(1.0)
def act(x: torch.Tensor, activation: str) -> torch.Tensor:
if activation == "relu":
return F.relu_(x)
elif activation == "leaky_relu":
return F.leaky_relu_(x, negative_slope=0.01)
elif activation == "swish":
return x * torch.sigmoid(x)
else:
raise Exception("Incorrect activation!")
class Base:
def __init__(self):
r"""Base function for extracting spectrogram, cos, and sin, etc."""
pass
def spectrogram(self, input: torch.Tensor, eps: float = 0.0) -> torch.Tensor:
r"""Calculate spectrogram.
Args:
input: (batch_size, segments_num)
eps: float
Returns:
spectrogram: (batch_size, time_steps, freq_bins)
"""
(real, imag) = self.stft(input)
return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
def spectrogram_phase(
self, input: torch.Tensor, eps: float = 0.0
) -> List[torch.Tensor]:
r"""Calculate the magnitude, cos, and sin of the STFT of input.
Args:
input: (batch_size, segments_num)
eps: float
Returns:
mag: (batch_size, time_steps, freq_bins)
cos: (batch_size, time_steps, freq_bins)
sin: (batch_size, time_steps, freq_bins)
"""
(real, imag) = self.stft(input)
mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
cos = real / mag
sin = imag / mag
return mag, cos, sin
def wav_to_spectrogram_phase(
self, input: torch.Tensor, eps: float = 1e-10
) -> List[torch.Tensor]:
r"""Convert waveforms to magnitude, cos, and sin of STFT.
Args:
input: (batch_size, channels_num, segment_samples)
eps: float
Outputs:
mag: (batch_size, channels_num, time_steps, freq_bins)
cos: (batch_size, channels_num, time_steps, freq_bins)
sin: (batch_size, channels_num, time_steps, freq_bins)
"""
batch_size, channels_num, segment_samples = input.shape
# Reshape input with shapes of (n, segments_num) to meet the
# requirements of the stft function.
x = input.reshape(batch_size * channels_num, segment_samples)
mag, cos, sin = self.spectrogram_phase(x, eps=eps)
# mag, cos, sin: (batch_size * channels_num, 1, time_steps, freq_bins)
_, _, time_steps, freq_bins = mag.shape
mag = mag.reshape(batch_size, channels_num, time_steps, freq_bins)
cos = cos.reshape(batch_size, channels_num, time_steps, freq_bins)
sin = sin.reshape(batch_size, channels_num, time_steps, freq_bins)
return mag, cos, sin
def wav_to_spectrogram(
self, input: torch.Tensor, eps: float = 1e-10
) -> List[torch.Tensor]:
mag, cos, sin = self.wav_to_spectrogram_phase(input, eps)
return mag
class Subband:
def __init__(self, subbands_num: int):
r"""Warning!! This class is not used!!
This class does not work as good as [1] which split subbands in the
time-domain. Please refere to [1] for formal implementation.
[1] Liu, Haohe, et al. "Channel-wise subband input for better voice and
accompaniment separation on high resolution music." arXiv preprint arXiv:2008.05216 (2020).
Args:
subbands_num: int, e.g., 4
"""
self.subbands_num = subbands_num
def analysis(self, x: torch.Tensor) -> torch.Tensor:
r"""Analysis time-frequency representation into subbands. Stack the
subbands along the channel axis.
Args:
x: (batch_size, channels_num, time_steps, freq_bins)
Returns:
output: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num)
"""
batch_size, channels_num, time_steps, freq_bins = x.shape
x = x.reshape(
batch_size,
channels_num,
time_steps,
self.subbands_num,
freq_bins // self.subbands_num,
)
# x: (batch_size, channels_num, time_steps, subbands_num, freq_bins // subbands_num)
x = x.transpose(2, 3)
output = x.reshape(
batch_size,
channels_num * self.subbands_num,
time_steps,
freq_bins // self.subbands_num,
)
# output: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num)
return output
def synthesis(self, x: torch.Tensor) -> torch.Tensor:
r"""Synthesis subband time-frequency representations into original
time-frequency representation.
Args:
x: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num)
Returns:
output: (batch_size, channels_num, time_steps, freq_bins)
"""
batch_size, subband_channels_num, time_steps, subband_freq_bins = x.shape
channels_num = subband_channels_num // self.subbands_num
freq_bins = subband_freq_bins * self.subbands_num
x = x.reshape(
batch_size,
channels_num,
self.subbands_num,
time_steps,
subband_freq_bins,
)
# x: (batch_size, channels_num, subbands_num, time_steps, freq_bins // subbands_num)
x = x.transpose(2, 3)
# x: (batch_size, channels_num, time_steps, subbands_num, freq_bins // subbands_num)
output = x.reshape(batch_size, channels_num, time_steps, freq_bins)
# x: (batch_size, channels_num, time_steps, freq_bins)
return output