zxdu20 commited on
Commit
507dfe3
·
1 Parent(s): 5c07af4

Fix batch beam search

Browse files
Files changed (1) hide show
  1. modeling_glm.py +90 -34
modeling_glm.py CHANGED
@@ -29,13 +29,13 @@ from transformers.utils import (
29
  )
30
  from transformers.modeling_outputs import (
31
  BaseModelOutputWithPastAndCrossAttentions,
 
32
  SequenceClassifierOutput,
33
- ModelOutput
34
  )
 
35
  from transformers.modeling_utils import (
36
  PreTrainedModel,
37
  )
38
- from transformers.utils import logging
39
  from .configuration_glm import GLMConfig
40
  from torch.nn.parameter import Parameter
41
 
@@ -781,20 +781,60 @@ class GLMModel(GLMPreTrainedModel):
781
  attention_mask = torch.zeros(batch_size)
782
  # Transformer.
783
  transformer_output = self.transformer(embeddings, position_ids, attention_mask, mems)
784
- logits, hidden_layers = transformer_output
785
- # outputs = hidden_layers
786
  if self.output_predict:
787
- # Parallel logits.
788
- # logits_parallel = mpu.copy_to_model_parallel_region(
789
- # logits)
790
- logits = F.linear(logits, self.word_embeddings.weight)
791
 
792
  return ModelOutput(
 
793
  logits=logits,
794
- mems=hidden_layers,
795
  )
796
 
