Spaces:
Running
Running
File size: 3,233 Bytes
c4c7cee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import torch
class SigmoidScheduler:
def __init__(self, start=-3, end=3, tau=1, clip_min=1e-9):
self.start = start
self.end = end
self.tau = tau
self.clip_min = clip_min
self.v_start = torch.sigmoid(torch.tensor(self.start / self.tau))
self.v_end = torch.sigmoid(torch.tensor(self.end / self.tau))
def __call__(self, t):
output = (
-torch.sigmoid((t * (self.end - self.start) + self.start) / self.tau)
+ self.v_end
) / (self.v_end - self.v_start)
return torch.clamp(output, min=self.clip_min, max=1.0)
def derivative(self, t):
x = (t * (self.end - self.start) + self.start) / self.tau
sigmoid_x = torch.sigmoid(x)
# Chain rule: d/dt of original function
return (
-(self.end - self.start)
* sigmoid_x
* (1 - sigmoid_x)
/ (self.tau * (self.v_end - self.v_start))
)
def alpha(self, t):
return -self.derivative(t) / (1e-6 + self.__call__(t))
class LinearScheduler:
def __init__(self, start=1, end=0, clip_min=1e-9):
self.start = start
self.end = end
self.clip_min = clip_min
def __call__(self, t):
output = (self.end - self.start) * t + self.start
return torch.clamp(output, min=self.clip_min, max=1.0)
def derivative(self, t):
return torch.tensor(self.end - self.start).to(t.device)
def alpha(self, t):
return -self.derivative(t) / (1e-6 + self.__call__(t))
class CosineScheduler:
def __init__(
self,
start: float = 1,
end: float = 0,
tau: float = 1.0,
clip_min: float = 1e-9,
):
self.start = start
self.end = end
self.tau = tau
self.clip_min = clip_min
self.v_start = torch.cos(torch.tensor(self.start) * torch.pi / 2) ** (
2 * self.tau
)
self.v_end = torch.cos(torch.tensor(self.end) * torch.pi / 2) ** (2 * self.tau)
def __call__(self, t: float) -> float:
output = (
torch.cos((t * (self.end - self.start) + self.start) * torch.pi / 2)
** (2 * self.tau)
- self.v_end
) / (self.v_start - self.v_end)
return torch.clamp(output, min=self.clip_min, max=1.0)
def derivative(self, t: float) -> float:
x = (t * (self.end - self.start) + self.start) * torch.pi / 2
cos_x = torch.cos(x)
# Chain rule: d/dt of original function
return (
-2
* self.tau
* (self.end - self.start)
* torch.pi
/ 2
* cos_x
* (cos_x ** (2 * self.tau - 1))
* torch.sin(x)
/ (self.v_start - self.v_end)
)
class CosineSchedulerSimple:
def __init__(self, ns: float = 0.0002, ds: float = 0.00025):
self.ns = ns
self.ds = ds
def __call__(self, t: float) -> float:
return torch.cos(((t + self.ns) / (1 + self.ds)) * torch.pi / 2) ** 2
def derivative(self, t: float) -> float:
x = ((t + self.ns) / (1 + self.ds)) * torch.pi / 2
return -torch.pi * torch.cos(x) * torch.sin(x) / (1 + self.ds)
|