import torch | |
def l1(output, target): | |
return torch.mean(torch.abs(output - target)) | |
def l1_wav(output_dict, target_dict): | |
return l1(output_dict['segment'], target_dict['segment']) | |
def get_loss_function(loss_type): | |
if loss_type == "l1_wav": | |
return l1_wav | |
else: | |
raise NotImplementedError("Error!") | |