chenwuml commited on
Commit
9102d37
1 Parent(s): 0bcc5d8

Upload modelling_RW.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modelling_RW.py +1257 -0
modelling_RW.py ADDED
@@ -0,0 +1,1257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # port of models described in RW
2
+ # We use the bloom model as a starting point for these model.
3
+ # Please refer to the bloom models for usage instructions.
4
+
5
+ import math
6
+ import warnings
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.utils.checkpoint
11
+ from torch import nn
12
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
13
+ from torch.nn import functional as F
14
+
15
+ from transformers.modeling_outputs import (
16
+ BaseModelOutputWithPastAndCrossAttentions,
17
+ CausalLMOutputWithCrossAttentions,
18
+ QuestionAnsweringModelOutput,
19
+ SequenceClassifierOutputWithPast,
20
+ TokenClassifierOutput,
21
+ )
22
+ from transformers.modeling_utils import PreTrainedModel
23
+ from transformers.utils import logging
24
+ from .configuration_RW import RWConfig
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ # NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
30
+ # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
31
+ class Linear(nn.Linear):
32
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
33
+ ret = input @ self.weight.T
34
+ if self.bias is None:
35
+ return ret
36
+ else:
37
+ return ret + self.bias
38
+
39
+
40
+ from einops import rearrange
41
+
42
+
43
+ # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
44
+ def rotate_half(x):
45
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
46
+ return torch.cat(
47
+ (-x2, x1), dim=x1.ndim - 1
48
+ ) # dim=-1 triggers a bug in torch < 1.8.0
49
+
50
+
51
+ class RotaryEmbedding(torch.nn.Module):
52
+ """Implementation of RotaryEmbedding from GPT-NeoX.
53
+ This implementation is design to operate on queries and keys that are compatible with
54
+ [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format).
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ head_dim: int,
60
+ base=1000000,
61
+ ):
62
+ super().__init__()
63
+ inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
64
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
65
+ self.head_dim = head_dim
66
+ self.seq_len_cached = None
67
+ self.batch_size_cached = None
68
+ self.cos_cached: torch.Tensor | None = None
69
+ self.sin_cached: torch.Tensor | None = None
70
+ self.max_position_embeddings = 2048
71
+ self.base = base
72
+
73
+ def cos_sin(
74
+ self,
75
+ seq_len: int,
76
+ device="cuda",
77
+ dtype=torch.bfloat16,
78
+ ) -> torch.Tensor:
79
+ if seq_len != self.seq_len_cached:
80
+ self.seq_len_cached = seq_len
81
+ t = torch.arange(seq_len, device=device, dtype=torch.float32)
82
+
83
+ freqs = torch.outer(t, self.inv_freq.to(device=t.device))
84
+ #print(f"HAHA -- {t.dtype=} -- {freqs.dtype=}")
85
+ #raise Exception("dtype checking")
86
+ #freqs = torch.einsum("i,j->ij", t, self.inv_freq)
87
+ emb = torch.cat((freqs, freqs), dim=-1).to(device)
88
+
89
+ if dtype in [torch.float16, torch.bfloat16]:
90
+ emb = emb.float()
91
+
92
+ self.cos_cached = emb.cos()[None, :, :]
93
+ self.sin_cached = emb.sin()[None, :, :]
94
+
95
+ self.cos_cached = self.cos_cached.type(dtype)
96
+ self.sin_cached = self.sin_cached.type(dtype)
97
+
98
+ return self.cos_cached, self.sin_cached
99
+
100
+ def forward(self, q, k):
101
+ batch, seq_len, head_dim = q.shape
102
+ cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
103
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
104
+
105
+
106
+ def _make_causal_mask(
107
+ input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
108
+ ) -> torch.BoolTensor:
109
+ batch_size, target_length = input_ids_shape
110
+ mask = torch.empty(
111
+ (target_length, target_length + past_key_values_length),
112
+ dtype=torch.bool,
113
+ device=device,
114
+ )
115
+ # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
116
+ seq_ids = torch.arange(target_length, device=device)
117
+ mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
118
+
119
+ if past_key_values_length > 0:
120
+ mask[:, :past_key_values_length] = False
121
+
122
+ expanded_mask = mask[None, None, :, :].expand(
123
+ batch_size, 1, target_length, target_length + past_key_values_length
124
+ )
125
+ return expanded_mask
126
+
127
+
128
+ def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
129
+ batch_size, src_length = mask.shape
130
+ tgt_length = tgt_length if tgt_length is not None else src_length
131
+
132
+ expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
133
+ return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
134
+
135
+
136
+ def build_alibi_tensor(
137
+ attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype
138
+ ) -> torch.Tensor:
139
+ batch_size, seq_length = attention_mask.shape
140
+ closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
141
+ base = torch.tensor(
142
+ 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
143
+ device=attention_mask.device,
144
+ dtype=torch.float32,
145
+ )
146
+ powers = torch.arange(
147
+ 1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32
148
+ )
149
+ slopes = torch.pow(base, powers)
150
+
151
+ if closest_power_of_2 != num_heads:
152
+ extra_base = torch.tensor(
153
+ 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
154
+ device=attention_mask.device,
155
+ dtype=torch.float32,
156
+ )
157
+ num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
158
+ extra_powers = torch.arange(
159
+ 1,
160
+ 1 + 2 * num_remaining_heads,
161
+ 2,
162
+ device=attention_mask.device,
163
+ dtype=torch.int32,
164
+ )
165
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
166
+
167
+ # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
168
+ # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
169
+ # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
170
+ # => the query_length dimension will then be broadcasted correctly
171
+ # This is more or less identical to T5's relative position bias:
172
+ # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
173
+ arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
174
+ alibi = slopes[..., None].bfloat16() * arange_tensor
175
+ return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
176
+
177
+
178
+ def dropout_add(
179
+ x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool
180
+ ) -> torch.Tensor:
181
+ out = F.dropout(x, p=prob, training=training)
182
+ out = residual + out
183
+ return out
184
+
185
+
186
+ class Attention(nn.Module):
187
+ def __init__(self, config: RWConfig):
188
+ super().__init__()
189
+
190
+ self.hidden_size = config.hidden_size
191
+ self.num_heads = config.n_head
192
+ self.head_dim = self.hidden_size // self.num_heads
193
+ self.split_size = self.hidden_size
194
+ self.hidden_dropout = config.hidden_dropout
195
+
196
+ if self.head_dim * self.num_heads != self.hidden_size:
197
+ raise ValueError(
198
+ f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
199
+ f" {self.num_heads})."
200
+ )
201
+
202
+ self.maybe_rotary = (
203
+ RotaryEmbedding(config.head_dim) if config.rotary else lambda q, k: (q, k)
204
+ )
205
+
206
+ # Layer-wise attention scaling
207
+ self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
208
+ self.beta = self.inv_norm_factor
209
+
210
+ self.query_key_value = Linear(
211
+ self.hidden_size,
212
+ (config.n_head_kv * 2 + config.n_head) * self.head_dim,
213
+ bias=config.bias,
214
+ )
215
+ self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
216
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
217
+ self.num_kv = config.n_head_kv
218
+
219
+ def _split_heads(
220
+ self, fused_qkv: torch.Tensor
221
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
222
+ """
223
+ Split the last dimension into (num_heads, head_dim), results share same memory
224
+ storage as `fused_qkv`
225
+
226
+ Args:
227
+ fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
228
+
229
+ Returns:
230
+ query: [batch_size, seq_length, num_heads, head_dim]
231
+ key: [batch_size, seq_length, num_heads, head_dim]
232
+ value: [batch_size, seq_length, num_heads, head_dim]
233
+ """
234
+ batch, seq_len, _ = fused_qkv.shape
235
+ qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv + 2, 64)
236
+ q = qkv[:, :, :, :-2]
237
+ k = qkv[:, :, :, [-2]]
238
+ v = qkv[:, :, :, [-1]]
239
+ k = torch.broadcast_to(k, q.shape)
240
+ v = torch.broadcast_to(v, q.shape)
241
+
242
+ q, k, v = [
243
+ rearrange(
244
+ x,
245
+ "batch seq_len group num_heads head_dim ->\
246
+ batch seq_len (group num_heads) head_dim",
247
+ head_dim=self.head_dim,
248
+ )
249
+ for x in [q, k, v]
250
+ ]
251
+ return q, k, v
252
+
253
+ def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
254
+ """
255
+ Merge heads together over the last dimenstion
256
+
257
+ Args:
258
+ x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
259
+
260
+ Returns:
261
+ torch.tensor: [batch_size, seq_length, num_heads * head_dim]
262
+ """
263
+ # What we want to achieve is:
264
+ # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
265
+ batch_size_and_num_heads, seq_length, _ = x.shape
266
+ batch_size = batch_size_and_num_heads // self.num_heads
267
+
268
+ # First view to decompose the batch size
269
+ # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
270
+ x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
271
+
272
+ # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
273
+ x = x.permute(0, 2, 1, 3)
274
+
275
+ # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
276
+ return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
277
+
278
+ def forward(
279
+ self,
280
+ hidden_states: torch.Tensor,
281
+ alibi: torch.Tensor,
282
+ attention_mask: torch.Tensor,
283
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
284
+ head_mask: Optional[torch.Tensor] = None,
285
+ use_cache: bool = False,
286
+ output_attentions: bool = False,
287
+ ):
288
+ fused_qkv = self.query_key_value(
289
+ hidden_states
290
+ ) # [batch_size, seq_length, 3 x hidden_size]
291
+
292
+ # 3 x [batch_size, seq_length, num_heads, head_dim]
293
+ (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
294
+
295
+ batch_size, q_length, _, _ = query_layer.shape
296
+
297
+ query_layer = query_layer.transpose(1, 2).reshape(
298
+ batch_size * self.num_heads, q_length, self.head_dim
299
+ )
300
+ key_layer = key_layer.transpose(1, 2).reshape(
301
+ batch_size * self.num_heads,
302
+ q_length,
303
+ self.head_dim,
304
+ )
305
+ value_layer = value_layer.transpose(1, 2).reshape(
306
+ batch_size * self.num_heads, q_length, self.head_dim
307
+ )
308
+
309
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
310
+
311
+ if layer_past is not None:
312
+ past_key, past_value = layer_past
313
+ # concatenate along seq_length dimension:
314
+ # - key: [batch_size * self.num_heads, head_dim, kv_length]
315
+ # - value: [batch_size * self.num_heads, kv_length, head_dim]
316
+ key_layer = torch.cat((past_key, key_layer), dim=1)
317
+ value_layer = torch.cat((past_value, value_layer), dim=1)
318
+
319
+ _, kv_length, _ = key_layer.shape
320
+
321
+ if use_cache is True:
322
+ present = (key_layer, value_layer)
323
+ else:
324
+ present = None
325
+
326
+ if alibi is None:
327
+ query_layer_ = query_layer.reshape(
328
+ batch_size, self.num_heads, -1, self.head_dim
329
+ )
330
+ key_layer_ = key_layer.reshape(
331
+ batch_size, self.num_heads, -1, self.head_dim
332
+ )
333
+ value_layer_ = value_layer.reshape(
334
+ batch_size, self.num_heads, -1, self.head_dim
335
+ )
336
+
337
+ attn_output = F.scaled_dot_product_attention(
338
+ query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
339
+ )
340
+
341
+ x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
342
+ x = x.permute(0, 2, 1, 3)
343
+ attn_output = x.reshape(
344
+ batch_size, q_length, self.num_heads * self.head_dim
345
+ )
346
+
347
+ output_tensor = self.dense(attn_output)
348
+
349
+ outputs = (output_tensor, present)
350
+ assert not output_attentions # not supported.
351
+ return outputs
352
+ else:
353
+ attention_mask_float = (
354
+ (attention_mask * 1.0)
355
+ .masked_fill(attention_mask, -1e9)
356
+ .to(torch.bfloat16)
357
+ )
358
+ matmul_result = query_layer @ key_layer.transpose(-1, -2)
359
+
360
+ # change view to [batch_size, num_heads, q_length, kv_length]
361
+ attention_scores = matmul_result.view(
362
+ batch_size, self.num_heads, q_length, kv_length
363
+ )
364
+
365
+ # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
366
+ input_dtype = attention_scores.dtype
367
+ # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
368
+ if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
369
+ attention_scores = attention_scores.to(torch.float32)
370
+ # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
371
+ attention_probs = F.softmax(
372
+ (attention_scores + alibi.view(batch_size, self.num_heads, 1, -1))
373
+ * self.inv_norm_factor
374
+ + attention_mask_float,
375
+ dim=-1,
376
+ dtype=hidden_states.dtype,
377
+ )
378
+ # [batch_size, num_heads, q_length, kv_length]
379
+ attention_probs = self.attention_dropout(attention_probs)
380
+
381
+ if head_mask is not None:
382
+ attention_probs = attention_probs * head_mask
383
+
384
+ # change view [batch_size x num_heads, q_length, kv_length]
385
+ attention_probs_reshaped = attention_probs.view(
386
+ batch_size * self.num_heads, q_length, kv_length
387
+ )
388
+
389
+ # matmul: [batch_size * num_heads, q_length, head_dim]
390
+ context_layer = attention_probs_reshaped @ value_layer
391
+
392
+ # change view [batch_size, num_heads, q_length, head_dim]
393
+ context_layer = self._merge_heads(context_layer)
394
+
395
+ output_tensor = self.dense(context_layer)
396
+
397
+ outputs = (output_tensor, present)
398
+ if output_attentions:
399
+ outputs += (attention_probs,)
400
+
401
+ return outputs
402
+
403
+
404
+ class MLP(nn.Module):
405
+ def __init__(self, config: RWConfig):
406
+ super().__init__()
407
+ hidden_size = config.hidden_size
408
+
409
+ self.dense_h_to_4h = Linear(hidden_size, 4 * hidden_size, bias=config.bias)
410
+ self.act = nn.GELU()
411
+ self.dense_4h_to_h = Linear(4 * hidden_size, hidden_size, bias=config.bias)
412
+ self.hidden_dropout = config.hidden_dropout
413
+
414
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
415
+ x = self.act(self.dense_h_to_4h(x))
416
+ x = self.dense_4h_to_h(x)
417
+ return x
418
+
419
+
420
+ class DecoderLayer(nn.Module):
421
+ def __init__(self, config: RWConfig):
422
+ super().__init__()
423
+ hidden_size = config.hidden_size
424
+
425
+ self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
426
+ self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
427
+
428
+ self.num_heads = config.n_head
429
+ self.self_attention = Attention(config)
430
+
431
+ self.mlp = MLP(config)
432
+
433
+ self.apply_residual_connection_post_layernorm = (
434
+ config.apply_residual_connection_post_layernorm
435
+ )
436
+ self.hidden_dropout = config.hidden_dropout
437
+
438
+ self.config = config
439
+
440
+ def forward(
441
+ self,
442
+ hidden_states: torch.Tensor,
443
+ alibi: torch.Tensor,
444
+ attention_mask: torch.Tensor,
445
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
446
+ head_mask: Optional[torch.Tensor] = None,
447
+ use_cache: bool = False,
448
+ output_attentions: bool = False,
449
+ ):
450
+ ln_attn = self.ln_attn(hidden_states)
451
+ ln_mlp = self.ln_mlp(hidden_states)
452
+
453
+ residual = hidden_states
454
+
455
+ # Self attention.
456
+ attn_outputs = self.self_attention(
457
+ ln_attn,
458
+ layer_past=layer_past,
459
+ attention_mask=attention_mask,
460
+ alibi=alibi,
461
+ head_mask=head_mask,
462
+ use_cache=use_cache,
463
+ output_attentions=output_attentions,
464
+ )
465
+
466
+ attention_output = attn_outputs[0]
467
+
468
+ outputs = attn_outputs[1:]
469
+
470
+ # MLP.
471
+ mlp_output = self.mlp(ln_mlp)
472
+
473
+ output = dropout_add(
474
+ mlp_output + attention_output,
475
+ residual,
476
+ self.config.hidden_dropout,
477
+ training=self.training,
478
+ )
479
+
480
+ if use_cache:
481
+ outputs = (output,) + outputs
482
+ else:
483
+ outputs = (output,) + outputs[1:]
484
+
485
+ return outputs # hidden_states, present, attentions
486
+
487
+
488
+ class RWPreTrainedModel(PreTrainedModel):
489
+ _keys_to_ignore_on_load_missing = [
490
+ r"h.*.self_attention.scale_mask_softmax.causal_mask",
491
+ r"lm_head.weight",
492
+ ]
493
+ """
494
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
495
+ models.
496
+ """
497
+
498
+ config_class = RWConfig
499
+ base_model_prefix = "transformer"
500
+ supports_gradient_checkpointing = True
501
+ _no_split_modules = ["DecoderLayer"]
502
+
503
+ def __init__(self, *inputs, **kwargs):
504
+ super().__init__(*inputs, **kwargs)
505
+
506
+ def _init_weights(self, module: nn.Module):
507
+ """Initialize the weights."""
508
+ if isinstance(module, nn.Linear) or isinstance(module, Linear):
509
+ # Slightly different from the TF version which uses truncated_normal for initialization
510
+ # cf https://github.com/pytorch/pytorch/pull/5617
511
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
512
+ if module.bias is not None:
513
+ module.bias.data.zero_()
514
+ elif isinstance(module, nn.Embedding):
515
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
516
+ if module.padding_idx is not None:
517
+ module.weight.data[module.padding_idx].zero_()
518
+ elif isinstance(module, LayerNorm):
519
+ module.bias.data.zero_()
520
+ module.weight.data.fill_(1.0)
521
+
522
+ def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
523
+ if isinstance(module, RWModel):
524
+ module.gradient_checkpointing = value
525
+
526
+ @staticmethod
527
+ def _convert_to_standard_cache(
528
+ past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
529
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
530
+ """
531
+ Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
532
+ num_heads, ...]))
533
+ """
534
+ batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
535
+ num_heads = batch_size_times_num_heads // batch_size
536
+ # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
537
+ # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
538
+ return tuple(
539
+ (
540
+ layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
541
+ layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
542
+ )
543
+ for layer_past in past_key_value
544
+ )
545
+
546
+ @staticmethod
547
+ def _convert_to_rw_cache(
548
+ past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
549
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
550
+ batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
551
+ batch_size_times_num_heads = batch_size * num_heads
552
+ # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
553
+ # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
554
+ return tuple(
555
+ (
556
+ layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
557
+ layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
558
+ )
559
+ for layer_past in past_key_value
560
+ )
561
+
562
+
563
+ class RWModel(RWPreTrainedModel):
564
+ def __init__(self, config: RWConfig):
565
+ super().__init__(config)
566
+
567
+ self.embed_dim = config.hidden_size
568
+ self.num_heads = config.n_head
569
+ self.alibi = config.alibi
570
+
571
+ # Embedding + LN Embedding
572
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
573
+
574
+ # Transformer blocks
575
+ self.h = nn.ModuleList(
576
+ [DecoderLayer(config) for _ in range(config.num_hidden_layers)]
577
+ )
578
+
579
+ # Final Layer Norm
580
+ self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
581
+
582
+ self.gradient_checkpointing = False
583
+
584
+ # Initialize weights and apply final processing
585
+ self.post_init()
586
+
587
+ def get_input_embeddings(self):
588
+ return self.word_embeddings
589
+
590
+ def _prepare_attn_mask(
591
+ self,
592
+ attention_mask: torch.Tensor,
593
+ input_shape: Tuple[int, int],
594
+ past_key_values_length: int,
595
+ ) -> torch.BoolTensor:
596
+ # create causal mask
597
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
598
+ combined_attention_mask = None
599
+ device = attention_mask.device
600
+ _, src_length = input_shape
601
+
602
+ if src_length > 1:
603
+ combined_attention_mask = _make_causal_mask(
604
+ input_shape,
605
+ device=device,
606
+ past_key_values_length=past_key_values_length,
607
+ )
608
+
609
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
610
+ expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
611
+ combined_attention_mask = (
612
+ expanded_attn_mask
613
+ if combined_attention_mask is None
614
+ else expanded_attn_mask | combined_attention_mask
615
+ )
616
+
617
+ return combined_attention_mask
618
+
619
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
620
+ self.word_embeddings = new_embeddings
621
+
622
+ def forward(
623
+ self,
624
+ input_ids: Optional[torch.LongTensor] = None,
625
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
626
+ attention_mask: Optional[torch.Tensor] = None,
627
+ head_mask: Optional[torch.LongTensor] = None,
628
+ inputs_embeds: Optional[torch.LongTensor] = None,
629
+ use_cache: Optional[bool] = None,
630
+ output_attentions: Optional[bool] = None,
631
+ output_hidden_states: Optional[bool] = None,
632
+ return_dict: Optional[bool] = None,
633
+ **deprecated_arguments,
634
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
635
+ if deprecated_arguments.pop("position_ids", False) is not False:
636
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
637
+ warnings.warn(
638
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
639
+ " passing `position_ids`.",
640
+ FutureWarning,
641
+ )
642
+ if len(deprecated_arguments) > 0:
643
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
644
+
645
+ output_attentions = (
646
+ output_attentions
647
+ if output_attentions is not None
648
+ else self.config.output_attentions
649
+ )
650
+ output_hidden_states = (
651
+ output_hidden_states
652
+ if output_hidden_states is not None
653
+ else self.config.output_hidden_states
654
+ )
655
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
656
+ return_dict = (
657
+ return_dict if return_dict is not None else self.config.use_return_dict
658
+ )
659
+
660
+ if input_ids is not None and inputs_embeds is not None:
661
+ raise ValueError(
662
+ "You cannot specify both input_ids and inputs_embeds at the same time"
663
+ )
664
+ elif input_ids is not None:
665
+ batch_size, seq_length = input_ids.shape
666
+ elif inputs_embeds is not None:
667
+ batch_size, seq_length, _ = inputs_embeds.shape
668
+ else:
669
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
670
+
671
+ if past_key_values is None:
672
+ past_key_values = tuple([None] * len(self.h))
673
+
674
+ # Prepare head mask if needed
675
+ # 1.0 in head_mask indicate we keep the head
676
+ # attention_probs has shape batch_size x num_heads x N x N
677
+ # head_mask has shape n_layer x batch x num_heads x N x N
678
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
679
+
680
+ if inputs_embeds is None:
681
+ inputs_embeds = self.word_embeddings(input_ids)
682
+
683
+ hidden_states = inputs_embeds
684
+
685
+ presents = () if use_cache else None
686
+ all_self_attentions = () if output_attentions else None
687
+ all_hidden_states = () if output_hidden_states else None
688
+
689
+ # Compute alibi tensor: check build_alibi_tensor documentation
690
+ seq_length_with_past = seq_length
691
+ past_key_values_length = 0
692
+ if past_key_values[0] is not None:
693
+ past_key_values_length = past_key_values[0][0].shape[2]
694
+ seq_length_with_past = seq_length_with_past + past_key_values_length
695
+ if attention_mask is None:
696
+ attention_mask = torch.ones(
697
+ (batch_size, seq_length_with_past), device=hidden_states.device
698
+ )
699
+ else:
700
+ attention_mask = attention_mask.to(hidden_states.device)
701
+
702
+ if self.alibi:
703
+ alibi = build_alibi_tensor(
704
+ attention_mask, self.num_heads, dtype=hidden_states.dtype
705
+ )
706
+ else:
707
+ alibi = None
708
+
709
+ causal_mask = self._prepare_attn_mask(
710
+ attention_mask,
711
+ input_shape=(batch_size, seq_length),
712
+ past_key_values_length=past_key_values_length,
713
+ )
714
+
715
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
716
+ if output_hidden_states:
717
+ all_hidden_states = all_hidden_states + (hidden_states,)
718
+
719
+ if self.gradient_checkpointing and self.training:
720
+ if use_cache:
721
+ logger.warning(
722
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
723
+ )
724
+ use_cache = False
725
+
726
+ def create_custom_forward(module):
727
+ def custom_forward(*inputs):
728
+ # None for past_key_value
729
+ return module(
730
+ *inputs,
731
+ use_cache=use_cache,
732
+ output_attentions=output_attentions,
733
+ )
734
+
735
+ return custom_forward
736
+
737
+ outputs = torch.utils.checkpoint.checkpoint(
738
+ create_custom_forward(block),
739
+ hidden_states,
740
+ alibi,
741
+ causal_mask,
742
+ head_mask[i],
743
+ )
744
+ else:
745
+ outputs = block(
746
+ hidden_states,
747
+ layer_past=layer_past,
748
+ attention_mask=causal_mask,
749
+ head_mask=head_mask[i],
750
+ use_cache=use_cache,
751
+ output_attentions=output_attentions,
752
+ alibi=alibi,
753
+ )
754
+
755
+ hidden_states = outputs[0]
756
+ if use_cache is True:
757
+ presents = presents + (outputs[1],)
758
+
759
+ if output_attentions:
760
+ all_self_attentions = all_self_attentions + (
761
+ outputs[2 if use_cache else 1],
762
+ )
763
+
764
+ # Add last hidden state
765
+ hidden_states = self.ln_f(hidden_states)
766
+
767
+ if output_hidden_states:
768
+ all_hidden_states = all_hidden_states + (hidden_states,)
769
+
770
+ if not return_dict:
771
+ return tuple(
772
+ v
773
+ for v in [
774
+ hidden_states,
775
+ presents,
776
+ all_hidden_states,
777
+ all_self_attentions,
778
+ ]
779
+ if v is not None
780
+ )
781
+
782
+ return BaseModelOutputWithPastAndCrossAttentions(
783
+ last_hidden_state=hidden_states,
784
+ past_key_values=presents,
785
+ hidden_states=all_hidden_states,
786
+ attentions=all_self_attentions,
787
+ )
788
+
789
+
790
+ class RWForCausalLM(RWPreTrainedModel):
791
+ _keys_to_ignore_on_load_missing = [
792
+ r"h.*.self_attention.scale_mask_softmax.causal_mask",
793
+ r"lm_head.weight",
794
+ ]
795
+
796
+ def __init__(self, config: RWConfig):
797
+ super().__init__(config)
798
+ self.transformer = RWModel(config)
799
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
800
+
801
+ # Initialize weights and apply final processing
802
+ self.post_init()
803
+
804
+ def get_output_embeddings(self):
805
+ return self.lm_head
806
+
807
+ def set_output_embeddings(self, new_embeddings: torch.Tensor):
808
+ self.lm_head = new_embeddings
809
+
810
+ def prepare_inputs_for_generation(
811
+ self,
812
+ input_ids: torch.LongTensor,
813
+ past: Optional[torch.Tensor] = None,
814
+ attention_mask: Optional[torch.Tensor] = None,
815
+ **kwargs,
816
+ ) -> dict:
817
+ # only last token for input_ids if past is not None
818
+ if past:
819
+ input_ids = input_ids[:, -1].unsqueeze(-1)
820
+
821
+ # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
822
+ if past[0][0].shape[0] == input_ids.shape[0]:
823
+ past = self._convert_to_rw_cache(past)
824
+
825
+ return {
826
+ "input_ids": input_ids,
827
+ "past_key_values": past,
828
+ "use_cache": kwargs.get("use_cache"),
829
+ "attention_mask": attention_mask,
830
+ }
831
+
832
+ def forward(
833
+ self,
834
+ input_ids: Optional[torch.LongTensor] = None,
835
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
836
+ attention_mask: Optional[torch.Tensor] = None,
837
+ head_mask: Optional[torch.Tensor] = None,
838
+ inputs_embeds: Optional[torch.Tensor] = None,
839
+ labels: Optional[torch.Tensor] = None,
840
+ use_cache: Optional[bool] = None,
841
+ output_attentions: Optional[bool] = None,
842
+ output_hidden_states: Optional[bool] = None,
843
+ return_dict: Optional[bool] = None,
844
+ input_tokens: Optional[torch.LongTensor] = None,
845
+ **deprecated_arguments,
846
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
847
+ r"""
848
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
849
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
850
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
851
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
852
+ """
853
+ if deprecated_arguments.pop("position_ids", False) is not False:
854
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
855
+ warnings.warn(
856
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
857
+ " passing `position_ids`.",
858
+ FutureWarning,
859
+ )
860
+ if len(deprecated_arguments) > 0:
861
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
862
+
863
+ return_dict = (
864
+ return_dict if return_dict is not None else self.config.use_return_dict
865
+ )
866
+
867
+ transformer_outputs = self.transformer(
868
+ input_ids,
869
+ past_key_values=past_key_values,
870
+ attention_mask=attention_mask,
871
+ head_mask=head_mask,
872
+ inputs_embeds=inputs_embeds,
873
+ use_cache=use_cache,
874
+ output_attentions=output_attentions,
875
+ output_hidden_states=output_hidden_states,
876
+ return_dict=return_dict,
877
+ )
878
+ hidden_states = transformer_outputs[0]
879
+
880
+ lm_logits = self.lm_head(hidden_states)
881
+
882
+ loss = None
883
+ if labels is not None:
884
+ if (input_tokens is not None):
885
+ stt = input_tokens
886
+ # Shift so that tokens < n predict n and also only calculate output loss
887
+ shift_logits = lm_logits[..., stt :-1, :].contiguous()
888
+ shift_labels = labels[..., stt + 1 :].contiguous()
889
+ else:
890
+ # Shift so that tokens < n predict n
891
+ shift_logits = lm_logits[..., :-1, :].contiguous()
892
+ shift_labels = labels[..., 1:].contiguous()
893
+ batch_size, seq_length, vocab_size = shift_logits.shape
894
+
895
+ # Flatten the tokens
896
+ loss_fct = CrossEntropyLoss()
897
+ loss = loss_fct(
898
+ shift_logits.view(batch_size * seq_length, vocab_size),
899
+ shift_labels.view(batch_size * seq_length),
900
+ )
901
+
902
+ if not return_dict:
903
+ output = (lm_logits,) + transformer_outputs[1:]
904
+ return ((loss,) + output) if loss is not None else output
905
+
906
+ return CausalLMOutputWithCrossAttentions(
907
+ loss=loss,
908
+ logits=lm_logits,
909
+ past_key_values=transformer_outputs.past_key_values,
910
+ hidden_states=transformer_outputs.hidden_states,
911
+ attentions=transformer_outputs.attentions,
912
+ )
913
+
914
+ def _reorder_cache(
915
+ self,
916
+ past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...],
917
+ beam_idx: torch.LongTensor,
918
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
919
+ """
920
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
921
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
922
+ beam_idx at every generation step.
923
+
924
+ Output shares the same memory storage as `past`.
925
+ """
926
+ standardized_past = self._convert_to_standard_cache(
927
+ past, batch_size=len(beam_idx)
928
+ )
929
+
930
+ # Get a copy of `beam_idx` on all the devices where we need those indices.
931
+ device_to_beam_idx = {
932
+ past_state.device: beam_idx.to(past_state.device)
933
+ for layer_past in past
934
+ for past_state in layer_past
935
+ }
936
+ reordered_past = tuple(
937
+ (
938
+ layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
939
+ layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
940
+ )
941
+ for layer_past in standardized_past
942
+ )
943
+ return self._convert_to_rw_cache(reordered_past)
944
+
945
+
946
+ class RWForSequenceClassification(RWPreTrainedModel):
947
+ _keys_to_ignore_on_load_missing = [
948
+ r"h.*.self_attention.scale_mask_softmax.causal_mask",
949
+ r"lm_head.weight",
950
+ ]
951
+
952
+ def __init__(self, config: RWConfig):
953
+ super().__init__(config)
954
+ self.num_labels = config.num_labels
955
+ self.transformer = RWModel(config)
956
+ self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
957
+
958
+ # Initialize weights and apply final processing
959
+ self.post_init()
960
+
961
+ def forward(
962
+ self,
963
+ input_ids: Optional[torch.LongTensor] = None,
964
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
965
+ attention_mask: Optional[torch.Tensor] = None,
966
+ head_mask: Optional[torch.Tensor] = None,
967
+ inputs_embeds: Optional[torch.Tensor] = None,
968
+ labels: Optional[torch.Tensor] = None,
969
+ use_cache: Optional[bool] = None,
970
+ output_attentions: Optional[bool] = None,
971
+ output_hidden_states: Optional[bool] = None,
972
+ return_dict: Optional[bool] = None,
973
+ **deprecated_arguments,
974
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
975
+ r"""
976
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
977
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
978
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
979
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
980
+ """
981
+ if deprecated_arguments.pop("position_ids", False) is not False:
982
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
983
+ warnings.warn(
984
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
985
+ " passing `position_ids`.",
986
+ FutureWarning,
987
+ )
988
+ if len(deprecated_arguments) > 0:
989
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
990
+
991
+ return_dict = (
992
+ return_dict if return_dict is not None else self.config.use_return_dict
993
+ )
994
+
995
+ transformer_outputs = self.transformer(
996
+ input_ids,
997
+ past_key_values=past_key_values,
998
+ attention_mask=attention_mask,
999
+ head_mask=head_mask,
1000
+ inputs_embeds=inputs_embeds,
1001
+ use_cache=use_cache,
1002
+ output_attentions=output_attentions,
1003
+ output_hidden_states=output_hidden_states,
1004
+ return_dict=return_dict,
1005
+ )
1006
+
1007
+ hidden_states = transformer_outputs[0]
1008
+ logits = self.score(hidden_states)
1009
+
1010
+ if input_ids is not None:
1011
+ batch_size = input_ids.shape[0]
1012
+ else:
1013
+ batch_size = inputs_embeds.shape[0]
1014
+
1015
+ if self.config.pad_token_id is None and batch_size != 1:
1016
+ raise ValueError(
1017
+ "Cannot handle batch sizes > 1 if no padding token is defined."
1018
+ )
1019
+ if self.config.pad_token_id is None:
1020
+ sequence_lengths = -1
1021
+ else:
1022
+ if input_ids is not None:
1023
+ sequence_lengths = (
1024
+ torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1
1025
+ )
1026
+ else:
1027
+ sequence_lengths = -1
1028
+ logger.warning(
1029
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1030
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1031
+ )
1032
+
1033
+ pooled_logits = logits[
1034
+ torch.arange(batch_size, device=logits.device), sequence_lengths
1035
+ ]
1036
+
1037
+ loss = None
1038
+ if labels is not None:
1039
+ if self.config.problem_type is None:
1040
+ if self.num_labels == 1:
1041
+ self.config.problem_type = "regression"
1042
+ elif self.num_labels > 1 and (
1043
+ labels.dtype == torch.long or labels.dtype == torch.int
1044
+ ):
1045
+ self.config.problem_type = "single_label_classification"
1046
+ else:
1047
+ self.config.problem_type = "multi_label_classification"
1048
+
1049
+ if self.config.problem_type == "regression":
1050
+ loss_fct = MSELoss()
1051
+ if self.num_labels == 1:
1052
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1053
+ else:
1054
+ loss = loss_fct(pooled_logits, labels)
1055
+ elif self.config.problem_type == "single_label_classification":
1056
+ loss_fct = CrossEntropyLoss()
1057
+ loss = loss_fct(pooled_logits, labels)
1058
+ elif self.config.problem_type == "multi_label_classification":
1059
+ loss_fct = BCEWithLogitsLoss()
1060
+ loss = loss_fct(pooled_logits, labels)
1061
+ if not return_dict:
1062
+ output = (pooled_logits,) + transformer_outputs[1:]
1063
+ return ((loss,) + output) if loss is not None else output
1064
+
1065
+ return SequenceClassifierOutputWithPast(
1066
+ loss=loss,
1067
+ logits=pooled_logits,
1068
+ past_key_values=transformer_outputs.past_key_values,
1069
+ hidden_states=transformer_outputs.hidden_states,
1070
+ attentions=transformer_outputs.attentions,
1071
+ )
1072
+
1073
+
1074
+ class RWForTokenClassification(RWPreTrainedModel):
1075
+ _keys_to_ignore_on_load_missing = [
1076
+ r"h.*.self_attention.scale_mask_softmax.causal_mask",
1077
+ r"lm_head.weight",
1078
+ ]
1079
+
1080
+ def __init__(self, config: RWConfig):
1081
+ super().__init__(config)
1082
+ self.num_labels = config.num_labels
1083
+
1084
+ self.transformer = RWModel(config)
1085
+ if (
1086
+ hasattr(config, "classifier_dropout")
1087
+ and config.classifier_dropout is not None
1088
+ ):
1089
+ classifier_dropout = config.classifier_dropout
1090
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1091
+ classifier_dropout = config.hidden_dropout
1092
+ else:
1093
+ classifier_dropout = 0.1
1094
+ self.dropout = nn.Dropout(classifier_dropout)
1095
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1096
+
1097
+ # Initialize weights and apply final processing
1098
+ self.post_init()
1099
+
1100
+ def forward(
1101
+ self,
1102
+ input_ids: Optional[torch.LongTensor] = None,
1103
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1104
+ attention_mask: Optional[torch.Tensor] = None,
1105
+ head_mask: Optional[torch.Tensor] = None,
1106
+ inputs_embeds: Optional[torch.Tensor] = None,
1107
+ labels: Optional[torch.Tensor] = None,
1108
+ use_cache: Optional[bool] = None,
1109
+ output_attentions: Optional[bool] = None,
1110
+ output_hidden_states: Optional[bool] = None,
1111
+ return_dict: Optional[bool] = None,
1112
+ **deprecated_arguments,
1113
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1114
+ r"""
1115
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1116
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1117
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1118
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1119
+ """
1120
+ if deprecated_arguments.pop("position_ids", False) is not False:
1121
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
1122
+ warnings.warn(
1123
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
1124
+ " passing `position_ids`.",
1125
+ FutureWarning,
1126
+ )
1127
+ if len(deprecated_arguments) > 0:
1128
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
1129
+
1130
+ return_dict = (
1131
+ return_dict if return_dict is not None else self.config.use_return_dict
1132
+ )
1133
+
1134
+ transformer_outputs = self.transformer(
1135
+ input_ids,
1136
+ past_key_values=past_key_values,
1137
+ attention_mask=attention_mask,
1138
+ head_mask=head_mask,
1139
+ inputs_embeds=inputs_embeds,
1140
+ use_cache=use_cache,
1141
+ output_attentions=output_attentions,
1142
+ output_hidden_states=output_hidden_states,
1143
+ return_dict=return_dict,
1144
+ )
1145
+
1146
+ hidden_states = transformer_outputs[0]
1147
+ hidden_states = self.dropout(hidden_states)
1148
+ logits = self.classifier(hidden_states)
1149
+
1150
+ loss = None
1151
+ if labels is not None:
1152
+ batch_size, seq_length = labels.shape
1153
+ loss_fct = CrossEntropyLoss()
1154
+ loss = loss_fct(
1155
+ logits.view(batch_size * seq_length, self.num_labels),
1156
+ labels.view(batch_size * seq_length),
1157
+ )
1158
+
1159
+ if not return_dict:
1160
+ output = (logits,) + transformer_outputs[2:]
1161
+ return ((loss,) + output) if loss is not None else output
1162
+
1163
+ return TokenClassifierOutput(
1164
+ loss=loss,
1165
+ logits=logits,
1166
+ hidden_states=transformer_outputs.hidden_states,
1167
+ attentions=transformer_outputs.attentions,
1168
+ )
1169
+
1170
+
1171
+ class RWForQuestionAnswering(RWPreTrainedModel):
1172
+ _keys_to_ignore_on_load_missing = [
1173
+ r"h.*.self_attention.scale_mask_softmax.causal_mask",
1174
+ r"lm_head.weight",
1175
+ ]
1176
+
1177
+ def __init__(self, config):
1178
+ super().__init__(config)
1179
+ self.transformer = RWModel(config)
1180
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1181
+
1182
+ # Initialize weights and apply final processing
1183
+ self.post_init()
1184
+
1185
+ def forward(
1186
+ self,
1187
+ input_ids: Optional[torch.LongTensor] = None,
1188
+ attention_mask: Optional[torch.FloatTensor] = None,
1189
+ position_ids: Optional[torch.LongTensor] = None,
1190
+ head_mask: Optional[torch.FloatTensor] = None,
1191
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1192
+ start_positions: Optional[torch.LongTensor] = None,
1193
+ end_positions: Optional[torch.LongTensor] = None,
1194
+ output_attentions: Optional[bool] = None,
1195
+ output_hidden_states: Optional[bool] = None,
1196
+ return_dict: Optional[bool] = None,
1197
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1198
+ r"""
1199
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1200
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1201
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1202
+ are not taken into account for computing the loss.
1203
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1204
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1205
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1206
+ are not taken into account for computing the loss.
1207
+ """
1208
+ return_dict = (
1209
+ return_dict if return_dict is not None else self.config.use_return_dict
1210
+ )
1211
+
1212
+ outputs = self.transformer(
1213
+ input_ids,
1214
+ attention_mask=attention_mask,
1215
+ position_ids=position_ids,
1216
+ head_mask=head_mask,
1217
+ inputs_embeds=inputs_embeds,
1218
+ output_attentions=output_attentions,
1219
+ output_hidden_states=output_hidden_states,
1220
+ return_dict=return_dict,
1221
+ )
1222
+
1223
+ sequence_output = outputs[0]
1224
+
1225
+ logits = self.qa_outputs(sequence_output)
1226
+ start_logits, end_logits = logits.split(1, dim=-1)
1227
+ start_logits = start_logits.squeeze(-1).contiguous()
1228
+ end_logits = end_logits.squeeze(-1).contiguous()
1229
+
1230
+ total_loss = None
1231
+ if start_positions is not None and end_positions is not None:
1232
+ # If we are on multi-GPU, split add a dimension
1233
+ if len(start_positions.size()) > 1:
1234
+ start_positions = start_positions.squeeze(-1)
1235
+ if len(end_positions.size()) > 1:
1236
+ end_positions = end_positions.squeeze(-1)
1237
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1238
+ ignored_index = start_logits.size(1)
1239
+ start_positions = start_positions.clamp(0, ignored_index)
1240
+ end_positions = end_positions.clamp(0, ignored_index)
1241
+
1242
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1243
+ start_loss = loss_fct(start_logits, start_positions)
1244
+ end_loss = loss_fct(end_logits, end_positions)
1245
+ total_loss = (start_loss + end_loss) / 2
1246
+
1247
+ if not return_dict:
1248
+ output = (start_logits, end_logits) + outputs[2:]
1249
+ return ((total_loss,) + output) if total_loss is not None else output
1250
+
1251
+ return QuestionAnsweringModelOutput(
1252
+ loss=total_loss,
1253
+ start_logits=start_logits,
1254
+ end_logits=end_logits,
1255
+ hidden_states=outputs.hidden_states,
1256
+ attentions=outputs.attentions,
1257
+ )