Spaces:
Runtime error
Runtime error
File size: 1,386 Bytes
3b49518 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
# Copyright (c) EPFL VILAB.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
class NoWeightingStrategy(nn.Module):
"""No weighting strategy
"""
def __init__(self, **kwargs):
super(NoWeightingStrategy, self).__init__()
def forward(self, task_losses):
return task_losses
class UncertaintyWeightingStrategy(nn.Module):
"""Uncertainty weighting strategy
"""
def __init__(self, tasks):
super(UncertaintyWeightingStrategy, self).__init__()
self.tasks = tasks
self.log_vars = nn.Parameter(torch.zeros(len(tasks)))
def forward(self, task_losses):
losses_tensor = torch.stack(list(task_losses.values()))
non_zero_losses_mask = (losses_tensor != 0.0)
# calculate weighted losses
losses_tensor = torch.exp(-self.log_vars) * losses_tensor + self.log_vars
# if some loss was 0 (i.e. task was dropped), weighted loss should also be 0 and not just log_var as no information was gained
losses_tensor *= non_zero_losses_mask
# return dictionary of weighted task losses
weighted_task_losses = task_losses.copy()
weighted_task_losses.update(zip(weighted_task_losses, losses_tensor))
return weighted_task_losses
|