LizaKovtun commited on
Commit
a58b94a
1 Parent(s): f72d251

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration_ESGify.py +141 -0
  2. modeling_ESGify.py +38 -0
configuration_ESGify.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List, Dict
3
+
4
+
5
+ class ESGifyConfig(PretrainedConfig):
6
+ model_type = "mpnet"
7
+
8
+ def __init__(
9
+ self,
10
+ attention_probs_dropout_prob: float = 0.1,
11
+ bos_token_id: int = 0,
12
+ eos_token_id: int = 2,
13
+ hidden_act: str = "gelu",
14
+ hidden_dropout_prob: float = 0.1,
15
+ hidden_size: int = 768,
16
+ initializer_range: float = 0.02,
17
+ intermediate_size: int = 3072,
18
+ layer_norm_eps: float = 1e-05,
19
+ max_position_embeddings: int = 514,
20
+ num_attention_heads: int = 12,
21
+ num_hidden_layers: int = 12,
22
+ output_attentions: bool = True,
23
+ pad_token_id: int = 1,
24
+ relative_attention_num_buckets: int = 32,
25
+ vocab_size: int = 30531,
26
+ id2label: Dict = {"0": "Legal Proceedings & Law Violations",
27
+ "1": "Biodiversity",
28
+ "2": "Communities Health and Safety",
29
+ "3": "Land Acquisition and Resettlement (S)",
30
+ "4": "Emergencies (Social)",
31
+ "5": "Corporate Governance",
32
+ "6": "Responsible Investment & Greenwashing",
33
+ "7": "Not Relevant to ESG",
34
+ "8": "Economic Crime",
35
+ "9": "Emergencies (Environmental)",
36
+ "10": "Hazardous Materials Management",
37
+ "11": "Environmental Management",
38
+ "12": "Landscape Transformation",
39
+ "13": "Human Rights",
40
+ "14": "Climate Risks",
41
+ "15": "Labor Relations Management",
42
+ "16": "Freedom of Association and Right to Organise",
43
+ "17": "Employee Health and Safety",
44
+ "18": "Surface Water Pollution",
45
+ "19": "Animal Welfare",
46
+ "20": "Water Consumption",
47
+ "21": "Disclosure",
48
+ "22": "Product Safety and Quality",
49
+ "23": "Greenhouse Gas Emissions",
50
+ "24": "Indigenous People",
51
+ "25": "Cultural Heritage",
52
+ "26": "Air Pollution",
53
+ "27": "Waste Management",
54
+ "28": "Soil and Groundwater Impact",
55
+ "29": "Forced Labour",
56
+ "30": "Wastewater Management",
57
+ "31": "Natural Resources",
58
+ "32": "Physical Impacts",
59
+ "33": "Values and Ethics",
60
+ "34": "Risk Management and Internal Control",
61
+ "35": "Supply Chain (Environmental)",
62
+ "36": "Supply Chain (Social)",
63
+ "37": "Discrimination",
64
+ "38": "Minimum Age and Child Labour",
65
+ "39": "Planning Limitations",
66
+ "40": "Data Safety",
67
+ "41": "Strategy Implementation",
68
+ "42": "Energy Efficiency and Renewables",
69
+ "43": "Land Acquisition and Resettlement (E)",
70
+ "44": "Supply Chain (Economic / Governance)",
71
+ "45": "Land Rehabilitation",
72
+ "46": "Retrenchment"
73
+ },
74
+ label2id: Dict = {"Legal Proceedings & Law Violations": "0",
75
+ "Biodiversity": "1",
76
+ "Communities Health and Safety": "2",
77
+ "Land Acquisition and Resettlement (S)": "3",
78
+ "Emergencies (Social)": "4",
79
+ "Corporate Governance": "5",
80
+ "Responsible Investment & Greenwashing": "6",
81
+ "Not Relevant to ESG": "7",
82
+ "Economic Crime": "8",
83
+ "Emergencies (Environmental)": "9",
84
+ "Hazardous Materials Management": "10",
85
+ "Environmental Management": "11",
86
+ "Landscape Transformation": "12",
87
+ "Human Rights": "13",
88
+ "Climate Risks": "14",
89
+ "Labor Relations Management": "15",
90
+ "Freedom of Association and Right to Organise": "16",
91
+ "Employee Health and Safety": "17",
92
+ "Surface Water Pollution": "18",
93
+ "Animal Welfare": "19",
94
+ "Water Consumption": "20",
95
+ "Disclosure": "21",
96
+ "Product Safety and Quality": "22",
97
+ "Greenhouse Gas Emissions": "23",
98
+ "Indigenous People": "24",
99
+ "Cultural Heritage": "25",
100
+ "Air Pollution": "26",
101
+ "Waste Management": "27",
102
+ "Soil and Groundwater Impact": "28",
103
+ "Forced Labour": "29",
104
+ "Wastewater Management": "30",
105
+ "Natural Resources": "31",
106
+ "Physical Impacts": "32",
107
+ "Values and Ethics": "33",
108
+ "Risk Management and Internal Control": "34",
109
+ "Supply Chain (Environmental)": "35",
110
+ "Supply Chain (Social)": "36",
111
+ "Discrimination": "37",
112
+ "Minimum Age and Child Labour": "38",
113
+ "Planning Limitations": "39",
114
+ "Data Safety": "40",
115
+ "Strategy Implementation": "41",
116
+ "Energy Efficiency and Renewables": "42",
117
+ "Land Acquisition and Resettlement (E)": "43",
118
+ "Supply Chain (Economic / Governance)": "44",
119
+ "Land Rehabilitation": "45",
120
+ "Retrenchment": "46"},
121
+ **kwargs,
122
+ ):
123
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
124
+ self.bos_token_id = bos_token_id,
125
+ self.eos_token_id = eos_token_id,
126
+ self.hidden_act = hidden_act,
127
+ self.hidden_dropout_prob = hidden_dropout_prob,
128
+ self.hidden_size = hidden_size,
129
+ self.initializer_range = initializer_range,
130
+ self.intermediate_size = intermediate_size,
131
+ self.layer_norm_eps = layer_norm_eps
132
+ self.max_position_embeddings = max_position_embeddings,
133
+ self.num_attention_heads = num_attention_heads,
134
+ self.num_hidden_layers = num_hidden_layers,
135
+ self.output_attentions = output_attentions,
136
+ self.pad_token_id = pad_token_id,
137
+ self.relative_attention_num_buckets = relative_attention_num_buckets,
138
+ self.vocab_size = vocab_size,
139
+ self.id2label = id2label,
140
+ self.label2id = label2id
141
+ super().__init__(**kwargs)
modeling_ESGify.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from transformers import MPNetPreTrainedModel, MPNetModel
3
+ from .configuration_ESGify import ESGifyConfig
4
+ import torch
5
+
6
+ class ESGify(MPNetPreTrainedModel):
7
+ """Model for Classification ESG risks from text."""
8
+ config_class = ESGifyConfig
9
+
10
+ def __init__(self, config): #tuning only the head
11
+ super().__init__(config)
12
+ # Instantiate Parts of model
13
+ self.mpnet = MPNetModel(config,add_pooling_layer=False)
14
+ self.id2label = config.id2label
15
+ self.label2id = config.label2id
16
+ self.classifier = torch.nn.Sequential(OrderedDict([('norm',torch.nn.BatchNorm1d(768)),
17
+ ('linear',torch.nn.Linear(768,512)),
18
+ ('act',torch.nn.ReLU()),
19
+ ('batch_n',torch.nn.BatchNorm1d(512)),
20
+ ('drop_class', torch.nn.Dropout(0.2)),
21
+ ('class_l',torch.nn.Linear(512 ,47))]))
22
+
23
+ def mean_pooling(model_output, attention_mask):
24
+ token_embeddings = model_output #First element of model_output contains all token embeddings
25
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
26
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
27
+
28
+ def forward(self, input_ids, attention_mask):
29
+ # Feed input to mpnet model
30
+ outputs = self.mpnet(input_ids=input_ids,
31
+ attention_mask=attention_mask)
32
+
33
+ # mean pooling dataset and eed input to classifier to compute logits
34
+ logits = self.classifier(self.mean_pooling(outputs['last_hidden_state'],attention_mask))
35
+
36
+ # apply sigmoid
37
+ logits = 1.0 / (1.0 + torch.exp(-logits))
38
+ return logits