|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from abc import ABCMeta, abstractmethod |
|
|
from typing import Dict, List |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
class BaseOptimWrapper(metaclass=ABCMeta): |
|
|
|
|
|
def __init__(self, optimizer): |
|
|
self.optimizer = optimizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(optimizer.param_groups) > 1: |
|
|
self.base_param_settings = { |
|
|
'params': torch.tensor([0.0], dtype=torch.float) |
|
|
} |
|
|
self.base_param_settings.update(**self.optimizer.defaults) |
|
|
else: |
|
|
self.base_param_settings = None |
|
|
|
|
|
@abstractmethod |
|
|
def update_params(self, *args, **kwargs): |
|
|
"""Update parameters in :attr:`optimizer`.""" |
|
|
|
|
|
@abstractmethod |
|
|
def backward(self, loss: torch.Tensor, **kwargs) -> None: |
|
|
"""Perform gradient back propagation.""" |
|
|
|
|
|
@abstractmethod |
|
|
def zero_grad(self, **kwargs) -> None: |
|
|
"""A wrapper of ``Optimizer.zero_grad``.""" |
|
|
|
|
|
@abstractmethod |
|
|
def step(self, **kwargs): |
|
|
"""Call the step method of optimizer.""" |
|
|
|
|
|
def state_dict(self) -> dict: |
|
|
"""A wrapper of ``Optimizer.state_dict``.""" |
|
|
state_dict = self.optimizer.state_dict() |
|
|
if self.base_param_settings is not None: |
|
|
state_dict['base_param_settings'] = self.base_param_settings |
|
|
return state_dict |
|
|
|
|
|
def load_state_dict(self, state_dict: dict) -> None: |
|
|
"""A wrapper of ``Optimizer.load_state_dict``. load the state dict of |
|
|
:attr:`optimizer`. |
|
|
|
|
|
Provide unified ``load_state_dict`` interface compatible with automatic |
|
|
mixed precision training. Subclass can overload this method to |
|
|
implement the required logic. For example, the state dictionary of |
|
|
GradScaler should be loaded when training with ``torch.cuda.amp``. |
|
|
|
|
|
Args: |
|
|
state_dict (dict): The state dictionary of :attr:`optimizer`. |
|
|
""" |
|
|
base_param_settings = state_dict.pop('base_param_settings', None) |
|
|
|
|
|
if base_param_settings is not None: |
|
|
self.base_param_settings = base_param_settings |
|
|
|
|
|
|
|
|
self.optimizer.load_state_dict(state_dict) |
|
|
|
|
|
@property |
|
|
def param_groups(self) -> List[dict]: |
|
|
"""A wrapper of ``Optimizer.param_groups``. |
|
|
|
|
|
Make OptimizeWrapper compatible with :class:`_ParamScheduler`. |
|
|
|
|
|
Returns: |
|
|
dict: the ``param_groups`` of :attr:`optimizer`. |
|
|
""" |
|
|
if self.base_param_settings is not None: |
|
|
return self.optimizer.param_groups + [self.base_param_settings] |
|
|
else: |
|
|
return self.optimizer.param_groups |
|
|
|
|
|
@property |
|
|
def defaults(self) -> dict: |
|
|
"""A wrapper of ``Optimizer.defaults``. |
|
|
|
|
|
Make OptimizeWrapper compatible with :class:`_ParamScheduler`. |
|
|
|
|
|
Returns: |
|
|
dict: the ``param_groups`` of :attr:`optimizer`. |
|
|
""" |
|
|
return self.optimizer.defaults |
|
|
|
|
|
def get_lr(self): |
|
|
"""Get the learning rate of the optimizer. |
|
|
|
|
|
Provide unified interface to get learning rate of optimizer. |
|
|
|
|
|
Returns: |
|
|
Dict[str, List[float]]: |
|
|
param_groups learning rate of the optimizer. |
|
|
""" |
|
|
res = {} |
|
|
if self.base_param_settings is not None: |
|
|
res['base_lr'] = [self.base_param_settings['lr']] |
|
|
|
|
|
res['lr'] = [group['lr'] for group in self.optimizer.param_groups] |
|
|
|
|
|
return res |
|
|
|
|
|
def get_momentum(self) -> Dict[str, List[float]]: |
|
|
"""Get the momentum of the optimizer. |
|
|
|
|
|
Provide unified interface to get momentum of optimizer. |
|
|
|
|
|
Returns: |
|
|
Dict[str, List[float]]: Momentum of the optimizer. |
|
|
""" |
|
|
momentum = [] |
|
|
for group in self.optimizer.param_groups: |
|
|
|
|
|
if 'momentum' in group.keys(): |
|
|
momentum.append(group['momentum']) |
|
|
|
|
|
elif 'betas' in group.keys(): |
|
|
momentum.append(group['betas'][0]) |
|
|
else: |
|
|
momentum.append(0) |
|
|
return dict(momentum=momentum) |
|
|
|