File size: 4,915 Bytes
f239efc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
""" Optimizer Factory w/ Custom Weight Decay
Hacked together by / Copyright 2020 Ross Wightman
"""
import re
import torch
from torch import optim as optim
from utils.distributed import is_main_process
import logging
logger = logging.getLogger(__name__)
try:
    from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
    has_apex = True
except ImportError:
    has_apex = False


def add_weight_decay(model, weight_decay, no_decay_list=(), filter_bias_and_bn=True):
    named_param_tuples = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue  # frozen weights
        if filter_bias_and_bn and (len(param.shape) == 1 or name.endswith(".bias")):
            named_param_tuples.append([name, param, 0])
        elif name in no_decay_list:
            named_param_tuples.append([name, param, 0])
        else:
            named_param_tuples.append([name, param, weight_decay])
    return named_param_tuples


def add_different_lr(named_param_tuples_or_model, diff_lr_names, diff_lr, default_lr):
    """use lr=diff_lr for modules named found in diff_lr_names,
    otherwise use lr=default_lr

    Args:
        named_param_tuples_or_model: List([name, param, weight_decay]), or nn.Module
        diff_lr_names: List(str)
        diff_lr: float
        default_lr: float
    Returns:
        named_param_tuples_with_lr: List([name, param, weight_decay, lr])
    """
    named_param_tuples_with_lr = []
    logger.info(f"diff_names: {diff_lr_names}, diff_lr: {diff_lr}")
    for name, p, wd in named_param_tuples_or_model:
        use_diff_lr = False
        for diff_name in diff_lr_names:
            # if diff_name in name:
            if re.search(diff_name, name) is not None:
                logger.info(f"param {name} use different_lr: {diff_lr}")
                use_diff_lr = True
                break

        named_param_tuples_with_lr.append(
            [name, p, wd, diff_lr if use_diff_lr else default_lr]
        )

    if is_main_process():
        for name, _, wd, diff_lr in named_param_tuples_with_lr:
            logger.info(f"param {name}: wd: {wd}, lr: {diff_lr}")

    return named_param_tuples_with_lr


def create_optimizer_params_group(named_param_tuples_with_lr):
    """named_param_tuples_with_lr: List([name, param, weight_decay, lr])"""
    group = {}
    for name, p, wd, lr in named_param_tuples_with_lr:
        if wd not in group:
            group[wd] = {}
        if lr not in group[wd]:
            group[wd][lr] = []
        group[wd][lr].append(p)

    optimizer_params_group = []
    for wd, lr_groups in group.items():
        for lr, p in lr_groups.items():
            optimizer_params_group.append(dict(
                params=p,
                weight_decay=wd,
                lr=lr
            ))
            logger.info(f"optimizer -- lr={lr} wd={wd} len(p)={len(p)}")
    return optimizer_params_group


def create_optimizer(args, model, filter_bias_and_bn=True):
    opt_lower = args.opt.lower()
    weight_decay = args.weight_decay
    # check for modules that requires different lr
    if hasattr(args, "different_lr") and args.different_lr.enable:
        diff_lr_module_names = args.different_lr.module_names
        diff_lr = args.different_lr.lr
    else:
        diff_lr_module_names = []
        diff_lr = None

    no_decay = {}
    if hasattr(model, 'no_weight_decay'):
        no_decay = model.no_weight_decay()
    named_param_tuples = add_weight_decay(
        model, weight_decay, no_decay, filter_bias_and_bn)
    named_param_tuples = add_different_lr(
        named_param_tuples, diff_lr_module_names, diff_lr, args.lr)
    parameters = create_optimizer_params_group(named_param_tuples)

    if 'fused' in opt_lower:
        assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'

    opt_args = dict(lr=args.lr, weight_decay=weight_decay)
    if hasattr(args, 'opt_eps') and args.opt_eps is not None:
        opt_args['eps'] = args.opt_eps
    if hasattr(args, 'opt_betas') and args.opt_betas is not None:
        opt_args['betas'] = args.opt_betas
    if hasattr(args, 'opt_args') and args.opt_args is not None:
        opt_args.update(args.opt_args)

    opt_split = opt_lower.split('_')
    opt_lower = opt_split[-1]
    if opt_lower == 'sgd' or opt_lower == 'nesterov':
        opt_args.pop('eps', None)
        optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
    elif opt_lower == 'momentum':
        opt_args.pop('eps', None)
        optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
    elif opt_lower == 'adam':
        optimizer = optim.Adam(parameters, **opt_args)
    elif opt_lower == 'adamw':
        optimizer = optim.AdamW(parameters, **opt_args)
    else:
        assert False and "Invalid optimizer"
        raise ValueError
    return optimizer