DeepLearning101 commited on
Commit
b0ebb46
1 Parent(s): 45311fe

Upload 4 files

Browse files
models/sequence_classification/causal_prompt_cls.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import torch
4
+ import torch.nn as nn
5
+ import transformers
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ from typing import Optional, Tuple, Union
9
+ from torch.nn import CrossEntropyLoss
10
+ from transformers import AutoModelForCausalLM
11
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
12
+ from transformers.models.gpt2.modeling_gpt2 import GPT2PreTrainedModel, GPT2Model, GPT2LMHeadModel
13
+ from transformers.modeling_outputs import ModelOutput
14
+ from tools.runner_utils.log_util import logging
15
+ from tools.model_utils.parameter_freeze import ParameterFreeze
16
+
17
+ logger = logging.getLogger(__name__)
18
+ freezer = ParameterFreeze()
19
+
20
+
21
+ """
22
+ Function: Use Causal LM to prompt for cls
23
+ Notes:
24
+ - For classification, the model only calculate the loss at the position of label, the other position is set as -100
25
+ - During inference, generate result at the last position.
26
+ """
27
+ class PromptGPT2ForSequenceClassification(GPT2PreTrainedModel):
28
+ _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
29
+
30
+ def __init__(self, config):
31
+ super().__init__(config)
32
+ self.transformer = GPT2Model(config)
33
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
34
+
35
+ if self.config.use_freezing:
36
+ self.transformer = freezer.freeze_lm(self.transformer)
37
+
38
+ # Model parallel
39
+ self.model_parallel = False
40
+ self.device_map = None
41
+
42
+ # These attributes should be assigned once the model is initialized
43
+ self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.transformer.device)
44
+
45
+ # Initialize weights and apply final processing
46
+ self.post_init()
47
+
48
+ def freeze_backbone(self, use_freezing: bool=True):
49
+ if use_freezing:
50
+ self.bert = freezer.freeze_lm(self.bert)
51
+ else:
52
+ self.bert = freezer.unfreeze_lm(self.bert)
53
+
54
+ def get_output_embeddings(self):
55
+ return self.lm_head
56
+
57
+ def set_output_embeddings(self, new_embeddings):
58
+ self.lm_head = new_embeddings
59
+
60
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
61
+ token_type_ids = kwargs.get("token_type_ids", None)
62
+ # only last token for inputs_ids if past is defined in kwargs
63
+ if past:
64
+ input_ids = input_ids[:, -1].unsqueeze(-1)
65
+ if token_type_ids is not None:
66
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
67
+
68
+ attention_mask = kwargs.get("attention_mask", None)
69
+ position_ids = kwargs.get("position_ids", None)
70
+
71
+ if attention_mask is not None and position_ids is None:
72
+ # create position_ids on the fly for batch generation
73
+ position_ids = attention_mask.long().cumsum(-1) - 1
74
+ position_ids.masked_fill_(attention_mask == 0, 1)
75
+ if past:
76
+ position_ids = position_ids[:, -1].unsqueeze(-1)
77
+ else:
78
+ position_ids = None
79
+ return {
80
+ "input_ids": input_ids,
81
+ "past_key_values": past,
82
+ "use_cache": kwargs.get("use_cache"),
83
+ "position_ids": position_ids,
84
+ "attention_mask": attention_mask,
85
+ "token_type_ids": token_type_ids,
86
+ }
87
+
88
+ def forward(
89
+ self,
90
+ input_ids: Optional[torch.LongTensor] = None,
91
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
92
+ attention_mask: Optional[torch.FloatTensor] = None,
93
+ token_type_ids: Optional[torch.LongTensor] = None,
94
+ position_ids: Optional[torch.LongTensor] = None,
95
+ head_mask: Optional[torch.FloatTensor] = None,
96
+ inputs_embeds: Optional[torch.FloatTensor] = None,
97
+ encoder_hidden_states: Optional[torch.Tensor] = None,
98
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
99
+ labels: Optional[torch.LongTensor] = None,
100
+ use_cache: Optional[bool] = None,
101
+ output_attentions: Optional[bool] = None,
102
+ output_hidden_states: Optional[bool] = None,
103
+ return_dict: Optional[bool] = None,
104
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
105
+ r"""
106
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
107
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
108
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
109
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
110
+ """
111
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
112
+
113
+ transformer_outputs = self.transformer(
114
+ input_ids,
115
+ past_key_values=past_key_values,
116
+ attention_mask=attention_mask,
117
+ token_type_ids=token_type_ids,
118
+ position_ids=position_ids,
119
+ head_mask=head_mask,
120
+ inputs_embeds=inputs_embeds,
121
+ encoder_hidden_states=encoder_hidden_states,
122
+ encoder_attention_mask=encoder_attention_mask,
123
+ use_cache=use_cache,
124
+ output_attentions=output_attentions,
125
+ output_hidden_states=output_hidden_states,
126
+ return_dict=return_dict,
127
+ )
128
+ hidden_states = transformer_outputs[0]
129
+
130
+ # Set device for model parallelism
131
+ if self.model_parallel:
132
+ torch.cuda.set_device(self.transformer.first_device)
133
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
134
+
135
+ lm_logits = self.lm_head(hidden_states)
136
+
137
+ loss = None
138
+ if labels is not None:
139
+ # Shift so that tokens < n predict n
140
+ shift_logits = lm_logits[..., :-1, :].contiguous()
141
+ shift_labels = labels[..., 1:].contiguous()
142
+ # print("shift_labels=", shift_labels)
143
+ # Flatten the tokens
144
+ loss_fct = CrossEntropyLoss()
145
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
146
+
147
+ if not return_dict:
148
+ output = (lm_logits,) + transformer_outputs[1:]
149
+ return ((loss,) + output) if loss is not None else output
150
+
151
+ return CausalLMOutputWithCrossAttentions(
152
+ loss=loss,
153
+ logits=lm_logits,
154
+ past_key_values=transformer_outputs.past_key_values,
155
+ hidden_states=transformer_outputs.hidden_states,
156
+ attentions=transformer_outputs.attentions,
157
+ cross_attentions=transformer_outputs.cross_attentions,
158
+ )
159
+
160
+ @staticmethod
161
+ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
162
+ """
163
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
164
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
165
+ beam_idx at every generation step.
166
+ """
167
+ return tuple(
168
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
169
+ for layer_past in past
170
+ )
171
+
172
+
173
+
174
+
175
+ # if __name__ == "__main__":
176
+ # from transformers import GPT2Tokenizer
177
+ # tokenizer = GPT2Tokenizer.from_pretrained("/Users/wangjianing/Desktop/开源代码与数据模型/模型/gpt2")
178
+ # model = GPT2ForInContextLearning.from_pretrained("/Users/wangjianing/Desktop/开源代码与数据模型/模型/gpt2")
179
+
180
+ # # In-Context Learning for classification
181
+ # # input_text = "The capital city of China is Beijing. \n\n The capital city of Japan is Tokyo. \n\n The capital city of America is"
182
+ # input_text = "What are follows emotions? \n\n Input: The book is very nice.\n Output: Great. \n\n Input: I never eat chocolate!\n Output:"
183
+ # # input_text = "This film is wonderful.\n Great."
184
+ # tokenizer.pad_token = tokenizer.eos_token
185
+ # inputs = tokenizer(input_text, return_tensors="pt")
186
+ # input_len = inputs["input_ids"].shape[-1]
187
+ # gen_output = model.generate(**inputs, max_length=input_len + 10)
188
+ # gen_result = tokenizer.decode(gen_output[0])
189
+ # print("classification result:\n", gen_result)
190
+
191
+ # # In-Context Learning for generation
192
+ # input_text = "Please tell me what is the transformer? "
193
+ # # input_text = "This film is wonderful.\n Great."
194
+ # tokenizer.pad_token = tokenizer.eos_token
195
+ # inputs = tokenizer(input_text, return_tensors="pt")
196
+ # input_len = inputs["input_ids"].shape[-1]
197
+ # gen_output = model.generate(**inputs, max_length=input_len + 60)
198
+ # gen_result = tokenizer.decode(gen_output[0])
199
+ # print("generation result:\n", gen_result)
models/sequence_classification/classification.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2021/8/19 10:54 上午
3
+ # @Author : JianingWang
4
+ # @File : classification.py
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
8
+ from transformers import RobertaModel
9
+ from transformers.activations import ACT2FN
10
+ from transformers.models.electra import ElectraModel
11
+ from transformers.models.roformer import RoFormerModel
12
+ from transformers.models.albert import AlbertModel
13
+ from transformers.models.bert import BertModel, BertPreTrainedModel
14
+ from transformers.models.deberta_v2 import DebertaV2Model, DebertaV2PreTrainedModel
15
+ from transformers.modeling_outputs import SequenceClassifierOutput
16
+ from transformers.models.roberta import RobertaPreTrainedModel
17
+ from transformers.models.bert.modeling_bert import BertForSequenceClassification
18
+ from transformers.models.megatron_bert import MegatronBertPreTrainedModel, MegatronBertModel
19
+
20
+ PRETRAINED_MODEL_MAP = {
21
+ "bert": BertPreTrainedModel,
22
+ "deberta-v2": DebertaV2PreTrainedModel,
23
+ "roberta": RobertaPreTrainedModel,
24
+ "erlangshen": MegatronBertPreTrainedModel
25
+ }
26
+
27
+
28
+ class BertPooler(nn.Module):
29
+ def __init__(self, hidden_size, hidden_act, hidden_dropout_prob):
30
+ super().__init__()
31
+ self.dense = nn.Linear(hidden_size, hidden_size)
32
+ # self.activation = nn.Tanh()
33
+ self.activation = ACT2FN[hidden_act]
34
+ # self.dropout = nn.Dropout(hidden_dropout_prob)
35
+
36
+ def forward(self, features):
37
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
38
+ # x = self.dropout(x)
39
+ x = self.dense(x)
40
+ x = self.activation(x)
41
+ return x
42
+
43
+
44
+ def build_cls_model(config):
45
+ BaseClass = PRETRAINED_MODEL_MAP[config.model_type]
46
+
47
+ class BertForClassification(BaseClass):
48
+
49
+ def __init__(self, config):
50
+ super().__init__(config)
51
+ self.num_labels = config.num_labels
52
+ self.config = config
53
+ self.model_type = config.model_type
54
+ self.problem_type = config.problem_type
55
+
56
+ if self.model_type == "bert":
57
+ self.bert = BertModel(config)
58
+ elif self.model_type == "albert":
59
+ self.albert = AlbertModel(config)
60
+ # elif self.model_type == "chinesebert":
61
+ # self.bert = ChineseBertModel(config)
62
+ elif self.model_type == "roformer":
63
+ self.roformer = RoFormerModel(config)
64
+ elif self.model_type == "electra":
65
+ self.electra = ElectraModel(config)
66
+ elif self.model_type == "deberta-v2":
67
+ self.deberta = DebertaV2Model(config)
68
+ elif self.model_type == "roberta":
69
+ self.roberta = RobertaModel(config)
70
+ elif self.model_type == "erlangshen":
71
+ self.bert = MegatronBertModel(config)
72
+ self.pooler = BertPooler(config.hidden_size, config.hidden_act, config.hidden_dropout_prob)
73
+ if hasattr(config, "cls_dropout_rate"):
74
+ cls_dropout_rate = config.cls_dropout_rate
75
+ else:
76
+ cls_dropout_rate = config.hidden_dropout_prob
77
+ self.dropout = nn.Dropout(cls_dropout_rate)
78
+ add_feature_dims = config.additional_feature_dims if hasattr(config, "additional_feature_dims") else 0
79
+ # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
80
+ cls_hidden = config.hidden_size + add_feature_dims
81
+ if hasattr(config, "is_relation_task"):
82
+ cls_hidden = config.hidden_size * 2
83
+ self.classifier = nn.Linear(cls_hidden, config.num_labels)
84
+
85
+ self.init_weights()
86
+
87
+ def forward(
88
+ self,
89
+ input_ids=None,
90
+ attention_mask=None,
91
+ token_type_ids=None,
92
+ position_ids=None,
93
+ head_mask=None,
94
+ inputs_embeds=None,
95
+ labels=None,
96
+ output_attentions=None,
97
+ output_hidden_states=None,
98
+ return_dict=None,
99
+ pseudo_label=None,
100
+ pinyin_ids=None,
101
+ additional_features=None
102
+ ):
103
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
104
+ logits, outputs = None, None
105
+ inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, "position_ids": position_ids,
106
+ "head_mask": head_mask, "inputs_embeds": inputs_embeds, "output_attentions": output_attentions,
107
+ "output_hidden_states": output_hidden_states, "return_dict": return_dict, "pinyin_ids": pinyin_ids}
108
+ inputs = {k: v for k, v in inputs.items() if v is not None}
109
+ if self.model_type == "chinesebert":
110
+ outputs = self.bert(**inputs)
111
+ elif self.model_type == "bert":
112
+ outputs = self.bert(**inputs)
113
+ elif self.model_type == "albert":
114
+ outputs = self.albert(**inputs)
115
+ elif self.model_type == "electra":
116
+ outputs = self.electra(**inputs)
117
+ elif self.model_type == "roformer":
118
+ outputs = self.roformer(**inputs)
119
+ elif self.model_type == "deberta-v2":
120
+ outputs = self.deberta(**inputs)
121
+ elif self.model_type == "roberta":
122
+ outputs = self.roberta(**inputs)
123
+ elif self.model_type == "erlangshen":
124
+ outputs = self.bert(**inputs)
125
+
126
+ if hasattr(self.config, "is_relation_task"):
127
+ w = torch.logical_and(input_ids >= min(self.config.start_token_ids), input_ids <= max(self.config.start_token_ids))
128
+ start_index = w.nonzero()[:, 1].view(-1, 2)
129
+ # <start_entity> + <end_entity> 进分类
130
+ pooler_output = torch.cat([torch.cat([x[y[0], :], x[y[1], :]]).unsqueeze(0) for x, y in zip(outputs.last_hidden_state, start_index)])
131
+ # [CLS] + <start_entity> + <end_entity> 进分类
132
+ # pooler_output = torch.cat([torch.cat([z, x[y[0], :], x[y[1], :]]).unsqueeze(0) for x, y, z in zip(outputs.last_hidden_state, start_index, outputs.last_hidden_state[:, 0])])
133
+
134
+ elif "pooler_output" in outputs:
135
+ pooler_output = outputs.pooler_output
136
+ else:
137
+ pooler_output = self.pooler(outputs[0])
138
+ pooler_output = self.dropout(pooler_output)
139
+ # pooler_output = self.LayerNorm(pooler_output)
140
+ if additional_features is not None:
141
+ pooler_output = torch.cat((pooler_output, additional_features), dim=1)
142
+ logits = self.classifier(pooler_output)
143
+
144
+ loss = None
145
+ if labels is not None:
146
+ if self.problem_type == "regression":
147
+ loss_fct = MSELoss()
148
+ if self.num_labels == 1:
149
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
150
+ else:
151
+ loss = loss_fct(logits, labels)
152
+ elif self.problem_type == "multi_label_classification":
153
+ loss_fct = BCEWithLogitsLoss()
154
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.float().view(-1, self.num_labels))
155
+ # elif self.problem_type in ["single_label_classification"] or hasattr(self.config, "is_relation_task"):
156
+ else:
157
+ # loss_fct = FocalLoss()
158
+ loss_fct = CrossEntropyLoss()
159
+ # 伪标签
160
+ if pseudo_label is not None:
161
+ train_logits, pseudo_logits = logits[pseudo_label > 0.9], logits[pseudo_label < 0.1]
162
+ train_labels, pseudo_labels = labels[pseudo_label > 0.9], labels[pseudo_label < 0.1]
163
+ train_loss = loss_fct(train_logits.view(-1, self.num_labels), train_labels.view(-1)) if train_labels.nelement() else 0
164
+ pseudo_loss = loss_fct(pseudo_logits.view(-1, self.num_labels), pseudo_labels.view(-1)) if pseudo_labels.nelement() else 0
165
+ loss = 0.9 * train_loss + 0.1 * pseudo_loss
166
+ else:
167
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
168
+ return SequenceClassifierOutput(
169
+ loss=loss,
170
+ logits=logits,
171
+ hidden_states=outputs.hidden_states,
172
+ attentions=outputs.attentions,
173
+ )
174
+
175
+ return BertForClassification
models/sequence_classification/head_cls.py ADDED
@@ -0,0 +1,1284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Head Tuning with Prefix / Adapter
3
+ """
4
+ from typing import Optional, List, Union, Tuple
5
+ import torch
6
+ from torch._C import NoopLogger
7
+ import torch.nn
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
11
+
12
+ from transformers import BertModel, BertPreTrainedModel
13
+ from transformers import RobertaModel, RobertaPreTrainedModel
14
+ from transformers.models.deberta.modeling_deberta import DebertaModel, DebertaPreTrainedModel, ContextPooler, StableDropout
15
+ from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2PreTrainedModel
16
+ from transformers.models.bart.modeling_bart import BartPretrainedModel, BartClassificationHead, BartModel
17
+ from transformers.models.roberta.modeling_roberta import RobertaClassificationHead
18
+ from transformers.models.bart.configuration_bart import BartConfig
19
+ from transformers.modeling_outputs import SequenceClassifierOutput, Seq2SeqSequenceClassifierOutput, SequenceClassifierOutputWithPast
20
+
21
+ from models.basic_modules.prefix_encoder import PrefixEncoder
22
+
23
+ from models.basic_modules.adapter import BertAdaModel, RobertaAdaModel, init_adapter
24
+ from tools.model_utils.parameter_freeze import ParameterFreeze
25
+
26
+ from tools.runner_utils.log_util import logging
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ freezer = ParameterFreeze()
31
+
32
+ ## ======== BERT ========
33
+
34
+ # Vanilla Fine-tuning For BERT
35
+ class BertForSequenceClassification(BertPreTrainedModel):
36
+ def __init__(self, config):
37
+ super().__init__(config)
38
+ self.num_labels = config.num_labels
39
+ self.config = config
40
+
41
+ self.bert = BertModel(config)
42
+ if self.config.use_freezing:
43
+ self.bert = freezer.freeze_lm(self.bert)
44
+
45
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
46
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
47
+
48
+ self.init_weights()
49
+
50
+ def freeze_backbone(self, use_freezing: bool=True):
51
+ if use_freezing:
52
+ self.bert = freezer.freeze_lm(self.bert)
53
+ else:
54
+ self.bert = freezer.unfreeze_lm(self.bert)
55
+
56
+ def forward(
57
+ self,
58
+ input_ids=None,
59
+ attention_mask=None,
60
+ token_type_ids=None,
61
+ position_ids=None,
62
+ head_mask=None,
63
+ inputs_embeds=None,
64
+ labels=None,
65
+ output_attentions=None,
66
+ output_hidden_states=None,
67
+ return_dict=None,
68
+ ):
69
+ r"""
70
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
71
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
72
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
73
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
74
+ """
75
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
76
+
77
+ # print("input_ids.shape=", input_ids.shape) # e.g., [8, 128]
78
+ # print("attention_mask.shape=", attention_mask.shape) # e.g., [8, 128]
79
+ # print("token_type_ids.shape=", token_type_ids.shape) # e.g., [8, 128]
80
+
81
+ outputs = self.bert(
82
+ input_ids,
83
+ attention_mask=attention_mask,
84
+ token_type_ids=token_type_ids,
85
+ position_ids=position_ids,
86
+ head_mask=head_mask,
87
+ inputs_embeds=inputs_embeds,
88
+ output_attentions=output_attentions,
89
+ output_hidden_states=output_hidden_states,
90
+ return_dict=return_dict,
91
+ )
92
+
93
+ pooled_output = outputs[1]
94
+
95
+ pooled_output = self.dropout(pooled_output)
96
+ logits = self.classifier(pooled_output)
97
+
98
+ loss = None
99
+ if labels is not None:
100
+ if self.config.problem_type is None:
101
+ if self.num_labels == 1:
102
+ self.config.problem_type = "regression"
103
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
104
+ self.config.problem_type = "single_label_classification"
105
+ else:
106
+ self.config.problem_type = "multi_label_classification"
107
+
108
+ if self.config.problem_type == "regression":
109
+ loss_fct = MSELoss()
110
+ if self.num_labels == 1:
111
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
112
+ else:
113
+ loss = loss_fct(logits, labels)
114
+ elif self.config.problem_type == "single_label_classification":
115
+ loss_fct = CrossEntropyLoss()
116
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
117
+ elif self.config.problem_type == "multi_label_classification":
118
+ loss_fct = BCEWithLogitsLoss()
119
+ loss = loss_fct(logits, labels)
120
+ if not return_dict:
121
+ output = (logits,) + outputs[2:]
122
+ return ((loss,) + output) if loss is not None else output
123
+
124
+ return SequenceClassifierOutput(
125
+ loss=loss,
126
+ logits=logits,
127
+ hidden_states=outputs.hidden_states,
128
+ attentions=outputs.attentions,
129
+ )
130
+
131
+ # Prefix-tuning For BERT
132
+ class BertPrefixForSequenceClassification(BertPreTrainedModel):
133
+ def __init__(self, config):
134
+ super().__init__(config)
135
+ self.num_labels = config.num_labels
136
+ self.config = config
137
+ self.bert = BertModel(config)
138
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
139
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
140
+
141
+ # for param in self.bert.parameters():
142
+ # param.requires_grad = False
143
+
144
+ if self.config.use_freezing:
145
+ self.bert = freezer.freeze_lm(self.bert)
146
+
147
+ self.pre_seq_len = config.pre_seq_len
148
+ self.n_layer = config.num_hidden_layers
149
+ self.n_head = config.num_attention_heads
150
+ self.n_embd = config.hidden_size // config.num_attention_heads
151
+
152
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
153
+
154
+ self.prefix_encoder = PrefixEncoder(config)
155
+
156
+ bert_param = 0
157
+ for name, param in self.bert.named_parameters():
158
+ bert_param += param.numel()
159
+ all_param = 0
160
+ for name, param in self.named_parameters():
161
+ all_param += param.numel()
162
+ total_param = all_param - bert_param
163
+ print("total param is {}".format(total_param)) # 9860105
164
+
165
+ def freeze_backbone(self, use_freezing: bool=True):
166
+ if use_freezing:
167
+ self.bert = freezer.freeze_lm(self.bert)
168
+ else:
169
+ self.bert = freezer.unfreeze_lm(self.bert)
170
+
171
+ def get_prompt(self, batch_size):
172
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
173
+ past_key_values = self.prefix_encoder(prefix_tokens)
174
+ # bsz, seqlen, _ = past_key_values.shape
175
+ past_key_values = past_key_values.view(
176
+ batch_size,
177
+ self.pre_seq_len,
178
+ self.n_layer * 2,
179
+ self.n_head,
180
+ self.n_embd
181
+ )
182
+ past_key_values = self.dropout(past_key_values)
183
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
184
+ return past_key_values
185
+
186
+ def forward(
187
+ self,
188
+ input_ids=None,
189
+ attention_mask=None,
190
+ token_type_ids=None,
191
+ position_ids=None,
192
+ head_mask=None,
193
+ inputs_embeds=None,
194
+ labels=None,
195
+ output_attentions=None,
196
+ output_hidden_states=None,
197
+ return_dict=None,
198
+ ):
199
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
200
+
201
+ # print("input_ids.shape=", input_ids.shape) # e.g., [8, 128]
202
+ # print("attention_mask.shape=", attention_mask.shape) # e.g., [8, 128]
203
+ # print("token_type_ids.shape=", token_type_ids.shape) # e.g., [8, 128]
204
+
205
+ batch_size = input_ids.shape[0]
206
+ past_key_values = self.get_prompt(batch_size=batch_size)
207
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
208
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
209
+
210
+ if position_ids is None:
211
+ position_ids = torch.tensor([i for i in range(input_ids.shape[-1])]).expand(batch_size, -1).to(self.bert.device)
212
+
213
+ outputs = self.bert(
214
+ input_ids,
215
+ attention_mask=attention_mask,
216
+ token_type_ids=token_type_ids,
217
+ position_ids=position_ids,
218
+ head_mask=head_mask,
219
+ inputs_embeds=inputs_embeds,
220
+ output_attentions=output_attentions,
221
+ output_hidden_states=output_hidden_states,
222
+ return_dict=return_dict,
223
+ past_key_values=past_key_values,
224
+ )
225
+
226
+ pooled_output = outputs[1]
227
+
228
+ pooled_output = self.dropout(pooled_output)
229
+ logits = self.classifier(pooled_output)
230
+
231
+ loss = None
232
+ if labels is not None:
233
+ if self.config.problem_type is None:
234
+ if self.num_labels == 1:
235
+ self.config.problem_type = "regression"
236
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
237
+ self.config.problem_type = "single_label_classification"
238
+ else:
239
+ self.config.problem_type = "multi_label_classification"
240
+
241
+ if self.config.problem_type == "regression":
242
+ loss_fct = MSELoss()
243
+ if self.num_labels == 1:
244
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
245
+ else:
246
+ loss = loss_fct(logits, labels)
247
+ elif self.config.problem_type == "single_label_classification":
248
+ loss_fct = CrossEntropyLoss()
249
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
250
+ elif self.config.problem_type == "multi_label_classification":
251
+ loss_fct = BCEWithLogitsLoss()
252
+ loss = loss_fct(logits, labels)
253
+ if not return_dict:
254
+ output = (logits,) + outputs[2:]
255
+ return ((loss,) + output) if loss is not None else output
256
+
257
+ return SequenceClassifierOutput(
258
+ loss=loss,
259
+ logits=logits,
260
+ hidden_states=outputs.hidden_states,
261
+ attentions=outputs.attentions,
262
+ )
263
+
264
+
265
+ # Prompt-tuning For BERT
266
+ class BertPtuningForSequenceClassification(BertPreTrainedModel):
267
+ def __init__(self, config):
268
+ super().__init__(config)
269
+ self.num_labels = config.num_labels
270
+ self.bert = BertModel(config)
271
+ self.embeddings = self.bert.embeddings
272
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
273
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
274
+
275
+ # for param in self.bert.parameters():
276
+ # param.requires_grad = False
277
+
278
+ if self.config.use_freezing:
279
+ self.bert = freezer.freeze_lm(self.bert)
280
+
281
+ self.pre_seq_len = config.pre_seq_len
282
+ self.n_layer = config.num_hidden_layers
283
+ self.n_head = config.num_attention_heads
284
+ self.n_embd = config.hidden_size // config.num_attention_heads
285
+
286
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
287
+ self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size)
288
+
289
+ def freeze_backbone(self, use_freezing: bool=True):
290
+ if use_freezing:
291
+ self.bert = freezer.freeze_lm(self.bert)
292
+ else:
293
+ self.bert = freezer.unfreeze_lm(self.bert)
294
+
295
+ def get_prompt(self, batch_size):
296
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
297
+ prompts = self.prefix_encoder(prefix_tokens)
298
+ return prompts
299
+
300
+ def forward(
301
+ self,
302
+ input_ids=None,
303
+ attention_mask=None,
304
+ token_type_ids=None,
305
+ position_ids=None,
306
+ head_mask=None,
307
+ inputs_embeds=None,
308
+ labels=None,
309
+ output_attentions=None,
310
+ output_hidden_states=None,
311
+ return_dict=None,
312
+ ):
313
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
314
+
315
+ batch_size = input_ids.shape[0]
316
+ raw_embedding = self.embeddings(
317
+ input_ids=input_ids,
318
+ position_ids=position_ids,
319
+ token_type_ids=token_type_ids,
320
+ )
321
+ prompts = self.get_prompt(batch_size=batch_size)
322
+ inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)
323
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
324
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
325
+
326
+ outputs = self.bert(
327
+ # input_ids,
328
+ attention_mask=attention_mask,
329
+ # token_type_ids=token_type_ids,
330
+ # position_ids=position_ids,
331
+ head_mask=head_mask,
332
+ inputs_embeds=inputs_embeds,
333
+ output_attentions=output_attentions,
334
+ output_hidden_states=output_hidden_states,
335
+ return_dict=return_dict,
336
+ # past_key_values=past_key_values,
337
+ )
338
+
339
+ # pooled_output = outputs[1]
340
+ sequence_output = outputs[0]
341
+ sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
342
+ first_token_tensor = sequence_output[:, 0]
343
+ pooled_output = self.bert.pooler.dense(first_token_tensor)
344
+ pooled_output = self.bert.pooler.activation(pooled_output)
345
+
346
+ pooled_output = self.dropout(pooled_output)
347
+ logits = self.classifier(pooled_output)
348
+
349
+ loss = None
350
+ if labels is not None:
351
+ if self.config.problem_type is None:
352
+ if self.num_labels == 1:
353
+ self.config.problem_type = "regression"
354
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
355
+ self.config.problem_type = "single_label_classification"
356
+ else:
357
+ self.config.problem_type = "multi_label_classification"
358
+
359
+ if self.config.problem_type == "regression":
360
+ loss_fct = MSELoss()
361
+ if self.num_labels == 1:
362
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
363
+ else:
364
+ loss = loss_fct(logits, labels)
365
+ elif self.config.problem_type == "single_label_classification":
366
+ loss_fct = CrossEntropyLoss()
367
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
368
+ elif self.config.problem_type == "multi_label_classification":
369
+ loss_fct = BCEWithLogitsLoss()
370
+ loss = loss_fct(logits, labels)
371
+ if not return_dict:
372
+ output = (logits,) + outputs[2:]
373
+ return ((loss,) + output) if loss is not None else output
374
+
375
+ return SequenceClassifierOutput(
376
+ loss=loss,
377
+ logits=logits,
378
+ hidden_states=outputs.hidden_states,
379
+ attentions=outputs.attentions,
380
+ )
381
+
382
+ # Adapter-tuning For BERT
383
+ class BertAdapterForSequenceClassification(BertPreTrainedModel):
384
+ def __init__(self, config):
385
+ super().__init__(config)
386
+ self.num_labels = config.num_labels
387
+ self.bert = BertAdaModel(config)
388
+ self.embeddings = self.bert.embeddings
389
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
390
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
391
+
392
+ # for param in self.bert.parameters():
393
+ # param.requires_grad = False
394
+ if self.config.use_freezing:
395
+ self.bert = freezer.freeze_lm_component(self.bert, "adapter")
396
+
397
+ def freeze_backbone(self, use_freezing: bool=True):
398
+ if use_freezing:
399
+ self.bert = freezer.freeze_lm_component(self.bert, "adapter")
400
+ else:
401
+ self.bert = freezer.unfreeze_lm(self.bert)
402
+
403
+
404
+ def forward(
405
+ self,
406
+ input_ids=None,
407
+ attention_mask=None,
408
+ token_type_ids=None,
409
+ position_ids=None,
410
+ head_mask=None,
411
+ inputs_embeds=None,
412
+ labels=None,
413
+ output_attentions=None,
414
+ output_hidden_states=None,
415
+ return_dict=None,
416
+ ):
417
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
418
+
419
+ batch_size = input_ids.shape[0]
420
+ inputs_embeds = self.embeddings(
421
+ input_ids=input_ids,
422
+ position_ids=position_ids,
423
+ token_type_ids=token_type_ids,
424
+ )
425
+ outputs = self.bert(
426
+ # input_ids,
427
+ attention_mask=attention_mask,
428
+ # token_type_ids=token_type_ids,
429
+ # position_ids=position_ids,
430
+ head_mask=head_mask,
431
+ inputs_embeds=inputs_embeds,
432
+ output_attentions=output_attentions,
433
+ output_hidden_states=output_hidden_states,
434
+ return_dict=return_dict,
435
+ # past_key_values=past_key_values,
436
+ )
437
+
438
+ # pooled_output = outputs[1]
439
+ sequence_output = outputs[0]
440
+ # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
441
+ first_token_tensor = sequence_output[:, 0]
442
+ pooled_output = self.bert.pooler.dense(first_token_tensor)
443
+ pooled_output = self.bert.pooler.activation(pooled_output)
444
+
445
+ pooled_output = self.dropout(pooled_output)
446
+ logits = self.classifier(pooled_output)
447
+
448
+ loss = None
449
+ if labels is not None:
450
+ if self.config.problem_type is None:
451
+ if self.num_labels == 1:
452
+ self.config.problem_type = "regression"
453
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
454
+ self.config.problem_type = "single_label_classification"
455
+ else:
456
+ self.config.problem_type = "multi_label_classification"
457
+
458
+ if self.config.problem_type == "regression":
459
+ loss_fct = MSELoss()
460
+ if self.num_labels == 1:
461
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
462
+ else:
463
+ loss = loss_fct(logits, labels)
464
+ elif self.config.problem_type == "single_label_classification":
465
+ loss_fct = CrossEntropyLoss()
466
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
467
+ elif self.config.problem_type == "multi_label_classification":
468
+ loss_fct = BCEWithLogitsLoss()
469
+ loss = loss_fct(logits, labels)
470
+ if not return_dict:
471
+ output = (logits,) + outputs[2:]
472
+ return ((loss,) + output) if loss is not None else output
473
+
474
+ return SequenceClassifierOutput(
475
+ loss=loss,
476
+ logits=logits,
477
+ hidden_states=outputs.hidden_states,
478
+ attentions=outputs.attentions,
479
+ )
480
+
481
+
482
+
483
+ # ========= RoBERTa =========
484
+
485
+ # Vanilla Fine-tuning For RoBERTa
486
+ class RobertaForSequenceClassification(RobertaPreTrainedModel):
487
+ def __init__(self, config):
488
+ super().__init__(config)
489
+ self.num_labels = config.num_labels
490
+ self.config = config
491
+ self.roberta = RobertaModel(config)
492
+ if self.config.use_freezing:
493
+ self.roberta = freezer.freeze_lm(self.roberta)
494
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
495
+ # self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
496
+ self.classifier = RobertaClassificationHead(config)
497
+ self.init_weights()
498
+
499
+ def freeze_backbone(self, use_freezing: bool=True):
500
+ if use_freezing:
501
+ self.roberta = freezer.freeze_lm(self.roberta)
502
+ else:
503
+ self.roberta = freezer.unfreeze_lm(self.roberta)
504
+
505
+ def forward(
506
+ self,
507
+ input_ids=None,
508
+ attention_mask=None,
509
+ token_type_ids=None,
510
+ position_ids=None,
511
+ head_mask=None,
512
+ inputs_embeds=None,
513
+ labels=None,
514
+ output_attentions=None,
515
+ output_hidden_states=None,
516
+ return_dict=None,
517
+ ):
518
+ r"""
519
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
520
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
521
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
522
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
523
+ """
524
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
525
+
526
+ outputs = self.roberta(
527
+ input_ids,
528
+ attention_mask=attention_mask,
529
+ token_type_ids=token_type_ids,
530
+ position_ids=position_ids,
531
+ head_mask=head_mask,
532
+ inputs_embeds=inputs_embeds,
533
+ output_attentions=output_attentions,
534
+ output_hidden_states=output_hidden_states,
535
+ return_dict=return_dict,
536
+ )
537
+
538
+ pooled_output = outputs[1]
539
+
540
+ pooled_output = self.dropout(pooled_output)
541
+ logits = self.classifier(pooled_output)
542
+
543
+ loss = None
544
+ if labels is not None:
545
+ if self.config.problem_type is None:
546
+ if self.num_labels == 1:
547
+ self.config.problem_type = "regression"
548
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
549
+ self.config.problem_type = "single_label_classification"
550
+ else:
551
+ self.config.problem_type = "multi_label_classification"
552
+
553
+ if self.config.problem_type == "regression":
554
+ loss_fct = MSELoss()
555
+ if self.num_labels == 1:
556
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
557
+ else:
558
+ loss = loss_fct(logits, labels)
559
+ elif self.config.problem_type == "single_label_classification":
560
+ loss_fct = CrossEntropyLoss()
561
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
562
+ elif self.config.problem_type == "multi_label_classification":
563
+ loss_fct = BCEWithLogitsLoss()
564
+ loss = loss_fct(logits, labels)
565
+ if not return_dict:
566
+ output = (logits,) + outputs[2:]
567
+ return ((loss,) + output) if loss is not None else output
568
+
569
+ return SequenceClassifierOutput(
570
+ loss=loss,
571
+ logits=logits,
572
+ hidden_states=outputs.hidden_states,
573
+ attentions=outputs.attentions,
574
+ )
575
+
576
+ # Prefix-tuning For RoBERTa
577
+ class RobertaPrefixForSequenceClassification(RobertaPreTrainedModel):
578
+ def __init__(self, config):
579
+ super().__init__(config)
580
+ self.num_labels = config.num_labels
581
+ self.config = config
582
+ self.roberta = RobertaModel(config)
583
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
584
+ # self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
585
+ self.classifier = RobertaClassificationHead(config)
586
+ self.init_weights()
587
+
588
+ for param in self.roberta.parameters():
589
+ param.requires_grad = False
590
+
591
+ self.pre_seq_len = config.pre_seq_len
592
+ self.n_layer = config.num_hidden_layers
593
+ self.n_head = config.num_attention_heads
594
+ self.n_embd = config.hidden_size // config.num_attention_heads
595
+
596
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
597
+ self.prefix_encoder = PrefixEncoder(config)
598
+
599
+ bert_param = 0
600
+ for name, param in self.roberta.named_parameters():
601
+ bert_param += param.numel()
602
+ all_param = 0
603
+ for name, param in self.named_parameters():
604
+ all_param += param.numel()
605
+ total_param = all_param - bert_param
606
+ print("total param is {}".format(total_param)) # 9860105
607
+
608
+ def freeze_backbone(self, use_freezing: bool=True):
609
+ if use_freezing:
610
+ self.roberta = freezer.freeze_lm(self.roberta)
611
+ else:
612
+ self.roberta = freezer.unfreeze_lm(self.roberta)
613
+
614
+
615
+ def get_prompt(self, batch_size):
616
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device)
617
+ # print("prefix_tokens.shape=", prefix_tokens.shape)
618
+ past_key_values = self.prefix_encoder(prefix_tokens)
619
+ # print("past_key_values[0].shape=", past_key_values[0].shape)
620
+ past_key_values = past_key_values.view(
621
+ batch_size,
622
+ self.pre_seq_len,
623
+ self.n_layer * 2,
624
+ self.n_head,
625
+ self.n_embd
626
+ )
627
+ # print("past_key_values[0].shape=", past_key_values[0].shape)
628
+ past_key_values = self.dropout(past_key_values)
629
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
630
+ # print("past_key_values[0].shape=", past_key_values[0].shape)
631
+ return past_key_values
632
+
633
+ def forward(
634
+ self,
635
+ input_ids=None,
636
+ attention_mask=None,
637
+ token_type_ids=None,
638
+ position_ids=None,
639
+ head_mask=None,
640
+ inputs_embeds=None,
641
+ labels=None,
642
+ output_attentions=None,
643
+ output_hidden_states=None,
644
+ return_dict=None,
645
+ ):
646
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
647
+
648
+ batch_size = input_ids.shape[0]
649
+ past_key_values = self.get_prompt(batch_size=batch_size)
650
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.roberta.device)
651
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
652
+
653
+ if position_ids is None:
654
+ position_ids = torch.tensor([i for i in range(input_ids.shape[-1])]).expand(batch_size, -1).to(self.roberta.device)
655
+
656
+ outputs = self.roberta(
657
+ input_ids,
658
+ attention_mask=attention_mask,
659
+ token_type_ids=token_type_ids,
660
+ position_ids=position_ids,
661
+ head_mask=head_mask,
662
+ inputs_embeds=inputs_embeds,
663
+ output_attentions=output_attentions,
664
+ output_hidden_states=output_hidden_states,
665
+ return_dict=return_dict,
666
+ past_key_values=past_key_values,
667
+ )
668
+
669
+ pooled_output = outputs[1]
670
+
671
+ pooled_output = self.dropout(pooled_output)
672
+ logits = self.classifier(pooled_output)
673
+
674
+ loss = None
675
+ if labels is not None:
676
+ labels = (labels < 0).long().to(labels.device) + labels
677
+
678
+ if self.config.problem_type is None:
679
+ if self.num_labels == 1:
680
+ self.config.problem_type = "regression"
681
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
682
+ self.config.problem_type = "single_label_classification"
683
+ else:
684
+ self.config.problem_type = "multi_label_classification"
685
+
686
+ if self.config.problem_type == "regression":
687
+ loss_fct = MSELoss()
688
+ if self.num_labels == 1:
689
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
690
+ else:
691
+ loss = loss_fct(logits, labels)
692
+ elif self.config.problem_type == "single_label_classification":
693
+ loss_fct = CrossEntropyLoss()
694
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
695
+ elif self.config.problem_type == "multi_label_classification":
696
+ loss_fct = BCEWithLogitsLoss()
697
+ loss = loss_fct(logits, labels)
698
+ if not return_dict:
699
+ output = (logits,) + outputs[2:]
700
+ return ((loss,) + output) if loss is not None else output
701
+
702
+ return SequenceClassifierOutput(
703
+ loss=loss,
704
+ logits=logits,
705
+ hidden_states=outputs.hidden_states,
706
+ attentions=outputs.attentions,
707
+ )
708
+
709
+ # Prompt-tuning For RoBERTa
710
+ class RobertaPtuningForSequenceClassification(RobertaPreTrainedModel):
711
+ def __init__(self, config):
712
+ super().__init__(config)
713
+ self.num_labels = config.num_labels
714
+ self.roberta = RobertaModel(config)
715
+ self.embeddings = self.roberta.embeddings
716
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
717
+ # self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
718
+ self.classifier = RobertaClassificationHead(config)
719
+
720
+ # for param in self.roberta.parameters():
721
+ # param.requires_grad = False
722
+
723
+ if self.config.use_freezing:
724
+ self.roberta = freezer.freeze_lm(self.roberta)
725
+
726
+ self.pre_seq_len = config.pre_seq_len
727
+ self.n_layer = config.num_hidden_layers
728
+ self.n_head = config.num_attention_heads
729
+ self.n_embd = config.hidden_size // config.num_attention_heads
730
+
731
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
732
+ self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size)
733
+
734
+ def freeze_backbone(self, use_freezing: bool=True):
735
+ if use_freezing:
736
+ self.roberta = freezer.freeze_lm(self.roberta)
737
+ else:
738
+ self.roberta = freezer.unfreeze_lm(self.roberta)
739
+
740
+ def get_prompt(self, batch_size):
741
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device)
742
+ prompts = self.prefix_encoder(prefix_tokens)
743
+ return prompts
744
+
745
+ def forward(
746
+ self,
747
+ input_ids=None,
748
+ attention_mask=None,
749
+ token_type_ids=None,
750
+ position_ids=None,
751
+ head_mask=None,
752
+ inputs_embeds=None,
753
+ labels=None,
754
+ output_attentions=None,
755
+ output_hidden_states=None,
756
+ return_dict=None,
757
+ ):
758
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
759
+
760
+ batch_size = input_ids.shape[0]
761
+ raw_embedding = self.embeddings(
762
+ input_ids=input_ids,
763
+ position_ids=position_ids,
764
+ token_type_ids=token_type_ids,
765
+ )
766
+ prompts = self.get_prompt(batch_size=batch_size)
767
+ inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)
768
+ # print(input_embeddings.shape)
769
+ # exit()
770
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.roberta.device)
771
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
772
+
773
+ outputs = self.roberta(
774
+ # input_ids,
775
+ attention_mask=attention_mask,
776
+ # token_type_ids=token_type_ids,
777
+ # position_ids=position_ids,
778
+ head_mask=head_mask,
779
+ inputs_embeds=inputs_embeds,
780
+ output_attentions=output_attentions,
781
+ output_hidden_states=output_hidden_states,
782
+ return_dict=return_dict,
783
+ # past_key_values=past_key_values,
784
+ )
785
+
786
+ # pooled_output = outputs[1]
787
+ sequence_output = outputs[0]
788
+ sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
789
+ first_token_tensor = sequence_output[:, 0]
790
+ pooled_output = self.roberta.pooler.dense(first_token_tensor)
791
+ pooled_output = self.roberta.pooler.activation(pooled_output)
792
+
793
+ pooled_output = self.dropout(pooled_output)
794
+ logits = self.classifier(pooled_output)
795
+
796
+ loss = None
797
+ if labels is not None:
798
+ if self.config.problem_type is None:
799
+ if self.num_labels == 1:
800
+ self.config.problem_type = "regression"
801
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
802
+ self.config.problem_type = "single_label_classification"
803
+ else:
804
+ self.config.problem_type = "multi_label_classification"
805
+
806
+ if self.config.problem_type == "regression":
807
+ loss_fct = MSELoss()
808
+ if self.num_labels == 1:
809
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
810
+ else:
811
+ loss = loss_fct(logits, labels)
812
+ elif self.config.problem_type == "single_label_classification":
813
+ loss_fct = CrossEntropyLoss()
814
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
815
+ elif self.config.problem_type == "multi_label_classification":
816
+ loss_fct = BCEWithLogitsLoss()
817
+ loss = loss_fct(logits, labels)
818
+ if not return_dict:
819
+ output = (logits,) + outputs[2:]
820
+ return ((loss,) + output) if loss is not None else output
821
+
822
+ return SequenceClassifierOutput(
823
+ loss=loss,
824
+ logits=logits,
825
+ hidden_states=outputs.hidden_states,
826
+ attentions=outputs.attentions,
827
+ )
828
+
829
+ # Adapter-tuning For RoBERTa
830
+ class RobertaAdapterForSequenceClassification(RobertaPreTrainedModel):
831
+ def __init__(self, config):
832
+ super().__init__(config)
833
+ self.num_labels = config.num_labels
834
+ self.roberta = RobertaAdaModel(config)
835
+ self.embeddings = self.roberta.embeddings
836
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
837
+ # self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
838
+ self.classifier = RobertaClassificationHead(config)
839
+
840
+ self.init_weights()
841
+ # for param in self.roberta.parameters():
842
+ # param.requires_grad = False
843
+ self.roberta = init_adapter(self.roberta)
844
+ if self.config.use_freezing:
845
+ self.roberta = freezer.freeze_lm_component(self.roberta, "adapter")
846
+
847
+ def freeze_backbone(self, use_freezing: bool=True):
848
+ if use_freezing:
849
+ self.roberta = freezer.freeze_lm_component(self.roberta, "adapter")
850
+ else:
851
+ self.roberta = freezer.unfreeze_lm(self.roberta)
852
+
853
+ def forward(
854
+ self,
855
+ input_ids=None,
856
+ attention_mask=None,
857
+ token_type_ids=None,
858
+ position_ids=None,
859
+ head_mask=None,
860
+ inputs_embeds=None,
861
+ labels=None,
862
+ output_attentions=None,
863
+ output_hidden_states=None,
864
+ return_dict=None,
865
+ ):
866
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
867
+
868
+ batch_size = input_ids.shape[0]
869
+ inputs_embeds = self.embeddings(
870
+ input_ids=input_ids,
871
+ position_ids=position_ids,
872
+ token_type_ids=token_type_ids,
873
+ )
874
+
875
+ outputs = self.roberta(
876
+ # input_ids,
877
+ attention_mask=attention_mask,
878
+ # token_type_ids=token_type_ids,
879
+ # position_ids=position_ids,
880
+ head_mask=head_mask,
881
+ inputs_embeds=inputs_embeds,
882
+ output_attentions=output_attentions,
883
+ output_hidden_states=output_hidden_states,
884
+ return_dict=return_dict,
885
+ # past_key_values=past_key_values,
886
+ )
887
+
888
+ # pooled_output = outputs[1]
889
+ sequence_output = outputs[0]
890
+ # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
891
+ first_token_tensor = sequence_output[:, 0]
892
+ pooled_output = self.roberta.pooler.dense(first_token_tensor)
893
+ pooled_output = self.roberta.pooler.activation(pooled_output)
894
+
895
+ pooled_output = self.dropout(pooled_output)
896
+ logits = self.classifier(pooled_output)
897
+
898
+ loss = None
899
+ if labels is not None:
900
+ if self.config.problem_type is None:
901
+ if self.num_labels == 1:
902
+ self.config.problem_type = "regression"
903
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
904
+ self.config.problem_type = "single_label_classification"
905
+ else:
906
+ self.config.problem_type = "multi_label_classification"
907
+
908
+ if self.config.problem_type == "regression":
909
+ loss_fct = MSELoss()
910
+ if self.num_labels == 1:
911
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
912
+ else:
913
+ loss = loss_fct(logits, labels)
914
+ elif self.config.problem_type == "single_label_classification":
915
+ loss_fct = CrossEntropyLoss()
916
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
917
+ elif self.config.problem_type == "multi_label_classification":
918
+ loss_fct = BCEWithLogitsLoss()
919
+ loss = loss_fct(logits, labels)
920
+ if not return_dict:
921
+ output = (logits,) + outputs[2:]
922
+ return ((loss,) + output) if loss is not None else output
923
+
924
+ return SequenceClassifierOutput(
925
+ loss=loss,
926
+ logits=logits,
927
+ hidden_states=outputs.hidden_states,
928
+ attentions=outputs.attentions,
929
+ )
930
+
931
+
932
+ # ========= DeBERTa =========
933
+
934
+ # Prefix-tuning For DeBERTa
935
+ class DebertaPrefixForSequenceClassification(DebertaPreTrainedModel):
936
+ def __init__(self, config):
937
+ super().__init__(config)
938
+ self.num_labels = config.num_labels
939
+ self.config = config
940
+ self.deberta = DebertaModel(config)
941
+ self.pooler = ContextPooler(config)
942
+ output_dim = self.pooler.output_dim
943
+ self.classifier = torch.nn.Linear(output_dim, self.num_labels)
944
+ self.dropout = StableDropout(config.hidden_dropout_prob)
945
+ self.init_weights()
946
+
947
+ # for param in self.deberta.parameters():
948
+ # param.requires_grad = False
949
+
950
+ if self.config.use_freezing:
951
+ self.deberta = freezer.freeze_lm(self.deberta)
952
+
953
+ self.pre_seq_len = config.pre_seq_len
954
+ self.n_layer = config.num_hidden_layers
955
+ self.n_head = config.num_attention_heads
956
+ self.n_embd = config.hidden_size // config.num_attention_heads
957
+
958
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
959
+ self.prefix_encoder = PrefixEncoder(config)
960
+
961
+ deberta_param = 0
962
+ for name, param in self.deberta.named_parameters():
963
+ deberta_param += param.numel()
964
+ all_param = 0
965
+ for name, param in self.named_parameters():
966
+ all_param += param.numel()
967
+ total_param = all_param - deberta_param
968
+ print("total param is {}".format(total_param)) # 9860105
969
+
970
+ def freeze_backbone(self, use_freezing: bool=True):
971
+ if use_freezing:
972
+ self.deberta = freezer.freeze_lm(self.deberta)
973
+ else:
974
+ self.deberta = freezer.unfreeze_lm(self.deberta)
975
+
976
+ def get_prompt(self, batch_size):
977
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device)
978
+ past_key_values = self.prefix_encoder(prefix_tokens)
979
+ # bsz, seqlen, _ = past_key_values.shape
980
+ past_key_values = past_key_values.view(
981
+ batch_size,
982
+ self.pre_seq_len,
983
+ self.n_layer * 2,
984
+ self.n_head,
985
+ self.n_embd
986
+ )
987
+ past_key_values = self.dropout(past_key_values)
988
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
989
+ return past_key_values
990
+
991
+ def forward(
992
+ self,
993
+ input_ids=None,
994
+ attention_mask=None,
995
+ token_type_ids=None,
996
+ position_ids=None,
997
+ head_mask=None,
998
+ inputs_embeds=None,
999
+ labels=None,
1000
+ output_attentions=None,
1001
+ output_hidden_states=None,
1002
+ return_dict=None,
1003
+ ):
1004
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1005
+
1006
+ batch_size = input_ids.shape[0]
1007
+ past_key_values = self.get_prompt(batch_size=batch_size)
1008
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.deberta.device)
1009
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
1010
+
1011
+ outputs = self.deberta(
1012
+ input_ids,
1013
+ attention_mask=attention_mask,
1014
+ token_type_ids=token_type_ids,
1015
+ position_ids=position_ids,
1016
+ inputs_embeds=inputs_embeds,
1017
+ output_attentions=output_attentions,
1018
+ output_hidden_states=output_hidden_states,
1019
+ return_dict=return_dict,
1020
+ past_key_values=past_key_values,
1021
+ )
1022
+
1023
+ encoder_layer = outputs[0]
1024
+ pooled_output = self.pooler(encoder_layer)
1025
+ pooled_output = self.dropout(pooled_output)
1026
+ logits = self.classifier(pooled_output)
1027
+
1028
+ loss = None
1029
+ if labels is not None:
1030
+ if self.num_labels == 1:
1031
+ # regression task
1032
+ loss_fn = torch.nn.MSELoss()
1033
+ logits = logits.view(-1).to(labels.dtype)
1034
+ loss = loss_fn(logits, labels.view(-1))
1035
+ elif labels.dim() == 1 or labels.size(-1) == 1:
1036
+ label_index = (labels >= 0).nonzero()
1037
+ labels = labels.long()
1038
+ if label_index.size(0) > 0:
1039
+ labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), logits.size(1)))
1040
+ labels = torch.gather(labels, 0, label_index.view(-1))
1041
+ loss_fct = CrossEntropyLoss()
1042
+ loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
1043
+ else:
1044
+ loss = torch.tensor(0).to(logits)
1045
+ else:
1046
+ log_softmax = torch.nn.LogSoftmax(-1)
1047
+ loss = -((log_softmax(logits) * labels).sum(-1)).mean()
1048
+ if not return_dict:
1049
+ output = (logits,) + outputs[1:]
1050
+ return ((loss,) + output) if loss is not None else output
1051
+ else:
1052
+ return SequenceClassifierOutput(
1053
+ loss=loss,
1054
+ logits=logits,
1055
+ hidden_states=outputs.hidden_states,
1056
+ attentions=outputs.attentions,
1057
+ )
1058
+
1059
+
1060
+ # GPT2 for classification
1061
+ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
1062
+
1063
+ def __init__(self, config):
1064
+ super().__init__(config)
1065
+ self.num_labels = config.num_labels
1066
+ self.transformer = GPT2Model(config)
1067
+ self.score = torch.nn.Linear(config.n_embd, self.num_labels, bias=False)
1068
+
1069
+ # Model parallel
1070
+ self.model_parallel = False
1071
+ self.device_map = None
1072
+
1073
+ # Initialize weights and apply final processing
1074
+ self.post_init()
1075
+
1076
+ def forward(
1077
+ self,
1078
+ input_ids: Optional[torch.LongTensor] = None,
1079
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1080
+ attention_mask: Optional[torch.FloatTensor] = None,
1081
+ token_type_ids: Optional[torch.LongTensor] = None,
1082
+ position_ids: Optional[torch.LongTensor] = None,
1083
+ head_mask: Optional[torch.FloatTensor] = None,
1084
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1085
+ labels: Optional[torch.LongTensor] = None,
1086
+ use_cache: Optional[bool] = None,
1087
+ output_attentions: Optional[bool] = None,
1088
+ output_hidden_states: Optional[bool] = None,
1089
+ return_dict: Optional[bool] = None,
1090
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1091
+ r"""
1092
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1093
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1094
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1095
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1096
+ """
1097
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1098
+
1099
+ transformer_outputs = self.transformer(
1100
+ input_ids,
1101
+ past_key_values=past_key_values,
1102
+ attention_mask=attention_mask,
1103
+ token_type_ids=token_type_ids,
1104
+ position_ids=position_ids,
1105
+ head_mask=head_mask,
1106
+ inputs_embeds=inputs_embeds,
1107
+ use_cache=use_cache,
1108
+ output_attentions=output_attentions,
1109
+ output_hidden_states=output_hidden_states,
1110
+ return_dict=return_dict,
1111
+ )
1112
+ hidden_states = transformer_outputs[0]
1113
+ logits = self.score(hidden_states)
1114
+
1115
+ if input_ids is not None:
1116
+ batch_size, sequence_length = input_ids.shape[:2]
1117
+ else:
1118
+ batch_size, sequence_length = inputs_embeds.shape[:2]
1119
+
1120
+ assert (
1121
+ self.config.pad_token_id is not None or batch_size == 1
1122
+ ), "Cannot handle batch sizes > 1 if no padding token is defined."
1123
+ if self.config.pad_token_id is None:
1124
+ sequence_lengths = -1
1125
+ else:
1126
+ if input_ids is not None:
1127
+ sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
1128
+ else:
1129
+ sequence_lengths = -1
1130
+ logger.warning(
1131
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1132
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1133
+ )
1134
+
1135
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1136
+
1137
+ loss = None
1138
+ if labels is not None:
1139
+ if self.config.problem_type is None:
1140
+ if self.num_labels == 1:
1141
+ self.config.problem_type = "regression"
1142
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1143
+ self.config.problem_type = "single_label_classification"
1144
+ else:
1145
+ self.config.problem_type = "multi_label_classification"
1146
+
1147
+ if self.config.problem_type == "regression":
1148
+ loss_fct = MSELoss()
1149
+ if self.num_labels == 1:
1150
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1151
+ else:
1152
+ loss = loss_fct(pooled_logits, labels)
1153
+ elif self.config.problem_type == "single_label_classification":
1154
+ loss_fct = CrossEntropyLoss()
1155
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1156
+ elif self.config.problem_type == "multi_label_classification":
1157
+ loss_fct = BCEWithLogitsLoss()
1158
+ loss = loss_fct(pooled_logits, labels)
1159
+ if not return_dict:
1160
+ output = (pooled_logits,) + transformer_outputs[1:]
1161
+ return ((loss,) + output) if loss is not None else output
1162
+
1163
+ return SequenceClassifierOutputWithPast(
1164
+ loss=loss,
1165
+ logits=pooled_logits,
1166
+ past_key_values=transformer_outputs.past_key_values,
1167
+ hidden_states=transformer_outputs.hidden_states,
1168
+ attentions=transformer_outputs.attentions,
1169
+ )
1170
+
1171
+
1172
+
1173
+
1174
+ # Bart for classification
1175
+ class BartForSequenceClassification(BartPretrainedModel):
1176
+ def __init__(self, config: BartConfig, **kwargs):
1177
+ super().__init__(config, **kwargs)
1178
+ self.model = BartModel(config)
1179
+ self.classification_head = BartClassificationHead(
1180
+ config.d_model,
1181
+ config.d_model,
1182
+ config.num_labels,
1183
+ config.classifier_dropout,
1184
+ )
1185
+ self.model._init_weights(self.classification_head.dense)
1186
+ self.model._init_weights(self.classification_head.out_proj)
1187
+
1188
+ def forward(
1189
+ self,
1190
+ input_ids: torch.LongTensor = None,
1191
+ attention_mask: Optional[torch.Tensor] = None,
1192
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1193
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1194
+ head_mask: Optional[torch.Tensor] = None,
1195
+ decoder_head_mask: Optional[torch.Tensor] = None,
1196
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1197
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
1198
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1199
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1200
+ labels: Optional[torch.LongTensor] = None,
1201
+ use_cache: Optional[bool] = None,
1202
+ output_attentions: Optional[bool] = None,
1203
+ output_hidden_states: Optional[bool] = None,
1204
+ return_dict: Optional[bool] = None,
1205
+ ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
1206
+ r"""
1207
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1208
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1209
+ config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1210
+ """
1211
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1212
+ if labels is not None:
1213
+ use_cache = False
1214
+
1215
+ if input_ids is None and inputs_embeds is not None:
1216
+ raise NotImplementedError(
1217
+ f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
1218
+ )
1219
+
1220
+ outputs = self.model(
1221
+ input_ids,
1222
+ attention_mask=attention_mask,
1223
+ decoder_input_ids=decoder_input_ids,
1224
+ decoder_attention_mask=decoder_attention_mask,
1225
+ head_mask=head_mask,
1226
+ decoder_head_mask=decoder_head_mask,
1227
+ cross_attn_head_mask=cross_attn_head_mask,
1228
+ encoder_outputs=encoder_outputs,
1229
+ inputs_embeds=inputs_embeds,
1230
+ decoder_inputs_embeds=decoder_inputs_embeds,
1231
+ use_cache=use_cache,
1232
+ output_attentions=output_attentions,
1233
+ output_hidden_states=output_hidden_states,
1234
+ return_dict=return_dict,
1235
+ )
1236
+ hidden_states = outputs[0] # last hidden state
1237
+ # print("hidden_states.shape=", hidden_states.shape) # [bz, seq_len, dim]
1238
+
1239
+ eos_mask = input_ids.eq(self.config.eos_token_id)
1240
+
1241
+ if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
1242
+ raise ValueError("All examples must have the same number of <eos> tokens.")
1243
+ sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
1244
+ :, -1, :
1245
+ ]
1246
+ logits = self.classification_head(sentence_representation)
1247
+
1248
+ loss = None
1249
+ if labels is not None:
1250
+ if self.config.problem_type is None:
1251
+ if self.config.num_labels == 1:
1252
+ self.config.problem_type = "regression"
1253
+ elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1254
+ self.config.problem_type = "single_label_classification"
1255
+ else:
1256
+ self.config.problem_type = "multi_label_classification"
1257
+
1258
+ if self.config.problem_type == "regression":
1259
+ loss_fct = MSELoss()
1260
+ if self.config.num_labels == 1:
1261
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1262
+ else:
1263
+ loss = loss_fct(logits, labels)
1264
+ elif self.config.problem_type == "single_label_classification":
1265
+ loss_fct = CrossEntropyLoss()
1266
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
1267
+ elif self.config.problem_type == "multi_label_classification":
1268
+ loss_fct = BCEWithLogitsLoss()
1269
+ loss = loss_fct(logits, labels)
1270
+ if not return_dict:
1271
+ output = (logits,) + outputs[1:]
1272
+ return ((loss,) + output) if loss is not None else output
1273
+
1274
+ return Seq2SeqSequenceClassifierOutput(
1275
+ loss=loss,
1276
+ logits=logits,
1277
+ past_key_values=outputs.past_key_values,
1278
+ decoder_hidden_states=outputs.decoder_hidden_states,
1279
+ decoder_attentions=outputs.decoder_attentions,
1280
+ cross_attentions=outputs.cross_attentions,
1281
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1282
+ encoder_hidden_states=outputs.encoder_hidden_states,
1283
+ encoder_attentions=outputs.encoder_attentions,
1284
+ )
models/sequence_classification/masked_prompt_cls.py ADDED
@@ -0,0 +1,2016 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom models for few-shot learning specific operations."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import transformers
6
+ import torch.nn.functional as F
7
+ from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction
8
+ from transformers.models.t5.modeling_t5 import T5ForConditionalGeneration
9
+ from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertForSequenceClassification, BertModel, BertOnlyMLMHead
10
+ from transformers.models.roberta.modeling_roberta import RobertaForSequenceClassification, RobertaModel, RobertaLMHead, RobertaClassificationHead, RobertaPreTrainedModel
11
+ from transformers.models.deberta_v2.modeling_deberta_v2 import DebertaV2PreTrainedModel, DebertaV2Model, StableDropout, ContextPooler, DebertaV2OnlyMLMHead
12
+ from transformers.models.deberta.modeling_deberta import DebertaPreTrainedModel, DebertaModel, StableDropout, ContextPooler, DebertaOnlyMLMHead
13
+ from transformers.modeling_outputs import SequenceClassifierOutput
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.models.bert.configuration_bert import BertConfig
16
+ import logging
17
+ from models.basic_modules.adapter import RobertaAdaModel, BertAdaModel
18
+ import os
19
+ from models.basic_modules.prefix_encoder import PrefixEncoder
20
+ from tools.model_utils.parameter_freeze import ParameterFreeze
21
+
22
+ freezer = ParameterFreeze()
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # Note: 如果mask_pos为None,请检查输入的模板是否有<mask>标记,是否修改data_collator文件
27
+
28
+ """
29
+ Vanilla Prompt-tuning BERT
30
+ """
31
+ class PromptBertForSequenceClassification(BertPreTrainedModel):
32
+
33
+ def __init__(self, config):
34
+ super().__init__(config)
35
+ self.num_labels = config.num_labels
36
+ self.pre_seq_len = self.config.pre_seq_len
37
+ self.hidden_size = self.config.hidden_size
38
+ # backbone
39
+ self.bert = BertModel(config)
40
+ if self.config.use_freezing:
41
+ self.bert = freezer.freeze_lm(self.bert)
42
+ # mlm head
43
+ self.cls = BertOnlyMLMHead(config)
44
+
45
+ self.init_weights()
46
+
47
+ # These attributes should be assigned once the model is initialized
48
+ self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device)
49
+
50
+ # For regression
51
+ self.lb = None
52
+ self.ub = None
53
+
54
+ # For label search.
55
+ self.return_full_softmax = None
56
+
57
+ def freeze_backbone(self, use_freezing: bool=True):
58
+ if use_freezing:
59
+ self.bert = freezer.freeze_lm(self.bert)
60
+ else:
61
+ self.bert = freezer.unfreeze_lm(self.bert)
62
+
63
+ def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False):
64
+ """
65
+ Encoding and obtain logits at masked position
66
+ """
67
+ if mask_pos is not None:
68
+ mask_pos = mask_pos.squeeze()
69
+ # Encode everything
70
+ if inputs_embeds is None:
71
+ outputs = self.bert(
72
+ input_ids,
73
+ attention_mask=attention_mask,
74
+ token_type_ids=token_type_ids
75
+ )
76
+ else:
77
+ outputs = self.bert(
78
+ None,
79
+ attention_mask=attention_mask,
80
+ token_type_ids=token_type_ids,
81
+ inputs_embeds=inputs_embeds
82
+ )
83
+ # Get <mask> token representation
84
+ sequence_output, pooled_output = outputs[:2]
85
+ sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
86
+ # Logits over vocabulary tokens
87
+ prediction_mask_scores = self.cls(sequence_mask_output)
88
+
89
+ # Exit early and only return mask logits.
90
+ if return_full_softmax:
91
+ return prediction_mask_scores
92
+
93
+ # Return logits for each label
94
+ logits = []
95
+ for label_id in range(len(self.label_word_list)):
96
+ logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
97
+ logits = torch.cat(logits, -1)
98
+
99
+ # Regression task
100
+ if self.config.num_labels == 1:
101
+ logsoftmax = nn.LogSoftmax(-1)
102
+ logits = logsoftmax(logits) # Log prob of right polarity
103
+
104
+ return logits, sequence_mask_output
105
+
106
+ def forward(
107
+ self,
108
+ input_ids=None,
109
+ attention_mask=None,
110
+ token_type_ids=None,
111
+ mask_pos=None,
112
+ labels=None,
113
+ inputs_embeds=None,
114
+ block_flag=None,
115
+ return_dict=None,
116
+ ):
117
+
118
+ logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds)
119
+ loss = None
120
+ if labels is not None:
121
+ if self.num_labels == 1:
122
+ # Regression task
123
+ loss_fct = nn.KLDivLoss(log_target=True)
124
+ labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
125
+ loss = loss_fct(logits.view(-1, 2), labels)
126
+ else:
127
+
128
+ if labels.shape == logits.shape:
129
+ loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32),
130
+ labels, reduction="batchmean")
131
+ else:
132
+ loss_fct = nn.CrossEntropyLoss()
133
+
134
+ loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
135
+
136
+ output = (logits,)
137
+ if self.num_labels == 1:
138
+ # Regression output
139
+ output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
140
+
141
+ if not return_dict:
142
+ return ((loss,) + output) if loss is not None else output
143
+
144
+ return SequenceClassifierOutput(
145
+ loss=loss,
146
+ logits=logits,
147
+ )
148
+
149
+
150
+
151
+ """
152
+ P-tuning BERT
153
+ """
154
+ class PromptBertPtuningForSequenceClassification(BertPreTrainedModel):
155
+
156
+ def __init__(self, config):
157
+ super().__init__(config)
158
+ self.num_labels = config.num_labels
159
+ self.pre_seq_len = self.config.pre_seq_len
160
+ self.hidden_size = self.config.hidden_size
161
+ # backbone
162
+ self.bert = BertModel(config)
163
+ if self.config.use_freezing:
164
+ self.bert = freezer.freeze_lm(self.bert)
165
+ # mlm head
166
+ self.cls = BertOnlyMLMHead(config)
167
+ # prompt encoder
168
+ self.prompt_encoder = None
169
+ # plm embedding layer
170
+ self.backbone_embeddings = self.bert.embeddings.word_embeddings
171
+ # prompt embedding layer
172
+ self.prompt_embeddings = torch.nn.Embedding(self.pre_seq_len, self.hidden_size)
173
+
174
+ self.init_weights()
175
+
176
+ # These attributes should be assigned once the model is initialized
177
+ self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device)
178
+
179
+ # For regression
180
+ self.lb = None
181
+ self.ub = None
182
+
183
+ # For label search.
184
+ self.return_full_softmax = None
185
+
186
+ def freeze_backbone(self, use_freezing: bool=True):
187
+ if use_freezing:
188
+ self.bert = freezer.freeze_lm(self.bert)
189
+ else:
190
+ self.bert = freezer.unfreeze_lm(self.bert)
191
+
192
+
193
+ def generate_continuous_prompt_inputs(self, input_ids, block_flag=None, reparameterization=False):
194
+ """
195
+ Generate continuous prompt embedding
196
+ """
197
+ inputs_embeds = self.backbone_embeddings(input_ids)
198
+
199
+ batch_size = inputs_embeds.shape[0]
200
+ if block_flag is None:
201
+ # the first token is set 1, others are set 0
202
+ block_flag = torch.zeros_like(input_ids).long().to(inputs_embeds.device)
203
+ block_flag[:, 0] = 1
204
+ try:
205
+ replace_embeds = self.prompt_embeddings(
206
+ torch.LongTensor(list(range(self.pre_seq_len))).to(inputs_embeds.device))
207
+ except:
208
+ import pdb
209
+ pdb.set_trace()
210
+ replace_embeds = self.prompt_embeddings(
211
+ torch.LongTensor(list(range(self.pre_seq_len))))
212
+ replace_embeds = replace_embeds.unsqueeze(0) # [batch_size, prompt_length, embed_size]
213
+
214
+ if self.prompt_encoder is not None:
215
+ replace_embeds = self.prompt_encoder(replace_embeds)
216
+
217
+ # edit by wjn
218
+ if reparameterization:
219
+ # blocked_indices = (block_flag == 1).nonzero(as_tuple=False).reshape((batch_size, self.pre_seq_len, 2))[:, :, 1]
220
+ blocked_indices = (block_flag == 1).nonzero()
221
+ # reparameterization
222
+ for bidx in range(batch_size):
223
+ for i in range(blocked_indices.shape[1]):
224
+ inputs_embeds[bidx, blocked_indices[bidx, i], :] = replace_embeds[:, i, :].squeeze()
225
+ else:
226
+ replace_embeds = replace_embeds.expand(batch_size, self.pre_seq_len, -1).to(inputs_embeds.device)
227
+ inputs_embeds = torch.cat((replace_embeds, inputs_embeds), dim=1)
228
+ return inputs_embeds
229
+
230
+ def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False):
231
+ """
232
+ Encoding and obtain logits at masked position
233
+ """
234
+ batch_size = inputs_embeds.shape[0]
235
+ if mask_pos is not None:
236
+ mask_pos = mask_pos.squeeze()
237
+ # Encode everything
238
+ if inputs_embeds is None:
239
+ outputs = self.bert(
240
+ input_ids,
241
+ attention_mask=attention_mask,
242
+ token_type_ids=token_type_ids
243
+ )
244
+ else:
245
+
246
+ if inputs_embeds.shape[1] == attention_mask.shape[1]:
247
+ outputs = self.bert(
248
+ None,
249
+ attention_mask=attention_mask,
250
+ token_type_ids=token_type_ids,
251
+ inputs_embeds=inputs_embeds
252
+ )
253
+ # Get <mask> token representation
254
+ sequence_output, pooled_output = outputs[:2]
255
+ # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
256
+ else:
257
+ if attention_mask is not None:
258
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).long().to(self.bert.device)
259
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
260
+ if token_type_ids is not None:
261
+ prefix_token_type_ids = torch.zeros(batch_size, self.pre_seq_len).long().to(self.bert.device)
262
+ token_type_ids = torch.cat((prefix_token_type_ids, token_type_ids), dim=1)
263
+ outputs = self.bert(
264
+ None,
265
+ attention_mask=attention_mask,
266
+ token_type_ids=token_type_ids,
267
+ inputs_embeds=inputs_embeds
268
+ )
269
+ # Get <mask> token representation
270
+ sequence_output, pooled_output = outputs[:2]
271
+ sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
272
+
273
+ sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
274
+ # Logits over vocabulary tokens
275
+ prediction_mask_scores = self.cls(sequence_mask_output)
276
+
277
+ # Exit early and only return mask logits.
278
+ if return_full_softmax:
279
+ return prediction_mask_scores
280
+
281
+ # Return logits for each label
282
+ logits = []
283
+ for label_id in range(len(self.label_word_list)):
284
+ logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
285
+ logits = torch.cat(logits, -1)
286
+
287
+ # Regression task
288
+ if self.config.num_labels == 1:
289
+ logsoftmax = nn.LogSoftmax(-1)
290
+ logits = logsoftmax(logits) # Log prob of right polarity
291
+
292
+ return logits, sequence_mask_output
293
+
294
+ def forward(
295
+ self,
296
+ input_ids=None,
297
+ attention_mask=None,
298
+ token_type_ids=None,
299
+ mask_pos=None,
300
+ labels=None,
301
+ inputs_embeds=None,
302
+ block_flag=None,
303
+ return_dict=None,
304
+ ):
305
+
306
+ inputs_embeds = self.generate_continuous_prompt_inputs(input_ids, block_flag)
307
+ logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds)
308
+ loss = None
309
+ if labels is not None:
310
+ if self.num_labels == 1:
311
+ # Regression task
312
+ loss_fct = nn.KLDivLoss(log_target=True)
313
+ labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
314
+ loss = loss_fct(logits.view(-1, 2), labels)
315
+ else:
316
+
317
+ if labels.shape == logits.shape:
318
+ loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32),
319
+ labels, reduction="batchmean")
320
+ else:
321
+ loss_fct = nn.CrossEntropyLoss()
322
+
323
+ loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
324
+
325
+ output = (logits,)
326
+ if self.num_labels == 1:
327
+ # Regression output
328
+ output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
329
+
330
+ if not return_dict:
331
+ return ((loss,) + output) if loss is not None else output
332
+
333
+ return SequenceClassifierOutput(
334
+ loss=loss,
335
+ logits=logits,
336
+ )
337
+
338
+
339
+
340
+ """
341
+ Prefix-tuning BERT
342
+ """
343
+ class PromptBertPrefixForSequenceClassification(BertPreTrainedModel):
344
+
345
+ def __init__(self, config):
346
+ super().__init__(config)
347
+
348
+
349
+ self.num_labels = config.num_labels
350
+ self.pre_seq_len = self.config.pre_seq_len
351
+ self.hidden_size = self.config.hidden_size
352
+
353
+ self.n_layer = config.num_hidden_layers
354
+ self.n_head = config.num_attention_heads
355
+ self.n_embd = config.hidden_size // config.num_attention_heads
356
+
357
+ # backbone
358
+ self.bert = BertModel(config)
359
+ if self.config.use_freezing:
360
+ self.bert = freezer.freeze_lm(self.bert)
361
+ # mlm head
362
+ self.cls = BertOnlyMLMHead(config)
363
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
364
+ # plm embedding layer
365
+ self.backbone_embeddings = self.bert.embeddings.word_embeddings
366
+ # prompt embedding layer
367
+ self.prompt_embeddings = torch.nn.Embedding(self.pre_seq_len, self.hidden_size)
368
+ # prefix encoder
369
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
370
+ self.prefix_encoder = PrefixEncoder(config)
371
+
372
+ self.init_weights()
373
+
374
+ # These attributes should be assigned once the model is initialized
375
+ self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device)
376
+
377
+ # For regression
378
+ self.lb = None
379
+ self.ub = None
380
+
381
+ # For label search.
382
+ self.return_full_softmax = None
383
+
384
+ # For regression
385
+ self.lb = None
386
+ self.ub = None
387
+
388
+ # For label search.
389
+ self.return_full_softmax = None
390
+
391
+ def freeze_backbone(self, use_freezing: bool=True):
392
+ if use_freezing:
393
+ self.bert = freezer.freeze_lm(self.bert)
394
+ else:
395
+ self.bert = freezer.unfreeze_lm(self.bert)
396
+
397
+ def get_prompt(self, batch_size):
398
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
399
+ past_key_values = self.prefix_encoder(prefix_tokens)
400
+ # bsz, seqlen, _ = past_key_values.shape
401
+ past_key_values = past_key_values.view(
402
+ batch_size,
403
+ self.pre_seq_len,
404
+ self.n_layer * 2,
405
+ self.n_head,
406
+ self.n_embd
407
+ )
408
+ past_key_values = self.dropout(past_key_values)
409
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
410
+ return past_key_values
411
+
412
+ def embed_encode(self, input_ids):
413
+ embedding_output = self.bert.embeddings.word_embeddings(input_ids)
414
+ return embedding_output
415
+
416
+ def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False):
417
+ batch_size = input_ids.size(0)
418
+
419
+ # add prefix for prompt-tuning
420
+ past_key_values = self.get_prompt(batch_size=batch_size)
421
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
422
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
423
+
424
+ if mask_pos is not None:
425
+ mask_pos = mask_pos.squeeze()
426
+
427
+ # Encode everything
428
+ outputs = self.bert(
429
+ input_ids,
430
+ attention_mask=attention_mask,
431
+ token_type_ids=token_type_ids,
432
+ past_key_values=past_key_values,
433
+ )
434
+ # Get <mask> token representation
435
+ sequence_output, pooled_output = outputs[:2]
436
+ # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
437
+ sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
438
+
439
+ # Logits over vocabulary tokens
440
+ prediction_mask_scores = self.cls(sequence_mask_output)
441
+
442
+ # Exit early and only return mask logits.
443
+ if return_full_softmax:
444
+ return prediction_mask_scores
445
+
446
+ # print("prediction_mask_scores.shape=", prediction_mask_scores.shape) # [batch_size, seq_len, vocab_size]
447
+
448
+ # Return logits for each label
449
+ logits = []
450
+ for label_id in range(len(self.label_word_list)):
451
+ logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
452
+ logits = torch.cat(logits, -1)
453
+
454
+ # Regression task
455
+ if self.config.num_labels == 1:
456
+ logsoftmax = nn.LogSoftmax(-1)
457
+ logits = logsoftmax(logits) # Log prob of right polarity
458
+
459
+ return logits, sequence_mask_output
460
+
461
+
462
+ def forward(
463
+ self,
464
+ input_ids=None,
465
+ attention_mask=None,
466
+ token_type_ids=None,
467
+ mask_pos=None,
468
+ labels=None,
469
+ inputs_embeds=None,
470
+ block_flag=None,
471
+ return_dict=None,
472
+ ):
473
+
474
+ logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds)
475
+
476
+ loss = None
477
+ if labels is not None:
478
+ if self.num_labels == 1:
479
+ # Regression task
480
+ loss_fct = nn.KLDivLoss(log_target=True)
481
+ labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
482
+ loss = loss_fct(logits.view(-1, 2), labels)
483
+ else:
484
+
485
+ if labels.shape == logits.shape:
486
+ loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32),
487
+ labels, reduction="batchmean")
488
+ else:
489
+ loss_fct = nn.CrossEntropyLoss()
490
+
491
+ loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
492
+
493
+ output = (logits,)
494
+ if self.num_labels == 1:
495
+ # Regression output
496
+ output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
497
+
498
+ if not return_dict:
499
+ return ((loss,) + output) if loss is not None else output
500
+
501
+ return SequenceClassifierOutput(
502
+ loss=loss,
503
+ logits=logits,
504
+ )
505
+
506
+
507
+ """
508
+ Adapter-tuning BERT
509
+ """
510
+ class PromptBertAdapterForSequenceClassification(BertPreTrainedModel):
511
+
512
+ def __init__(self, config):
513
+ super().__init__(config)
514
+ self.num_labels = config.num_labels
515
+ self.bert = BertAdaModel(config)
516
+ self.cls = BertOnlyMLMHead(config)
517
+ self.init_weights()
518
+
519
+ if self.config.use_freezing:
520
+ self.bert = freezer.freeze_lm_component(self.bert, "adapter")
521
+
522
+ # These attributes should be assigned once the model is initialized
523
+ self.model_args = None
524
+ self.data_args = None
525
+ self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device)
526
+
527
+ # For regression
528
+ self.lb = None
529
+ self.ub = None
530
+
531
+ # For label search.
532
+ self.return_full_softmax = None
533
+
534
+ def freeze_backbone(self, use_freezing: bool=True):
535
+ if use_freezing:
536
+ self.bert = freezer.freeze_lm_component(self.bert, "adapter")
537
+ else:
538
+ self.bert = freezer.unfreeze_lm(self.bert)
539
+
540
+ def embed_encode(self, input_ids):
541
+ embedding_output = self.bert.embeddings.word_embeddings(input_ids)
542
+ return embedding_output
543
+
544
+ def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False):
545
+ batch_size = input_ids.size(0)
546
+
547
+ if mask_pos is not None:
548
+ mask_pos = mask_pos.squeeze()
549
+
550
+ # Encode everything
551
+ if inputs_embeds is None:
552
+ outputs = self.bert(
553
+ input_ids,
554
+ attention_mask=attention_mask,
555
+ token_type_ids=token_type_ids
556
+ )
557
+ else:
558
+ outputs = self.bert(
559
+ None,
560
+ attention_mask=attention_mask,
561
+ token_type_ids=token_type_ids,
562
+ inputs_embeds=inputs_embeds
563
+ )
564
+
565
+ # Get <mask> token representation
566
+ sequence_output, pooled_output = outputs[:2]
567
+ sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
568
+
569
+ # Logits over vocabulary tokens
570
+ prediction_mask_scores = self.cls(sequence_mask_output)
571
+
572
+ # Exit early and only return mask logits.
573
+ if return_full_softmax:
574
+ return prediction_mask_scores
575
+
576
+ # Return logits for each label
577
+ logits = []
578
+ for label_id in range(len(self.label_word_list)):
579
+ logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
580
+ logits = torch.cat(logits, -1)
581
+
582
+ # Regression task
583
+ if self.config.num_labels == 1:
584
+ logsoftmax = nn.LogSoftmax(-1)
585
+ logits = logsoftmax(logits) # Log prob of right polarity
586
+
587
+ return logits, sequence_mask_output
588
+
589
+
590
+ def forward(
591
+ self,
592
+ input_ids=None,
593
+ attention_mask=None,
594
+ token_type_ids=None,
595
+ mask_pos=None,
596
+ labels=None,
597
+ inputs_embeds=None,
598
+ block_flag=None,
599
+ return_dict=None,
600
+ ):
601
+
602
+ logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds)
603
+
604
+ loss = None
605
+ if labels is not None:
606
+ if self.num_labels == 1:
607
+ # Regression task
608
+ loss_fct = nn.KLDivLoss(log_target=True)
609
+ labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
610
+ loss = loss_fct(logits.view(-1, 2), labels)
611
+ else:
612
+
613
+ if labels.shape == logits.shape:
614
+ loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32),
615
+ labels, reduction="batchmean")
616
+ else:
617
+ loss_fct = nn.CrossEntropyLoss()
618
+
619
+ loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
620
+
621
+ output = (logits,)
622
+ if self.num_labels == 1:
623
+ # Regression output
624
+ output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
625
+
626
+ if not return_dict:
627
+ return ((loss,) + output) if loss is not None else output
628
+
629
+ return SequenceClassifierOutput(
630
+ loss=loss,
631
+ logits=logits,
632
+ )
633
+
634
+
635
+
636
+ """
637
+ Vanilla Prompt-tuning RoBERTa
638
+ """
639
+ class PromptRobertaForSequenceClassification(RobertaPreTrainedModel):
640
+
641
+ def __init__(self, config):
642
+ super().__init__(config)
643
+ self.num_labels = config.num_labels
644
+ self.pre_seq_len = self.config.pre_seq_len
645
+ self.hidden_size = self.config.hidden_size
646
+ # backbone
647
+ self.roberta = RobertaModel(config)
648
+ if self.config.use_freezing:
649
+ self.roberta = freezer.freeze_lm(self.roberta)
650
+ # mlm head
651
+ self.cls = RobertaLMHead(config)
652
+
653
+ self.init_weights()
654
+
655
+ # These attributes should be assigned once the model is initialized
656
+ self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.roberta.device)
657
+
658
+ # For regression
659
+ self.lb = None
660
+ self.ub = None
661
+
662
+ # For label search.
663
+ self.return_full_softmax = None
664
+
665
+ def freeze_backbone(self, use_freezing: bool=True):
666
+ if use_freezing:
667
+ self.roberta = freezer.freeze_lm(self.roberta)
668
+ else:
669
+ self.roberta = freezer.unfreeze_lm(self.roberta)
670
+
671
+ def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False):
672
+ """
673
+ Encoding and obtain logits at masked position
674
+ """
675
+ if mask_pos is not None:
676
+ mask_pos = mask_pos.squeeze()
677
+ # Encode everything
678
+ if inputs_embeds is None:
679
+ outputs = self.roberta(
680
+ input_ids,
681
+ attention_mask=attention_mask,
682
+ token_type_ids=token_type_ids
683
+ )
684
+ else:
685
+ outputs = self.roberta(
686
+ None,
687
+ attention_mask=attention_mask,
688
+ token_type_ids=token_type_ids,
689
+ inputs_embeds=inputs_embeds
690
+ )
691
+ # Get <mask> token representation
692
+ sequence_output, pooled_output = outputs[:2]
693
+ sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
694
+ # Logits over vocabulary tokens
695
+ prediction_mask_scores = self.cls(sequence_mask_output)
696
+
697
+ # Exit early and only return mask logits.
698
+ if return_full_softmax:
699
+ return prediction_mask_scores
700
+
701
+ # Return logits for each label
702
+ logits = []
703
+ for label_id in range(len(self.label_word_list)):
704
+ logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
705
+ logits = torch.cat(logits, -1)
706
+
707
+ # Regression task
708
+ if self.config.num_labels == 1:
709
+ logsoftmax = nn.LogSoftmax(-1)
710
+ logits = logsoftmax(logits) # Log prob of right polarity
711
+
712
+ return logits, sequence_mask_output
713
+
714
+ def forward(
715
+ self,
716
+ input_ids=None,
717
+ attention_mask=None,
718
+ token_type_ids=None,
719
+ mask_pos=None,
720
+ labels=None,
721
+ inputs_embeds=None,
722
+ block_flag=None,
723
+ return_dict=None,
724
+ ):
725
+
726
+ logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds)
727
+ loss = None
728
+ if labels is not None:
729
+ if self.num_labels == 1:
730
+ # Regression task
731
+ loss_fct = nn.KLDivLoss(log_target=True)
732
+ labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
733
+ loss = loss_fct(logits.view(-1, 2), labels)
734
+ else:
735
+
736
+ if labels.shape == logits.shape:
737
+ loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32),
738
+ labels, reduction="batchmean")
739
+ else:
740
+ loss_fct = nn.CrossEntropyLoss()
741
+
742
+ loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
743
+
744
+ output = (logits,)
745
+ if self.num_labels == 1:
746
+ # Regression output
747
+ output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
748
+
749
+ if not return_dict:
750
+ return ((loss,) + output) if loss is not None else output
751
+
752
+ return SequenceClassifierOutput(
753
+ loss=loss,
754
+ logits=logits,
755
+ )
756
+
757
+
758
+ """
759
+ P-tuning RoBERTa
760
+ """
761
+ class PromptRobertaPtuningForSequenceClassification(RobertaPreTrainedModel):
762
+
763
+ def __init__(self, config):
764
+ super().__init__(config)
765
+ self.num_labels = config.num_labels
766
+ self.pre_seq_len = self.config.pre_seq_len
767
+ self.hidden_size = self.config.hidden_size
768
+ # backbone
769
+ self.roberta = RobertaModel(config)
770
+ if self.config.use_freezing:
771
+ self.roberta = freezer.freeze_lm(self.roberta)
772
+ # mlm head
773
+ self.cls = RobertaLMHead(config)
774
+ # prompt encoder
775
+ self.prompt_encoder = None
776
+ # plm embedding layer
777
+ self.backbone_embeddings = self.roberta.embeddings.word_embeddings
778
+ # prompt embedding layer
779
+ self.prompt_embeddings = torch.nn.Embedding(self.pre_seq_len, self.hidden_size)
780
+
781
+ self.init_weights()
782
+
783
+ # These attributes should be assigned once the model is initialized
784
+ self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.roberta.device)
785
+
786
+ # For regression
787
+ self.lb = None
788
+ self.ub = None
789
+
790
+ # For label search.
791
+ self.return_full_softmax = None
792
+
793
+ def freeze_backbone(self, use_freezing: bool=True):
794
+ if use_freezing:
795
+ self.roberta = freezer.freeze_lm(self.roberta)
796
+ else:
797
+ self.roberta = freezer.unfreeze_lm(self.roberta)
798
+
799
+
800
+ def generate_continuous_prompt_inputs(self, input_ids, block_flag=None, reparameterization=False):
801
+ """
802
+ Generate continuous prompt embedding
803
+ """
804
+ inputs_embeds = self.backbone_embeddings(input_ids)
805
+
806
+ batch_size = inputs_embeds.shape[0]
807
+ if block_flag is None:
808
+ # the first token is set 1, others are set 0
809
+ block_flag = torch.zeros_like(input_ids).long().to(inputs_embeds.device)
810
+ block_flag[:, 0] = 1
811
+ try:
812
+ replace_embeds = self.prompt_embeddings(
813
+ torch.LongTensor(list(range(self.pre_seq_len))).to(inputs_embeds.device))
814
+ except:
815
+ import pdb
816
+ pdb.set_trace()
817
+ replace_embeds = self.prompt_embeddings(torch.LongTensor(list(range(self.pre_seq_len))))
818
+ replace_embeds = replace_embeds.unsqueeze(0) # [batch_size, prompt_length, embed_size]
819
+
820
+ if self.prompt_encoder is not None:
821
+ replace_embeds = self.prompt_encoder(replace_embeds)
822
+
823
+ # edit by wjn
824
+ if reparameterization:
825
+ # blocked_indices = (block_flag == 1).nonzero(as_tuple=False).reshape((batch_size, self.pre_seq_len, 2))[:, :, 1]
826
+ blocked_indices = (block_flag == 1).nonzero()
827
+ # reparameterization
828
+ for bidx in range(batch_size):
829
+ for i in range(blocked_indices.shape[1]):
830
+ inputs_embeds[bidx, blocked_indices[bidx, i], :] = replace_embeds[:, i, :].squeeze()
831
+ else:
832
+ replace_embeds = replace_embeds.expand(batch_size, self.pre_seq_len, -1).to(inputs_embeds.device)
833
+ inputs_embeds = torch.cat((replace_embeds, inputs_embeds), dim=1)
834
+ return inputs_embeds
835
+
836
+ def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False):
837
+ """
838
+ Encoding and obtain logits at masked position
839
+ """
840
+ batch_size = inputs_embeds.shape[0]
841
+ if mask_pos is not None:
842
+ mask_pos = mask_pos.squeeze()
843
+ # Encode everything
844
+ if inputs_embeds is None:
845
+ outputs = self.roberta(
846
+ input_ids,
847
+ attention_mask=attention_mask,
848
+ token_type_ids=token_type_ids
849
+ )
850
+ else:
851
+
852
+ if inputs_embeds.shape[1] == attention_mask.shape[1]:
853
+ outputs = self.roberta(
854
+ None,
855
+ attention_mask=attention_mask,
856
+ token_type_ids=token_type_ids,
857
+ inputs_embeds=inputs_embeds
858
+ )
859
+ # Get <mask> token representation
860
+ sequence_output, pooled_output = outputs[:2]
861
+ # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
862
+ else:
863
+ if attention_mask is not None:
864
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).long().to(self.roberta.device)
865
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
866
+ if token_type_ids is not None:
867
+ prefix_token_type_ids = torch.zeros(batch_size, self.pre_seq_len).long().to(self.roberta.device)
868
+ token_type_ids = torch.cat((prefix_token_type_ids, token_type_ids), dim=1)
869
+ outputs = self.roberta(
870
+ None,
871
+ attention_mask=attention_mask,
872
+ token_type_ids=token_type_ids,
873
+ inputs_embeds=inputs_embeds
874
+ )
875
+ # Get <mask> token representation
876
+ sequence_output, pooled_output = outputs[:2]
877
+ sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
878
+
879
+ sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
880
+ # Logits over vocabulary tokens
881
+ prediction_mask_scores = self.cls(sequence_mask_output)
882
+
883
+ # Exit early and only return mask logits.
884
+ if return_full_softmax:
885
+ return prediction_mask_scores
886
+
887
+ # Return logits for each label
888
+ logits = []
889
+ for label_id in range(len(self.label_word_list)):
890
+ logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
891
+ logits = torch.cat(logits, -1)
892
+
893
+ # Regression task
894
+ if self.config.num_labels == 1:
895
+ logsoftmax = nn.LogSoftmax(-1)
896
+ logits = logsoftmax(logits) # Log prob of right polarity
897
+
898
+ return logits, sequence_mask_output
899
+
900
+ def forward(
901
+ self,
902
+ input_ids=None,
903
+ attention_mask=None,
904
+ token_type_ids=None,
905
+ mask_pos=None,
906
+ labels=None,
907
+ inputs_embeds=None,
908
+ block_flag=None,
909
+ return_dict=None,
910
+ ):
911
+
912
+ inputs_embeds = self.generate_continuous_prompt_inputs(input_ids, block_flag)
913
+ logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds)
914
+ loss = None
915
+ if labels is not None:
916
+ if self.num_labels == 1:
917
+ # Regression task
918
+ loss_fct = nn.KLDivLoss(log_target=True)
919
+ labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
920
+ loss = loss_fct(logits.view(-1, 2), labels)
921
+ else:
922
+
923
+ if labels.shape == logits.shape:
924
+ loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32),
925
+ labels, reduction="batchmean")
926
+ else:
927
+ loss_fct = nn.CrossEntropyLoss()
928
+
929
+ loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
930
+
931
+ output = (logits,)
932
+ if self.num_labels == 1:
933
+ # Regression output
934
+ output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
935
+
936
+ if not return_dict:
937
+ return ((loss,) + output) if loss is not None else output
938
+
939
+ return SequenceClassifierOutput(
940
+ loss=loss,
941
+ logits=logits,
942
+ )
943
+
944
+
945
+ """
946
+ Prefix-tuning RoBERTa
947
+ """
948
+ class PromptRobertaPrefixForSequenceClassification(RobertaPreTrainedModel):
949
+
950
+ def __init__(self, config):
951
+ super().__init__(config)
952
+
953
+
954
+ self.num_labels = config.num_labels
955
+ self.pre_seq_len = self.config.pre_seq_len
956
+ self.hidden_size = self.config.hidden_size
957
+
958
+ self.n_layer = config.num_hidden_layers
959
+ self.n_head = config.num_attention_heads
960
+ self.n_embd = config.hidden_size // config.num_attention_heads
961
+
962
+ # backbone
963
+ self.robert = RobertaModel(config)
964
+ if self.config.use_freezing:
965
+ self.robert = freezer.freeze_lm(self.robert)
966
+ # mlm head
967
+ self.cls = RobertaLMHead(config)
968
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
969
+ # plm embedding layer
970
+ self.backbone_embeddings = self.robert.embeddings.word_embeddings
971
+ # prompt embedding layer
972
+ self.prompt_embeddings = torch.nn.Embedding(self.pre_seq_len, self.hidden_size)
973
+ # prefix encoder
974
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
975
+ self.prefix_encoder = PrefixEncoder(config)
976
+
977
+ self.init_weights()
978
+
979
+ # These attributes should be assigned once the model is initialized
980
+ self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.robert.device)
981
+
982
+ # For regression
983
+ self.lb = None
984
+ self.ub = None
985
+
986
+ # For label search.
987
+ self.return_full_softmax = None
988
+
989
+ # For regression
990
+ self.lb = None
991
+ self.ub = None
992
+
993
+ # For label search.
994
+ self.return_full_softmax = None
995
+
996
+ def freeze_backbone(self, use_freezing: bool=True):
997
+ if use_freezing:
998
+ self.robert = freezer.freeze_lm(self.robert)
999
+ else:
1000
+ self.robert = freezer.unfreeze_lm(self.robert)
1001
+
1002
+ def get_prompt(self, batch_size):
1003
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.robert.device)
1004
+ past_key_values = self.prefix_encoder(prefix_tokens)
1005
+ # bsz, seqlen, _ = past_key_values.shape
1006
+ past_key_values = past_key_values.view(
1007
+ batch_size,
1008
+ self.pre_seq_len,
1009
+ self.n_layer * 2,
1010
+ self.n_head,
1011
+ self.n_embd
1012
+ )
1013
+ past_key_values = self.dropout(past_key_values)
1014
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
1015
+ return past_key_values
1016
+
1017
+ def embed_encode(self, input_ids):
1018
+ embedding_output = self.robert.embeddings.word_embeddings(input_ids)
1019
+ return embedding_output
1020
+
1021
+ def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False):
1022
+ batch_size = input_ids.size(0)
1023
+
1024
+ # add prefix for prompt-tuning
1025
+ past_key_values = self.get_prompt(batch_size=batch_size)
1026
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.robert.device)
1027
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
1028
+
1029
+ if mask_pos is not None:
1030
+ mask_pos = mask_pos.squeeze()
1031
+
1032
+ # Encode everything
1033
+ outputs = self.robert(
1034
+ input_ids,
1035
+ attention_mask=attention_mask,
1036
+ token_type_ids=token_type_ids,
1037
+ past_key_values=past_key_values,
1038
+ )
1039
+ # Get <mask> token representation
1040
+ sequence_output, pooled_output = outputs[:2]
1041
+ # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
1042
+ sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
1043
+
1044
+ # Logits over vocabulary tokens
1045
+ prediction_mask_scores = self.cls(sequence_mask_output)
1046
+
1047
+ # Exit early and only return mask logits.
1048
+ if return_full_softmax:
1049
+ return prediction_mask_scores
1050
+
1051
+ # Return logits for each label
1052
+ logits = []
1053
+ for label_id in range(len(self.label_word_list)):
1054
+ logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
1055
+ logits = torch.cat(logits, -1)
1056
+
1057
+ # Regression task
1058
+ if self.config.num_labels == 1:
1059
+ logsoftmax = nn.LogSoftmax(-1)
1060
+ logits = logsoftmax(logits) # Log prob of right polarity
1061
+
1062
+ return logits, sequence_mask_output
1063
+
1064
+
1065
+ def forward(
1066
+ self,
1067
+ input_ids=None,
1068
+ attention_mask=None,
1069
+ token_type_ids=None,
1070
+ mask_pos=None,
1071
+ labels=None,
1072
+ inputs_embeds=None,
1073
+ block_flag=None,
1074
+ return_dict=None,
1075
+ ):
1076
+
1077
+ logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds)
1078
+
1079
+ loss = None
1080
+ if labels is not None:
1081
+ if self.num_labels == 1:
1082
+ # Regression task
1083
+ loss_fct = nn.KLDivLoss(log_target=True)
1084
+ labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
1085
+ loss = loss_fct(logits.view(-1, 2), labels)
1086
+ else:
1087
+
1088
+ if labels.shape == logits.shape:
1089
+ loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32),
1090
+ labels, reduction="batchmean")
1091
+ else:
1092
+ loss_fct = nn.CrossEntropyLoss()
1093
+
1094
+ loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
1095
+
1096
+ output = (logits,)
1097
+ if self.num_labels == 1:
1098
+ # Regression output
1099
+ output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
1100
+
1101
+ if not return_dict:
1102
+ return ((loss,) + output) if loss is not None else output
1103
+
1104
+ return SequenceClassifierOutput(
1105
+ loss=loss,
1106
+ logits=logits,
1107
+ )
1108
+
1109
+ """
1110
+ Adapter-tuning RoBERTa
1111
+ """
1112
+ class PromptRobertaAdapterForSequenceClassification(RobertaPreTrainedModel):
1113
+
1114
+ def __init__(self, config):
1115
+ super().__init__(config)
1116
+ self.num_labels = config.num_labels
1117
+ self.roberta = RobertaAdaModel(config)
1118
+ self.cls = RobertaLMHead(config)
1119
+ self.init_weights()
1120
+
1121
+ if self.config.use_freezing:
1122
+ self.roberta = freezer.freeze_lm_component(self.roberta, "adapter")
1123
+
1124
+ # These attributes should be assigned once the model is initialized
1125
+ self.model_args = None
1126
+ self.data_args = None
1127
+ self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.roberta.device)
1128
+
1129
+ # For regression
1130
+ self.lb = None
1131
+ self.ub = None
1132
+
1133
+ # For label search.
1134
+ self.return_full_softmax = None
1135
+
1136
+ def freeze_backbone(self, use_freezing: bool=True):
1137
+ if use_freezing:
1138
+ self.roberta = freezer.freeze_lm_component(self.roberta, "adapter")
1139
+ else:
1140
+ self.roberta = freezer.unfreeze_lm(self.berobertart)
1141
+
1142
+ def embed_encode(self, input_ids):
1143
+ embedding_output = self.roberta.embeddings.word_embeddings(input_ids)
1144
+ return embedding_output
1145
+
1146
+ def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False):
1147
+ batch_size = input_ids.size(0)
1148
+
1149
+ if mask_pos is not None:
1150
+ mask_pos = mask_pos.squeeze()
1151
+
1152
+ # Encode everything
1153
+ if inputs_embeds is None:
1154
+ outputs = self.roberta(
1155
+ input_ids,
1156
+ attention_mask=attention_mask,
1157
+ token_type_ids=token_type_ids
1158
+ )
1159
+ else:
1160
+ outputs = self.roberta(
1161
+ None,
1162
+ attention_mask=attention_mask,
1163
+ token_type_ids=token_type_ids,
1164
+ inputs_embeds=inputs_embeds
1165
+ )
1166
+
1167
+ # Get <mask> token representation
1168
+ sequence_output, pooled_output = outputs[:2]
1169
+ sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
1170
+
1171
+ # Logits over vocabulary tokens
1172
+ prediction_mask_scores = self.cls(sequence_mask_output)
1173
+
1174
+ # Exit early and only return mask logits.
1175
+ if return_full_softmax:
1176
+ return prediction_mask_scores
1177
+
1178
+ # Return logits for each label
1179
+ logits = []
1180
+ for label_id in range(len(self.label_word_list)):
1181
+ logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
1182
+ logits = torch.cat(logits, -1)
1183
+
1184
+ # Regression task
1185
+ if self.config.num_labels == 1:
1186
+ logsoftmax = nn.LogSoftmax(-1)
1187
+ logits = logsoftmax(logits) # Log prob of right polarity
1188
+
1189
+ return logits, sequence_mask_output
1190
+
1191
+
1192
+ def forward(
1193
+ self,
1194
+ input_ids=None,
1195
+ attention_mask=None,
1196
+ token_type_ids=None,
1197
+ mask_pos=None,
1198
+ labels=None,
1199
+ inputs_embeds=None,
1200
+ block_flag=None,
1201
+ return_dict=None,
1202
+ ):
1203
+
1204
+ logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds)
1205
+
1206
+ loss = None
1207
+ if labels is not None:
1208
+ if self.num_labels == 1:
1209
+ # Regression task
1210
+ loss_fct = nn.KLDivLoss(log_target=True)
1211
+ labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
1212
+ loss = loss_fct(logits.view(-1, 2), labels)
1213
+ else:
1214
+
1215
+ if labels.shape == logits.shape:
1216
+ loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32),
1217
+ labels, reduction="batchmean")
1218
+ else:
1219
+ loss_fct = nn.CrossEntropyLoss()
1220
+
1221
+ loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
1222
+
1223
+ output = (logits,)
1224
+ if self.num_labels == 1:
1225
+ # Regression output
1226
+ output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
1227
+
1228
+ if not return_dict:
1229
+ return ((loss,) + output) if loss is not None else output
1230
+
1231
+ return SequenceClassifierOutput(
1232
+ loss=loss,
1233
+ logits=logits,
1234
+ )
1235
+
1236
+
1237
+ # class DebertaForPromptFinetuning(DebertaPreTrainedModel):
1238
+ # _keys_to_ignore_on_load_unexpected = [r"pooler"]
1239
+ # _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1240
+
1241
+ # def __init__(self, config):
1242
+ # super().__init__(config)
1243
+ # self.num_labels = config.num_labels
1244
+ # #self.deberta = DebertaV2Model(config)
1245
+
1246
+ # self.deberta = DebertaModel(config)
1247
+ # self.cls = DebertaOnlyMLMHead(config)
1248
+
1249
+ # if self.config.use_freezing:
1250
+ # self.deberta = freezer.freeze_lm(self.deberta)
1251
+
1252
+ # self.pooler = ContextPooler(config)
1253
+ # output_dim = self.pooler.output_dim
1254
+
1255
+ # self.classifier = torch.nn.Linear(output_dim, self.num_labels)
1256
+ # drop_out = getattr(config, "cls_dropout", None)
1257
+ # drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
1258
+
1259
+ # self.dropout = StableDropout(drop_out)
1260
+
1261
+ # classification_list = [self.pooler, self.dropout,self.classifier]
1262
+
1263
+ # self.classifier = nn.Sequential(*classification_list)
1264
+ # # self.cls = DebertaV2OnlyMLMHead(config)
1265
+
1266
+ # self.map = nn.Linear(config.hidden_size, config.hidden_size)
1267
+ # self.init_weights()
1268
+
1269
+ # # These attributes should be assigned once the model is initialized
1270
+ # self.model_args = None
1271
+ # self.data_args = None
1272
+ # self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device)
1273
+ # self.K = 1
1274
+ # self.step_size=1e-5
1275
+ # # import pdb
1276
+ # # pdb.set_trace()
1277
+ # #self.step_size=config.step_size
1278
+
1279
+ # # For regression
1280
+ # self.lb = None
1281
+ # self.ub = None
1282
+
1283
+ # self.pre_seq_len = self.config.pre_seq_len
1284
+ # # For auto label search.
1285
+ # self.return_full_softmax = None
1286
+
1287
+ # def freeze_backbone(self, use_freezing: bool=True):
1288
+ # if use_freezing:
1289
+ # self.deberta = freezer.freeze_lm(self.deberta)
1290
+ # else:
1291
+ # self.deberta = freezer.unfreeze_lm(self.deberta)
1292
+
1293
+
1294
+
1295
+ # def embed_encode(self, input_ids):
1296
+ # embedding_output = self.deberta.embeddings.word_embeddings(input_ids)
1297
+ # return embedding_output
1298
+
1299
+ # def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None,
1300
+ # return_full_softmax=False):
1301
+ # batch_size = input_ids.size(0)
1302
+
1303
+ # if mask_pos is not None:
1304
+ # mask_pos = mask_pos.squeeze()
1305
+
1306
+
1307
+ # # Encode everything
1308
+ # if inputs_embeds is None:
1309
+ # outputs = self.deberta(
1310
+ # input_ids,
1311
+ # attention_mask=attention_mask,
1312
+ # token_type_ids=token_type_ids
1313
+ # )
1314
+ # else:
1315
+ # outputs = self.deberta(
1316
+ # None,
1317
+ # attention_mask=attention_mask,
1318
+ # token_type_ids=token_type_ids,
1319
+ # inputs_embeds=inputs_embeds
1320
+ # )
1321
+
1322
+ # # Get <mask> token representation
1323
+ # sequence_output = outputs[0]
1324
+ # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
1325
+ # sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
1326
+
1327
+ # # Logits over vocabulary tokens
1328
+ # prediction_mask_scores = self.cls(sequence_mask_output)
1329
+
1330
+ # # sequence_mask_output = self.lm_head.dense(sequence_mask_output)
1331
+
1332
+ # # Exit early and only return mask logits.
1333
+ # if return_full_softmax:
1334
+ # return prediction_mask_scores
1335
+
1336
+ # # Return logits for each label
1337
+ # logits = []
1338
+ # for label_id in range(len(self.label_word_list)):
1339
+ # logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
1340
+ # logits = torch.cat(logits, -1)
1341
+
1342
+ # # Regression task
1343
+ # if self.config.num_labels == 1:
1344
+ # logsoftmax = nn.LogSoftmax(-1)
1345
+ # logits = logsoftmax(logits) # Log prob of right polarity
1346
+
1347
+ # if self.model_args.hybrid == 1:
1348
+ # cls_logits = self.classifier(sequence_output)
1349
+ # return (logits, cls_logits), sequence_mask_output
1350
+
1351
+ # return logits, sequence_mask_output
1352
+
1353
+ # def forward(
1354
+ # self,
1355
+ # input_ids=None,
1356
+ # attention_mask=None,
1357
+ # token_type_ids=None,
1358
+ # mask_pos=None,
1359
+ # labels=None,
1360
+ # inputs_embeds=None,
1361
+ # fwd_type=0,
1362
+ # block_flag=None
1363
+ # ):
1364
+
1365
+ # if fwd_type == 2:
1366
+ # assert inputs_embeds is not None
1367
+ # return self.encode(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
1368
+ # mask_pos=mask_pos, inputs_embeds=inputs_embeds)
1369
+
1370
+ # elif fwd_type == 1:
1371
+ # return self.embed_encode(input_ids)
1372
+
1373
+
1374
+
1375
+ # if (self.model_args.prompt_ptuning or self.model_args.prompt_prefix) and block_flag is not None:
1376
+ # inputs_embeds = self.generate_continuous_prompt_inputs(input_ids, block_flag)
1377
+
1378
+ # logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds)
1379
+
1380
+ # if self.model_args.hybrid == 1:
1381
+ # logits = logits[0]
1382
+ # cls_logits = logits[1]
1383
+
1384
+ # loss = None
1385
+ # if labels is not None:
1386
+ # if self.num_labels == 1:
1387
+ # # Regression task
1388
+ # loss_fct = nn.KLDivLoss(log_target=True)
1389
+ # labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb),
1390
+ # (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
1391
+ # loss = loss_fct(logits.view(-1, 2), labels)
1392
+ # else:
1393
+
1394
+ # if labels.shape == logits.shape:
1395
+ # loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32),
1396
+ # labels, reduction="batchmean")
1397
+ # else:
1398
+ # loss_fct = nn.CrossEntropyLoss()
1399
+
1400
+ # loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
1401
+
1402
+ # output = (logits,)
1403
+ # if self.num_labels == 1:
1404
+ # # Regression output
1405
+ # output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
1406
+
1407
+ # return ((loss,) + output) if loss is not None else output
1408
+
1409
+
1410
+
1411
+ # # add by wjn
1412
+ # # Prefix-tuning for Deberta
1413
+ # class DebertaPrefixForPromptFinetuning(DebertaPreTrainedModel):
1414
+
1415
+ # def __init__(self, config):
1416
+ # super().__init__(config)
1417
+ # self.num_labels = config.num_labels
1418
+ # #self.deberta = DebertaV2Model(config)
1419
+
1420
+ # self.deberta = DebertaModel(config)
1421
+ # self.cls = DebertaOnlyMLMHead(config)
1422
+
1423
+ # self.pooler = ContextPooler(config)
1424
+ # output_dim = self.pooler.output_dim
1425
+
1426
+ # self.classifier = torch.nn.Linear(output_dim, self.num_labels)
1427
+ # drop_out = getattr(config, "cls_dropout", None)
1428
+ # drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
1429
+
1430
+ # self.dropout = StableDropout(drop_out)
1431
+
1432
+ # classification_list = [self.pooler, self.dropout,self.classifier]
1433
+
1434
+ # self.classifier = nn.Sequential(*classification_list)
1435
+ # # self.cls = DebertaV2OnlyMLMHead(config)
1436
+
1437
+ # self.map = nn.Linear(config.hidden_size, config.hidden_size)
1438
+ # self.init_weights()
1439
+
1440
+ # if self.config.use_freezing:
1441
+ # self.deberta = freezer.freeze_lm(self.deberta)
1442
+
1443
+ # self.pre_seq_len = config.pre_seq_len
1444
+ # self.n_layer = config.num_hidden_layers
1445
+ # self.n_head = config.num_attention_heads
1446
+ # self.n_embd = config.hidden_size // config.num_attention_heads
1447
+
1448
+ # self.prefix_tokens = torch.arange(self.pre_seq_len).long()
1449
+ # self.prefix_encoder = PrefixEncoder(config)
1450
+
1451
+ # # These attributes should be assigned once the model is initialized
1452
+ # self.model_args = None
1453
+ # self.data_args = None
1454
+ # self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device)
1455
+ # self.K = 1
1456
+ # self.step_size=1e-5
1457
+ # # import pdb
1458
+ # # pdb.set_trace()
1459
+ # #self.step_size=config.step_size
1460
+
1461
+ # # For regression
1462
+ # self.lb = None
1463
+ # self.ub = None
1464
+
1465
+
1466
+ # # For auto label search.
1467
+ # self.return_full_softmax = None
1468
+
1469
+ # def freeze_backbone(self, use_freezing: bool=True):
1470
+ # if use_freezing:
1471
+ # self.deberta = freezer.freeze_lm(self.deberta)
1472
+ # else:
1473
+ # self.deberta = freezer.unfreeze_lm(self.deberta)
1474
+
1475
+ # def get_prompt(self, batch_size):
1476
+ # prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device)
1477
+ # past_key_values = self.prefix_encoder(prefix_tokens)
1478
+ # # bsz, seqlen, _ = past_key_values.shape
1479
+ # past_key_values = past_key_values.view(
1480
+ # batch_size,
1481
+ # self.pre_seq_len,
1482
+ # self.n_layer * 2,
1483
+ # self.n_head,
1484
+ # self.n_embd
1485
+ # )
1486
+ # past_key_values = self.dropout(past_key_values)
1487
+ # past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
1488
+ # return past_key_values
1489
+
1490
+
1491
+ # def get_constrast_loss(self,
1492
+ # input_ids=None,
1493
+ # attention_mask=None,
1494
+ # mask_pos=None,
1495
+ # labels=None,
1496
+ # inputs_embeds=None):
1497
+
1498
+ # self.cos = nn.CosineSimilarity(dim=-1)
1499
+
1500
+
1501
+ # _, sequence_mask_output_1 = self.encode(input_ids, attention_mask, mask_pos, inputs_embeds)
1502
+ # _, sequence_mask_output_2 = self.encode(input_ids, attention_mask, mask_pos, inputs_embeds)
1503
+
1504
+ # sequence_mask_output_1= self.lm_head.dense(sequence_mask_output_1)
1505
+ # sequence_mask_output_2 = self.lm_head.dense(sequence_mask_output_2)
1506
+ # # input_args = [input_ids, attention_mask, mask_pos, labels, None, 1]
1507
+ # # embed = self.forward(*input_args)
1508
+ # #
1509
+ # # vat_args = [input_ids, attention_mask, mask_pos, labels, embed, 2]
1510
+ # #
1511
+ # # adv_logits, outputs = self.forward(*vat_args)
1512
+ # #
1513
+ # # logit_mask = F.softmax(logits, dim=-1)[torch.arange(adv_logits.size(0)), labels] > 0.7
1514
+ # #
1515
+ # # outputs = outputs[logit_mask]
1516
+ # # seq_outputs = sequence_mask_output[logit_mask]
1517
+ # # new_label = labels[logit_mask]
1518
+ # # #
1519
+ # # #
1520
+ # # rand_perm = torch.randperm(outputs.size(0))
1521
+ # # rand_outputs = outputs[rand_perm, :]
1522
+ # # rand_label = new_label[rand_perm]
1523
+ # # pair_label = (new_label == rand_label).long()
1524
+ # #
1525
+ # # seq_outputs = self.map(seq_outputs)
1526
+ # # rand_outputs = self.map(rand_outputs)
1527
+
1528
+ # pair_labels = (labels.unsqueeze(1) == labels.unsqueeze(0)).float()
1529
+
1530
+ # # import pdb
1531
+ # # pdb.set_trace()
1532
+
1533
+ # contra_loss = self.contra_lc(sequence_mask_output_1.unsqueeze(1), sequence_mask_output_2.unsqueeze(0), pair_labels)
1534
+
1535
+ # if torch.isnan(contra_loss):
1536
+ # return 0
1537
+
1538
+ # return contra_loss
1539
+
1540
+ # def embed_encode(self, input_ids):
1541
+ # embedding_output = self.deberta.embeddings.word_embeddings(input_ids)
1542
+ # return embedding_output
1543
+
1544
+ # def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False):
1545
+ # batch_size = input_ids.size(0)
1546
+
1547
+ # # add prefix for prompt-tuning
1548
+ # past_key_values = self.get_prompt(batch_size=batch_size)
1549
+ # prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.deberta.device)
1550
+ # attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
1551
+
1552
+ # if mask_pos is not None:
1553
+ # mask_pos = mask_pos.squeeze()
1554
+
1555
+ # # Encode everything
1556
+
1557
+ # outputs = self.deberta(
1558
+ # input_ids,
1559
+ # attention_mask=attention_mask,
1560
+ # token_type_ids=token_type_ids,
1561
+ # past_key_values=past_key_values,
1562
+ # )
1563
+
1564
+
1565
+ # # Get <mask> token representation
1566
+ # sequence_output, pooled_output = outputs[:2]
1567
+ # # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
1568
+ # sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
1569
+
1570
+ # # Logits over vocabulary tokens
1571
+ # prediction_mask_scores = self.cls(sequence_mask_output)
1572
+
1573
+ # #sequence_mask_output = self.lm_head.dense(sequence_mask_output)
1574
+
1575
+ # # Exit early and only return mask logits.
1576
+ # if return_full_softmax:
1577
+ # return prediction_mask_scores
1578
+
1579
+ # # Return logits for each label
1580
+ # logits = []
1581
+ # for label_id in range(len(self.label_word_list)):
1582
+ # logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
1583
+ # logits = torch.cat(logits, -1)
1584
+
1585
+ # # Regression task
1586
+ # if self.config.num_labels == 1:
1587
+ # logsoftmax = nn.LogSoftmax(-1)
1588
+ # logits = logsoftmax(logits) # Log prob of right polarity
1589
+
1590
+ # if self.model_args.hybrid == 1:
1591
+ # cls_logits = self.classifier(sequence_output)
1592
+ # return (logits, cls_logits), sequence_mask_output
1593
+
1594
+ # return logits, sequence_mask_output
1595
+
1596
+
1597
+ # def forward(
1598
+ # self,
1599
+ # input_ids=None,
1600
+ # attention_mask=None,
1601
+ # token_type_ids=None,
1602
+ # mask_pos=None,
1603
+ # labels=None,
1604
+ # inputs_embeds=None,
1605
+ # fwd_type=0,
1606
+ # block_flag=None,
1607
+ # return_dict=None,
1608
+ # ):
1609
+
1610
+ # if fwd_type == 2:
1611
+ # assert inputs_embeds is not None
1612
+ # return self.encode(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
1613
+ # mask_pos=mask_pos, inputs_embeds=inputs_embeds)
1614
+
1615
+ # elif fwd_type == 1:
1616
+ # return self.embed_encode(input_ids)
1617
+
1618
+
1619
+
1620
+ # if (self.model_args.prompt_ptuning or self.model_args.prompt_prefix) and block_flag is not None:
1621
+ # inputs_embeds = self.generate_continuous_prompt_inputs(input_ids, block_flag)
1622
+
1623
+ # logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds)
1624
+
1625
+ # if self.model_args.hybrid == 1:
1626
+ # logits = logits[0]
1627
+ # cls_logits = logits[1]
1628
+
1629
+ # loss = None
1630
+ # if labels is not None:
1631
+ # if self.num_labels == 1:
1632
+ # # Regression task
1633
+ # loss_fct = nn.KLDivLoss(log_target=True)
1634
+ # labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb),
1635
+ # (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
1636
+ # loss = loss_fct(logits.view(-1, 2), labels)
1637
+ # else:
1638
+
1639
+ # if labels.shape == logits.shape:
1640
+ # loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32),
1641
+ # labels, reduction="batchmean")
1642
+ # else:
1643
+ # loss_fct = nn.CrossEntropyLoss()
1644
+
1645
+ # loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
1646
+
1647
+ # output = (logits,)
1648
+ # if self.num_labels == 1:
1649
+ # # Regression output
1650
+ # output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
1651
+
1652
+ # if not return_dict:
1653
+ # return ((loss,) + output) if loss is not None else output
1654
+
1655
+ # return SequenceClassifierOutput(
1656
+ # loss=loss,
1657
+ # logits=logits,
1658
+ # )
1659
+
1660
+
1661
+
1662
+
1663
+ # class Debertav2ForPromptFinetuning(DebertaV2PreTrainedModel):
1664
+ # _keys_to_ignore_on_load_unexpected = [r"pooler"]
1665
+ # _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1666
+
1667
+ # def __init__(self, config):
1668
+ # super().__init__(config)
1669
+ # self.num_labels = config.num_labels
1670
+ # self.deberta = DebertaV2Model(config)
1671
+
1672
+ # if self.config.use_freezing:
1673
+ # self.deberta = freezer.freeze_lm(self.deberta)
1674
+ # self.cls = DebertaV2OnlyMLMHead(config)
1675
+
1676
+ # #self.deberta = DebertaModel(config)
1677
+ # #self.cls = DebertaOnlyMLMHead(config)
1678
+
1679
+ # self.pooler = ContextPooler(config)
1680
+ # output_dim = self.pooler.output_dim
1681
+
1682
+ # self.classifier = torch.nn.Linear(output_dim, self.num_labels)
1683
+ # drop_out = getattr(config, "cls_dropout", None)
1684
+ # drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
1685
+
1686
+ # self.dropout = StableDropout(drop_out)
1687
+
1688
+ # classification_list = [self.pooler, self.dropout,self.classifier]
1689
+
1690
+ # self.classifier = nn.Sequential(*classification_list)
1691
+ # # self.cls = DebertaV2OnlyMLMHead(config)
1692
+
1693
+ # self.map = nn.Linear(config.hidden_size, config.hidden_size)
1694
+ # self.init_weights()
1695
+
1696
+ # # These attributes should be assigned once the model is initialized
1697
+ # self.model_args = None
1698
+ # self.data_args = None
1699
+ # self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device)
1700
+ # self.K = 1
1701
+ # self.step_size=1e-5
1702
+ # # import pdb
1703
+ # # pdb.set_trace()
1704
+ # #self.step_size=config.step_size
1705
+
1706
+ # # For regression
1707
+ # self.lb = None
1708
+ # self.ub = None
1709
+
1710
+ # self.pre_seq_len = self.config.pre_seq_len
1711
+ # # For auto label search.
1712
+ # self.return_full_softmax = None
1713
+
1714
+ # def freeze_backbone(self, use_freezing: bool=True):
1715
+ # if use_freezing:
1716
+ # self.deberta = freezer.freeze_lm(self.deberta)
1717
+ # else:
1718
+ # self.deberta = freezer.unfreeze_lm(self.deberta)
1719
+
1720
+ # def embed_encode(self, input_ids):
1721
+ # embedding_output = self.deberta.embeddings.word_embeddings(input_ids)
1722
+ # return embedding_output
1723
+
1724
+ # def encode(self, input_ids=None, attention_mask=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False):
1725
+ # batch_size = input_ids.size(0)
1726
+
1727
+ # if mask_pos is not None:
1728
+ # mask_pos = mask_pos.squeeze()
1729
+
1730
+ # # Encode everything
1731
+ # if inputs_embeds is None:
1732
+ # outputs = self.deberta(
1733
+ # input_ids,
1734
+ # attention_mask=attention_mask
1735
+ # )
1736
+ # else:
1737
+ # outputs = self.deberta(
1738
+ # None,
1739
+ # attention_mask=attention_mask,
1740
+ # inputs_embeds=inputs_embeds
1741
+ # )
1742
+
1743
+
1744
+ # # Get <mask> token representation
1745
+ # sequence_output = outputs[0]
1746
+ # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
1747
+ # sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
1748
+
1749
+
1750
+ # # Logits over vocabulary tokens
1751
+ # prediction_mask_scores = self.cls(sequence_mask_output)
1752
+
1753
+ # #sequence_mask_output = self.lm_head.dense(sequence_mask_output)
1754
+
1755
+ # # Exit early and only return mask logits.
1756
+ # if return_full_softmax:
1757
+ # return prediction_mask_scores
1758
+
1759
+ # # Return logits for each label
1760
+ # logits = []
1761
+ # for label_id in range(len(self.label_word_list)):
1762
+ # logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
1763
+ # logits = torch.cat(logits, -1)
1764
+
1765
+ # # Regression task
1766
+ # if self.config.num_labels == 1:
1767
+ # logsoftmax = nn.LogSoftmax(-1)
1768
+ # logits = logsoftmax(logits) # Log prob of right polarity
1769
+
1770
+ # return logits, sequence_mask_output
1771
+
1772
+
1773
+ # def forward(
1774
+ # self,
1775
+ # input_ids=None,
1776
+ # attention_mask=None,
1777
+ # mask_pos=None,
1778
+ # labels=None,
1779
+ # inputs_embeds=None,
1780
+ # fwd_type=0,
1781
+ # block_flag=None,
1782
+ # return_dict=None
1783
+ # ):
1784
+ # if fwd_type == 2:
1785
+ # assert inputs_embeds is not None
1786
+ # return self.encode(input_ids=input_ids, attention_mask=attention_mask, mask_pos=mask_pos, inputs_embeds=inputs_embeds)
1787
+
1788
+ # elif fwd_type == 1:
1789
+ # return self.embed_encode(input_ids)
1790
+
1791
+ # logits, sequence_mask_output = self.encode(input_ids, attention_mask, mask_pos, inputs_embeds)
1792
+
1793
+ # loss = None
1794
+
1795
+
1796
+ # if labels is not None:
1797
+ # if self.num_labels == 1:
1798
+ # # Regression task
1799
+ # loss_fct = nn.KLDivLoss(log_target=True)
1800
+ # labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
1801
+ # loss = loss_fct(logits.view(-1, 2), labels)
1802
+ # else:
1803
+
1804
+ # if labels.shape == logits.shape:
1805
+ # loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32),
1806
+ # labels, reduction="batchmean")
1807
+ # else:
1808
+ # loss_fct = nn.CrossEntropyLoss()
1809
+
1810
+ # loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
1811
+ # if self.model_args.hybrid == 1:
1812
+ # cls_loss = loss_fct(cls_logits.view(-1, cls_logits.size(-1)), labels.view(-1))
1813
+ # loss = loss + cls_loss
1814
+
1815
+ # output = (logits,)
1816
+ # if self.num_labels == 1:
1817
+ # # Regression output
1818
+ # output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
1819
+
1820
+ # if not return_dict:
1821
+ # return ((loss,) + output) if loss is not None else output
1822
+
1823
+ # return SequenceClassifierOutput(
1824
+ # loss=loss,
1825
+ # logits=logits,
1826
+ # )
1827
+
1828
+
1829
+ # class Debertav2PrefixForPromptFinetuning(DebertaV2PreTrainedModel):
1830
+ # _keys_to_ignore_on_load_unexpected = [r"pooler"]
1831
+ # _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1832
+
1833
+ # def __init__(self, config):
1834
+ # super().__init__(config)
1835
+ # self.num_labels = config.num_labels
1836
+ # self.deberta = DebertaV2Model(config)
1837
+ # self.cls = DebertaV2OnlyMLMHead(config)
1838
+
1839
+ # #self.deberta = DebertaModel(config)
1840
+ # #self.cls = DebertaOnlyMLMHead(config)
1841
+
1842
+ # self.pooler = ContextPooler(config)
1843
+ # output_dim = self.pooler.output_dim
1844
+
1845
+ # self.classifier = torch.nn.Linear(output_dim, self.num_labels)
1846
+ # drop_out = getattr(config, "cls_dropout", None)
1847
+ # drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
1848
+
1849
+ # self.dropout = StableDropout(drop_out)
1850
+
1851
+ # classification_list = [self.pooler, self.dropout,self.classifier]
1852
+
1853
+ # self.classifier = nn.Sequential(*classification_list)
1854
+ # # self.cls = DebertaV2OnlyMLMHead(config)
1855
+
1856
+ # self.map = nn.Linear(config.hidden_size, config.hidden_size)
1857
+ # self.init_weights()
1858
+
1859
+ # if self.config.use_freezing:
1860
+ # self.deberta = freezer.freeze_lm(self.deberta)
1861
+
1862
+ # self.pre_seq_len = config.pre_seq_len
1863
+ # self.n_layer = config.num_hidden_layers
1864
+ # self.n_head = config.num_attention_heads
1865
+ # self.n_embd = config.hidden_size // config.num_attention_heads
1866
+
1867
+ # self.prefix_tokens = torch.arange(self.pre_seq_len).long()
1868
+ # self.prefix_encoder = PrefixEncoder(config)
1869
+
1870
+ # # These attributes should be assigned once the model is initialized
1871
+ # self.model_args = None
1872
+ # self.data_args = None
1873
+ # self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device)
1874
+ # self.K = 1
1875
+ # self.step_size=1e-5
1876
+ # # import pdb
1877
+ # # pdb.set_trace()
1878
+ # #self.step_size=config.step_size
1879
+
1880
+ # # For regression
1881
+ # self.lb = None
1882
+ # self.ub = None
1883
+
1884
+
1885
+ # # For auto label search.
1886
+ # self.return_full_softmax = None
1887
+
1888
+ # def freeze_backbone(self, use_freezing: bool=True):
1889
+ # if use_freezing:
1890
+ # self.deberta = freezer.freeze_lm(self.deberta)
1891
+ # else:
1892
+ # self.deberta = freezer.unfreeze_lm(self.deberta)
1893
+
1894
+ # def get_prompt(self, batch_size):
1895
+ # prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device)
1896
+ # past_key_values = self.prefix_encoder(prefix_tokens)
1897
+ # # bsz, seqlen, _ = past_key_values.shape
1898
+ # past_key_values = past_key_values.view(
1899
+ # batch_size,
1900
+ # self.pre_seq_len,
1901
+ # self.n_layer * 2,
1902
+ # self.n_head,
1903
+ # self.n_embd
1904
+ # )
1905
+ # past_key_values = self.dropout(past_key_values)
1906
+ # past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
1907
+ # return past_key_values
1908
+
1909
+
1910
+ # def embed_encode(self, input_ids):
1911
+ # embedding_output = self.deberta.embeddings.word_embeddings(input_ids)
1912
+ # return embedding_output
1913
+
1914
+ # def encode(self, input_ids=None, attention_mask=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False):
1915
+ # batch_size = input_ids.size(0)
1916
+
1917
+ # # add prefix for prompt-tuning
1918
+ # past_key_values = self.get_prompt(batch_size=batch_size)
1919
+ # prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.deberta.device)
1920
+ # attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
1921
+
1922
+
1923
+ # if mask_pos is not None:
1924
+ # mask_pos = mask_pos.squeeze()
1925
+
1926
+ # # Encode everything
1927
+ # outputs = self.deberta(
1928
+ # input_ids,
1929
+ # attention_mask=attention_mask,
1930
+ # past_key_values=past_key_values,
1931
+ # )
1932
+
1933
+
1934
+ # # Get <mask> token representation
1935
+ # sequence_output = outputs[0]
1936
+ # # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
1937
+ # sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
1938
+
1939
+
1940
+ # # Logits over vocabulary tokens
1941
+ # prediction_mask_scores = self.cls(sequence_mask_output)
1942
+
1943
+ # #sequence_mask_output = self.lm_head.dense(sequence_mask_output)
1944
+
1945
+ # # Exit early and only return mask logits.
1946
+ # if return_full_softmax:
1947
+ # return prediction_mask_scores
1948
+
1949
+ # # Return logits for each label
1950
+ # logits = []
1951
+ # for label_id in range(len(self.label_word_list)):
1952
+ # logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
1953
+ # logits = torch.cat(logits, -1)
1954
+
1955
+ # # Regression task
1956
+ # if self.config.num_labels == 1:
1957
+ # logsoftmax = nn.LogSoftmax(-1)
1958
+ # logits = logsoftmax(logits) # Log prob of right polarity
1959
+
1960
+ # return logits, sequence_mask_output
1961
+
1962
+
1963
+ # def forward(
1964
+ # self,
1965
+ # input_ids=None,
1966
+ # attention_mask=None,
1967
+ # mask_pos=None,
1968
+ # labels=None,
1969
+ # inputs_embeds=None,
1970
+ # fwd_type=0,
1971
+ # block_flag=None,
1972
+ # return_dict=None,
1973
+ # ):
1974
+ # if fwd_type == 2:
1975
+ # assert inputs_embeds is not None
1976
+ # return self.encode(input_ids=input_ids, attention_mask=attention_mask, mask_pos=mask_pos, inputs_embeds=inputs_embeds)
1977
+
1978
+ # elif fwd_type == 1:
1979
+ # return self.embed_encode(input_ids)
1980
+
1981
+ # logits, sequence_mask_output = self.encode(input_ids, attention_mask, mask_pos, inputs_embeds)
1982
+
1983
+ # loss = None
1984
+
1985
+
1986
+ # if labels is not None:
1987
+ # if self.num_labels == 1:
1988
+ # # Regression task
1989
+ # loss_fct = nn.KLDivLoss(log_target=True)
1990
+ # labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
1991
+ # loss = loss_fct(logits.view(-1, 2), labels)
1992
+ # else:
1993
+
1994
+ # if labels.shape == logits.shape:
1995
+ # loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32),
1996
+ # labels, reduction="batchmean")
1997
+ # else:
1998
+ # loss_fct = nn.CrossEntropyLoss()
1999
+
2000
+ # loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
2001
+ # if self.model_args.hybrid == 1:
2002
+ # cls_loss = loss_fct(cls_logits.view(-1, cls_logits.size(-1)), labels.view(-1))
2003
+ # loss = loss + cls_loss
2004
+
2005
+ # output = (logits,)
2006
+ # if self.num_labels == 1:
2007
+ # # Regression output
2008
+ # output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
2009
+
2010
+ # if not return_dict:
2011
+ # return ((loss,) + output) if loss is not None else output
2012
+
2013
+ # return SequenceClassifierOutput(
2014
+ # loss=loss,
2015
+ # logits=logits,
2016
+ # )