wav_SED / panns_inference /pytorch_utils.py
pcdarvin's picture
draft
5e740f6
"""This pytorch_utils.py contains functions from:
https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/pytorch/pytorch_utils.py
"""
import torch
def move_data_to_device(x, device):
if 'float' in str(x.dtype):
x = torch.Tensor(x)
elif 'int' in str(x.dtype):
x = torch.LongTensor(x)
else:
return x
return x.to(device)
def interpolate(x, ratio):
"""Interpolate the prediction to compensate the downsampling operation in a
CNN.
Args:
x: (batch_size, time_steps, classes_num)
ratio: int, ratio to upsample
"""
(batch_size, time_steps, classes_num) = x.shape
upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
return upsampled
def pad_framewise_output(framewise_output, frames_num):
"""Pad framewise_output to the same length as input frames.
Args:
framewise_output: (batch_size, frames_num, classes_num)
frames_num: int, number of frames to pad
Outputs:
output: (batch_size, frames_num, classes_num)
"""
pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1)
"""tensor for padding"""
output = torch.cat((framewise_output, pad), dim=1)
"""(batch_size, frames_num, classes_num)"""
return output
def do_mixup(x, mixup_lambda):
out = x[0::2].transpose(0, -1) * mixup_lambda[0::2] + \
x[1::2].transpose(0, -1) * mixup_lambda[1::2]
return out.transpose(0, -1)