import numpy as np import torch.nn.functional as F from torch import nn from .model import MLPLayers class LinearProbe(nn.Module): def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None): """ Args: model: nn.Module mlp: bool, if True, then use the MLP layer as the linear probe module freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe in_ch: int, the output channel from CLAP model out_ch: int, the output channel from linear probe (class_num) act: torch.nn.functional, the activation function before the loss function """ super().__init__() in_ch = 512 self.clap_model = model self.clap_model.text_branch = None # to save memory self.freeze = freeze if mlp: self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch]) else: self.lp_layer = nn.Linear(in_ch, out_ch) if self.freeze: for param in self.clap_model.parameters(): param.requires_grad = False if act == "None": self.act = None elif act == "relu": self.act = nn.ReLU() elif act == "elu": self.act = nn.ELU() elif act == "prelu": self.act = nn.PReLU(num_parameters=in_ch) elif act == "softmax": self.act = nn.Softmax(dim=-1) elif act == "sigmoid": self.act = nn.Sigmoid() def forward(self, x, mix_lambda=None, device=None): """ Args: x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list mix_lambda: torch.tensor [batch], the mixup lambda Returns: class_prob: torch.tensor [batch, class_num] """ # batchnorm cancel grandient if self.freeze: self.clap_model.eval() x = self.clap_model.audio_projection( self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)[ "embedding" ] ) out = self.lp_layer(x) if self.act is not None: out = self.act(out) return out