File size: 5,931 Bytes
23fe031
 
 
 
 
 
 
 
 
 
 
e77f999
23fe031
 
e77f999
23fe031
e77f999
23fe031
 
 
e77f999
 
 
 
 
 
 
 
 
 
23fe031
 
 
 
 
 
 
 
 
 
 
 
 
 
e77f999
 
23fe031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Copyright (c) Microsoft, Inc. 2020
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Zhou Bo
# Date: 01/15/2020
#

import copy
import torch
from transformers.configuration_utils import PretrainedConfig
from .ops import *
from .bert import *
from .config import DebertaConfig
from .cache_utils import load_model_state
from transformers import PreTrainedModel

__all__ = ['DeBERTa']


class DebertaPretrainedModel(PreTrainedModel):
  config_class = DebertaConfig
  base_model_prefix = 'deberta'
  
  def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
    super().__init__(config, *inputs, **kwargs)


class DeBERTa(DebertaPretrainedModel):
  """ DeBERTa encoder
  This module is composed of the input embedding layer with stacked transformer layers with disentangled attention.

  Parameters:
    config:
      A model config class instance with the configuration to build a new model. The schema is similar to `BertConfig`, \
          for more details, please refer :class:`~DeBERTa.deberta.ModelConfig`

    pre_trained:
      The pre-trained DeBERTa model, it can be a physical path of a pre-trained DeBERTa model or a released configurations, \
          i.e. [**base, large, base_mnli, large_mnli**]

  """

  def __init__(self, config, pre_trained=None):
    super().__init__(config)
    state = None
    if pre_trained is not None:
      state, model_config = load_model_state(pre_trained)
      if config is not None and model_config is not None:
        for k in config.__dict__:
          if k not in ['hidden_size',
            'intermediate_size',
            'num_attention_heads',
            'num_hidden_layers',
            'vocab_size',
            'max_position_embeddings']:
            model_config.__dict__[k] = config.__dict__[k]
      config = copy.copy(model_config)
    self.embeddings = BertEmbeddings(config)
    self.encoder = BertEncoder(config)
    self.config = config
    self.pre_trained = pre_trained
    self.apply_state(state)

  def forward(self, input_ids, attention_mask=None, token_type_ids=None, output_all_encoded_layers=True, position_ids = None, return_att = False):
    """
    Args:
      input_ids:
        a torch.LongTensor of shape [batch_size, sequence_length] \
      with the word token indices in the vocabulary

      attention_mask:
        an optional parameter for input mask or attention mask.

        - If it's an input mask, then it will be torch.LongTensor of shape [batch_size, sequence_length] with indices \
      selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max \
      input sequence length in the current batch. It's the mask that we typically use for attention when \
      a batch has varying length sentences.

        - If it's an attention mask then it will be torch.LongTensor of shape [batch_size, sequence_length, sequence_length]. \
      In this case, it's a mask indicate which tokens in the sequence should be attended by other tokens in the sequence.

      token_type_ids:
        an optional torch.LongTensor of shape [batch_size, sequence_length] with the token \
      types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to \
      a `sentence B` token (see BERT paper for more details).

      output_all_encoded_layers:
        whether to output results of all encoder layers, default, True

    Returns:

      - The output of the stacked transformer layers if `output_all_encoded_layers=True`, else \
      the last layer of stacked transformer layers

      - Attention matrix of self-attention layers if `return_att=True`


    Example::

      # Batch of wordPiece token ids.
      # Each sample was padded with zero to the maxium length of the batch
      input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
      # Mask of valid input ids
      attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])

      # DeBERTa model initialized with pretrained base model
      bert = DeBERTa(pre_trained='base')

      encoder_layers = bert(input_ids, attention_mask=attention_mask)

    """

    if attention_mask is None:
      attention_mask = torch.ones_like(input_ids)
    if token_type_ids is None:
      token_type_ids = torch.zeros_like(input_ids)
      token_mask = torch.ones_like(input_ids)
    else:
      idxs = torch.flip(torch.cumsum(torch.flip(token_type_ids, [-1]), axis=1), [-1])
      token_mask = idxs > 0
      token_mask = token_mask.byte()
    ebd_output = self.embeddings(input_ids.to(torch.long), token_type_ids.to(torch.long), position_ids, token_mask)
    embedding_output = ebd_output['embeddings']
    encoder_output = self.encoder(embedding_output,
                   attention_mask,
                   output_all_encoded_layers=output_all_encoded_layers, return_att = return_att)
    encoder_output.update(ebd_output)
    return encoder_output

  def apply_state(self, state = None):
    """ Load state from previous loaded model state dictionary.

      Args:
        state (:obj:`dict`, optional): State dictionary as the state returned by torch.module.state_dict(), default: `None`. \
            If it's `None`, then will use the pre-trained state loaded via the constructor to re-initialize \
            the `DeBERTa` model
    """
    if self.pre_trained is None and state is None:
      return
    if state is None:
      state, config = load_model_state(self.pre_trained)
      self.config = config
    
    prefix = ''
    for k in state:
      if 'embeddings.' in k:
        if not k.startswith('embeddings.'):
          prefix = k[:k.index('embeddings.')]
        break

    missing_keys = []
    unexpected_keys = []
    error_msgs = []
    self._load_from_state_dict(state, prefix = prefix, local_metadata=None, strict=True, missing_keys=missing_keys, unexpected_keys=unexpected_keys, error_msgs=error_msgs)