File size: 6,534 Bytes
14c9181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Copyright (c) OpenMMLab. All rights reserved.
import random
from typing import Dict, Optional, Sequence, Union

import torch
import torch.nn as nn
from torch.nn import functional as F

from mmocr.models.common.dictionary import Dictionary
from mmocr.models.textrecog.decoders.base import BaseDecoder
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample


@MODELS.register_module()
class ABCNetRecDecoder(BaseDecoder):
    """Decoder for ABCNet.

    Args:
        in_channels (int): Number of input channels.
        dropout_prob (float): Probability of dropout. Default to 0.5.
        teach_prob (float): Probability of teacher forcing. Defaults to 0.5.
        dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
            the instance of `Dictionary`.
        module_loss (dict, optional): Config to build module_loss. Defaults
            to None.
        postprocessor (dict, optional): Config to build postprocessor.
            Defaults to None.
        max_seq_len (int, optional): Max sequence length. Defaults to 30.
        init_cfg (dict or list[dict], optional): Initialization configs.
            Defaults to None.
    """

    def __init__(self,
                 in_channels: int = 256,
                 dropout_prob: float = 0.5,
                 teach_prob: float = 0.5,
                 dictionary: Union[Dictionary, Dict] = None,
                 module_loss: Dict = None,
                 postprocessor: Dict = None,
                 max_seq_len: int = 30,
                 init_cfg=dict(type='Xavier', layer='Conv2d'),
                 **kwargs):
        super().__init__(
            init_cfg=init_cfg,
            dictionary=dictionary,
            module_loss=module_loss,
            postprocessor=postprocessor,
            max_seq_len=max_seq_len)
        self.in_channels = in_channels
        self.teach_prob = teach_prob
        self.embedding = nn.Embedding(self.dictionary.num_classes, in_channels)
        self.attn_combine = nn.Linear(in_channels * 2, in_channels)
        self.dropout = nn.Dropout(dropout_prob)
        self.gru = nn.GRU(in_channels, in_channels)
        self.out = nn.Linear(in_channels, self.dictionary.num_classes)
        self.vat = nn.Linear(in_channels, 1)
        self.softmax = nn.Softmax(dim=-1)

    def forward_train(
        self,
        feat: torch.Tensor,
        out_enc: Optional[torch.Tensor] = None,
        data_samples: Optional[Sequence[TextRecogDataSample]] = None
    ) -> torch.Tensor:
        """
        Args:
            feat (Tensor): A Tensor of shape :math:`(N, C, 1, W)`.
            out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
            data_samples (list[TextRecogDataSample], optional): Batch of
                TextRecogDataSample, containing gt_text information. Defaults
                to None.

        Returns:
            Tensor: The raw logit tensor. Shape :math:`(N, W, C)` where
            :math:`C` is ``num_classes``.
        """
        bs = out_enc.size()[1]
        trg_seq = []
        for target in data_samples:
            trg_seq.append(target.gt_text.padded_indexes.to(feat.device))
        decoder_input = torch.zeros(bs).long().to(out_enc.device)
        trg_seq = torch.stack(trg_seq, dim=0)
        decoder_hidden = torch.zeros(1, bs,
                                     self.in_channels).to(out_enc.device)
        decoder_outputs = []
        for index in range(trg_seq.shape[1]):
            #  decoder_output (nbatch, ncls)
            decoder_output, decoder_hidden = self._attention(
                decoder_input, decoder_hidden, out_enc)
            teach_forcing = True if random.random(
            ) > self.teach_prob else False
            if teach_forcing:
                decoder_input = trg_seq[:, index]  # Teacher forcing
            else:
                _, topi = decoder_output.data.topk(1)
                decoder_input = topi.squeeze()
            decoder_outputs.append(decoder_output)

        return torch.stack(decoder_outputs, dim=1)

    def forward_test(
        self,
        feat: Optional[torch.Tensor] = None,
        out_enc: Optional[torch.Tensor] = None,
        data_samples: Optional[Sequence[TextRecogDataSample]] = None
    ) -> torch.Tensor:
        """
        Args:
            feat (Tensor): A Tensor of shape :math:`(N, C, 1, W)`.
            out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
            data_samples (list[TextRecogDataSample]): Batch of
                TextRecogDataSample, containing ``gt_text`` information.
                Defaults to None.

        Returns:
            Tensor: Character probabilities. of shape
            :math:`(N, self.max_seq_len, C)` where :math:`C` is
            ``num_classes``.
        """
        bs = out_enc.size()[1]
        outputs = []
        decoder_input = torch.zeros(bs).long().to(out_enc.device)
        decoder_hidden = torch.zeros(1, bs,
                                     self.in_channels).to(out_enc.device)
        for _ in range(self.max_seq_len):
            #  decoder_output (nbatch, ncls)
            decoder_output, decoder_hidden = self._attention(
                decoder_input, decoder_hidden, out_enc)
            _, topi = decoder_output.data.topk(1)
            decoder_input = topi.squeeze()
            outputs.append(decoder_output)
        outputs = torch.stack(outputs, dim=1)
        return self.softmax(outputs)

    def _attention(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input)
        embedded = self.dropout(embedded)

        # test
        batch_size = encoder_outputs.shape[1]

        alpha = hidden + encoder_outputs
        alpha = alpha.view(-1, alpha.shape[-1])  # (T * n, hidden_size)
        attn_weights = self.vat(torch.tanh(alpha))  # (T * n, 1)
        attn_weights = attn_weights.view(-1, 1, batch_size).permute(
            (2, 1, 0))  # (T, 1, n)  -> (n, 1, T)
        attn_weights = F.softmax(attn_weights, dim=2)

        attn_applied = torch.matmul(attn_weights,
                                    encoder_outputs.permute((1, 0, 2)))

        if embedded.dim() == 1:
            embedded = embedded.unsqueeze(0)
        output = torch.cat((embedded, attn_applied.squeeze(1)), 1)
        output = self.attn_combine(output).unsqueeze(0)  # (1, n, hidden_size)

        output = F.relu(output)
        output, hidden = self.gru(output, hidden)  # (1, n, hidden_size)
        output = self.out(output[0])
        return output, hidden