Spaces:
Sleeping
Sleeping
File size: 6,534 Bytes
14c9181 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
# Copyright (c) OpenMMLab. All rights reserved.
import random
from typing import Dict, Optional, Sequence, Union
import torch
import torch.nn as nn
from torch.nn import functional as F
from mmocr.models.common.dictionary import Dictionary
from mmocr.models.textrecog.decoders.base import BaseDecoder
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
@MODELS.register_module()
class ABCNetRecDecoder(BaseDecoder):
"""Decoder for ABCNet.
Args:
in_channels (int): Number of input channels.
dropout_prob (float): Probability of dropout. Default to 0.5.
teach_prob (float): Probability of teacher forcing. Defaults to 0.5.
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
the instance of `Dictionary`.
module_loss (dict, optional): Config to build module_loss. Defaults
to None.
postprocessor (dict, optional): Config to build postprocessor.
Defaults to None.
max_seq_len (int, optional): Max sequence length. Defaults to 30.
init_cfg (dict or list[dict], optional): Initialization configs.
Defaults to None.
"""
def __init__(self,
in_channels: int = 256,
dropout_prob: float = 0.5,
teach_prob: float = 0.5,
dictionary: Union[Dictionary, Dict] = None,
module_loss: Dict = None,
postprocessor: Dict = None,
max_seq_len: int = 30,
init_cfg=dict(type='Xavier', layer='Conv2d'),
**kwargs):
super().__init__(
init_cfg=init_cfg,
dictionary=dictionary,
module_loss=module_loss,
postprocessor=postprocessor,
max_seq_len=max_seq_len)
self.in_channels = in_channels
self.teach_prob = teach_prob
self.embedding = nn.Embedding(self.dictionary.num_classes, in_channels)
self.attn_combine = nn.Linear(in_channels * 2, in_channels)
self.dropout = nn.Dropout(dropout_prob)
self.gru = nn.GRU(in_channels, in_channels)
self.out = nn.Linear(in_channels, self.dictionary.num_classes)
self.vat = nn.Linear(in_channels, 1)
self.softmax = nn.Softmax(dim=-1)
def forward_train(
self,
feat: torch.Tensor,
out_enc: Optional[torch.Tensor] = None,
data_samples: Optional[Sequence[TextRecogDataSample]] = None
) -> torch.Tensor:
"""
Args:
feat (Tensor): A Tensor of shape :math:`(N, C, 1, W)`.
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
data_samples (list[TextRecogDataSample], optional): Batch of
TextRecogDataSample, containing gt_text information. Defaults
to None.
Returns:
Tensor: The raw logit tensor. Shape :math:`(N, W, C)` where
:math:`C` is ``num_classes``.
"""
bs = out_enc.size()[1]
trg_seq = []
for target in data_samples:
trg_seq.append(target.gt_text.padded_indexes.to(feat.device))
decoder_input = torch.zeros(bs).long().to(out_enc.device)
trg_seq = torch.stack(trg_seq, dim=0)
decoder_hidden = torch.zeros(1, bs,
self.in_channels).to(out_enc.device)
decoder_outputs = []
for index in range(trg_seq.shape[1]):
# decoder_output (nbatch, ncls)
decoder_output, decoder_hidden = self._attention(
decoder_input, decoder_hidden, out_enc)
teach_forcing = True if random.random(
) > self.teach_prob else False
if teach_forcing:
decoder_input = trg_seq[:, index] # Teacher forcing
else:
_, topi = decoder_output.data.topk(1)
decoder_input = topi.squeeze()
decoder_outputs.append(decoder_output)
return torch.stack(decoder_outputs, dim=1)
def forward_test(
self,
feat: Optional[torch.Tensor] = None,
out_enc: Optional[torch.Tensor] = None,
data_samples: Optional[Sequence[TextRecogDataSample]] = None
) -> torch.Tensor:
"""
Args:
feat (Tensor): A Tensor of shape :math:`(N, C, 1, W)`.
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
data_samples (list[TextRecogDataSample]): Batch of
TextRecogDataSample, containing ``gt_text`` information.
Defaults to None.
Returns:
Tensor: Character probabilities. of shape
:math:`(N, self.max_seq_len, C)` where :math:`C` is
``num_classes``.
"""
bs = out_enc.size()[1]
outputs = []
decoder_input = torch.zeros(bs).long().to(out_enc.device)
decoder_hidden = torch.zeros(1, bs,
self.in_channels).to(out_enc.device)
for _ in range(self.max_seq_len):
# decoder_output (nbatch, ncls)
decoder_output, decoder_hidden = self._attention(
decoder_input, decoder_hidden, out_enc)
_, topi = decoder_output.data.topk(1)
decoder_input = topi.squeeze()
outputs.append(decoder_output)
outputs = torch.stack(outputs, dim=1)
return self.softmax(outputs)
def _attention(self, input, hidden, encoder_outputs):
embedded = self.embedding(input)
embedded = self.dropout(embedded)
# test
batch_size = encoder_outputs.shape[1]
alpha = hidden + encoder_outputs
alpha = alpha.view(-1, alpha.shape[-1]) # (T * n, hidden_size)
attn_weights = self.vat(torch.tanh(alpha)) # (T * n, 1)
attn_weights = attn_weights.view(-1, 1, batch_size).permute(
(2, 1, 0)) # (T, 1, n) -> (n, 1, T)
attn_weights = F.softmax(attn_weights, dim=2)
attn_applied = torch.matmul(attn_weights,
encoder_outputs.permute((1, 0, 2)))
if embedded.dim() == 1:
embedded = embedded.unsqueeze(0)
output = torch.cat((embedded, attn_applied.squeeze(1)), 1)
output = self.attn_combine(output).unsqueeze(0) # (1, n, hidden_size)
output = F.relu(output)
output, hidden = self.gru(output, hidden) # (1, n, hidden_size)
output = self.out(output[0])
return output, hidden
|