JohnnyPittt's picture
Duplicate from nateraw/deepafx-st
51da11a
import math
import torch
import librosa
# based on https://github.com/neuralaudio/hear-baseline/blob/main/hearbaseline/naive.py
class RandomMelProjection(torch.nn.Module):
def __init__(
self,
sample_rate,
embed_dim=4096,
n_mels=128,
n_fft=4096,
hop_size=1024,
seed=0,
epsilon=1e-4,
):
super().__init__()
self.sample_rate = sample_rate
self.embed_dim = embed_dim
self.n_mels = n_mels
self.n_fft = n_fft
self.hop_size = hop_size
self.seed = seed
self.epsilon = epsilon
# Set random seed
torch.random.manual_seed(self.seed)
# Create a Hann window buffer to apply to frames prior to FFT.
self.register_buffer("window", torch.hann_window(self.n_fft))
# Create a mel filter buffer.
mel_scale = torch.tensor(
librosa.filters.mel(
self.sample_rate,
n_fft=self.n_fft,
n_mels=self.n_mels,
)
)
self.register_buffer("mel_scale", mel_scale)
# Projection matrices.
normalization = math.sqrt(self.n_mels)
self.projection = torch.nn.Parameter(
torch.rand(self.n_mels, self.embed_dim) / normalization,
requires_grad=False,
)
def forward(self, x):
bs, chs, samp = x.size()
x = torch.stft(
x.view(bs, -1),
self.n_fft,
self.hop_size,
window=self.window,
return_complex=True,
)
x = x.unsqueeze(1).permute(0, 1, 3, 2)
# Apply the mel-scale filter to the power spectrum.
x = torch.matmul(x.abs(), self.mel_scale.transpose(0, 1))
# power scale
x = torch.pow(x + self.epsilon, 0.3)
# apply random projection
e = x.matmul(self.projection)
# take mean across temporal dim
e = e.mean(dim=2).view(bs, -1)
return e
def compute_frame_embedding(self, x):
# Compute the real-valued Fourier transform on windowed input signal.
x = torch.fft.rfft(x * self.window)
# Convert to a power spectrum.
x = torch.abs(x) ** 2.0
# Apply the mel-scale filter to the power spectrum.
x = torch.matmul(x, self.mel_scale.transpose(0, 1))
# Convert to a log mel spectrum.
x = torch.log(x + self.epsilon)
# Apply projection to get a 4096 dimension embedding
embedding = x.matmul(self.projection)
return embedding