Spaces:
Runtime error
Runtime error
File size: 2,093 Bytes
dbac20f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import logging
log = logging.getLogger()
def get_parameter_groups(model, cfg, print_log=False):
"""
Assign different weight decays and learning rates to different parameters.
Returns a parameter group which can be passed to the optimizer.
"""
weight_decay = cfg.weight_decay
# embed_weight_decay = cfg.embed_weight_decay
# backbone_lr_ratio = cfg.backbone_lr_ratio
base_lr = cfg.learning_rate
backbone_params = []
embed_params = []
other_params = []
# embedding_names = ['summary_pos', 'query_init', 'query_emb', 'obj_pe']
# embedding_names = [e + '.weight' for e in embedding_names]
# inspired by detectron2
memo = set()
for name, param in model.named_parameters():
if not param.requires_grad:
continue
# Avoid duplicating parameters
if param in memo:
continue
memo.add(param)
if name.startswith('module'):
name = name[7:]
inserted = False
# if name.startswith('pixel_encoder.'):
# backbone_params.append(param)
# inserted = True
# if print_log:
# log.info(f'{name} counted as a backbone parameter.')
# else:
# for e in embedding_names:
# if name.endswith(e):
# embed_params.append(param)
# inserted = True
# if print_log:
# log.info(f'{name} counted as an embedding parameter.')
# break
# if not inserted:
other_params.append(param)
parameter_groups = [
# {
# 'params': backbone_params,
# 'lr': base_lr * backbone_lr_ratio,
# 'weight_decay': weight_decay
# },
# {
# 'params': embed_params,
# 'lr': base_lr,
# 'weight_decay': embed_weight_decay
# },
{
'params': other_params,
'lr': base_lr,
'weight_decay': weight_decay
},
]
return parameter_groups
|