davinwang commited on
Commit
e99e8a1
1 Parent(s): b1502f4

compatible with DirectML

Browse files

Tensor.new is a deprecated constructor and does not support PrivateUse1 in pytorch 1.13.1/2.0.0, use torch.ones instead. Please refer to https://github.com/microsoft/DirectML/issues/400 and https://github.com/pytorch/pytorch/issues/95734

Files changed (1) hide show
  1. modeling_chatglm.py +1 -1
modeling_chatglm.py CHANGED
@@ -1135,7 +1135,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1135
  )
1136
  logits_warper = self._get_logits_warper(generation_config)
1137
 
1138
- unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
1139
  scores = None
1140
  while True:
1141
  model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
 
1135
  )
1136
  logits_warper = self._get_logits_warper(generation_config)
1137
 
1138
+ unfinished_sequences = torch.ones(input_ids.shape[0], device=input_ids.device, dtype=input_ids.dtype)
1139
  scores = None
1140
  while True:
1141
  model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)