Lazyhope commited on
Commit
92c76d1
1 Parent(s): 7c92b66

Upload model

Browse files
Files changed (3) hide show
  1. CloneDetectionModel.py +96 -0
  2. config.json +5 -2
  3. pytorch_model.bin +2 -2
CloneDetectionModel.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Original work:
3
+ https://github.com/sangHa0411/CloneDetection/blob/main/models/codebert.py#L169
4
+
5
+ Copyright (c) 2022 Sangha Park(sangha110495), Young Jin Ahn(snoop2head)
6
+
7
+ All credits to the original authors.
8
+ """
9
+ import torch.nn as nn
10
+ from transformers import (
11
+ RobertaPreTrainedModel,
12
+ RobertaModel,
13
+ )
14
+ from transformers.modeling_outputs import SequenceClassifierOutput
15
+
16
+
17
+ class CloneDetectionModel(RobertaPreTrainedModel):
18
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
19
+
20
+ def __init__(self, config):
21
+ super().__init__(config)
22
+ self.num_labels = config.num_labels
23
+ self.config = config
24
+
25
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
26
+ self.net = nn.Sequential(
27
+ nn.Dropout(config.hidden_dropout_prob),
28
+ nn.Linear(config.hidden_size, config.hidden_size),
29
+ nn.ReLU(),
30
+ )
31
+ self.classifier = nn.Linear(config.hidden_size * 4, config.num_labels)
32
+
33
+ def forward(
34
+ self,
35
+ input_ids=None,
36
+ attention_mask=None,
37
+ token_type_ids=None,
38
+ position_ids=None,
39
+ head_mask=None,
40
+ inputs_embeds=None,
41
+ labels=None,
42
+ output_attentions=None,
43
+ output_hidden_states=None,
44
+ return_dict=None,
45
+ ):
46
+
47
+ return_dict = (
48
+ return_dict if return_dict is not None else self.config.use_return_dict
49
+ )
50
+
51
+ outputs = self.roberta(
52
+ input_ids,
53
+ attention_mask=attention_mask,
54
+ token_type_ids=token_type_ids,
55
+ position_ids=position_ids,
56
+ head_mask=head_mask,
57
+ inputs_embeds=inputs_embeds,
58
+ output_attentions=output_attentions,
59
+ output_hidden_states=output_hidden_states,
60
+ return_dict=return_dict,
61
+ )
62
+
63
+ hidden_states = outputs[0]
64
+ batch_size, _, hidden_size = hidden_states.shape
65
+
66
+ # CLS code1 SEP SEP code2 SEP
67
+ cls_flag = input_ids == self.config.tokenizer_cls_token_id # cls token
68
+ sep_flag = input_ids == self.config.tokenizer_sep_token_id # sep token
69
+
70
+ special_token_states = hidden_states[cls_flag + sep_flag].view(
71
+ batch_size, -1, hidden_size
72
+ ) # (batch_size, 4, hidden_size)
73
+ special_hidden_states = self.net(
74
+ special_token_states
75
+ ) # (batch_size, 4, hidden_size)
76
+
77
+ pooled_output = special_hidden_states.view(
78
+ batch_size, -1
79
+ ) # (batch_size, hidden_size * 4)
80
+ logits = self.classifier(pooled_output) # (batch_size, num_labels)
81
+
82
+ loss = None
83
+ if labels is not None:
84
+ loss_fct = nn.CrossEntropyLoss()
85
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
86
+
87
+ if not return_dict:
88
+ output = (logits,) + outputs[2:]
89
+ return ((loss,) + output) if loss is not None else output
90
+
91
+ return SequenceClassifierOutput(
92
+ loss=loss,
93
+ logits=logits,
94
+ hidden_states=outputs.hidden_states,
95
+ attentions=outputs.attentions,
96
+ )
config.json CHANGED
@@ -1,9 +1,12 @@
1
  {
2
- "_name_or_path": "microsoft/graphcodebert-base",
3
  "architectures": [
4
- "RobertaRBERT"
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
 
 
 
7
  "bos_token_id": 0,
8
  "classifier_dropout": null,
9
  "dropout_rate": 0.1,
 
1
  {
2
+ "_name_or_path": "./checkpoint",
3
  "architectures": [
4
+ "CloneDetectionModel"
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
+ "auto_map": {
8
+ "AutoModel": "CloneDetectionModel.CloneDetectionModel"
9
+ },
10
  "bos_token_id": 0,
11
  "classifier_dropout": null,
12
  "dropout_rate": 0.1,
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ee6809a890981926df5d49e47747fc9912290459d9d678e7f89a49ceb133200d
3
- size 498680501
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd61ea43ac55f9dcb691449f3489fbc90638a96a958289b24c7abf6306642f02
3
+ size 498675949