ESG-BERT-Reddit / modeling_ESGBertReddit.py
admation's picture
Upload model
4b73a79
raw
history blame
1.32 kB
import torch
import torch.nn as nn
from transformers import PreTrainedModel,BertModel
from ESGBertReddit_model.configuration_ESGBertReddit import ESGRedditConfig
class ClassificationModel(PreTrainedModel):
config_class = ESGRedditConfig
def __init__(self,config):
super().__init__(config)
self.bert = BertModel.from_pretrained('yiyanghkust/finbert-esg',output_attentions=True)
self.W = nn.Linear(self.bert.config.hidden_size, config.num_classes)
self.num_classes = config.num_classes
def forward(self,input_ids,attention_mask,token_type_ids,**kw):
h, _, attn = self.bert(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids).values()
h_cls = h[:,0,:]
output = self.W(h_cls)
return output, attn
class BertModelForESGClassification(PreTrainedModel):
config_class = ESGRedditConfig
def __init__(self, config):
super().__init__(config)
self.model = ClassificationModel(config)
def forward(self,**inputs):
logits,_ = self.model(**inputs)
if "labels" in inputs:
loss = torch.nn.cross_entropy(logits, labels)
return {"loss": loss, "logits": logits}
return {"logits": logits}