Source code for accelerate.optimizer

# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from .state import AcceleratorState, DistributedType, is_tpu_available


if is_tpu_available():
    import torch_xla.core.xla_model as xm


def move_to_device(state, device):
    if isinstance(state, (list, tuple)):
        return type(state)(move_to_device(t, device) for t in state)
    elif isinstance(state, dict):
        return type(state)({k: move_to_device(v, device) for k, v in state.items()})
    elif isinstance(state, torch.Tensor):
        return state.to(device)
    return state


[docs]class AcceleratedOptimizer(torch.optim.Optimizer): """ Internal wrapper around a torch optimizer. Args: optimizer (:obj:`torch.optim.optimizer.Optimizer`): The optimizer to wrap. device_placement (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the optimizer should handle device placement. If so, it will place the state dictionary of :obj:`optimizer` on the right device. scaler (:obj:`torch.cuda.amp.grad_scaler.GradScaler`, `optional`): The scaler to use in the step function if training with mixed precision. """ def __init__(self, optimizer, device_placement=True, scaler=None): self.optimizer = optimizer self.scaler = scaler self.state = AcceleratorState() # Handle device placement if device_placement: state_dict = self.optimizer.state_dict() if self.state.distributed_type == DistributedType.TPU: xm.send_cpu_data_to_device(state_dict, self.state.device) else: state_dict = move_to_device(state_dict, self.state.device) self.optimizer.load_state_dict(state_dict) @property def param_groups(self): return self.optimizer.param_groups def add_param_group(self, param_group): self.optimizer.add_param_group(param_group) def load_state_dict(self, state_dict): if self.state.distributed_type == DistributedType.TPU and self.device_placement: xm.send_cpu_data_to_device(state_dict, self.state.device) self.optimizer.load_state_dict(state_dict) def state_dict(self): return self.optimizer.state_dict() def zero_grad(self): self.optimizer.zero_grad() def step(self): if self.state.distributed_type == DistributedType.TPU: xm.optimizer_step(self.optimizer) elif self.scaler is not None: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() def _switch_parameters(self, parameters_map): for param_group in self.optimizer.param_groups: param_group["params"] = [parameters_map.get(p, p) for p in param_group["params"]]