dragonSwing commited on
Commit
5859a7b
1 Parent(s): 32335bd

Fix device bug

Browse files
Files changed (1) hide show
  1. modeling_seq2labels.py +1 -1
modeling_seq2labels.py CHANGED
@@ -119,5 +119,5 @@ class Seq2LabelsModel(BertPreTrainedModel):
119
  detect_logits=logits_d,
120
  hidden_states=outputs.hidden_states,
121
  attentions=outputs.attentions,
122
- max_error_probability=torch.ones(logits.size(0)),
123
  )
119
  detect_logits=logits_d,
120
  hidden_states=outputs.hidden_states,
121
  attentions=outputs.attentions,
122
+ max_error_probability=torch.ones(logits.size(0), device=logits.device),
123
  )