|
import copy |
|
import torch |
|
import itertools |
|
from enum import Enum |
|
from uniperceiver.config import CfgNode |
|
from uniperceiver.utils.registry import Registry |
|
from uniperceiver.utils import comm |
|
|
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union |
|
|
|
SOLVER_REGISTRY = Registry("SOLVER") |
|
SOLVER_REGISTRY.__doc__ = """ |
|
Registry for SOLVER. |
|
""" |
|
|
|
_GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]] |
|
_GradientClipper = Callable[[_GradientClipperInput], None] |
|
|
|
def _create_gradient_clipper(cfg: CfgNode) -> _GradientClipper: |
|
def clip_grad_norm(p: _GradientClipperInput): |
|
torch.nn.utils.clip_grad_norm_(p, cfg.SOLVER.GRAD_CLIP, cfg.SOLVER.NORM_TYPE) |
|
|
|
def clip_grad_value(p: _GradientClipperInput): |
|
torch.nn.utils.clip_grad_value_(p, cfg.SOLVER.GRAD_CLIP) |
|
|
|
_GRADIENT_CLIP_TYPE_TO_CLIPPER = { |
|
'value': clip_grad_value, |
|
'norm': clip_grad_norm, |
|
} |
|
clipper = _GRADIENT_CLIP_TYPE_TO_CLIPPER[cfg.SOLVER.GRAD_CLIP_TYPE] |
|
if cfg.SOLVER.GRAD_CLIP_TYPE == 'value': |
|
return clipper, None |
|
else: |
|
return None, clipper |
|
|
|
|
|
def get_default_optimizer_params( |
|
model: torch.nn.Module, |
|
base_lr: Optional[float] = None, |
|
weight_decay: Optional[float] = None, |
|
weight_decay_norm: Optional[float] = None, |
|
bias_lr_factor: Optional[float] = 1.0, |
|
weight_decay_bias: Optional[float] = None, |
|
overrides: Optional[Dict[str, Dict[str, float]]] = None, |
|
): |
|
if weight_decay_bias is None: |
|
weight_decay_bias = weight_decay |
|
norm_module_types = ( |
|
torch.nn.BatchNorm1d, |
|
torch.nn.BatchNorm2d, |
|
torch.nn.BatchNorm3d, |
|
torch.nn.SyncBatchNorm, |
|
|
|
torch.nn.GroupNorm, |
|
torch.nn.InstanceNorm1d, |
|
torch.nn.InstanceNorm2d, |
|
torch.nn.InstanceNorm3d, |
|
torch.nn.LayerNorm, |
|
torch.nn.LocalResponseNorm, |
|
) |
|
params: List[Dict[str, Any]] = [] |
|
memo: Set[torch.nn.parameter.Parameter] = set() |
|
|
|
no_decay_list = {} |
|
if hasattr(model, 'no_weight_decay'): |
|
no_decay_list = model.no_weight_decay() |
|
|
|
for module_name, module in model.named_modules(): |
|
no_decay = False |
|
if module_name in no_decay_list: |
|
no_decay = True |
|
for module_param_name, value in module.named_parameters(recurse=False): |
|
if not value.requires_grad: |
|
continue |
|
|
|
if value in memo: |
|
continue |
|
memo.add(value) |
|
|
|
schedule_params = { |
|
"lr": base_lr, |
|
"weight_decay": weight_decay, |
|
} |
|
|
|
|
|
if isinstance(module, norm_module_types): |
|
schedule_params["weight_decay"] = weight_decay_norm |
|
elif module_param_name == "bias": |
|
|
|
|
|
|
|
|
|
schedule_params["lr"] = base_lr * bias_lr_factor |
|
schedule_params["weight_decay"] = weight_decay_bias |
|
|
|
if no_decay or (module_param_name in no_decay_list): |
|
schedule_params["weight_decay"] = 0. |
|
|
|
|
|
if overrides is not None and module_param_name in overrides: |
|
schedule_params.update(overrides[module_param_name]) |
|
params += [ |
|
{ |
|
"params": [value], |
|
"lr": schedule_params["lr"], |
|
"weight_decay": schedule_params["weight_decay"], |
|
} |
|
] |
|
|
|
return params |
|
|
|
def get_layer_id(module_name, num_layers): |
|
""" |
|
Assign a parameter with its layer id |
|
modified from BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 |
|
""" |
|
if module_name.split('.')[0] in [ |
|
'video_embed', 'token_embed', 'prompt_embed', 'visual_embed', 'cls_token' '' |
|
]: |
|
return 0 |
|
elif module_name.startswith('encoder'): |
|
return int(module_name.split('.')[2]) + 1 |
|
elif module_name.startswith('predictor'): |
|
return num_layers |
|
else: |
|
raise NotImplementedError('please check this layer') |
|
|
|
def create_seperate_moe_param_groups( |
|
model, |
|
base_lr: Optional[float] = None, |
|
weight_decay: Optional[float] = None, |
|
weight_decay_norm: Optional[float] = None, |
|
bias_lr_factor: Optional[float] = 1.0, |
|
wg_lr_facetor: Optional[float] = 1.0, |
|
weight_decay_bias: Optional[float] = None, |
|
weight_decay_embedding: Optional[float] = None, |
|
weight_decay_wg: Optional[float] = None, |
|
cfg: dict = None, |
|
): |
|
try: |
|
from deepspeed.moe.utils import is_moe_param |
|
except: |
|
def is_moe_param(param: torch.Tensor) -> bool: |
|
if hasattr(param, "allreduce") and not param.allreduce: |
|
return True |
|
return False |
|
|
|
params: List[Dict[str, Any]] = [] |
|
memo: Set[torch.nn.parameter.Parameter] = set() |
|
|
|
num_layers = cfg.MODEL.BERT.NUM_HIDDEN_LAYERS + 1 |
|
layer_decay = cfg.SOLVER.LAYER_LR_DECAY |
|
layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) |
|
|
|
|
|
if weight_decay_bias is None: |
|
weight_decay_bias = weight_decay |
|
norm_module_types = ( |
|
torch.nn.BatchNorm1d, |
|
torch.nn.BatchNorm2d, |
|
torch.nn.BatchNorm3d, |
|
torch.nn.SyncBatchNorm, |
|
|
|
torch.nn.GroupNorm, |
|
torch.nn.InstanceNorm1d, |
|
torch.nn.InstanceNorm2d, |
|
torch.nn.InstanceNorm3d, |
|
torch.nn.LayerNorm, |
|
torch.nn.LocalResponseNorm, |
|
) |
|
|
|
|
|
|
|
|
|
no_decay_list = {} |
|
if hasattr(model, 'no_weight_decay'): |
|
no_decay_list = model.no_weight_decay() |
|
|
|
wg_list = {} |
|
if hasattr(model, 'expert_gate_group'): |
|
wg_list = model.expert_gate_group() |
|
|
|
|
|
|
|
for module_name, module in model.named_modules(): |
|
no_decay = False |
|
if module_name in no_decay_list: |
|
no_decay = True |
|
is_wg_param = False |
|
for wg_name in wg_list: |
|
if wg_name in module_name: |
|
is_wg_param = True |
|
continue |
|
|
|
for module_param_name, value in module.named_parameters(recurse=False): |
|
|
|
this_scale = layer_scales[ get_layer_id(module_name, num_layers)] if layer_decay < 1.0 else 1.0 |
|
|
|
|
|
if not value.requires_grad: |
|
continue |
|
|
|
if value in memo: |
|
continue |
|
memo.add(value) |
|
schedule_params = { |
|
"lr": base_lr, |
|
"weight_decay": weight_decay, |
|
"moe": False, |
|
} |
|
if is_moe_param(value): |
|
schedule_params['moe'] = True |
|
|
|
if no_decay or (module_param_name in no_decay_list): |
|
schedule_params["weight_decay"] = 0. |
|
elif is_wg_param and isinstance( |
|
module, |
|
torch.nn.Linear) and module_param_name != "bias": |
|
|
|
schedule_params["lr"] = base_lr * wg_lr_facetor |
|
schedule_params["weight_decay"] = weight_decay_wg |
|
|
|
elif isinstance(module, torch.nn.Embedding): |
|
schedule_params['weight_decay'] = weight_decay_embedding |
|
|
|
elif isinstance(module, norm_module_types): |
|
if not cfg.SOLVER.WEIGHT_DECAY_NORMBIAS_WEIGHT and module_param_name == "bias": |
|
|
|
schedule_params["lr"] = base_lr * bias_lr_factor |
|
schedule_params['weight_decay'] = weight_decay_bias |
|
else: |
|
schedule_params['weight_decay'] = weight_decay_norm |
|
|
|
elif module_param_name == "bias" or value.ndim == 1: |
|
schedule_params["lr"] = base_lr * bias_lr_factor |
|
schedule_params['weight_decay'] = weight_decay_bias |
|
|
|
params += [{ |
|
"params": [value], |
|
"lr": max(schedule_params["lr"] * this_scale, cfg.LR_SCHEDULER.get('MIN_LR', 1e-6)), |
|
"moe": schedule_params['moe'], |
|
"weight_decay": schedule_params["weight_decay"], |
|
"name": f'{module_name}.{module_param_name}' |
|
}] |
|
|
|
|
|
|
|
return params |
|
|
|
|
|
def create_group_moe_param_groups( |
|
model, |
|
base_lr: Optional[float] = None, |
|
weight_decay: Optional[float] = None, |
|
weight_decay_norm: Optional[float] = None, |
|
bias_lr_factor: Optional[float] = 1.0, |
|
wg_lr_facetor: Optional[float] = 1.0, |
|
weight_decay_bias: Optional[float] = None, |
|
weight_decay_embedding: Optional[float] = None, |
|
weight_decay_wg: Optional[float] = None, |
|
cfg: dict = None, |
|
): |
|
from deepspeed.moe.utils import is_moe_param |
|
|
|
|
|
memo: Set[torch.nn.parameter.Parameter] = set() |
|
|
|
if weight_decay_bias is None: |
|
weight_decay_bias = weight_decay |
|
norm_module_types = ( |
|
torch.nn.BatchNorm1d, |
|
torch.nn.BatchNorm2d, |
|
torch.nn.BatchNorm3d, |
|
torch.nn.SyncBatchNorm, |
|
torch.nn.GroupNorm, |
|
torch.nn.InstanceNorm1d, |
|
torch.nn.InstanceNorm2d, |
|
torch.nn.InstanceNorm3d, |
|
torch.nn.LayerNorm, |
|
torch.nn.LocalResponseNorm, |
|
) |
|
|
|
group_params_dict = {} |
|
|
|
no_decay_list = {} |
|
if hasattr(model, 'no_weight_decay'): |
|
no_decay_list = model.no_weight_decay() |
|
|
|
wg_list = {} |
|
if hasattr(model, 'expert_gate_group'): |
|
wg_list = model.expert_gate_group() |
|
|
|
for module_name, module in model.named_modules(): |
|
no_decay = False |
|
if module_name in no_decay_list: |
|
no_decay = True |
|
is_wg_param = False |
|
for wg_name in wg_list: |
|
if wg_name in module_name: |
|
is_wg_param = True |
|
continue |
|
|
|
for module_param_name, value in module.named_parameters(recurse=False): |
|
if not value.requires_grad: |
|
continue |
|
|
|
if value in memo: |
|
continue |
|
memo.add(value) |
|
|
|
|
|
lr_of_this_param = base_lr |
|
wd_of_this_param = weight_decay |
|
moe_of_this_param = False |
|
if is_moe_param(value): |
|
moe_of_this_param = True |
|
|
|
if no_decay or (module_param_name in no_decay_list): |
|
|
|
wd_of_this_param = 0. |
|
elif is_wg_param and isinstance( |
|
module, torch.nn.Linear) and module_param_name != "bias": |
|
|
|
lr_of_this_param = base_lr * wg_lr_facetor |
|
wd_of_this_param = weight_decay_wg |
|
|
|
elif isinstance(module, torch.nn.Embedding): |
|
wd_of_this_param = weight_decay_embedding |
|
|
|
elif isinstance(module, norm_module_types): |
|
if not cfg.SOLVER.WEIGHT_DECAY_NORMBIAS_WEIGHT and module_param_name == "bias": |
|
|
|
lr_of_this_param = base_lr * bias_lr_factor |
|
wd_of_this_param = weight_decay_bias |
|
else: |
|
wd_of_this_param = weight_decay_norm |
|
|
|
elif module_param_name == "bias": |
|
lr_of_this_param = base_lr * bias_lr_factor |
|
wd_of_this_param = weight_decay_bias |
|
|
|
param_group_name = f'lr_{lr_of_this_param}_wd_{wd_of_this_param}_moe_{moe_of_this_param}' |
|
if param_group_name not in group_params_dict: |
|
group_params_dict[param_group_name] = { |
|
'params': [], |
|
"lr": lr_of_this_param, |
|
"weight_decay": wd_of_this_param, |
|
'moe': moe_of_this_param, |
|
'name': param_group_name, |
|
'params_name': [], |
|
} |
|
group_params_dict[param_group_name]['params'].append(value) |
|
group_params_dict[param_group_name]['params_name'].append( |
|
f'{module_name}.{module_param_name}') |
|
|
|
|
|
valid_params_groups = list(group_params_dict.values()) |
|
return valid_params_groups |
|
|
|
|
|
|
|
|
|
def create_moe_param_groups( |
|
model, |
|
base_lr: Optional[float] = None, |
|
weight_decay: Optional[float] = None, |
|
weight_decay_norm: Optional[float] = None, |
|
bias_lr_factor: Optional[float] = 1.0, |
|
wg_lr_facetor: Optional[float] = 1.0, |
|
weight_decay_bias: Optional[float] = None, |
|
weight_decay_embedding: Optional[float] = None, |
|
weight_decay_wg: Optional[float] = None, |
|
|
|
): |
|
from deepspeed.moe.utils import is_moe_param |
|
|
|
''' |
|
name: |
|
''' |
|
if weight_decay_bias is None: |
|
weight_decay_bias = weight_decay |
|
norm_module_types = ( |
|
torch.nn.BatchNorm1d, |
|
torch.nn.BatchNorm2d, |
|
torch.nn.BatchNorm3d, |
|
torch.nn.SyncBatchNorm, |
|
torch.nn.GroupNorm, |
|
torch.nn.InstanceNorm1d, |
|
torch.nn.InstanceNorm2d, |
|
torch.nn.InstanceNorm3d, |
|
torch.nn.LayerNorm, |
|
torch.nn.LocalResponseNorm, |
|
) |
|
|
|
if weight_decay_embedding == 0.0: |
|
norm_module_types = norm_module_types + (torch.nn.Embedding, ) |
|
else: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
params_with_weight_decay = { |
|
'params': [], |
|
'name': 'weight_decay_params', |
|
'params_name': [], |
|
} |
|
params_without_weight_decay = { |
|
'params': [], |
|
"weight_decay": 0.0, |
|
'name': 'without_weight_decay_params', |
|
'params_name': [], |
|
} |
|
bias_params = { |
|
'params': [], |
|
"lr": base_lr * bias_lr_factor, |
|
"weight_decay": weight_decay_bias, |
|
'name': 'bias_params', |
|
'params_name': [], |
|
} |
|
wg_params = { |
|
'params': [], |
|
"lr": base_lr * wg_lr_facetor, |
|
"weight_decay": weight_decay_wg, |
|
'name': 'wg_params', |
|
'params_name': [], |
|
} |
|
norm_params = { |
|
'params': [], |
|
"weight_decay": weight_decay_norm, |
|
'name': 'norm_params', |
|
'params_name': [], |
|
} |
|
moe_params_with_weight_decay = { |
|
'params': [], |
|
'moe': True, |
|
'name': 'weight_decay_moe_params', |
|
'params_name': [], |
|
} |
|
moe_params_without_weight_decay = { |
|
'params': [], |
|
"weight_decay": 0.0, |
|
'moe': True, |
|
'name': 'without_weight_decay_moe_params', |
|
'params_name': [], |
|
} |
|
moe_bias_params = { |
|
'params': [], |
|
"lr": base_lr * bias_lr_factor, |
|
"weight_decay": weight_decay_bias, |
|
'moe': True, |
|
'name': 'bias_moe_params', |
|
'params_name': [], |
|
} |
|
moe_norm_params = { |
|
'params': [], |
|
"weight_decay": weight_decay_norm, |
|
'moe': True, |
|
'name': 'norm_moe_params', |
|
'params_name': [], |
|
} |
|
|
|
params_groups = [ |
|
params_with_weight_decay, params_without_weight_decay, norm_params, bias_params, wg_params, \ |
|
moe_params_with_weight_decay, moe_params_without_weight_decay, moe_norm_params, moe_bias_params |
|
] |
|
|
|
|
|
|
|
no_decay_list = {} |
|
if hasattr(model, 'no_weight_decay'): |
|
no_decay_list = model.no_weight_decay() |
|
|
|
wg_list = {} |
|
if hasattr(model, 'expert_gate_group'): |
|
wg_list = model.expert_gate_group() |
|
|
|
memo: Set[torch.nn.parameter.Parameter] = set() |
|
|
|
for module_name, module in model.named_modules(): |
|
no_decay = False |
|
if module_name in no_decay_list: |
|
no_decay = True |
|
is_wg_param = False |
|
for wg_name in wg_list: |
|
if wg_name in module_name: |
|
is_wg_param = True |
|
continue |
|
|
|
for module_param_name, value in module.named_parameters(recurse=False): |
|
if not value.requires_grad: |
|
continue |
|
|
|
if value in memo: |
|
continue |
|
memo.add(value) |
|
if is_moe_param(value): |
|
if no_decay or (module_param_name in no_decay_list): |
|
moe_params_without_weight_decay['params'].append(value) |
|
elif isinstance(module, norm_module_types): |
|
moe_norm_params['params'].append(value) |
|
elif module_param_name == "bias": |
|
moe_bias_params['params'].append(value) |
|
else: |
|
moe_params_with_weight_decay['params'].append(value) |
|
else: |
|
if no_decay or (module_param_name in no_decay_list): |
|
params_without_weight_decay['params'].append(value) |
|
params_without_weight_decay['params_name'].append(f'{module_name}.{module_param_name}') |
|
elif is_wg_param and isinstance(module, torch.nn.Linear) and module_param_name != "bias": |
|
|
|
wg_params['params'].append(value) |
|
wg_params['params_name'].append( |
|
f'{module_name}.{module_param_name}') |
|
elif isinstance(module, norm_module_types): |
|
norm_params['params'].append(value) |
|
norm_params['params_name'].append( |
|
f'{module_name}.{module_param_name}') |
|
elif module_param_name == "bias": |
|
bias_params['params'].append(value) |
|
bias_params['params_name'].append( |
|
f'{module_name}.{module_param_name}') |
|
else: |
|
params_with_weight_decay['params'].append(value) |
|
params_with_weight_decay['params_name'].append( |
|
f'{module_name}.{module_param_name}') |
|
|
|
valid_params_groups = [ |
|
group for group in params_groups if len(group['params']) > 0 |
|
] |
|
|
|
return valid_params_groups |
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generate_optimizer_class_with_gradient_clipping( |
|
optimizer: Type[torch.optim.Optimizer], |
|
*, |
|
per_param_clipper: Optional[_GradientClipper] = None, |
|
global_clipper: Optional[_GradientClipper] = None, |
|
) -> Type[torch.optim.Optimizer]: |
|
""" |
|
Dynamically creates a new type that inherits the type of a given instance |
|
and overrides the `step` method to add gradient clipping |
|
""" |
|
assert ( |
|
per_param_clipper is None or global_clipper is None |
|
), "Not allowed to use both per-parameter clipping and global clipping" |
|
|
|
def optimizer_wgc_step(self, closure=None): |
|
if per_param_clipper is not None: |
|
for group in self.param_groups: |
|
for p in group["params"]: |
|
per_param_clipper(p) |
|
else: |
|
|
|
|
|
all_params = itertools.chain(*[g["params"] for g in self.param_groups]) |
|
norm_before_clip = global_clipper(all_params) |
|
|
|
super(type(self), self).step(closure) |
|
|
|
OptimizerWithGradientClip = type( |
|
optimizer.__name__ + "WithGradientClip", |
|
(optimizer,), |
|
{"step": optimizer_wgc_step}, |
|
) |
|
return OptimizerWithGradientClip |
|
|
|
def maybe_add_gradient_clipping( |
|
cfg: CfgNode, optimizer: Type[torch.optim.Optimizer] |
|
) -> Type[torch.optim.Optimizer]: |
|
""" |
|
If gradient clipping is enabled through config options, wraps the existing |
|
optimizer type to become a new dynamically created class OptimizerWithGradientClip |
|
that inherits the given optimizer and overrides the `step` method to |
|
include gradient clipping. |
|
|
|
Args: |
|
cfg: CfgNode, configuration options |
|
optimizer: type. A subclass of torch.optim.Optimizer |
|
|
|
Return: |
|
type: either the input `optimizer` (if gradient clipping is disabled), or |
|
a subclass of it with gradient clipping included in the `step` method. |
|
""" |
|
if cfg.SOLVER.GRAD_CLIP <= 0: |
|
return optimizer |
|
if isinstance(optimizer, torch.optim.Optimizer): |
|
optimizer_type = type(optimizer) |
|
else: |
|
assert issubclass(optimizer, torch.optim.Optimizer), optimizer |
|
optimizer_type = optimizer |
|
|
|
per_param_clipper, global_clipper = _create_gradient_clipper(cfg) |
|
OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping( |
|
optimizer_type, per_param_clipper=per_param_clipper, global_clipper=global_clipper |
|
) |
|
if isinstance(optimizer, torch.optim.Optimizer): |
|
optimizer.__class__ = OptimizerWithGradientClip |
|
return optimizer |
|
else: |
|
return OptimizerWithGradientClip |
|
|
|
def build_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer: |
|
""" |
|
Build an optimizer from config. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
params = create_seperate_moe_param_groups( |
|
model, |
|
base_lr=cfg.SOLVER.BASE_LR, |
|
weight_decay=cfg.SOLVER.WEIGHT_DECAY, |
|
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM, |
|
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR, |
|
wg_lr_facetor=cfg.SOLVER.WG_LR_FACTOR, |
|
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS, |
|
weight_decay_embedding=cfg.SOLVER.WEIGHT_DECAY_EMBEDDING, |
|
weight_decay_wg=cfg.SOLVER.WEIGHT_DECAY_WG, |
|
cfg=cfg, |
|
) |
|
if cfg.SOLVER.NAME == 'LAMB': |
|
from uniperceiver.optim import LAMB |
|
optimizer = LAMB( |
|
params, |
|
lr=cfg.SOLVER.BASE_LR, |
|
betas=cfg.SOLVER.BETAS, |
|
eps=cfg.SOLVER.EPS, |
|
weight_decay=cfg.SOLVER.WEIGHT_DECAY, ) |
|
|
|
else: |
|
optimizer = torch.optim.AdamW( |
|
params, |
|
lr=cfg.SOLVER.BASE_LR, |
|
betas=cfg.SOLVER.BETAS, |
|
eps=cfg.SOLVER.EPS, |
|
weight_decay=cfg.SOLVER.WEIGHT_DECAY, |
|
) |
|
|
|
|
|
return optimizer |
|
|