sgoodfriend's picture
A2C playing Walker2DBulletEnv-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/0760ef7d52b17f30219a27c18ba52c8895025ae3
aa3f47c
raw
history blame
No virus
852 Bytes
from torch.optim import Optimizer
from typing import Callable
Schedule = Callable[[float], float]
def linear_schedule(
start_val: float, end_val: float, end_fraction: float = 1.0
) -> Schedule:
def func(progress_fraction: float) -> float:
if progress_fraction >= end_fraction:
return end_val
else:
return start_val + (end_val - start_val) * progress_fraction / end_fraction
return func
def constant_schedule(val: float) -> Schedule:
return lambda f: val
def schedule(name: str, start_val: float) -> Schedule:
if name == "linear":
return linear_schedule(start_val, 0)
return constant_schedule(start_val)
def update_learning_rate(optimizer: Optimizer, learning_rate: float) -> None:
for param_group in optimizer.param_groups:
param_group["lr"] = learning_rate