python-clone-detection / clone_detection_pipeline.py
Lazyhope's picture
Add pipeline for clone detection
d0e1e46
raw
history blame contribute delete
No virus
5.15 kB
"""
Original work:
https://github.com/sangHa0411/CloneDetection/blob/main/utils/preprocessor.py
Copyright (c) 2022 Sangha Park(sangha110495), Young Jin Ahn(snoop2head)
All credits to the original authors.
"""
import re
import torch
from transformers import Pipeline
class FunctionPreprocessor:
def get_function(self, code):
results = []
fn_list = re.findall("\ndef [a-zA-Z0-9_]+\(", code)
for fn in fn_list:
results.append(fn[4:-1].strip())
return results
def determine_function(self, code, function_name):
num = len(re.findall("[^a-zA-Z]" + function_name + "[^a-zA-Z]", code))
return False if num <= 1 else True
def delete_function(self, code, name):
start_id, _ = re.search("def " + name, code).span()
ptr = start_id
while ptr < len(code) - 1:
if code[ptr] == "\n" and re.search("[a-zA-Z]", code[ptr + 1]) is not None:
break
ptr += 1
if ptr != len(code) - 1:
end_id = ptr
code = code[:start_id] + code[end_id:]
return code
def preprocess(self, code):
code = "\n" + code
fn_list = self.get_function(code)
if len(fn_list) == 0:
return code
for fn in fn_list:
flag = self.determine_function(code, fn)
if flag == False:
code = self.delete_function(code, fn)
return code
class AnnotationPreprocessor:
def search(self, sen_list, string):
for i, sen in enumerate(sen_list):
if string in sen:
return i
return -1
def delete_annotation_block(self, code, string):
sens = [sen for sen in code.split("\n")]
start_id = self.search(sens, string)
end_id = self.search(sens[start_id + 1 :], string)
if end_id != -1:
end_id += start_id + 1
code = sens[:start_id] + sens[end_id + 1 :]
else:
code = sens[:start_id] + sens[start_id + 1 :]
code = "\n".join(code)
return code
def delete_block(self, code, string):
while string in code:
code = self.delete_annotation_block(code, string)
return code
def delete_annotation(self, code):
sens = code.split("\n")
sens_processed = []
for sen in sens:
if "#" in sen:
index = sen.index("#")
sen = sen[:index]
sens_processed.append(sen)
return "\n".join(sens_processed)
def delete_import(self, code):
sens = code.split("\n")
sens_processed = []
for sen in sens:
if "import" not in sen:
sens_processed.append(sen)
return "\n".join(sens_processed)
def preprocess(self, code):
code = self.delete_block(code, '"""')
code = self.delete_block(code, "'''")
code = self.delete_annotation(code)
code = self.delete_import(code)
code = re.sub("\s+", " ", code).strip()
return code
def preprocessor(code, instance):
processed_code = instance.preprocess(code)
return processed_code if processed_code.strip() else code
def token_to_inputs(feature):
inputs = {}
for k, v in feature.items():
inputs[k] = torch.tensor(v).unsqueeze(0)
return inputs
class CloneDetectionPipeline(Pipeline):
fn_preprocessor = FunctionPreprocessor()
an_preprocessor = AnnotationPreprocessor()
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
return preprocess_kwargs, {}, {}
def preprocess(self, inputs):
code1 = inputs[0]
code2 = inputs[1]
if code1.strip() == "" or code2.strip() == "":
ture_prob = float(code1.strip() == code2.strip())
return {"skip": True, "output": {False: 1 - ture_prob, True: ture_prob}}
code1 = preprocessor(
preprocessor(code1, self.fn_preprocessor), self.an_preprocessor
)
code2 = preprocessor(
preprocessor(code2, self.fn_preprocessor), self.an_preprocessor
)
feature1 = self.tokenizer(
code1, code2, max_length=512, return_token_type_ids=False, truncation=True
)
feature2 = self.tokenizer(
code2, code1, max_length=512, return_token_type_ids=False, truncation=True
)
return {
"inputs1": token_to_inputs(feature1),
"inputs2": token_to_inputs(feature2),
}
def _forward(self, model_inputs):
if model_inputs.get("skip", False):
return model_inputs
inputs1 = model_inputs["inputs1"]
inputs2 = model_inputs["inputs2"]
logits1 = self.model(**inputs1).logits[0]
logits2 = self.model(**inputs2).logits[0]
logits = (logits1 + logits2) / 2
return {"logits": logits}
def postprocess(self, model_outputs):
if model_outputs.get("skip", False):
return model_outputs["output"]
probs = model_outputs["logits"].softmax(-1).tolist()
return {False: probs[0], True: probs[1]}