ESGify / modeling_ESGify.py
LizaKovtun's picture
Update modeling_ESGify.py
141b0b7
raw history blame
No virus
1.94 kB
from collections import OrderedDict
from transformers import MPNetPreTrainedModel, MPNetModel
from configuration_ESGify import ESGifyConfig
import torch
class ESGify(MPNetPreTrainedModel):
"""Model for Classification ESG risks from text."""
config_class = ESGifyConfig
def __init__(self, config): #tuning only the head
super().__init__(config)
# Instantiate Parts of model
self.mpnet = MPNetModel(config,add_pooling_layer=False)
self.id2label = config.id2label
self.label2id = config.label2id
self.classifier = torch.nn.Sequential(OrderedDict([('norm',torch.nn.BatchNorm1d(768)),
('linear',torch.nn.Linear(768,512)),
('act',torch.nn.ReLU()),
('batch_n',torch.nn.BatchNorm1d(512)),
('drop_class', torch.nn.Dropout(0.2)),
('class_l',torch.nn.Linear(512 ,47))]))
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def forward(self, input_ids, attention_mask):
# Feed input to mpnet model
outputs = self.mpnet(input_ids=input_ids,
attention_mask=attention_mask)
# mean pooling dataset and eed input to classifier to compute logits
logits = self.classifier(self.mean_pooling(outputs['last_hidden_state'],attention_mask))
# apply sigmoid
logits = 1.0 / (1.0 + torch.exp(-logits))
return logits