Spaces:
Build error
Build error
import math | |
import torch | |
import librosa | |
# based on https://github.com/neuralaudio/hear-baseline/blob/main/hearbaseline/naive.py | |
class RandomMelProjection(torch.nn.Module): | |
def __init__( | |
self, | |
sample_rate, | |
embed_dim=4096, | |
n_mels=128, | |
n_fft=4096, | |
hop_size=1024, | |
seed=0, | |
epsilon=1e-4, | |
): | |
super().__init__() | |
self.sample_rate = sample_rate | |
self.embed_dim = embed_dim | |
self.n_mels = n_mels | |
self.n_fft = n_fft | |
self.hop_size = hop_size | |
self.seed = seed | |
self.epsilon = epsilon | |
# Set random seed | |
torch.random.manual_seed(self.seed) | |
# Create a Hann window buffer to apply to frames prior to FFT. | |
self.register_buffer("window", torch.hann_window(self.n_fft)) | |
# Create a mel filter buffer. | |
mel_scale = torch.tensor( | |
librosa.filters.mel( | |
self.sample_rate, | |
n_fft=self.n_fft, | |
n_mels=self.n_mels, | |
) | |
) | |
self.register_buffer("mel_scale", mel_scale) | |
# Projection matrices. | |
normalization = math.sqrt(self.n_mels) | |
self.projection = torch.nn.Parameter( | |
torch.rand(self.n_mels, self.embed_dim) / normalization, | |
requires_grad=False, | |
) | |
def forward(self, x): | |
bs, chs, samp = x.size() | |
x = torch.stft( | |
x.view(bs, -1), | |
self.n_fft, | |
self.hop_size, | |
window=self.window, | |
return_complex=True, | |
) | |
x = x.unsqueeze(1).permute(0, 1, 3, 2) | |
# Apply the mel-scale filter to the power spectrum. | |
x = torch.matmul(x.abs(), self.mel_scale.transpose(0, 1)) | |
# power scale | |
x = torch.pow(x + self.epsilon, 0.3) | |
# apply random projection | |
e = x.matmul(self.projection) | |
# take mean across temporal dim | |
e = e.mean(dim=2).view(bs, -1) | |
return e | |
def compute_frame_embedding(self, x): | |
# Compute the real-valued Fourier transform on windowed input signal. | |
x = torch.fft.rfft(x * self.window) | |
# Convert to a power spectrum. | |
x = torch.abs(x) ** 2.0 | |
# Apply the mel-scale filter to the power spectrum. | |
x = torch.matmul(x, self.mel_scale.transpose(0, 1)) | |
# Convert to a log mel spectrum. | |
x = torch.log(x + self.epsilon) | |
# Apply projection to get a 4096 dimension embedding | |
embedding = x.matmul(self.projection) | |
return embedding | |