import pickle import numpy as np import json import pandas as pd from data.scripts.extract_rmsf_labels import extract_rmsf_labels, extract_bfactor_labels, extract_plddt_labels import yaml from tqdm import tqdm import os def get_flucts_from_pickle(f): return pickle.load(f) def get_flucts_from_jsonl(f): _flucts = f.readlines() pdb_code_to_fluct_dict = {} for line in _flucts: json_obj = json.loads(line.strip()) pdb_code_to_fluct_dict[json_obj['pdb_name']] = np.array(json_obj['fluctuations']) return pdb_code_to_fluct_dict def read_flexpert_predictions(path): with open(path, 'r') as f: lines = f.readlines() pdb_code_to_fluct_dict = {} for name_line, fluct_line in zip(lines[::2], lines[1::2]): _name = name_line.strip().strip('>') if '.' in _name: _name = _name.replace('.', '_') pdb_code_to_fluct_dict[_name] = np.array(fluct_line.strip().split(','), dtype=np.float32) return pdb_code_to_fluct_dict if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument('--evaluate_flexpert', action='store_true', default=False) args = parser.parse_args() config = yaml.load(open('configs/data_config.yaml', 'r'), Loader=yaml.FullLoader) DATA_DIR = config['precomputed_flexibility_profiles_dir'] if args.evaluate_flexpert: flexpert_3d_predictions_path = config['flexpert_3d_predictions_path'] flexpert_seq_predictions_path = config['flexpert_seq_predictions_path'] assert os.path.exists(flexpert_3d_predictions_path), f"Flexpert-3D predictions file does not exist: {flexpert_3d_predictions_path}" assert os.path.exists(flexpert_seq_predictions_path), f"Flexpert-Seq predictions file does not exist: {flexpert_seq_predictions_path}" flexpert_3d_predictions = read_flexpert_predictions(flexpert_3d_predictions_path) flexpert_seq_predictions = read_flexpert_predictions(flexpert_seq_predictions_path) with open(f'{DATA_DIR}/anm_square_fluctuations.pickle','rb') as f: anm_sqFlucts = get_flucts_from_pickle(f) with open(f'{DATA_DIR}/gnm_square_fluctuations.pickle','rb') as f: gnm_sqFlucts = get_flucts_from_pickle(f) with open(f'{DATA_DIR}/atlas_esm_plddt.jsonl','rb') as f: esm_plddt = get_flucts_from_jsonl(f) atlas_list_path = config['pdb_codes_path'] atlas_analyses_dir = config['atlas_out_dir'] atlas_bfactor_path = atlas_analyses_dir + "/{}_analysis/{}_Bfactor.tsv" atlas_plddt_path = atlas_analyses_dir + "/{}_analysis/{}_pLDDT.tsv" atlas_rmsf_path = atlas_analyses_dir + "/{}_analysis/{}_RMSF.tsv" with open(atlas_list_path,'r') as f: atlas_list = f.readlines() atlas_list = [a.strip() for a in atlas_list] fluctuations = {} if args.evaluate_flexpert: print("Filtering to testset only, to evaluate Flexpert-3D and Flexpert-Seq predictions") atlas_list = [a for a in atlas_list if a in flexpert_seq_predictions.keys()] for key in tqdm(atlas_list): fluctuations[key] = pd.DataFrame() fluctuations[key]['prody_ANM'] = np.sqrt(anm_sqFlucts.get(key, np.nan)) fluctuations[key]['prody_GNM'] = np.sqrt(gnm_sqFlucts.get(key, np.nan)) fluctuations[key]['esm_plddt'] = 1 - esm_plddt.get(key, np.nan) fluctuations[key]['rmsf'] = extract_rmsf_labels(atlas_rmsf_path.format(key, key))[1] fluctuations[key]['bfactor'] = extract_bfactor_labels(atlas_bfactor_path.format(key, key))[1] fluctuations[key]['af2_plddt'] = 1 - extract_plddt_labels(atlas_plddt_path.format(key, key))[1] if args.evaluate_flexpert and key in flexpert_seq_predictions.keys(): fluctuations[key]['flexpert_3d'] = flexpert_3d_predictions.get(key) fluctuations[key]['flexpert_seq'] = flexpert_seq_predictions.get(key) pearson_correlations = [] for pdb_code,df in fluctuations.items(): cols = ['rmsf', 'bfactor', 'af2_plddt', 'esm_plddt', 'prody_GNM', 'prody_ANM'] if args.evaluate_flexpert: cols.append('flexpert_3d') cols.append('flexpert_seq') pc = df[cols].corr(method='pearson') if np.any(np.isnan(pc)): print(f'{pdb_code} has NaN values in Pearson correlation') continue pearson_correlations.append(pc) #compute average across all pdb codes columns = ['rmsf', 'bfactor', 'af2_plddt', 'esm_plddt', 'prody_GNM', 'prody_ANM'] if args.evaluate_flexpert: columns.append('flexpert_3d') columns.append('flexpert_seq') print("Pearson correlations:") pearson_mean = np.mean(pearson_correlations, axis=0) pearson_mean_rounded = np.round(pearson_mean, 2) print(pd.DataFrame(pearson_mean_rounded, index=columns, columns=columns)) print("\n")