import torch from torch.utils.data import Dataset, DataLoader from torch import nn from transformers import ASTConfig, ASTFeatureExtractor, ASTModel class BirdAST(nn.Module): def __init__(self, backbone_name, n_classes, n_mlp_layers=1, activation='silu'): super(BirdAST, self).__init__() # pre-trained backbone backbone_config = ASTConfig.from_pretrained(backbone_name) self.ast = ASTModel.from_pretrained(backbone_name, config=backbone_config) self.hidden_size = backbone_config.hidden_size # set activation functions if activation == 'relu': self.activation = nn.ReLU() elif activation == 'silu': self.activation = nn.SiLU() else: raise ValueError("Unsupported activation function. Choose 'relu' or 'silu'.") # define MLP layers with activation layers = [] for _ in range(n_mlp_layers): layers.append(nn.Linear(self.hidden_size, self.hidden_size)) layers.append(self.activation) layers.append(nn.Linear(self.hidden_size, n_classes)) self.mlp = nn.Sequential(*layers) def forward(self, spectrogram): # spectrogram: (batch_size, n_mels, n_frames) # output: (batch_size, n_classes) ast_output = self.ast(spectrogram, output_hidden_states=False) logits = self.mlp(ast_output.last_hidden_state[:, 0, :]) # Use the CLS token return {'logits': logits}