vilsonrodrigues commited on
Commit
0e7ea20
1 Parent(s): 5f278b3

deprecated

Browse files
Files changed (1) hide show
  1. modelling_RW.py +0 -1095
modelling_RW.py DELETED
@@ -1,1095 +0,0 @@
1
- # port of models described in RW
2
- # We use the bloom model as a starting point for these model.
3
- # Please refer to the bloom models for usage instructions.
4
-
5
- import math
6
- import warnings
7
- from typing import Optional, Tuple, Union
8
-
9
- import torch
10
- import torch.utils.checkpoint
11
- from torch import nn
12
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
13
- from torch.nn import functional as F
14
-
15
- from transformers.modeling_outputs import (
16
- BaseModelOutputWithPastAndCrossAttentions,
17
- CausalLMOutputWithCrossAttentions,
18
- QuestionAnsweringModelOutput,
19
- SequenceClassifierOutputWithPast,
20
- TokenClassifierOutput,
21
- )
22
- from transformers.modeling_utils import PreTrainedModel
23
- from transformers.utils import logging
24
- from .configuration_RW import RWConfig
25
-
26
- logger = logging.get_logger(__name__)
27
-
28
- # NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
29
- # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
30
- class Linear(nn.Linear):
31
- def forward(self, input: torch.Tensor) -> torch.Tensor:
32
- ret = input @ self.weight.T
33
- if self.bias is None:
34
- return ret
35
- else:
36
- return ret + self.bias
37
-
38
-
39
- from einops import rearrange
40
-
41
- # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
42
- def rotate_half(x):
43
- x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
44
- return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0
45
-
46
-
47
- class RotaryEmbedding(torch.nn.Module):
48
- """Implementation of RotaryEmbedding from GPT-NeoX.
49
- This implementation is design to operate on queries and keys that are compatible with
50
- [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format).
51
- """
52
-
53
- def __init__(
54
- self,
55
- head_dim: int,
56
- base=10000,
57
- ):
58
- super().__init__()
59
- inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
60
- self.register_buffer("inv_freq", inv_freq, persistent=False)
61
- self.head_dim = head_dim
62
- self.seq_len_cached = None
63
- self.batch_size_cached = None
64
- self.cos_cached: torch.Tensor | None = None
65
- self.sin_cached: torch.Tensor | None = None
66
-
67
- def cos_sin(
68
- self,
69
- seq_len: int,
70
- device="cuda",
71
- dtype=torch.bfloat16,
72
- ) -> torch.Tensor:
73
- if seq_len != self.seq_len_cached:
74
- self.seq_len_cached = seq_len
75
- t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
76
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
77
- emb = torch.cat((freqs, freqs), dim=-1).to(device)
78
-
79
- if dtype in [torch.float16, torch.bfloat16]:
80
- emb = emb.float()
81
-
82
- self.cos_cached = emb.cos()[None, :, :]
83
- self.sin_cached = emb.sin()[None, :, :]
84
-
85
- self.cos_cached = self.cos_cached.type(dtype)
86
- self.sin_cached = self.sin_cached.type(dtype)
87
-
88
- return self.cos_cached, self.sin_cached
89
-
90
- def forward(self, q, k):
91
- batch, seq_len, head_dim = q.shape
92
- cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
93
- return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
94
-
95
-
96
- def _make_causal_mask(
97
- input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
98
- ) -> torch.BoolTensor:
99
- batch_size, target_length = input_ids_shape
100
- mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
101
- # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
102
- seq_ids = torch.arange(target_length, device=device)
103
- mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
104
-
105
- if past_key_values_length > 0:
106
- mask[:, :past_key_values_length] = False
107
-
108
- expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
109
- return expanded_mask
110
-
111
-
112
- def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
113
- batch_size, src_length = mask.shape
114
- tgt_length = tgt_length if tgt_length is not None else src_length
115
-
116
- expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
117
- return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
118
-
119
-
120
- def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
121
- batch_size, seq_length = attention_mask.shape
122
- closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
123
- base = torch.tensor(
124
- 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
125
- )
126
- powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
127
- slopes = torch.pow(base, powers)
128
-
129
- if closest_power_of_2 != num_heads:
130
- extra_base = torch.tensor(
131
- 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
132
- )
133
- num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
134
- extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
135
- slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
136
-
137
- # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
138
- # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
139
- # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
140
- # => the query_length dimension will then be broadcasted correctly
141
- # This is more or less identical to T5's relative position bias:
142
- # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
143
- arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
144
- alibi = slopes[..., None].bfloat16() * arange_tensor
145
- return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
146
-
147
-
148
- def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
149
- out = F.dropout(x, p=prob, training=training)
150
- out = residual + out
151
- return out
152
-
153
-
154
- class Attention(nn.Module):
155
- def __init__(self, config: RWConfig):
156
- super().__init__()
157
-
158
- self.hidden_size = config.hidden_size
159
- self.num_heads = config.n_head
160
- self.head_dim = self.hidden_size // self.num_heads
161
- self.split_size = self.hidden_size
162
- self.hidden_dropout = config.hidden_dropout
163
-
164
- if self.head_dim * self.num_heads != self.hidden_size:
165
- raise ValueError(
166
- f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
167
- f" {self.num_heads})."
168
- )
169
-
170
- self.maybe_rotary = RotaryEmbedding(config.head_dim) if config.rotary else lambda q, k: (q, k)
171
-
172
- # Layer-wise attention scaling
173
- self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
174
- self.beta = self.inv_norm_factor
175
-
176
- self.query_key_value = Linear(
177
- self.hidden_size,
178
- 3 * self.hidden_size if not config.multi_query else (self.hidden_size + 2 * self.head_dim),
179
- bias=config.bias,
180
- )
181
- self.multi_query = config.multi_query
182
- self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
183
- self.attention_dropout = nn.Dropout(config.attention_dropout)
184
- self.num_kv = config.n_head if not self.multi_query else 1
185
-
186
- def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
187
- """
188
- Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
189
- storage as `fused_qkv`
190
- Args:
191
- fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
192
- Returns:
193
- query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
194
- value: [batch_size, seq_length, num_heads, head_dim]
195
- """
196
- if not self.multi_query:
197
- batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
198
- fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
199
- return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
200
- else:
201
- batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
202
- fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
203
- return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
204
-
205
- def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
206
- """
207
- Merge heads together over the last dimenstion
208
- Args:
209
- x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
210
- Returns:
211
- torch.tensor: [batch_size, seq_length, num_heads * head_dim]
212
- """
213
- # What we want to achieve is:
214
- # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
215
- batch_size_and_num_heads, seq_length, _ = x.shape
216
- batch_size = batch_size_and_num_heads // self.num_heads
217
-
218
- # First view to decompose the batch size
219
- # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
220
- x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
221
-
222
- # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
223
- x = x.permute(0, 2, 1, 3)
224
-
225
- # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
226
- return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
227
-
228
- def forward(
229
- self,
230
- hidden_states: torch.Tensor,
231
- alibi: torch.Tensor,
232
- attention_mask: torch.Tensor,
233
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
234
- head_mask: Optional[torch.Tensor] = None,
235
- use_cache: bool = False,
236
- output_attentions: bool = False,
237
- ):
238
- fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
239
-
240
- # 3 x [batch_size, seq_length, num_heads, head_dim]
241
- (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
242
-
243
- batch_size, q_length, _, _ = query_layer.shape
244
-
245
- query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
246
- key_layer = key_layer.transpose(1, 2).reshape(
247
- batch_size * self.num_kv,
248
- q_length,
249
- self.head_dim,
250
- )
251
- value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
252
-
253
- query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
254
-
255
- if layer_past is not None:
256
- past_key, past_value = layer_past
257
- # concatenate along seq_length dimension:
258
- # - key: [batch_size * self.num_heads, head_dim, kv_length]
259
- # - value: [batch_size * self.num_heads, kv_length, head_dim]
260
- key_layer = torch.cat((past_key, key_layer), dim=1)
261
- value_layer = torch.cat((past_value, value_layer), dim=1)
262
-
263
- _, kv_length, _ = key_layer.shape
264
-
265
- if use_cache is True:
266
- present = (key_layer, value_layer)
267
- else:
268
- present = None
269
-
270
- if alibi is None:
271
- query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
272
- key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
273
- value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
274
-
275
- attn_output = F.scaled_dot_product_attention(
276
- query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
277
- )
278
-
279
- x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
280
- x = x.permute(0, 2, 1, 3)
281
- attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim)
282
-
283
- output_tensor = self.dense(attn_output)
284
-
285
- outputs = (output_tensor, present)
286
- assert not output_attentions # not supported.
287
- return outputs
288
- else:
289
- attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
290
- matmul_result = query_layer @ key_layer.transpose(-1, -2)
291
-
292
- # change view to [batch_size, num_heads, q_length, kv_length]
293
- attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
294
-
295
- # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
296
- input_dtype = attention_scores.dtype
297
- # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
298
- if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
299
- attention_scores = attention_scores.to(torch.float32)
300
- # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
301
- attention_probs = F.softmax(
302
- (attention_scores + alibi) * self.inv_norm_factor + attention_mask_float,
303
- dim=-1,
304
- dtype=hidden_states.dtype,
305
- )
306
- # [batch_size, num_heads, q_length, kv_length]
307
- attention_probs = self.attention_dropout(attention_probs)
308
-
309
- if head_mask is not None:
310
- attention_probs = attention_probs * head_mask
311
-
312
- # change view [batch_size x num_heads, q_length, kv_length]
313
- attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
314
-
315
- # matmul: [batch_size * num_heads, q_length, head_dim]
316
- context_layer = attention_probs_reshaped @ value_layer
317
-
318
- # change view [batch_size, num_heads, q_length, head_dim]
319
- context_layer = self._merge_heads(context_layer)
320
-
321
- output_tensor = self.dense(context_layer)
322
-
323
- outputs = (output_tensor, present)
324
- if output_attentions:
325
- outputs += (attention_probs,)
326
-
327
- return outputs
328
-
329
-
330
- class MLP(nn.Module):
331
- def __init__(self, config: RWConfig):
332
- super().__init__()
333
- hidden_size = config.hidden_size
334
-
335
- self.dense_h_to_4h = Linear(hidden_size, 4 * hidden_size, bias=config.bias)
336
- self.act = nn.GELU()
337
- self.dense_4h_to_h = Linear(4 * hidden_size, hidden_size, bias=config.bias)
338
- self.hidden_dropout = config.hidden_dropout
339
-
340
- def forward(self, x: torch.Tensor) -> torch.Tensor:
341
- x = self.act(self.dense_h_to_4h(x))
342
- x = self.dense_4h_to_h(x)
343
- return x
344
-
345
-
346
- class DecoderLayer(nn.Module):
347
- def __init__(self, config: RWConfig):
348
- super().__init__()
349
- hidden_size = config.hidden_size
350
-
351
- self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
352
- self.num_heads = config.n_head
353
- self.self_attention = Attention(config)
354
-
355
- if not config.parallel_attn:
356
- # unused if parallel attn
357
- self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
358
-
359
- self.mlp = MLP(config)
360
-
361
- self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
362
- self.hidden_dropout = config.hidden_dropout
363
-
364
- self.config = config
365
-
366
- def forward(
367
- self,
368
- hidden_states: torch.Tensor,
369
- alibi: torch.Tensor,
370
- attention_mask: torch.Tensor,
371
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
372
- head_mask: Optional[torch.Tensor] = None,
373
- use_cache: bool = False,
374
- output_attentions: bool = False,
375
- ):
376
-
377
- layernorm_output = self.input_layernorm(hidden_states)
378
- residual = hidden_states
379
-
380
- # Self attention.
381
- attn_outputs = self.self_attention(
382
- layernorm_output,
383
- layer_past=layer_past,
384
- attention_mask=attention_mask,
385
- alibi=alibi,
386
- head_mask=head_mask,
387
- use_cache=use_cache,
388
- output_attentions=output_attentions,
389
- )
390
-
391
- attention_output = attn_outputs[0]
392
-
393
- if not self.config.parallel_attn:
394
- residual = dropout_add(attention_output, residual, self.config.attention_dropout, training=self.training)
395
- layernorm_output = self.post_attention_layernorm(residual)
396
-
397
- outputs = attn_outputs[1:]
398
-
399
- # MLP.
400
- mlp_output = self.mlp(layernorm_output)
401
-
402
- if self.config.parallel_attn:
403
- mlp_output += attention_output
404
-
405
- output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
406
-
407
- if use_cache:
408
- outputs = (output,) + outputs
409
- else:
410
- outputs = (output,) + outputs[1:]
411
-
412
- return outputs # hidden_states, present, attentions
413
-
414
-
415
- class RWPreTrainedModel(PreTrainedModel):
416
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
417
- """
418
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
419
- models.
420
- """
421
-
422
- config_class = RWConfig
423
- base_model_prefix = "transformer"
424
- supports_gradient_checkpointing = True
425
- _no_split_modules = ["DecoderLayer"]
426
-
427
- def __init__(self, *inputs, **kwargs):
428
- super().__init__(*inputs, **kwargs)
429
-
430
- def _init_weights(self, module: nn.Module):
431
- """Initialize the weights."""
432
- if isinstance(module, nn.Linear) or isinstance(module, Linear):
433
- # Slightly different from the TF version which uses truncated_normal for initialization
434
- # cf https://github.com/pytorch/pytorch/pull/5617
435
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
436
- if module.bias is not None:
437
- module.bias.data.zero_()
438
- elif isinstance(module, nn.Embedding):
439
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
440
- if module.padding_idx is not None:
441
- module.weight.data[module.padding_idx].zero_()
442
- elif isinstance(module, LayerNorm):
443
- module.bias.data.zero_()
444
- module.weight.data.fill_(1.0)
445
-
446
- def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
447
- if isinstance(module, RWModel):
448
- module.gradient_checkpointing = value
449
-
450
- @staticmethod
451
- def _convert_to_standard_cache(
452
- past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
453
- ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
454
- """
455
- Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
456
- num_heads, ...]))
457
- """
458
- batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
459
- num_heads = batch_size_times_num_heads // batch_size
460
- # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
461
- # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
462
- return tuple(
463
- (
464
- layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
465
- layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
466
- )
467
- for layer_past in past_key_value
468
- )
469
-
470
- @staticmethod
471
- def _convert_to_rw_cache(
472
- past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
473
- ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
474
- batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
475
- batch_size_times_num_heads = batch_size * num_heads
476
- # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
477
- # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
478
- return tuple(
479
- (
480
- layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
481
- layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
482
- )
483
- for layer_past in past_key_value
484
- )
485
-
486
-
487
- class RWModel(RWPreTrainedModel):
488
- def __init__(self, config: RWConfig):
489
- super().__init__(config)
490
-
491
- self.embed_dim = config.hidden_size
492
- self.num_heads = config.n_head
493
- self.alibi = config.alibi
494
-
495
- # Embedding + LN Embedding
496
- self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
497
-
498
- # Transformer blocks
499
- self.h = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
500
-
501
- # Final Layer Norm
502
- self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
503
-
504
- self.gradient_checkpointing = False
505
-
506
- # Initialize weights and apply final processing
507
- self.post_init()
508
-
509
- def get_input_embeddings(self):
510
- return self.word_embeddings
511
-
512
- def _prepare_attn_mask(
513
- self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
514
- ) -> torch.BoolTensor:
515
- # create causal mask
516
- # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
517
- combined_attention_mask = None
518
- device = attention_mask.device
519
- _, src_length = input_shape
520
-
521
- if src_length > 1:
522
- combined_attention_mask = _make_causal_mask(
523
- input_shape, device=device, past_key_values_length=past_key_values_length
524
- )
525
-
526
- # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
527
- expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
528
- combined_attention_mask = (
529
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
530
- )
531
-
532
- return combined_attention_mask
533
-
534
- def set_input_embeddings(self, new_embeddings: torch.Tensor):
535
- self.word_embeddings = new_embeddings
536
-
537
- def forward(
538
- self,
539
- input_ids: Optional[torch.LongTensor] = None,
540
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
541
- attention_mask: Optional[torch.Tensor] = None,
542
- head_mask: Optional[torch.LongTensor] = None,
543
- inputs_embeds: Optional[torch.LongTensor] = None,
544
- use_cache: Optional[bool] = None,
545
- output_attentions: Optional[bool] = None,
546
- output_hidden_states: Optional[bool] = None,
547
- return_dict: Optional[bool] = None,
548
- **deprecated_arguments,
549
- ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
550
- if deprecated_arguments.pop("position_ids", False) is not False:
551
- # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
552
- warnings.warn(
553
- "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
554
- " passing `position_ids`.",
555
- FutureWarning,
556
- )
557
- if len(deprecated_arguments) > 0:
558
- raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
559
-
560
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
561
- output_hidden_states = (
562
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
563
- )
564
- use_cache = use_cache if use_cache is not None else self.config.use_cache
565
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
566
-
567
- if input_ids is not None and inputs_embeds is not None:
568
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
569
- elif input_ids is not None:
570
- batch_size, seq_length = input_ids.shape
571
- elif inputs_embeds is not None:
572
- batch_size, seq_length, _ = inputs_embeds.shape
573
- else:
574
- raise ValueError("You have to specify either input_ids or inputs_embeds")
575
-
576
- if past_key_values is None:
577
- past_key_values = tuple([None] * len(self.h))
578
-
579
- # Prepare head mask if needed
580
- # 1.0 in head_mask indicate we keep the head
581
- # attention_probs has shape batch_size x num_heads x N x N
582
- # head_mask has shape n_layer x batch x num_heads x N x N
583
- head_mask = self.get_head_mask(head_mask, self.config.n_layer)
584
-
585
- if inputs_embeds is None:
586
- inputs_embeds = self.word_embeddings(input_ids)
587
-
588
- hidden_states = inputs_embeds
589
-
590
- presents = () if use_cache else None
591
- all_self_attentions = () if output_attentions else None
592
- all_hidden_states = () if output_hidden_states else None
593
-
594
- # Compute alibi tensor: check build_alibi_tensor documentation
595
- seq_length_with_past = seq_length
596
- past_key_values_length = 0
597
- if past_key_values[0] is not None:
598
- past_key_values_length = past_key_values[0][0].shape[2]
599
- seq_length_with_past = seq_length_with_past + past_key_values_length
600
- if attention_mask is None:
601
- attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
602
- else:
603
- attention_mask = attention_mask.to(hidden_states.device)
604
-
605
- if self.alibi:
606
- alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
607
- else:
608
- alibi = None
609
-
610
- causal_mask = self._prepare_attn_mask(
611
- attention_mask,
612
- input_shape=(batch_size, seq_length),
613
- past_key_values_length=past_key_values_length,
614
- )
615
-
616
- for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
617
-
618
- if output_hidden_states:
619
- all_hidden_states = all_hidden_states + (hidden_states,)
620
-
621
- if self.gradient_checkpointing and self.training:
622
-
623
- if use_cache:
624
- logger.warning(
625
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
626
- )
627
- use_cache = False
628
-
629
- def create_custom_forward(module):
630
- def custom_forward(*inputs):
631
- # None for past_key_value
632
- return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
633
-
634
- return custom_forward
635
-
636
- outputs = torch.utils.checkpoint.checkpoint(
637
- create_custom_forward(block),
638
- hidden_states,
639
- alibi,
640
- causal_mask,
641
- head_mask[i],
642
- )
643
- else:
644
- outputs = block(
645
- hidden_states,
646
- layer_past=layer_past,
647
- attention_mask=causal_mask,
648
- head_mask=head_mask[i],
649
- use_cache=use_cache,
650
- output_attentions=output_attentions,
651
- alibi=alibi,
652
- )
653
-
654
- hidden_states = outputs[0]
655
- if use_cache is True:
656
- presents = presents + (outputs[1],)
657
-
658
- if output_attentions:
659
- all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
660
-
661
- # Add last hidden state
662
- hidden_states = self.ln_f(hidden_states)
663
-
664
- if output_hidden_states:
665
- all_hidden_states = all_hidden_states + (hidden_states,)
666
-
667
- if not return_dict:
668
- return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
669
-
670
- return BaseModelOutputWithPastAndCrossAttentions(
671
- last_hidden_state=hidden_states,
672
- past_key_values=presents,
673
- hidden_states=all_hidden_states,
674
- attentions=all_self_attentions,
675
- )
676
-
677
-
678
- class RWForCausalLM(RWPreTrainedModel):
679
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
680
-
681
- def __init__(self, config: RWConfig):
682
- super().__init__(config)
683
- self.transformer = RWModel(config)
684
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
685
-
686
- # Initialize weights and apply final processing
687
- self.post_init()
688
-
689
- def get_output_embeddings(self):
690
- return self.lm_head
691
-
692
- def set_output_embeddings(self, new_embeddings: torch.Tensor):
693
- self.lm_head = new_embeddings
694
-
695
- def prepare_inputs_for_generation(
696
- self,
697
- input_ids: torch.LongTensor,
698
- past: Optional[torch.Tensor] = None,
699
- attention_mask: Optional[torch.Tensor] = None,
700
- **kwargs,
701
- ) -> dict:
702
- # only last token for input_ids if past is not None
703
- if past:
704
- input_ids = input_ids[:, -1].unsqueeze(-1)
705
-
706
- # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
707
- if past[0][0].shape[0] == input_ids.shape[0]:
708
- past = self._convert_to_rw_cache(past)
709
-
710
- return {
711
- "input_ids": input_ids,
712
- "past_key_values": past,
713
- "use_cache": kwargs.get("use_cache"),
714
- "attention_mask": attention_mask,
715
- }
716
-
717
- def forward(
718
- self,
719
- input_ids: Optional[torch.LongTensor] = None,
720
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
721
- attention_mask: Optional[torch.Tensor] = None,
722
- head_mask: Optional[torch.Tensor] = None,
723
- inputs_embeds: Optional[torch.Tensor] = None,
724
- labels: Optional[torch.Tensor] = None,
725
- use_cache: Optional[bool] = None,
726
- output_attentions: Optional[bool] = None,
727
- output_hidden_states: Optional[bool] = None,
728
- return_dict: Optional[bool] = None,
729
- **deprecated_arguments,
730
- ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
731
- r"""
732
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
733
- Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
734
- `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
735
- are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
736
- """
737
- if deprecated_arguments.pop("position_ids", False) is not False:
738
- # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
739
- warnings.warn(
740
- "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
741
- " passing `position_ids`.",
742
- FutureWarning,
743
- )
744
- if len(deprecated_arguments) > 0:
745
- raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
746
-
747
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
748
-
749
- transformer_outputs = self.transformer(
750
- input_ids,
751
- past_key_values=past_key_values,
752
- attention_mask=attention_mask,
753
- head_mask=head_mask,
754
- inputs_embeds=inputs_embeds,
755
- use_cache=use_cache,
756
- output_attentions=output_attentions,
757
- output_hidden_states=output_hidden_states,
758
- return_dict=return_dict,
759
- )
760
- hidden_states = transformer_outputs[0]
761
-
762
- lm_logits = self.lm_head(hidden_states)
763
-
764
- loss = None
765
- if labels is not None:
766
- # Shift so that tokens < n predict n
767
- shift_logits = lm_logits[..., :-1, :].contiguous()
768
- shift_labels = labels[..., 1:].contiguous()
769
- batch_size, seq_length, vocab_size = shift_logits.shape
770
- # Flatten the tokens
771
- loss_fct = CrossEntropyLoss()
772
- loss = loss_fct(
773
- shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
774
- )
775
-
776
- if not return_dict:
777
- output = (lm_logits,) + transformer_outputs[1:]
778
- return ((loss,) + output) if loss is not None else output
779
-
780
- return CausalLMOutputWithCrossAttentions(
781
- loss=loss,
782
- logits=lm_logits,
783
- past_key_values=transformer_outputs.past_key_values,
784
- hidden_states=transformer_outputs.hidden_states,
785
- attentions=transformer_outputs.attentions,
786
- )
787
-
788
- def _reorder_cache(
789
- self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
790
- ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
791
- """
792
- This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
793
- [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
794
- beam_idx at every generation step.
795
- Output shares the same memory storage as `past`.
796
- """
797
- standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
798
-
799
- # Get a copy of `beam_idx` on all the devices where we need those indices.
800
- device_to_beam_idx = {
801
- past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
802
- }
803
- reordered_past = tuple(
804
- (
805
- layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
806
- layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
807
- )
808
- for layer_past in standardized_past
809
- )
810
- return self._convert_to_rw_cache(reordered_past)
811
-
812
-
813
- class RWForSequenceClassification(RWPreTrainedModel):
814
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
815
-
816
- def __init__(self, config: RWConfig):
817
- super().__init__(config)
818
- self.num_labels = config.num_labels
819
- self.transformer = RWModel(config)
820
- self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
821
-
822
- # Initialize weights and apply final processing
823
- self.post_init()
824
-
825
- def forward(
826
- self,
827
- input_ids: Optional[torch.LongTensor] = None,
828
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
829
- attention_mask: Optional[torch.Tensor] = None,
830
- head_mask: Optional[torch.Tensor] = None,
831
- inputs_embeds: Optional[torch.Tensor] = None,
832
- labels: Optional[torch.Tensor] = None,
833
- use_cache: Optional[bool] = None,
834
- output_attentions: Optional[bool] = None,
835
- output_hidden_states: Optional[bool] = None,
836
- return_dict: Optional[bool] = None,
837
- **deprecated_arguments,
838
- ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
839
- r"""
840
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
841
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
842
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
843
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
844
- """
845
- if deprecated_arguments.pop("position_ids", False) is not False:
846
- # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
847
- warnings.warn(
848
- "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
849
- " passing `position_ids`.",
850
- FutureWarning,
851
- )
852
- if len(deprecated_arguments) > 0:
853
- raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
854
-
855
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
856
-
857
- transformer_outputs = self.transformer(
858
- input_ids,
859
- past_key_values=past_key_values,
860
- attention_mask=attention_mask,
861
- head_mask=head_mask,
862
- inputs_embeds=inputs_embeds,
863
- use_cache=use_cache,
864
- output_attentions=output_attentions,
865
- output_hidden_states=output_hidden_states,
866
- return_dict=return_dict,
867
- )
868
-
869
- hidden_states = transformer_outputs[0]
870
- logits = self.score(hidden_states)
871
-
872
- if input_ids is not None:
873
- batch_size = input_ids.shape[0]
874
- else:
875
- batch_size = inputs_embeds.shape[0]
876
-
877
- if self.config.pad_token_id is None and batch_size != 1:
878
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
879
- if self.config.pad_token_id is None:
880
- sequence_lengths = -1
881
- else:
882
- if input_ids is not None:
883
- sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1
884
- else:
885
- sequence_lengths = -1
886
- logger.warning(
887
- f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
888
- "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
889
- )
890
-
891
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
892
-
893
- loss = None
894
- if labels is not None:
895
- if self.config.problem_type is None:
896
- if self.num_labels == 1:
897
- self.config.problem_type = "regression"
898
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
899
- self.config.problem_type = "single_label_classification"
900
- else:
901
- self.config.problem_type = "multi_label_classification"
902
-
903
- if self.config.problem_type == "regression":
904
- loss_fct = MSELoss()
905
- if self.num_labels == 1:
906
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
907
- else:
908
- loss = loss_fct(pooled_logits, labels)
909
- elif self.config.problem_type == "single_label_classification":
910
- loss_fct = CrossEntropyLoss()
911
- loss = loss_fct(pooled_logits, labels)
912
- elif self.config.problem_type == "multi_label_classification":
913
- loss_fct = BCEWithLogitsLoss()
914
- loss = loss_fct(pooled_logits, labels)
915
- if not return_dict:
916
- output = (pooled_logits,) + transformer_outputs[1:]
917
- return ((loss,) + output) if loss is not None else output
918
-
919
- return SequenceClassifierOutputWithPast(
920
- loss=loss,
921
- logits=pooled_logits,
922
- past_key_values=transformer_outputs.past_key_values,
923
- hidden_states=transformer_outputs.hidden_states,
924
- attentions=transformer_outputs.attentions,
925
- )
926
-
927
-
928
- class RWForTokenClassification(RWPreTrainedModel):
929
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
930
-
931
- def __init__(self, config: RWConfig):
932
- super().__init__(config)
933
- self.num_labels = config.num_labels
934
-
935
- self.transformer = RWModel(config)
936
- if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
937
- classifier_dropout = config.classifier_dropout
938
- elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
939
- classifier_dropout = config.hidden_dropout
940
- else:
941
- classifier_dropout = 0.1
942
- self.dropout = nn.Dropout(classifier_dropout)
943
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
944
-
945
- # Initialize weights and apply final processing
946
- self.post_init()
947
-
948
- def forward(
949
- self,
950
- input_ids: Optional[torch.LongTensor] = None,
951
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
952
- attention_mask: Optional[torch.Tensor] = None,
953
- head_mask: Optional[torch.Tensor] = None,
954
- inputs_embeds: Optional[torch.Tensor] = None,
955
- labels: Optional[torch.Tensor] = None,
956
- use_cache: Optional[bool] = None,
957
- output_attentions: Optional[bool] = None,
958
- output_hidden_states: Optional[bool] = None,
959
- return_dict: Optional[bool] = None,
960
- **deprecated_arguments,
961
- ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
962
- r"""
963
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
964
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
965
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
966
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
967
- """
968
- if deprecated_arguments.pop("position_ids", False) is not False:
969
- # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
970
- warnings.warn(
971
- "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
972
- " passing `position_ids`.",
973
- FutureWarning,
974
- )
975
- if len(deprecated_arguments) > 0:
976
- raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
977
-
978
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
979
-
980
- transformer_outputs = self.transformer(
981
- input_ids,
982
- past_key_values=past_key_values,
983
- attention_mask=attention_mask,
984
- head_mask=head_mask,
985
- inputs_embeds=inputs_embeds,
986
- use_cache=use_cache,
987
- output_attentions=output_attentions,
988
- output_hidden_states=output_hidden_states,
989
- return_dict=return_dict,
990
- )
991
-
992
- hidden_states = transformer_outputs[0]
993
- hidden_states = self.dropout(hidden_states)
994
- logits = self.classifier(hidden_states)
995
-
996
- loss = None
997
- if labels is not None:
998
- batch_size, seq_length = labels.shape
999
- loss_fct = CrossEntropyLoss()
1000
- loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length))
1001
-
1002
- if not return_dict:
1003
- output = (logits,) + transformer_outputs[2:]
1004
- return ((loss,) + output) if loss is not None else output
1005
-
1006
- return TokenClassifierOutput(
1007
- loss=loss,
1008
- logits=logits,
1009
- hidden_states=transformer_outputs.hidden_states,
1010
- attentions=transformer_outputs.attentions,
1011
- )
1012
-
1013
-
1014
- class RWForQuestionAnswering(RWPreTrainedModel):
1015
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
1016
-
1017
- def __init__(self, config):
1018
- super().__init__(config)
1019
- self.transformer = RWModel(config)
1020
- self.qa_outputs = nn.Linear(config.hidden_size, 2)
1021
-
1022
- # Initialize weights and apply final processing
1023
- self.post_init()
1024
-
1025
- def forward(
1026
- self,
1027
- input_ids: Optional[torch.LongTensor] = None,
1028
- attention_mask: Optional[torch.FloatTensor] = None,
1029
- position_ids: Optional[torch.LongTensor] = None,
1030
- head_mask: Optional[torch.FloatTensor] = None,
1031
- inputs_embeds: Optional[torch.FloatTensor] = None,
1032
- start_positions: Optional[torch.LongTensor] = None,
1033
- end_positions: Optional[torch.LongTensor] = None,
1034
- output_attentions: Optional[bool] = None,
1035
- output_hidden_states: Optional[bool] = None,
1036
- return_dict: Optional[bool] = None,
1037
- ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1038
- r"""
1039
- start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1040
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
1041
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1042
- are not taken into account for computing the loss.
1043
- end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1044
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
1045
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1046
- are not taken into account for computing the loss.
1047
- """
1048
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1049
-
1050
- outputs = self.transformer(
1051
- input_ids,
1052
- attention_mask=attention_mask,
1053
- position_ids=position_ids,
1054
- head_mask=head_mask,
1055
- inputs_embeds=inputs_embeds,
1056
- output_attentions=output_attentions,
1057
- output_hidden_states=output_hidden_states,
1058
- return_dict=return_dict,
1059
- )
1060
-
1061
- sequence_output = outputs[0]
1062
-
1063
- logits = self.qa_outputs(sequence_output)
1064
- start_logits, end_logits = logits.split(1, dim=-1)
1065
- start_logits = start_logits.squeeze(-1).contiguous()
1066
- end_logits = end_logits.squeeze(-1).contiguous()
1067
-
1068
- total_loss = None
1069
- if start_positions is not None and end_positions is not None:
1070
- # If we are on multi-GPU, split add a dimension
1071
- if len(start_positions.size()) > 1:
1072
- start_positions = start_positions.squeeze(-1)
1073
- if len(end_positions.size()) > 1:
1074
- end_positions = end_positions.squeeze(-1)
1075
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
1076
- ignored_index = start_logits.size(1)
1077
- start_positions = start_positions.clamp(0, ignored_index)
1078
- end_positions = end_positions.clamp(0, ignored_index)
1079
-
1080
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1081
- start_loss = loss_fct(start_logits, start_positions)
1082
- end_loss = loss_fct(end_logits, end_positions)
1083
- total_loss = (start_loss + end_loss) / 2
1084
-
1085
- if not return_dict:
1086
- output = (start_logits, end_logits) + outputs[2:]
1087
- return ((total_loss,) + output) if total_loss is not None else output
1088
-
1089
- return QuestionAnsweringModelOutput(
1090
- loss=total_loss,
1091
- start_logits=start_logits,
1092
- end_logits=end_logits,
1093
- hidden_states=outputs.hidden_states,
1094
- attentions=outputs.attentions,
1095
- )