wanadzhar913 commited on
Commit
de174dc
1 Parent(s): 441b842

Upload MistralForSequenceClassification

Browse files
Files changed (2) hide show
  1. classifier.py +429 -1
  2. config.json +1 -1
classifier.py CHANGED
@@ -2,12 +2,440 @@ from bidirectional_mistral import MistralBiModel
2
  from transformers import MistralPreTrainedModel
3
  import torch
4
  import numpy as np
5
- from typing import Optional, List
6
  from torch import nn
7
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
8
  from transformers.modeling_outputs import SequenceClassifierOutputWithPast
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
 
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  class MistralForSequenceClassification(MistralPreTrainedModel):
12
  def __init__(self, config):
13
  super().__init__(config)
 
2
  from transformers import MistralPreTrainedModel
3
  import torch
4
  import numpy as np
5
+ from typing import List, Optional, Tuple, Union
6
  from torch import nn
7
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
8
  from transformers.modeling_outputs import SequenceClassifierOutputWithPast
9
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
10
+ from transformers import (
11
+ MistralModel,
12
+ MistralPreTrainedModel,
13
+ MistralForCausalLM,
14
+ MistralConfig,
15
+ )
16
+ from transformers.modeling_outputs import BaseModelOutputWithPast
17
+ from transformers.cache_utils import Cache, DynamicCache
18
+ from transformers.models.mistral.modeling_mistral import (
19
+ MistralDecoderLayer,
20
+ MistralRMSNorm,
21
+ MistralAttention,
22
+ MistralFlashAttention2,
23
+ MistralSdpaAttention,
24
+ MistralMLP,
25
+ )
26
+ from torch import nn
27
+ from transformers.utils import logging
28
+
29
+
30
+ def _prepare_4d_causal_attention_mask(
31
+ attention_mask: Optional[torch.Tensor],
32
+ input_shape: Union[torch.Size, Tuple, List],
33
+ inputs_embeds: torch.Tensor,
34
+ past_key_values_length: int,
35
+ sliding_window: Optional[int] = None,
36
+ ):
37
+ """
38
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
39
+ `(batch_size, key_value_length)`
40
+
41
+ Args:
42
+ attention_mask (`torch.Tensor` or `None`):
43
+ A 2D attention mask of shape `(batch_size, key_value_length)`
44
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
45
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
46
+ inputs_embeds (`torch.Tensor`):
47
+ The embedded inputs as a torch Tensor.
48
+ past_key_values_length (`int`):
49
+ The length of the key value cache.
50
+ sliding_window (`int`, *optional*):
51
+ If the model uses windowed attention, a sliding window should be passed.
52
+ """
53
+ attn_mask_converter = AttentionMaskConverter(
54
+ is_causal=False, sliding_window=sliding_window
55
+ ) # is_causal=True in original implementation
56
+
57
+ key_value_length = input_shape[-1] + past_key_values_length
58
+
59
+ # 4d mask is passed through the layers
60
+ if attention_mask is not None and len(attention_mask.shape) == 2:
61
+ attention_mask = attn_mask_converter.to_4d(
62
+ attention_mask,
63
+ input_shape[-1],
64
+ key_value_length=key_value_length,
65
+ dtype=inputs_embeds.dtype,
66
+ )
67
+ elif attention_mask is not None and len(attention_mask.shape) == 4:
68
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
69
+ if tuple(attention_mask.shape) != expected_shape:
70
+ raise ValueError(
71
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
72
+ )
73
+ else:
74
+ # if the 4D mask has correct shape - invert it and fill with negative infinity
75
+ inverted_mask = 1.0 - attention_mask
76
+ attention_mask = inverted_mask.masked_fill(
77
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
78
+ )
79
+ else:
80
+ attention_mask = attn_mask_converter.to_causal_4d(
81
+ input_shape[0],
82
+ input_shape[-1],
83
+ key_value_length,
84
+ dtype=inputs_embeds.dtype,
85
+ device=inputs_embeds.device,
86
+ )
87
+
88
+ return attention_mask
89
+
90
+
91
+ # Adapted from _prepare_4d_causal_attention_mask
92
+ def _prepare_4d_causal_attention_mask_for_sdpa(
93
+ attention_mask: Optional[torch.Tensor],
94
+ input_shape: Union[torch.Size, Tuple, List],
95
+ inputs_embeds: torch.Tensor,
96
+ past_key_values_length: int,
97
+ sliding_window: Optional[int] = None,
98
+ ):
99
+ """
100
+ Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
101
+
102
+ In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
103
+ `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
104
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
105
+ """
106
+ attn_mask_converter = AttentionMaskConverter(
107
+ is_causal=False, sliding_window=sliding_window
108
+ ) # is_causal=True in original implementation
109
+
110
+ key_value_length = input_shape[-1] + past_key_values_length
111
+ batch_size, query_length = input_shape
112
+
113
+ # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
114
+ # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
115
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is
116
+ # possible (https://github.com/pytorch/pytorch/pull/120400).
117
+ is_tracing = (
118
+ torch.jit.is_tracing()
119
+ or isinstance(inputs_embeds, torch.fx.Proxy)
120
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
121
+ )
122
+
123
+ if attention_mask is not None:
124
+ # 4d mask is passed through
125
+ if len(attention_mask.shape) == 4:
126
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
127
+ if tuple(attention_mask.shape) != expected_shape:
128
+ raise ValueError(
129
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
130
+ )
131
+ else:
132
+ # if the 4D mask has correct shape - invert it and fill with negative infinity
133
+ inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)
134
+ attention_mask = inverted_mask.masked_fill(
135
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
136
+ )
137
+ return attention_mask
138
+
139
+ elif not is_tracing and torch.all(attention_mask == 1):
140
+ if query_length == 1:
141
+ # For query_length == 1, causal attention and bi-directional attention are the same.
142
+ attention_mask = None
143
+ elif key_value_length == query_length:
144
+ attention_mask = None
145
+ else:
146
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
147
+ # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
148
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
149
+ pass
150
+ elif query_length > 1 and key_value_length != query_length:
151
+ # See the comment above (https://github.com/pytorch/pytorch/issues/108108).
152
+ # Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`.
153
+ attention_mask = True
154
+ elif is_tracing:
155
+ raise ValueError(
156
+ 'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
157
+ )
158
+
159
+ if attention_mask is None:
160
+ expanded_4d_mask = None
161
+ elif attention_mask is True:
162
+ expanded_4d_mask = attn_mask_converter.to_causal_4d(
163
+ input_shape[0],
164
+ input_shape[-1],
165
+ key_value_length,
166
+ dtype=inputs_embeds.dtype,
167
+ device=inputs_embeds.device,
168
+ )
169
+ else:
170
+ expanded_4d_mask = attn_mask_converter.to_4d(
171
+ attention_mask,
172
+ input_shape[-1],
173
+ dtype=inputs_embeds.dtype,
174
+ key_value_length=key_value_length,
175
+ )
176
+
177
+ # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
178
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
179
+ # Details: https://github.com/pytorch/pytorch/issues/110213
180
+ if not is_tracing and expanded_4d_mask.device.type == "cuda":
181
+ expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
182
+ expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
183
+ )
184
+
185
+ return expanded_4d_mask
186
+
187
+ class ModifiedMistralAttention(MistralAttention):
188
+ def __init__(self, *args, **kwargs):
189
+ super().__init__(*args, **kwargs)
190
+ self.is_causal = False
191
+
192
+
193
+ class ModifiedMistralFlashAttention2(MistralFlashAttention2):
194
+ def __init__(self, *args, **kwargs):
195
+ super().__init__(*args, **kwargs)
196
+ self.is_causal = False
197
+
198
+
199
+ class ModifiedMistralSdpaAttention(MistralSdpaAttention):
200
+ def __init__(self, *args, **kwargs):
201
+ super().__init__(*args, **kwargs)
202
+ self.is_causal = False
203
+
204
+
205
+ MISTRAL_ATTENTION_CLASSES = {
206
+ "eager": ModifiedMistralAttention,
207
+ "flash_attention_2": ModifiedMistralFlashAttention2,
208
+ "sdpa": ModifiedMistralSdpaAttention,
209
+ }
210
+
211
+
212
+ class ModifiedMistralDecoderLayer(MistralDecoderLayer):
213
+ def __init__(self, config: MistralConfig, layer_idx: int):
214
+ nn.Module.__init__(self)
215
+ self.hidden_size = config.hidden_size
216
+
217
+ self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](
218
+ config, layer_idx
219
+ )
220
+
221
+ self.mlp = MistralMLP(config)
222
+ self.input_layernorm = MistralRMSNorm(
223
+ config.hidden_size, eps=config.rms_norm_eps
224
+ )
225
+ self.post_attention_layernorm = MistralRMSNorm(
226
+ config.hidden_size, eps=config.rms_norm_eps
227
+ )
228
+
229
+
230
+ class MistralBiModel(MistralModel):
231
+ def __init__(self, config: MistralConfig):
232
+ MistralPreTrainedModel.__init__(self, config)
233
+ self.padding_idx = config.pad_token_id
234
+ self.vocab_size = config.vocab_size
235
+
236
+ self.embed_tokens = nn.Embedding(
237
+ config.vocab_size, config.hidden_size, self.padding_idx
238
+ )
239
+ self.layers = nn.ModuleList(
240
+ [
241
+ ModifiedMistralDecoderLayer(config, layer_idx)
242
+ for layer_idx in range(config.num_hidden_layers)
243
+ ]
244
+ )
245
+ self._attn_implementation = config._attn_implementation
246
+ self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
247
+
248
+ self.gradient_checkpointing = False
249
+ # Initialize weights and apply final processing
250
+ self.post_init()
251
+
252
+ # Copied from forward() in transformers.models.mistral.modeling_mistral.MistralModel
253
+ def forward(
254
+ self,
255
+ input_ids: torch.LongTensor = None,
256
+ attention_mask: Optional[torch.Tensor] = None,
257
+ position_ids: Optional[torch.LongTensor] = None,
258
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
259
+ inputs_embeds: Optional[torch.FloatTensor] = None,
260
+ use_cache: Optional[bool] = None,
261
+ output_attentions: Optional[bool] = None,
262
+ output_hidden_states: Optional[bool] = None,
263
+ return_dict: Optional[bool] = None,
264
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
265
+ output_attentions = (
266
+ output_attentions
267
+ if output_attentions is not None
268
+ else self.config.output_attentions
269
+ )
270
+ output_hidden_states = (
271
+ output_hidden_states
272
+ if output_hidden_states is not None
273
+ else self.config.output_hidden_states
274
+ )
275
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
276
+
277
+ return_dict = (
278
+ return_dict if return_dict is not None else self.config.use_return_dict
279
+ )
280
+
281
+ # retrieve input_ids and inputs_embeds
282
+ if input_ids is not None and inputs_embeds is not None:
283
+ raise ValueError(
284
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
285
+ )
286
+ elif input_ids is not None:
287
+ batch_size, seq_length = input_ids.shape
288
+ elif inputs_embeds is not None:
289
+ batch_size, seq_length, _ = inputs_embeds.shape
290
+ else:
291
+ raise ValueError(
292
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
293
+ )
294
+
295
+ if self.gradient_checkpointing and self.training:
296
+ if use_cache:
297
+ logger.warning_once(
298
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
299
+ )
300
+ use_cache = False
301
+
302
+ past_key_values_length = 0
303
+
304
+ if use_cache:
305
+ use_legacy_cache = not isinstance(past_key_values, Cache)
306
+ if use_legacy_cache:
307
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
308
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
309
+
310
+ if position_ids is None:
311
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
312
+ position_ids = torch.arange(
313
+ past_key_values_length,
314
+ seq_length + past_key_values_length,
315
+ dtype=torch.long,
316
+ device=device,
317
+ )
318
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
319
+ else:
320
+ position_ids = position_ids.view(-1, seq_length).long()
321
 
322
+ if inputs_embeds is None:
323
+ inputs_embeds = self.embed_tokens(input_ids)
324
 
325
+ if (
326
+ attention_mask is not None
327
+ and self._attn_implementation == "flash_attention_2"
328
+ and use_cache
329
+ ):
330
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
331
+ if is_padding_right:
332
+ raise ValueError(
333
+ "You are attempting to perform batched generation with padding_side='right'"
334
+ " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
335
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. ")
336
+
337
+ if self._attn_implementation == "flash_attention_2":
338
+ # 2d mask is passed through the layers
339
+ attention_mask = (
340
+ attention_mask
341
+ if (attention_mask is not None and 0 in attention_mask)
342
+ else None
343
+ )
344
+ elif self._attn_implementation == "sdpa" and not output_attentions:
345
+ # The original implementation is by-passed, see attn_mask_utils.py
346
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
347
+ attention_mask,
348
+ (batch_size, seq_length),
349
+ inputs_embeds,
350
+ past_key_values_length,
351
+ )
352
+ else:
353
+ # 4d mask is passed through the layers
354
+ attention_mask = _prepare_4d_causal_attention_mask(
355
+ attention_mask,
356
+ (batch_size, seq_length),
357
+ inputs_embeds,
358
+ past_key_values_length,
359
+ sliding_window=self.config.sliding_window,
360
+ )
361
+
362
+ hidden_states = inputs_embeds
363
+
364
+ # decoder layers
365
+ all_hidden_states = () if output_hidden_states else None
366
+ all_self_attns = () if output_attentions else None
367
+ next_decoder_cache = None
368
+
369
+ for decoder_layer in self.layers:
370
+ if output_hidden_states:
371
+ all_hidden_states += (hidden_states,)
372
+
373
+ if self.gradient_checkpointing and self.training:
374
+ layer_outputs = self._gradient_checkpointing_func(
375
+ decoder_layer.__call__,
376
+ hidden_states,
377
+ attention_mask,
378
+ position_ids,
379
+ past_key_values,
380
+ output_attentions,
381
+ use_cache,
382
+ )
383
+ else:
384
+ layer_outputs = decoder_layer(
385
+ hidden_states,
386
+ attention_mask=attention_mask,
387
+ position_ids=position_ids,
388
+ past_key_value=past_key_values,
389
+ output_attentions=output_attentions,
390
+ use_cache=use_cache,
391
+ )
392
+
393
+ hidden_states = layer_outputs[0]
394
+
395
+ if use_cache:
396
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
397
+
398
+ if output_attentions:
399
+ all_self_attns += (layer_outputs[1],)
400
+
401
+ hidden_states = self.norm(hidden_states)
402
+
403
+ # add hidden states from the last decoder layer
404
+ if output_hidden_states:
405
+ all_hidden_states += (hidden_states,)
406
+
407
+ next_cache = None
408
+ if use_cache:
409
+ next_cache = (
410
+ next_decoder_cache.to_legacy_cache()
411
+ if use_legacy_cache
412
+ else next_decoder_cache
413
+ )
414
+
415
+ if not return_dict:
416
+ return tuple(
417
+ v
418
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
419
+ if v is not None
420
+ )
421
+ return BaseModelOutputWithPast(
422
+ last_hidden_state=hidden_states,
423
+ past_key_values=next_cache,
424
+ hidden_states=all_hidden_states,
425
+ attentions=all_self_attns,
426
+ )
427
+
428
+
429
+ class MistralBiForMNTP(MistralForCausalLM):
430
+ def __init__(self, config):
431
+ MistralPreTrainedModel.__init__(self, config)
432
+ self.model = MistralBiModel(config)
433
+ self.vocab_size = config.vocab_size
434
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
435
+
436
+ # Initialize weights and apply final processing
437
+ self.post_init()
438
+
439
  class MistralForSequenceClassification(MistralPreTrainedModel):
440
  def __init__(self, config):
441
  super().__init__(config)
config.json CHANGED
@@ -26,7 +26,7 @@
26
  "sliding_window": 4096,
27
  "tie_word_embeddings": false,
28
  "torch_dtype": "float32",
29
- "transformers_version": "4.44.2",
30
  "use_cache": true,
31
  "vocab_size": 32000
32
  }
 
26
  "sliding_window": 4096,
27
  "tie_word_embeddings": false,
28
  "torch_dtype": "float32",
29
+ "transformers_version": "4.43.3",
30
  "use_cache": true,
31
  "vocab_size": 32000
32
  }