jbochi commited on
Commit
06e4cba
1 Parent(s): a3055fa

Fix mistake in parallel layers

Browse files
Files changed (1) hide show
  1. decoder_only_t5/modeling.py +42 -53
decoder_only_t5/modeling.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  from torch import nn
6
  from torch.nn import CrossEntropyLoss
7
  from transformers.models.t5 import modeling_t5
8
- from transformers.modeling_outputs import Seq2SeqLMOutput
9
  from transformers.utils import (
10
  add_start_docstrings_to_model_forward,
11
  logging,
@@ -167,22 +167,28 @@ class DecoderOnlyT5Attention(modeling_t5.T5Attention):
167
  ) # (batch_size, n_heads, seq_length, dim_per_head)
168
 
169
  # get key/value states
170
- key_states = project(
171
- hidden_states,
172
- self.k,
173
- key_value_states,
174
- past_key_value[0] if past_key_value is not None else None,
 
 
 
175
  )
176
- value_states = project(
177
- hidden_states,
178
- self.v,
179
- key_value_states,
180
- past_key_value[1] if past_key_value is not None else None,
 
 
 
181
  )
182
 
183
  # compute scores
184
  scores = torch.matmul(
185
- query_states, repeat_kv(key_states, self.n_kv_groups).transpose(3, 2)
186
  ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
187
 
188
  if position_bias is None:
@@ -345,8 +351,9 @@ class DecoderOnlyT5Block(modeling_t5.T5Block):
345
 
346
  ff_layer = self.layer[-1]
347
  if self.parallel_layers:
 
348
  x = self.layer[0].layer_norm(hidden_states)
349
- ff_output = ff_layer(hidden_states)
350
  else:
351
  x = hidden_states
352
 
@@ -418,7 +425,7 @@ class DecoderOnlyT5Block(modeling_t5.T5Block):
418
  attention_outputs = attention_outputs + cross_attention_outputs[2:]
419
 
420
  if self.parallel_layers:
421
- # https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L295
422
  hidden_states = x + ff_output
423
  hidden_states *= 2**-0.5
424
  hidden_states = hidden_states + self.layer[0].dropout(hidden_states)
@@ -508,27 +515,21 @@ class DecoderOnlyT5Model(modeling_t5.T5ForConditionalGeneration):
508
 
509
  @add_start_docstrings_to_model_forward(modeling_t5.T5_INPUTS_DOCSTRING)
510
  @replace_return_docstrings(
511
- output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
512
  )
513
  def forward(
514
  self,
515
- _input_ids: Optional[torch.LongTensor] = None,
516
  attention_mask: Optional[torch.FloatTensor] = None,
517
- decoder_input_ids: Optional[torch.LongTensor] = None,
518
- decoder_attention_mask: Optional[torch.BoolTensor] = None,
519
- head_mask: Optional[torch.FloatTensor] = None,
520
- decoder_head_mask: Optional[torch.FloatTensor] = None,
521
- cross_attn_head_mask: Optional[torch.Tensor] = None,
522
- encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
523
  past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
524
- _inputs_embeds: Optional[torch.FloatTensor] = None,
525
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
526
  labels: Optional[torch.LongTensor] = None,
527
  use_cache: Optional[bool] = None,
528
  output_attentions: Optional[bool] = None,
529
  output_hidden_states: Optional[bool] = None,
530
  return_dict: Optional[bool] = None,
531
- ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
532
  r"""
533
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
534
  Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
@@ -548,43 +549,31 @@ class DecoderOnlyT5Model(modeling_t5.T5ForConditionalGeneration):
548
  if self.model_parallel:
549
  torch.cuda.set_device(self.decoder.first_device)
550
 
551
- if (
552
- labels is not None
553
- and decoder_input_ids is None
554
- and decoder_inputs_embeds is None
555
- ):
556
- # get decoder inputs from shifting lm labels to the right
557
- decoder_input_ids = self._shift_right(labels)
558
-
559
  # Set device for model parallelism
560
  if self.model_parallel:
561
  torch.cuda.set_device(self.decoder.first_device)
562
- if decoder_input_ids is not None:
563
- decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
564
  if attention_mask is not None:
565
  attention_mask = attention_mask.to(self.decoder.first_device)
566
- if decoder_attention_mask is not None:
567
- decoder_attention_mask = decoder_attention_mask.to(
568
- self.decoder.first_device
569
- )
570
 
571
  # Decode
572
- decoder_outputs = self.decoder(
573
- input_ids=decoder_input_ids,
574
- attention_mask=decoder_attention_mask,
575
- inputs_embeds=decoder_inputs_embeds,
576
  past_key_values=past_key_values,
577
- # encoder_hidden_states=hidden_states,
578
- encoder_attention_mask=attention_mask,
579
- head_mask=decoder_head_mask,
580
- cross_attn_head_mask=cross_attn_head_mask,
581
  use_cache=use_cache,
582
  output_attentions=output_attentions,
583
  output_hidden_states=output_hidden_states,
584
  return_dict=return_dict,
585
  )
586
 
587
- sequence_output = decoder_outputs[0]
588
 
589
  # Set device for model parallelism
590
  if self.model_parallel:
@@ -608,13 +597,13 @@ class DecoderOnlyT5Model(modeling_t5.T5ForConditionalGeneration):
608
  # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
609
 
610
  if not return_dict:
611
- output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
612
  return ((loss,) + output) if loss is not None else output
613
 
614
- return Seq2SeqLMOutput(
615
  loss=loss,
616
  logits=lm_logits,
617
- past_key_values=decoder_outputs.past_key_values,
618
- decoder_hidden_states=decoder_outputs.hidden_states,
619
- decoder_attentions=decoder_outputs.attentions,
620
  )
 
5
  from torch import nn
6
  from torch.nn import CrossEntropyLoss
7
  from transformers.models.t5 import modeling_t5
8
+ from transformers.modeling_outputs import CausalLMOutputWithPast
9
  from transformers.utils import (
10
  add_start_docstrings_to_model_forward,
11
  logging,
 
167
  ) # (batch_size, n_heads, seq_length, dim_per_head)
168
 
169
  # get key/value states
170
+ key_states = repeat_kv(
171
+ project(
172
+ hidden_states,
173
+ self.k,
174
+ key_value_states,
175
+ past_key_value[0] if past_key_value is not None else None,
176
+ ),
177
+ self.n_kv_groups,
178
  )
179
+ value_states = repeat_kv(
180
+ project(
181
+ hidden_states,
182
+ self.v,
183
+ key_value_states,
184
+ past_key_value[1] if past_key_value is not None else None,
185
+ ),
186
+ self.n_kv_groups,
187
  )
188
 
189
  # compute scores
190
  scores = torch.matmul(
191
+ query_states, key_states.transpose(3, 2)
192
  ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
193
 
194
  if position_bias is None:
 
351
 
352
  ff_layer = self.layer[-1]
353
  if self.parallel_layers:
354
+ # https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L563-L568
355
  x = self.layer[0].layer_norm(hidden_states)
356
+ ff_output = ff_layer(x)
357
  else:
358
  x = hidden_states
359
 
 
425
  attention_outputs = attention_outputs + cross_attention_outputs[2:]
426
 
427
  if self.parallel_layers:
428
+ # https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L534-L578
429
  hidden_states = x + ff_output
430
  hidden_states *= 2**-0.5
431
  hidden_states = hidden_states + self.layer[0].dropout(hidden_states)
 
515
 
516
  @add_start_docstrings_to_model_forward(modeling_t5.T5_INPUTS_DOCSTRING)
517
  @replace_return_docstrings(
518
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
519
  )
520
  def forward(
521
  self,
522
+ input_ids: Optional[torch.LongTensor] = None,
523
  attention_mask: Optional[torch.FloatTensor] = None,
524
+ position_ids: Optional[torch.LongTensor] = None,
 
 
 
 
 
525
  past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
526
+ inputs_embeds: Optional[torch.FloatTensor] = None,
 
527
  labels: Optional[torch.LongTensor] = None,
528
  use_cache: Optional[bool] = None,
529
  output_attentions: Optional[bool] = None,
530
  output_hidden_states: Optional[bool] = None,
531
  return_dict: Optional[bool] = None,
532
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
533
  r"""
534
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
535
  Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
 
549
  if self.model_parallel:
550
  torch.cuda.set_device(self.decoder.first_device)
551
 
 
 
 
 
 
 
 
 
552
  # Set device for model parallelism
553
  if self.model_parallel:
554
  torch.cuda.set_device(self.decoder.first_device)
555
+ if input_ids is not None:
556
+ input_ids = input_ids.to(self.decoder.first_device)
557
  if attention_mask is not None:
558
  attention_mask = attention_mask.to(self.decoder.first_device)
 
 
 
 
559
 
560
  # Decode
561
+ outputs = self.decoder(
562
+ input_ids=input_ids,
563
+ attention_mask=attention_mask,
564
+ inputs_embeds=inputs_embeds,
565
  past_key_values=past_key_values,
566
+ encoder_hidden_states=None,
567
+ encoder_attention_mask=None,
568
+ head_mask=None,
569
+ cross_attn_head_mask=None,
570
  use_cache=use_cache,
571
  output_attentions=output_attentions,
572
  output_hidden_states=output_hidden_states,
573
  return_dict=return_dict,
574
  )
575
 
576
+ sequence_output = outputs[0]
577
 
578
  # Set device for model parallelism
579
  if self.model_parallel:
 
597
  # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
598
 
599
  if not return_dict:
600
+ output = (lm_logits,) + outputs[1:]
601
  return ((loss,) + output) if loss is not None else output
602
 
603
+ return CausalLMOutputWithPast(
604
  loss=loss,
605
  logits=lm_logits,
606
+ past_key_values=outputs.past_key_values,
607
+ hidden_states=outputs.hidden_states,
608
+ attentions=outputs.attentions,
609
  )