Spaces:
Runtime error
Runtime error
File size: 2,281 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
# Copyright (c) OpenMMLab. All rights reserved.
from torch import nn
from torch.nn import CrossEntropyLoss
from mmocr.models.builder import LOSSES
@LOSSES.register_module()
class MaskedCrossEntropyLoss(nn.Module):
"""The implementation of masked cross entropy loss.
The mask has 1 for real tokens and 0 for padding tokens,
which only keep active parts of the cross entropy loss.
Args:
num_labels (int): Number of classes in labels.
ignore_index (int): Specifies a target value that is ignored
and does not contribute to the input gradient.
"""
def __init__(self, num_labels=None, ignore_index=0):
super().__init__()
self.num_labels = num_labels
self.criterion = CrossEntropyLoss(ignore_index=ignore_index)
def forward(self, logits, img_metas):
'''Loss forword.
Args:
logits: Model output with shape [N, C].
img_metas (dict): A dict containing the following keys:
- img (list]): This parameter is reserved.
- labels (list[int]): The labels for each word
of the sequence.
- texts (list): The words of the sequence.
- input_ids (list): The ids for each word of
the sequence.
- attention_mask (list): The mask for each word
of the sequence. The mask has 1 for real tokens
and 0 for padding tokens. Only real tokens are
attended to.
- token_type_ids (list): The tokens for each word
of the sequence.
'''
labels = img_metas['labels']
attention_masks = img_metas['attention_masks']
# Only keep active parts of the loss
if attention_masks is not None:
active_loss = attention_masks.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss]
active_labels = labels.view(-1)[active_loss]
loss = self.criterion(active_logits, active_labels)
else:
loss = self.criterion(
logits.view(-1, self.num_labels), labels.view(-1))
return {'loss_cls': loss}
|