Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmocr.models.builder import LOSSES | |
class ABILoss(nn.Module): | |
"""Implementation of ABINet multiloss that allows mixing different types of | |
losses with weights. | |
Args: | |
enc_weight (float): The weight of encoder loss. Defaults to 1.0. | |
dec_weight (float): The weight of decoder loss. Defaults to 1.0. | |
fusion_weight (float): The weight of fuser (aligner) loss. | |
Defaults to 1.0. | |
num_classes (int): Number of unique output language tokens. | |
Returns: | |
A dictionary whose key/value pairs are the losses of three modules. | |
""" | |
def __init__(self, | |
enc_weight=1.0, | |
dec_weight=1.0, | |
fusion_weight=1.0, | |
num_classes=37, | |
**kwargs): | |
assert isinstance(enc_weight, float) or isinstance(enc_weight, int) | |
assert isinstance(dec_weight, float) or isinstance(dec_weight, int) | |
assert isinstance(fusion_weight, float) or \ | |
isinstance(fusion_weight, int) | |
super().__init__() | |
self.enc_weight = enc_weight | |
self.dec_weight = dec_weight | |
self.fusion_weight = fusion_weight | |
self.num_classes = num_classes | |
def _flatten(self, logits, target_lens): | |
flatten_logits = torch.cat( | |
[s[:target_lens[i]] for i, s in enumerate((logits))]) | |
return flatten_logits | |
def _ce_loss(self, logits, targets): | |
targets_one_hot = F.one_hot(targets, self.num_classes) | |
log_prob = F.log_softmax(logits, dim=-1) | |
loss = -(targets_one_hot.to(log_prob.device) * log_prob).sum(dim=-1) | |
return loss.mean() | |
def _loss_over_iters(self, outputs, targets): | |
""" | |
Args: | |
outputs (list[Tensor]): Each tensor has shape (N, T, C) where N is | |
the batch size, T is the sequence length and C is the number of | |
classes. | |
targets_dicts (dict): The dictionary with at least `padded_targets` | |
defined. | |
""" | |
iter_num = len(outputs) | |
dec_outputs = torch.cat(outputs, dim=0) | |
flatten_targets_iternum = targets.repeat(iter_num) | |
return self._ce_loss(dec_outputs, flatten_targets_iternum) | |
def forward(self, outputs, targets_dict, img_metas=None): | |
""" | |
Args: | |
outputs (dict): The output dictionary with at least one of | |
``out_enc``, ``out_dec`` and ``out_fusers`` specified. | |
targets_dict (dict): The target dictionary containing the key | |
``padded_targets``, which represents target sequences in | |
shape (batch_size, sequence_length). | |
Returns: | |
A loss dictionary with ``loss_visual``, ``loss_lang`` and | |
``loss_fusion``. Each should either be the loss tensor or ``0`` if | |
the output of its corresponding module is not given. | |
""" | |
assert 'out_enc' in outputs or \ | |
'out_dec' in outputs or 'out_fusers' in outputs | |
losses = {} | |
target_lens = [len(t) for t in targets_dict['targets']] | |
flatten_targets = torch.cat([t for t in targets_dict['targets']]) | |
if outputs.get('out_enc', None): | |
enc_input = self._flatten(outputs['out_enc']['logits'], | |
target_lens) | |
enc_loss = self._ce_loss(enc_input, | |
flatten_targets) * self.enc_weight | |
losses['loss_visual'] = enc_loss | |
if outputs.get('out_decs', None): | |
dec_logits = [ | |
self._flatten(o['logits'], target_lens) | |
for o in outputs['out_decs'] | |
] | |
dec_loss = self._loss_over_iters(dec_logits, | |
flatten_targets) * self.dec_weight | |
losses['loss_lang'] = dec_loss | |
if outputs.get('out_fusers', None): | |
fusion_logits = [ | |
self._flatten(o['logits'], target_lens) | |
for o in outputs['out_fusers'] | |
] | |
fusion_loss = self._loss_over_iters( | |
fusion_logits, flatten_targets) * self.fusion_weight | |
losses['loss_fusion'] = fusion_loss | |
return losses | |