# PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition # Reference from https://github.com/qiuqiangkong/audioset_tagging_cnn # Some layers are re-designed for CLAP import os os.environ["NUMBA_CACHE_DIR"] = "/tmp/" import torch import torch.nn as nn import torch.nn.functional as F from torchlibrosa.stft import Spectrogram, LogmelFilterBank from torchlibrosa.augmentation import SpecAugmentation from .utils import do_mixup, interpolate from .feature_fusion import iAFF, AFF, DAF def init_layer(layer): """Initialize a Linear or Convolutional layer.""" nn.init.xavier_uniform_(layer.weight) if hasattr(layer, "bias"): if layer.bias is not None: layer.bias.data.fill_(0.0) def init_bn(bn): """Initialize a Batchnorm layer.""" bn.bias.data.fill_(0.0) bn.weight.data.fill_(1.0) class ConvBlock(nn.Module): def __init__(self, in_channels, out_channels): super(ConvBlock, self).__init__() self.conv1 = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, ) self.conv2 = nn.Conv2d( in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, ) self.bn1 = nn.BatchNorm2d(out_channels) self.bn2 = nn.BatchNorm2d(out_channels) self.init_weight() def init_weight(self): init_layer(self.conv1) init_layer(self.conv2) init_bn(self.bn1) init_bn(self.bn2) def forward(self, input, pool_size=(2, 2), pool_type="avg"): x = input x = F.relu_(self.bn1(self.conv1(x))) x = F.relu_(self.bn2(self.conv2(x))) if pool_type == "max": x = F.max_pool2d(x, kernel_size=pool_size) elif pool_type == "avg": x = F.avg_pool2d(x, kernel_size=pool_size) elif pool_type == "avg+max": x1 = F.avg_pool2d(x, kernel_size=pool_size) x2 = F.max_pool2d(x, kernel_size=pool_size) x = x1 + x2 else: raise Exception("Incorrect argument!") return x class ConvBlock5x5(nn.Module): def __init__(self, in_channels, out_channels): super(ConvBlock5x5, self).__init__() self.conv1 = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False, ) self.bn1 = nn.BatchNorm2d(out_channels) self.init_weight() def init_weight(self): init_layer(self.conv1) init_bn(self.bn1) def forward(self, input, pool_size=(2, 2), pool_type="avg"): x = input x = F.relu_(self.bn1(self.conv1(x))) if pool_type == "max": x = F.max_pool2d(x, kernel_size=pool_size) elif pool_type == "avg": x = F.avg_pool2d(x, kernel_size=pool_size) elif pool_type == "avg+max": x1 = F.avg_pool2d(x, kernel_size=pool_size) x2 = F.max_pool2d(x, kernel_size=pool_size) x = x1 + x2 else: raise Exception("Incorrect argument!") return x class AttBlock(nn.Module): def __init__(self, n_in, n_out, activation="linear", temperature=1.0): super(AttBlock, self).__init__() self.activation = activation self.temperature = temperature self.att = nn.Conv1d( in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True, ) self.cla = nn.Conv1d( in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True, ) self.bn_att = nn.BatchNorm1d(n_out) self.init_weights() def init_weights(self): init_layer(self.att) init_layer(self.cla) init_bn(self.bn_att) def forward(self, x): # x: (n_samples, n_in, n_time) norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1) cla = self.nonlinear_transform(self.cla(x)) x = torch.sum(norm_att * cla, dim=2) return x, norm_att, cla def nonlinear_transform(self, x): if self.activation == "linear": return x elif self.activation == "sigmoid": return torch.sigmoid(x) class Cnn14(nn.Module): def __init__( self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num, enable_fusion=False, fusion_type="None", ): super(Cnn14, self).__init__() window = "hann" center = True pad_mode = "reflect" ref = 1.0 amin = 1e-10 top_db = None self.enable_fusion = enable_fusion self.fusion_type = fusion_type # Spectrogram extractor self.spectrogram_extractor = Spectrogram( n_fft=window_size, hop_length=hop_size, win_length=window_size, window=window, center=center, pad_mode=pad_mode, freeze_parameters=True, ) # Logmel feature extractor self.logmel_extractor = LogmelFilterBank( sr=sample_rate, n_fft=window_size, n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, freeze_parameters=True, ) # Spec augmenter self.spec_augmenter = SpecAugmentation( time_drop_width=64, time_stripes_num=2, freq_drop_width=8, freq_stripes_num=2, ) self.bn0 = nn.BatchNorm2d(64) if (self.enable_fusion) and (self.fusion_type == "channel_map"): self.conv_block1 = ConvBlock(in_channels=4, out_channels=64) else: self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) self.fc1 = nn.Linear(2048, 2048, bias=True) self.fc_audioset = nn.Linear(2048, classes_num, bias=True) if (self.enable_fusion) and ( self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"] ): self.mel_conv1d = nn.Sequential( nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2), nn.BatchNorm1d(64), # No Relu ) if self.fusion_type == "daf_1d": self.fusion_model = DAF() elif self.fusion_type == "aff_1d": self.fusion_model = AFF(channels=64, type="1D") elif self.fusion_type == "iaff_1d": self.fusion_model = iAFF(channels=64, type="1D") if (self.enable_fusion) and ( self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] ): self.mel_conv2d = nn.Sequential( nn.Conv2d(1, 64, kernel_size=(5, 5), stride=(6, 2), padding=(2, 2)), nn.BatchNorm2d(64), nn.ReLU(inplace=True), ) if self.fusion_type == "daf_2d": self.fusion_model = DAF() elif self.fusion_type == "aff_2d": self.fusion_model = AFF(channels=64, type="2D") elif self.fusion_type == "iaff_2d": self.fusion_model = iAFF(channels=64, type="2D") self.init_weight() def init_weight(self): init_bn(self.bn0) init_layer(self.fc1) init_layer(self.fc_audioset) def forward(self, input, mixup_lambda=None, device=None): """ Input: (batch_size, data_length)""" if self.enable_fusion and input["longer"].sum() == 0: # if no audio is longer than 10s, then randomly select one audio to be longer input["longer"][torch.randint(0, input["longer"].shape[0], (1,))] = True if not self.enable_fusion: x = self.spectrogram_extractor( input["waveform"].to(device=device, non_blocking=True) ) # (batch_size, 1, time_steps, freq_bins) x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) x = x.transpose(1, 3) x = self.bn0(x) x = x.transpose(1, 3) else: longer_list = input["longer"].to(device=device, non_blocking=True) x = input["mel_fusion"].to(device=device, non_blocking=True) longer_list_idx = torch.where(longer_list)[0] x = x.transpose(1, 3) x = self.bn0(x) x = x.transpose(1, 3) if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]: new_x = x[:, 0:1, :, :].clone().contiguous() # local processing if len(longer_list_idx) > 0: fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous() FB, FC, FT, FF = fusion_x_local.size() fusion_x_local = fusion_x_local.view(FB * FC, FT, FF) fusion_x_local = torch.permute( fusion_x_local, (0, 2, 1) ).contiguous() fusion_x_local = self.mel_conv1d(fusion_x_local) fusion_x_local = fusion_x_local.view( FB, FC, FF, fusion_x_local.size(-1) ) fusion_x_local = ( torch.permute(fusion_x_local, (0, 2, 1, 3)) .contiguous() .flatten(2) ) if fusion_x_local.size(-1) < FT: fusion_x_local = torch.cat( [ fusion_x_local, torch.zeros( (FB, FF, FT - fusion_x_local.size(-1)), device=device, ), ], dim=-1, ) else: fusion_x_local = fusion_x_local[:, :, :FT] # 1D fusion new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous() new_x[longer_list_idx] = self.fusion_model( new_x[longer_list_idx], fusion_x_local ) x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :] else: x = new_x elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]: x = x # no change if self.training: x = self.spec_augmenter(x) # Mixup on spectrogram if self.training and mixup_lambda is not None: x = do_mixup(x, mixup_lambda) if (self.enable_fusion) and ( self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] ): global_x = x[:, 0:1, :, :] # global processing B, C, H, W = global_x.shape global_x = self.conv_block1(global_x, pool_size=(2, 2), pool_type="avg") if len(longer_list_idx) > 0: local_x = x[longer_list_idx, 1:, :, :].contiguous() TH = global_x.size(-2) # local processing B, C, H, W = local_x.shape local_x = local_x.view(B * C, 1, H, W) local_x = self.mel_conv2d(local_x) local_x = local_x.view( B, C, local_x.size(1), local_x.size(2), local_x.size(3) ) local_x = local_x.permute((0, 2, 1, 3, 4)).contiguous().flatten(2, 3) TB, TC, _, TW = local_x.size() if local_x.size(-2) < TH: local_x = torch.cat( [ local_x, torch.zeros( (TB, TC, TH - local_x.size(-2), TW), device=global_x.device, ), ], dim=-2, ) else: local_x = local_x[:, :, :TH, :] global_x[longer_list_idx] = self.fusion_model( global_x[longer_list_idx], local_x ) x = global_x else: x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") x = F.dropout(x, p=0.2, training=self.training) x = torch.mean(x, dim=3) latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) latent_x = latent_x1 + latent_x2 latent_x = latent_x.transpose(1, 2) latent_x = F.relu_(self.fc1(latent_x)) latent_output = interpolate(latent_x, 32) (x1, _) = torch.max(x, dim=2) x2 = torch.mean(x, dim=2) x = x1 + x2 x = F.dropout(x, p=0.5, training=self.training) x = F.relu_(self.fc1(x)) embedding = F.dropout(x, p=0.5, training=self.training) clipwise_output = torch.sigmoid(self.fc_audioset(x)) output_dict = { "clipwise_output": clipwise_output, "embedding": embedding, "fine_grained_embedding": latent_output, } return output_dict class Cnn6(nn.Module): def __init__( self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num, enable_fusion=False, fusion_type="None", ): super(Cnn6, self).__init__() window = "hann" center = True pad_mode = "reflect" ref = 1.0 amin = 1e-10 top_db = None self.enable_fusion = enable_fusion self.fusion_type = fusion_type # Spectrogram extractor self.spectrogram_extractor = Spectrogram( n_fft=window_size, hop_length=hop_size, win_length=window_size, window=window, center=center, pad_mode=pad_mode, freeze_parameters=True, ) # Logmel feature extractor self.logmel_extractor = LogmelFilterBank( sr=sample_rate, n_fft=window_size, n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, freeze_parameters=True, ) # Spec augmenter self.spec_augmenter = SpecAugmentation( time_drop_width=64, time_stripes_num=2, freq_drop_width=8, freq_stripes_num=2, ) self.bn0 = nn.BatchNorm2d(64) self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64) self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128) self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256) self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512) self.fc1 = nn.Linear(512, 512, bias=True) self.fc_audioset = nn.Linear(512, classes_num, bias=True) self.init_weight() def init_weight(self): init_bn(self.bn0) init_layer(self.fc1) init_layer(self.fc_audioset) def forward(self, input, mixup_lambda=None, device=None): """ Input: (batch_size, data_length)""" x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) x = x.transpose(1, 3) x = self.bn0(x) x = x.transpose(1, 3) if self.training: x = self.spec_augmenter(x) # Mixup on spectrogram if self.training and mixup_lambda is not None: x = do_mixup(x, mixup_lambda) x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") x = F.dropout(x, p=0.2, training=self.training) x = torch.mean(x, dim=3) latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) latent_x = latent_x1 + latent_x2 latent_x = latent_x.transpose(1, 2) latent_x = F.relu_(self.fc1(latent_x)) latent_output = interpolate(latent_x, 16) (x1, _) = torch.max(x, dim=2) x2 = torch.mean(x, dim=2) x = x1 + x2 x = F.dropout(x, p=0.5, training=self.training) x = F.relu_(self.fc1(x)) embedding = F.dropout(x, p=0.5, training=self.training) clipwise_output = torch.sigmoid(self.fc_audioset(x)) output_dict = { "clipwise_output": clipwise_output, "embedding": embedding, "fine_grained_embedding": latent_output, } return output_dict class Cnn10(nn.Module): def __init__( self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num, enable_fusion=False, fusion_type="None", ): super(Cnn10, self).__init__() window = "hann" center = True pad_mode = "reflect" ref = 1.0 amin = 1e-10 top_db = None self.enable_fusion = enable_fusion self.fusion_type = fusion_type # Spectrogram extractor self.spectrogram_extractor = Spectrogram( n_fft=window_size, hop_length=hop_size, win_length=window_size, window=window, center=center, pad_mode=pad_mode, freeze_parameters=True, ) # Logmel feature extractor self.logmel_extractor = LogmelFilterBank( sr=sample_rate, n_fft=window_size, n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, freeze_parameters=True, ) # Spec augmenter self.spec_augmenter = SpecAugmentation( time_drop_width=64, time_stripes_num=2, freq_drop_width=8, freq_stripes_num=2, ) self.bn0 = nn.BatchNorm2d(64) self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) self.fc1 = nn.Linear(1024, 1024, bias=True) self.fc_audioset = nn.Linear(1024, classes_num, bias=True) self.init_weight() def init_weight(self): init_bn(self.bn0) init_layer(self.fc1) init_layer(self.fc_audioset) def forward(self, input, mixup_lambda=None, device=None): """ Input: (batch_size, data_length)""" x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) x = x.transpose(1, 3) x = self.bn0(x) x = x.transpose(1, 3) if self.training: x = self.spec_augmenter(x) # Mixup on spectrogram if self.training and mixup_lambda is not None: x = do_mixup(x, mixup_lambda) x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") x = F.dropout(x, p=0.2, training=self.training) x = torch.mean(x, dim=3) latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) latent_x = latent_x1 + latent_x2 latent_x = latent_x.transpose(1, 2) latent_x = F.relu_(self.fc1(latent_x)) latent_output = interpolate(latent_x, 32) (x1, _) = torch.max(x, dim=2) x2 = torch.mean(x, dim=2) x = x1 + x2 x = F.dropout(x, p=0.5, training=self.training) x = F.relu_(self.fc1(x)) embedding = F.dropout(x, p=0.5, training=self.training) clipwise_output = torch.sigmoid(self.fc_audioset(x)) output_dict = { "clipwise_output": clipwise_output, "embedding": embedding, "fine_grained_embedding": latent_output, } return output_dict def create_pann_model(audio_cfg, enable_fusion=False, fusion_type="None"): try: ModelProto = eval(audio_cfg.model_name) model = ModelProto( sample_rate=audio_cfg.sample_rate, window_size=audio_cfg.window_size, hop_size=audio_cfg.hop_size, mel_bins=audio_cfg.mel_bins, fmin=audio_cfg.fmin, fmax=audio_cfg.fmax, classes_num=audio_cfg.class_num, enable_fusion=enable_fusion, fusion_type=fusion_type, ) return model except: raise RuntimeError( f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough." )