| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | import logging |
| |
|
| | from dataclasses import dataclass |
| | from enum import Enum, auto |
| | from typing import Any, Optional |
| |
|
| | import numpy as np |
| | from omegaconf import II, MISSING |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from fairseq import checkpoint_utils, tasks |
| | from omegaconf import open_dict |
| |
|
| | from fairseq.dataclass import FairseqDataclass |
| | from fairseq.models import BaseFairseqModel, register_model |
| | from .mae import interpolate_pos_embed |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class PredictionMode(Enum): |
| | MEAN_POOLING = auto() |
| | CLS_TOKEN = auto() |
| | LIN_SOFTMAX = auto() |
| |
|
| |
|
| | @dataclass |
| | class MaeImageClassificationConfig(FairseqDataclass): |
| | model_path: str = MISSING |
| | no_pretrained_weights: bool = False |
| | linear_classifier: bool = False |
| | num_classes: int = 1000 |
| | mixup: float = 0.8 |
| | cutmix: float = 1.0 |
| | label_smoothing: float = 0.1 |
| |
|
| | drop_path_rate: float = 0.1 |
| | layer_decay: float = 0.65 |
| |
|
| | mixup_prob: float = 1.0 |
| | mixup_switch_prob: float = 0.5 |
| | mixup_mode: str = "batch" |
| |
|
| | pretrained_model_args: Any = None |
| | data: str = II("task.data") |
| |
|
| | norm_eps: Optional[float] = None |
| |
|
| | remove_alibi: bool = False |
| |
|
| | |
| | encoder_dropout: float = 0 |
| | post_mlp_drop: float = 0 |
| | attention_dropout: float = 0 |
| | activation_dropout: float = 0.0 |
| | dropout_input: float = 0.0 |
| | layerdrop: float = 0.0 |
| |
|
| | prenet_layerdrop: float = 0 |
| | prenet_dropout: float = 0 |
| |
|
| | use_fc_norm: bool = True |
| | prediction_mode: PredictionMode = PredictionMode.MEAN_POOLING |
| |
|
| | no_decay_blocks: bool = True |
| |
|
| |
|
| | def get_layer_id_for_vit(name, num_layers): |
| | """ |
| | Assign a parameter with its layer id |
| | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 |
| | """ |
| | if name in ["cls_token", "pos_embed"]: |
| | return 0 |
| | elif name.startswith("patch_embed"): |
| | return 0 |
| | elif name.startswith("rel_pos_bias"): |
| | return num_layers - 1 |
| | elif name.startswith("blocks"): |
| | return int(name.split(".")[1]) + 1 |
| | else: |
| | return num_layers |
| |
|
| |
|
| | @register_model("mae_image_classification", dataclass=MaeImageClassificationConfig) |
| | class MaeImageClassificationModel(BaseFairseqModel): |
| | def __init__(self, cfg: MaeImageClassificationConfig): |
| | super().__init__() |
| | self.cfg = cfg |
| |
|
| | if cfg.pretrained_model_args is None: |
| | state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {}) |
| | pretrained_args = state.get("cfg", None) |
| |
|
| | pretrained_args.criterion = None |
| | pretrained_args.lr_scheduler = None |
| |
|
| | logger.info(pretrained_args.model) |
| |
|
| | with open_dict(pretrained_args.model): |
| | pretrained_args.model.drop_path_rate = cfg.drop_path_rate |
| | if cfg.norm_eps is not None: |
| | pretrained_args.model.norm_eps = cfg.norm_eps |
| |
|
| | cfg.pretrained_model_args = pretrained_args |
| |
|
| | logger.info(pretrained_args) |
| | else: |
| | state = None |
| | pretrained_args = cfg.pretrained_model_args |
| |
|
| | if "data" in pretrained_args.task: |
| | pretrained_args.task.data = cfg.data |
| | elif "image" in pretrained_args.task: |
| | pretrained_args.task.image.data = cfg.data |
| |
|
| | if "modalities" in pretrained_args.model: |
| | prenet_blocks = pretrained_args.model["modalities"]["image"]["prenet_depth"] |
| | model_blocks = pretrained_args.model["depth"] |
| | with open_dict(pretrained_args): |
| | dpr = np.linspace(0, cfg.drop_path_rate, model_blocks).tolist() |
| | pretrained_args.model["modalities"]["image"][ |
| | "start_drop_path_rate" |
| | ] = dpr[0] |
| | pretrained_args.model["modalities"]["image"][ |
| | "end_drop_path_rate" |
| | ] = max(0, dpr[prenet_blocks - 1]) |
| | pretrained_args.model["start_drop_path_rate"] = dpr[prenet_blocks] |
| | pretrained_args.model["end_drop_path_rate"] = dpr[-1] |
| |
|
| | if "mae_masking" in pretrained_args.model["modalities"]["image"]: |
| | del pretrained_args.model["modalities"]["image"]["mae_masking"] |
| |
|
| | if cfg.remove_alibi: |
| | pretrained_args.model["modalities"]["image"][ |
| | "use_alibi_encoder" |
| | ] = False |
| | if ( |
| | state is not None |
| | and "modality_encoders.IMAGE.alibi_bias" in state["model"] |
| | ): |
| | del state["model"]["modality_encoders.IMAGE.alibi_bias"] |
| |
|
| | pretrained_args.model["encoder_dropout"] = cfg.encoder_dropout |
| | pretrained_args.model["post_mlp_drop"] = cfg.post_mlp_drop |
| | pretrained_args.model["attention_dropout"] = cfg.attention_dropout |
| | pretrained_args.model["activation_dropout"] = cfg.activation_dropout |
| | pretrained_args.model["dropout_input"] = cfg.dropout_input |
| | pretrained_args.model["layerdrop"] = cfg.layerdrop |
| |
|
| | pretrained_args.model["modalities"]["image"][ |
| | "prenet_layerdrop" |
| | ] = cfg.prenet_layerdrop |
| | pretrained_args.model["modalities"]["image"][ |
| | "prenet_dropout" |
| | ] = cfg.prenet_dropout |
| | else: |
| | |
| | with open_dict(pretrained_args): |
| | pretrained_args.model["drop_path_rate"] = cfg.drop_path_rate |
| | pretrained_args.model["block_dropout"] = cfg.encoder_dropout |
| | pretrained_args.model["attention_dropout"] = cfg.attention_dropout |
| | pretrained_args.model["activation_dropout"] = cfg.activation_dropout |
| |
|
| | task = tasks.setup_task(pretrained_args.task) |
| | model = task.build_model(pretrained_args.model, from_checkpoint=True) |
| |
|
| | self.d2v_multi = "data2vec_multi" in pretrained_args.model._name |
| | self.linear_classifier = cfg.linear_classifier |
| |
|
| | self.model = model |
| |
|
| | if state is not None and not cfg.no_pretrained_weights: |
| | interpolate_pos_embed(model, state) |
| |
|
| | if "modality_encoders.IMAGE.positional_encoder.pos_embed" in state["model"]: |
| | state["model"][ |
| | "modality_encoders.IMAGE.positional_encoder.positions" |
| | ] = state["model"][ |
| | "modality_encoders.IMAGE.positional_encoder.pos_embed" |
| | ] |
| | del state["model"][ |
| | "modality_encoders.IMAGE.positional_encoder.pos_embed" |
| | ] |
| | if "modality_encoders.IMAGE.encoder_mask" in state["model"]: |
| | del state["model"]["modality_encoders.IMAGE.encoder_mask"] |
| |
|
| | model.load_state_dict(state["model"], strict=True) |
| |
|
| | if self.d2v_multi: |
| | model.remove_pretraining_modules(modality="image") |
| | else: |
| | model.remove_pretraining_modules() |
| |
|
| | if self.linear_classifier: |
| | model.requires_grad_(False) |
| |
|
| | self.fc_norm = None |
| | if self.cfg.use_fc_norm: |
| | self.fc_norm = nn.LayerNorm(pretrained_args.model.embed_dim, eps=1e-6) |
| | nn.init.constant_(self.fc_norm.bias, 0) |
| | nn.init.constant_(self.fc_norm.weight, 1.0) |
| |
|
| | self.head = nn.Linear(pretrained_args.model.embed_dim, cfg.num_classes) |
| |
|
| | nn.init.trunc_normal_(self.head.weight, std=0.02) |
| | nn.init.constant_(self.head.bias, 0) |
| |
|
| | self.mixup_fn = None |
| |
|
| | if cfg.mixup > 0 or cfg.cutmix > 0: |
| | from timm.data import Mixup |
| |
|
| | self.mixup_fn = Mixup( |
| | mixup_alpha=cfg.mixup, |
| | cutmix_alpha=cfg.cutmix, |
| | cutmix_minmax=None, |
| | prob=cfg.mixup_prob, |
| | switch_prob=cfg.mixup_switch_prob, |
| | mode=cfg.mixup_mode, |
| | label_smoothing=cfg.label_smoothing, |
| | num_classes=cfg.num_classes, |
| | ) |
| |
|
| | if self.model.norm is not None: |
| | for pn, p in self.model.norm.named_parameters(): |
| | if len(p.shape) == 1 or pn.endswith(".bias"): |
| | p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}} |
| |
|
| | if self.fc_norm is not None: |
| | for pn, p in self.fc_norm.named_parameters(): |
| | if len(p.shape) == 1 or pn.endswith(".bias"): |
| | p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}} |
| |
|
| | for pn, p in self.head.named_parameters(): |
| | if len(p.shape) == 1 or pn.endswith(".bias"): |
| | p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}} |
| |
|
| | if self.d2v_multi: |
| | mod_encs = list(model.modality_encoders.values()) |
| | assert len(mod_encs) == 1, len(mod_encs) |
| | blocks = list(mod_encs[0].context_encoder.blocks) + list(model.blocks) |
| | else: |
| | blocks = model.blocks |
| |
|
| | num_layers = len(blocks) + 1 |
| | layer_scales = list( |
| | cfg.layer_decay ** (num_layers - i) for i in range(num_layers + 1) |
| | ) |
| |
|
| | if self.d2v_multi: |
| | for n, p in self.model.named_parameters(): |
| | optimizer_override_dict = {} |
| |
|
| | if len(p.shape) == 1 or n.endswith(".bias"): |
| | optimizer_override_dict["weight_decay_scale"] = 0 |
| |
|
| | p.optim_overrides = {"optimizer": optimizer_override_dict} |
| |
|
| | if cfg.layer_decay > 0: |
| | for i, b in enumerate(blocks): |
| | lid = i + 1 |
| | if layer_scales[lid] == 1.0: |
| | continue |
| |
|
| | for n, p in b.named_parameters(): |
| | optim_override = getattr(p, "optim_overrides", {}) |
| | if "optimizer" not in optim_override: |
| | optim_override["optimizer"] = {} |
| |
|
| | if cfg.no_decay_blocks: |
| | optim_override["optimizer"]["lr_scale"] = layer_scales[lid] |
| | p.optim_overrides = optim_override |
| | else: |
| | optim_override["optimizer"] = { |
| | "lr_scale": layer_scales[lid] |
| | } |
| | p.optim_overrides = optim_override |
| |
|
| | else: |
| | for n, p in self.model.named_parameters(): |
| | optimizer_override_dict = {} |
| | layer_id = get_layer_id_for_vit(n, num_layers) |
| |
|
| | if len(p.shape) == 1 or n.endswith(".bias"): |
| | optimizer_override_dict["weight_decay_scale"] = 0 |
| |
|
| | if cfg.layer_decay > 0: |
| | optimizer_override_dict["lr_scale"] = layer_scales[layer_id] |
| | p.optim_overrides = {"optimizer": optimizer_override_dict} |
| |
|
| | @classmethod |
| | def build_model(cls, cfg: MaeImageClassificationConfig, task=None): |
| | """Build a new model instance.""" |
| |
|
| | return cls(cfg) |
| |
|
| | def forward( |
| | self, |
| | imgs, |
| | labels=None, |
| | ): |
| | if self.training and self.mixup_fn is not None and labels is not None: |
| | imgs, labels = self.mixup_fn(imgs, labels) |
| |
|
| | if self.linear_classifier: |
| | with torch.no_grad(): |
| | x = self.model_forward(imgs) |
| | else: |
| | x = self.model_forward(imgs) |
| |
|
| | if self.cfg.prediction_mode == PredictionMode.MEAN_POOLING: |
| | x = x.mean(dim=1) |
| | elif self.cfg.prediction_mode == PredictionMode.CLS_TOKEN: |
| | x = x[:, 0] |
| | elif self.cfg.prediction_mode == PredictionMode.LIN_SOFTMAX: |
| | dtype = x.dtype |
| | x = F.logsigmoid(x.float()) |
| | x = torch.logsumexp(x + x, dim=1) - torch.logsumexp(x + 1e-6, dim=1) |
| | x = x.clamp(max=0) |
| | x = x - torch.log(-(torch.expm1(x))) |
| | x = torch.nan_to_num(x, nan=0, posinf=0, neginf=0) |
| | x = x.to(dtype=dtype) |
| | else: |
| | raise Exception(f"unknown prediction mode {self.cfg.prediction_mode.name}") |
| |
|
| | if self.fc_norm is not None: |
| | x = self.fc_norm(x) |
| |
|
| | x = self.head(x) |
| |
|
| | if labels is None: |
| | return x |
| |
|
| | if self.training and self.mixup_fn is not None: |
| | loss = -labels * F.log_softmax(x.float(), dim=-1) |
| | else: |
| | loss = F.cross_entropy( |
| | x.float(), |
| | labels, |
| | label_smoothing=self.cfg.label_smoothing if self.training else 0, |
| | reduction="none", |
| | ) |
| |
|
| | result = { |
| | "losses": {"regression": loss}, |
| | "sample_size": imgs.size(0), |
| | } |
| |
|
| | if not self.training: |
| | with torch.no_grad(): |
| | pred = x.argmax(-1) |
| | correct = (pred == labels).sum() |
| | result["correct"] = correct |
| |
|
| | return result |
| |
|
| | def model_forward(self, imgs): |
| | if self.d2v_multi: |
| | x = self.model.extract_features( |
| | imgs, |
| | mode="IMAGE", |
| | mask=False, |
| | remove_extra_tokens=( |
| | self.cfg.prediction_mode != PredictionMode.CLS_TOKEN |
| | ), |
| | )["x"] |
| | else: |
| | x = self.model(imgs, predictions_only=True) |
| | if ( |
| | "no_cls" not in self.model.cfg or not self.model.cfg.no_cls |
| | ) and not self.cfg.prediction_mode == PredictionMode.CLS_TOKEN: |
| | x = x[:, 1:] |
| | return x |
| |
|