import torch from deepafx_st.models.mobilenetv2 import MobileNetV2 from deepafx_st.models.efficient_net import EfficientNet class SpectralEncoder(torch.nn.Module): def __init__( self, num_params, sample_rate, encoder_model="mobilenet_v2", embed_dim=1028, width_mult=1, min_level_db=-80, ): """Encoder operating on spectrograms. Args: num_params (int): Number of processor parameters to generate. sample_rate (float): Audio sample rate for computing melspectrogram. encoder_model (str, optional): Encoder model architecture. Default: "mobilenet_v2" embed_dim (int, optional): Dimentionality of the encoder representations. width_mult (int, optional): Encoder size. Default: 1 min_level_db (float, optional): Minimal dB value for the spectrogram. Default: -80 """ super().__init__() self.num_params = num_params self.sample_rate = sample_rate self.encoder_model = encoder_model self.embed_dim = embed_dim self.width_mult = width_mult self.min_level_db = min_level_db # load model from torch.hub if encoder_model == "mobilenet_v2": self.encoder = MobileNetV2(embed_dim=embed_dim, width_mult=width_mult) elif encoder_model == "efficient_net": self.encoder = EfficientNet.from_name( "efficientnet-b2", in_channels=1, image_size=(128, 65), include_top=False, ) self.embedding_projection = torch.nn.Conv2d( in_channels=1408, out_channels=embed_dim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True, ) else: raise ValueError(f"Invalid encoder_model: {encoder_model}.") self.window = torch.nn.Parameter(torch.hann_window(4096)) def forward(self, x): """ Args: x (Tensor): Input waveform of shape [batch x channels x samples] Returns: e (Tensor): Latent embedding produced by Encoder. [batch x embed_dim] """ bs, chs, samp = x.size() # compute spectrogram of waveform X = torch.stft( x.view(bs, -1), 4096, 2048, window=self.window, return_complex=True, ) X_db = torch.pow(X.abs() + 1e-8, 0.3) X_db_norm = X_db # standardize (0, 1) 0.322970 0.278452 X_db_norm -= 0.322970 X_db_norm /= 0.278452 X_db_norm = X_db_norm.unsqueeze(1).permute(0, 1, 3, 2) if self.encoder_model == "mobilenet_v2": # repeat channels by 3 to fit vision model X_db_norm = X_db_norm.repeat(1, 3, 1, 1) # pass melspectrogram through encoder e = self.encoder(X_db_norm) # apply avg pooling across time for encoder embeddings e = torch.nn.functional.adaptive_avg_pool2d(e, 1).reshape(e.shape[0], -1) # normalize by L2 norm norm = torch.norm(e, p=2, dim=-1, keepdim=True) e_norm = e / norm elif self.encoder_model == "efficient_net": # Efficient Net internal downsamples by 32 on time and freq axis, then average pools the rest e = self.encoder(X_db_norm) # Adding 1x1 conv to project down or up to the requested embedding size e = self.embedding_projection(e) e = torch.squeeze(e, dim=3) e = torch.squeeze(e, dim=2) # normalize by L2 norm norm = torch.norm(e, p=2, dim=-1, keepdim=True) e_norm = e / norm return e_norm