""" Original work: https://github.com/sangHa0411/CloneDetection/blob/main/models/codebert.py#L169 Copyright (c) 2022 Sangha Park(sangha110495), Young Jin Ahn(snoop2head) All credits to the original authors. """ import torch.nn as nn from transformers import ( RobertaPreTrainedModel, RobertaModel, ) from transformers.modeling_outputs import SequenceClassifierOutput class CloneDetectionModel(RobertaPreTrainedModel): _keys_to_ignore_on_load_missing = [r"position_ids"] def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.config = config self.roberta = RobertaModel(config, add_pooling_layer=False) self.net = nn.Sequential( nn.Dropout(config.hidden_dropout_prob), nn.Linear(config.hidden_size, config.hidden_size), nn.ReLU(), ) self.classifier = nn.Linear(config.hidden_size * 4, config.num_labels) def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) outputs = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] batch_size, _, hidden_size = hidden_states.shape # CLS code1 SEP SEP code2 SEP cls_flag = input_ids == self.config.tokenizer_cls_token_id # cls token sep_flag = input_ids == self.config.tokenizer_sep_token_id # sep token special_token_states = hidden_states[cls_flag + sep_flag].view( batch_size, -1, hidden_size ) # (batch_size, 4, hidden_size) special_hidden_states = self.net( special_token_states ) # (batch_size, 4, hidden_size) pooled_output = special_hidden_states.view( batch_size, -1 ) # (batch_size, hidden_size * 4) logits = self.classifier(pooled_output) # (batch_size, num_labels) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )