"""This pytorch_utils.py contains functions from: https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/pytorch/pytorch_utils.py """ import torch def move_data_to_device(x, device): if 'float' in str(x.dtype): x = torch.Tensor(x) elif 'int' in str(x.dtype): x = torch.LongTensor(x) else: return x return x.to(device) def interpolate(x, ratio): """Interpolate the prediction to compensate the downsampling operation in a CNN. Args: x: (batch_size, time_steps, classes_num) ratio: int, ratio to upsample """ (batch_size, time_steps, classes_num) = x.shape upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) return upsampled def pad_framewise_output(framewise_output, frames_num): """Pad framewise_output to the same length as input frames. Args: framewise_output: (batch_size, frames_num, classes_num) frames_num: int, number of frames to pad Outputs: output: (batch_size, frames_num, classes_num) """ pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1) """tensor for padding""" output = torch.cat((framewise_output, pad), dim=1) """(batch_size, frames_num, classes_num)""" return output def do_mixup(x, mixup_lambda): out = x[0::2].transpose(0, -1) * mixup_lambda[0::2] + \ x[1::2].transpose(0, -1) * mixup_lambda[1::2] return out.transpose(0, -1)