797
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
798
  @add_start_docstrings(
799
  """GLM Model transformer with a `language modeling` head on top""",
800
  GLM_START_DOCSTRING,
@@ -833,6 +873,16 @@ class GLMForConditionalGeneration(GLMPreTrainedModel):
833
  position_ids = position_ids[:, :, :seq_length]
834
  if attention_mask is not None:
835
  attention_mask = attention_mask[:, :, :seq_length, :seq_length]
 
 
 
 
 
 
 
 
 
 
836
  return {
837
  "input_ids": input_ids,
838
  "position_ids": position_ids,
@@ -845,10 +895,21 @@ class GLMForConditionalGeneration(GLMPreTrainedModel):
845
  input_ids=None,
846
  position_ids=None,
847
  attention_mask=None,
 
848
  mems=None,
849
  **kwargs
850
  ):
851
- return self.glm.forward(input_ids, position_ids, attention_mask, mems=mems, **kwargs)
 
 
 
 
 
 
 
 
 
 
852
 
853
 
854
  @add_start_docstrings(
@@ -857,16 +918,19 @@ class GLMForConditionalGeneration(GLMPreTrainedModel):
857
  GLM_START_DOCSTRING,
858
  )
859
  class GLMForSequenceClassification(GLMPreTrainedModel):
860
- def __init__(self, config, hidden_dropout=False, num_class=1):
861
  super().__init__(config)
862
  self.pool_token = config.pool_token
863
  self.glm = GLMModel(config)
864
  self.glm.output_predict = False
865
  self.num_class = num_class
866
  # Multi-choice head.
867
- self.pool_layer = torch.nn.Linear(config.hidden_size, config.hidden_size)
868
- self.multichoice_dropout = torch.nn.Dropout(hidden_dropout)
869
- self.multichoice_head = torch.nn.Linear(config.hidden_size, num_class)
 
 
 
870
 
871
  # Initialize weights and apply final processing
872
  self.post_init()
@@ -891,29 +955,21 @@ class GLMForSequenceClassification(GLMPreTrainedModel):
891
  input_ids = input_ids.reshape(-1, input_ids.size(-1))
892
  attention_mask = attention_mask.reshape(-1, *attention_mask.size()[2:])
893
  position_ids = position_ids.reshape(-1, *position_ids.size()[2:])
894
- model_out = self.glm.forward(input_ids, position_ids, attention_mask)
895
- outputs, mems = model_out.last_hidden_state, model_out.hidden_states
896
-
897
- if self.pool_token == 'start':
898
- output = outputs[
899
- torch.arange(outputs.size(0), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask]
900
- elif self.pool_token == 'pad':
901
- output = outputs[torch.arange(outputs.size(0), dtype=attention_mask.dtype,
902
- device=attention_mask.device), attention_mask - 1]
903
- elif self.pool_token == 'cls':
904
- output = outputs[:, 0]
905
- else:
906
- raise NotImplementedError
907
 
908
- output = torch.tanh(self.pool_layer(output))
909
- multichoice_output = self.multichoice_dropout(output)
910
- logits = self.multichoice_head(multichoice_output)
911
- loss_fct = CrossEntropyLoss()
 
912
  if num_choices is not None:
913
  logits = logits.view(-1, num_choices)
914
- # assert (labels is not None, "labels must not None!")
915
- loss = loss_fct(logits, labels)
 
 
916
  # loss = F.cross_entropy(logits.contiguous().float(), labels.long())
917
  return SequenceClassifierOutput(loss=loss,
918
  logits=logits,
919
- hidden_states=mems)
 
29
  )
30
  from transformers.modeling_outputs import (
31
  BaseModelOutputWithPastAndCrossAttentions,
32
+ ModelOutput,
33
  SequenceClassifierOutput,
 
34
  )
35
+
36
  from transformers.modeling_utils import (
37
  PreTrainedModel,
38
  )
 
39
  from .configuration_glm import GLMConfig
40
  from torch.nn.parameter import Parameter
41
 
 
781
  attention_mask = torch.zeros(batch_size)
782
  # Transformer.
783
  transformer_output = self.transformer(embeddings, position_ids, attention_mask, mems)
784
+ last_hidden_states, mems = transformer_output
785
+ logits = None
786
  if self.output_predict:
787
+ logits = F.linear(last_hidden_states, self.word_embeddings.weight)
 
 
 
788
 
789
  return ModelOutput(
790
+ last_hidden_states=last_hidden_states,
791
  logits=logits,
792
+ mems=mems,
793
  )
794
 
795
 
796
+ @add_start_docstrings(
797
+ """GLM Model transformer for multiple choice classification""",
798
+ GLM_START_DOCSTRING
799
+ )
800
+ class GLMForMultipleChoice(GLMPreTrainedModel):
801
+ def __init__(self, config):
802
+ super().__init__(config)
803
+ self.glm = GLMModel(config)
804
+ self.post_init()
805
+
806
+ def forward(
807
+ self,
808
+ input_ids=None,
809
+ position_ids=None,
810
+ attention_mask=None,
811
+ choice_ids=None,
812
+ choice_indices=None,
813
+ labels=None,
814
+ mems=None,
815
+ **kwargs
816
+ ):
817
+ model_output = self.glm(input_ids, position_ids, attention_mask, mems=mems, **kwargs)
818
+ lm_logits = model_output.logits
819
+ log_probs = []
820
+ for output, choices, choice_index in zip(F.log_softmax(lm_logits, dim=-1), choice_ids, choice_indices):
821
+ log_probs_single = []
822
+ for choice, choice_target_id in zip(choices, choice_index):
823
+ tmp = output[choice_target_id, choice]
824
+ log_probs_single.append(tmp.sum())
825
+ log_probs.append(torch.stack(log_probs_single))
826
+ log_probs = torch.stack(log_probs)
827
+ loss = None
828
+ if labels is not None:
829
+ loss_fct = CrossEntropyLoss()
830
+ loss = loss_fct(log_probs, labels)
831
+ return ModelOutput(
832
+ loss=loss,
833
+ logits=log_probs,
834
+ lm_logits=lm_logits,
835
+ mems=model_output.mems
836
+ )
837
+
838
  @add_start_docstrings(
839
  """GLM Model transformer with a `language modeling` head on top""",
840
  GLM_START_DOCSTRING,
 
873
  position_ids = position_ids[:, :, :seq_length]
874
  if attention_mask is not None:
875
  attention_mask = attention_mask[:, :, :seq_length, :seq_length]
876
+ if position_ids is not None and input_ids.size(0) > position_ids.size(0):
877
+ batch_size = position_ids.size(0)
878
+ num_beams = input_ids.size(0) // batch_size
879
+ position_ids = position_ids.unsqueeze(1).expand(-1, num_beams, -1, -1)
880
+ position_ids = position_ids.reshape(batch_size * num_beams, *position_ids.shape[-2:])
881
+ if attention_mask is not None and input_ids.size(0) > attention_mask.size(0):
882
+ batch_size = attention_mask.size(0)
883
+ num_beams = input_ids.size(0) // batch_size
884
+ attention_mask = attention_mask.unsqueeze(1).expand(-1, num_beams, -1, -1, -1)
885
+ attention_mask = attention_mask.reshape(batch_size * num_beams, *attention_mask.shape[-3:])
886
  return {
887
  "input_ids": input_ids,
888
  "position_ids": position_ids,
 
895
  input_ids=None,
896
  position_ids=None,
897
  attention_mask=None,
898
+ labels=None,
899
  mems=None,
900
  **kwargs
901
  ):
902
+ model_output = self.glm(input_ids, position_ids, attention_mask, mems=mems, **kwargs)
903
+ lm_logits = model_output.logits
904
+ loss = None
905
+ if labels is not None:
906
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
907
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
908
+ return ModelOutput(
909
+ loss=loss,
910
+ logits=lm_logits,
911
+ mems=model_output.mems
912
+ )
913
 
914
 
915
  @add_start_docstrings(
 
918
  GLM_START_DOCSTRING,
919
  )
920
  class GLMForSequenceClassification(GLMPreTrainedModel):
921
+ def __init__(self, config: GLMConfig, hidden_dropout=None, num_class=1):
922
  super().__init__(config)
923
  self.pool_token = config.pool_token
924
  self.glm = GLMModel(config)
925
  self.glm.output_predict = False
926
  self.num_class = num_class
927
  # Multi-choice head.
928
+ self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
929
+ classifier_dropout = (
930
+ config.classifier_dropout if config.classifier_dropout is not None else config.output_dropout_prob
931
+ )
932
+ self.dropout = torch.nn.Dropout(classifier_dropout)
933
+ self.out_proj = torch.nn.Linear(config.hidden_size, config.num_labels)
934
 
935
  # Initialize weights and apply final processing
936
  self.post_init()
 
955
  input_ids = input_ids.reshape(-1, input_ids.size(-1))
956
  attention_mask = attention_mask.reshape(-1, *attention_mask.size()[2:])
957
  position_ids = position_ids.reshape(-1, *position_ids.size()[2:])
958
+ model_out = self.glm(input_ids, position_ids, attention_mask)
959
+ outputs, mems = model_out.last_hidden_states, model_out.mems
 
 
 
 
 
 
 
 
 
 
 
960
 
961
+ output = outputs[:, 0, :]
962
+ output = self.dropout(output)
963
+ output = torch.tanh(self.dense(output))
964
+ output = self.dropout(output)
965
+ logits = self.out_proj(output)
966
  if num_choices is not None:
967
  logits = logits.view(-1, num_choices)
968
+ loss = None
969
+ if labels is not None:
970
+ loss_fct = CrossEntropyLoss()
971
+ loss = loss_fct(logits, labels)
972
  # loss = F.cross_entropy(logits.contiguous().float(), labels.long())
973
  return SequenceClassifierOutput(loss=loss,
974
  logits=logits,
975
+ hidden_states=outputs)