MultiMAE / utils /task_balancing.py
Bachmann Roman Christian
Initial commit
3b49518
raw
history blame
No virus
1.39 kB
# Copyright (c) EPFL VILAB.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
class NoWeightingStrategy(nn.Module):
"""No weighting strategy
"""
def __init__(self, **kwargs):
super(NoWeightingStrategy, self).__init__()
def forward(self, task_losses):
return task_losses
class UncertaintyWeightingStrategy(nn.Module):
"""Uncertainty weighting strategy
"""
def __init__(self, tasks):
super(UncertaintyWeightingStrategy, self).__init__()
self.tasks = tasks
self.log_vars = nn.Parameter(torch.zeros(len(tasks)))
def forward(self, task_losses):
losses_tensor = torch.stack(list(task_losses.values()))
non_zero_losses_mask = (losses_tensor != 0.0)
# calculate weighted losses
losses_tensor = torch.exp(-self.log_vars) * losses_tensor + self.log_vars
# if some loss was 0 (i.e. task was dropped), weighted loss should also be 0 and not just log_var as no information was gained
losses_tensor *= non_zero_losses_mask
# return dictionary of weighted task losses
weighted_task_losses = task_losses.copy()
weighted_task_losses.update(zip(weighted_task_losses, losses_tensor))
return weighted_task_losses