File size: 5,154 Bytes
d0e1e46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""
Original work:
https://github.com/sangHa0411/CloneDetection/blob/main/utils/preprocessor.py

Copyright (c) 2022 Sangha Park(sangha110495), Young Jin Ahn(snoop2head)

All credits to the original authors.
"""
import re
import torch
from transformers import Pipeline


class FunctionPreprocessor:
    def get_function(self, code):
        results = []
        fn_list = re.findall("\ndef [a-zA-Z0-9_]+\(", code)

        for fn in fn_list:
            results.append(fn[4:-1].strip())
        return results

    def determine_function(self, code, function_name):
        num = len(re.findall("[^a-zA-Z]" + function_name + "[^a-zA-Z]", code))
        return False if num <= 1 else True

    def delete_function(self, code, name):
        start_id, _ = re.search("def " + name, code).span()
        ptr = start_id

        while ptr < len(code) - 1:
            if code[ptr] == "\n" and re.search("[a-zA-Z]", code[ptr + 1]) is not None:
                break
            ptr += 1

        if ptr != len(code) - 1:
            end_id = ptr
            code = code[:start_id] + code[end_id:]

        return code

    def preprocess(self, code):
        code = "\n" + code
        fn_list = self.get_function(code)
        if len(fn_list) == 0:
            return code

        for fn in fn_list:
            flag = self.determine_function(code, fn)

            if flag == False:
                code = self.delete_function(code, fn)

        return code


class AnnotationPreprocessor:
    def search(self, sen_list, string):
        for i, sen in enumerate(sen_list):
            if string in sen:
                return i
        return -1

    def delete_annotation_block(self, code, string):
        sens = [sen for sen in code.split("\n")]

        start_id = self.search(sens, string)
        end_id = self.search(sens[start_id + 1 :], string)
        if end_id != -1:
            end_id += start_id + 1
            code = sens[:start_id] + sens[end_id + 1 :]
        else:
            code = sens[:start_id] + sens[start_id + 1 :]

        code = "\n".join(code)
        return code

    def delete_block(self, code, string):
        while string in code:
            code = self.delete_annotation_block(code, string)
        return code

    def delete_annotation(self, code):
        sens = code.split("\n")

        sens_processed = []
        for sen in sens:
            if "#" in sen:
                index = sen.index("#")
                sen = sen[:index]
            sens_processed.append(sen)

        return "\n".join(sens_processed)

    def delete_import(self, code):
        sens = code.split("\n")

        sens_processed = []
        for sen in sens:
            if "import" not in sen:
                sens_processed.append(sen)

        return "\n".join(sens_processed)

    def preprocess(self, code):
        code = self.delete_block(code, '"""')
        code = self.delete_block(code, "'''")
        code = self.delete_annotation(code)
        code = self.delete_import(code)
        code = re.sub("\s+", " ", code).strip()
        return code


def preprocessor(code, instance):
    processed_code = instance.preprocess(code)
    return processed_code if processed_code.strip() else code


def token_to_inputs(feature):
    inputs = {}
    for k, v in feature.items():
        inputs[k] = torch.tensor(v).unsqueeze(0)

    return inputs


class CloneDetectionPipeline(Pipeline):
    fn_preprocessor = FunctionPreprocessor()
    an_preprocessor = AnnotationPreprocessor()

    def _sanitize_parameters(self, **kwargs):
        preprocess_kwargs = {}
        return preprocess_kwargs, {}, {}

    def preprocess(self, inputs):
        code1 = inputs[0]
        code2 = inputs[1]
        if code1.strip() == "" or code2.strip() == "":
            ture_prob = float(code1.strip() == code2.strip())
            return {"skip": True, "output": {False: 1 - ture_prob, True: ture_prob}}

        code1 = preprocessor(
            preprocessor(code1, self.fn_preprocessor), self.an_preprocessor
        )
        code2 = preprocessor(
            preprocessor(code2, self.fn_preprocessor), self.an_preprocessor
        )

        feature1 = self.tokenizer(
            code1, code2, max_length=512, return_token_type_ids=False, truncation=True
        )
        feature2 = self.tokenizer(
            code2, code1, max_length=512, return_token_type_ids=False, truncation=True
        )

        return {
            "inputs1": token_to_inputs(feature1),
            "inputs2": token_to_inputs(feature2),
        }

    def _forward(self, model_inputs):
        if model_inputs.get("skip", False):
            return model_inputs

        inputs1 = model_inputs["inputs1"]
        inputs2 = model_inputs["inputs2"]

        logits1 = self.model(**inputs1).logits[0]
        logits2 = self.model(**inputs2).logits[0]
        logits = (logits1 + logits2) / 2

        return {"logits": logits}

    def postprocess(self, model_outputs):
        if model_outputs.get("skip", False):
            return model_outputs["output"]

        probs = model_outputs["logits"].softmax(-1).tolist()

        return {False: probs[0], True: probs[1]}