File size: 3,300 Bytes
6a6edcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import numpy as np

# --- Loss Weighting
class BaseLossWeight():
    def weight(self, logSNR):
        raise NotImplementedError("this method needs to be overridden")

    def __call__(self, logSNR, *args, shift=1, clamp_range=None, **kwargs):
        clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range
        if shift != 1:
            logSNR = logSNR.clone() + 2 * np.log(shift)
        return self.weight(logSNR, *args, **kwargs).clamp(*clamp_range)

class ComposedLossWeight(BaseLossWeight):
    def __init__(self, div, mul):
        self.mul = [mul] if isinstance(mul, BaseLossWeight) else mul
        self.div = [div] if isinstance(div, BaseLossWeight) else div

    def weight(self, logSNR):
        prod, div = 1, 1
        for m in self.mul:
            prod *= m.weight(logSNR)
        for d in self.div:
            div *= d.weight(logSNR)
        return prod/div

class ConstantLossWeight(BaseLossWeight):
    def __init__(self, v=1):
        self.v = v

    def weight(self, logSNR):
        return torch.ones_like(logSNR) * self.v

class SNRLossWeight(BaseLossWeight):
    def weight(self, logSNR):
        return logSNR.exp()

class P2LossWeight(BaseLossWeight):
    def __init__(self, k=1.0, gamma=1.0, s=1.0):
        self.k, self.gamma, self.s = k, gamma, s

    def weight(self, logSNR):
        return (self.k + (logSNR * self.s).exp()) ** -self.gamma

class SNRPlusOneLossWeight(BaseLossWeight):
    def weight(self, logSNR):
        return logSNR.exp() + 1

class MinSNRLossWeight(BaseLossWeight):
    def __init__(self, max_snr=5):
        self.max_snr = max_snr

    def weight(self, logSNR):
        return logSNR.exp().clamp(max=self.max_snr)

class MinSNRPlusOneLossWeight(BaseLossWeight):
    def __init__(self, max_snr=5):
        self.max_snr = max_snr

    def weight(self, logSNR):
        return (logSNR.exp() + 1).clamp(max=self.max_snr)

class TruncatedSNRLossWeight(BaseLossWeight):
    def __init__(self, min_snr=1):
        self.min_snr = min_snr

    def weight(self, logSNR):
        return logSNR.exp().clamp(min=self.min_snr)

class SechLossWeight(BaseLossWeight):
    def __init__(self, div=2):
        self.div = div

    def weight(self, logSNR):
        return 1/(logSNR/self.div).cosh()

class DebiasedLossWeight(BaseLossWeight):
    def weight(self, logSNR):
        return 1/logSNR.exp().sqrt()

class SigmoidLossWeight(BaseLossWeight):
    def __init__(self, s=1):
        self.s = s

    def weight(self, logSNR):
        return (logSNR * self.s).sigmoid()

class AdaptiveLossWeight(BaseLossWeight):
    def __init__(self, logsnr_range=[-10, 10], buckets=300, weight_range=[1e-7, 1e7]):
        self.bucket_ranges = torch.linspace(logsnr_range[0], logsnr_range[1], buckets-1)
        self.bucket_losses = torch.ones(buckets)
        self.weight_range = weight_range

    def weight(self, logSNR):
        indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR)
        return (1/self.bucket_losses.to(logSNR.device)[indices]).clamp(*self.weight_range)

    def update_buckets(self, logSNR, loss, beta=0.99):
        indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR).cpu()
        self.bucket_losses[indices] = self.bucket_losses[indices]*beta + loss.detach().cpu() * (1-beta)