Spaces:
Running
Running
import json | |
import torch | |
import torch.nn as nn | |
import re | |
def get_optim_params(cfg: list, model: nn.Module): | |
""" | |
E.g.: | |
^(?=.*a)(?=.*b).*$ means including a and b | |
^(?=.*(?:a|b)).*$ means including a or b | |
^(?=.*a)(?!.*b).*$ means including a, but not b | |
""" | |
param_groups = [] | |
visited = [] | |
for pg in cfg: | |
pattern = pg['params'] | |
params = {k: v for k, v in model.named_parameters() if v.requires_grad and len(re.findall(pattern, k)) > 0} | |
pg['params'] = params.values() | |
param_groups.append(pg) | |
visited.extend(list(params.keys())) | |
names = [k for k, v in model.named_parameters() if v.requires_grad] | |
if len(visited) < len(names): | |
unseen = set(names) - set(visited) | |
params = {k: v for k, v in model.named_parameters() if v.requires_grad and k in unseen} | |
param_groups.append({'params': params.values()}) | |
visited.extend(list(params.keys())) | |
assert len(visited) == len(names), '' | |
return param_groups | |