Text Generation
Transformers
PyTorch
code
gpt2
custom_code
Eval Results
text-generation-inference
mayank-mishra commited on
Commit
cbd1dd1
1 Parent(s): e464072

:bug: fix past_length in santacoder

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. modeling_gpt2_mq.py +200 -23
config.json CHANGED
@@ -14,7 +14,7 @@
14
  "eos_token_id": 50256,
15
  "initializer_range": 0.02,
16
  "layer_norm_epsilon": 1e-05,
17
- "model_type": "gpt2",
18
  "n_embd": 2048,
19
  "n_head": 16,
20
  "n_inner": 8192,
 
14
  "eos_token_id": 50256,
15
  "initializer_range": 0.02,
16
  "layer_norm_epsilon": 1e-05,
17
+ "model_type": "santacoder",
18
  "n_embd": 2048,
19
  "n_head": 16,
20
  "n_inner": 8192,
modeling_gpt2_mq.py CHANGED
@@ -1,39 +1,21 @@
1
  """PyTorch OpenAI GPT-2 model modified with MultiQuery attention"""
2
 
3
 
4
- import math
5
- import os
6
- from dataclasses import dataclass
7
  from typing import Optional, Tuple, Union
8
 
9
  import torch
10
  import torch.utils.checkpoint
11
  from torch import nn
12
  from torch.cuda.amp import autocast
13
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
14
-
15
- from transformers.activations import ACT2FN
16
- from transformers.modeling_outputs import (
17
- BaseModelOutputWithPastAndCrossAttentions,
18
- CausalLMOutputWithCrossAttentions,
19
- SequenceClassifierOutputWithPast,
20
- TokenClassifierOutput,
21
- )
22
- from transformers.modeling_utils import PreTrainedModel, SequenceSummary
23
  from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
24
 
25
- from transformers.utils import (
26
- ModelOutput,
27
- add_code_sample_docstrings,
28
- add_start_docstrings,
29
- add_start_docstrings_to_model_forward,
30
- logging,
31
- replace_return_docstrings,
32
- )
33
- from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
34
  from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2Block, GPT2PreTrainedModel, GPT2LMHeadModel
35
- from .configuration_gpt2_mq import GPT2CustomConfig, MULTI_QUERY, MULTI_HEAD
36
 
 
37
 
38
 
39
  class GPT2MQAttention(nn.Module):
@@ -329,6 +311,201 @@ class GPT2CustomModel(GPT2Model):
329
  # Initialize weights and apply final processing
330
  self.post_init()
331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
333
  class GPT2LMHeadCustomModel(GPT2LMHeadModel):
334
  config_class = GPT2CustomConfig
 
1
  """PyTorch OpenAI GPT-2 model modified with MultiQuery attention"""
2
 
3
 
 
 
 
4
  from typing import Optional, Tuple, Union
5
 
6
  import torch
7
  import torch.utils.checkpoint
8
  from torch import nn
9
  from torch.cuda.amp import autocast
10
+
11
+ from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
 
 
 
 
 
 
 
 
12
  from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
13
 
14
+ from transformers.utils import logging
 
 
 
 
 
 
 
 
15
  from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2Block, GPT2PreTrainedModel, GPT2LMHeadModel
16
+ from .configuration_gpt2_mq import GPT2CustomConfig, MULTI_QUERY
17
 
18
+ logger = logging.get_logger(__name__)
19
 
20
 
21
  class GPT2MQAttention(nn.Module):
 
311
  # Initialize weights and apply final processing
312
  self.post_init()
313
 
