KyanChen's picture
Upload 303 files
4d0eb62
raw
history blame
2.43 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union
import torch
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
@MODELS.register_module()
class CAEHead(BaseModule):
"""Head for CAE Pre-training.
Compute the align loss and the main loss. In addition, this head also
generates the prediction target generated by dalle.
Args:
loss (dict): The config of loss.
tokenizer_path (str): The path of the tokenizer.
init_cfg (dict or List[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
loss: dict,
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.loss_module = MODELS.build(loss)
@torch.no_grad()
def _generate_target(self, logits_target: torch.Tensor) -> torch.Tensor:
"""Generate the reconstruction target.
Args:
logits_target (torch.Tensor): The logits generated by DALL-E.s
Returns:
torch.Tensor: The logits target.
"""
target = torch.argmax(logits_target, dim=1)
return target.flatten(1)
def loss(self, logits: torch.Tensor, logits_target: torch.Tensor,
latent_pred: torch.Tensor, latent_target: torch.Tensor,
mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Generate loss.
Args:
logits (torch.Tensor): Logits generated by decoder.
logits_target (img_target): Target generated by dalle for decoder
prediction.
latent_pred (torch.Tensor): Latent prediction by regressor.
latent_target (torch.Tensor): Target for latent prediction,
generated by teacher.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The tuple of loss.
- ``loss_main`` (torch.Tensor): Cross entropy loss.
- ``loss_align`` (torch.Tensor): MSE loss.
"""
target = self._generate_target(logits_target) # target features
target = target[mask].detach()
# loss main for decoder, loss align for regressor
loss_main, loss_align = self.loss_module(logits, target, latent_pred,
latent_target)
return (loss_main, loss_align)