|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch import nn |
|
|
|
|
|
class MultipleOutputLoss2(nn.Module): |
|
def __init__(self, loss, weight_factors=None): |
|
""" |
|
use this if you have several outputs and ground truth (both list of same len) and the loss should be computed |
|
between them (x[0] and y[0], x[1] and y[1] etc) |
|
:param loss: |
|
:param weight_factors: |
|
""" |
|
super(MultipleOutputLoss2, self).__init__() |
|
self.weight_factors = weight_factors |
|
self.loss = loss |
|
|
|
def forward(self, x, y): |
|
assert isinstance(x, (tuple, list)), "x must be either tuple or list" |
|
assert isinstance(y, (tuple, list)), "y must be either tuple or list" |
|
if self.weight_factors is None: |
|
weights = [0] * len(x) |
|
weights[0] = 1 |
|
else: |
|
weights = self.weight_factors |
|
|
|
l = weights[0] * self.loss(x[0], y[0]) |
|
for i in range(1, len(x)): |
|
if weights[i] != 0: |
|
l += weights[i] * self.loss(x[i], y[i]) |
|
return l |
|
|
|
|
|
|