Spaces:
Runtime error
Runtime error
File size: 481 Bytes
c310e19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
# TODO maybe push this to nn?
def smooth_l1_loss(input, target, beta=1. / 9, size_average=True):
"""
very similar to the smooth_l1_loss from pytorch, but with
the extra beta parameter
"""
n = torch.abs(input - target)
cond = n < beta
loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta)
if size_average:
return loss.mean()
return loss.sum()
|