File size: 3,373 Bytes
746c674
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import torch.nn as nn
import math

try:
    from . import helpers as h
except:
    import helpers as h



class Const():
    def __init__(self, c):
        self.c = c if c is None else float(c)

    def getVal(self, c = None, **kargs):
        return self.c if self.c is not None else c

    def __str__(self):
        return str(self.c)

    def initConst(x):
        return x if isinstance(x, Const) else Const(x)

class Lin(Const):
    def __init__(self, start, end, steps, initial = 0, quant = False):
        self.start = float(start)
        self.end = float(end)
        self.steps = float(steps)
        self.initial = float(initial)
        self.quant = quant

    def getVal(self, time = 0, **kargs):
        if self.quant:
            time = math.floor(time)
        return (self.end - self.start) * max(0,min(1, float(time - self.initial) / self.steps)) + self.start

    def __str__(self):
        return "Lin(%s,%s,%s,%s, quant=%s)".format(str(self.start), str(self.end), str(self.steps), str(self.initial), str(self.quant))

class Until(Const):
    def __init__(self, thresh, a, b):
        self.a = Const.initConst(a)
        self.b = Const.initConst(b)
        self.thresh = thresh

    def getVal(self, *args, time = 0, **kargs):
        return self.a.getVal(*args, time = time, **kargs) if time < self.thresh else self.b.getVal(*args, time = time - self.thresh, **kargs)

    def __str__(self):
        return "Until(%s, %s, %s)" % (str(self.thresh), str(self.a), str(self.b))

class Scale(Const): # use with mix when aw = 1, and 0 <= c < 1
    def __init__(self, c):
        self.c = Const.initConst(c)

    def getVal(self, *args, **kargs):
        c = self.c.getVal(*args, **kargs)
        if c == 0:
            return 0
        assert c >= 0
        assert c < 1
        return c / (1 - c)

    def __str__(self):
        return "Scale(%s)" % str(self.c)

def MixLin(*args, **kargs):
    return Scale(Lin(*args, **kargs))

class Normal(Const):
    def __init__(self, c):
        self.c = Const.initConst(c)

    def getVal(self, *args, shape = [1], **kargs):
        c = self.c.getVal(*args, shape = shape, **kargs)
        return torch.randn(shape, device = h.device).abs() * c

    def __str__(self):
        return "Normal(%s)" % str(self.c)

class Clip(Const):
    def __init__(self, c, l, u):
        self.c = Const.initConst(c)
        self.l = Const.initConst(l)
        self.u = Const.initConst(u)

    def getVal(self, *args, **kargs):
        c = self.c.getVal(*args, **kargs)
        l = self.l.getVal(*args, **kargs)
        u = self.u.getVal(*args, **kargs)
        if isinstance(c, float):
            return min(max(c,l),u)
        else:
            return c.clamp(l,u)

    def __str__(self):
        return "Clip(%s, %s, %s)" % (str(self.c), str(self.l), str(self.u))

class Fun(Const):
    def __init__(self, foo):
        self.foo = foo
    def getVal(self, *args, **kargs):
        return self.foo(*args, **kargs)
    
    def __str__(self):
        return "Fun(...)"

class Complement(Const): # use with mix when aw = 1, and 0 <= c < 1
    def __init__(self, c):
        self.c = Const.initConst(c)

    def getVal(self, *args, **kargs):
        c = self.c.getVal(*args, **kargs)
        assert c >= 0
        assert c <= 1
        return 1 - c

    def __str__(self):
        return "Complement(%s)" % str(self.c)