Fix batch beam search
Browse files- modeling_glm.py +82 -10
modeling_glm.py
CHANGED
@@ -30,6 +30,7 @@ from transformers.utils import (
|
|
30 |
from transformers.modeling_outputs import (
|
31 |
BaseModelOutputWithPastAndCrossAttentions,
|
32 |
ModelOutput,
|
|
|
33 |
)
|
34 |
|
35 |
from transformers.modeling_utils import (
|
@@ -780,17 +781,15 @@ class GLMModel(GLMPreTrainedModel):
|
|
780 |
attention_mask = torch.zeros(batch_size)
|
781 |
# Transformer.
|
782 |
transformer_output = self.transformer(embeddings, position_ids, attention_mask, mems)
|
783 |
-
|
784 |
-
|
785 |
if self.output_predict:
|
786 |
-
|
787 |
-
# logits_parallel = mpu.copy_to_model_parallel_region(
|
788 |
-
# logits)
|
789 |
-
logits = F.linear(logits, self.word_embeddings.weight)
|
790 |
|
791 |
return ModelOutput(
|
|
|
792 |
logits=logits,
|
793 |
-
mems=
|
794 |
)
|
795 |
|
796 |
|
@@ -815,7 +814,7 @@ class GLMForMultipleChoice(GLMPreTrainedModel):
|
|
815 |
mems=None,
|
816 |
**kwargs
|
817 |
):
|
818 |
-
model_output = self.glm
|
819 |
lm_logits = model_output.logits
|
820 |
log_probs = []
|
821 |
for output, choices, choice_index in zip(F.log_softmax(lm_logits, dim=-1), choice_ids, choice_indices):
|
@@ -874,6 +873,16 @@ class GLMForConditionalGeneration(GLMPreTrainedModel):
|
|
874 |
position_ids = position_ids[:, :, :seq_length]
|
875 |
if attention_mask is not None:
|
876 |
attention_mask = attention_mask[:, :, :seq_length, :seq_length]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
877 |
return {
|
878 |
"input_ids": input_ids,
|
879 |
"position_ids": position_ids,
|
@@ -890,7 +899,7 @@ class GLMForConditionalGeneration(GLMPreTrainedModel):
|
|
890 |
mems=None,
|
891 |
**kwargs
|
892 |
):
|
893 |
-
model_output = self.glm
|
894 |
lm_logits = model_output.logits
|
895 |
loss = None
|
896 |
if labels is not None:
|
@@ -900,4 +909,67 @@ class GLMForConditionalGeneration(GLMPreTrainedModel):
|
|
900 |
loss=loss,
|
901 |
logits=lm_logits,
|
902 |
mems=model_output.mems
|
903 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
from transformers.modeling_outputs import (
|
31 |
BaseModelOutputWithPastAndCrossAttentions,
|
32 |
ModelOutput,
|
33 |
+
SequenceClassifierOutput,
|
34 |
)
|
35 |
|
36 |
from transformers.modeling_utils import (
|
|
|
781 |
attention_mask = torch.zeros(batch_size)
|
782 |
# Transformer.
|
783 |
transformer_output = self.transformer(embeddings, position_ids, attention_mask, mems)
|
784 |
+
last_hidden_states, mems = transformer_output
|
785 |
+
logits = None
|
786 |
if self.output_predict:
|
787 |
+
logits = F.linear(last_hidden_states, self.word_embeddings.weight)
|
|
|
|
|
|
|
788 |
|
789 |
return ModelOutput(
|
790 |
+
last_hidden_states=last_hidden_states,
|
791 |
logits=logits,
|
792 |
+
mems=mems,
|
793 |
)
|
794 |
|
795 |
|
|
|
814 |
mems=None,
|
815 |
**kwargs
|
816 |
):
|
817 |
+
model_output = self.glm(input_ids, position_ids, attention_mask, mems=mems, **kwargs)
|
818 |
lm_logits = model_output.logits
|
819 |
log_probs = []
|
820 |
for output, choices, choice_index in zip(F.log_softmax(lm_logits, dim=-1), choice_ids, choice_indices):
|
|
|
873 |
position_ids = position_ids[:, :, :seq_length]
|
874 |
if attention_mask is not None:
|
875 |
attention_mask = attention_mask[:, :, :seq_length, :seq_length]
|
876 |
+
if position_ids is not None and input_ids.size(0) > position_ids.size(0):
|
877 |
+
batch_size = position_ids.size(0)
|
878 |
+
num_beams = input_ids.size(0) // batch_size
|
879 |
+
position_ids = position_ids.unsqueeze(1).expand(-1, num_beams, -1, -1)
|
880 |
+
position_ids = position_ids.reshape(batch_size * num_beams, *position_ids.shape[-2:])
|
881 |
+
if attention_mask is not None and input_ids.size(0) > attention_mask.size(0):
|
882 |
+
batch_size = attention_mask.size(0)
|
883 |
+
num_beams = input_ids.size(0) // batch_size
|
884 |
+
attention_mask = attention_mask.unsqueeze(1).expand(-1, num_beams, -1, -1, -1)
|
885 |
+
attention_mask = attention_mask.reshape(batch_size * num_beams, *attention_mask.shape[-3:])
|
886 |
return {
|
887 |
"input_ids": input_ids,
|
888 |
"position_ids": position_ids,
|
|
|
899 |
mems=None,
|
900 |
**kwargs
|
901 |
):
|
902 |
+
model_output = self.glm(input_ids, position_ids, attention_mask, mems=mems, **kwargs)
|
903 |
lm_logits = model_output.logits
|
904 |
loss = None
|
905 |
if labels is not None:
|
|
|
909 |
loss=loss,
|
910 |
logits=lm_logits,
|
911 |
mems=model_output.mems
|
912 |
+
)
|
913 |
+
|
914 |
+
|
915 |
+
@add_start_docstrings(
|
916 |
+
"""GLM Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
917 |
+
the pooled output) e.g. for GLUE tasks. """,
|
918 |
+
GLM_START_DOCSTRING,
|
919 |
+
)
|
920 |
+
class GLMForSequenceClassification(GLMPreTrainedModel):
|
921 |
+
def __init__(self, config: GLMConfig, hidden_dropout=None, num_class=1):
|
922 |
+
super().__init__(config)
|
923 |
+
self.pool_token = config.pool_token
|
924 |
+
self.glm = GLMModel(config)
|
925 |
+
self.glm.output_predict = False
|
926 |
+
self.num_class = num_class
|
927 |
+
# Multi-choice head.
|
928 |
+
self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
929 |
+
classifier_dropout = (
|
930 |
+
config.classifier_dropout if config.classifier_dropout is not None else config.output_dropout_prob
|
931 |
+
)
|
932 |
+
self.dropout = torch.nn.Dropout(classifier_dropout)
|
933 |
+
self.out_proj = torch.nn.Linear(config.hidden_size, config.num_labels)
|
934 |
+
|
935 |
+
# Initialize weights and apply final processing
|
936 |
+
self.post_init()
|
937 |
+
|
938 |
+
@add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
939 |
+
@add_code_sample_docstrings(
|
940 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
941 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
942 |
+
output_type=SequenceClassifierOutput,
|
943 |
+
config_class=_CONFIG_FOR_DOC,
|
944 |
+
)
|
945 |
+
def forward(self,
|
946 |
+
input_ids=None,
|
947 |
+
position_ids=None,
|
948 |
+
attention_mask=None,
|
949 |
+
labels=None):
|
950 |
+
|
951 |
+
num_choices = None
|
952 |
+
|
953 |
+
if len(input_ids.shape) == 3:
|
954 |
+
batch_size, num_choices = input_ids.shape[:2]
|
955 |
+
input_ids = input_ids.reshape(-1, input_ids.size(-1))
|
956 |
+
attention_mask = attention_mask.reshape(-1, *attention_mask.size()[2:])
|
957 |
+
position_ids = position_ids.reshape(-1, *position_ids.size()[2:])
|
958 |
+
model_out = self.glm(input_ids, position_ids, attention_mask)
|
959 |
+
outputs, mems = model_out.last_hidden_states, model_out.mems
|
960 |
+
|
961 |
+
output = outputs[:, 0, :]
|
962 |
+
output = self.dropout(output)
|
963 |
+
output = torch.tanh(self.dense(output))
|
964 |
+
output = self.dropout(output)
|
965 |
+
logits = self.out_proj(output)
|
966 |
+
if num_choices is not None:
|
967 |
+
logits = logits.view(-1, num_choices)
|
968 |
+
loss = None
|
969 |
+
if labels is not None:
|
970 |
+
loss_fct = CrossEntropyLoss()
|
971 |
+
loss = loss_fct(logits, labels)
|
972 |
+
# loss = F.cross_entropy(logits.contiguous().float(), labels.long())
|
973 |
+
return SequenceClassifierOutput(loss=loss,
|
974 |
+
logits=logits,
|
975 |
+
hidden_states=outputs)
|