File size: 331 Bytes
ae29df4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch


def l1(output, target):
    return torch.mean(torch.abs(output - target))


def l1_wav(output_dict, target_dict):
	return l1(output_dict['segment'], target_dict['segment'])


def get_loss_function(loss_type):
    if loss_type == "l1_wav":
        return l1_wav

    else:
        raise NotImplementedError("Error!")