PROTAC-Degradation-Predictor / src /collect_checkpoints.py
ribesstefano's picture
Added results with FP normalization removed
d8a186f
import os
import sys
import re
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
def main():
# Create a directory called models if it doesn't exist
models_dir = os.path.join(os.path.dirname(__file__), '..', 'models')
if not os.path.exists(models_dir):
os.makedirs(models_dir)
log_dir = os.path.join(os.path.dirname(__file__), '..', 'logs')
trim_str = 'logs_Active_Dmax_0.6_pDC50_6.0_test_split_0.1_random_random_cv_model_fold0/logs_Active_Dmax_0.6_pDC50_6.0_test_split_0.1_'
# Recursively loop in the log directory
# in the checkpoint directory, copy the file ending with .ckpt into models but rename its name from 'protac' to its
# parent directory name removing 'logs_Active_Dmax_0.6_pDC50_6.0_test_split_0.1_random_random_cv_model_fold0/logs_Active_Dmax_0.6_pDC50_6.0_test_split_0.1_'
for root, dirs, files in os.walk(log_dir):
for file in files:
if file.endswith('.ckpt'):
checkpoint_file = os.path.join(root, file)
model_name = file.split(trim_str)[-1]
if 'tanimoto' in root:
split_type = 'tanimoto'
elif 'random' in root:
split_type = 'random'
elif 'uniprot' in root:
split_type = 'uniprot'
else:
raise ValueError('Unknown split type')
if 'fold' in root:
# Get fold number, i.e., subsequent character after fold
fold = root.split('fold')[-1][0]
# print(f'models/{model_name.replace("protac", f"cv_model_{split_type}_fold{fold}")}')
model_name = model_name.replace("protac", f"cv_model_{split_type}_fold{fold}")
else:
model_name = model_name.replace("val_", "test_")
# Check if in destination directory, there is a "similar" model
base_model_name = model_name.split('-')[0]
old_model_name = None
# Loop over the models directory to check if there is a similar model
for model in os.listdir(models_dir):
if base_model_name in model:
old_model_name = model
break
# Now check the accuracy and ROC-AUC of the old model and the new model
# If the new model is better, then replace the old model with the new model
# Example of model name: cv_model_random_fold0-epoch=00-val_acc=0.00-val_roc_auc=0.00.ckpt
if old_model_name is not None:
if 'val_acc' in model_name:
old_acc = float(re.search(r'val_acc=(\d+\.\d+)', old_model_name).group(1))
old_roc_auc = float(re.search(r'val_roc_auc=(\d+\.\d+)', old_model_name).group(1))
new_acc = float(re.search(r'val_acc=(\d+\.\d+)', model_name).group(1))
new_roc_auc = float(re.search(r'val_roc_auc=(\d+\.\d+)', model_name).group(1))
if new_acc > old_acc and new_roc_auc > old_roc_auc:
print(f'Replacing {old_model_name} with {model_name}')
os.system(f'rm {os.path.join(models_dir, old_model_name)}')
os.system(f'cp {checkpoint_file} {os.path.join(models_dir, model_name)}')
if 'test_acc' in model_name:
old_acc = float(re.search(r'test_acc=(\d+\.\d+)', old_model_name).group(1))
old_roc_auc = float(re.search(r'test_roc_auc=(\d+\.\d+)', old_model_name).group(1))
new_acc = float(re.search(r'test_acc=(\d+\.\d+)', model_name).group(1))
new_roc_auc = float(re.search(r'test_roc_auc=(\d+\.\d+)', model_name).group(1))
if new_acc > old_acc and new_roc_auc > old_roc_auc:
print(f'Replacing {old_model_name} with {model_name}')
os.system(f'rm {os.path.join(models_dir, old_model_name)}')
os.system(f'cp {checkpoint_file} {os.path.join(models_dir, model_name)}')
else:
print(f'Copying {model_name}')
os.system(f'cp {checkpoint_file} {os.path.join(models_dir, model_name)}')
if __name__ == '__main__':
main()