ErfanMoosaviMonazzah commited on
Commit
aa65e77
1 Parent(s): 3647818

Upload model

Browse files
Files changed (1) hide show
  1. modeling_backpack_gpt2_nli.py +1 -1
modeling_backpack_gpt2_nli.py CHANGED
@@ -58,6 +58,6 @@ class BackpackGPT2NLIModel(GPT2PreTrainedModel):
58
  def predict(self, input_ids=None, attention_mask=None):
59
  logits = self.forward(input_ids, attention_mask, labels=None)['logits']
60
  p = torch.argmax(logits, axis=1)
61
- labels = [self.config.id2label[index] for index in p]
62
  return labels
63
 
 
58
  def predict(self, input_ids=None, attention_mask=None):
59
  logits = self.forward(input_ids, attention_mask, labels=None)['logits']
60
  p = torch.argmax(logits, axis=1)
61
+ labels = [self.config.id2label[index.item()] for index in p]
62
  return labels
63