File size: 3,373 Bytes
746c674 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import torch
import torch.nn as nn
import math
try:
from . import helpers as h
except:
import helpers as h
class Const():
def __init__(self, c):
self.c = c if c is None else float(c)
def getVal(self, c = None, **kargs):
return self.c if self.c is not None else c
def __str__(self):
return str(self.c)
def initConst(x):
return x if isinstance(x, Const) else Const(x)
class Lin(Const):
def __init__(self, start, end, steps, initial = 0, quant = False):
self.start = float(start)
self.end = float(end)
self.steps = float(steps)
self.initial = float(initial)
self.quant = quant
def getVal(self, time = 0, **kargs):
if self.quant:
time = math.floor(time)
return (self.end - self.start) * max(0,min(1, float(time - self.initial) / self.steps)) + self.start
def __str__(self):
return "Lin(%s,%s,%s,%s, quant=%s)".format(str(self.start), str(self.end), str(self.steps), str(self.initial), str(self.quant))
class Until(Const):
def __init__(self, thresh, a, b):
self.a = Const.initConst(a)
self.b = Const.initConst(b)
self.thresh = thresh
def getVal(self, *args, time = 0, **kargs):
return self.a.getVal(*args, time = time, **kargs) if time < self.thresh else self.b.getVal(*args, time = time - self.thresh, **kargs)
def __str__(self):
return "Until(%s, %s, %s)" % (str(self.thresh), str(self.a), str(self.b))
class Scale(Const): # use with mix when aw = 1, and 0 <= c < 1
def __init__(self, c):
self.c = Const.initConst(c)
def getVal(self, *args, **kargs):
c = self.c.getVal(*args, **kargs)
if c == 0:
return 0
assert c >= 0
assert c < 1
return c / (1 - c)
def __str__(self):
return "Scale(%s)" % str(self.c)
def MixLin(*args, **kargs):
return Scale(Lin(*args, **kargs))
class Normal(Const):
def __init__(self, c):
self.c = Const.initConst(c)
def getVal(self, *args, shape = [1], **kargs):
c = self.c.getVal(*args, shape = shape, **kargs)
return torch.randn(shape, device = h.device).abs() * c
def __str__(self):
return "Normal(%s)" % str(self.c)
class Clip(Const):
def __init__(self, c, l, u):
self.c = Const.initConst(c)
self.l = Const.initConst(l)
self.u = Const.initConst(u)
def getVal(self, *args, **kargs):
c = self.c.getVal(*args, **kargs)
l = self.l.getVal(*args, **kargs)
u = self.u.getVal(*args, **kargs)
if isinstance(c, float):
return min(max(c,l),u)
else:
return c.clamp(l,u)
def __str__(self):
return "Clip(%s, %s, %s)" % (str(self.c), str(self.l), str(self.u))
class Fun(Const):
def __init__(self, foo):
self.foo = foo
def getVal(self, *args, **kargs):
return self.foo(*args, **kargs)
def __str__(self):
return "Fun(...)"
class Complement(Const): # use with mix when aw = 1, and 0 <= c < 1
def __init__(self, c):
self.c = Const.initConst(c)
def getVal(self, *args, **kargs):
c = self.c.getVal(*args, **kargs)
assert c >= 0
assert c <= 1
return 1 - c
def __str__(self):
return "Complement(%s)" % str(self.c)
|