File size: 3,837 Bytes
66a6dc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch

from deepafx_st.models.mobilenetv2 import MobileNetV2
from deepafx_st.models.efficient_net import EfficientNet


class SpectralEncoder(torch.nn.Module):
    def __init__(
        self,
        num_params,
        sample_rate,
        encoder_model="mobilenet_v2",
        embed_dim=1028,
        width_mult=1,
        min_level_db=-80,
    ):
        """Encoder operating on spectrograms.

        Args:
            num_params (int): Number of processor parameters to generate.
            sample_rate (float): Audio sample rate for computing melspectrogram.
            encoder_model (str, optional): Encoder model architecture. Default: "mobilenet_v2"
            embed_dim (int, optional): Dimentionality of the encoder representations.
            width_mult (int, optional): Encoder size. Default: 1
            min_level_db (float, optional): Minimal dB value for the spectrogram. Default: -80
        """
        super().__init__()
        self.num_params = num_params
        self.sample_rate = sample_rate
        self.encoder_model = encoder_model
        self.embed_dim = embed_dim
        self.width_mult = width_mult
        self.min_level_db = min_level_db

        # load model from torch.hub
        if encoder_model == "mobilenet_v2":
            self.encoder = MobileNetV2(embed_dim=embed_dim, width_mult=width_mult)
        elif encoder_model == "efficient_net":
            self.encoder = EfficientNet.from_name(
                "efficientnet-b2",
                in_channels=1,
                image_size=(128, 65),
                include_top=False,
            )
            self.embedding_projection = torch.nn.Conv2d(
                in_channels=1408,
                out_channels=embed_dim,
                kernel_size=(1, 1),
                stride=(1, 1),
                padding=(0, 0),
                bias=True,
            )

        else:
            raise ValueError(f"Invalid encoder_model: {encoder_model}.")

        self.window = torch.nn.Parameter(torch.hann_window(4096))

    def forward(self, x):
        """
        Args:
            x (Tensor): Input waveform of shape [batch x channels x samples]

        Returns:
            e (Tensor): Latent embedding produced by Encoder. [batch x embed_dim]
        """
        bs, chs, samp = x.size()

        # compute spectrogram of waveform
        X = torch.stft(
            x.view(bs, -1),
            4096,
            2048,
            window=self.window,
            return_complex=True,
        )
        X_db = torch.pow(X.abs() + 1e-8, 0.3)
        X_db_norm = X_db

        # standardize (0, 1) 0.322970 0.278452
        X_db_norm -= 0.322970
        X_db_norm /= 0.278452
        X_db_norm = X_db_norm.unsqueeze(1).permute(0, 1, 3, 2)

        if self.encoder_model == "mobilenet_v2":
            # repeat channels by 3 to fit vision model
            X_db_norm = X_db_norm.repeat(1, 3, 1, 1)

            # pass melspectrogram through encoder
            e = self.encoder(X_db_norm)

            # apply avg pooling across time for encoder embeddings
            e = torch.nn.functional.adaptive_avg_pool2d(e, 1).reshape(e.shape[0], -1)

            # normalize by L2 norm
            norm = torch.norm(e, p=2, dim=-1, keepdim=True)
            e_norm = e / norm

        elif self.encoder_model == "efficient_net":

            # Efficient Net internal downsamples by 32 on time and freq axis, then average pools the rest
            e = self.encoder(X_db_norm)

            # Adding 1x1 conv to project down or up to the requested embedding size
            e = self.embedding_projection(e)
            e = torch.squeeze(e, dim=3)
            e = torch.squeeze(e, dim=2)

            # normalize by L2 norm
            norm = torch.norm(e, p=2, dim=-1, keepdim=True)
            e_norm = e / norm

        return e_norm