| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| import warnings |
|
|
| from .state import AcceleratorState, GradientState |
|
|
|
|
| warnings.filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler") |
|
|
|
|
| class AcceleratedScheduler: |
| """ |
| A wrapper around a learning rate scheduler that will only step when the optimizer(s) have a training step. Useful |
| to avoid making a scheduler step too fast when gradients went overflow and there was no training step (in mixed |
| precision training) |
| |
| When performing gradient accumulation scheduler lengths should not be changed accordingly, Accelerate will always |
| step the scheduler to account for it. |
| |
| Args: |
| scheduler (`torch.optim.lr_scheduler._LRScheduler`): |
| The scheduler to wrap. |
| optimizers (one or a list of `torch.optim.Optimizer`): |
| The optimizers used. |
| step_with_optimizer (`bool`, *optional*, defaults to `True`): |
| Whether or not the scheduler should be stepped at each optimizer step. |
| split_batches (`bool`, *optional*, defaults to `False`): |
| Whether or not the dataloaders split one batch across the different processes (so batch size is the same |
| regardless of the number of processes) or create batches on each process (so batch size is the original |
| batch size multiplied by the number of processes). |
| """ |
|
|
| def __init__(self, scheduler, optimizers, step_with_optimizer: bool = True, split_batches: bool = False): |
| self.scheduler = scheduler |
| self.optimizers = optimizers if isinstance(optimizers, (list, tuple)) else [optimizers] |
| self.split_batches = split_batches |
| self.step_with_optimizer = step_with_optimizer |
| self.gradient_state = GradientState() |
|
|
| def step(self, *args, **kwargs): |
| if not self.step_with_optimizer: |
| |
| self.scheduler.step(*args, **kwargs) |
| return |
|
|
| |
| if not self.gradient_state.sync_gradients: |
| if self.gradient_state.adjust_scheduler: |
| self.scheduler._step_count += 1 |
| return |
|
|
| for opt in self.optimizers: |
| if opt.step_was_skipped: |
| return |
| if self.split_batches: |
| |
| self.scheduler.step(*args, **kwargs) |
| else: |
| |
| |
| num_processes = AcceleratorState().num_processes |
| for _ in range(num_processes): |
| |
| if hasattr(self.scheduler, "total_steps"): |
| if self.scheduler._step_count <= self.scheduler.total_steps: |
| self.scheduler.step(*args, **kwargs) |
| else: |
| self.scheduler.step(*args, **kwargs) |
|
|
| |
| def get_last_lr(self): |
| return self.scheduler.get_last_lr() |
|
|
| def state_dict(self): |
| return self.scheduler.state_dict() |
|
|
| def load_state_dict(self, state_dict): |
| self.scheduler.load_state_dict(state_dict) |
|
|
| def get_lr(self): |
| return self.scheduler.get_lr() |
|
|
| def print_lr(self, *args, **kwargs): |
| return self.scheduler.print_lr(*args, **kwargs) |
|
|