314
+ def forward(
315
+ self,
316
+ input_ids: Optional[torch.LongTensor] = None,
317
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
318
+ attention_mask: Optional[torch.FloatTensor] = None,
319
+ token_type_ids: Optional[torch.LongTensor] = None,
320
+ position_ids: Optional[torch.LongTensor] = None,
321
+ head_mask: Optional[torch.FloatTensor] = None,
322
+ inputs_embeds: Optional[torch.FloatTensor] = None,
323
+ encoder_hidden_states: Optional[torch.Tensor] = None,
324
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
325
+ use_cache: Optional[bool] = None,
326
+ output_attentions: Optional[bool] = None,
327
+ output_hidden_states: Optional[bool] = None,
328
+ return_dict: Optional[bool] = None,
329
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
330
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
331
+ output_hidden_states = (
332
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
333
+ )
334
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
335
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
336
+
337
+ if input_ids is not None and inputs_embeds is not None:
338
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
339
+ elif input_ids is not None:
340
+ input_shape = input_ids.size()
341
+ input_ids = input_ids.view(-1, input_shape[-1])
342
+ batch_size = input_ids.shape[0]
343
+ elif inputs_embeds is not None:
344
+ input_shape = inputs_embeds.size()[:-1]
345
+ batch_size = inputs_embeds.shape[0]
346
+ else:
347
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
348
+
349
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
350
+
351
+ if token_type_ids is not None:
352
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
353
+ if position_ids is not None:
354
+ position_ids = position_ids.view(-1, input_shape[-1])
355
+
356
+ if past_key_values is None:
357
+ past_length = 0
358
+ past_key_values = tuple([None] * len(self.h))
359
+ else:
360
+ # this is different from GPT2
361
+ past_length = past_key_values[0][0].size(-1)
362
+ if position_ids is None:
363
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
364
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
365
+
366
+ # GPT2Attention mask.
367
+ if attention_mask is not None:
368
+ if batch_size <= 0:
369
+ raise ValueError("batch_size has to be defined and > 0")
370
+ attention_mask = attention_mask.view(batch_size, -1)
371
+ # We create a 3D attention mask from a 2D tensor mask.
372
+ # Sizes are [batch_size, 1, 1, to_seq_length]
373
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
374
+ # this attention mask is more simple than the triangular masking of causal attention
375
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
376
+ attention_mask = attention_mask[:, None, None, :]
377
+
378
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
379
+ # masked positions, this operation will create a tensor which is 0.0 for
380
+ # positions we want to attend and the dtype's smallest value for masked positions.
381
+ # Since we are adding it to the raw scores before the softmax, this is
382
+ # effectively the same as removing these entirely.
383
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
384
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
385
+
386
+ # If a 2D or 3D attention mask is provided for the cross-attention
387
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
388
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
389
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
390
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
391
+ if encoder_attention_mask is None:
392
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
393
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
394
+ else:
395
+ encoder_attention_mask = None
396
+
397
+ # Prepare head mask if needed
398
+ # 1.0 in head_mask indicate we keep the head
399
+ # attention_probs has shape bsz x n_heads x N x N
400
+ # head_mask has shape n_layer x batch x n_heads x N x N
401
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
402
+
403
+ if inputs_embeds is None:
404
+ inputs_embeds = self.wte(input_ids)
405
+ position_embeds = self.wpe(position_ids)
406
+ hidden_states = inputs_embeds + position_embeds
407
+
408
+ if token_type_ids is not None:
409
+ token_type_embeds = self.wte(token_type_ids)
410
+ hidden_states = hidden_states + token_type_embeds
411
+
412
+ hidden_states = self.drop(hidden_states)
413
+
414
+ output_shape = input_shape + (hidden_states.size(-1),)
415
+
416
+ presents = () if use_cache else None
417
+ all_self_attentions = () if output_attentions else None
418
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
419
+ all_hidden_states = () if output_hidden_states else None
420
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
421
+
422
+ # Model parallel
423
+ if self.model_parallel:
424
+ torch.cuda.set_device(hidden_states.device)
425
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
426
+ if layer_past is not None:
427
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
428
+ # Ensure that attention_mask is always on the same device as hidden_states
429
+ if attention_mask is not None:
430
+ attention_mask = attention_mask.to(hidden_states.device)
431
+ if isinstance(head_mask, torch.Tensor):
432
+ head_mask = head_mask.to(hidden_states.device)
433
+ if output_hidden_states:
434
+ all_hidden_states = all_hidden_states + (hidden_states,)
435
+
436
+ if self.gradient_checkpointing and self.training:
437
+
438
+ if use_cache:
439
+ logger.warning(
440
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
441
+ )
442
+ use_cache = False
443
+
444
+ def create_custom_forward(module):
445
+ def custom_forward(*inputs):
446
+ # None for past_key_value
447
+ return module(*inputs, use_cache, output_attentions)
448
+
449
+ return custom_forward
450
+
451
+ outputs = torch.utils.checkpoint.checkpoint(
452
+ create_custom_forward(block),
453
+ hidden_states,
454
+ None,
455
+ attention_mask,
456
+ head_mask[i],
457
+ encoder_hidden_states,
458
+ encoder_attention_mask,
459
+ )
460
+ else:
461
+ outputs = block(
462
+ hidden_states,
463
+ layer_past=layer_past,
464
+ attention_mask=attention_mask,
465
+ head_mask=head_mask[i],
466
+ encoder_hidden_states=encoder_hidden_states,
467
+ encoder_attention_mask=encoder_attention_mask,
468
+ use_cache=use_cache,
469
+ output_attentions=output_attentions,
470
+ )
471
+
472
+ hidden_states = outputs[0]
473
+ if use_cache is True:
474
+ presents = presents + (outputs[1],)
475
+
476
+ if output_attentions:
477
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
478
+ if self.config.add_cross_attention:
479
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
480
+
481
+ # Model Parallel: If it's the last layer for that device, put things on the next device
482
+ if self.model_parallel:
483
+ for k, v in self.device_map.items():
484
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
485
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
486
+
487
+ hidden_states = self.ln_f(hidden_states)
488
+
489
+ hidden_states = hidden_states.view(output_shape)
490
+ # Add last hidden state
491
+ if output_hidden_states:
492
+ all_hidden_states = all_hidden_states + (hidden_states,)
493
+
494
+ if not return_dict:
495
+ return tuple(
496
+ v
497
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
498
+ if v is not None
499
+ )
500
+
501
+ return BaseModelOutputWithPastAndCrossAttentions(
502
+ last_hidden_state=hidden_states,
503
+ past_key_values=presents,
504
+ hidden_states=all_hidden_states,
505
+ attentions=all_self_attentions,
506
+ cross_attentions=all_cross_attentions,
507
+ )
508
+
509
 
510
  class GPT2LMHeadCustomModel(GPT2LMHeadModel):
511
  config_class = GPT2CustomConfig