Spaces:
Runtime error
Runtime error
File size: 4,659 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmocr.models.builder import LOSSES
@LOSSES.register_module()
class CELoss(nn.Module):
"""Implementation of loss module for encoder-decoder based text recognition
method with CrossEntropy loss.
Args:
ignore_index (int): Specifies a target value that is
ignored and does not contribute to the input gradient.
reduction (str): Specifies the reduction to apply to the output,
should be one of the following: ('none', 'mean', 'sum').
ignore_first_char (bool): Whether to ignore the first token in target (
usually the start token). If ``True``, the last token of the output
sequence will also be removed to be aligned with the target length.
"""
def __init__(self,
ignore_index=-1,
reduction='none',
ignore_first_char=False):
super().__init__()
assert isinstance(ignore_index, int)
assert isinstance(reduction, str)
assert reduction in ['none', 'mean', 'sum']
assert isinstance(ignore_first_char, bool)
self.loss_ce = nn.CrossEntropyLoss(
ignore_index=ignore_index, reduction=reduction)
self.ignore_first_char = ignore_first_char
def format(self, outputs, targets_dict):
targets = targets_dict['padded_targets']
if self.ignore_first_char:
targets = targets[:, 1:].contiguous()
outputs = outputs[:, :-1, :]
outputs = outputs.permute(0, 2, 1).contiguous()
return outputs, targets
def forward(self, outputs, targets_dict, img_metas=None):
"""
Args:
outputs (Tensor): A raw logit tensor of shape :math:`(N, T, C)`.
targets_dict (dict): A dict with a key ``padded_targets``, which is
a tensor of shape :math:`(N, T)`. Each element is the index of
a character.
img_metas (None): Unused.
Returns:
dict: A loss dict with the key ``loss_ce``.
"""
outputs, targets = self.format(outputs, targets_dict)
loss_ce = self.loss_ce(outputs, targets.to(outputs.device))
losses = dict(loss_ce=loss_ce)
return losses
@LOSSES.register_module()
class SARLoss(CELoss):
"""Implementation of loss module in `SAR.
<https://arxiv.org/abs/1811.00751>`_.
Args:
ignore_index (int): Specifies a target value that is
ignored and does not contribute to the input gradient.
reduction (str): Specifies the reduction to apply to the output,
should be one of the following: ("none", "mean", "sum").
Warning:
SARLoss assumes that the first input token is always `<SOS>`.
"""
def __init__(self, ignore_index=0, reduction='mean', **kwargs):
super().__init__(ignore_index, reduction)
def format(self, outputs, targets_dict):
targets = targets_dict['padded_targets']
# targets[0, :], [start_idx, idx1, idx2, ..., end_idx, pad_idx...]
# outputs[0, :, 0], [idx1, idx2, ..., end_idx, ...]
# ignore first index of target in loss calculation
targets = targets[:, 1:].contiguous()
# ignore last index of outputs to be in same seq_len with targets
outputs = outputs[:, :-1, :].permute(0, 2, 1).contiguous()
return outputs, targets
@LOSSES.register_module()
class TFLoss(CELoss):
"""Implementation of loss module for transformer.
Args:
ignore_index (int, optional): The character index to be ignored in
loss computation.
reduction (str): Type of reduction to apply to the output,
should be one of the following: ("none", "mean", "sum").
flatten (bool): Whether to flatten the vectors for loss computation.
Warning:
TFLoss assumes that the first input token is always `<SOS>`.
"""
def __init__(self,
ignore_index=-1,
reduction='none',
flatten=True,
**kwargs):
super().__init__(ignore_index, reduction)
assert isinstance(flatten, bool)
self.flatten = flatten
def format(self, outputs, targets_dict):
outputs = outputs[:, :-1, :].contiguous()
targets = targets_dict['padded_targets']
targets = targets[:, 1:].contiguous()
if self.flatten:
outputs = outputs.view(-1, outputs.size(-1))
targets = targets.view(-1)
else:
outputs = outputs.permute(0, 2, 1).contiguous()
return outputs, targets
|