lmzjms's picture
Upload 35 files
15ac91d
raw
history blame
313 Bytes
import torch
import torch.nn.functional as F
def clip_bce(output_dict, target_dict):
"""Binary crossentropy loss.
"""
return F.binary_cross_entropy(
output_dict['clipwise_output'], target_dict['target'])
def get_loss_func(loss_type):
if loss_type == 'clip_bce':
return clip_bce