jRefactoring / autoencoder.py
gautam-shetty's picture
Initial commit
a5fb347
from keras.layers import Input, Dense, Flatten
from keras.models import Model
from Database import Database
import numpy as np, json
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from dotenv import dotenv_values
import pandas as pd
# from tensorflow.python.ops.confusion_matrix import confusion_matrix
from sklearn.metrics import precision_recall_fscore_support
class Autoencoder:
def __get_autoencoder(self, input_dim) -> Model:
input_shape = (input_dim,)
input_layer = Input(shape=input_shape)
# Encoder layers
encoder = Flatten()(input_layer)
encoder = Dense(128, activation='relu')(encoder)
encoder = Dense(64, activation='relu')(encoder)
# encoder = Dense(32, activation='relu')(encoder)
# Decoder layers
# decoder = Dense(64, activation='relu')(encoder)
decoder = Dense(128, activation='relu')(encoder) #decoder
decoder = Dense(input_dim, activation='sigmoid')(decoder)
# Autoencoder model
autoencoder = Model(inputs=input_layer, outputs=decoder)
# autoencoder.compile(optimizer='adam', loss='binary_crossentropy')
autoencoder.compile(optimizer='adam', loss='mse')
return autoencoder
def __print_summary(self, model: Model):
print(model.summary())
return
def __fit_autoencoder(self,epochs,batch_size,model: Model, train_var,valid_var=None):
history = model.fit(train_var,train_var,
# validation_data=(valid_var,valid_var),
epochs=epochs,batch_size=batch_size)
return history, model
def __split_train_test_val(self, data):
train_array, test_array = train_test_split(data,test_size=0.2,random_state=42)
train_array, valid_array = train_test_split(train_array,test_size=0.1,random_state=42)
return train_array, valid_array, test_array
@staticmethod
def __compute_metrics(conf_matrix):
precision = conf_matrix[1][1] / (conf_matrix[1][1] + conf_matrix[0][1])
if precision==1:
print(conf_matrix)
recall = conf_matrix[1][1] / (conf_matrix[1][1] + conf_matrix[1][0])
f1 = (2 * precision * recall) / (precision + recall)
# print("precision: " + str(precision) + ", recall: " + str(recall) + ", f1: " + str(f1))
return precision, recall, f1
def __find_optimal_modified(self,error_df: pd.DataFrame, steps=50):
min_error, max_error = error_df["Reconstruction_error"].min(), error_df["Reconstruction_error"].max()
optimal_threshold = (min_error+max_error)/2
y_pred = [0 if e > optimal_threshold else 1 for e in error_df.Reconstruction_error.values]
precision, recall, f1,_=precision_recall_fscore_support(error_df.True_class, y_pred, average='macro')
return optimal_threshold, precision, recall, f1
def __find_optimal(self,error_df: pd.DataFrame, steps=50):
min_error, max_error = error_df["Reconstruction_error"].min(), error_df["Reconstruction_error"].max()
optimal_threshold = min_error
max_f1 = 0
max_pr = 0
max_re = 0
# step_value = (max_error-min_error)/(steps - 1)
for threshold in np.arange(min_error, max_error, 0.005):
# print("Threshold: " + str(threshold))
# y_pred = [1 if e > threshold else 0 for e in error_df.Reconstruction_error.values]
y_pred = [0 if e > threshold else 1 for e in error_df.Reconstruction_error.values]
# conf_matrix = confusion_matrix(error_df.True_class, y_pred)
# precision, recall, f1 = self.__compute_metrics(conf_matrix)
# precision, recall, f1,_=precision_recall_fscore_support(error_df.True_class, y_pred, average='macro')
# precision, recall, f1,_=precision_recall_fscore_support(error_df.True_class, y_pred, average='micro')
# precision, recall, f1,_=precision_recall_fscore_support(error_df.True_class, y_pred, average='weighted')
precision, recall, f1,_=precision_recall_fscore_support(error_df.True_class, y_pred, average='binary')
if f1 > max_f1:
max_f1 = f1
optimal_threshold = threshold
max_pr = precision
max_re = recall
print(f"Result optimal_threshold={optimal_threshold}, max_precision={max_pr}, max_recall={max_re}, max_f1={max_f1}")
# return optimal_threshold, max_pr.numpy(), max_re.numpy(), max_f1.numpy()
return optimal_threshold, max_pr, max_re, max_f1
@staticmethod
def __split_by_percent(data,percent):
return train_test_split(data,test_size=0.3,random_state=42)
def train_autoencoder(self):
#GraphCodeBERT
autoencoder = self.__get_autoencoder(768)
self.__print_summary(autoencoder)
#Create Dataset df
df = pd.DataFrame(columns=['Embedding','True_class'])
#DB
db = Database(dotenv_values(".env")['COLLECTION_NAME'])
# embeddings_list = [emb["embedding"] for emb in list(db.find_docs({"refactoring_type":"Extract Method"}))]
pos_emb_list, neg_emb_list = [],[]
for doc in list(db.find_docs({"refactoring_type":"Extract Method"})):
pos_emb_list.append(doc['embedding_pos'])
neg_emb_list.append(doc['embedding_neg'])
pos_emb_list_train, pos_emb_list_test = self.__split_by_percent(pos_emb_list,0.3)
_, neg_emb_list_test = self.__split_by_percent(neg_emb_list,0.3)
x_train = np.array(pos_emb_list_train)
x_test = np.array(pos_emb_list_test+neg_emb_list_test)
y_test = np.array([1 for i in range(0,len(pos_emb_list_test))]+[0 for i in range(0,len(neg_emb_list_test))])
# print(np.array(pos_emb_list_train).shape)
epoch = 25
history, trained_model = self.__fit_autoencoder(epoch,32,autoencoder,x_train)
trained_model.save('./results/autoencoder_'+str(epoch)+'.hdf5')
#Test
test_predict = trained_model.predict(x_test)
mse = np.mean(np.power(x_test - test_predict, 2), axis=1)
error_df = pd.DataFrame({'Reconstruction_error': mse,
'True_class': y_test})
print("Max: ", error_df["Reconstruction_error"].max())
print("Min: ", error_df["Reconstruction_error"].min())
# optimal_threshold, precision, recall, f1 = self.__find_optimal(error_df,100)
optimal_threshold, precision, recall, f1 = self.__find_optimal_modified(error_df,100)
print(f"Result optimal_threshold={optimal_threshold}, max_precision={precision}, max_recall={recall}, max_f1={f1}")
metrics = {
"Threshold":optimal_threshold,
"Precision": precision,
"Recall":recall,
"F1":f1
}
with open('./results/metrics.json','w') as fp:
json.dump(metrics,fp)
plt.plot(history.history['loss'])
plt.savefig("./results/training_graph.png")
if __name__=="__main__":
Autoencoder().train_autoencoder()