File size: 4,588 Bytes
749745d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
import itertools

from .lr_scheduler import WarmupMultiStepLR, WarmupCosineAnnealingLR, WarmupReduceLROnPlateau


def make_optimizer(cfg, model):
    def maybe_add_full_model_gradient_clipping(optim):  # optim: the optimizer class
        # detectron2 doesn't have full model gradient clipping now
        clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
        enable = (
            cfg.SOLVER.CLIP_GRADIENTS.ENABLED
            and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
            and clip_norm_val > 0.0
        )

        class FullModelGradientClippingOptimizer(optim):
            def step(self, closure=None):
                all_params = itertools.chain(*[x["params"] for x in self.param_groups])
                torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
                super().step(closure=closure)

        return FullModelGradientClippingOptimizer if enable else optim

    params = []
    for key, value in model.named_parameters():
        if not value.requires_grad:
            print(key, "no grad")
            continue
        lr = cfg.SOLVER.BASE_LR
        weight_decay = cfg.SOLVER.WEIGHT_DECAY

        # different lr schedule
        if "language_backbone" in key:
            lr = cfg.SOLVER.LANG_LR

        if "backbone.body" in key and "language_backbone.body" not in key:
            lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BACKBONE_BODY_LR_FACTOR

        if "t2i" in key:
            lr = lr * cfg.SOLVER.FUSION_LR_FACTOR

        if "bias" in key:
            lr *= cfg.SOLVER.BIAS_LR_FACTOR
            weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS

        if "norm" in key or "Norm" in key:
            weight_decay *= cfg.SOLVER.WEIGHT_DECAY_NORM_FACTOR
            print("Setting weight decay of {} to {}".format(key, weight_decay))
        print(key, lr, )
        params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]

    if cfg.SOLVER.OPTIMIZER == "SGD":
        optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(params, lr, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.OPTIMIZER == "ADAMW":
        optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(params, lr)

    return optimizer


def make_lr_scheduler(cfg, optimizer):
    if cfg.SOLVER.MULTI_MAX_EPOCH:
        assert len(cfg.SOLVER.MULTI_MAX_EPOCH) == len(cfg.SOLVER.STEPS)
        lr_scheduler = []

        for stage_step, stage_max_epoch in zip(cfg.SOLVER.STEPS, cfg.SOLVER.MULTI_MAX_ITER):
            milestones = []
            for step in stage_step:
                milestones.append(round(step * stage_max_epoch))
            lr_scheduler.append(
                WarmupMultiStepLR(
                    optimizer,
                    milestones,
                    cfg.SOLVER.GAMMA,
                    warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
                    warmup_iters=cfg.SOLVER.WARMUP_ITERS,
                    warmup_method=cfg.SOLVER.WARMUP_METHOD,
                )
            )
        return lr_scheduler

    elif cfg.SOLVER.USE_COSINE:
        max_iters = cfg.SOLVER.MAX_ITER
        return WarmupCosineAnnealingLR(
            optimizer,
            max_iters,
            cfg.SOLVER.GAMMA,
            warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
            warmup_iters=cfg.SOLVER.WARMUP_ITERS,
            warmup_method=cfg.SOLVER.WARMUP_METHOD,
            eta_min=cfg.SOLVER.MIN_LR,
        )

    elif cfg.SOLVER.USE_AUTOSTEP:
        max_iters = cfg.SOLVER.MAX_ITER
        return WarmupReduceLROnPlateau(
            optimizer,
            max_iters,
            cfg.SOLVER.GAMMA,
            warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
            warmup_iters=cfg.SOLVER.WARMUP_ITERS,
            warmup_method=cfg.SOLVER.WARMUP_METHOD,
            eta_min=cfg.SOLVER.MIN_LR,
            patience=cfg.SOLVER.STEP_PATIENCE,
            verbose=True,
        )

    else:
        milestones = []
        for step in cfg.SOLVER.STEPS:
            if step < 1:
                milestones.append(round(step * cfg.SOLVER.MAX_ITER))
            else:
                milestones.append(step)
        return WarmupMultiStepLR(
            optimizer,
            milestones,
            cfg.SOLVER.GAMMA,
            warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
            warmup_iters=cfg.SOLVER.WARMUP_ITERS,
            warmup_method=cfg.SOLVER.WARMUP_METHOD,
        )