yourusername's picture
:beers: cheers
66a6dc0
import torch
class EpsilonScheduler:
def __init__(
self,
epsilon: float = 0.001,
patience: int = 10,
factor: float = 0.5,
verbose: bool = False,
):
self.epsilon = epsilon
self.patience = patience
self.factor = factor
self.best = 1e16
self.count = 0
self.verbose = verbose
def step(self, metric: float):
if metric < self.best:
self.best = metric
self.count = 0
else:
self.count += 1
if self.verbose:
print(f"Train loss has not improved for {self.count} epochs.")
if self.count >= self.patience:
self.count = 0
self.epsilon *= self.factor
if self.verbose:
print(f"Reducing epsilon to {self.epsilon:0.2e}...")