jiuku's picture
Duplicate from haoheliu/audioldm-text-to-audio-generation
4039be3
# PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
# Reference from https://github.com/qiuqiangkong/audioset_tagging_cnn
# Some layers are re-designed for CLAP
import os
os.environ["NUMBA_CACHE_DIR"] = "/tmp/"
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from torchlibrosa.augmentation import SpecAugmentation
from .utils import do_mixup, interpolate, pad_framewise_output
from .feature_fusion import iAFF, AFF, DAF
def init_layer(layer):
"""Initialize a Linear or Convolutional layer."""
nn.init.xavier_uniform_(layer.weight)
if hasattr(layer, "bias"):
if layer.bias is not None:
layer.bias.data.fill_(0.0)
def init_bn(bn):
"""Initialize a Batchnorm layer."""
bn.bias.data.fill_(0.0)
bn.weight.data.fill_(1.0)
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
bias=False,
)
self.conv2 = nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
bias=False,
)
self.bn1 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.init_weight()
def init_weight(self):
init_layer(self.conv1)
init_layer(self.conv2)
init_bn(self.bn1)
init_bn(self.bn2)
def forward(self, input, pool_size=(2, 2), pool_type="avg"):
x = input
x = F.relu_(self.bn1(self.conv1(x)))
x = F.relu_(self.bn2(self.conv2(x)))
if pool_type == "max":
x = F.max_pool2d(x, kernel_size=pool_size)
elif pool_type == "avg":
x = F.avg_pool2d(x, kernel_size=pool_size)
elif pool_type == "avg+max":
x1 = F.avg_pool2d(x, kernel_size=pool_size)
x2 = F.max_pool2d(x, kernel_size=pool_size)
x = x1 + x2
else:
raise Exception("Incorrect argument!")
return x
class ConvBlock5x5(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvBlock5x5, self).__init__()
self.conv1 = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(5, 5),
stride=(1, 1),
padding=(2, 2),
bias=False,
)
self.bn1 = nn.BatchNorm2d(out_channels)
self.init_weight()
def init_weight(self):
init_layer(self.conv1)
init_bn(self.bn1)
def forward(self, input, pool_size=(2, 2), pool_type="avg"):
x = input
x = F.relu_(self.bn1(self.conv1(x)))
if pool_type == "max":
x = F.max_pool2d(x, kernel_size=pool_size)
elif pool_type == "avg":
x = F.avg_pool2d(x, kernel_size=pool_size)
elif pool_type == "avg+max":
x1 = F.avg_pool2d(x, kernel_size=pool_size)
x2 = F.max_pool2d(x, kernel_size=pool_size)
x = x1 + x2
else:
raise Exception("Incorrect argument!")
return x
class AttBlock(nn.Module):
def __init__(self, n_in, n_out, activation="linear", temperature=1.0):
super(AttBlock, self).__init__()
self.activation = activation
self.temperature = temperature
self.att = nn.Conv1d(
in_channels=n_in,
out_channels=n_out,
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
self.cla = nn.Conv1d(
in_channels=n_in,
out_channels=n_out,
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
self.bn_att = nn.BatchNorm1d(n_out)
self.init_weights()
def init_weights(self):
init_layer(self.att)
init_layer(self.cla)
init_bn(self.bn_att)
def forward(self, x):
# x: (n_samples, n_in, n_time)
norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1)
cla = self.nonlinear_transform(self.cla(x))
x = torch.sum(norm_att * cla, dim=2)
return x, norm_att, cla
def nonlinear_transform(self, x):
if self.activation == "linear":
return x
elif self.activation == "sigmoid":
return torch.sigmoid(x)
class Cnn14(nn.Module):
def __init__(
self,
sample_rate,
window_size,
hop_size,
mel_bins,
fmin,
fmax,
classes_num,
enable_fusion=False,
fusion_type="None",
):
super(Cnn14, self).__init__()
window = "hann"
center = True
pad_mode = "reflect"
ref = 1.0
amin = 1e-10
top_db = None
self.enable_fusion = enable_fusion
self.fusion_type = fusion_type
# Spectrogram extractor
self.spectrogram_extractor = Spectrogram(
n_fft=window_size,
hop_length=hop_size,
win_length=window_size,
window=window,
center=center,
pad_mode=pad_mode,
freeze_parameters=True,
)
# Logmel feature extractor
self.logmel_extractor = LogmelFilterBank(
sr=sample_rate,
n_fft=window_size,
n_mels=mel_bins,
fmin=fmin,
fmax=fmax,
ref=ref,
amin=amin,
top_db=top_db,
freeze_parameters=True,
)
# Spec augmenter
self.spec_augmenter = SpecAugmentation(
time_drop_width=64,
time_stripes_num=2,
freq_drop_width=8,
freq_stripes_num=2,
)
self.bn0 = nn.BatchNorm2d(64)
if (self.enable_fusion) and (self.fusion_type == "channel_map"):
self.conv_block1 = ConvBlock(in_channels=4, out_channels=64)
else:
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
self.fc1 = nn.Linear(2048, 2048, bias=True)
self.fc_audioset = nn.Linear(2048, classes_num, bias=True)
if (self.enable_fusion) and (
self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]
):
self.mel_conv1d = nn.Sequential(
nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
nn.BatchNorm1d(64), # No Relu
)
if self.fusion_type == "daf_1d":
self.fusion_model = DAF()
elif self.fusion_type == "aff_1d":
self.fusion_model = AFF(channels=64, type="1D")
elif self.fusion_type == "iaff_1d":
self.fusion_model = iAFF(channels=64, type="1D")
if (self.enable_fusion) and (
self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
):
self.mel_conv2d = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=(5, 5), stride=(6, 2), padding=(2, 2)),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
)
if self.fusion_type == "daf_2d":
self.fusion_model = DAF()
elif self.fusion_type == "aff_2d":
self.fusion_model = AFF(channels=64, type="2D")
elif self.fusion_type == "iaff_2d":
self.fusion_model = iAFF(channels=64, type="2D")
self.init_weight()
def init_weight(self):
init_bn(self.bn0)
init_layer(self.fc1)
init_layer(self.fc_audioset)
def forward(self, input, mixup_lambda=None, device=None):
"""
Input: (batch_size, data_length)"""
if self.enable_fusion and input["longer"].sum() == 0:
# if no audio is longer than 10s, then randomly select one audio to be longer
input["longer"][torch.randint(0, input["longer"].shape[0], (1,))] = True
if not self.enable_fusion:
x = self.spectrogram_extractor(
input["waveform"].to(device=device, non_blocking=True)
) # (batch_size, 1, time_steps, freq_bins)
x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
x = x.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
else:
longer_list = input["longer"].to(device=device, non_blocking=True)
x = input["mel_fusion"].to(device=device, non_blocking=True)
longer_list_idx = torch.where(longer_list)[0]
x = x.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]:
new_x = x[:, 0:1, :, :].clone().contiguous()
# local processing
if len(longer_list_idx) > 0:
fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous()
FB, FC, FT, FF = fusion_x_local.size()
fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
fusion_x_local = torch.permute(
fusion_x_local, (0, 2, 1)
).contiguous()
fusion_x_local = self.mel_conv1d(fusion_x_local)
fusion_x_local = fusion_x_local.view(
FB, FC, FF, fusion_x_local.size(-1)
)
fusion_x_local = (
torch.permute(fusion_x_local, (0, 2, 1, 3))
.contiguous()
.flatten(2)
)
if fusion_x_local.size(-1) < FT:
fusion_x_local = torch.cat(
[
fusion_x_local,
torch.zeros(
(FB, FF, FT - fusion_x_local.size(-1)),
device=device,
),
],
dim=-1,
)
else:
fusion_x_local = fusion_x_local[:, :, :FT]
# 1D fusion
new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous()
new_x[longer_list_idx] = self.fusion_model(
new_x[longer_list_idx], fusion_x_local
)
x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :]
else:
x = new_x
elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]:
x = x # no change
if self.training:
x = self.spec_augmenter(x)
# Mixup on spectrogram
if self.training and mixup_lambda is not None:
x = do_mixup(x, mixup_lambda)
if (self.enable_fusion) and (
self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
):
global_x = x[:, 0:1, :, :]
# global processing
B, C, H, W = global_x.shape
global_x = self.conv_block1(global_x, pool_size=(2, 2), pool_type="avg")
if len(longer_list_idx) > 0:
local_x = x[longer_list_idx, 1:, :, :].contiguous()
TH = global_x.size(-2)
# local processing
B, C, H, W = local_x.shape
local_x = local_x.view(B * C, 1, H, W)
local_x = self.mel_conv2d(local_x)
local_x = local_x.view(
B, C, local_x.size(1), local_x.size(2), local_x.size(3)
)
local_x = local_x.permute((0, 2, 1, 3, 4)).contiguous().flatten(2, 3)
TB, TC, _, TW = local_x.size()
if local_x.size(-2) < TH:
local_x = torch.cat(
[
local_x,
torch.zeros(
(TB, TC, TH - local_x.size(-2), TW),
device=global_x.device,
),
],
dim=-2,
)
else:
local_x = local_x[:, :, :TH, :]
global_x[longer_list_idx] = self.fusion_model(
global_x[longer_list_idx], local_x
)
x = global_x
else:
x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
x = F.dropout(x, p=0.2, training=self.training)
x = torch.mean(x, dim=3)
latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
latent_x = latent_x1 + latent_x2
latent_x = latent_x.transpose(1, 2)
latent_x = F.relu_(self.fc1(latent_x))
latent_output = interpolate(latent_x, 32)
(x1, _) = torch.max(x, dim=2)
x2 = torch.mean(x, dim=2)
x = x1 + x2
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu_(self.fc1(x))
embedding = F.dropout(x, p=0.5, training=self.training)
clipwise_output = torch.sigmoid(self.fc_audioset(x))
output_dict = {
"clipwise_output": clipwise_output,
"embedding": embedding,
"fine_grained_embedding": latent_output,
}
return output_dict
class Cnn6(nn.Module):
def __init__(
self,
sample_rate,
window_size,
hop_size,
mel_bins,
fmin,
fmax,
classes_num,
enable_fusion=False,
fusion_type="None",
):
super(Cnn6, self).__init__()
window = "hann"
center = True
pad_mode = "reflect"
ref = 1.0
amin = 1e-10
top_db = None
self.enable_fusion = enable_fusion
self.fusion_type = fusion_type
# Spectrogram extractor
self.spectrogram_extractor = Spectrogram(
n_fft=window_size,
hop_length=hop_size,
win_length=window_size,
window=window,
center=center,
pad_mode=pad_mode,
freeze_parameters=True,
)
# Logmel feature extractor
self.logmel_extractor = LogmelFilterBank(
sr=sample_rate,
n_fft=window_size,
n_mels=mel_bins,
fmin=fmin,
fmax=fmax,
ref=ref,
amin=amin,
top_db=top_db,
freeze_parameters=True,
)
# Spec augmenter
self.spec_augmenter = SpecAugmentation(
time_drop_width=64,
time_stripes_num=2,
freq_drop_width=8,
freq_stripes_num=2,
)
self.bn0 = nn.BatchNorm2d(64)
self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512)
self.fc1 = nn.Linear(512, 512, bias=True)
self.fc_audioset = nn.Linear(512, classes_num, bias=True)
self.init_weight()
def init_weight(self):
init_bn(self.bn0)
init_layer(self.fc1)
init_layer(self.fc_audioset)
def forward(self, input, mixup_lambda=None, device=None):
"""
Input: (batch_size, data_length)"""
x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
x = x.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
if self.training:
x = self.spec_augmenter(x)
# Mixup on spectrogram
if self.training and mixup_lambda is not None:
x = do_mixup(x, mixup_lambda)
x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
x = F.dropout(x, p=0.2, training=self.training)
x = torch.mean(x, dim=3)
latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
latent_x = latent_x1 + latent_x2
latent_x = latent_x.transpose(1, 2)
latent_x = F.relu_(self.fc1(latent_x))
latent_output = interpolate(latent_x, 16)
(x1, _) = torch.max(x, dim=2)
x2 = torch.mean(x, dim=2)
x = x1 + x2
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu_(self.fc1(x))
embedding = F.dropout(x, p=0.5, training=self.training)
clipwise_output = torch.sigmoid(self.fc_audioset(x))
output_dict = {
"clipwise_output": clipwise_output,
"embedding": embedding,
"fine_grained_embedding": latent_output,
}
return output_dict
class Cnn10(nn.Module):
def __init__(
self,
sample_rate,
window_size,
hop_size,
mel_bins,
fmin,
fmax,
classes_num,
enable_fusion=False,
fusion_type="None",
):
super(Cnn10, self).__init__()
window = "hann"
center = True
pad_mode = "reflect"
ref = 1.0
amin = 1e-10
top_db = None
self.enable_fusion = enable_fusion
self.fusion_type = fusion_type
# Spectrogram extractor
self.spectrogram_extractor = Spectrogram(
n_fft=window_size,
hop_length=hop_size,
win_length=window_size,
window=window,
center=center,
pad_mode=pad_mode,
freeze_parameters=True,
)
# Logmel feature extractor
self.logmel_extractor = LogmelFilterBank(
sr=sample_rate,
n_fft=window_size,
n_mels=mel_bins,
fmin=fmin,
fmax=fmax,
ref=ref,
amin=amin,
top_db=top_db,
freeze_parameters=True,
)
# Spec augmenter
self.spec_augmenter = SpecAugmentation(
time_drop_width=64,
time_stripes_num=2,
freq_drop_width=8,
freq_stripes_num=2,
)
self.bn0 = nn.BatchNorm2d(64)
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
self.fc1 = nn.Linear(1024, 1024, bias=True)
self.fc_audioset = nn.Linear(1024, classes_num, bias=True)
self.init_weight()
def init_weight(self):
init_bn(self.bn0)
init_layer(self.fc1)
init_layer(self.fc_audioset)
def forward(self, input, mixup_lambda=None, device=None):
"""
Input: (batch_size, data_length)"""
x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
x = x.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
if self.training:
x = self.spec_augmenter(x)
# Mixup on spectrogram
if self.training and mixup_lambda is not None:
x = do_mixup(x, mixup_lambda)
x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
x = F.dropout(x, p=0.2, training=self.training)
x = torch.mean(x, dim=3)
latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
latent_x = latent_x1 + latent_x2
latent_x = latent_x.transpose(1, 2)
latent_x = F.relu_(self.fc1(latent_x))
latent_output = interpolate(latent_x, 32)
(x1, _) = torch.max(x, dim=2)
x2 = torch.mean(x, dim=2)
x = x1 + x2
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu_(self.fc1(x))
embedding = F.dropout(x, p=0.5, training=self.training)
clipwise_output = torch.sigmoid(self.fc_audioset(x))
output_dict = {
"clipwise_output": clipwise_output,
"embedding": embedding,
"fine_grained_embedding": latent_output,
}
return output_dict
def create_pann_model(audio_cfg, enable_fusion=False, fusion_type="None"):
try:
ModelProto = eval(audio_cfg.model_name)
model = ModelProto(
sample_rate=audio_cfg.sample_rate,
window_size=audio_cfg.window_size,
hop_size=audio_cfg.hop_size,
mel_bins=audio_cfg.mel_bins,
fmin=audio_cfg.fmin,
fmax=audio_cfg.fmax,
classes_num=audio_cfg.class_num,
enable_fusion=enable_fusion,
fusion_type=fusion_type,
)
return model
except:
raise RuntimeError(
f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough."
)