GuysRGithub commited on
Commit
312b0de
·
1 Parent(s): cbbca27
Files changed (2) hide show
  1. app.py +3 -2
  2. bart.py +1969 -0
app.py CHANGED
@@ -5,12 +5,13 @@ import re
5
  import textwrap
6
  from transformers import AutoModelForSeq2SeqLM
7
  from transformers import AutoTokenizer
 
8
  from langdetect import detect
9
  import subprocess
10
 
11
- tokenizer = AutoTokenizer.from_pretrained("GuysTrans/bart-base-re-attention-seq-512")
12
 
13
- vn_tokenizer = AutoTokenizer.from_pretrained("GuysTrans/bart-base-vn-re-attention-vn-tokenizer")
14
 
15
  model = AutoModelForSeq2SeqLM.from_pretrained(
16
  "GuysTrans/bart-base-re-attention-seq-512")
 
5
  import textwrap
6
  from transformers import AutoModelForSeq2SeqLM
7
  from transformers import AutoTokenizer
8
+ from bart import BartForConditionalGeneration
9
  from langdetect import detect
10
  import subprocess
11
 
12
+ tokenizer = BartForConditionalGeneration.from_pretrained("GuysTrans/bart-base-re-attention-seq-512")
13
 
14
+ vn_tokenizer = BartForConditionalGeneration.from_pretrained("GuysTrans/bart-base-vn-re-attention-vn-tokenizer")
15
 
16
  model = AutoModelForSeq2SeqLM.from_pretrained(
17
  "GuysTrans/bart-base-re-attention-seq-512")
