|
|
|
|
|
|
|
|
|
|
| import logging
|
| from collections import defaultdict
|
|
|
| logger = logging.getLogger("dinov3")
|
|
|
|
|
| def get_vit_lr_decay_rate(
|
| name,
|
| lr_decay_rate=1.0,
|
| num_layers=12,
|
| force_is_backbone=False,
|
| chunked_blocks=False,
|
| ):
|
| """
|
| Calculate lr decay rate for different ViT blocks.
|
| Args:
|
| name (string): parameter name.
|
| lr_decay_rate (float): base lr decay rate.
|
| num_layers (int): number of ViT blocks.
|
| Returns:
|
| lr decay rate for the given parameter.
|
| """
|
| layer_id = num_layers + 1
|
| if name.startswith("backbone") or force_is_backbone:
|
| if (
|
| ".pos_embed" in name
|
| or ".patch_embed" in name
|
| or ".mask_token" in name
|
| or ".cls_token" in name
|
| or ".storage_tokens" in name
|
| ):
|
| layer_id = 0
|
| elif force_is_backbone and (
|
| "pos_embed" in name
|
| or "patch_embed" in name
|
| or "mask_token" in name
|
| or "cls_token" in name
|
| or "storage_tokens" in name
|
| ):
|
| layer_id = 0
|
| elif ".blocks." in name and ".residual." not in name:
|
| layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
|
| elif chunked_blocks and "blocks." in name and "residual." not in name:
|
| layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
|
| elif "blocks." in name and "residual." not in name:
|
| layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
|
|
|
| return lr_decay_rate ** (num_layers + 1 - layer_id)
|
|
|
|
|
| def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0, dino_head_wd_multiplier=1.0):
|
| chunked_blocks = False
|
| if hasattr(model, "n_blocks"):
|
| logger.info("chunked fsdp")
|
| n_blocks = model.n_blocks
|
| chunked_blocks = model.chunked_blocks
|
| elif hasattr(model, "blocks"):
|
| logger.info("first code branch")
|
| n_blocks = len(model.blocks)
|
| elif hasattr(model, "backbone"):
|
| logger.info("second code branch")
|
| n_blocks = len(model.backbone.blocks)
|
| else:
|
| logger.info("else code branch")
|
| n_blocks = 0
|
| all_param_groups = []
|
|
|
| for name, param in model.named_parameters():
|
| name = remove_fsdp_compile_names(name)
|
| if not param.requires_grad:
|
| continue
|
| decay_rate = get_vit_lr_decay_rate(
|
| name,
|
| lr_decay_rate,
|
| num_layers=n_blocks,
|
| force_is_backbone=n_blocks > 0,
|
| chunked_blocks=chunked_blocks,
|
| )
|
| d = {
|
| "name": name,
|
| "params": param,
|
| "is_last_layer": False,
|
| "lr_multiplier": decay_rate,
|
| "wd_multiplier": 1.0,
|
| }
|
|
|
| if "dino_head" in name:
|
| d["wd_multiplier"] = dino_head_wd_multiplier
|
|
|
| if "last_layer" in name:
|
| d["is_last_layer"] = True
|
|
|
|
|
| if name.endswith("bias") or "norm" in name or "gamma" in name or "fourier_w" in name:
|
| d["wd_multiplier"] = 0.0
|
|
|
| if "patch_embed" in name:
|
| d["lr_multiplier"] *= patch_embed_lr_mult
|
|
|
| all_param_groups.append(d)
|
| logger.info(f"{name}: lr_multiplier: {d['lr_multiplier']}, wd_multiplier: {d['wd_multiplier']}")
|
|
|
| return all_param_groups
|
|
|
|
|
| def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
|
| fused_params_groups = defaultdict(lambda: {"params": []})
|
| for d in all_params_groups:
|
| identifier = ""
|
| for k in keys:
|
| identifier += k + str(d[k]) + "_"
|
|
|
| for k in keys:
|
| fused_params_groups[identifier][k] = d[k]
|
| fused_params_groups[identifier]["params"].append(d["params"])
|
|
|
| return fused_params_groups.values()
|
|
|
|
|
| def get_params_groups_with_decay_fsdp(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0, dino_head_wd_multiplier=1.0):
|
| if hasattr(model, "module"):
|
| is_backbone = hasattr(model.module, "blocks")
|
| n_blocks = len(model.module.blocks) if is_backbone else 0
|
| else:
|
| is_backbone = hasattr(model, "blocks")
|
| n_blocks = len(model.blocks) if is_backbone else 0
|
|
|
| all_param_groups = []
|
|
|
| for name, param in model.named_parameters():
|
| name = remove_fsdp_compile_names(name)
|
| if not param.requires_grad:
|
| continue
|
| decay_rate = get_vit_lr_decay_rate(
|
| name,
|
| lr_decay_rate,
|
| num_layers=n_blocks,
|
| force_is_backbone=n_blocks > 0,
|
| chunked_blocks=False,
|
| )
|
| d = {
|
| "name": name,
|
| "params": param,
|
| "is_last_layer": False,
|
| "lr_multiplier": decay_rate,
|
| "wd_multiplier": 1.0,
|
| }
|
|
|
| if "dino_head" in name:
|
| d["wd_multiplier"] = dino_head_wd_multiplier
|
|
|
| if "last_layer" in name:
|
| d["is_last_layer"] = True
|
|
|
|
|
| if name.endswith("bias") or "norm" in name or "gamma" in name or "fourier_w" in name:
|
| d["wd_multiplier"] = 0.0
|
|
|
| if "patch_embed" in name:
|
| d["lr_multiplier"] *= patch_embed_lr_mult
|
|
|
| all_param_groups.append(d)
|
| logger.info(f"{name}: lr_multiplier: {d['lr_multiplier']}, wd_multiplier: {d['wd_multiplier']}")
|
|
|
| return all_param_groups
|
|
|
|
|
| def remove_fsdp_compile_names(name: str):
|
| name = name.replace("_fsdp_wrapped_module.", "")
|
| name = name.replace("_checkpoint_wrapped_module.", "")
|
| name = name.replace("parametrizations.", "")
|
| name = name.removesuffix(".original")
|
| name = name.replace("module.", "")
|
| name = name.replace("_orig_mod.", "")
|
| return name
|
|
|