File size: 2,569 Bytes
66a6dc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import math
import torch
import librosa

# based on https://github.com/neuralaudio/hear-baseline/blob/main/hearbaseline/naive.py


class RandomMelProjection(torch.nn.Module):
    def __init__(
        self,
        sample_rate,
        embed_dim=4096,
        n_mels=128,
        n_fft=4096,
        hop_size=1024,
        seed=0,
        epsilon=1e-4,
    ):
        super().__init__()
        self.sample_rate = sample_rate
        self.embed_dim = embed_dim
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_size = hop_size
        self.seed = seed
        self.epsilon = epsilon

        # Set random seed
        torch.random.manual_seed(self.seed)

        # Create a Hann window buffer to apply to frames prior to FFT.
        self.register_buffer("window", torch.hann_window(self.n_fft))

        # Create a mel filter buffer.
        mel_scale = torch.tensor(
            librosa.filters.mel(
                self.sample_rate,
                n_fft=self.n_fft,
                n_mels=self.n_mels,
            )
        )
        self.register_buffer("mel_scale", mel_scale)

        # Projection matrices.
        normalization = math.sqrt(self.n_mels)
        self.projection = torch.nn.Parameter(
            torch.rand(self.n_mels, self.embed_dim) / normalization,
            requires_grad=False,
        )

    def forward(self, x):
        bs, chs, samp = x.size()

        x = torch.stft(
            x.view(bs, -1),
            self.n_fft,
            self.hop_size,
            window=self.window,
            return_complex=True,
        )
        x = x.unsqueeze(1).permute(0, 1, 3, 2)

        # Apply the mel-scale filter to the power spectrum.
        x = torch.matmul(x.abs(), self.mel_scale.transpose(0, 1))

        # power scale
        x = torch.pow(x + self.epsilon, 0.3)

        # apply random projection
        e = x.matmul(self.projection)

        # take mean across temporal dim
        e = e.mean(dim=2).view(bs, -1)

        return e

    def compute_frame_embedding(self, x):
        # Compute the real-valued Fourier transform on windowed input signal.
        x = torch.fft.rfft(x * self.window)

        # Convert to a power spectrum.
        x = torch.abs(x) ** 2.0

        # Apply the mel-scale filter to the power spectrum.
        x = torch.matmul(x, self.mel_scale.transpose(0, 1))

        # Convert to a log mel spectrum.
        x = torch.log(x + self.epsilon)

        # Apply projection to get a 4096 dimension embedding
        embedding = x.matmul(self.projection)

        return embedding