shahrukhx01 commited on
Commit
4e98941
1 Parent(s): 254528a

add model files

Browse files
Files changed (2) hide show
  1. multitask_model.py +144 -0
  2. test.py +0 -0
multitask_model.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation borrowed from transformers package and extended to support multiple prediction heads:
3
+
4
+ https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/modeling_bert.py
5
+ """
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
10
+
11
+ import transformers
12
+ from transformers import BertTokenizer
13
+ from transformers import models
14
+ from transformers.modeling_outputs import SequenceClassifierOutput
15
+
16
+ from transformers.models.bert.configuration_bert import BertConfig
17
+ from transformers.models.bert.modeling_bert import (
18
+ BertPreTrainedModel,
19
+ BERT_INPUTS_DOCSTRING,
20
+ _TOKENIZER_FOR_DOC,
21
+ _CHECKPOINT_FOR_DOC,
22
+ _CONFIG_FOR_DOC,
23
+ BertModel,
24
+ )
25
+
26
+ from transformers.file_utils import (
27
+ add_code_sample_docstrings,
28
+ add_start_docstrings_to_model_forward,
29
+ )
30
+
31
+
32
+ class BertForSequenceClassification(BertPreTrainedModel):
33
+ def __init__(self, config, **kwargs):
34
+ super().__init__(transformers.PretrainedConfig())
35
+ self.num_labels = kwargs.get("task_labels_map", {})
36
+ self.config = config
37
+
38
+ self.bert = BertModel(config)
39
+ classifier_dropout = (
40
+ config.classifier_dropout
41
+ if config.classifier_dropout is not None
42
+ else config.hidden_dropout_prob
43
+ )
44
+ self.dropout = nn.Dropout(classifier_dropout)
45
+ ## add task specific output heads
46
+ self.classifier1 = nn.Linear(
47
+ config.hidden_size, list(self.num_labels.values())[0]
48
+ )
49
+ self.classifier2 = nn.Linear(
50
+ config.hidden_size, list(self.num_labels.values())[1]
51
+ )
52
+
53
+ self.init_weights()
54
+
55
+ @add_start_docstrings_to_model_forward(
56
+ BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
57
+ )
58
+ @add_code_sample_docstrings(
59
+ tokenizer_class=_TOKENIZER_FOR_DOC,
60
+ checkpoint=_CHECKPOINT_FOR_DOC,
61
+ output_type=SequenceClassifierOutput,
62
+ config_class=_CONFIG_FOR_DOC,
63
+ )
64
+ def forward(
65
+ self,
66
+ input_ids=None,
67
+ attention_mask=None,
68
+ token_type_ids=None,
69
+ position_ids=None,
70
+ head_mask=None,
71
+ inputs_embeds=None,
72
+ labels=None,
73
+ output_attentions=None,
74
+ output_hidden_states=None,
75
+ return_dict=None,
76
+ task_name=None,
77
+ ):
78
+ r"""
79
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
80
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
81
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
82
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
83
+ """
84
+ return_dict = (
85
+ return_dict if return_dict is not None else self.config.use_return_dict
86
+ )
87
+
88
+ outputs = self.bert(
89
+ input_ids,
90
+ attention_mask=attention_mask,
91
+ token_type_ids=token_type_ids,
92
+ position_ids=position_ids,
93
+ head_mask=head_mask,
94
+ inputs_embeds=inputs_embeds,
95
+ output_attentions=output_attentions,
96
+ output_hidden_states=output_hidden_states,
97
+ return_dict=return_dict,
98
+ )
99
+
100
+ pooled_output = outputs[1]
101
+
102
+ pooled_output = self.dropout(pooled_output)
103
+ logits = None
104
+ if task_name == list(self.num_labels.keys())[0]:
105
+ logits = self.classifier1(pooled_output)
106
+ elif task_name == list(self.num_labels.keys())[1]:
107
+ logits = self.classifier2(pooled_output)
108
+
109
+ loss = None
110
+ if labels is not None:
111
+ if self.config.problem_type is None:
112
+ if self.num_labels[task_name] == 1:
113
+ self.config.problem_type = "regression"
114
+ elif self.num_labels[task_name] > 1 and (
115
+ labels.dtype == torch.long or labels.dtype == torch.int
116
+ ):
117
+ self.config.problem_type = "single_label_classification"
118
+ else:
119
+ self.config.problem_type = "multi_label_classification"
120
+
121
+ if self.config.problem_type == "regression":
122
+ loss_fct = MSELoss()
123
+ if self.num_labels[task_name] == 1:
124
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
125
+ else:
126
+ loss = loss_fct(logits, labels)
127
+ elif self.config.problem_type == "single_label_classification":
128
+ loss_fct = CrossEntropyLoss()
129
+ loss = loss_fct(
130
+ logits.view(-1, self.num_labels[task_name]), labels.view(-1)
131
+ )
132
+ elif self.config.problem_type == "multi_label_classification":
133
+ loss_fct = BCEWithLogitsLoss()
134
+ loss = loss_fct(logits, labels)
135
+ if not return_dict:
136
+ output = (logits,) + outputs[2:]
137
+ return ((loss,) + output) if loss is not None else output
138
+
139
+ return SequenceClassifierOutput(
140
+ loss=loss,
141
+ logits=logits,
142
+ hidden_states=outputs.hidden_states,
143
+ attentions=outputs.attentions,
144
+ )
test.py DELETED
File without changes