Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import copy | |
| import math | |
| from typing import Dict, Optional, Sequence, Union | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn.bricks.transformer import BaseTransformerLayer | |
| from mmengine.model import ModuleList | |
| from mmocr.models.common.dictionary import Dictionary | |
| from mmocr.models.common.modules import PositionalEncoding | |
| from mmocr.registry import MODELS | |
| from mmocr.structures import TextRecogDataSample | |
| from .base import BaseDecoder | |
| def clones(module: nn.Module, N: int) -> nn.ModuleList: | |
| """Produce N identical layers. | |
| Args: | |
| module (nn.Module): A pytorch nn.module. | |
| N (int): Number of copies. | |
| Returns: | |
| nn.ModuleList: A pytorch nn.ModuleList with the copies. | |
| """ | |
| return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) | |
| class Embeddings(nn.Module): | |
| """Construct the word embeddings given vocab size and embed dim. | |
| Args: | |
| d_model (int): The embedding dimension. | |
| vocab (int): Vocablury size. | |
| """ | |
| def __init__(self, d_model: int, vocab: int): | |
| super().__init__() | |
| self.lut = nn.Embedding(vocab, d_model) | |
| self.d_model = d_model | |
| def forward(self, *input: torch.Tensor) -> torch.Tensor: | |
| """Forward the embeddings. | |
| Args: | |
| input (torch.Tensor): The input tensors. | |
| Returns: | |
| torch.Tensor: The embeddings. | |
| """ | |
| x = input[0] | |
| return self.lut(x) * math.sqrt(self.d_model) | |
| class MasterDecoder(BaseDecoder): | |
| """Decoder module in `MASTER <https://arxiv.org/abs/1910.02562>`_. | |
| Code is partially modified from https://github.com/wenwenyu/MASTER-pytorch. | |
| Args: | |
| n_layers (int): Number of attention layers. Defaults to 3. | |
| n_head (int): Number of parallel attention heads. Defaults to 8. | |
| d_model (int): Dimension :math:`E` of the input from previous model. | |
| Defaults to 512. | |
| feat_size (int): The size of the input feature from previous model, | |
| usually :math:`H * W`. Defaults to 6 * 40. | |
| d_inner (int): Hidden dimension of feedforward layers. | |
| Defaults to 2048. | |
| attn_drop (float): Dropout rate of the attention layer. Defaults to 0. | |
| ffn_drop (float): Dropout rate of the feedforward layer. Defaults to 0. | |
| feat_pe_drop (float): Dropout rate of the feature positional encoding | |
| layer. Defaults to 0.2. | |
| dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or | |
| the instance of `Dictionary`. Defaults to None. | |
| module_loss (dict, optional): Config to build module_loss. Defaults | |
| to None. | |
| postprocessor (dict, optional): Config to build postprocessor. | |
| Defaults to None. | |
| max_seq_len (int): Maximum output sequence length :math:`T`. Defaults | |
| to 30. | |
| init_cfg (dict or list[dict], optional): Initialization configs. | |
| """ | |
| def __init__( | |
| self, | |
| n_layers: int = 3, | |
| n_head: int = 8, | |
| d_model: int = 512, | |
| feat_size: int = 6 * 40, | |
| d_inner: int = 2048, | |
| attn_drop: float = 0., | |
| ffn_drop: float = 0., | |
| feat_pe_drop: float = 0.2, | |
| module_loss: Optional[Dict] = None, | |
| postprocessor: Optional[Dict] = None, | |
| dictionary: Optional[Union[Dict, Dictionary]] = None, | |
| max_seq_len: int = 30, | |
| init_cfg: Optional[Union[Dict, Sequence[Dict]]] = None, | |
| ): | |
| super().__init__( | |
| module_loss=module_loss, | |
| postprocessor=postprocessor, | |
| dictionary=dictionary, | |
| init_cfg=init_cfg, | |
| max_seq_len=max_seq_len) | |
| operation_order = ('norm', 'self_attn', 'norm', 'cross_attn', 'norm', | |
| 'ffn') | |
| decoder_layer = BaseTransformerLayer( | |
| operation_order=operation_order, | |
| attn_cfgs=dict( | |
| type='MultiheadAttention', | |
| embed_dims=d_model, | |
| num_heads=n_head, | |
| attn_drop=attn_drop, | |
| dropout_layer=dict(type='Dropout', drop_prob=attn_drop), | |
| ), | |
| ffn_cfgs=dict( | |
| type='FFN', | |
| embed_dims=d_model, | |
| feedforward_channels=d_inner, | |
| ffn_drop=ffn_drop, | |
| dropout_layer=dict(type='Dropout', drop_prob=ffn_drop), | |
| ), | |
| norm_cfg=dict(type='LN'), | |
| batch_first=True, | |
| ) | |
| self.decoder_layers = ModuleList( | |
| [copy.deepcopy(decoder_layer) for _ in range(n_layers)]) | |
| self.cls = nn.Linear(d_model, self.dictionary.num_classes) | |
| self.SOS = self.dictionary.start_idx | |
| self.PAD = self.dictionary.padding_idx | |
| self.max_seq_len = max_seq_len | |
| self.feat_size = feat_size | |
| self.n_head = n_head | |
| self.embedding = Embeddings( | |
| d_model=d_model, vocab=self.dictionary.num_classes) | |
| # TODO: | |
| self.positional_encoding = PositionalEncoding( | |
| d_hid=d_model, n_position=self.max_seq_len + 1) | |
| self.feat_positional_encoding = PositionalEncoding( | |
| d_hid=d_model, n_position=self.feat_size, dropout=feat_pe_drop) | |
| self.norm = nn.LayerNorm(d_model) | |
| self.softmax = nn.Softmax(dim=-1) | |
| def make_target_mask(self, tgt: torch.Tensor, | |
| device: torch.device) -> torch.Tensor: | |
| """Make target mask for self attention. | |
| Args: | |
| tgt (Tensor): Shape [N, l_tgt] | |
| device (torch.device): Mask device. | |
| Returns: | |
| Tensor: Mask of shape [N * self.n_head, l_tgt, l_tgt] | |
| """ | |
| trg_pad_mask = (tgt != self.PAD).unsqueeze(1).unsqueeze(3).bool() | |
| tgt_len = tgt.size(1) | |
| trg_sub_mask = torch.tril( | |
| torch.ones((tgt_len, tgt_len), dtype=torch.bool, device=device)) | |
| tgt_mask = trg_pad_mask & trg_sub_mask | |
| # inverse for mmcv's BaseTransformerLayer | |
| tril_mask = tgt_mask.clone() | |
| tgt_mask = tgt_mask.float().masked_fill_(tril_mask == 0, -1e9) | |
| tgt_mask = tgt_mask.masked_fill_(tril_mask, 0) | |
| tgt_mask = tgt_mask.repeat(1, self.n_head, 1, 1) | |
| tgt_mask = tgt_mask.view(-1, tgt_len, tgt_len) | |
| return tgt_mask | |
| def decode(self, tgt_seq: torch.Tensor, feature: torch.Tensor, | |
| src_mask: torch.BoolTensor, | |
| tgt_mask: torch.BoolTensor) -> torch.Tensor: | |
| """Decode the input sequence. | |
| Args: | |
| tgt_seq (Tensor): Target sequence of shape: math: `(N, T, C)`. | |
| feature (Tensor): Input feature map from encoder of | |
| shape: math: `(N, C, H, W)` | |
| src_mask (BoolTensor): The source mask of shape: math: `(N, H*W)`. | |
| tgt_mask (BoolTensor): The target mask of shape: math: `(N, T, T)`. | |
| Return: | |
| Tensor: The decoded sequence. | |
| """ | |
| tgt_seq = self.embedding(tgt_seq) | |
| x = self.positional_encoding(tgt_seq) | |
| attn_masks = [tgt_mask, src_mask] | |
| for layer in self.decoder_layers: | |
| x = layer( | |
| query=x, key=feature, value=feature, attn_masks=attn_masks) | |
| x = self.norm(x) | |
| return self.cls(x) | |
| def forward_train(self, | |
| feat: Optional[torch.Tensor] = None, | |
| out_enc: torch.Tensor = None, | |
| data_samples: Sequence[TextRecogDataSample] = None | |
| ) -> torch.Tensor: | |
| """Forward for training. Source mask will not be used here. | |
| Args: | |
| feat (Tensor, optional): Input feature map from backbone. | |
| out_enc (Tensor): Unused. | |
| data_samples (list[TextRecogDataSample]): Batch of | |
| TextRecogDataSample, containing gt_text and valid_ratio | |
| information. | |
| Returns: | |
| Tensor: The raw logit tensor. Shape :math:`(N, T, C)` where | |
| :math:`C` is ``num_classes``. | |
| """ | |
| # flatten 2D feature map | |
| if len(feat.shape) > 3: | |
| b, c, h, w = feat.shape | |
| feat = feat.view(b, c, h * w) | |
| feat = feat.permute((0, 2, 1)) | |
| feat = self.feat_positional_encoding(feat) | |
| trg_seq = [] | |
| for target in data_samples: | |
| trg_seq.append(target.gt_text.padded_indexes.to(feat.device)) | |
| trg_seq = torch.stack(trg_seq, dim=0) | |
| src_mask = None | |
| tgt_mask = self.make_target_mask(trg_seq, device=feat.device) | |
| return self.decode(trg_seq, feat, src_mask, tgt_mask) | |
| def forward_test(self, | |
| feat: Optional[torch.Tensor] = None, | |
| out_enc: torch.Tensor = None, | |
| data_samples: Sequence[TextRecogDataSample] = None | |
| ) -> torch.Tensor: | |
| """Forward for testing. | |
| Args: | |
| feat (Tensor, optional): Input feature map from backbone. | |
| out_enc (Tensor): Unused. | |
| data_samples (list[TextRecogDataSample]): Unused. | |
| Returns: | |
| Tensor: Character probabilities. of shape | |
| :math:`(N, self.max_seq_len, C)` where :math:`C` is | |
| ``num_classes``. | |
| """ | |
| # flatten 2D feature map | |
| if len(feat.shape) > 3: | |
| b, c, h, w = feat.shape | |
| feat = feat.view(b, c, h * w) | |
| feat = feat.permute((0, 2, 1)) | |
| feat = self.feat_positional_encoding(feat) | |
| N = feat.shape[0] | |
| input = torch.full((N, 1), | |
| self.SOS, | |
| device=feat.device, | |
| dtype=torch.long) | |
| output = None | |
| for _ in range(self.max_seq_len): | |
| target_mask = self.make_target_mask(input, device=feat.device) | |
| out = self.decode(input, feat, None, target_mask) | |
| output = out | |
| _, next_word = torch.max(out, dim=-1) | |
| input = torch.cat([input, next_word[:, -1].unsqueeze(-1)], dim=1) | |
| return self.softmax(output) | |