Upload model
Browse files
modeling_ESGBertReddit.py
CHANGED
@@ -11,7 +11,7 @@ class ClassificationModel(PreTrainedModel):
|
|
11 |
self.W = nn.Linear(self.bert.config.hidden_size, config.num_classes)
|
12 |
self.num_classes = config.num_classes
|
13 |
|
14 |
-
def forward(self,
|
15 |
h, _, attn = self.bert(input_ids=input_ids,
|
16 |
attention_mask=attention_mask,
|
17 |
token_type_ids=token_type_ids).values()
|
@@ -26,9 +26,9 @@ class BertModelForESGClassification(PreTrainedModel):
|
|
26 |
super().__init__(config)
|
27 |
self.model = ClassificationModel(config)
|
28 |
|
29 |
-
def forward(self
|
30 |
logits,_ = self.model(**inputs)
|
31 |
-
if labels
|
32 |
loss = torch.nn.cross_entropy(logits, labels)
|
33 |
return {"loss": loss, "logits": logits}
|
34 |
return {"logits": logits}
|
|
|
11 |
self.W = nn.Linear(self.bert.config.hidden_size, config.num_classes)
|
12 |
self.num_classes = config.num_classes
|
13 |
|
14 |
+
def forward(self,input_ids,attention_mask,token_type_ids,**kw):
|
15 |
h, _, attn = self.bert(input_ids=input_ids,
|
16 |
attention_mask=attention_mask,
|
17 |
token_type_ids=token_type_ids).values()
|
|
|
26 |
super().__init__(config)
|
27 |
self.model = ClassificationModel(config)
|
28 |
|
29 |
+
def forward(self,**inputs):
|
30 |
logits,_ = self.model(**inputs)
|
31 |
+
if "labels" in inputs:
|
32 |
loss = torch.nn.cross_entropy(logits, labels)
|
33 |
return {"loss": loss, "logits": logits}
|
34 |
return {"logits": logits}
|