| """ | |
| Original work: | |
| https://github.com/sangHa0411/CloneDetection/blob/main/utils/preprocessor.py | |
| Copyright (c) 2022 Sangha Park(sangha110495), Young Jin Ahn(snoop2head) | |
| All credits to the original authors. | |
| """ | |
| import re | |
| import torch | |
| from transformers import Pipeline | |
| class FunctionPreprocessor: | |
| def get_function(self, code): | |
| results = [] | |
| fn_list = re.findall("\ndef [a-zA-Z0-9_]+\(", code) | |
| for fn in fn_list: | |
| results.append(fn[4:-1].strip()) | |
| return results | |
| def determine_function(self, code, function_name): | |
| num = len(re.findall("[^a-zA-Z]" + function_name + "[^a-zA-Z]", code)) | |
| return False if num <= 1 else True | |
| def delete_function(self, code, name): | |
| start_id, _ = re.search("def " + name, code).span() | |
| ptr = start_id | |
| while ptr < len(code) - 1: | |
| if code[ptr] == "\n" and re.search("[a-zA-Z]", code[ptr + 1]) is not None: | |
| break | |
| ptr += 1 | |
| if ptr != len(code) - 1: | |
| end_id = ptr | |
| code = code[:start_id] + code[end_id:] | |
| return code | |
| def preprocess(self, code): | |
| code = "\n" + code | |
| fn_list = self.get_function(code) | |
| if len(fn_list) == 0: | |
| return code | |
| for fn in fn_list: | |
| flag = self.determine_function(code, fn) | |
| if flag == False: | |
| code = self.delete_function(code, fn) | |
| return code | |
| class AnnotationPreprocessor: | |
| def search(self, sen_list, string): | |
| for i, sen in enumerate(sen_list): | |
| if string in sen: | |
| return i | |
| return -1 | |
| def delete_annotation_block(self, code, string): | |
| sens = [sen for sen in code.split("\n")] | |
| start_id = self.search(sens, string) | |
| end_id = self.search(sens[start_id + 1 :], string) | |
| if end_id != -1: | |
| end_id += start_id + 1 | |
| code = sens[:start_id] + sens[end_id + 1 :] | |
| else: | |
| code = sens[:start_id] + sens[start_id + 1 :] | |
| code = "\n".join(code) | |
| return code | |
| def delete_block(self, code, string): | |
| while string in code: | |
| code = self.delete_annotation_block(code, string) | |
| return code | |
| def delete_annotation(self, code): | |
| sens = code.split("\n") | |
| sens_processed = [] | |
| for sen in sens: | |
| if "#" in sen: | |
| index = sen.index("#") | |
| sen = sen[:index] | |
| sens_processed.append(sen) | |
| return "\n".join(sens_processed) | |
| def delete_import(self, code): | |
| sens = code.split("\n") | |
| sens_processed = [] | |
| for sen in sens: | |
| if "import" not in sen: | |
| sens_processed.append(sen) | |
| return "\n".join(sens_processed) | |
| def preprocess(self, code): | |
| code = self.delete_block(code, '"""') | |
| code = self.delete_block(code, "'''") | |
| code = self.delete_annotation(code) | |
| code = self.delete_import(code) | |
| code = re.sub("\s+", " ", code).strip() | |
| return code | |
| def preprocessor(code, instance): | |
| processed_code = instance.preprocess(code) | |
| return processed_code if processed_code.strip() else code | |
| def token_to_inputs(feature): | |
| inputs = {} | |
| for k, v in feature.items(): | |
| inputs[k] = torch.tensor(v).unsqueeze(0) | |
| return inputs | |
| class CloneDetectionPipeline(Pipeline): | |
| fn_preprocessor = FunctionPreprocessor() | |
| an_preprocessor = AnnotationPreprocessor() | |
| def _sanitize_parameters(self, **kwargs): | |
| preprocess_kwargs = {} | |
| return preprocess_kwargs, {}, {} | |
| def preprocess(self, inputs): | |
| code1 = inputs[0] | |
| code2 = inputs[1] | |
| if code1.strip() == "" or code2.strip() == "": | |
| ture_prob = float(code1.strip() == code2.strip()) | |
| return {"skip": True, "output": {False: 1 - ture_prob, True: ture_prob}} | |
| code1 = preprocessor( | |
| preprocessor(code1, self.fn_preprocessor), self.an_preprocessor | |
| ) | |
| code2 = preprocessor( | |
| preprocessor(code2, self.fn_preprocessor), self.an_preprocessor | |
| ) | |
| feature1 = self.tokenizer( | |
| code1, code2, max_length=512, return_token_type_ids=False, truncation=True | |
| ) | |
| feature2 = self.tokenizer( | |
| code2, code1, max_length=512, return_token_type_ids=False, truncation=True | |
| ) | |
| return { | |
| "inputs1": token_to_inputs(feature1), | |
| "inputs2": token_to_inputs(feature2), | |
| } | |
| def _forward(self, model_inputs): | |
| if model_inputs.get("skip", False): | |
| return model_inputs | |
| inputs1 = model_inputs["inputs1"] | |
| inputs2 = model_inputs["inputs2"] | |
| logits1 = self.model(**inputs1).logits[0] | |
| logits2 = self.model(**inputs2).logits[0] | |
| logits = (logits1 + logits2) / 2 | |
| return {"logits": logits} | |
| def postprocess(self, model_outputs): | |
| if model_outputs.get("skip", False): | |
| return model_outputs["output"] | |
| probs = model_outputs["logits"].softmax(-1).tolist() | |
| return {False: probs[0], True: probs[1]} | |