Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import List, Optional, Tuple, Union | |
import torch | |
from mmengine.model import BaseModule | |
from mmpretrain.registry import MODELS | |
class CAEHead(BaseModule): | |
"""Head for CAE Pre-training. | |
Compute the align loss and the main loss. In addition, this head also | |
generates the prediction target generated by dalle. | |
Args: | |
loss (dict): The config of loss. | |
tokenizer_path (str): The path of the tokenizer. | |
init_cfg (dict or List[dict], optional): Initialization config dict. | |
Defaults to None. | |
""" | |
def __init__(self, | |
loss: dict, | |
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: | |
super().__init__(init_cfg=init_cfg) | |
self.loss_module = MODELS.build(loss) | |
def _generate_target(self, logits_target: torch.Tensor) -> torch.Tensor: | |
"""Generate the reconstruction target. | |
Args: | |
logits_target (torch.Tensor): The logits generated by DALL-E.s | |
Returns: | |
torch.Tensor: The logits target. | |
""" | |
target = torch.argmax(logits_target, dim=1) | |
return target.flatten(1) | |
def loss(self, logits: torch.Tensor, logits_target: torch.Tensor, | |
latent_pred: torch.Tensor, latent_target: torch.Tensor, | |
mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Generate loss. | |
Args: | |
logits (torch.Tensor): Logits generated by decoder. | |
logits_target (img_target): Target generated by dalle for decoder | |
prediction. | |
latent_pred (torch.Tensor): Latent prediction by regressor. | |
latent_target (torch.Tensor): Target for latent prediction, | |
generated by teacher. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: The tuple of loss. | |
- ``loss_main`` (torch.Tensor): Cross entropy loss. | |
- ``loss_align`` (torch.Tensor): MSE loss. | |
""" | |
target = self._generate_target(logits_target) # target features | |
target = target[mask].detach() | |
# loss main for decoder, loss align for regressor | |
loss_main, loss_align = self.loss_module(logits, target, latent_pred, | |
latent_target) | |
return (loss_main, loss_align) | |