File size: 1,324 Bytes
25c44d8
 
 
 
 
 
 
 
 
 
 
 
 
4b73a79
25c44d8
 
 
 
 
 
 
 
 
 
 
 
 
 
4b73a79
25c44d8
4b73a79
25c44d8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.nn as nn
from transformers import PreTrainedModel,BertModel
from ESGBertReddit_model.configuration_ESGBertReddit import ESGRedditConfig
class ClassificationModel(PreTrainedModel):
    config_class = ESGRedditConfig

    def __init__(self,config):
        super().__init__(config)
        self.bert = BertModel.from_pretrained('yiyanghkust/finbert-esg',output_attentions=True)
        self.W = nn.Linear(self.bert.config.hidden_size, config.num_classes)
        self.num_classes = config.num_classes
        
    def forward(self,input_ids,attention_mask,token_type_ids,**kw):
        h, _, attn = self.bert(input_ids=input_ids, 
                               attention_mask=attention_mask,
                               token_type_ids=token_type_ids).values()
        h_cls = h[:,0,:]
        output = self.W(h_cls)
        return output, attn

class BertModelForESGClassification(PreTrainedModel):
    config_class = ESGRedditConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = ClassificationModel(config)

    def forward(self,**inputs):
        logits,_ = self.model(**inputs)
        if "labels" in inputs:
            loss = torch.nn.cross_entropy(logits, labels)
            return {"loss": loss, "logits": logits}
        return {"logits": logits}