dudududukim commited on
Commit
d7784c4
1 Parent(s): 2b383ee
model_bert_concat/__init__.py ADDED
File without changes
model_bert_concat/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (181 Bytes). View file
 
model_bert_concat/__pycache__/configuration_bert_concat.cpython-310.pyc ADDED
Binary file (698 Bytes). View file
 
model_bert_concat/__pycache__/modeling_bert_concat.cpython-310.pyc ADDED
Binary file (1.9 kB). View file
 
model_bert_concat/configuration_bert_concat.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class BertConcatConfig(PretrainedConfig):
4
+ def __init__(self, bert_model_name='klue/bert-base', num_labels=2, **kwargs):
5
+ super().__init__(**kwargs)
6
+ self.bert_model_name = bert_model_name
7
+ self.num_labels = num_labels
model_bert_concat/modeling_bert_concat.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, AutoModel
2
+ from transformers.modeling_outputs import SequenceClassifierOutput
3
+ import torch
4
+ import torch.nn as nn
5
+ from .configuration_bert_concat import BertConcatConfig
6
+
7
+
8
+ class BertConcatClassifier(PreTrainedModel):
9
+ config_class = BertConcatConfig
10
+
11
+ def __init__(self, config):
12
+ super().__init__(config)
13
+ self.bert = AutoModel.from_pretrained(config.bert_model_name, output_hidden_states=True)
14
+ self.num_labels = config.num_labels
15
+
16
+ # Classification layers
17
+ self.conv = nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1) # 3x768 -> 1x768
18
+ self.relu = nn.ReLU()
19
+ self.classifier = nn.Linear(768, self.num_labels)
20
+
21
+ def forward(self, input_ids, attention_mask=None, labels=None):
22
+ outputs = self.bert(input_ids, attention_mask=attention_mask)
23
+ hidden_states = outputs.hidden_states
24
+
25
+ # Concatenate the vectors as per custom model design
26
+ last_cls_vector = hidden_states[-1][:, 0, :]
27
+ fourth_last_cls_vector = hidden_states[-4][:, 0, :]
28
+ mean_pooled_vector = hidden_states[-1].mean(dim=1)
29
+
30
+ concatenated_vector = torch.cat(
31
+ (last_cls_vector.unsqueeze(1),
32
+ fourth_last_cls_vector.unsqueeze(1),
33
+ mean_pooled_vector.unsqueeze(1)),
34
+ dim=1
35
+ )
36
+
37
+ # Apply convolution and linear layers
38
+ conv_output = self.conv(concatenated_vector).squeeze(2)
39
+ relu_output = self.relu(conv_output)
40
+ logits = self.classifier(relu_output)
41
+ logits = logits.squeeze(1)
42
+
43
+ loss = None
44
+ if labels is not None:
45
+ loss_fct = nn.CrossEntropyLoss()
46
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
47
+
48
+ return SequenceClassifierOutput(
49
+ loss=loss,
50
+ logits=logits,
51
+ hidden_states=outputs.hidden_states,
52
+ attentions=outputs.attentions
53
+ )