File size: 146 Bytes
9667e74
 
 
 
 
1
2
3
4
5
6
def l2_loss(input, target, mask, batch_size):
    loss = (input - target) * mask
    loss = (loss * loss) / 2 / batch_size

    return loss.sum()