admation commited on
Commit
25c44d8
1 Parent(s): 4c49264

Upload model

Browse files
config.json CHANGED
@@ -1,4 +1,11 @@
1
  {
 
 
 
 
 
 
 
2
  "id2label": {
3
  "0": "D",
4
  "1": "C",
@@ -13,5 +20,6 @@
13
  },
14
  "model_type": "ESGBertReddit",
15
  "num_classes": 4,
 
16
  "transformers_version": "4.24.0"
17
  }
 
1
  {
2
+ "architectures": [
3
+ "BertModelForESGClassification"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_ESGBertReddit.ESGRedditConfig",
7
+ "AutoModelForSequenceClassification": "modeling_ESGBertReddit.BertModelForESGClassification"
8
+ },
9
  "id2label": {
10
  "0": "D",
11
  "1": "C",
 
20
  },
21
  "model_type": "ESGBertReddit",
22
  "num_classes": 4,
23
+ "torch_dtype": "float32",
24
  "transformers_version": "4.24.0"
25
  }
configuration_ESGBertReddit.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ class ESGRedditConfig(PretrainedConfig):
3
+ model_type = "ESGBertReddit"
4
+
5
+ def __init__(
6
+ self,
7
+ architectures = ["BertForSequenceClassification"],
8
+ num_classes: int = 4,
9
+ **kwargs
10
+ ):
11
+ self.architectures = architectures
12
+ self.num_classes = num_classes
13
+ super().__init__(**kwargs)
modeling_ESGBertReddit.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel,BertModel
4
+ from ESGBertReddit_model.configuration_ESGBertReddit import ESGRedditConfig
5
+ class ClassificationModel(PreTrainedModel):
6
+ config_class = ESGRedditConfig
7
+
8
+ def __init__(self,config):
9
+ super().__init__(config)
10
+ self.bert = BertModel.from_pretrained('yiyanghkust/finbert-esg',output_attentions=True)
11
+ self.W = nn.Linear(self.bert.config.hidden_size, config.num_classes)
12
+ self.num_classes = config.num_classes
13
+
14
+ def forward(self, input_ids, attention_mask, token_type_ids):
15
+ h, _, attn = self.bert(input_ids=input_ids,
16
+ attention_mask=attention_mask,
17
+ token_type_ids=token_type_ids).values()
18
+ h_cls = h[:,0,:]
19
+ output = self.W(h_cls)
20
+ return output, attn
21
+
22
+ class BertModelForESGClassification(PreTrainedModel):
23
+ config_class = ESGRedditConfig
24
+
25
+ def __init__(self, config):
26
+ super().__init__(config)
27
+ self.model = ClassificationModel(config)
28
+
29
+ def forward(self, inputs, labels=None):
30
+ logits,_ = self.model(**inputs)
31
+ if labels is not None:
32
+ loss = torch.nn.cross_entropy(logits, labels)
33
+ return {"loss": loss, "logits": logits}
34
+ return {"logits": logits}
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3cfd0f21b11a64ad7bcaa71932f94a52b5fb2feb5efeb4b50f35f6dda8d9a81e
3
+ size 439088877