admation commited on
Commit
4b73a79
1 Parent(s): ff209c2

Upload model

Browse files
Files changed (1) hide show
  1. modeling_ESGBertReddit.py +3 -3
modeling_ESGBertReddit.py CHANGED
@@ -11,7 +11,7 @@ class ClassificationModel(PreTrainedModel):
11
  self.W = nn.Linear(self.bert.config.hidden_size, config.num_classes)
12
  self.num_classes = config.num_classes
13
 
14
- def forward(self, input_ids, attention_mask, token_type_ids):
15
  h, _, attn = self.bert(input_ids=input_ids,
16
  attention_mask=attention_mask,
17
  token_type_ids=token_type_ids).values()
@@ -26,9 +26,9 @@ class BertModelForESGClassification(PreTrainedModel):
26
  super().__init__(config)
27
  self.model = ClassificationModel(config)
28
 
29
- def forward(self, inputs, labels=None):
30
  logits,_ = self.model(**inputs)
31
- if labels is not None:
32
  loss = torch.nn.cross_entropy(logits, labels)
33
  return {"loss": loss, "logits": logits}
34
  return {"logits": logits}
 
11
  self.W = nn.Linear(self.bert.config.hidden_size, config.num_classes)
12
  self.num_classes = config.num_classes
13
 
14
+ def forward(self,input_ids,attention_mask,token_type_ids,**kw):
15
  h, _, attn = self.bert(input_ids=input_ids,
16
  attention_mask=attention_mask,
17
  token_type_ids=token_type_ids).values()
 
26
  super().__init__(config)
27
  self.model = ClassificationModel(config)
28
 
29
+ def forward(self,**inputs):
30
  logits,_ = self.model(**inputs)
31
+ if "labels" in inputs:
32
  loss = torch.nn.cross_entropy(logits, labels)
33
  return {"loss": loss, "logits": logits}
34
  return {"logits": logits}