File size: 850 Bytes
641e847
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from typing import Any

from pytorch_toolbelt.losses import BinaryFocalLoss
from torch import nn
from torch.nn.modules.loss import BCEWithLogitsLoss


class WeightedLosses(nn.Module):
    def __init__(self, losses, weights):
        super().__init__()
        self.losses = losses
        self.weights = weights

    def forward(self, *input: Any, **kwargs: Any):
        cum_loss = 0
        for loss, w in zip(self.losses, self.weights):
            cum_loss += w * loss.forward(*input, **kwargs)
        return cum_loss


class BinaryCrossentropy(BCEWithLogitsLoss):
    pass


class FocalLoss(BinaryFocalLoss):
    def __init__(self, alpha=None, gamma=3, ignore_index=None, reduction="mean", normalized=False,
                 reduced_threshold=None):
        super().__init__(alpha, gamma, ignore_index, reduction, normalized, reduced_threshold)