| | |
| | from collections import OrderedDict |
| | from typing import Sequence |
| |
|
| | import torch |
| | from mmengine.model import BaseModel |
| | from torch import nn |
| |
|
| | try: |
| | from transformers import AutoTokenizer, BertConfig |
| | from transformers import BertModel as HFBertModel |
| | except ImportError: |
| | AutoTokenizer = None |
| | HFBertModel = None |
| |
|
| | from mmdet.registry import MODELS |
| |
|
| |
|
| | def generate_masks_with_special_tokens_and_transfer_map( |
| | tokenized, special_tokens_list): |
| | """Generate attention mask between each pair of special tokens. |
| | |
| | Only token pairs in between two special tokens are attended to |
| | and thus the attention mask for these pairs is positive. |
| | |
| | Args: |
| | input_ids (torch.Tensor): input ids. Shape: [bs, num_token] |
| | special_tokens_mask (list): special tokens mask. |
| | |
| | Returns: |
| | Tuple(Tensor, Tensor): |
| | - attention_mask is the attention mask between each tokens. |
| | Only token pairs in between two special tokens are positive. |
| | Shape: [bs, num_token, num_token]. |
| | - position_ids is the position id of tokens within each valid sentence. |
| | The id starts from 0 whenenver a special token is encountered. |
| | Shape: [bs, num_token] |
| | """ |
| | input_ids = tokenized['input_ids'] |
| | bs, num_token = input_ids.shape |
| | |
| | |
| | special_tokens_mask = torch.zeros((bs, num_token), |
| | device=input_ids.device).bool() |
| |
|
| | for special_token in special_tokens_list: |
| | special_tokens_mask |= input_ids == special_token |
| |
|
| | |
| | idxs = torch.nonzero(special_tokens_mask) |
| |
|
| | |
| | attention_mask = ( |
| | torch.eye(num_token, |
| | device=input_ids.device).bool().unsqueeze(0).repeat( |
| | bs, 1, 1)) |
| | position_ids = torch.zeros((bs, num_token), device=input_ids.device) |
| | previous_col = 0 |
| | for i in range(idxs.shape[0]): |
| | row, col = idxs[i] |
| | if (col == 0) or (col == num_token - 1): |
| | attention_mask[row, col, col] = True |
| | position_ids[row, col] = 0 |
| | else: |
| | attention_mask[row, previous_col + 1:col + 1, |
| | previous_col + 1:col + 1] = True |
| | position_ids[row, previous_col + 1:col + 1] = torch.arange( |
| | 0, col - previous_col, device=input_ids.device) |
| | previous_col = col |
| |
|
| | return attention_mask, position_ids.to(torch.long) |
| |
|
| |
|
| | @MODELS.register_module() |
| | class BertModel(BaseModel): |
| | """BERT model for language embedding only encoder. |
| | |
| | Args: |
| | name (str, optional): name of the pretrained BERT model from |
| | HuggingFace. Defaults to bert-base-uncased. |
| | max_tokens (int, optional): maximum number of tokens to be |
| | used for BERT. Defaults to 256. |
| | pad_to_max (bool, optional): whether to pad the tokens to max_tokens. |
| | Defaults to True. |
| | use_sub_sentence_represent (bool, optional): whether to use sub |
| | sentence represent introduced in `Grounding DINO |
| | <https://arxiv.org/abs/2303.05499>`. Defaults to False. |
| | special_tokens_list (list, optional): special tokens used to split |
| | subsentence. It cannot be None when `use_sub_sentence_represent` |
| | is True. Defaults to None. |
| | add_pooling_layer (bool, optional): whether to adding pooling |
| | layer in bert encoder. Defaults to False. |
| | num_layers_of_embedded (int, optional): number of layers of |
| | the embedded model. Defaults to 1. |
| | use_checkpoint (bool, optional): whether to use gradient checkpointing. |
| | Defaults to False. |
| | """ |
| |
|
| | def __init__(self, |
| | name: str = 'bert-base-uncased', |
| | max_tokens: int = 256, |
| | pad_to_max: bool = True, |
| | use_sub_sentence_represent: bool = False, |
| | special_tokens_list: list = None, |
| | add_pooling_layer: bool = False, |
| | num_layers_of_embedded: int = 1, |
| | use_checkpoint: bool = False, |
| | **kwargs) -> None: |
| |
|
| | super().__init__(**kwargs) |
| | self.max_tokens = max_tokens |
| | self.pad_to_max = pad_to_max |
| |
|
| | if AutoTokenizer is None: |
| | raise RuntimeError( |
| | 'transformers is not installed, please install it by: ' |
| | 'pip install transformers.') |
| |
|
| | self.tokenizer = AutoTokenizer.from_pretrained(name) |
| | self.language_backbone = nn.Sequential( |
| | OrderedDict([('body', |
| | BertEncoder( |
| | name, |
| | add_pooling_layer=add_pooling_layer, |
| | num_layers_of_embedded=num_layers_of_embedded, |
| | use_checkpoint=use_checkpoint))])) |
| |
|
| | self.use_sub_sentence_represent = use_sub_sentence_represent |
| | if self.use_sub_sentence_represent: |
| | assert special_tokens_list is not None, \ |
| | 'special_tokens should not be None \ |
| | if use_sub_sentence_represent is True' |
| |
|
| | self.special_tokens = self.tokenizer.convert_tokens_to_ids( |
| | special_tokens_list) |
| |
|
| | def forward(self, captions: Sequence[str], **kwargs) -> dict: |
| | """Forward function.""" |
| | device = next(self.language_backbone.parameters()).device |
| | tokenized = self.tokenizer.batch_encode_plus( |
| | captions, |
| | max_length=self.max_tokens, |
| | padding='max_length' if self.pad_to_max else 'longest', |
| | return_special_tokens_mask=True, |
| | return_tensors='pt', |
| | truncation=True).to(device) |
| | input_ids = tokenized.input_ids |
| | if self.use_sub_sentence_represent: |
| | attention_mask, position_ids = \ |
| | generate_masks_with_special_tokens_and_transfer_map( |
| | tokenized, self.special_tokens) |
| | token_type_ids = tokenized['token_type_ids'] |
| |
|
| | else: |
| | attention_mask = tokenized.attention_mask |
| | position_ids = None |
| | token_type_ids = None |
| |
|
| | tokenizer_input = { |
| | 'input_ids': input_ids, |
| | 'attention_mask': attention_mask, |
| | 'position_ids': position_ids, |
| | 'token_type_ids': token_type_ids |
| | } |
| | language_dict_features = self.language_backbone(tokenizer_input) |
| | if self.use_sub_sentence_represent: |
| | language_dict_features['position_ids'] = position_ids |
| | language_dict_features[ |
| | 'text_token_mask'] = tokenized.attention_mask.bool() |
| | return language_dict_features |
| |
|
| |
|
| | class BertEncoder(nn.Module): |
| | """BERT encoder for language embedding. |
| | |
| | Args: |
| | name (str): name of the pretrained BERT model from HuggingFace. |
| | Defaults to bert-base-uncased. |
| | add_pooling_layer (bool): whether to add a pooling layer. |
| | num_layers_of_embedded (int): number of layers of the embedded model. |
| | Defaults to 1. |
| | use_checkpoint (bool): whether to use gradient checkpointing. |
| | Defaults to False. |
| | """ |
| |
|
| | def __init__(self, |
| | name: str, |
| | add_pooling_layer: bool = False, |
| | num_layers_of_embedded: int = 1, |
| | use_checkpoint: bool = False): |
| | super().__init__() |
| | if BertConfig is None: |
| | raise RuntimeError( |
| | 'transformers is not installed, please install it by: ' |
| | 'pip install transformers.') |
| | config = BertConfig.from_pretrained(name) |
| | config.gradient_checkpointing = use_checkpoint |
| | |
| | self.model = HFBertModel.from_pretrained( |
| | name, add_pooling_layer=add_pooling_layer, config=config) |
| | self.language_dim = config.hidden_size |
| | self.num_layers_of_embedded = num_layers_of_embedded |
| |
|
| | def forward(self, x) -> dict: |
| | mask = x['attention_mask'] |
| |
|
| | outputs = self.model( |
| | input_ids=x['input_ids'], |
| | attention_mask=mask, |
| | position_ids=x['position_ids'], |
| | token_type_ids=x['token_type_ids'], |
| | output_hidden_states=True, |
| | ) |
| |
|
| | |
| | encoded_layers = outputs.hidden_states[1:] |
| | features = torch.stack(encoded_layers[-self.num_layers_of_embedded:], |
| | 1).mean(1) |
| | |
| | features = features / self.num_layers_of_embedded |
| | if mask.dim() == 2: |
| | embedded = features * mask.unsqueeze(-1).float() |
| | else: |
| | embedded = features |
| |
|
| | results = { |
| | 'embedded': embedded, |
| | 'masks': mask, |
| | 'hidden': encoded_layers[-1] |
| | } |
| | return results |
| |
|