File size: 1,651 Bytes
e7dd443 76a55af e7dd443 76a55af e7dd443 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
from torch.optim import Optimizer
from typing import Callable
Schedule = Callable[[float], float]
def linear_schedule(
start_val: float, end_val: float, end_fraction: float = 1.0
) -> Schedule:
def func(progress_fraction: float) -> float:
if progress_fraction >= end_fraction:
return end_val
else:
return start_val + (end_val - start_val) * progress_fraction / end_fraction
return func
def constant_schedule(val: float) -> Schedule:
return lambda f: val
def spike_schedule(
max_value: float,
start_fraction: float = 1e-2,
end_fraction: float = 1e-4,
peak_progress: float = 0.1,
) -> Schedule:
assert 0 < peak_progress < 1
def func(progress_fraction: float) -> float:
if progress_fraction < peak_progress:
fraction = (
start_fraction
+ (1 - start_fraction) * progress_fraction / peak_progress
)
else:
fraction = 1 + (end_fraction - 1) * (progress_fraction - peak_progress) / (
1 - peak_progress
)
return max_value * fraction
return func
def schedule(name: str, start_val: float) -> Schedule:
if name == "linear":
return linear_schedule(start_val, 0)
elif name == "none":
return constant_schedule(start_val)
elif name == "spike":
return spike_schedule(start_val)
else:
raise ValueError(f"Schedule {name} not supported")
def update_learning_rate(optimizer: Optimizer, learning_rate: float) -> None:
for param_group in optimizer.param_groups:
param_group["lr"] = learning_rate
|