DeepLearning101 commited on
Commit
0d88c28
1 Parent(s): b0ebb46

Upload 2 files

Browse files
models/code/code_classification.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time    : 2023/3/11 8:02 上午
3
+ # @Author  : NuoChen
4
+ # @File    : code_classification.py
5
+
6
+ ## ======== Roberta ========
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
10
+ from transformers import RobertaModel
11
+ from transformers.activations import ACT2FN
12
+ from transformers.models.electra import ElectraModel
13
+ from transformers.models.roformer import RoFormerModel
14
+ from transformers.models.albert import AlbertModel
15
+ from transformers.models.bert import BertModel, BertPreTrainedModel
16
+ from transformers.models.deberta_v2 import DebertaV2Model, DebertaV2PreTrainedModel
17
+ from transformers.modeling_outputs import SequenceClassifierOutput
18
+ from transformers.models.roberta import RobertaPreTrainedModel
19
+ from transformers.models.bert.modeling_bert import BertForSequenceClassification
20
+ from transformers.models.megatron_bert import MegatronBertPreTrainedModel, MegatronBertModel
21
+ import logging
22
+ from typing import Optional, List, Union, Tuple
23
+ import torch
24
+ from torch._C import NoopLogger
25
+ from torch.autograd import Variable
26
+ import copy
27
+ import torch.nn
28
+ import torch.nn.functional as F
29
+ from torch import Tensor
30
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
31
+
32
+ from transformers import RobertaModel, RobertaPreTrainedModel
33
+ from transformers.models.plbart.modeling_plbart import PLBartPreTrainedModel, PLBartClassificationHead, PLBartModel
34
+ from transformers.models.plbart.configuration_plbart import PLBartConfig
35
+ from transformers.models.t5.modeling_t5 import T5PreTrainedModel#, T5ClassificationHead, T5Model
36
+ from transformers.models.t5.modeling_t5 import T5Config,T5Stack
37
+ from transformers.modeling_outputs import SequenceClassifierOutput, Seq2SeqSequenceClassifierOutput, SequenceClassifierOutputWithPast
38
+ from models.basic_modules.prefix_encoder import PrefixEncoder
39
+
40
+ from models.basic_modules.adapter import BertAdaModel, RobertaAdaModel, init_adapter
41
+ from tools.model_utils.parameter_freeze import ParameterFreeze
42
+
43
+ freezer = ParameterFreeze()
44
+
45
+ ## ======== Roberta ========
46
+ # Vanilla Fine-tuning For Roberta
47
+ class RobertaForCodeClassification(RobertaPreTrainedModel):
48
+ def __init__(self, config):
49
+ super().__init__(config)
50
+ self.num_labels = config.num_labels
51
+ self.config = config
52
+ self.roberta = RobertaModel(config)
53
+ if self.config.use_freezing:
54
+ self.roberta = freezer.freeze_lm(self.roberta)
55
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
56
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
57
+ self.init_weights()
58
+
59
+ def freeze_backbone(self, use_freezing: bool=True):
60
+ if use_freezing:
61
+ self.roberta = freezer.freeze_lm(self.roberta)
62
+ else:
63
+ self.roberta = freezer.unfreeze_lm(self.roberta)
64
+
65
+ def forward(
66
+ self,
67
+ input_ids=None,
68
+ attention_mask=None,
69
+ token_type_ids=None,
70
+ position_ids=None,
71
+ head_mask=None,
72
+ inputs_embeds=None,
73
+ labels=None,
74
+ output_attentions=None,
75
+ output_hidden_states=None,
76
+ return_dict=None,
77
+ ):
78
+ r"""
79
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
80
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
81
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
82
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
83
+ """
84
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
85
+
86
+ outputs = self.roberta(
87
+ input_ids,
88
+ attention_mask=attention_mask,
89
+ token_type_ids=token_type_ids,
90
+ position_ids=position_ids,
91
+ head_mask=head_mask,
92
+ inputs_embeds=inputs_embeds,
93
+ output_attentions=output_attentions,
94
+ output_hidden_states=output_hidden_states,
95
+ return_dict=return_dict,
96
+ )
97
+
98
+ pooled_output = outputs[1]
99
+
100
+ pooled_output = self.dropout(pooled_output)
101
+ logits = self.classifier(pooled_output)
102
+
103
+ loss = None
104
+ if labels is not None:
105
+ loss_fct = CrossEntropyLoss()
106
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
107
+
108
+ if not return_dict:
109
+ output = (logits,) + outputs[2:]
110
+ return ((loss,) + output) if loss is not None else output
111
+
112
+ return SequenceClassifierOutput(
113
+ loss=loss,
114
+ logits=logits,
115
+ hidden_states=outputs.hidden_states,
116
+ attentions=outputs.attentions,
117
+ )
118
+
119
+ ## ======== CodeBERT ========
120
+ # Vanilla Fine-tuning For CodeBERT
121
+ class CodeBERTForCodeClassification(RobertaPreTrainedModel):
122
+ def __init__(self, config):
123
+ super().__init__(config)
124
+ self.num_labels = config.num_labels
125
+ self.config = config
126
+ self.roberta = RobertaModel(config)
127
+ if self.config.use_freezing:
128
+ self.roberta = freezer.freeze_lm(self.roberta)
129
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
130
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
131
+ self.init_weights()
132
+
133
+ def freeze_backbone(self, use_freezing: bool=True):
134
+ if use_freezing:
135
+ self.roberta = freezer.freeze_lm(self.roberta)
136
+ else:
137
+ self.roberta = freezer.unfreeze_lm(self.roberta)
138
+
139
+ def forward(
140
+ self,
141
+ input_ids=None,
142
+ attention_mask=None,
143
+ token_type_ids=None,
144
+ position_ids=None,
145
+ head_mask=None,
146
+ inputs_embeds=None,
147
+ labels=None,
148
+ output_attentions=None,
149
+ output_hidden_states=None,
150
+ return_dict=None,
151
+ ):
152
+ r"""
153
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
154
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
155
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
156
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
157
+ """
158
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
159
+
160
+ outputs = self.roberta(
161
+ input_ids,
162
+ attention_mask=attention_mask,
163
+ token_type_ids=token_type_ids,
164
+ position_ids=position_ids,
165
+ head_mask=head_mask,
166
+ inputs_embeds=inputs_embeds,
167
+ output_attentions=output_attentions,
168
+ output_hidden_states=output_hidden_states,
169
+ return_dict=return_dict,
170
+ )
171
+
172
+ pooled_output = outputs[1]
173
+
174
+ pooled_output = self.dropout(pooled_output)
175
+ logits = self.classifier(pooled_output)
176
+
177
+ loss = None
178
+ if labels is not None:
179
+ if self.config.problem_type is None:
180
+ if self.num_labels == 1:
181
+ self.config.problem_type = "regression"
182
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
183
+ self.config.problem_type = "single_label_classification"
184
+ else:
185
+ self.config.problem_type = "multi_label_classification"
186
+
187
+ if self.config.problem_type == "regression":
188
+ loss_fct = MSELoss()
189
+ if self.num_labels == 1:
190
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
191
+ else:
192
+ loss = loss_fct(logits, labels)
193
+ elif self.config.problem_type == "single_label_classification":
194
+ loss_fct = CrossEntropyLoss()
195
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
196
+ elif self.config.problem_type == "multi_label_classification":
197
+ loss_fct = BCEWithLogitsLoss()
198
+ loss = loss_fct(logits, labels)
199
+ if not return_dict:
200
+ output = (logits,) + outputs[2:]
201
+ return ((loss,) + output) if loss is not None else output
202
+
203
+ return SequenceClassifierOutput(
204
+ loss=loss,
205
+ logits=logits,
206
+ hidden_states=outputs.hidden_states,
207
+ attentions=outputs.attentions,
208
+ )
209
+
210
+ ## ======== GraphCodeBERT ========
211
+
212
+ # Vanilla Fine-tuning For GraphCodeBERT
213
+ class GraphCodeBERTForCodeClassification(RobertaPreTrainedModel):
214
+ def __init__(self, config):
215
+ super().__init__(config)
216
+ self.num_labels = config.num_labels
217
+ self.config = config
218
+ self.roberta = RobertaModel(config)
219
+ if self.config.use_freezing:
220
+ self.roberta = freezer.freeze_lm(self.roberta)
221
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
222
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
223
+ self.init_weights()
224
+
225
+ def freeze_backbone(self, use_freezing: bool=True):
226
+ if use_freezing:
227
+ self.roberta = freezer.freeze_lm(self.roberta)
228
+ else:
229
+ self.roberta = freezer.unfreeze_lm(self.roberta)
230
+
231
+ def forward(
232
+ self,
233
+ input_ids=None,
234
+ attention_mask=None,
235
+ token_type_ids=None,
236
+ position_ids=None,
237
+ head_mask=None,
238
+ inputs_embeds=None,
239
+ labels=None,
240
+ output_attentions=None,
241
+ output_hidden_states=None,
242
+ return_dict=None,
243
+ ):
244
+ r"""
245
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
246
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
247
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
248
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
249
+ """
250
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
251
+
252
+ outputs = self.roberta(
253
+ input_ids,
254
+ attention_mask=attention_mask,
255
+ token_type_ids=token_type_ids,
256
+ position_ids=position_ids,
257
+ head_mask=head_mask,
258
+ inputs_embeds=inputs_embeds,
259
+ output_attentions=output_attentions,
260
+ output_hidden_states=output_hidden_states,
261
+ return_dict=return_dict,
262
+ )
263
+
264
+ pooled_output = outputs[1]
265
+
266
+ pooled_output = self.dropout(pooled_output)
267
+ logits = self.classifier(pooled_output)
268
+
269
+ loss = None
270
+ if labels is not None:
271
+ loss_fct = CrossEntropyLoss()
272
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
273
+
274
+ if not return_dict:
275
+ output = (logits,) + outputs[2:]
276
+ return ((loss,) + output) if loss is not None else output
277
+
278
+ return SequenceClassifierOutput(
279
+ loss=loss,
280
+ logits=logits,
281
+ hidden_states=outputs.hidden_states,
282
+ attentions=outputs.attentions,
283
+ )
284
+
285
+ ## ======== PLBART ========
286
+
287
+ # Vanilla Fine-tuning For PLBART
288
+ class PLBARTForCodeClassification(PLBartPreTrainedModel):
289
+
290
+ _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
291
+
292
+ def __init__(self, config: PLBartConfig, **kwargs):
293
+ super().__init__(config, **kwargs)
294
+ self.model = PLBartModel(config)
295
+ self.classification_head = PLBartClassificationHead(
296
+ config.d_model,
297
+ config.d_model,
298
+ config.num_labels,
299
+ config.classifier_dropout,
300
+ )
301
+ self.model._init_weights(self.classification_head.dense)
302
+ self.model._init_weights(self.classification_head.out_proj)
303
+
304
+
305
+ # Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward
306
+ def forward(
307
+ self,
308
+ input_ids: torch.LongTensor = None,
309
+ attention_mask: Optional[torch.Tensor] = None,
310
+ decoder_input_ids: Optional[torch.LongTensor] = None,
311
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
312
+ head_mask: Optional[torch.Tensor] = None,
313
+ decoder_head_mask: Optional[torch.Tensor] = None,
314
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
315
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
316
+ inputs_embeds: Optional[torch.FloatTensor] = None,
317
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
318
+ labels: Optional[torch.LongTensor] = None,
319
+ use_cache: Optional[bool] = None,
320
+ output_attentions: Optional[bool] = None,
321
+ output_hidden_states: Optional[bool] = None,
322
+ return_dict: Optional[bool] = None,
323
+ ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
324
+ r"""
325
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
326
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
327
+ config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
328
+ """
329
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
330
+ if labels is not None:
331
+ use_cache = False
332
+
333
+ if input_ids is None and inputs_embeds is not None:
334
+ raise NotImplementedError(
335
+ f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
336
+ )
337
+
338
+ outputs = self.model(
339
+ input_ids,
340
+ attention_mask=attention_mask,
341
+ decoder_input_ids=decoder_input_ids,
342
+ decoder_attention_mask=decoder_attention_mask,
343
+ head_mask=head_mask,
344
+ decoder_head_mask=decoder_head_mask,
345
+ cross_attn_head_mask=cross_attn_head_mask,
346
+ encoder_outputs=encoder_outputs,
347
+ inputs_embeds=inputs_embeds,
348
+ decoder_inputs_embeds=decoder_inputs_embeds,
349
+ use_cache=use_cache,
350
+ output_attentions=output_attentions,
351
+ output_hidden_states=output_hidden_states,
352
+ return_dict=return_dict,
353
+ )
354
+ hidden_states = outputs[0] # last hidden state
355
+
356
+ eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
357
+
358
+ if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
359
+ raise ValueError("All examples must have the same number of <eos> tokens.")
360
+ sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
361
+ :, -1, :
362
+ ]
363
+ logits = self.classification_head(sentence_representation)
364
+
365
+ loss = None
366
+ if labels is not None:
367
+ if self.config.problem_type is None:
368
+ if self.config.num_labels == 1:
369
+ self.config.problem_type = "regression"
370
+ elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
371
+ self.config.problem_type = "single_label_classification"
372
+ else:
373
+ self.config.problem_type = "multi_label_classification"
374
+
375
+ if self.config.problem_type == "regression":
376
+ loss_fct = MSELoss()
377
+ if self.config.num_labels == 1:
378
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
379
+ else:
380
+ loss = loss_fct(logits, labels)
381
+ elif self.config.problem_type == "single_label_classification":
382
+ loss_fct = CrossEntropyLoss()
383
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
384
+ elif self.config.problem_type == "multi_label_classification":
385
+ loss_fct = BCEWithLogitsLoss()
386
+ loss = loss_fct(logits, labels)
387
+ if not return_dict:
388
+ output = (logits,) + outputs[1:]
389
+ return ((loss,) + output) if loss is not None else output
390
+
391
+ return Seq2SeqSequenceClassifierOutput(
392
+ loss=loss,
393
+ logits=logits,
394
+ past_key_values=outputs.past_key_values,
395
+ decoder_hidden_states=outputs.decoder_hidden_states,
396
+ decoder_attentions=outputs.decoder_attentions,
397
+ cross_attentions=outputs.cross_attentions,
398
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
399
+ encoder_hidden_states=outputs.encoder_hidden_states,
400
+ encoder_attentions=outputs.encoder_attentions,
401
+ )
402
+
403
+
404
+ ## ======== CodeT5 ========
405
+
406
+ # Vanilla Fine-tuning For CodeT5
407
+ class CodeT5ForCodeClassification(T5PreTrainedModel):
408
+ _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
409
+
410
+ def __init__(self, config: T5Config):
411
+ super().__init__(config)
412
+ self.model_dim = config.d_model
413
+ self.config.problem_type = None
414
+ self.config.is_encoder_decoder = False
415
+
416
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
417
+
418
+ encoder_config = copy.deepcopy(config)
419
+ encoder_config.is_decoder = False
420
+ encoder_config.is_encoder_decoder = False
421
+ encoder_config.use_cache = False
422
+ self.encoder = T5Stack(encoder_config, self.shared)
423
+
424
+ classifier_dropout = (
425
+ config.classifier_dropout if hasattr(config, 'classifier_dropout') else config.dropout_rate
426
+ )
427
+ self.dropout = nn.Dropout(classifier_dropout)
428
+ self.classifier = nn.Linear(config.d_model, config.num_labels)
429
+
430
+ # Initialize weights and apply final processing
431
+ self.post_init()
432
+
433
+ # Model parallel
434
+ self.model_parallel = False
435
+ self.device_map = None
436
+
437
+ def parallelize(self, device_map=None):
438
+ self.device_map = (
439
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
440
+ if device_map is None
441
+ else device_map
442
+ )
443
+ assert_device_map(self.device_map, len(self.encoder.block))
444
+ self.encoder.parallelize(self.device_map)
445
+ self.classifier.to(self.encoder.first_device)
446
+ self.model_parallel = True
447
+
448
+ def deparallelize(self):
449
+ self.encoder.deparallelize()
450
+ self.encoder = self.encoder.to("cpu")
451
+ self.classifier = self.classifier.to("cpu")
452
+ self.model_parallel = False
453
+ self.device_map = None
454
+ torch.cuda.empty_cache()
455
+
456
+ def get_input_embeddings(self):
457
+ return self.shared
458
+
459
+ def set_input_embeddings(self, new_embeddings):
460
+ self.shared = new_embeddings
461
+ self.encoder.set_input_embeddings(new_embeddings)
462
+
463
+ def get_encoder(self):
464
+ return self.encoder
465
+
466
+ def _prune_heads(self, heads_to_prune):
467
+ """
468
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
469
+ class PreTrainedModel
470
+ """
471
+ for layer, heads in heads_to_prune.items():
472
+ self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
473
+
474
+ def forward(
475
+ self,
476
+ input_ids: Optional[torch.LongTensor] = None,
477
+ attention_mask: Optional[torch.FloatTensor] = None,
478
+ head_mask: Optional[torch.FloatTensor] = None,
479
+ inputs_embeds: Optional[torch.FloatTensor] = None,
480
+ labels: Optional[torch.LongTensor] = None,
481
+ output_attentions: Optional[bool] = None,
482
+ output_hidden_states: Optional[bool] = None,
483
+ return_dict: Optional[bool] = None,
484
+ ) -> Union[Tuple[torch.FloatTensor], SequenceClassifierOutput]:
485
+ r"""
486
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
487
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
488
+ Returns:
489
+ """
490
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
491
+
492
+ outputs = self.encoder(
493
+ input_ids=input_ids,
494
+ attention_mask=attention_mask,
495
+ inputs_embeds=inputs_embeds,
496
+ head_mask=head_mask,
497
+ output_attentions=output_attentions,
498
+ output_hidden_states=output_hidden_states,
499
+ return_dict=return_dict,
500
+ )
501
+
502
+ # Get last hidden indices
503
+ # (batch_size) -> (batch_size, 1) -> (batch_size, hidden_size) -> (batch_size, 1, hidden_size)
504
+ last_hidden_indices = (
505
+ (input_ids != self.config.pad_token_id).sum(dim=-1) - 1
506
+ ).unsqueeze(dim=-1).repeat(1, outputs[0].size(-1)).unsqueeze(1)
507
+ sequence_output = outputs[0].gather(dim=1, index=last_hidden_indices).squeeze(1)
508
+
509
+ sequence_output = self.dropout(sequence_output)
510
+ logits = self.classifier(sequence_output)
511
+
512
+ loss = None
513
+ if labels is not None:
514
+ if self.config.problem_type is None:
515
+ if self.config.num_labels == 1:
516
+ self.config.problem_type = "regression"
517
+ elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
518
+ self.config.problem_type = "single_label_classification"
519
+ else:
520
+ self.config.problem_type = "multi_label_classification"
521
+
522
+ if self.config.problem_type == "regression":
523
+ loss_fct = nn.MSELoss()
524
+ if self.config.num_labels == 1:
525
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
526
+ else:
527
+ loss = loss_fct(logits, labels)
528
+ elif self.config.problem_type == "single_label_classification":
529
+ loss_fct = nn.CrossEntropyLoss()
530
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
531
+ elif self.config.problem_type == "multi_label_classification":
532
+ loss_fct = nn.BCEWithLogitsLoss()
533
+ loss = loss_fct(logits, labels)
534
+
535
+ if not return_dict:
536
+ output = (logits,) + outputs[2:]
537
+ return ((loss,) + output) if loss is not None else output
538
+
539
+ return SequenceClassifierOutput(
540
+ loss=loss,
541
+ logits=logits,
542
+ hidden_states=outputs.hidden_states,
543
+ attentions=outputs.attentions
544
+ )
models/code/code_generation.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time    : 2023/4/05 18:02 下午
3
+ # @Author  : NuoChen
4
+ # @File    : code_generation.py
5
+
6
+ from transformers import PLBartTokenizer, PLBartForSequenceClassification, PLBartConfig, PLBartForConditionalGeneration
7
+ from typing import Any, Dict, List, Optional, Tuple, Union
8
+ from transformers.modeling_outputs import (
9
+ BaseModelOutput,
10
+ BaseModelOutputWithPastAndCrossAttentions,
11
+ CausalLMOutputWithCrossAttentions,
12
+ Seq2SeqLMOutput,
13
+ Seq2SeqModelOutput,
14
+ Seq2SeqSequenceClassifierOutput,
15
+ )
16
+ import torch
17
+ from torch import nn
18
+ from typing import Optional, List, Union, Tuple
19
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
20
+
21
+ from transformers import RobertaModel, RobertaPreTrainedModel
22
+ from transformers.models.plbart.modeling_plbart import PLBartPreTrainedModel, PLBartModel
23
+ from transformers.models.plbart.configuration_plbart import PLBartConfig
24
+
25
+
26
+ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
27
+ """
28
+ Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
29
+ have a single `decoder_start_token_id` in contrast to other Bart-like models.
30
+ """
31
+ prev_output_tokens = input_ids.clone()
32
+
33
+ if pad_token_id is None:
34
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
35
+ # replace possible -100 values in labels by `pad_token_id`
36
+ prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)
37
+
38
+ index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
39
+ decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze()
40
+ prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
41
+ prev_output_tokens[:, 0] = decoder_start_tokens
42
+
43
+ return prev_output_tokens
44
+
45
+ class PLBARTForCodeGeneration(PLBartPreTrainedModel):
46
+ base_model_prefix = "model"
47
+ _keys_to_ignore_on_load_missing = [
48
+ r"final_logits_bias",
49
+ r"encoder.version",
50
+ r"decoder.version",
51
+ r"lm_head.weight",
52
+ ]
53
+
54
+ def __init__(self, config: PLBartConfig):
55
+ super().__init__(config)
56
+ self.model = PLBartModel(config)
57
+ self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
58
+ self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
59
+
60
+ self.init_weights()
61
+
62
+ def get_encoder(self):
63
+ return self.model.get_encoder()
64
+
65
+ def get_decoder(self):
66
+ return self.model.get_decoder()
67
+
68
+ def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
69
+ new_embeddings = super().resize_token_embeddings(new_num_tokens)
70
+ self._resize_final_logits_bias(new_num_tokens)
71
+ return new_embeddings
72
+
73
+ def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
74
+ old_num_tokens = self.final_logits_bias.shape[-1]
75
+ if new_num_tokens <= old_num_tokens:
76
+ new_bias = self.final_logits_bias[:, :new_num_tokens]
77
+ else:
78
+ extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
79
+ new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
80
+ self.register_buffer("final_logits_bias", new_bias)
81
+
82
+ def get_output_embeddings(self):
83
+ return self.lm_head
84
+
85
+ def set_output_embeddings(self, new_embeddings):
86
+ self.lm_head = new_embeddings
87
+
88
+ def forward(
89
+ self,
90
+ input_ids: Optional[torch.LongTensor] = None,
91
+ attention_mask: Optional[torch.LongTensor] = None,
92
+ decoder_input_ids: Optional[torch.LongTensor] = None,
93
+ decoder_attention_mask: Optional[torch.Tensor] = None,
94
+ head_mask: Optional[torch.Tensor] = None,
95
+ decoder_head_mask: Optional[torch.LongTensor] = None,
96
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
97
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
98
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
99
+ inputs_embeds: Optional[torch.FloatTensor] = None,
100
+ decoder_inputs_embeds=None,
101
+ labels: Optional[torch.Tensor] = None,
102
+ use_cache: Optional[bool] = None,
103
+ output_attentions: Optional[bool] = None,
104
+ output_hidden_states: Optional[bool] = None,
105
+ return_dict: Optional[bool] = None,
106
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
107
+ r"""
108
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
109
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
110
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
111
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
112
+
113
+ Returns:
114
+
115
+ """
116
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
117
+
118
+ if labels is not None:
119
+ if decoder_input_ids is None:
120
+ decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
121
+
122
+ outputs = self.model(
123
+ input_ids,
124
+ attention_mask=attention_mask,
125
+ decoder_input_ids=decoder_input_ids,
126
+ encoder_outputs=encoder_outputs,
127
+ decoder_attention_mask=decoder_attention_mask,
128
+ head_mask=head_mask,
129
+ decoder_head_mask=decoder_head_mask,
130
+ cross_attn_head_mask=cross_attn_head_mask,
131
+ past_key_values=past_key_values,
132
+ inputs_embeds=inputs_embeds,
133
+ decoder_inputs_embeds=decoder_inputs_embeds,
134
+ use_cache=use_cache,
135
+ output_attentions=output_attentions,
136
+ output_hidden_states=output_hidden_states,
137
+ return_dict=return_dict,
138
+ )
139
+ lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
140
+
141
+ masked_lm_loss = None
142
+ if labels is not None:
143
+ loss_fct = CrossEntropyLoss()
144
+ masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
145
+
146
+ if not return_dict:
147
+ output = (lm_logits,) + outputs[1:]
148
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
149
+
150
+ return Seq2SeqLMOutput(
151
+ loss=masked_lm_loss,
152
+ logits=lm_logits,
153
+ past_key_values=outputs.past_key_values,
154
+ decoder_hidden_states=outputs.decoder_hidden_states,
155
+ decoder_attentions=outputs.decoder_attentions,
156
+ cross_attentions=outputs.cross_attentions,
157
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
158
+ encoder_hidden_states=outputs.encoder_hidden_states,
159
+ encoder_attentions=outputs.encoder_attentions,
160
+ )
161
+
162
+ def prepare_inputs_for_generation(
163
+ self,
164
+ decoder_input_ids: torch.LongTensor,
165
+ past: Optional[List[torch.FloatTensor]] = None,
166
+ attention_mask: Optional[torch.LongTensor] = None,
167
+ head_mask: Optional[torch.Tensor] = None,
168
+ decoder_head_mask: Optional[torch.Tensor] = None,
169
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
170
+ use_cache: Optional[bool] = None,
171
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
172
+ **kwargs # TODO: Check if this is needed. It is unused?
173
+ ) -> Dict[str, Any]:
174
+ # cut decoder_input_ids if past is used
175
+ if past is not None:
176
+ decoder_input_ids = decoder_input_ids[:, -1:]
177
+
178
+ return {
179
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
180
+ "encoder_outputs": encoder_outputs,
181
+ "past_key_values": past,
182
+ "decoder_input_ids": decoder_input_ids,
183
+ "attention_mask": attention_mask,
184
+ "head_mask": head_mask,
185
+ "decoder_head_mask": decoder_head_mask,
186
+ "cross_attn_head_mask": cross_attn_head_mask,
187
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
188
+ }
189
+
190
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
191
+ return shift_tokens_right(labels, self.config.pad_token_id)
192
+
193
+ @staticmethod
194
+ def _reorder_cache(past, beam_idx):
195
+ reordered_past = ()
196
+ for layer_past in past:
197
+ # cached cross_attention states don't have to be reordered -> they are always the same
198
+ reordered_past += (
199
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
200
+ )
201
+ return reordered_past