""" Original work: https://github.com/sangHa0411/CloneDetection/blob/main/utils/preprocessor.py Copyright (c) 2022 Sangha Park(sangha110495), Young Jin Ahn(snoop2head) All credits to the original authors. """ import re import torch from transformers import Pipeline class FunctionPreprocessor: def get_function(self, code): results = [] fn_list = re.findall("\ndef [a-zA-Z0-9_]+\(", code) for fn in fn_list: results.append(fn[4:-1].strip()) return results def determine_function(self, code, function_name): num = len(re.findall("[^a-zA-Z]" + function_name + "[^a-zA-Z]", code)) return False if num <= 1 else True def delete_function(self, code, name): start_id, _ = re.search("def " + name, code).span() ptr = start_id while ptr < len(code) - 1: if code[ptr] == "\n" and re.search("[a-zA-Z]", code[ptr + 1]) is not None: break ptr += 1 if ptr != len(code) - 1: end_id = ptr code = code[:start_id] + code[end_id:] return code def preprocess(self, code): code = "\n" + code fn_list = self.get_function(code) if len(fn_list) == 0: return code for fn in fn_list: flag = self.determine_function(code, fn) if flag == False: code = self.delete_function(code, fn) return code class AnnotationPreprocessor: def search(self, sen_list, string): for i, sen in enumerate(sen_list): if string in sen: return i return -1 def delete_annotation_block(self, code, string): sens = [sen for sen in code.split("\n")] start_id = self.search(sens, string) end_id = self.search(sens[start_id + 1 :], string) if end_id != -1: end_id += start_id + 1 code = sens[:start_id] + sens[end_id + 1 :] else: code = sens[:start_id] + sens[start_id + 1 :] code = "\n".join(code) return code def delete_block(self, code, string): while string in code: code = self.delete_annotation_block(code, string) return code def delete_annotation(self, code): sens = code.split("\n") sens_processed = [] for sen in sens: if "#" in sen: index = sen.index("#") sen = sen[:index] sens_processed.append(sen) return "\n".join(sens_processed) def delete_import(self, code): sens = code.split("\n") sens_processed = [] for sen in sens: if "import" not in sen: sens_processed.append(sen) return "\n".join(sens_processed) def preprocess(self, code): code = self.delete_block(code, '"""') code = self.delete_block(code, "'''") code = self.delete_annotation(code) code = self.delete_import(code) code = re.sub("\s+", " ", code).strip() return code def preprocessor(code, instance): processed_code = instance.preprocess(code) return processed_code if processed_code.strip() else code def token_to_inputs(feature): inputs = {} for k, v in feature.items(): inputs[k] = torch.tensor(v).unsqueeze(0) return inputs class CloneDetectionPipeline(Pipeline): fn_preprocessor = FunctionPreprocessor() an_preprocessor = AnnotationPreprocessor() def _sanitize_parameters(self, **kwargs): preprocess_kwargs = {} return preprocess_kwargs, {}, {} def preprocess(self, inputs): code1 = inputs[0] code2 = inputs[1] if code1.strip() == "" or code2.strip() == "": ture_prob = float(code1.strip() == code2.strip()) return {"skip": True, "output": {False: 1 - ture_prob, True: ture_prob}} code1 = preprocessor( preprocessor(code1, self.fn_preprocessor), self.an_preprocessor ) code2 = preprocessor( preprocessor(code2, self.fn_preprocessor), self.an_preprocessor ) feature1 = self.tokenizer( code1, code2, max_length=512, return_token_type_ids=False, truncation=True ) feature2 = self.tokenizer( code2, code1, max_length=512, return_token_type_ids=False, truncation=True ) return { "inputs1": token_to_inputs(feature1), "inputs2": token_to_inputs(feature2), } def _forward(self, model_inputs): if model_inputs.get("skip", False): return model_inputs inputs1 = model_inputs["inputs1"] inputs2 = model_inputs["inputs2"] logits1 = self.model(**inputs1).logits[0] logits2 = self.model(**inputs2).logits[0] logits = (logits1 + logits2) / 2 return {"logits": logits} def postprocess(self, model_outputs): if model_outputs.get("skip", False): return model_outputs["output"] probs = model_outputs["logits"].softmax(-1).tolist() return {False: probs[0], True: probs[1]}