LizaKovtun
commited on
Commit
•
a58b94a
1
Parent(s):
f72d251
Upload 2 files
Browse files- configuration_ESGify.py +141 -0
- modeling_ESGify.py +38 -0
configuration_ESGify.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
from typing import List, Dict
|
3 |
+
|
4 |
+
|
5 |
+
class ESGifyConfig(PretrainedConfig):
|
6 |
+
model_type = "mpnet"
|
7 |
+
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
attention_probs_dropout_prob: float = 0.1,
|
11 |
+
bos_token_id: int = 0,
|
12 |
+
eos_token_id: int = 2,
|
13 |
+
hidden_act: str = "gelu",
|
14 |
+
hidden_dropout_prob: float = 0.1,
|
15 |
+
hidden_size: int = 768,
|
16 |
+
initializer_range: float = 0.02,
|
17 |
+
intermediate_size: int = 3072,
|
18 |
+
layer_norm_eps: float = 1e-05,
|
19 |
+
max_position_embeddings: int = 514,
|
20 |
+
num_attention_heads: int = 12,
|
21 |
+
num_hidden_layers: int = 12,
|
22 |
+
output_attentions: bool = True,
|
23 |
+
pad_token_id: int = 1,
|
24 |
+
relative_attention_num_buckets: int = 32,
|
25 |
+
vocab_size: int = 30531,
|
26 |
+
id2label: Dict = {"0": "Legal Proceedings & Law Violations",
|
27 |
+
"1": "Biodiversity",
|
28 |
+
"2": "Communities Health and Safety",
|
29 |
+
"3": "Land Acquisition and Resettlement (S)",
|
30 |
+
"4": "Emergencies (Social)",
|
31 |
+
"5": "Corporate Governance",
|
32 |
+
"6": "Responsible Investment & Greenwashing",
|
33 |
+
"7": "Not Relevant to ESG",
|
34 |
+
"8": "Economic Crime",
|
35 |
+
"9": "Emergencies (Environmental)",
|
36 |
+
"10": "Hazardous Materials Management",
|
37 |
+
"11": "Environmental Management",
|
38 |
+
"12": "Landscape Transformation",
|
39 |
+
"13": "Human Rights",
|
40 |
+
"14": "Climate Risks",
|
41 |
+
"15": "Labor Relations Management",
|
42 |
+
"16": "Freedom of Association and Right to Organise",
|
43 |
+
"17": "Employee Health and Safety",
|
44 |
+
"18": "Surface Water Pollution",
|
45 |
+
"19": "Animal Welfare",
|
46 |
+
"20": "Water Consumption",
|
47 |
+
"21": "Disclosure",
|
48 |
+
"22": "Product Safety and Quality",
|
49 |
+
"23": "Greenhouse Gas Emissions",
|
50 |
+
"24": "Indigenous People",
|
51 |
+
"25": "Cultural Heritage",
|
52 |
+
"26": "Air Pollution",
|
53 |
+
"27": "Waste Management",
|
54 |
+
"28": "Soil and Groundwater Impact",
|
55 |
+
"29": "Forced Labour",
|
56 |
+
"30": "Wastewater Management",
|
57 |
+
"31": "Natural Resources",
|
58 |
+
"32": "Physical Impacts",
|
59 |
+
"33": "Values and Ethics",
|
60 |
+
"34": "Risk Management and Internal Control",
|
61 |
+
"35": "Supply Chain (Environmental)",
|
62 |
+
"36": "Supply Chain (Social)",
|
63 |
+
"37": "Discrimination",
|
64 |
+
"38": "Minimum Age and Child Labour",
|
65 |
+
"39": "Planning Limitations",
|
66 |
+
"40": "Data Safety",
|
67 |
+
"41": "Strategy Implementation",
|
68 |
+
"42": "Energy Efficiency and Renewables",
|
69 |
+
"43": "Land Acquisition and Resettlement (E)",
|
70 |
+
"44": "Supply Chain (Economic / Governance)",
|
71 |
+
"45": "Land Rehabilitation",
|
72 |
+
"46": "Retrenchment"
|
73 |
+
},
|
74 |
+
label2id: Dict = {"Legal Proceedings & Law Violations": "0",
|
75 |
+
"Biodiversity": "1",
|
76 |
+
"Communities Health and Safety": "2",
|
77 |
+
"Land Acquisition and Resettlement (S)": "3",
|
78 |
+
"Emergencies (Social)": "4",
|
79 |
+
"Corporate Governance": "5",
|
80 |
+
"Responsible Investment & Greenwashing": "6",
|
81 |
+
"Not Relevant to ESG": "7",
|
82 |
+
"Economic Crime": "8",
|
83 |
+
"Emergencies (Environmental)": "9",
|
84 |
+
"Hazardous Materials Management": "10",
|
85 |
+
"Environmental Management": "11",
|
86 |
+
"Landscape Transformation": "12",
|
87 |
+
"Human Rights": "13",
|
88 |
+
"Climate Risks": "14",
|
89 |
+
"Labor Relations Management": "15",
|
90 |
+
"Freedom of Association and Right to Organise": "16",
|
91 |
+
"Employee Health and Safety": "17",
|
92 |
+
"Surface Water Pollution": "18",
|
93 |
+
"Animal Welfare": "19",
|
94 |
+
"Water Consumption": "20",
|
95 |
+
"Disclosure": "21",
|
96 |
+
"Product Safety and Quality": "22",
|
97 |
+
"Greenhouse Gas Emissions": "23",
|
98 |
+
"Indigenous People": "24",
|
99 |
+
"Cultural Heritage": "25",
|
100 |
+
"Air Pollution": "26",
|
101 |
+
"Waste Management": "27",
|
102 |
+
"Soil and Groundwater Impact": "28",
|
103 |
+
"Forced Labour": "29",
|
104 |
+
"Wastewater Management": "30",
|
105 |
+
"Natural Resources": "31",
|
106 |
+
"Physical Impacts": "32",
|
107 |
+
"Values and Ethics": "33",
|
108 |
+
"Risk Management and Internal Control": "34",
|
109 |
+
"Supply Chain (Environmental)": "35",
|
110 |
+
"Supply Chain (Social)": "36",
|
111 |
+
"Discrimination": "37",
|
112 |
+
"Minimum Age and Child Labour": "38",
|
113 |
+
"Planning Limitations": "39",
|
114 |
+
"Data Safety": "40",
|
115 |
+
"Strategy Implementation": "41",
|
116 |
+
"Energy Efficiency and Renewables": "42",
|
117 |
+
"Land Acquisition and Resettlement (E)": "43",
|
118 |
+
"Supply Chain (Economic / Governance)": "44",
|
119 |
+
"Land Rehabilitation": "45",
|
120 |
+
"Retrenchment": "46"},
|
121 |
+
**kwargs,
|
122 |
+
):
|
123 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
124 |
+
self.bos_token_id = bos_token_id,
|
125 |
+
self.eos_token_id = eos_token_id,
|
126 |
+
self.hidden_act = hidden_act,
|
127 |
+
self.hidden_dropout_prob = hidden_dropout_prob,
|
128 |
+
self.hidden_size = hidden_size,
|
129 |
+
self.initializer_range = initializer_range,
|
130 |
+
self.intermediate_size = intermediate_size,
|
131 |
+
self.layer_norm_eps = layer_norm_eps
|
132 |
+
self.max_position_embeddings = max_position_embeddings,
|
133 |
+
self.num_attention_heads = num_attention_heads,
|
134 |
+
self.num_hidden_layers = num_hidden_layers,
|
135 |
+
self.output_attentions = output_attentions,
|
136 |
+
self.pad_token_id = pad_token_id,
|
137 |
+
self.relative_attention_num_buckets = relative_attention_num_buckets,
|
138 |
+
self.vocab_size = vocab_size,
|
139 |
+
self.id2label = id2label,
|
140 |
+
self.label2id = label2id
|
141 |
+
super().__init__(**kwargs)
|
modeling_ESGify.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from transformers import MPNetPreTrainedModel, MPNetModel
|
3 |
+
from .configuration_ESGify import ESGifyConfig
|
4 |
+
import torch
|
5 |
+
|
6 |
+
class ESGify(MPNetPreTrainedModel):
|
7 |
+
"""Model for Classification ESG risks from text."""
|
8 |
+
config_class = ESGifyConfig
|
9 |
+
|
10 |
+
def __init__(self, config): #tuning only the head
|
11 |
+
super().__init__(config)
|
12 |
+
# Instantiate Parts of model
|
13 |
+
self.mpnet = MPNetModel(config,add_pooling_layer=False)
|
14 |
+
self.id2label = config.id2label
|
15 |
+
self.label2id = config.label2id
|
16 |
+
self.classifier = torch.nn.Sequential(OrderedDict([('norm',torch.nn.BatchNorm1d(768)),
|
17 |
+
('linear',torch.nn.Linear(768,512)),
|
18 |
+
('act',torch.nn.ReLU()),
|
19 |
+
('batch_n',torch.nn.BatchNorm1d(512)),
|
20 |
+
('drop_class', torch.nn.Dropout(0.2)),
|
21 |
+
('class_l',torch.nn.Linear(512 ,47))]))
|
22 |
+
|
23 |
+
def mean_pooling(model_output, attention_mask):
|
24 |
+
token_embeddings = model_output #First element of model_output contains all token embeddings
|
25 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
26 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
27 |
+
|
28 |
+
def forward(self, input_ids, attention_mask):
|
29 |
+
# Feed input to mpnet model
|
30 |
+
outputs = self.mpnet(input_ids=input_ids,
|
31 |
+
attention_mask=attention_mask)
|
32 |
+
|
33 |
+
# mean pooling dataset and eed input to classifier to compute logits
|
34 |
+
logits = self.classifier(self.mean_pooling(outputs['last_hidden_state'],attention_mask))
|
35 |
+
|
36 |
+
# apply sigmoid
|
37 |
+
logits = 1.0 / (1.0 + torch.exp(-logits))
|
38 |
+
return logits
|