Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.runner import ModuleList | |
from mmocr.models.builder import DECODERS | |
from mmocr.models.common import PositionalEncoding, TFDecoderLayer | |
from .base_decoder import BaseDecoder | |
class NRTRDecoder(BaseDecoder): | |
"""Transformer Decoder block with self attention mechanism. | |
Args: | |
n_layers (int): Number of attention layers. | |
d_embedding (int): Language embedding dimension. | |
n_head (int): Number of parallel attention heads. | |
d_k (int): Dimension of the key vector. | |
d_v (int): Dimension of the value vector. | |
d_model (int): Dimension :math:`D_m` of the input from previous model. | |
d_inner (int): Hidden dimension of feedforward layers. | |
n_position (int): Length of the positional encoding vector. Must be | |
greater than ``max_seq_len``. | |
dropout (float): Dropout rate. | |
num_classes (int): Number of output classes :math:`C`. | |
max_seq_len (int): Maximum output sequence length :math:`T`. | |
start_idx (int): The index of `<SOS>`. | |
padding_idx (int): The index of `<PAD>`. | |
init_cfg (dict or list[dict], optional): Initialization configs. | |
Warning: | |
This decoder will not predict the final class which is assumed to be | |
`<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>` | |
is also ignored by loss as specified in | |
:obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`. | |
""" | |
def __init__(self, | |
n_layers=6, | |
d_embedding=512, | |
n_head=8, | |
d_k=64, | |
d_v=64, | |
d_model=512, | |
d_inner=256, | |
n_position=200, | |
dropout=0.1, | |
num_classes=93, | |
max_seq_len=40, | |
start_idx=1, | |
padding_idx=92, | |
init_cfg=None, | |
**kwargs): | |
super().__init__(init_cfg=init_cfg) | |
self.padding_idx = padding_idx | |
self.start_idx = start_idx | |
self.max_seq_len = max_seq_len | |
self.trg_word_emb = nn.Embedding( | |
num_classes, d_embedding, padding_idx=padding_idx) | |
self.position_enc = PositionalEncoding( | |
d_embedding, n_position=n_position) | |
self.dropout = nn.Dropout(p=dropout) | |
self.layer_stack = ModuleList([ | |
TFDecoderLayer( | |
d_model, d_inner, n_head, d_k, d_v, dropout=dropout, **kwargs) | |
for _ in range(n_layers) | |
]) | |
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) | |
pred_num_class = num_classes - 1 # ignore padding_idx | |
self.classifier = nn.Linear(d_model, pred_num_class) | |
def get_pad_mask(seq, pad_idx): | |
return (seq != pad_idx).unsqueeze(-2) | |
def get_subsequent_mask(seq): | |
"""For masking out the subsequent info.""" | |
len_s = seq.size(1) | |
subsequent_mask = 1 - torch.triu( | |
torch.ones((len_s, len_s), device=seq.device), diagonal=1) | |
subsequent_mask = subsequent_mask.unsqueeze(0).bool() | |
return subsequent_mask | |
def _attention(self, trg_seq, src, src_mask=None): | |
trg_embedding = self.trg_word_emb(trg_seq) | |
trg_pos_encoded = self.position_enc(trg_embedding) | |
tgt = self.dropout(trg_pos_encoded) | |
trg_mask = self.get_pad_mask( | |
trg_seq, | |
pad_idx=self.padding_idx) & self.get_subsequent_mask(trg_seq) | |
output = tgt | |
for dec_layer in self.layer_stack: | |
output = dec_layer( | |
output, | |
src, | |
self_attn_mask=trg_mask, | |
dec_enc_attn_mask=src_mask) | |
output = self.layer_norm(output) | |
return output | |
def _get_mask(self, logit, img_metas): | |
valid_ratios = None | |
if img_metas is not None: | |
valid_ratios = [ | |
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas | |
] | |
N, T, _ = logit.size() | |
mask = None | |
if valid_ratios is not None: | |
mask = logit.new_zeros((N, T)) | |
for i, valid_ratio in enumerate(valid_ratios): | |
valid_width = min(T, math.ceil(T * valid_ratio)) | |
mask[i, :valid_width] = 1 | |
return mask | |
def forward_train(self, feat, out_enc, targets_dict, img_metas): | |
r""" | |
Args: | |
feat (None): Unused. | |
out_enc (Tensor): Encoder output of shape :math:`(N, T, D_m)` | |
where :math:`D_m` is ``d_model``. | |
targets_dict (dict): A dict with the key ``padded_targets``, a | |
tensor of shape :math:`(N, T)`. Each element is the index of a | |
character. | |
img_metas (dict): A dict that contains meta information of input | |
images. Preferably with the key ``valid_ratio``. | |
Returns: | |
Tensor: The raw logit tensor. Shape :math:`(N, T, C)`. | |
""" | |
src_mask = self._get_mask(out_enc, img_metas) | |
targets = targets_dict['padded_targets'].to(out_enc.device) | |
attn_output = self._attention(targets, out_enc, src_mask=src_mask) | |
outputs = self.classifier(attn_output) | |
return outputs | |
def forward_test(self, feat, out_enc, img_metas): | |
src_mask = self._get_mask(out_enc, img_metas) | |
N = out_enc.size(0) | |
init_target_seq = torch.full((N, self.max_seq_len + 1), | |
self.padding_idx, | |
device=out_enc.device, | |
dtype=torch.long) | |
# bsz * seq_len | |
init_target_seq[:, 0] = self.start_idx | |
outputs = [] | |
for step in range(0, self.max_seq_len): | |
decoder_output = self._attention( | |
init_target_seq, out_enc, src_mask=src_mask) | |
# bsz * seq_len * C | |
step_result = F.softmax( | |
self.classifier(decoder_output[:, step, :]), dim=-1) | |
# bsz * num_classes | |
outputs.append(step_result) | |
_, step_max_index = torch.max(step_result, dim=-1) | |
init_target_seq[:, step + 1] = step_max_index | |
outputs = torch.stack(outputs, dim=1) | |
return outputs | |