|
import numpy as np |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from .model import MLPLayers |
|
|
|
|
|
class LinearProbe(nn.Module): |
|
def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None): |
|
""" |
|
Args: |
|
model: nn.Module |
|
mlp: bool, if True, then use the MLP layer as the linear probe module |
|
freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe |
|
in_ch: int, the output channel from CLAP model |
|
out_ch: int, the output channel from linear probe (class_num) |
|
act: torch.nn.functional, the activation function before the loss function |
|
""" |
|
super().__init__() |
|
in_ch = 512 |
|
self.clap_model = model |
|
self.clap_model.text_branch = None |
|
self.freeze = freeze |
|
if mlp: |
|
self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch]) |
|
else: |
|
self.lp_layer = nn.Linear(in_ch, out_ch) |
|
|
|
if self.freeze: |
|
for param in self.clap_model.parameters(): |
|
param.requires_grad = False |
|
|
|
if act == "None": |
|
self.act = None |
|
elif act == "relu": |
|
self.act = nn.ReLU() |
|
elif act == "elu": |
|
self.act = nn.ELU() |
|
elif act == "prelu": |
|
self.act = nn.PReLU(num_parameters=in_ch) |
|
elif act == "softmax": |
|
self.act = nn.Softmax(dim=-1) |
|
elif act == "sigmoid": |
|
self.act = nn.Sigmoid() |
|
|
|
def forward(self, x, mix_lambda=None, device=None): |
|
""" |
|
Args: |
|
x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list |
|
mix_lambda: torch.tensor [batch], the mixup lambda |
|
Returns: |
|
class_prob: torch.tensor [batch, class_num] |
|
|
|
""" |
|
|
|
if self.freeze: |
|
self.clap_model.eval() |
|
|
|
x = self.clap_model.audio_projection( |
|
self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)[ |
|
"embedding" |
|
] |
|
) |
|
out = self.lp_layer(x) |
|
if self.act is not None: |
|
out = self.act(out) |
|
return out |
|
|