# Copyright (c) OpenMMLab. All rights reserved. from typing import List, Optional, Tuple, Union import torch from mmengine.model import BaseModule from mmpretrain.registry import MODELS @MODELS.register_module() class CAEHead(BaseModule): """Head for CAE Pre-training. Compute the align loss and the main loss. In addition, this head also generates the prediction target generated by dalle. Args: loss (dict): The config of loss. tokenizer_path (str): The path of the tokenizer. init_cfg (dict or List[dict], optional): Initialization config dict. Defaults to None. """ def __init__(self, loss: dict, init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: super().__init__(init_cfg=init_cfg) self.loss_module = MODELS.build(loss) @torch.no_grad() def _generate_target(self, logits_target: torch.Tensor) -> torch.Tensor: """Generate the reconstruction target. Args: logits_target (torch.Tensor): The logits generated by DALL-E.s Returns: torch.Tensor: The logits target. """ target = torch.argmax(logits_target, dim=1) return target.flatten(1) def loss(self, logits: torch.Tensor, logits_target: torch.Tensor, latent_pred: torch.Tensor, latent_target: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Generate loss. Args: logits (torch.Tensor): Logits generated by decoder. logits_target (img_target): Target generated by dalle for decoder prediction. latent_pred (torch.Tensor): Latent prediction by regressor. latent_target (torch.Tensor): Target for latent prediction, generated by teacher. Returns: Tuple[torch.Tensor, torch.Tensor]: The tuple of loss. - ``loss_main`` (torch.Tensor): Cross entropy loss. - ``loss_align`` (torch.Tensor): MSE loss. """ target = self._generate_target(logits_target) # target features target = target[mask].detach() # loss main for decoder, loss align for regressor loss_main, loss_align = self.loss_module(logits, target, latent_pred, latent_target) return (loss_main, loss_align)