LizaKovtun commited on
Commit
fe6c45b
1 Parent(s): b8c0b06

Delete modeling_ESGify.py

Browse files
Files changed (1) hide show
  1. modeling_ESGify.py +0 -38
modeling_ESGify.py DELETED
@@ -1,38 +0,0 @@
1
- from collections import OrderedDict
2
- from transformers import MPNetPreTrainedModel, MPNetModel
3
- from configuration_ESGify import ESGifyConfig
4
- import torch
5
-
6
- class ESGify(MPNetPreTrainedModel):
7
- """Model for Classification ESG risks from text."""
8
- config_class = ESGifyConfig
9
-
10
- def __init__(self, config): #tuning only the head
11
- super().__init__(config)
12
- # Instantiate Parts of model
13
- self.mpnet = MPNetModel(config,add_pooling_layer=False)
14
- self.id2label = config.id2label
15
- self.label2id = config.label2id
16
- self.classifier = torch.nn.Sequential(OrderedDict([('norm',torch.nn.BatchNorm1d(768)),
17
- ('linear',torch.nn.Linear(768,512)),
18
- ('act',torch.nn.ReLU()),
19
- ('batch_n',torch.nn.BatchNorm1d(512)),
20
- ('drop_class', torch.nn.Dropout(0.2)),
21
- ('class_l',torch.nn.Linear(512 ,47))]))
22
-
23
- def mean_pooling(model_output, attention_mask):
24
- token_embeddings = model_output #First element of model_output contains all token embeddings
25
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
26
- return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
27
-
28
- def forward(self, input_ids, attention_mask):
29
- # Feed input to mpnet model
30
- outputs = self.mpnet(input_ids=input_ids,
31
- attention_mask=attention_mask)
32
-
33
- # mean pooling dataset and eed input to classifier to compute logits
34
- logits = self.classifier(self.mean_pooling(outputs['last_hidden_state'],attention_mask))
35
-
36
- # apply sigmoid
37
- logits = 1.0 / (1.0 + torch.exp(-logits))
38
- return logits