ErfanMoosaviMonazzah commited on
Commit
189fb58
1 Parent(s): 589b81c

Upload model

Browse files
config.json CHANGED
@@ -42,7 +42,7 @@
42
  "summary_type": "cls_index",
43
  "summary_use_proj": true,
44
  "torch_dtype": "float32",
45
- "transformers_version": "4.31.0",
46
  "use_cache": true,
47
  "vocab_size": 50264
48
  }
 
42
  "summary_type": "cls_index",
43
  "summary_use_proj": true,
44
  "torch_dtype": "float32",
45
+ "transformers_version": "4.35.2",
46
  "use_cache": true,
47
  "vocab_size": 50264
48
  }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a540f442266937e77756ce8730b0f0492dc80e51c21522a7e86984e3c6d21bfd
3
+ size 682724236
modeling_backpack_gpt2_nli.py CHANGED
@@ -52,4 +52,12 @@ class BackpackGPT2NLIModel(GPT2PreTrainedModel):
52
  loss = self.loss_func(flat_logits, flat_labels)
53
  return {'logits': logits, 'loss': loss}
54
  else:
55
- return {'logits': logits}
 
 
 
 
 
 
 
 
 
52
  loss = self.loss_func(flat_logits, flat_labels)
53
  return {'logits': logits, 'loss': loss}
54
  else:
55
+ return {'logits': logits}
56
+
57
+
58
+ def predict(self, input_ids=None, attention_mask=None):
59
+ logits = self.forward(input_ids, attention_mask, labels=None)
60
+ p = torch.argmax(p, axis=1)
61
+ labels = [self.config.id2label[index] for index in p]
62
+ return labels
63
+