import torch from torch import nn import torch.nn.functional as F from transformers import ASTConfig, ASTFeatureExtractor, ASTModel BirdAST_FEATURE_EXTRACTOR = ASTFeatureExtractor() DEFAULT_SR = 16_000 DEFAULT_BACKBONE = "MIT/ast-finetuned-audioset-10-10-0.4593" DEFAULT_N_CLASSES = 728 DEFAULT_ACTIVATION = "silu" DEFAULT_N_MLP_LAYERS = 1 def birdast_preprocess(audio_array, sr=DEFAULT_SR): """ Preprocess audio array for BirdAST model audio_array: np.array, audio array of the recording, shape (n_samples,) Note: The audio array should be normalized to [-1, 1] sr: int, sampling rate of the audio array (default: 16_000) Note: 1. The audio array should be normalized to [-1, 1]. 2. The audio length should be 10 seconds (or 10.24 seconds). Longer audio will be truncated. """ # Extract features features = BirdAST_FEATURE_EXTRACTOR(audio_array, sampling_rate=sr, padding="max_length", return_tensors="pt") # Convert to PyTorch tensor spectrogram = torch.tensor(features['input_values']).squeeze(0) return spectrogram def birdast_inference( model_weights, spectrogram, device = 'cpu', backbone_name=DEFAULT_BACKBONE, n_classes=DEFAULT_N_CLASSES, activation=DEFAULT_ACTIVATION, n_mlp_layers=DEFAULT_N_MLP_LAYERS ): """ Perform inference on BirdAST model model_weights: list, list of model weights spectrogram: torch.Tensor, spectrogram tensor, shape (batch_size, n_frames, n_mels,) device: str, device to run inference (default: 'cpu') backbone_name: str, name of the backbone model (default: 'MIT/ast-finetuned-audioset-10-10-0.4593') n_classes: int, number of classes (default: 728) activation: str, activation function (default: 'silu') n_mlp_layers: int, number of MLP layers (default: 1) Returns: predictions: np.array, array of predictions, shape (n_models, batch_size, n_classes) """ model = BirdAST( backbone_name=backbone_name, n_classes=n_classes, n_mlp_layers=n_mlp_layers, activation=activation ) predict_collects = [] for _weights in model_weights: model.load_state_dict(torch.load(_weights, map_location=device)) if device != 'cpu': model.to(device) model.eval() with torch.no_grad(): if device != 'cpu': spectrogram = spectrogram.to(device) # check if the input tensor is in the correct shape if spectrogram.dim() == 2: spectrogram = spectrogram.unsqueeze(0) # -> (batch_size, n_frames, n_mels) output = model(spectrogram) logits = output['logits'] probs = F.softmax(logits, dim=-1) predict_collects.append(probs) if device != 'cpu': predict_collects = [pred.cpu() for pred in predict_collects] predict_collects = torch.cat(predict_collects, dim=0).numpy() return predict_collects class BirdAST(nn.Module): def __init__(self, backbone_name, n_classes, n_mlp_layers=1, activation='silu'): super(BirdAST, self).__init__() # pre-trained backbone backbone_config = ASTConfig.from_pretrained(backbone_name) self.ast = ASTModel.from_pretrained(backbone_name, config=backbone_config) self.hidden_size = backbone_config.hidden_size # set activation functions if activation == 'relu': self.activation = nn.ReLU() elif activation == 'silu': self.activation = nn.SiLU() else: raise ValueError("Unsupported activation function. Choose 'relu' or 'silu'.") # define MLP layers with activation layers = [] for _ in range(n_mlp_layers): layers.append(nn.Linear(self.hidden_size, self.hidden_size)) layers.append(self.activation) layers.append(nn.Linear(self.hidden_size, n_classes)) self.mlp = nn.Sequential(*layers) def forward(self, spectrogram): # spectrogram: (batch_size, n_frames, n_mels) # output: (batch_size, n_classes) ast_output = self.ast(spectrogram, output_hidden_states=False) logits = self.mlp(ast_output.last_hidden_state[:, 0, :]) # Use the CLS token return {'logits': logits} if __name__ == '__main__': import numpy as np import matplotlib.pyplot as plt # example usage of BirdAST_Seq # create random audio array audio_array = np.random.randn(160_000 * 10) # Preprocess audio array spectrogram = birdast_preprocess(audio_array) model_weights_dir = '/workspace/voice_of_jungle/training_logs' # Load model weights model_weights = [f'{model_weights_dir}/BirdAST_Baseline_GroupKFold_fold_{i}.pth' for i in range(5)] # Perform inference predictions = birdast_inference(model_weights, spectrogram.unsqueeze(0)) # Plot predictions fig, ax = plt.subplots() for i, pred in enumerate(predictions): ax.plot(pred[0], label=f'model_{i}') ax.legend() fig.savefig('test_BirdAST_Seq.png') print("Inference completed successfully!")