jRefactoring / predict.py
gautam-shetty's picture
fix: return loss and threshold
3deb082
raw
history blame
1.29 kB
from graphCodeBert import GraphCodeBert
from keras.models import load_model, Model
import numpy as np, json
class Predict:
def __generate_code_embedding(self,code_snippet):
embedding = np.array(GraphCodeBert().generate_individual_embedding(code_snippet)).reshape((1,768))
return embedding
def __calculate_loss(self,code_embedding,model_name):
model:Model = load_model(f'results/{model_name}.hdf5')
return model.evaluate(code_embedding,code_embedding)
def predict(self,code_snippet):
model_name="autoencoder_25"
code_embedding = self.__generate_code_embedding(code_snippet)
print("Input code snippet shape: ",code_embedding.shape)
loss = self.__calculate_loss(code_embedding,model_name)
print("Reconstruction Loss: ",loss)
with open('./results/metrics.json',"r") as fp:
metric_json = json.loads(fp.read())
threshold = metric_json["Threshold"]
return "Not a candidate for refactoring" if loss>threshold else "Is a candidate for refactoring", threshold, loss
if __name__=="__main__":
Predict().predict(""" public void sleep(){
int s1 = 1;
int s2 = 2;
int s3 = 3;
int s4 = 4;
int s5 = 5;
int s6 = 6;
int s7 = 7;
int s8 = 8;
}""")