bart.py ADDED
@@ -0,0 +1,1969 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch BART model."""
16
+ import copy
17
+ import math
18
+ import warnings
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn, einsum
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
+
26
+ from transformers.activations import ACT2FN
27
+ from transformers.modeling_outputs import (
28
+ BaseModelOutput,
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ CausalLMOutputWithCrossAttentions,
31
+ Seq2SeqLMOutput,
32
+ Seq2SeqModelOutput,
33
+ Seq2SeqQuestionAnsweringModelOutput,
34
+ Seq2SeqSequenceClassifierOutput,
35
+ )
36
+ from transformers.modeling_utils import PreTrainedModel
37
+ from transformers.utils import (
38
+ add_code_sample_docstrings,
39
+ add_end_docstrings,
40
+ add_start_docstrings,
41
+ add_start_docstrings_to_model_forward,
42
+ logging,
43
+ replace_return_docstrings,
44
+ )
45
+ from transformers.models.bart.configuration_bart import BartConfig
46
+
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+ _CHECKPOINT_FOR_DOC = "facebook/bart-base"
51
+ _CONFIG_FOR_DOC = "BartConfig"
52
+
53
+ # Base model docstring
54
+ _EXPECTED_OUTPUT_SHAPE = [1, 8, 768]
55
+
56
+ # SequenceClassification docstring
57
+ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "valhalla/bart-large-sst2"
58
+ _SEQ_CLASS_EXPECTED_LOSS = 0.0
59
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'POSITIVE'"
60
+
61
+ # QuestionAsnwering docstring
62
+ _CHECKPOINT_FOR_QA = "valhalla/bart-large-finetuned-squadv1"
63
+ _QA_EXPECTED_LOSS = 0.59
64
+ _QA_EXPECTED_OUTPUT = "' nice puppet'"
65
+
66
+
67
+ BART_PRETRAINED_MODEL_ARCHIVE_LIST = [
68
+ "facebook/bart-large",
69
+ # see all BART models at https://huggingface.co/models?filter=bart
70
+ ]
71
+
72
+
73
+ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
74
+ """
75
+ Shift input ids one token to the right.
76
+ """
77
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
78
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
79
+ shifted_input_ids[:, 0] = decoder_start_token_id
80
+
81
+ if pad_token_id is None:
82
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
83
+ # replace possible -100 values in labels by `pad_token_id`
84
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
85
+
86
+ return shifted_input_ids
87
+
88
+
89
+ def _make_causal_mask(
90
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
91
+ ):
92
+ """
93
+ Make causal mask used for bi-directional self-attention.
94
+ """
95
+ bsz, tgt_len = input_ids_shape
96
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
97
+ mask_cond = torch.arange(mask.size(-1), device=device)
98
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
99
+ mask = mask.to(dtype)
100
+
101
+ if past_key_values_length > 0:
102
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
103
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
104
+
105
+
106
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
107
+ """
108
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
109
+ """
110
+ bsz, src_len = mask.size()
111
+ tgt_len = tgt_len if tgt_len is not None else src_len
112
+
113
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
114
+
115
+ inverted_mask = 1.0 - expanded_mask
116
+
117
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
118
+
119
+
120
+ class BartLearnedPositionalEmbedding(nn.Embedding):
121
+ """
122
+ This module learns positional embeddings up to a fixed maximum size.
123
+ """
124
+
125
+ def __init__(self, num_embeddings: int, embedding_dim: int):
126
+ # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
127
+ # and adjust num_embeddings appropriately. Other models don't have this hack
128
+ self.offset = 2
129
+ super().__init__(num_embeddings + self.offset, embedding_dim)
130
+
131
+ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
132
+ """`input_ids' shape is expected to be [bsz x seqlen]."""
133
+
134
+ bsz, seq_len = input_ids.shape[:2]
135
+ positions = torch.arange(
136
+ past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
137
+ ).expand(bsz, -1)
138
+
139
+ return super().forward(positions + self.offset)
140
+
141
+
142
+ class BartAttention(nn.Module):
143
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
144
+
145
+ def __init__(
146
+ self,
147
+ embed_dim: int,
148
+ num_heads: int,
149
+ dropout: float = 0.0,
150
+ is_decoder: bool = False,
151
+ bias: bool = True,
152
+ ):
153
+ super().__init__()
154
+ self.embed_dim = embed_dim
155
+ self.num_heads = num_heads
156
+ self.dropout = dropout
157
+ self.head_dim = embed_dim // num_heads
158
+ # Re-attention
159
+ self.reatten_matrix = nn.Parameter(torch.randn(self.num_heads, self.num_heads))
160
+ self.var_norm = nn.BatchNorm2d(self.num_heads)
161
+
162
+
163
+ if (self.head_dim * num_heads) != self.embed_dim:
164
+ raise ValueError(
165
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
166
+ f" and `num_heads`: {num_heads})."
167
+ )
168
+ self.scaling = self.head_dim**-0.5
169
+ self.reatten_scale = self.scaling
170
+ self.is_decoder = is_decoder
171
+
172
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
173
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
174
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
175
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
176
+ self.proj_drop = nn.Dropout(0.0)
177
+ self.attn_drop = nn.Dropout(0.0)
178
+
179
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
180
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
181
+
182
+ def forward(
183
+ self,
184
+ hidden_states: torch.Tensor,
185
+ key_value_states: Optional[torch.Tensor] = None,
186
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
187
+ attention_mask: Optional[torch.Tensor] = None,
188
+ layer_head_mask: Optional[torch.Tensor] = None,
189
+ output_attentions: bool = False,
190
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
191
+ """Input shape: Batch x Time x Channel"""
192
+
193
+ # if key_value_states are provided this layer is used as a cross-attention layer
194
+ # for the decoder
195
+ re_attention = False
196
+ is_cross_attention = key_value_states is not None
197
+
198
+ bsz, tgt_len, _ = hidden_states.size()
199
+
200
+ # get query proj
201
+ query_states = self.q_proj(hidden_states) * self.scaling
202
+ # get key, value proj
203
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
204
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
205
+ # the provided `key_value_states` to support prefix tuning
206
+ if (
207
+ is_cross_attention
208
+ and past_key_value is not None
209
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
210
+ ):
211
+ # reuse k,v, cross_attentions
212
+ key_states = past_key_value[0]
213
+ value_states = past_key_value[1]
214
+ elif is_cross_attention:
215
+ # cross_attentions
216
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
217
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
218
+ elif past_key_value is not None:
219
+ # reuse k, v, self_attention
220
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
221
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
222
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
223
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
224
+ else:
225
+ # self_attention
226
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
227
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
228
+ re_attention = True
229
+
230
+ if self.is_decoder:
231
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
232
+ # Further calls to cross_attention layer can then reuse all cross-attention
233
+ # key/value_states (first "if" case)
234
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
235
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
236
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
237
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
238
+ past_key_value = (key_states, value_states)
239
+
240
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
241
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
242
+ key_states = key_states.reshape(*proj_shape)
243
+ value_states = value_states.reshape(*proj_shape)
244
+
245
+ src_len = key_states.size(1)
246
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
247
+
248
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
249
+ raise ValueError(
250
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
251
+ f" {attn_weights.size()}"
252
+ )
253
+
254
+ if attention_mask is not None:
255
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
256
+ raise ValueError(
257
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
258
+ )
259
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
260
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
261
+
262
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
263
+ # Re-attention
264
+ if re_attention:
265
+ # attn_weights = self.attn_drop(attn_weights)
266
+ attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len)
267
+ attn_weights = einsum('b h i j, h g -> b g i j', attn_weights, self.reatten_matrix) * self.reatten_scale
268
+ # attn_weights = self.var_norm(attn_weights) * self.reatten_scale
269
+ attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len)
270
+
271
+ if layer_head_mask is not None:
272
+ if layer_head_mask.size() != (self.num_heads,):
273
+ raise ValueError(
274
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
275
+ f" {layer_head_mask.size()}"
276
+ )
277
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
278
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
279
+
280
+ if output_attentions:
281
+ # this operation is a bit awkward, but it's required to
282
+ # make sure that attn_weights keeps its gradient.
283
+ # In order to do so, attn_weights have to be reshaped
284
+ # twice and have to be reused in the following
285
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
286
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
287
+ else:
288
+ attn_weights_reshaped = None
289
+
290
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
291
+
292
+ attn_output = torch.bmm(attn_probs, value_states)
293
+
294
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
295
+ raise ValueError(
296
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
297
+ f" {attn_output.size()}"
298
+ )
299
+
300
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
301
+ attn_output = attn_output.transpose(1, 2)
302
+
303
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
304
+ # partitioned across GPUs when using tensor-parallelism.
305
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
306
+
307
+ attn_output = self.out_proj(attn_output)
308
+
309
+ return attn_output, attn_weights_reshaped, past_key_value
310
+
311
+
312
+ class BartEncoderLayer(nn.Module):
313
+ def __init__(self, config: BartConfig):
314
+ super().__init__()
315
+ self.embed_dim = config.d_model
316
+ self.self_attn = BartAttention(
317
+ embed_dim=self.embed_dim,
318
+ num_heads=config.encoder_attention_heads,
319
+ dropout=config.attention_dropout,
320
+ )
321
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
322
+ self.dropout = config.dropout
323
+ self.activation_fn = ACT2FN[config.activation_function]
324
+ self.activation_dropout = config.activation_dropout
325
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
326
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
327
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
328
+
329
+ def forward(
330
+ self,
331
+ hidden_states: torch.FloatTensor,
332
+ attention_mask: torch.FloatTensor,
333
+ layer_head_mask: torch.FloatTensor,
334
+ output_attentions: Optional[bool] = False,
335
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
336
+ """
337
+ Args:
338
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
339
+ attention_mask (`torch.FloatTensor`): attention mask of size
340
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
341
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
342
+ `(encoder_attention_heads,)`.
343
+ output_attentions (`bool`, *optional*):
344
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
345
+ returned tensors for more detail.
346
+ """
347
+ residual = hidden_states
348
+ hidden_states, attn_weights, _ = self.self_attn(
349
+ hidden_states=hidden_states,
350
+ attention_mask=attention_mask,
351
+ layer_head_mask=layer_head_mask,
352
+ output_attentions=output_attentions,
353
+ )
354
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
355
+ hidden_states = residual + hidden_states
356
+ hidden_states = self.self_attn_layer_norm(hidden_states)
357
+
358
+ residual = hidden_states
359
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
360
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
361
+ hidden_states = self.fc2(hidden_states)
362
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
363
+ hidden_states = residual + hidden_states
364
+ hidden_states = self.final_layer_norm(hidden_states)
365
+
366
+ if hidden_states.dtype == torch.float16 and (
367
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
368
+ ):
369
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
370
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
371
+
372
+ outputs = (hidden_states,)
373
+
374
+ if output_attentions:
375
+ outputs += (attn_weights,)
376
+
377
+ return outputs
378
+
379
+
380
+ class BartDecoderLayer(nn.Module):
381
+ def __init__(self, config: BartConfig):
382
+ super().__init__()
383
+ self.embed_dim = config.d_model
384
+
385
+ self.self_attn = BartAttention(
386
+ embed_dim=self.embed_dim,
387
+ num_heads=config.decoder_attention_heads,
388
+ dropout=config.attention_dropout,
389
+ is_decoder=True,
390
+ )
391
+ self.dropout = config.dropout
392
+ self.activation_fn = ACT2FN[config.activation_function]
393
+ self.activation_dropout = config.activation_dropout
394
+
395
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
396
+ self.encoder_attn = BartAttention(
397
+ self.embed_dim,
398
+ config.decoder_attention_heads,
399
+ dropout=config.attention_dropout,
400
+ is_decoder=True,
401
+ )
402
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
403
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
404
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
405
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
406
+
407
+ def forward(
408
+ self,
409
+ hidden_states: torch.Tensor,
410
+ attention_mask: Optional[torch.Tensor] = None,
411
+ encoder_hidden_states: Optional[torch.Tensor] = None,
412
+ encoder_attention_mask: Optional[torch.Tensor] = None,
413
+ layer_head_mask: Optional[torch.Tensor] = None,
414
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
415
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
416
+ output_attentions: Optional[bool] = False,
417
+ use_cache: Optional[bool] = True,
418
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
419
+ """
420
+ Args:
421
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
422
+ attention_mask (`torch.FloatTensor`): attention mask of size
423
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
424
+ encoder_hidden_states (`torch.FloatTensor`):
425
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
426
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
427
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
428
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
429
+ `(encoder_attention_heads,)`.
430
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
431
+ size `(decoder_attention_heads,)`.
432
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
433
+ output_attentions (`bool`, *optional*):
434
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
435
+ returned tensors for more detail.
436
+ """
437
+ residual = hidden_states
438
+
439
+ # Self Attention
440
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
441
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
442
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
443
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
444
+ hidden_states=hidden_states,
445
+ past_key_value=self_attn_past_key_value,
446
+ attention_mask=attention_mask,
447
+ layer_head_mask=layer_head_mask,
448
+ output_attentions=output_attentions,
449
+ )
450
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
451
+ hidden_states = residual + hidden_states
452
+ hidden_states = self.self_attn_layer_norm(hidden_states)
453
+
454
+ # Cross-Attention Block
455
+ cross_attn_present_key_value = None
456
+ cross_attn_weights = None
457
+ if encoder_hidden_states is not None:
458
+ residual = hidden_states
459
+
460
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
461
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
462
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
463
+ hidden_states=hidden_states,
464
+ key_value_states=encoder_hidden_states,
465
+ attention_mask=encoder_attention_mask,
466
+ layer_head_mask=cross_attn_layer_head_mask,
467
+ past_key_value=cross_attn_past_key_value,
468
+ output_attentions=output_attentions,
469
+ )
470
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
471
+ hidden_states = residual + hidden_states
472
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
473
+
474
+ # add cross-attn to positions 3,4 of present_key_value tuple
475
+ present_key_value = present_key_value + cross_attn_present_key_value
476
+
477
+ # Fully Connected
478
+ residual = hidden_states
479
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
480
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
481
+ hidden_states = self.fc2(hidden_states)
482
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
483
+ hidden_states = residual + hidden_states
484
+ hidden_states = self.final_layer_norm(hidden_states)
485
+
486
+ outputs = (hidden_states,)
487
+
488
+ if output_attentions:
489
+ outputs += (self_attn_weights, cross_attn_weights)
490
+
491
+ if use_cache:
492
+ outputs += (present_key_value,)
493
+
494
+ return outputs
495
+
496
+
497
+ class BartClassificationHead(nn.Module):
498
+ """Head for sentence-level classification tasks."""
499
+
500
+ def __init__(
501
+ self,
502
+ input_dim: int,
503
+ inner_dim: int,
504
+ num_classes: int,
505
+ pooler_dropout: float,
506
+ ):
507
+ super().__init__()
508
+ self.dense = nn.Linear(input_dim, inner_dim)
509
+ self.dropout = nn.Dropout(p=pooler_dropout)
510
+ self.out_proj = nn.Linear(inner_dim, num_classes)
511
+
512
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
513
+ hidden_states = self.dropout(hidden_states)
514
+ hidden_states = self.dense(hidden_states)
515
+ hidden_states = torch.tanh(hidden_states)
516
+ hidden_states = self.dropout(hidden_states)
517
+ hidden_states = self.out_proj(hidden_states)
518
+ return hidden_states
519
+
520
+
521
+ class BartPreTrainedModel(PreTrainedModel):
522
+ config_class = BartConfig
523
+ base_model_prefix = "model"
524
+ supports_gradient_checkpointing = True
525
+ _keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"]
526
+ _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
527
+ _skip_keys_device_placement = "past_key_values"
528
+
529
+ def _init_weights(self, module):
530
+ std = self.config.init_std
531
+ if isinstance(module, nn.Linear):
532
+ module.weight.data.normal_(mean=0.0, std=std)
533
+ if module.bias is not None:
534
+ module.bias.data.zero_()
535
+ elif isinstance(module, nn.Embedding):
536
+ module.weight.data.normal_(mean=0.0, std=std)
537
+ if module.padding_idx is not None:
538
+ module.weight.data[module.padding_idx].zero_()
539
+
540
+ def _set_gradient_checkpointing(self, module, value=False):
541
+ if isinstance(module, (BartDecoder, BartEncoder)):
542
+ module.gradient_checkpointing = value
543
+
544
+ @property
545
+ def dummy_inputs(self):
546
+ pad_token = self.config.pad_token_id
547
+ input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
548
+ dummy_inputs = {
549
+ "attention_mask": input_ids.ne(pad_token),
550
+ "input_ids": input_ids,
551
+ }
552
+ return dummy_inputs
553
+
554
+
555
+ class PretrainedBartModel(BartPreTrainedModel):
556
+ def __init_subclass__(self):
557
+ warnings.warn(
558
+ "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.",
559
+ FutureWarning,
560
+ )
561
+
562
+
563
+ class BartPretrainedModel(BartPreTrainedModel):
564
+ def __init_subclass__(self):
565
+ warnings.warn(
566
+ "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.",
567
+ FutureWarning,
568
+ )
569
+
570
+
571
+ BART_START_DOCSTRING = r"""
572
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
573
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
574
+ etc.)
575
+
576
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
577
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
578
+ and behavior.
579
+
580
+ Parameters:
581
+ config ([`BartConfig`]):
582
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
583
+ load the weights associated with the model, only the configuration. Check out the
584
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
585
+ """
586
+
587
+ BART_GENERATION_EXAMPLE = r"""
588
+ Summarization example:
589
+
590
+ ```python
591
+ >>> from transformers import AutoTokenizer, BartForConditionalGeneration
592
+
593
+ >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
594
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
595
+
596
+ >>> ARTICLE_TO_SUMMARIZE = (
597
+ ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
598
+ ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
599
+ ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
600
+ ... )
601
+ >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt")
602
+
603
+ >>> # Generate Summary
604
+ >>> summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20)
605
+ >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
606
+ 'PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions'
607
+ ```
608
+
609
+ Mask filling example:
610
+
611
+ ```python
612
+ >>> from transformers import AutoTokenizer, BartForConditionalGeneration
613
+
614
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
615
+ >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
616
+
617
+ >>> TXT = "My friends are <mask> but they eat too many carbs."
618
+ >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
619
+ >>> logits = model(input_ids).logits
620
+
621
+ >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
622
+ >>> probs = logits[0, masked_index].softmax(dim=0)
623
+ >>> values, predictions = probs.topk(5)
624
+
625
+ >>> tokenizer.decode(predictions).split()
626
+ ['not', 'good', 'healthy', 'great', 'very']
627
+ ```
628
+ """
629
+
630
+ BART_INPUTS_DOCSTRING = r"""
631
+ Args:
632
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
633
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
634
+ it.
635
+
636
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
637
+ [`PreTrainedTokenizer.__call__`] for details.
638
+
639
+ [What are input IDs?](../glossary#input-ids)
640
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
641
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
642
+
643
+ - 1 for tokens that are **not masked**,
644
+ - 0 for tokens that are **masked**.
645
+
646
+ [What are attention masks?](../glossary#attention-mask)
647
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
648
+ Indices of decoder input sequence tokens in the vocabulary.
649
+
650
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
651
+ [`PreTrainedTokenizer.__call__`] for details.
652
+
653
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
654
+
655
+ Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
656
+ is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
657
+
658
+ For translation and summarization training, `decoder_input_ids` should be provided. If no
659
+ `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
660
+ for denoising pre-training following the paper.
661
+ decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
662
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
663
+ be used by default.
664
+
665
+ If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
666
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
667
+ information on the default strategy.
668
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
669
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
670
+
671
+ - 1 indicates the head is **not masked**,
672
+ - 0 indicates the head is **masked**.
673
+
674
+ decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
675
+ Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
676
+
677
+ - 1 indicates the head is **not masked**,
678
+ - 0 indicates the head is **masked**.
679
+
680
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
681
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
682
+ 1]`:
683
+
684
+ - 1 indicates the head is **not masked**,
685
+ - 0 indicates the head is **masked**.
686
+
687
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
688
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
689
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
690
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
691
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
692
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
693
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
694
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
695
+
696
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
697
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
698
+
699
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
700
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
701
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape
702
+ `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you
703
+ can choose to directly pass an embedded representation. This is useful if you want more control over how to
704
+ convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
705
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
706
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
707
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
708
+ input (see `past_key_values`). This is useful if you want more control over how to convert
709
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
710
+
711
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
712
+ of `inputs_embeds`.
713
+ use_cache (`bool`, *optional*):
714
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
715
+ `past_key_values`).
716
+ output_attentions (`bool`, *optional*):
717
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
718
+ tensors for more detail.
719
+ output_hidden_states (`bool`, *optional*):
720
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
721
+ more detail.
722
+ return_dict (`bool`, *optional*):
723
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
724
+ """
725
+
726
+
727
+ class BartEncoder(BartPreTrainedModel):
728
+ """
729
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
730
+ [`BartEncoderLayer`].
731
+
732
+ Args:
733
+ config: BartConfig
734
+ embed_tokens (nn.Embedding): output embedding
735
+ """
736
+
737
+ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
738
+ super().__init__(config)
739
+
740
+ self.dropout = config.dropout
741
+ self.layerdrop = config.encoder_layerdrop
742
+
743
+ embed_dim = config.d_model
744
+ self.padding_idx = config.pad_token_id
745
+ self.max_source_positions = config.max_position_embeddings
746
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
747
+
748
+ self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
749
+
750
+ if embed_tokens is not None:
751
+ self.embed_tokens.weight = embed_tokens.weight
752
+
753
+ self.embed_positions = BartLearnedPositionalEmbedding(
754
+ config.max_position_embeddings,
755
+ embed_dim,
756
+ )
757
+ self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
758
+ self.layernorm_embedding = nn.LayerNorm(embed_dim)
759
+
760
+ self.gradient_checkpointing = False
761
+ # Initialize weights and apply final processing
762
+ self.post_init()
763
+
764
+ def get_input_embeddings(self):
765
+ return self.embed_tokens
766
+
767
+ def set_input_embeddings(self, value):
768
+ self.embed_tokens = value
769
+
770
+ def forward(
771
+ self,
772
+ input_ids: torch.LongTensor = None,
773
+ attention_mask: Optional[torch.Tensor] = None,
774
+ head_mask: Optional[torch.Tensor] = None,
775
+ inputs_embeds: Optional[torch.FloatTensor] = None,
776
+ output_attentions: Optional[bool] = None,
777
+ output_hidden_states: Optional[bool] = None,
778
+ return_dict: Optional[bool] = None,
779
+ ) -> Union[Tuple, BaseModelOutput]:
780
+ r"""
781
+ Args:
782
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
783
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
784
+ provide it.
785
+
786
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
787
+ [`PreTrainedTokenizer.__call__`] for details.
788
+
789
+ [What are input IDs?](../glossary#input-ids)
790
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
791
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
792
+
793
+ - 1 for tokens that are **not masked**,
794
+ - 0 for tokens that are **masked**.
795
+
796
+ [What are attention masks?](../glossary#attention-mask)
797
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
798
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
799
+
800
+ - 1 indicates the head is **not masked**,
801
+ - 0 indicates the head is **masked**.
802
+
803
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
804
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
805
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
806
+ than the model's internal embedding lookup matrix.
807
+ output_attentions (`bool`, *optional*):
808
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
809
+ returned tensors for more detail.
810
+ output_hidden_states (`bool`, *optional*):
811
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
812
+ for more detail.
813
+ return_dict (`bool`, *optional*):
814
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
815
+ """
816
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
817
+ output_hidden_states = (
818
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
819
+ )
820
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
821
+
822
+ # retrieve input_ids and inputs_embeds
823
+ if input_ids is not None and inputs_embeds is not None:
824
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
825
+ elif input_ids is not None:
826
+ input = input_ids
827
+ input_ids = input_ids.view(-1, input_ids.shape[-1])
828
+ elif inputs_embeds is not None:
829
+ input = inputs_embeds[:, :, -1]
830
+ else:
831
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
832
+
833
+ if inputs_embeds is None:
834
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
835
+
836
+ embed_pos = self.embed_positions(input)
837
+ embed_pos = embed_pos.to(inputs_embeds.device)
838
+
839
+ hidden_states = inputs_embeds + embed_pos
840
+ hidden_states = self.layernorm_embedding(hidden_states)
841
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
842
+
843
+ # expand attention_mask
844
+ if attention_mask is not None:
845
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
846
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
847
+
848
+ encoder_states = () if output_hidden_states else None
849
+ all_attentions = () if output_attentions else None
850
+
851
+ # check if head_mask has a correct number of layers specified if desired
852
+ if head_mask is not None:
853
+ if head_mask.size()[0] != (len(self.layers)):
854
+ raise ValueError(
855
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
856
+ f" {head_mask.size()[0]}."
857
+ )
858
+
859
+ for idx, encoder_layer in enumerate(self.layers):
860
+ if output_hidden_states:
861
+ encoder_states = encoder_states + (hidden_states,)
862
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
863
+ to_drop = False
864
+ if self.training:
865
+ dropout_probability = torch.rand([])
866
+ if dropout_probability < self.layerdrop: # skip the layer
867
+ to_drop = True
868
+
869
+ if to_drop:
870
+ layer_outputs = (None, None)
871
+ else:
872
+ if self.gradient_checkpointing and self.training:
873
+
874
+ def create_custom_forward(module):
875
+ def custom_forward(*inputs):
876
+ return module(*inputs, output_attentions)
877
+
878
+ return custom_forward
879
+
880
+ layer_outputs = torch.utils.checkpoint.checkpoint(
881
+ create_custom_forward(encoder_layer),
882
+ hidden_states,
883
+ attention_mask,
884
+ (head_mask[idx] if head_mask is not None else None),
885
+ )
886
+ else:
887
+ layer_outputs = encoder_layer(
888
+ hidden_states,
889
+ attention_mask,
890
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
891
+ output_attentions=output_attentions,
892
+ )
893
+
894
+ hidden_states = layer_outputs[0]
895
+
896
+ if output_attentions:
897
+ all_attentions = all_attentions + (layer_outputs[1],)
898
+
899
+ if output_hidden_states:
900
+ encoder_states = encoder_states + (hidden_states,)
901
+
902
+ if not return_dict:
903
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
904
+ return BaseModelOutput(
905
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
906
+ )
907
+
908
+
909
+ class BartDecoder(BartPreTrainedModel):
910
+ """
911
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`]
912
+
913
+ Args:
914
+ config: BartConfig
915
+ embed_tokens (nn.Embedding): output embedding
916
+ """
917
+
918
+ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
919
+ super().__init__(config)
920
+ self.dropout = config.dropout
921
+ self.layerdrop = config.decoder_layerdrop
922
+ self.padding_idx = config.pad_token_id
923
+ self.max_target_positions = config.max_position_embeddings
924
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
925
+
926
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
927
+
928
+ if embed_tokens is not None:
929
+ self.embed_tokens.weight = embed_tokens.weight
930
+
931
+ self.embed_positions = BartLearnedPositionalEmbedding(
932
+ config.max_position_embeddings,
933
+ config.d_model,
934
+ )
935
+ self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
936
+ self.layernorm_embedding = nn.LayerNorm(config.d_model)
937
+
938
+ self.gradient_checkpointing = False
939
+ # Initialize weights and apply final processing
940
+ self.post_init()
941
+
942
+ def get_input_embeddings(self):
943
+ return self.embed_tokens
944
+
945
+ def set_input_embeddings(self, value):
946
+ self.embed_tokens = value
947
+
948
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
949
+ # create causal mask
950
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
951
+ combined_attention_mask = None
952
+ if input_shape[-1] > 1:
953
+ combined_attention_mask = _make_causal_mask(
954
+ input_shape,
955
+ inputs_embeds.dtype,
956
+ device=inputs_embeds.device,
957
+ past_key_values_length=past_key_values_length,
958
+ )
959
+
960
+ if attention_mask is not None:
961
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
962
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
963
+ inputs_embeds.device
964
+ )
965
+ combined_attention_mask = (
966
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
967
+ )
968
+
969
+ return combined_attention_mask
970
+
971
+ def forward(
972
+ self,
973
+ input_ids: torch.LongTensor = None,
974
+ attention_mask: Optional[torch.Tensor] = None,
975
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
976
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
977
+ head_mask: Optional[torch.Tensor] = None,
978
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
979
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
980
+ inputs_embeds: Optional[torch.FloatTensor] = None,
981
+ use_cache: Optional[bool] = None,
982
+ output_attentions: Optional[bool] = None,
983
+ output_hidden_states: Optional[bool] = None,
984
+ return_dict: Optional[bool] = None,
985
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
986
+ r"""
987
+ Args:
988
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
989
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
990
+ provide it.
991
+
992
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
993
+ [`PreTrainedTokenizer.__call__`] for details.
994
+
995
+ [What are input IDs?](../glossary#input-ids)
996
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
997
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
998
+
999
+ - 1 for tokens that are **not masked**,
1000
+ - 0 for tokens that are **masked**.
1001
+
1002
+ [What are attention masks?](../glossary#attention-mask)
1003
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
1004
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
1005
+ of the decoder.
1006
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
1007
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
1008
+ selected in `[0, 1]`:
1009
+
1010
+ - 1 for tokens that are **not masked**,
1011
+ - 0 for tokens that are **masked**.
1012
+
1013
+ [What are attention masks?](../glossary#attention-mask)
1014
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1015
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
1016
+
1017
+ - 1 indicates the head is **not masked**,
1018
+ - 0 indicates the head is **masked**.
1019
+
1020
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1021
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
1022
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
1023
+
1024
+ - 1 indicates the head is **not masked**,
1025
+ - 0 indicates the head is **masked**.
1026
+
1027
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1028
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1029
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
1030
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
1031
+
1032
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
1033
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1034
+
1035
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
1036
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
1037
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
1038
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
1039
+ `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
1040
+ control over how to convert `input_ids` indices into associated vectors than the model's internal
1041
+ embedding lookup matrix.
1042
+ output_attentions (`bool`, *optional*):
1043
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1044
+ returned tensors for more detail.
1045
+ output_hidden_states (`bool`, *optional*):
1046
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1047
+ for more detail.
1048
+ return_dict (`bool`, *optional*):
1049
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1050
+ """
1051
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1052
+ output_hidden_states = (
1053
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1054
+ )
1055
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1056
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1057
+
1058
+ # retrieve input_ids and inputs_embeds
1059
+ if input_ids is not None and inputs_embeds is not None:
1060
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1061
+ elif input_ids is not None:
1062
+ input = input_ids
1063
+ input_shape = input.shape
1064
+ input_ids = input_ids.view(-1, input_shape[-1])
1065
+ elif inputs_embeds is not None:
1066
+ input_shape = inputs_embeds.size()[:-1]
1067
+ input = inputs_embeds[:, :, -1]
1068
+ else:
1069
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1070
+
1071
+ # past_key_values_length
1072
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1073
+
1074
+ if inputs_embeds is None:
1075
+ inputs_embeds = self.embed_tokens(input) * self.embed_scale
1076
+
1077
+ attention_mask = self._prepare_decoder_attention_mask(
1078
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
1079
+ )
1080
+
1081
+ # expand encoder attention mask
1082
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
1083
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1084
+ encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
1085
+
1086
+ # embed positions
1087
+ positions = self.embed_positions(input, past_key_values_length)
1088
+ positions = positions.to(inputs_embeds.device)
1089
+
1090
+ hidden_states = inputs_embeds + positions
1091
+ hidden_states = self.layernorm_embedding(hidden_states)
1092
+
1093
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1094
+
1095
+ if self.gradient_checkpointing and self.training:
1096
+ if use_cache:
1097
+ logger.warning_once(
1098
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1099
+ )
1100
+ use_cache = False
1101
+
1102
+ # decoder layers
1103
+ all_hidden_states = () if output_hidden_states else None
1104
+ all_self_attns = () if output_attentions else None
1105
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
1106
+ next_decoder_cache = () if use_cache else None
1107
+
1108
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
1109
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
1110
+ if attn_mask is not None:
1111
+ if attn_mask.size()[0] != (len(self.layers)):
1112
+ raise ValueError(
1113
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
1114
+ f" {head_mask.size()[0]}."
1115
+ )
1116
+
1117
+ for idx, decoder_layer in enumerate(self.layers):
1118
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1119
+ if output_hidden_states:
1120
+ all_hidden_states += (hidden_states,)
1121
+ if self.training:
1122
+ dropout_probability = torch.rand([])
1123
+ if dropout_probability < self.layerdrop:
1124
+ continue
1125
+
1126
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
1127
+
1128
+ if self.gradient_checkpointing and self.training:
1129
+
1130
+ def create_custom_forward(module):
1131
+ def custom_forward(*inputs):
1132
+ # None for past_key_value
1133
+ return module(*inputs, output_attentions, use_cache)
1134
+
1135
+ return custom_forward
1136
+
1137
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1138
+ create_custom_forward(decoder_layer),
1139
+ hidden_states,
1140
+ attention_mask,
1141
+ encoder_hidden_states,
1142
+ encoder_attention_mask,
1143
+ head_mask[idx] if head_mask is not None else None,
1144
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
1145
+ None,
1146
+ )
1147
+ else:
1148
+ layer_outputs = decoder_layer(
1149
+ hidden_states,
1150
+ attention_mask=attention_mask,
1151
+ encoder_hidden_states=encoder_hidden_states,
1152
+ encoder_attention_mask=encoder_attention_mask,
1153
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
1154
+ cross_attn_layer_head_mask=(
1155
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
1156
+ ),
1157
+ past_key_value=past_key_value,
1158
+ output_attentions=output_attentions,
1159
+ use_cache=use_cache,
1160
+ )
1161
+ hidden_states = layer_outputs[0]
1162
+
1163
+ if use_cache:
1164
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
1165
+
1166
+ if output_attentions:
1167
+ all_self_attns += (layer_outputs[1],)
1168
+
1169
+ if encoder_hidden_states is not None:
1170
+ all_cross_attentions += (layer_outputs[2],)
1171
+
1172
+ # add hidden states from the last decoder layer
1173
+ if output_hidden_states:
1174
+ all_hidden_states += (hidden_states,)
1175
+
1176
+ next_cache = next_decoder_cache if use_cache else None
1177
+ if not return_dict:
1178
+ return tuple(
1179
+ v
1180
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
1181
+ if v is not None
1182
+ )
1183
+ return BaseModelOutputWithPastAndCrossAttentions(
1184
+ last_hidden_state=hidden_states,
1185
+ past_key_values=next_cache,
1186
+ hidden_states=all_hidden_states,
1187
+ attentions=all_self_attns,
1188
+ cross_attentions=all_cross_attentions,
1189
+ )
1190
+
1191
+
1192
+ @add_start_docstrings(
1193
+ "The bare BART Model outputting raw hidden-states without any specific head on top.",
1194
+ BART_START_DOCSTRING,
1195
+ )
1196
+ class BartModel(BartPreTrainedModel):
1197
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
1198
+
1199
+ def __init__(self, config: BartConfig):
1200
+ super().__init__(config)
1201
+
1202
+ padding_idx, vocab_size = config.pad_token_id, config.vocab_size
1203
+ self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
1204
+
1205
+ self.encoder = BartEncoder(config, self.shared)
1206
+ self.decoder = BartDecoder(config, self.shared)
1207
+
1208
+ # Initialize weights and apply final processing
1209
+ self.post_init()
1210
+
1211
+ def get_input_embeddings(self):
1212
+ return self.shared
1213
+
1214
+ def set_input_embeddings(self, value):
1215
+ self.shared = value
1216
+ self.encoder.embed_tokens = self.shared
1217
+ self.decoder.embed_tokens = self.shared
1218
+
1219
+ def get_encoder(self):
1220
+ return self.encoder
1221
+
1222
+ def get_decoder(self):
1223
+ return self.decoder
1224
+
1225
+ @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
1226
+ @add_code_sample_docstrings(
1227
+ checkpoint=_CHECKPOINT_FOR_DOC,
1228
+ output_type=Seq2SeqModelOutput,
1229
+ config_class=_CONFIG_FOR_DOC,
1230
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
1231
+ )
1232
+ def forward(
1233
+ self,
1234
+ input_ids: torch.LongTensor = None,
1235
+ attention_mask: Optional[torch.Tensor] = None,
1236
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1237
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1238
+ head_mask: Optional[torch.Tensor] = None,
1239
+ decoder_head_mask: Optional[torch.Tensor] = None,
1240
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1241
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
1242
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1243
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1244
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1245
+ use_cache: Optional[bool] = None,
1246
+ output_attentions: Optional[bool] = None,
1247
+ output_hidden_states: Optional[bool] = None,
1248
+ return_dict: Optional[bool] = None,
1249
+ ) -> Union[Tuple, Seq2SeqModelOutput]:
1250
+ # different to other models, Bart automatically creates decoder_input_ids from
1251
+ # input_ids if no decoder_input_ids are provided
1252
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
1253
+ if input_ids is None:
1254
+ raise ValueError(
1255
+ "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
1256
+ "passed, `input_ids` cannot be `None`. Please pass either "
1257
+ "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
1258
+ )
1259
+
1260
+ decoder_input_ids = shift_tokens_right(
1261
+ input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
1262
+ )
1263
+
1264
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1265
+ output_hidden_states = (
1266
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1267
+ )
1268
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1269
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1270
+
1271
+ if encoder_outputs is None:
1272
+ encoder_outputs = self.encoder(
1273
+ input_ids=input_ids,
1274
+ attention_mask=attention_mask,
1275
+ head_mask=head_mask,
1276
+ inputs_embeds=inputs_embeds,
1277
+ output_attentions=output_attentions,
1278
+ output_hidden_states=output_hidden_states,
1279
+ return_dict=return_dict,
1280
+ )
1281
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1282
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1283
+ encoder_outputs = BaseModelOutput(
1284
+ last_hidden_state=encoder_outputs[0],
1285
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1286
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1287
+ )
1288
+
1289
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
1290
+ decoder_outputs = self.decoder(
1291
+ input_ids=decoder_input_ids,
1292
+ attention_mask=decoder_attention_mask,
1293
+ encoder_hidden_states=encoder_outputs[0],
1294
+ encoder_attention_mask=attention_mask,
1295
+ head_mask=decoder_head_mask,
1296
+ cross_attn_head_mask=cross_attn_head_mask,
1297
+ past_key_values=past_key_values,
1298
+ inputs_embeds=decoder_inputs_embeds,
1299
+ use_cache=use_cache,
1300
+ output_attentions=output_attentions,
1301
+ output_hidden_states=output_hidden_states,
1302
+ return_dict=return_dict,
1303
+ )
1304
+
1305
+ if not return_dict:
1306
+ return decoder_outputs + encoder_outputs
1307
+
1308
+ return Seq2SeqModelOutput(
1309
+ last_hidden_state=decoder_outputs.last_hidden_state,
1310
+ past_key_values=decoder_outputs.past_key_values,
1311
+ decoder_hidden_states=decoder_outputs.hidden_states,
1312
+ decoder_attentions=decoder_outputs.attentions,
1313
+ cross_attentions=decoder_outputs.cross_attentions,
1314
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1315
+ encoder_hidden_states=encoder_outputs.hidden_states,
1316
+ encoder_attentions=encoder_outputs.attentions,
1317
+ )
1318
+
1319
+
1320
+ @add_start_docstrings(
1321
+ "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
1322
+ )
1323
+ class BartForConditionalGeneration(BartPreTrainedModel):
1324
+ base_model_prefix = "model"
1325
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
1326
+ _keys_to_ignore_on_load_missing = ["final_logits_bias"]
1327
+
1328
+ def __init__(self, config: BartConfig):
1329
+ super().__init__(config)
1330
+ self.model = BartModel(config)
1331
+ self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
1332
+ self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
1333
+
1334
+ # Initialize weights and apply final processing
1335
+ self.post_init()
1336
+
1337
+ def get_encoder(self):
1338
+ return self.model.get_encoder()
1339
+
1340
+ def get_decoder(self):
1341
+ return self.model.get_decoder()
1342
+
1343
+ def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
1344
+ new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
1345
+ self._resize_final_logits_bias(new_embeddings.weight.shape[0])
1346
+ return new_embeddings
1347
+
1348
+ def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
1349
+ old_num_tokens = self.final_logits_bias.shape[-1]
1350
+ if new_num_tokens <= old_num_tokens:
1351
+ new_bias = self.final_logits_bias[:, :new_num_tokens]
1352
+ else:
1353
+ extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
1354
+ new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
1355
+ self.register_buffer("final_logits_bias", new_bias)
1356
+
1357
+ def get_output_embeddings(self):
1358
+ return self.lm_head
1359
+
1360
+ def set_output_embeddings(self, new_embeddings):
1361
+ self.lm_head = new_embeddings
1362
+
1363
+ @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
1364
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1365
+ @add_end_docstrings(BART_GENERATION_EXAMPLE)
1366
+ def forward(
1367
+ self,
1368
+ input_ids: torch.LongTensor = None,
1369
+ attention_mask: Optional[torch.Tensor] = None,
1370
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1371
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1372
+ head_mask: Optional[torch.Tensor] = None,
1373
+ decoder_head_mask: Optional[torch.Tensor] = None,
1374
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1375
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
1376
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1377
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1378
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1379
+ labels: Optional[torch.LongTensor] = None,
1380
+ use_cache: Optional[bool] = None,
1381
+ output_attentions: Optional[bool] = None,
1382
+ output_hidden_states: Optional[bool] = None,
1383
+ return_dict: Optional[bool] = None,
1384
+ ) -> Union[Tuple, Seq2SeqLMOutput]:
1385
+ r"""
1386
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1387
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1388
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1389
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1390
+
1391
+ Returns:
1392
+ """
1393
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1394
+
1395
+ if labels is not None:
1396
+ if use_cache:
1397
+ logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
1398
+ use_cache = False
1399
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
1400
+ decoder_input_ids = shift_tokens_right(
1401
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
1402
+ )
1403
+
1404
+ outputs = self.model(
1405
+ input_ids,
1406
+ attention_mask=attention_mask,
1407
+ decoder_input_ids=decoder_input_ids,
1408
+ encoder_outputs=encoder_outputs,
1409
+ decoder_attention_mask=decoder_attention_mask,
1410
+ head_mask=head_mask,
1411
+ decoder_head_mask=decoder_head_mask,
1412
+ cross_attn_head_mask=cross_attn_head_mask,
1413
+ past_key_values=past_key_values,
1414
+ inputs_embeds=inputs_embeds,
1415
+ decoder_inputs_embeds=decoder_inputs_embeds,
1416
+ use_cache=use_cache,
1417
+ output_attentions=output_attentions,
1418
+ output_hidden_states=output_hidden_states,
1419
+ return_dict=return_dict,
1420
+ )
1421
+
1422
+ lm_logits = self.lm_head(outputs[0])
1423
+ lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
1424
+
1425
+ masked_lm_loss = None
1426
+ if labels is not None:
1427
+ labels = labels.to(lm_logits.device)
1428
+ loss_fct = CrossEntropyLoss()
1429
+ masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
1430
+
1431
+ if not return_dict:
1432
+ output = (lm_logits,) + outputs[1:]
1433
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1434
+
1435
+ return Seq2SeqLMOutput(
1436
+ loss=masked_lm_loss,
1437
+ logits=lm_logits,
1438
+ past_key_values=outputs.past_key_values,
1439
+ decoder_hidden_states=outputs.decoder_hidden_states,
1440
+ decoder_attentions=outputs.decoder_attentions,
1441
+ cross_attentions=outputs.cross_attentions,
1442
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1443
+ encoder_hidden_states=outputs.encoder_hidden_states,
1444
+ encoder_attentions=outputs.encoder_attentions,
1445
+ )
1446
+
1447
+ def prepare_inputs_for_generation(
1448
+ self,
1449
+ decoder_input_ids,
1450
+ past_key_values=None,
1451
+ attention_mask=None,
1452
+ decoder_attention_mask=None,
1453
+ head_mask=None,
1454
+ decoder_head_mask=None,
1455
+ cross_attn_head_mask=None,
1456
+ use_cache=None,
1457
+ encoder_outputs=None,
1458
+ **kwargs,
1459
+ ):
1460
+ # cut decoder_input_ids if past_key_values is used
1461
+ if past_key_values is not None:
1462
+ decoder_input_ids = decoder_input_ids[:, -1:]
1463
+
1464
+ return {
1465
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
1466
+ "encoder_outputs": encoder_outputs,
1467
+ "past_key_values": past_key_values,
1468
+ "decoder_input_ids": decoder_input_ids,
1469
+ "attention_mask": attention_mask,
1470
+ "decoder_attention_mask": decoder_attention_mask,
1471
+ "head_mask": head_mask,
1472
+ "decoder_head_mask": decoder_head_mask,
1473
+ "cross_attn_head_mask": cross_attn_head_mask,
1474
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
1475
+ }
1476
+
1477
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1478
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
1479
+
1480
+ @staticmethod
1481
+ def _reorder_cache(past_key_values, beam_idx):
1482
+ reordered_past = ()
1483
+ for layer_past in past_key_values:
1484
+ # cached cross_attention states don't have to be reordered -> they are always the same
1485
+ reordered_past += (
1486
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
1487
+ + layer_past[2:],
1488
+ )
1489
+ return reordered_past
1490
+
1491
+
1492
+ @add_start_docstrings(
1493
+ """
1494
+ Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
1495
+ tasks.
1496
+ """,
1497
+ BART_START_DOCSTRING,
1498
+ )
1499
+ class BartForSequenceClassification(BartPreTrainedModel):
1500
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
1501
+
1502
+ def __init__(self, config: BartConfig, **kwargs):
1503
+ super().__init__(config, **kwargs)
1504
+ self.model = BartModel(config)
1505
+ self.classification_head = BartClassificationHead(
1506
+ config.d_model,
1507
+ config.d_model,
1508
+ config.num_labels,
1509
+ config.classifier_dropout,
1510
+ )
1511
+
1512
+ # Initialize weights and apply final processing
1513
+ self.post_init()
1514
+
1515
+ @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
1516
+ @add_code_sample_docstrings(
1517
+ checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
1518
+ output_type=Seq2SeqSequenceClassifierOutput,
1519
+ config_class=_CONFIG_FOR_DOC,
1520
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
1521
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
1522
+ )
1523
+ def forward(
1524
+ self,
1525
+ input_ids: torch.LongTensor = None,
1526
+ attention_mask: Optional[torch.Tensor] = None,
1527
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1528
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1529
+ head_mask: Optional[torch.Tensor] = None,
1530
+ decoder_head_mask: Optional[torch.Tensor] = None,
1531
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1532
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
1533
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1534
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1535
+ labels: Optional[torch.LongTensor] = None,
1536
+ use_cache: Optional[bool] = None,
1537
+ output_attentions: Optional[bool] = None,
1538
+ output_hidden_states: Optional[bool] = None,
1539
+ return_dict: Optional[bool] = None,
1540
+ ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
1541
+ r"""
1542
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1543
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1544
+ config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1545
+ """
1546
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1547
+ if labels is not None:
1548
+ use_cache = False
1549
+
1550
+ if input_ids is None and inputs_embeds is not None:
1551
+ raise NotImplementedError(
1552
+ f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
1553
+ )
1554
+
1555
+ outputs = self.model(
1556
+ input_ids,
1557
+ attention_mask=attention_mask,
1558
+ decoder_input_ids=decoder_input_ids,
1559
+ decoder_attention_mask=decoder_attention_mask,
1560
+ head_mask=head_mask,
1561
+ decoder_head_mask=decoder_head_mask,
1562
+ cross_attn_head_mask=cross_attn_head_mask,
1563
+ encoder_outputs=encoder_outputs,
1564
+ inputs_embeds=inputs_embeds,
1565
+ decoder_inputs_embeds=decoder_inputs_embeds,
1566
+ use_cache=use_cache,
1567
+ output_attentions=output_attentions,
1568
+ output_hidden_states=output_hidden_states,
1569
+ return_dict=return_dict,
1570
+ )
1571
+ hidden_states = outputs[0] # last hidden state
1572
+
1573
+ eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
1574
+
1575
+ if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
1576
+ raise ValueError("All examples must have the same number of <eos> tokens.")
1577
+ sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
1578
+ :, -1, :
1579
+ ]
1580
+ logits = self.classification_head(sentence_representation)
1581
+
1582
+ loss = None
1583
+ if labels is not None:
1584
+ labels = labels.to(logits.device)
1585
+ if self.config.problem_type is None:
1586
+ if self.config.num_labels == 1:
1587
+ self.config.problem_type = "regression"
1588
+ elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1589
+ self.config.problem_type = "single_label_classification"
1590
+ else:
1591
+ self.config.problem_type = "multi_label_classification"
1592
+
1593
+ if self.config.problem_type == "regression":
1594
+ loss_fct = MSELoss()
1595
+ if self.config.num_labels == 1:
1596
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1597
+ else:
1598
+ loss = loss_fct(logits, labels)
1599
+ elif self.config.problem_type == "single_label_classification":
1600
+ loss_fct = CrossEntropyLoss()
1601
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
1602
+ elif self.config.problem_type == "multi_label_classification":
1603
+ loss_fct = BCEWithLogitsLoss()
1604
+ loss = loss_fct(logits, labels)
1605
+ if not return_dict:
1606
+ output = (logits,) + outputs[1:]
1607
+ return ((loss,) + output) if loss is not None else output
1608
+
1609
+ return Seq2SeqSequenceClassifierOutput(
1610
+ loss=loss,
1611
+ logits=logits,
1612
+ past_key_values=outputs.past_key_values,
1613
+ decoder_hidden_states=outputs.decoder_hidden_states,
1614
+ decoder_attentions=outputs.decoder_attentions,
1615
+ cross_attentions=outputs.cross_attentions,
1616
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1617
+ encoder_hidden_states=outputs.encoder_hidden_states,
1618
+ encoder_attentions=outputs.encoder_attentions,
1619
+ )
1620
+
1621
+
1622
+ @add_start_docstrings(
1623
+ """
1624
+ BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1625
+ layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1626
+ """,
1627
+ BART_START_DOCSTRING,
1628
+ )
1629
+ class BartForQuestionAnswering(BartPreTrainedModel):
1630
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
1631
+
1632
+ def __init__(self, config):
1633
+ super().__init__(config)
1634
+
1635
+ config.num_labels = 2
1636
+ self.num_labels = config.num_labels
1637
+
1638
+ self.model = BartModel(config)
1639
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1640
+
1641
+ # Initialize weights and apply final processing
1642
+ self.post_init()
1643
+
1644
+ @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
1645
+ @add_code_sample_docstrings(
1646
+ checkpoint=_CHECKPOINT_FOR_QA,
1647
+ output_type=Seq2SeqQuestionAnsweringModelOutput,
1648
+ config_class=_CONFIG_FOR_DOC,
1649
+ expected_loss=_QA_EXPECTED_LOSS,
1650
+ expected_output=_QA_EXPECTED_OUTPUT,
1651
+ )
1652
+ def forward(
1653
+ self,
1654
+ input_ids: torch.Tensor = None,
1655
+ attention_mask: Optional[torch.Tensor] = None,
1656
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1657
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1658
+ head_mask: Optional[torch.Tensor] = None,
1659
+ decoder_head_mask: Optional[torch.Tensor] = None,
1660
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1661
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
1662
+ start_positions: Optional[torch.LongTensor] = None,
1663
+ end_positions: Optional[torch.LongTensor] = None,
1664
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1665
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1666
+ use_cache: Optional[bool] = None,
1667
+ output_attentions: Optional[bool] = None,
1668
+ output_hidden_states: Optional[bool] = None,
1669
+ return_dict: Optional[bool] = None,
1670
+ ) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]:
1671
+ r"""
1672
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1673
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1674
+ Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
1675
+ are not taken into account for computing the loss.
1676
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1677
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1678
+ Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
1679
+ are not taken into account for computing the loss.
1680
+ """
1681
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1682
+ if start_positions is not None and end_positions is not None:
1683
+ use_cache = False
1684
+
1685
+ outputs = self.model(
1686
+ input_ids,
1687
+ attention_mask=attention_mask,
1688
+ decoder_input_ids=decoder_input_ids,
1689
+ decoder_attention_mask=decoder_attention_mask,
1690
+ head_mask=head_mask,
1691
+ decoder_head_mask=decoder_head_mask,
1692
+ cross_attn_head_mask=cross_attn_head_mask,
1693
+ encoder_outputs=encoder_outputs,
1694
+ inputs_embeds=inputs_embeds,
1695
+ decoder_inputs_embeds=decoder_inputs_embeds,
1696
+ use_cache=use_cache,
1697
+ output_attentions=output_attentions,
1698
+ output_hidden_states=output_hidden_states,
1699
+ return_dict=return_dict,
1700
+ )
1701
+
1702
+ sequence_output = outputs[0]
1703
+
1704
+ logits = self.qa_outputs(sequence_output)
1705
+ start_logits, end_logits = logits.split(1, dim=-1)
1706
+ start_logits = start_logits.squeeze(-1).contiguous()
1707
+ end_logits = end_logits.squeeze(-1).contiguous()
1708
+
1709
+ total_loss = None
1710
+ if start_positions is not None and end_positions is not None:
1711
+ # If we are on multi-GPU, split add a dimension
1712
+ if len(start_positions.size()) > 1:
1713
+ start_positions = start_positions.squeeze(-1)
1714
+ if len(end_positions.size()) > 1:
1715
+ end_positions = end_positions.squeeze(-1)
1716
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1717
+ ignored_index = start_logits.size(1)
1718
+ start_positions = start_positions.clamp(0, ignored_index)
1719
+ end_positions = end_positions.clamp(0, ignored_index)
1720
+
1721
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1722
+ start_loss = loss_fct(start_logits, start_positions)
1723
+ end_loss = loss_fct(end_logits, end_positions)
1724
+ total_loss = (start_loss + end_loss) / 2
1725
+
1726
+ if not return_dict:
1727
+ output = (
1728
+ start_logits,
1729
+ end_logits,
1730
+ ) + outputs[1:]
1731
+ return ((total_loss,) + output) if total_loss is not None else output
1732
+
1733
+ return Seq2SeqQuestionAnsweringModelOutput(
1734
+ loss=total_loss,
1735
+ start_logits=start_logits,
1736
+ end_logits=end_logits,
1737
+ past_key_values=outputs.past_key_values,
1738
+ decoder_hidden_states=outputs.decoder_hidden_states,
1739
+ decoder_attentions=outputs.decoder_attentions,
1740
+ cross_attentions=outputs.cross_attentions,
1741
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1742
+ encoder_hidden_states=outputs.encoder_hidden_states,
1743
+ encoder_attentions=outputs.encoder_attentions,
1744
+ )
1745
+
1746
+
1747
+ class BartDecoderWrapper(BartPreTrainedModel):
1748
+ """
1749
+ This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
1750
+ used in combination with the [`EncoderDecoderModel`] framework.
1751
+ """
1752
+
1753
+ def __init__(self, config):
1754
+ super().__init__(config)
1755
+ self.decoder = BartDecoder(config)
1756
+
1757
+ def forward(self, *args, **kwargs):
1758
+ return self.decoder(*args, **kwargs)
1759
+
1760
+
1761
+ @add_start_docstrings(
1762
+ """
1763
+ BART decoder with with a language modeling head on top (linear layer with weights tied to the input embeddings).
1764
+ """,
1765
+ BART_START_DOCSTRING,
1766
+ )
1767
+ class BartForCausalLM(BartPreTrainedModel):
1768
+ _tied_weights_keys = ["lm_head.weight"]
1769
+
1770
+ def __init__(self, config):
1771
+ config = copy.deepcopy(config)
1772
+ config.is_decoder = True
1773
+ config.is_encoder_decoder = False
1774
+ super().__init__(config)
1775
+ self.model = BartDecoderWrapper(config)
1776
+
1777
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1778
+
1779
+ # Initialize weights and apply final processing
1780
+ self.post_init()
1781
+
1782
+ def get_input_embeddings(self):
1783
+ return self.model.decoder.embed_tokens
1784
+
1785
+ def set_input_embeddings(self, value):
1786
+ self.model.decoder.embed_tokens = value
1787
+
1788
+ def get_output_embeddings(self):
1789
+ return self.lm_head
1790
+
1791
+ def set_output_embeddings(self, new_embeddings):
1792
+ self.lm_head = new_embeddings
1793
+
1794
+ def set_decoder(self, decoder):
1795
+ self.model.decoder = decoder
1796
+
1797
+ def get_decoder(self):
1798
+ return self.model.decoder
1799
+
1800
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
1801
+ def forward(
1802
+ self,
1803
+ input_ids: torch.LongTensor = None,
1804
+ attention_mask: Optional[torch.Tensor] = None,
1805
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1806
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1807
+ head_mask: Optional[torch.Tensor] = None,
1808
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1809
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1810
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1811
+ labels: Optional[torch.LongTensor] = None,
1812
+ use_cache: Optional[bool] = None,
1813
+ output_attentions: Optional[bool] = None,
1814
+ output_hidden_states: Optional[bool] = None,
1815
+ return_dict: Optional[bool] = None,
1816
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1817
+ r"""
1818
+ Args:
1819
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1820
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
1821
+ provide it.
1822
+
1823
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1824
+ [`PreTrainedTokenizer.__call__`] for details.
1825
+
1826
+ [What are input IDs?](../glossary#input-ids)
1827
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1828
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1829
+
1830
+ - 1 for tokens that are **not masked**,
1831
+ - 0 for tokens that are **masked**.
1832
+
1833
+ [What are attention masks?](../glossary#attention-mask)
1834
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1835
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
1836
+ if the model is configured as a decoder.
1837
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1838
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
1839
+ in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1840
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1841
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
1842
+
1843
+ - 1 indicates the head is **not masked**,
1844
+ - 0 indicates the head is **masked**.
1845
+
1846
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1847
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
1848
+
1849
+ - 1 indicates the head is **not masked**,
1850
+ - 0 indicates the head is **masked**.
1851
+
1852
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1853
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1854
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
1855
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
1856
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
1857
+
1858
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
1859
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1860
+
1861
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
1862
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
1863
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1864
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1865
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1866
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1867
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1868
+ use_cache (`bool`, *optional*):
1869
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1870
+ (see `past_key_values`).
1871
+
1872
+ - 1 for tokens that are **not masked**,
1873
+ - 0 for tokens that are **masked**.
1874
+ output_attentions (`bool`, *optional*):
1875
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1876
+ returned tensors for more detail.
1877
+ output_hidden_states (`bool`, *optional*):
1878
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1879
+ for more detail.
1880
+ return_dict (`bool`, *optional*):
1881
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1882
+
1883
+ Returns:
1884
+
1885
+ Example:
1886
+
1887
+ ```python
1888
+ >>> from transformers import AutoTokenizer, BartForCausalLM
1889
+
1890
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
1891
+ >>> model = BartForCausalLM.from_pretrained("facebook/bart-base", add_cross_attention=False)
1892
+ >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
1893
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1894
+ >>> outputs = model(**inputs)
1895
+
1896
+ >>> logits = outputs.logits
1897
+ >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
1898
+ >>> list(logits.shape) == expected_shape
1899
+ True
1900
+ ```"""
1901
+
1902
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1903
+ output_hidden_states = (
1904
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1905
+ )
1906
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1907
+
1908
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1909
+ outputs = self.model.decoder(
1910
+ input_ids=input_ids,
1911
+ attention_mask=attention_mask,
1912
+ encoder_hidden_states=encoder_hidden_states,
1913
+ encoder_attention_mask=encoder_attention_mask,
1914
+ head_mask=head_mask,
1915
+ cross_attn_head_mask=cross_attn_head_mask,
1916
+ past_key_values=past_key_values,
1917
+ inputs_embeds=inputs_embeds,
1918
+ use_cache=use_cache,
1919
+ output_attentions=output_attentions,
1920
+ output_hidden_states=output_hidden_states,
1921
+ return_dict=return_dict,
1922
+ )
1923
+
1924
+ logits = self.lm_head(outputs[0])
1925
+
1926
+ loss = None
1927
+ if labels is not None:
1928
+ labels = labels.to(logits.device)
1929
+ loss_fct = CrossEntropyLoss()
1930
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
1931
+
1932
+ if not return_dict:
1933
+ output = (logits,) + outputs[1:]
1934
+ return (loss,) + output if loss is not None else output
1935
+
1936
+ return CausalLMOutputWithCrossAttentions(
1937
+ loss=loss,
1938
+ logits=logits,
1939
+ past_key_values=outputs.past_key_values,
1940
+ hidden_states=outputs.hidden_states,
1941
+ attentions=outputs.attentions,
1942
+ cross_attentions=outputs.cross_attentions,
1943
+ )
1944
+
1945
+ def prepare_inputs_for_generation(
1946
+ self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
1947
+ ):
1948
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1949
+ if attention_mask is None:
1950
+ attention_mask = input_ids.new_ones(input_ids.shape)
1951
+
1952
+ if past_key_values:
1953
+ input_ids = input_ids[:, -1:]
1954
+ # first step, decoder_cached_states are empty
1955
+ return {
1956
+ "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
1957
+ "attention_mask": attention_mask,
1958
+ "past_key_values": past_key_values,
1959
+ "use_cache": use_cache,
1960
+ }
1961
+
1962
+ @staticmethod
1963
+ def _reorder_cache(past_key_values, beam_idx):
1964
+ reordered_past = ()
1965
+ for layer_past in past_key_values:
1966
+ reordered_past += (
1967
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1968
+ )
1969
+ return reordered_past