# Copyright (c) OpenMMLab. All rights reserved. import torch from mmengine.model import BaseModule from mmpretrain.registry import MODELS @MODELS.register_module() class MAEPretrainHead(BaseModule): """Head for MAE Pre-training. Args: loss (dict): Config of loss. norm_pix_loss (bool): Whether or not normalize target. Defaults to False. patch_size (int): Patch size. Defaults to 16. """ def __init__(self, loss: dict, norm_pix: bool = False, patch_size: int = 16) -> None: super().__init__() self.norm_pix = norm_pix self.patch_size = patch_size self.loss_module = MODELS.build(loss) def patchify(self, imgs: torch.Tensor) -> torch.Tensor: r"""Split images into non-overlapped patches. Args: imgs (torch.Tensor): A batch of images. The shape should be :math:`(B, 3, H, W)`. Returns: torch.Tensor: Patchified images. The shape is :math:`(B, L, \text{patch_size}^2 \times 3)`. """ p = self.patch_size assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 h = w = imgs.shape[2] // p x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) x = torch.einsum('nchpwq->nhwpqc', x) x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) return x def unpatchify(self, x: torch.Tensor) -> torch.Tensor: r"""Combine non-overlapped patches into images. Args: x (torch.Tensor): The shape is :math:`(B, L, \text{patch_size}^2 \times 3)`. Returns: torch.Tensor: The shape is :math:`(B, 3, H, W)`. """ p = self.patch_size h = w = int(x.shape[1]**.5) assert h * w == x.shape[1] x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) x = torch.einsum('nhwpqc->nchpwq', x) imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) return imgs def construct_target(self, target: torch.Tensor) -> torch.Tensor: """Construct the reconstruction target. In addition to splitting images into tokens, this module will also normalize the image according to ``norm_pix``. Args: target (torch.Tensor): Image with the shape of B x 3 x H x W Returns: torch.Tensor: Tokenized images with the shape of B x L x C """ target = self.patchify(target) if self.norm_pix: # normalize the target image mean = target.mean(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True) target = (target - mean) / (var + 1.e-6)**.5 return target def loss(self, pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """Generate loss. Args: pred (torch.Tensor): The reconstructed image. target (torch.Tensor): The target image. mask (torch.Tensor): The mask of the target image. Returns: torch.Tensor: The reconstruction loss. """ target = self.construct_target(target) loss = self.loss_module(pred, target, mask) return loss