File size: 5,931 Bytes
23fe031 e77f999 23fe031 e77f999 23fe031 e77f999 23fe031 e77f999 23fe031 e77f999 23fe031 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
# Copyright (c) Microsoft, Inc. 2020
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Zhou Bo
# Date: 01/15/2020
#
import copy
import torch
from transformers.configuration_utils import PretrainedConfig
from .ops import *
from .bert import *
from .config import DebertaConfig
from .cache_utils import load_model_state
from transformers import PreTrainedModel
__all__ = ['DeBERTa']
class DebertaPretrainedModel(PreTrainedModel):
config_class = DebertaConfig
base_model_prefix = 'deberta'
def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
class DeBERTa(DebertaPretrainedModel):
""" DeBERTa encoder
This module is composed of the input embedding layer with stacked transformer layers with disentangled attention.
Parameters:
config:
A model config class instance with the configuration to build a new model. The schema is similar to `BertConfig`, \
for more details, please refer :class:`~DeBERTa.deberta.ModelConfig`
pre_trained:
The pre-trained DeBERTa model, it can be a physical path of a pre-trained DeBERTa model or a released configurations, \
i.e. [**base, large, base_mnli, large_mnli**]
"""
def __init__(self, config, pre_trained=None):
super().__init__(config)
state = None
if pre_trained is not None:
state, model_config = load_model_state(pre_trained)
if config is not None and model_config is not None:
for k in config.__dict__:
if k not in ['hidden_size',
'intermediate_size',
'num_attention_heads',
'num_hidden_layers',
'vocab_size',
'max_position_embeddings']:
model_config.__dict__[k] = config.__dict__[k]
config = copy.copy(model_config)
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.config = config
self.pre_trained = pre_trained
self.apply_state(state)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, output_all_encoded_layers=True, position_ids = None, return_att = False):
"""
Args:
input_ids:
a torch.LongTensor of shape [batch_size, sequence_length] \
with the word token indices in the vocabulary
attention_mask:
an optional parameter for input mask or attention mask.
- If it's an input mask, then it will be torch.LongTensor of shape [batch_size, sequence_length] with indices \
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max \
input sequence length in the current batch. It's the mask that we typically use for attention when \
a batch has varying length sentences.
- If it's an attention mask then it will be torch.LongTensor of shape [batch_size, sequence_length, sequence_length]. \
In this case, it's a mask indicate which tokens in the sequence should be attended by other tokens in the sequence.
token_type_ids:
an optional torch.LongTensor of shape [batch_size, sequence_length] with the token \
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to \
a `sentence B` token (see BERT paper for more details).
output_all_encoded_layers:
whether to output results of all encoder layers, default, True
Returns:
- The output of the stacked transformer layers if `output_all_encoded_layers=True`, else \
the last layer of stacked transformer layers
- Attention matrix of self-attention layers if `return_att=True`
Example::
# Batch of wordPiece token ids.
# Each sample was padded with zero to the maxium length of the batch
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
# Mask of valid input ids
attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
# DeBERTa model initialized with pretrained base model
bert = DeBERTa(pre_trained='base')
encoder_layers = bert(input_ids, attention_mask=attention_mask)
"""
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
token_mask = torch.ones_like(input_ids)
else:
idxs = torch.flip(torch.cumsum(torch.flip(token_type_ids, [-1]), axis=1), [-1])
token_mask = idxs > 0
token_mask = token_mask.byte()
ebd_output = self.embeddings(input_ids.to(torch.long), token_type_ids.to(torch.long), position_ids, token_mask)
embedding_output = ebd_output['embeddings']
encoder_output = self.encoder(embedding_output,
attention_mask,
output_all_encoded_layers=output_all_encoded_layers, return_att = return_att)
encoder_output.update(ebd_output)
return encoder_output
def apply_state(self, state = None):
""" Load state from previous loaded model state dictionary.
Args:
state (:obj:`dict`, optional): State dictionary as the state returned by torch.module.state_dict(), default: `None`. \
If it's `None`, then will use the pre-trained state loaded via the constructor to re-initialize \
the `DeBERTa` model
"""
if self.pre_trained is None and state is None:
return
if state is None:
state, config = load_model_state(self.pre_trained)
self.config = config
prefix = ''
for k in state:
if 'embeddings.' in k:
if not k.startswith('embeddings.'):
prefix = k[:k.index('embeddings.')]
break
missing_keys = []
unexpected_keys = []
error_msgs = []
self._load_from_state_dict(state, prefix = prefix, local_metadata=None, strict=True, missing_keys=missing_keys, unexpected_keys=unexpected_keys, error_msgs=error_msgs)
|