compatible with DirectML

#71
by davinwang - opened
Files changed (1) hide show
  1. modeling_chatglm.py +1 -1
modeling_chatglm.py CHANGED
@@ -1135,7 +1135,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1135
  )
1136
  logits_warper = self._get_logits_warper(generation_config)
1137
 
1138
- unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
1139
  scores = None
1140
  while True:
1141
  model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
 
1135
  )
1136
  logits_warper = self._get_logits_warper(generation_config)
1137
 
1138
+ unfinished_sequences = torch.ones(input_ids.shape[0], device=input_ids.device, dtype=input_ids.dtype)
1139
  scores = None
1140
  while True:
1141
  model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)