angepapa commited on
Commit
fec9f61
·
verified ·
1 Parent(s): 293e2cf

Upload 404 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +26 -0
  2. Dockerfile +38 -0
  3. ParaSurf/create_datasets_from_csv/README.md +37 -0
  4. ParaSurf/create_datasets_from_csv/__pycache__/split_pdb2chains_only.cpython-39.pyc +0 -0
  5. ParaSurf/create_datasets_from_csv/final_dataset_preparation.py +146 -0
  6. ParaSurf/create_datasets_from_csv/process_csv_dataset.py +130 -0
  7. ParaSurf/create_datasets_from_csv/split_pdb2chains_only.py +43 -0
  8. ParaSurf/model/ParaSurf_model.py +173 -0
  9. ParaSurf/model/__pycache__/ParaSurf_model.cpython-310.pyc +0 -0
  10. ParaSurf/model/__pycache__/ParaSurf_model.cpython-39.pyc +0 -0
  11. ParaSurf/model/__pycache__/dataset.cpython-310.pyc +0 -0
  12. ParaSurf/model/__pycache__/dataset.cpython-39.pyc +0 -0
  13. ParaSurf/model/dataset.py +107 -0
  14. ParaSurf/model_weights/README.md +11 -0
  15. ParaSurf/preprocess/README.md +71 -0
  16. ParaSurf/preprocess/__pycache__/check_empty_features.cpython-310.pyc +0 -0
  17. ParaSurf/preprocess/__pycache__/check_empty_features.cpython-39.pyc +0 -0
  18. ParaSurf/preprocess/__pycache__/clean_dataset.cpython-310.pyc +0 -0
  19. ParaSurf/preprocess/__pycache__/clean_dataset.cpython-39.pyc +0 -0
  20. ParaSurf/preprocess/check_empty_features.py +68 -0
  21. ParaSurf/preprocess/check_rec_ant_touch.py +89 -0
  22. ParaSurf/preprocess/clean_dataset.py +27 -0
  23. ParaSurf/preprocess/create_input_features.py +230 -0
  24. ParaSurf/preprocess/create_proteins_file.py +23 -0
  25. ParaSurf/preprocess/create_sample_files.py +31 -0
  26. ParaSurf/preprocess/create_surfpoints.py +57 -0
  27. ParaSurf/train/V_domain_results.py +159 -0
  28. ParaSurf/train/__pycache__/V_domain_results.cpython-310.pyc +0 -0
  29. ParaSurf/train/__pycache__/V_domain_results.cpython-39.pyc +0 -0
  30. ParaSurf/train/__pycache__/bsite_extraction.cpython-310.pyc +0 -0
  31. ParaSurf/train/__pycache__/bsite_extraction.cpython-39.pyc +0 -0
  32. ParaSurf/train/__pycache__/distance_coords.cpython-310.pyc +0 -0
  33. ParaSurf/train/__pycache__/distance_coords.cpython-39.pyc +0 -0
  34. ParaSurf/train/__pycache__/features.cpython-310.pyc +0 -0
  35. ParaSurf/train/__pycache__/features.cpython-39.pyc +0 -0
  36. ParaSurf/train/__pycache__/network.cpython-310.pyc +0 -0
  37. ParaSurf/train/__pycache__/network.cpython-39.pyc +0 -0
  38. ParaSurf/train/__pycache__/protein.cpython-310.pyc +0 -0
  39. ParaSurf/train/__pycache__/protein.cpython-39.pyc +0 -0
  40. ParaSurf/train/__pycache__/utils.cpython-310.pyc +0 -0
  41. ParaSurf/train/__pycache__/utils.cpython-39.pyc +0 -0
  42. ParaSurf/train/__pycache__/validation.cpython-310.pyc +0 -0
  43. ParaSurf/train/__pycache__/validation.cpython-39.pyc +0 -0
  44. ParaSurf/train/bsite_extraction.py +48 -0
  45. ParaSurf/train/distance_coords.py +173 -0
  46. ParaSurf/train/features.py +37 -0
  47. ParaSurf/train/network.py +58 -0
  48. ParaSurf/train/protein.py +92 -0
  49. ParaSurf/train/train.py +172 -0
  50. ParaSurf/train/utils.py +497 -0
.gitattributes CHANGED
@@ -33,3 +33,29 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ pdb2pqr-linux-bin64-2.1.1/_codecs_cn.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
37
+ pdb2pqr-linux-bin64-2.1.1/_codecs_hk.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
38
+ pdb2pqr-linux-bin64-2.1.1/_codecs_jp.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
39
+ pdb2pqr-linux-bin64-2.1.1/_codecs_kr.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
40
+ pdb2pqr-linux-bin64-2.1.1/_codecs_tw.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
41
+ pdb2pqr-linux-bin64-2.1.1/_ctypes.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
42
+ pdb2pqr-linux-bin64-2.1.1/datetime.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
43
+ pdb2pqr-linux-bin64-2.1.1/doc/images/flowchart.png filter=lfs diff=lfs merge=lfs -text
44
+ pdb2pqr-linux-bin64-2.1.1/libcrypto.so.1.0.0 filter=lfs diff=lfs merge=lfs -text
45
+ pdb2pqr-linux-bin64-2.1.1/libexpat.so.1 filter=lfs diff=lfs merge=lfs -text
46
+ pdb2pqr-linux-bin64-2.1.1/libncursesw.so.5 filter=lfs diff=lfs merge=lfs -text
47
+ pdb2pqr-linux-bin64-2.1.1/libpython2.7.so.1.0 filter=lfs diff=lfs merge=lfs -text
48
+ pdb2pqr-linux-bin64-2.1.1/libreadline.so.6 filter=lfs diff=lfs merge=lfs -text
49
+ pdb2pqr-linux-bin64-2.1.1/libssl.so.1.0.0 filter=lfs diff=lfs merge=lfs -text
50
+ pdb2pqr-linux-bin64-2.1.1/libstdc++.so.6 filter=lfs diff=lfs merge=lfs -text
51
+ pdb2pqr-linux-bin64-2.1.1/libtinfo.so.5 filter=lfs diff=lfs merge=lfs -text
52
+ pdb2pqr-linux-bin64-2.1.1/libz.so.1 filter=lfs diff=lfs merge=lfs -text
53
+ pdb2pqr-linux-bin64-2.1.1/numpy.core.multiarray.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
54
+ pdb2pqr-linux-bin64-2.1.1/numpy.core.umath.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
55
+ pdb2pqr-linux-bin64-2.1.1/numpy.fft.fftpack_lite.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
56
+ pdb2pqr-linux-bin64-2.1.1/numpy.linalg._umath_linalg.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
57
+ pdb2pqr-linux-bin64-2.1.1/numpy.linalg.lapack_lite.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
58
+ pdb2pqr-linux-bin64-2.1.1/numpy.random.mtrand.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
59
+ pdb2pqr-linux-bin64-2.1.1/pdb2pka._apbslib.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
60
+ pdb2pqr-linux-bin64-2.1.1/pdb2pka._pMC_mult.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
61
+ pdb2pqr-linux-bin64-2.1.1/pdb2pqr filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM continuumio/miniconda3
2
+
3
+ WORKDIR /app
4
+ COPY . /app
5
+
6
+ # Install system dependencies
7
+ RUN apt-get update && apt-get install -y build-essential wget
8
+
9
+ # Create Conda environment and install dependencies
10
+ RUN conda create -n ParaSurf python=3.10 openbabel -c conda-forge -y
11
+ RUN conda run -n ParaSurf pip install -r requirements.txt
12
+
13
+ # Increase file descriptor limit to prevent FD_SETSIZE error
14
+ RUN echo "ulimit -n 65535" >> ~/.bashrc
15
+
16
+ # Download missing Gradio frpc binary
17
+ RUN wget https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64 -O /opt/conda/envs/ParaSurf/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.3 && \
18
+ chmod +x /opt/conda/envs/ParaSurf/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.3
19
+
20
+ # Install DMS software
21
+ WORKDIR /app/dms
22
+ RUN make install
23
+
24
+ # Ensure necessary binaries are executable
25
+ RUN chmod +x /app/pdb2pqr-linux-bin64-2.1.1/pdb2pqr && \
26
+ chmod +x /opt/conda/envs/ParaSurf/bin/* && \
27
+ chmod -R 755 /app
28
+
29
+ # Set writable directories for Matplotlib cache
30
+ ENV MPLCONFIGDIR=/tmp/matplotlib
31
+ ENV XDG_CACHE_HOME=/tmp
32
+
33
+ WORKDIR /app
34
+ EXPOSE 7860
35
+
36
+ # Run the app with higher file descriptor limits
37
+ CMD ["bash", "-c", "ulimit -n 65535 && conda run --no-capture-output -n ParaSurf python app.py"]
38
+
ParaSurf/create_datasets_from_csv/README.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # **Dataset preperation**
2
+
3
+ ### Steps for Dataset Preparation
4
+ #### Step 1
5
+ Download Specific PDB Files Use the process_csv_dataset.py script to download the PDB files listed in the .csv files.
6
+ ```bash
7
+ # Step 1: Download specified PDB files
8
+ python process_csv_dataset.py
9
+ ```
10
+
11
+ #### Step 2
12
+ Generate Final Complexes Run final_dataset_preparation.py to arrange the files into complexes with the specified chain IDs from the .csv files.
13
+ ```bash
14
+ # Step 2: Organize files into final complexes
15
+ python final_dataset_preparation.py
16
+ ```
17
+
18
+
19
+ After running these scripts, you will find a test_data/pdbs folder organized as follows:
20
+ ```bash
21
+ ├── PECAN
22
+ │ ├── TRAIN
23
+ │ │ ├── 1A3R_receptor_1.pdb
24
+ │ │ ├── 1A3R_antigen_1_1.pdb
25
+ │ │ ├── ...
26
+ │ │ ├── 5WUX_receptor_1.pdb
27
+ │ │ └── 5WUX_antigen_1_1.pdb
28
+ │ ├── VAL
29
+ │ └── TEST
30
+ ├── Paragraph_Expanded
31
+ │ ├── TRAIN
32
+ │ ├── VAL
33
+ │ └── TEST
34
+ └── MIPE
35
+ ├── TRAIN_VAL
36
+ └── TEST
37
+ ```
ParaSurf/create_datasets_from_csv/__pycache__/split_pdb2chains_only.cpython-39.pyc ADDED
Binary file (1.54 kB). View file
 
ParaSurf/create_datasets_from_csv/final_dataset_preparation.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import pandas as pd
4
+ from split_pdb2chains_only import extract_chains_from_pdb
5
+ from tqdm import tqdm
6
+
7
+
8
+ def process_raw_pdb_data(info_df, initial_raw_pdb_files, final_folder):
9
+ """
10
+ Processes the raw PDB files by extracting the specific antibody and antigen chains from the .csv file, merging them,
11
+ and saving the merged files in the final train_val and test folder.
12
+
13
+ Parameters:
14
+ info_df (DataFrame): DataFrame containing the PDB codes, antibody chains, and antigen chains.
15
+ initial_raw_pdb_files (str): Path to the initial raw PDB files directory.
16
+ final_folder (str): Path to the folder where the processed files will be saved.
17
+ """
18
+ if not os.path.exists(final_folder):
19
+ os.makedirs(final_folder)
20
+
21
+ for i, row in tqdm(info_df.iterrows(), total=len(info_df)):
22
+ pdb_id = row['pdb_code']
23
+ ab_heavy_chain = row['Heavy_chain'] # Use only this line if you want to construct the only heavy chain dataset
24
+ ab_light_chain = row['Light_chain'] # Use only this line if you want to construct the only light chain dataset
25
+ ag_chain = row['ag']
26
+
27
+ pdb_file = os.path.join(initial_raw_pdb_files, pdb_id + '.pdb')
28
+ # Extract all the chains from the pdb file and save them to /tmp
29
+ chain_files, all_chains = extract_chains_from_pdb(pdb_file, '/tmp')
30
+
31
+ # Assign the correct chains
32
+ ab_heavy_chain_path = f'/tmp/{pdb_id}_chain{ab_heavy_chain}.pdb'
33
+ ab_light_chain_path = f'/tmp/{pdb_id}_chain{ab_light_chain}.pdb'
34
+
35
+ # Merge antibody chains into one file
36
+ receptor_output_path = f'{final_folder}/{pdb_id}_receptor_1.pdb'
37
+ with open(receptor_output_path, 'w') as receptor_file:
38
+ for ab_file in [ab_heavy_chain_path, ab_light_chain_path]: # also delete one (ab_heavy_chain_path or ab_light_chain_path) if you construct the only heavy/light chain dataset
39
+ with open(ab_file, 'r') as infile:
40
+ receptor_file.write(infile.read())
41
+
42
+ print(f"Successfully merged {ab_heavy_chain} and {ab_light_chain} into {receptor_output_path}")
43
+
44
+ ag_chain_list = ag_chain.split(';')
45
+
46
+ if len(ag_chain_list) == 1:
47
+ # If there's only one antigen chain
48
+ ag_chain_1 = ag_chain_list[0].strip()
49
+ ag_chain_1_path = f'/tmp/{pdb_id}_chain{ag_chain_1}.pdb'
50
+ print(f"Handling one antigen chain: {ag_chain_1}")
51
+
52
+ # Copy the single antigen chain to the output
53
+ antigen_output_path = f'{final_folder}/{pdb_id}_antigen_1_1.pdb'
54
+ shutil.copyfile(ag_chain_1_path, antigen_output_path)
55
+
56
+ print(f"Successfully copied {ag_chain_1} to {antigen_output_path}")
57
+
58
+ elif len(ag_chain_list) == 2:
59
+ # If there are two antigen chains
60
+ ag_chain_1, ag_chain_2 = ag_chain_list
61
+ ag_chain_1_path = f'/tmp/{pdb_id}_chain{ag_chain_1}.pdb'
62
+ ag_chain_2_path = f'/tmp/{pdb_id}_chain{ag_chain_2}.pdb'
63
+ print(f"Handling two antigen chains: {ag_chain_1}, {ag_chain_2}")
64
+
65
+ # Merge the antigen chains into a single PDB file
66
+ antigen_output_path = f'{final_folder}/{pdb_id}_antigen_1_1.pdb'
67
+ with open(antigen_output_path, 'w') as outfile:
68
+ for ag_file in [ag_chain_1_path, ag_chain_2_path]:
69
+ with open(ag_file, 'r') as infile:
70
+ outfile.write(infile.read())
71
+
72
+ print(f"Successfully merged {ag_chain_1} and {ag_chain_2} into {antigen_output_path}")
73
+
74
+ elif len(ag_chain_list) == 3:
75
+ # If there are three antigen chains
76
+ ag_chain_1, ag_chain_2, ag_chain_3 = ag_chain_list
77
+ ag_chain_1_path = f'/tmp/{pdb_id}_chain{ag_chain_1}.pdb'
78
+ ag_chain_2_path = f'/tmp/{pdb_id}_chain{ag_chain_2}.pdb'
79
+ ag_chain_3_path = f'/tmp/{pdb_id}_chain{ag_chain_3}.pdb'
80
+ print(f"Handling three antigen chains: {ag_chain_1}, {ag_chain_2}, {ag_chain_3}")
81
+
82
+ # Merge the antigen chains into a single PDB file
83
+ antigen_output_path = f'{final_folder}/{pdb_id}_antigen_1_1.pdb'
84
+ with open(antigen_output_path, 'w') as outfile:
85
+ for ag_file in [ag_chain_1_path, ag_chain_2_path, ag_chain_3_path]:
86
+ with open(ag_file, 'r') as infile:
87
+ outfile.write(infile.read())
88
+
89
+ print(f"Successfully merged {ag_chain_1}, {ag_chain_2}, and {ag_chain_3} into {antigen_output_path}")
90
+
91
+ # At the end, remove all the chain pdb files from the /tmp folder
92
+ for chain_file in chain_files:
93
+ os.remove(chain_file)
94
+
95
+
96
+ if __name__ == '__main__':
97
+ user = os.getenv('USER')
98
+
99
+ datasets = ['PECAN', 'Paragraph_Expanded', 'MIPE']
100
+
101
+ for dataset in datasets:
102
+ if dataset == 'MIPE': # here the split is train-val and test according to the MIPE paper
103
+ # csv path
104
+ train_val_info = pd.read_csv(f'/home/{user}/PycharmProjects/github_projects/ParaSurf/training_data/{dataset}/train_val.csv')
105
+ test_info = pd.read_csv(f'/home/{user}/PycharmProjects/github_projects/ParaSurf/training_data/{dataset}/test_set.csv')
106
+
107
+ # path to init raw PDB storage
108
+ init_pdb_files_train_val = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/train_val_data_initial_raw_pdb_files'
109
+ init_pdb_files_test = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/test_data_initial_raw_pdb_files'
110
+
111
+ # final folder
112
+ final_train_val_folder = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/TRAIN_VAL'
113
+ final_test_folder = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/TEST'
114
+
115
+ process_raw_pdb_data(train_val_info, init_pdb_files_train_val, final_train_val_folder)
116
+ process_raw_pdb_data(test_info, init_pdb_files_test, final_test_folder)
117
+
118
+ shutil.rmtree(init_pdb_files_train_val)
119
+ shutil.rmtree(init_pdb_files_test)
120
+
121
+ else:
122
+ # Paths to dataset csv files
123
+ train_info = pd.read_csv(f'/home/{user}/PycharmProjects/github_projects/ParaSurf/training_data/{dataset}/train_set.csv')
124
+ val_info = pd.read_csv(f'/home/{user}/PycharmProjects/github_projects/ParaSurf/training_data/{dataset}/val_set.csv')
125
+ test_info = pd.read_csv(f'/home/{user}/PycharmProjects/github_projects/ParaSurf/training_data/{dataset}/test_set.csv')
126
+
127
+
128
+ # Paths to init raw pdb files
129
+ initial_pdb_files_train = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/train_data_initial_raw_pdb_files'
130
+ initial_pdb_files_val = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/val_data_initial_raw_pdb_files'
131
+ initial_pdb_files_test = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/test_data_initial_raw_pdb_files'
132
+
133
+ # Final folder for the merged files that contain the final PDB complexes
134
+ final_train_folder = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/TRAIN'
135
+ final_val_folder = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/VAL'
136
+ final_test_folder = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/TEST'
137
+
138
+ # Process the train-val-test data
139
+ process_raw_pdb_data(train_info, initial_pdb_files_train, final_train_folder)
140
+ process_raw_pdb_data(val_info, initial_pdb_files_val, final_val_folder)
141
+ process_raw_pdb_data(test_info, initial_pdb_files_test, final_test_folder)
142
+
143
+ # REMOVE the init raw pdb files
144
+ shutil.rmtree(initial_pdb_files_train)
145
+ shutil.rmtree(initial_pdb_files_val)
146
+ shutil.rmtree(initial_pdb_files_test)
ParaSurf/create_datasets_from_csv/process_csv_dataset.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ from Bio.PDB import PDBList
4
+ from tqdm import tqdm
5
+
6
+
7
+
8
+ def add_headers_if_not_present(csv_file, headerlist):
9
+ """
10
+ Add headers to the CSV file if they are not already present.
11
+
12
+ Parameters:
13
+ csv_file (str): Path to the CSV file.
14
+ headerlist (list): List of headers to add.
15
+ """
16
+ # Read the first row to check for headers
17
+ first_row = pd.read_csv(csv_file, nrows=1)
18
+
19
+ # Check if the first row contains the expected headers
20
+ if list(first_row.columns) != headerlist:
21
+ print(f"Headers not found in {csv_file}. Adding headers...")
22
+ # Load the full data without headers
23
+ data = pd.read_csv(csv_file, header=None)
24
+ # Assign the correct headers
25
+ data.columns = headerlist
26
+ # Save the file with the correct headers
27
+ data.to_csv(csv_file, header=True, index=False)
28
+ print(f"Headers added to {csv_file}")
29
+ else:
30
+ print(f"Headers already present in {csv_file}. No changes made.")
31
+
32
+
33
+ def download_pdb(pdb_code, output_dir):
34
+ pdbl = PDBList()
35
+ pdbl.retrieve_pdb_file(pdb_code, pdir=output_dir, file_format='pdb')
36
+
37
+
38
+ def download_and_rename_pdb_files(pdb_list, folder):
39
+
40
+ """
41
+ Downloads PDB files from the provided list and renames them from `.ent` to `{pdb_code}.pdb`.
42
+
43
+ Parameters:
44
+ pdb_list (list): List of PDB codes to be downloaded.
45
+ folder (str): Directory where the PDB files will be saved and renamed.
46
+ """
47
+ # Download PDB files
48
+ for pdb_code in pdb_list:
49
+ download_pdb(pdb_code, folder)
50
+
51
+ # Rename files to {pdb_code}.pdb
52
+ for pdb_file in os.listdir(folder):
53
+ if pdb_file.endswith('.ent'):
54
+ old_file_path = os.path.join(folder, pdb_file)
55
+ new_file_name = pdb_file.split('.')[0][-4:].upper() + '.pdb' #Capital because the csv gives the pdb names in capital
56
+ new_file_path = os.path.join(folder, new_file_name)
57
+ os.rename(old_file_path, new_file_path)
58
+ print(f"Renamed {old_file_path} to {new_file_path}")
59
+
60
+
61
+ def process_dataset(csv_file, folder):
62
+
63
+ """
64
+ Processes a dataset by adding headers, extracting PDB codes, and downloading/renaming PDB files.
65
+
66
+ Parameters:
67
+ csv_file (str): Path to the CSV file.
68
+ folder (str): Directory where the PDB files will be saved and renamed.
69
+ """
70
+ # Add headers if not present
71
+ add_headers_if_not_present(csv_file, headerlist)
72
+
73
+ # Read the CSV file
74
+ dataset = pd.read_csv(csv_file)
75
+
76
+ # Create folder if it doesn't exist
77
+ if not os.path.exists(folder):
78
+ os.makedirs(folder)
79
+
80
+ # Initialize the PDB list
81
+ pdb_list = []
82
+
83
+ # Process each row
84
+ for i, row in dataset.iterrows():
85
+ pdb_list.append(row['pdb_code'])
86
+
87
+ # Download and rename PDB files
88
+ download_and_rename_pdb_files(pdb_list, folder)
89
+
90
+ if __name__ == '__main__':
91
+
92
+ # ALL datasets follow the same process
93
+ user = os.getenv('USER')
94
+ datasets = ['PECAN', 'Paragraph_Expanded', 'MIPE']
95
+
96
+
97
+ # Define the correct headers
98
+ headerlist = ['pdb_code', 'Light_chain', 'Heavy_chain', 'ag']
99
+
100
+ for dataset in datasets:
101
+
102
+ if dataset == 'MIPE': # here the split is train-val and test according to the MIPE paper
103
+ # csv path
104
+ train_val = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/training_data/{dataset}/train_val.csv'
105
+ test = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/training_data/{dataset}/test_set.csv'
106
+
107
+ # path to init raw PDB storage
108
+ train_val_folder = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/train_val_data_initial_raw_pdb_files'
109
+ test_folder = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/test_data_initial_raw_pdb_files'
110
+
111
+ process_dataset(train_val, train_val_folder)
112
+ process_dataset(test, test_folder)
113
+
114
+ else:
115
+ # Paths to your CSV files. Download dataset from here: https://github.com/oxpig/Paragraph/tree/main/training_data/Expanded
116
+ train = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/training_data/{dataset}/train_set.csv'
117
+ val = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/training_data/{dataset}/val_set.csv'
118
+ test = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/training_data/{dataset}/test_set.csv'
119
+
120
+
121
+ # Paths for init raw PDB file storage
122
+ train_folder = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/train_data_initial_raw_pdb_files'
123
+ val_folder = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/val_data_initial_raw_pdb_files'
124
+ test_folder = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/test_data_initial_raw_pdb_files'
125
+
126
+
127
+ # Process each dataset
128
+ process_dataset(train, train_folder)
129
+ process_dataset(val, val_folder)
130
+ process_dataset(test, test_folder)
ParaSurf/create_datasets_from_csv/split_pdb2chains_only.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ def extract_chains_from_pdb(pdb_file, output_dir):
5
+ """
6
+ Extract and save the chains from a PDB file as separate chain-specific PDB files.
7
+
8
+ Args:
9
+ pdb_file (str): Path to the PDB file.
10
+ output_dir (str): Path to the directory where the chain-specific files should be saved.
11
+
12
+ Returns:
13
+ list: Paths to the chain-specific PDB files.
14
+ """
15
+ chain_dict = {}
16
+
17
+ with open(pdb_file, 'r') as f:
18
+ for line in f:
19
+ if line.startswith('ATOM'):
20
+ chain_id = line[21]
21
+ if chain_id in chain_dict:
22
+ chain_dict[chain_id].append(line)
23
+ else:
24
+ chain_dict[chain_id] = [line]
25
+
26
+ chain_files = []
27
+ for chain_id, lines in chain_dict.items():
28
+ chain_file = os.path.join(output_dir, f"{os.path.splitext(os.path.basename(pdb_file))[0]}_chain{chain_id}.pdb")
29
+ with open(chain_file, 'w') as f:
30
+ f.writelines(lines)
31
+ # print(f'Chain {chain_id} saved as {chain_file}.')
32
+ chain_files.append(chain_file)
33
+
34
+ chain_ids = [chain.split("/")[-1].split(".")[0][-1] for chain in chain_files]
35
+
36
+ return chain_files, chain_ids
37
+
38
+ if __name__ == '__main__':
39
+ pdb_file = '/home/angepapa/PycharmProjects/DeepSurf2.0/3bgf.pdb'
40
+ output_dir = "/".join(pdb_file.split('/')[:-1])
41
+ chain_files, chain_ids = extract_chains_from_pdb(pdb_file, output_dir)
42
+ print(chain_files)
43
+ print(chain_ids)
ParaSurf/model/ParaSurf_model.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchsummary import summary
5
+ import time
6
+
7
+
8
+ class GeM(nn.Module):
9
+ def __init__(self, p=3.0, eps=1e-6):
10
+ super(GeM, self).__init__()
11
+ # Initialize p as a learnable parameter
12
+ self.p = nn.Parameter(torch.ones(1) * p)
13
+ self.eps = eps
14
+
15
+ def forward(self, x):
16
+ return self.gem(x, self.p, self.eps)
17
+
18
+ def gem(self, x, p, eps):
19
+ # Clamp all elements in x to a minimum of eps and then raise them to the power of p
20
+ # Apply avg_pool3d with kernel size being the spatial dimension of the feature map (entire depth, height, width)
21
+ # Finally, take the power of 1/p to invert the earlier power of p operation
22
+ return F.avg_pool3d(x.clamp(min=eps).pow(p), (x.size(2), x.size(3), x.size(4))).pow(1. / p)
23
+
24
+ def __repr__(self):
25
+ # This helps in identifying the layer characteristics when printing the model or layer
26
+ return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', eps=' + str(
27
+ self.eps) + ')'
28
+
29
+
30
+ # Define a custom Bottleneck module with optional dilation
31
+ class DilatedBottleneck(nn.Module):
32
+ expansion = 4
33
+
34
+ def __init__(self, in_planes, planes, stride=1, dilation=1, dropout_prob=0.25):
35
+ super(DilatedBottleneck, self).__init__()
36
+ self.conv1 = nn.Conv3d(in_planes, planes, kernel_size=1, bias=False)
37
+ self.bn1 = nn.BatchNorm3d(planes)
38
+ self.dropout1 = nn.Dropout3d(dropout_prob)
39
+ self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation,
40
+ bias=False)
41
+ self.bn2 = nn.BatchNorm3d(planes)
42
+ self.dropout2 = nn.Dropout3d(dropout_prob)
43
+ self.conv3 = nn.Conv3d(planes, self.expansion * planes, kernel_size=1, bias=False)
44
+ self.bn3 = nn.BatchNorm3d(self.expansion * planes)
45
+
46
+ self.shortcut = nn.Sequential()
47
+ if stride != 1 or in_planes != self.expansion * planes:
48
+ self.shortcut = nn.Sequential(
49
+ nn.Conv3d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
50
+ nn.BatchNorm3d(self.expansion * planes)
51
+ )
52
+
53
+ def forward(self, x):
54
+ out = self.dropout1(F.relu(self.bn1(self.conv1(x))))
55
+ out = self.dropout2(F.relu(self.bn2(self.conv2(out))))
56
+ out = self.bn3(self.conv3(out))
57
+ out += self.shortcut(x)
58
+ out = F.relu(out)
59
+ return out
60
+
61
+
62
+ # Define the Transformer Block
63
+ class TransformerBlock(nn.Module):
64
+ def __init__(self, feature_size, nhead, num_layers):
65
+ super(TransformerBlock, self).__init__()
66
+ self.transformer = nn.TransformerEncoder(
67
+ nn.TransformerEncoderLayer(d_model=feature_size, nhead=nhead),
68
+ num_layers=num_layers
69
+ )
70
+
71
+ def forward(self, x):
72
+ orig_shape = x.shape # Save original shape
73
+ x = x.flatten(2) # Flatten spatial dimensions
74
+ x = x.permute(2, 0, 1) # Reshape for the transformer (Seq, Batch, Features)
75
+ x = self.transformer(x)
76
+ x = x.permute(1, 2, 0).view(*orig_shape) # Restore original shape
77
+ return x
78
+
79
+ # Define a Compression Layer
80
+ class CompressionLayer(nn.Module):
81
+ def __init__(self, in_channels, out_channels):
82
+ super(CompressionLayer, self).__init__()
83
+ self.conv1x1 = nn.Conv3d(in_channels, out_channels, kernel_size=1)
84
+ self.bn = nn.BatchNorm3d(out_channels)
85
+ self.relu = nn.ReLU(inplace=True)
86
+
87
+ def forward(self, x):
88
+ x = self.conv1x1(x)
89
+ x = self.bn(x)
90
+ x = self.relu(x)
91
+ return x
92
+
93
+
94
+ # Define the Enhanced ResNet with hybrid architecture
95
+ class ResNet3D_Transformer(nn.Module):
96
+ def __init__(self, in_channels, block, num_blocks, num_classes=1, dropout_prob=0.1):
97
+ super(ResNet3D_Transformer, self).__init__()
98
+ self.in_planes = 64
99
+
100
+ self.initial_layers = nn.Sequential(
101
+ nn.Conv3d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False),
102
+ nn.BatchNorm3d(64),
103
+ nn.ReLU(inplace=True),
104
+ nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
105
+ )
106
+
107
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
108
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
109
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2, dilation=2)
110
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
111
+
112
+ self.compression = CompressionLayer(512 * block.expansion, 256)
113
+ self.transformer_block = TransformerBlock(feature_size=256, nhead=8, num_layers=1) # change to 4
114
+ # self.gem_pooling = GeM(p=3.0, eps=1e-6)
115
+ self.dropout = nn.Dropout(dropout_prob)
116
+
117
+
118
+ self.classifier = nn.Linear(256, num_classes)
119
+
120
+ def _make_layer(self, block, planes, num_blocks, stride, dilation=1):
121
+ strides = [stride] + [1] * (num_blocks - 1)
122
+ layers = []
123
+ for s in strides:
124
+ layers.append(block(self.in_planes, planes, s, dilation))
125
+ self.in_planes = planes * block.expansion
126
+ return nn.Sequential(*layers)
127
+
128
+ def forward(self, x):
129
+ x = x.permute(0, 4, 3, 2, 1)
130
+ x = self.initial_layers(x)
131
+ x = self.layer1(x)
132
+ x = self.layer2(x)
133
+ x = self.layer3(x)
134
+ x = self.layer4(x)
135
+ # Compress and transform
136
+ x = self.compression(x)
137
+ x = self.transformer_block(x)
138
+ # Global average pooling
139
+ x = torch.mean(x, dim=[2, 3, 4])
140
+
141
+ # Classify
142
+ x = self.dropout(x) # Apply dropout before classification
143
+ x = self.classifier(x)
144
+ return x
145
+
146
+
147
+ def count_parameters(model):
148
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
149
+
150
+
151
+ if __name__ == "__main__":
152
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
153
+ device = 'cpu'
154
+ num_classes = 1
155
+ num_input_channels = 20 # Number of input channels
156
+ model = ResNet3D_Transformer(num_input_channels, DilatedBottleneck, [3, 4, 6, 3], num_classes=num_classes).to(device)
157
+ grid_size = 41 # Assuming the input grid size (for example, 41x41x41x19)
158
+
159
+ start = time.time()
160
+ num_params = count_parameters(model)
161
+ print(f"Number of parameters in the model: {num_params}")
162
+ print(model)
163
+
164
+ dummy_input = torch.randn(64, grid_size, grid_size, grid_size, num_input_channels).to(device)
165
+ dummy_input = dummy_input.float().to(device)
166
+
167
+
168
+ output = model(dummy_input)
169
+
170
+ print("Output shape:", output.shape)
171
+ print(output)
172
+ print(f'total time: {(time.time() - start)/60} mins')
173
+
ParaSurf/model/__pycache__/ParaSurf_model.cpython-310.pyc ADDED
Binary file (6.29 kB). View file
 
ParaSurf/model/__pycache__/ParaSurf_model.cpython-39.pyc ADDED
Binary file (6.36 kB). View file
 
ParaSurf/model/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (2.76 kB). View file
 
ParaSurf/model/__pycache__/dataset.cpython-39.pyc ADDED
Binary file (2.76 kB). View file
 
ParaSurf/model/dataset.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random, os
3
+ from scipy import sparse
4
+ from torch.utils.data import Dataset
5
+ from tqdm import tqdm
6
+ from torch.utils.data import DataLoader
7
+ import h5py
8
+
9
+
10
+ class dataset(Dataset):
11
+ def __init__(self, train_file, batch_size, data_path, grid_size, training, feature_vector_lentgh, feature_names=['deepsite']):
12
+
13
+ super(dataset, self).__init__()
14
+ self.training = training
15
+ self.feature_vector_lentgh = feature_vector_lentgh
16
+ # in testing mode training file is not read
17
+ if self.training:
18
+ with open(train_file) as f:
19
+ self.train_lines = f.readlines()
20
+ random.shuffle(self.train_lines)
21
+
22
+ else:
23
+ self.train_lines = []
24
+
25
+
26
+
27
+ random.shuffle(self.train_lines)
28
+
29
+ self.pointer_tr = 0
30
+ self.pointer_val = 0
31
+
32
+ self.batch_size = batch_size
33
+ self.data_path = data_path
34
+ self.grid_size = grid_size
35
+ self.feature_names = feature_names
36
+ # if added_features is None: # resolved outside
37
+ self.nAtomTypes = 0
38
+
39
+ self.nfeats = {
40
+ 'deepsite': 8,
41
+ 'kalasanty': feature_vector_lentgh,
42
+ 'kalasanty_with_force_fields': feature_vector_lentgh,
43
+ 'kalasanty_norotgrid': 18,
44
+ 'spat_protr': 1,
45
+ 'spat_protr_norotgrid': 1
46
+ }
47
+
48
+ for name in feature_names:
49
+ self.nAtomTypes += self.nfeats[name]
50
+
51
+ def __len__(self):
52
+ if self.training:
53
+ return len(self.train_lines)
54
+
55
+
56
+ def __getitem__(self, index):
57
+ if self.training:
58
+ samples = self.train_lines
59
+
60
+
61
+ label, sample_file = samples[index].split()
62
+ label = int(label)
63
+ base_name, prot, sample = sample_file.split('/')
64
+
65
+ feats = np.zeros((self.grid_size, self.grid_size, self.grid_size, self.nAtomTypes))
66
+ feat_cnt = 0
67
+
68
+ for name in self.feature_names:
69
+ if 'deepsite' == name:
70
+ data = np.load(os.path.join(self.data_path, base_name + '_' + name, prot, sample), allow_pickle=True)
71
+ elif 'kalasanty' == name:
72
+ data = sparse.load_npz(os.path.join(self.data_path, base_name, prot, sample[:-1] + 'z'))
73
+ data = np.reshape(np.array(data.todense()), (self.grid_size, self.grid_size, self.grid_size, self.nfeats['kalasanty']))
74
+ elif 'kalasanty_with_force_fields' == name:
75
+ data = sparse.load_npz(os.path.join(self.data_path, base_name, prot, sample[:-1] + 'z'))
76
+ data = np.reshape(np.array(data.todense()), (self.grid_size, self.grid_size, self.grid_size, self.nfeats['kalasanty_with_force_fields']))
77
+ elif 'spat_protr' in name:
78
+ data = np.load(os.path.join(self.data_path, base_name + '_' + name, prot, sample), allow_pickle=True)
79
+ else:
80
+ print('unknown feat')
81
+
82
+ if len(data) == 3:
83
+ data = data[2] # prosoxh, mono sto scPDB, gia thn wra (sto kalasanty den exw points, normals)
84
+
85
+ feats[:, :, :, feat_cnt:feat_cnt + self.nfeats[name]] = data
86
+ feat_cnt += self.nfeats[name]
87
+
88
+ if feat_cnt != self.nAtomTypes:
89
+ print('error !')
90
+
91
+ # Modified code with explicit strides: Because pytorch does not handle negative samples
92
+ if self.training:
93
+ rot_axis = random.randint(1, 3)
94
+ feats_copy = feats.copy()
95
+ if rot_axis == 1:
96
+ feats_copy = np.rot90(feats_copy, random.randint(0, 3), axes=(0, 1))
97
+ elif rot_axis == 2:
98
+ feats_copy = np.rot90(feats_copy, random.randint(0, 3), axes=(0, 2))
99
+ elif rot_axis == 3:
100
+ feats_copy = np.rot90(feats_copy, random.randint(0, 3), axes=(1, 2))
101
+ feats = np.ascontiguousarray(feats_copy)
102
+
103
+ if np.isnan(np.sum(feats)):
104
+ print('nan input')
105
+
106
+ return feats, label
107
+
ParaSurf/model_weights/README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## *ParaSurf Model Weights*
2
+
3
+ [Download ParaSurf model weights](https://drive.google.com/drive/folders/1Kpehru9SnWsl7_Wq93WuI_o7f8wrPgpI?usp=drive_link)
4
+
5
+ Best model weights for the 3 benchmark datasets:
6
+ * PECAN Dataset
7
+ * Paragraph Expanded
8
+ * Paragraph Expanded (Heavy Chains Only)
9
+ * Paragraph Expanded (Light Chains Only)
10
+ * MIPE Dataset
11
+
ParaSurf/preprocess/README.md ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # **Feature Extraction - Preprocessing phase**
2
+
3
+ This guide outlines the steps needed to generate the ParaSurf 41x41x41x22 input feature vector for training. By following these steps, you will create a dataset ready for training, organized in the specified folder structure.
4
+ ### Step 1: Clean the Antibody-Antigen Complex
5
+ Remove ions, ligands, and water molecules from the antibody-antigen complex and rearrange atom IDs within the PDB structure.
6
+ ```bash
7
+ # Clean the antibody-antigen complex
8
+ python clean_dataset.py
9
+ ```
10
+
11
+ ### Step 2: Sanity Check for Interaction
12
+ Verify that at least one antibody heavy atom is within 4.5Å of any antigen heavy atom, ensuring proximity-based interactions.
13
+ ```bash
14
+ # Run sanity check
15
+ python check_rec_ant_touch.py
16
+ ```
17
+ ### Step 3: Generate Molecular Surface Points
18
+ Create the molecular surface for each receptor in the training folder using DMS software. These surface points will serve as a basis for feature extraction.
19
+ ```bash
20
+ # Generate molecular surface points
21
+ python create_surfpoints.py
22
+ ```
23
+
24
+ ### Step 4: Generate ParaSurf Input Feature Grids (41x41x41x22)
25
+ ```bash
26
+ # Create the 3D feature grids for each surface point generated in Step 3. Each feature grid includes 22 channels with essential structural and electrostatic information.
27
+ python create_input_features.py
28
+ ```
29
+
30
+ ### Step 5: Prepare .proteins Files
31
+ Generate .proteins files for training, validation, and testing. These files list all receptors (antibodies) to be used in each dataset split.
32
+ ```bash
33
+ # Create train/val/test .proteins files
34
+ python create_proteins_file.py
35
+ ```
36
+
37
+ ### Step 6: Create .samples Files
38
+ Generate .samples files, each listing paths to feature files created in Step 4. These files act as a link between features and the training pipeline.
39
+ ```bash
40
+ # Generate .samples files for network training
41
+ python create_sample_files.py
42
+ ```
43
+
44
+ ## **Folder Structure After Preprocessing**
45
+
46
+
47
+ After completing the above steps, the resulting folder structure should be organized as follows:
48
+ ```bash
49
+ ├── test_data
50
+ │ ├── datasets
51
+ │ │ ├── PECAN_TRAIN.samples
52
+ │ │ ├── PECAN_TRAIN.proteins
53
+ │ │ ├── PECAN_VAL.proteins
54
+ │ │ ├── PECAN_TEST.proteins
55
+ │ │ └── ...
56
+ ├── feats
57
+ │ ├── PECAN_22
58
+ │ ├── Paragraph_Expanded_22
59
+ │ └── MIPE_22
60
+ ├── surfpoints
61
+ │ ├── PECAN
62
+ │ │ └── TRAIN
63
+ │ ├── Paragraph_Expanded
64
+ │ │ └── TRAIN
65
+ │ ├── MIPE
66
+ │ │ └── TRAIN
67
+ └── pdbs # already created from ParaSurf/create_datasets_from_csv
68
+ ```
69
+
70
+
71
+ Now we are ready for training!
ParaSurf/preprocess/__pycache__/check_empty_features.cpython-310.pyc ADDED
Binary file (2.72 kB). View file
 
ParaSurf/preprocess/__pycache__/check_empty_features.cpython-39.pyc ADDED
Binary file (2.53 kB). View file
 
ParaSurf/preprocess/__pycache__/clean_dataset.cpython-310.pyc ADDED
Binary file (1.04 kB). View file
 
ParaSurf/preprocess/__pycache__/clean_dataset.cpython-39.pyc ADDED
Binary file (1.01 kB). View file
 
ParaSurf/preprocess/check_empty_features.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ def remove_empty_features(feats_folder, pdbs_path, surf_path, log_file_path="removed_complexes_log.txt"):
4
+ """
5
+ Checks each subfolder in the base folder for files. If a subfolder is empty, removes it along with
6
+ associated files from `data_path` and `surf_path` and logs the removals.
7
+
8
+ Parameters:
9
+ - feats_folder (str): The main directory containing subfolders with the features to check.
10
+ - data_path (str): Path where receptor and antigen PDB files are located.
11
+ - surf_path (str): Path where surface points files are located.
12
+ - log_file_path (str): Path for the log file to track removed folders and files. Default is 'removed_folders_log.txt'.
13
+
14
+ Returns:
15
+ - total_empty_folders (int): Count of empty folders removed.
16
+ """
17
+
18
+ # Identify all subfolders in the base folder
19
+ subfolders = [d for d in os.listdir(feats_folder) if os.path.isdir(os.path.join(feats_folder, d))]
20
+ empty_folders = []
21
+
22
+ # Open log file to record removed folders
23
+ with open(log_file_path, 'w') as log_file:
24
+ log_file.write("Log of Removed Folders and Files\n")
25
+ log_file.write("=" * 30 + "\n")
26
+
27
+ # Check each subfolder and remove if empty
28
+ for folder in subfolders:
29
+ path = os.path.join(feats_folder, folder)
30
+ if not any(os.path.isfile(os.path.join(path, i)) for i in os.listdir(path)):
31
+ empty_folders.append(folder)
32
+ pdb_code = folder.split('_')[0]
33
+
34
+ # Define paths to the files to be removed
35
+ rec_file = os.path.join(pdbs_path, pdb_code + '_receptor_1.pdb')
36
+ antigen_file = os.path.join(pdbs_path, pdb_code + '_antigen_1_1.pdb')
37
+ surf_file = os.path.join(surf_path, pdb_code + '_receptor_1.surfpoints')
38
+
39
+ # Remove the empty folder and associated files
40
+ os.rmdir(path)
41
+ if os.path.exists(rec_file):
42
+ os.remove(rec_file)
43
+ if os.path.exists(antigen_file):
44
+ os.remove(antigen_file)
45
+ if os.path.exists(surf_file):
46
+ os.remove(surf_file)
47
+
48
+ # Log each removal
49
+ log_file.write(f"{pdb_code} complex removed since no features found.\n")
50
+
51
+ total_empty_folders = len(empty_folders)
52
+ # Delete the log file if no folders were removed
53
+ if total_empty_folders == 0:
54
+ os.remove(log_file_path)
55
+ print("\nAll complexes have features!!!")
56
+ else:
57
+ print(f"Total empty folders removed: {total_empty_folders}")
58
+ print(f"Details logged in {log_file_path}")
59
+
60
+ return total_empty_folders
61
+
62
+ # Example usage
63
+ if __name__ == '__main__':
64
+ user = os.getenv('USER')
65
+ pdbs_path = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/eraseme/TRAIN'
66
+ surf_path = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/surf_points/eraseme/TRAIN'
67
+ feats_path = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/feats/eraseme_22'
68
+ remove_empty_features(feats_path, pdbs_path, surf_path)
ParaSurf/preprocess/check_rec_ant_touch.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ from tqdm import tqdm
4
+
5
+
6
+
7
+ def locate_receptor_binding_site_atoms(receptor_pdb_file, antigen_pdb_file, distance_cutoff=4):
8
+ rec_coordinates = []
9
+ with open(receptor_pdb_file, 'r') as file:
10
+ for line in file:
11
+ if line.startswith("ATOM"):
12
+ x = float(line[30:38].strip())
13
+ y = float(line[38:46].strip())
14
+ z = float(line[46:54].strip())
15
+ rec_coordinates.append((x, y, z))
16
+
17
+ ant_coordinates = []
18
+ with open(antigen_pdb_file, 'r') as file:
19
+ for line in file:
20
+ if line.startswith("ATOM"):
21
+ x = float(line[30:38].strip())
22
+ y = float(line[38:46].strip())
23
+ z = float(line[46:54].strip())
24
+ ant_coordinates.append((x, y, z))
25
+
26
+ # Create a list to store the final coordinates
27
+ final_coordinates = []
28
+
29
+ # Compare each coordinate from rec_coordinates with each coordinate from ant_coordinates
30
+ for rec_coord in rec_coordinates:
31
+ for ant_coord in ant_coordinates:
32
+ if math.dist(rec_coord, ant_coord) < distance_cutoff:
33
+ final_coordinates.append(rec_coord)
34
+ break # Break the inner loop if a match is found to avoid duplicate entries
35
+
36
+ # sanity check
37
+ for coor in final_coordinates:
38
+ if coor not in rec_coordinates:
39
+ print('BINDING SITE COORDINATE NOT IN RECEPTORs COORDINATES!!!!!!')
40
+ return final_coordinates, rec_coordinates
41
+
42
+
43
+ def check_receptor_antigen_interactions(pdb_dir, distance_cutoff=6, log_file="interaction_issues.txt"):
44
+ """
45
+ :param pdb_dir: directory with receptor and antigen pdb files
46
+ :param distance_cutoff: the distance cutoff for binding site
47
+ :param log_file: the file where issues will be logged
48
+ :return: It checks if the receptor and antigen are in contact with each other
49
+ """
50
+ all_successful = True # A flag to track if all pairs are correct
51
+
52
+ # Open the log file for writing
53
+ with open(log_file, 'w') as log:
54
+ log.write("Receptor-Antigen Interaction Issues Log\n")
55
+ log.write("=====================================\n")
56
+
57
+ non_interacting_pdbs = 0
58
+ for pdb_file in tqdm(os.listdir(pdb_dir)):
59
+ pdb_id = pdb_file.split('_')[0]
60
+ cur_rec_pdb = os.path.join(pdb_dir, f'{pdb_id}_receptor_1.pdb')
61
+ cur_ant_pdb = os.path.join(pdb_dir, f'{pdb_id}_antigen_1_1.pdb')
62
+
63
+ if os.path.exists(cur_rec_pdb) and os.path.exists(cur_ant_pdb):
64
+ final, rec = locate_receptor_binding_site_atoms(cur_rec_pdb, cur_ant_pdb, distance_cutoff)
65
+ if len(final) == 0:
66
+ non_interacting_pdbs += 1
67
+ log.write(f'\nNON-INTERACTING PAIRS!!!: problem with {pdb_id}.pdb. {pdb_id}_receptor_1.pdb and '
68
+ f' {pdb_id}_antigen_1_1.pdb files are removed.\n')
69
+ os.remove(cur_rec_pdb)
70
+ os.remove(cur_ant_pdb)
71
+ all_successful = False # Mark as unsuccessful if any issue is found
72
+
73
+ # Check if everything was successful
74
+ if all_successful:
75
+ print("Success! All receptors interact with their associated antigens.")
76
+ # since no issue s were found we can remove the log file
77
+ os.remove(log_file)
78
+ else:
79
+ print(f'\n ~~~~~ Total pdbs found with issues: {non_interacting_pdbs} and are removed from the folder ~~~~~\n')
80
+ log.write(f'\n\n ~~~~~ Total pdbs found with issues: {non_interacting_pdbs} ~~~~~')
81
+ print(f"Some receptors do not interact with their antigens. Issues logged in {log_file}.")
82
+
83
+
84
+ # example usage
85
+ if __name__ == '__main__':
86
+ user = os.getenv('USER')
87
+ pdb_dir = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/eraseme/TRAIN'
88
+ index = pdb_dir.split('/')[-1]
89
+ check_receptor_antigen_interactions(pdb_dir, distance_cutoff=4.5, log_file=f'{pdb_dir}/{index}_interaction_issues.txt')
ParaSurf/preprocess/clean_dataset.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from ParaSurf.utils.remove_hydrogens_from_pdb import remove_hydrogens_from_pdb_folder
3
+ from ParaSurf.utils.remove_HETATMS_from_receptors import remove_hetatm_from_pdb_folder
4
+ from ParaSurf.utils.reaarange_atom_id import process_pdb_files_in_folder
5
+
6
+
7
+ def clean_dataset(dataset_path_with_pdbs):
8
+ """
9
+ :param dataset_path_with_pdbs:
10
+ :return: a cleaned dataset ready to be processed for training purposes with 3 steps of filtering
11
+ """
12
+ data_path = dataset_path_with_pdbs
13
+
14
+ # step1: remove hydrogens
15
+ remove_hydrogens_from_pdb_folder(input_folder=data_path,
16
+ output_folder=data_path)
17
+
18
+ # step2: remove HETATMS only from the receptors
19
+ remove_hetatm_from_pdb_folder(input_folder=data_path,
20
+ output_folder=data_path)
21
+
22
+ # step3: re-arrange the atom_id of each pdb
23
+ process_pdb_files_in_folder(folder_path=data_path)
24
+
25
+ if __name__ == "__main__":
26
+ user = os.getenv('USER')
27
+ clean_dataset(f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/eraseme/TRAIN')
ParaSurf/preprocess/create_input_features.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocessing import Pool
2
+ from multiprocessing import Lock
3
+ import time, os
4
+ import numpy as np
5
+ from Bio.PDB.PDBParser import PDBParser
6
+ from ParaSurf.utils.bsite_lib import readSurfpoints, readSurfpoints_with_residues, dist_point_from_lig
7
+ from ParaSurf.utils.features import KalasantyFeaturizer, KalasantyFeaturizer_with_force_fields
8
+ from scipy import sparse
9
+ from tqdm import tqdm
10
+ import warnings
11
+ from ParaSurf.utils.distance_coords import locate_receptor_binding_site_residues
12
+ from check_empty_features import remove_empty_features
13
+
14
+ # Ignore warnings
15
+ warnings.filterwarnings('ignore')
16
+
17
+
18
+ lock = Lock() # Instantiate a Lock for thread safety.
19
+
20
+
21
+ def balanced_sampling(surf_file, protein_file, lig_files, cutoff=4.5):
22
+ """
23
+ Returns a subset of equal positive and negative samples from surface points in `surf_file`, with positive samples
24
+ selected from residues close to the antigen.
25
+
26
+ Parameters:
27
+ - surf_file: Path to the file with protein surface points.
28
+ - protein_file: Path to the protein structure file (e.g., PDB format).
29
+ - lig_files: List of ligand (antigen) structure file paths.
30
+ - cutoff: Distance cutoff in Ångstroms for defining binding residues (default is 4).
31
+
32
+ Returns:
33
+ - Balanced samples including features and labels for the selected surface points.
34
+ """
35
+ all_lig_coords = []
36
+ for lig_file in lig_files:
37
+ with lock: # Locks the parser to ensure thread safety.
38
+ lig = parser.get_structure('antigen', lig_file)
39
+ lig_coords = np.array([atom.get_coord() for atom in lig.get_atoms()])
40
+ all_lig_coords.append(lig_coords)
41
+
42
+ points, normals = readSurfpoints(surf_file) # modified by me
43
+ # create the residue groups for the whole protein
44
+ all_rec_residues = readSurfpoints_with_residues(surf_file)
45
+
46
+
47
+ # find the bind site residues
48
+ bind_site_rec_residues = locate_receptor_binding_site_residues(protein_file, lig_file, distance_cutoff=cutoff)
49
+
50
+ # gather all distances
51
+ # Create an array to store the minimum distance of each point to any ligand atom
52
+ dist_from_lig = np.full(len(points), np.inf)
53
+ near_lig = np.zeros(len(points), dtype=bool)
54
+
55
+ # Update distances only for points in bind site residues
56
+ bind_site_indices = [item for i in bind_site_rec_residues for item in all_rec_residues[i]['idx']]
57
+ bind_site_indices_set = set(bind_site_indices) # Convert to set for fast lookup
58
+
59
+ # Loop through ligand coordinates and update distances for binding site points
60
+ # IMPORTANT Step because if a residue belongs to the binding site that DOES NOT mean that all the atom of this
61
+ # residue belongs the binding site (<6 armstrong to the ligand). So here we check from the binding site residues which
62
+ # atoms actually bind (<6 armstrong to the ligand)
63
+ for lig_coords in all_lig_coords:
64
+ for i, p in enumerate(points):
65
+ if i in bind_site_indices_set:
66
+ dist = dist_point_from_lig(p, lig_coords) # Adjust this function if necessary
67
+ if dist < dist_from_lig[i]:
68
+ dist_from_lig[i] = dist
69
+ near_lig[i] = dist < cutoff
70
+
71
+ # Filter positive indices to include only those near a ligand
72
+ pos_idxs = np.array([idx for idx in bind_site_indices if near_lig[idx]])
73
+
74
+ # If there are more positive indices than allowed, select the best ones based on the distance
75
+ if len(pos_idxs) > maxPosSamples:
76
+ pos_idxs = pos_idxs[np.argsort(dist_from_lig[pos_idxs])[:maxPosSamples]]
77
+
78
+ # Select the negative samples
79
+ all_neg_samples = [idx for idx, i in enumerate(points) if idx not in pos_idxs]
80
+
81
+ # Calculate number of negative samples to match the number of positive samples
82
+ num_neg_samples = min(len(all_neg_samples), len(pos_idxs))
83
+
84
+ neg_idxs = np.array(all_neg_samples)
85
+ if len(neg_idxs) > num_neg_samples:
86
+ neg_downsampled = np.random.choice(neg_idxs, num_neg_samples, replace=False)
87
+ else:
88
+ neg_downsampled = neg_idxs
89
+
90
+ # Concatenate positive and negative indices
91
+ sample_idxs = np.concatenate((pos_idxs, neg_downsampled))
92
+
93
+ # Shuffle the indices to ensure randomness
94
+ np.random.shuffle(sample_idxs)
95
+
96
+ # create the sample labels
97
+ # Convert pos_idxs to a set for faster membership testing
98
+ pos_set = set(pos_idxs)
99
+
100
+ # Use list comprehension to create labels
101
+ sample_labels = [i in pos_set for i in sample_idxs]
102
+ if feat_type == 'kalasanty':
103
+ featurizer = KalasantyFeaturizer(protein_file, protonate, gridSize, voxelSize, use_protrusion, protr_radius)
104
+ elif feat_type == 'kalasanty_with_force_fields':
105
+ featurizer = KalasantyFeaturizer_with_force_fields(protein_file, protonate, gridSize, voxelSize, use_protrusion, protr_radius,
106
+ add_atom_radius_features=add_atoms_radius_ff_features)
107
+
108
+ feature_vector_length = featurizer.channels.shape[1]
109
+ with open(feature_vector_length_tmp_path, 'w') as file:
110
+ file.write(str(feature_vector_length))
111
+
112
+ for i, sample in enumerate(sample_idxs):
113
+ features = featurizer.grid_feats(points[sample], normals[sample], rotate_grid)
114
+ if np.count_nonzero(features) == 0:
115
+ print('Zero features', protein_file.rsplit('/', 1)[1][:-4], i, points[sample], normals[sample])
116
+
117
+ yield features, sample_labels[i], points[sample], normals[sample]
118
+
119
+
120
+ def samples_per_prot(prot):
121
+ """
122
+ Generates and saves balanced surface point samples for a given protein.
123
+
124
+ Parameters:
125
+ - prot: Protein identifier for which features are being generated.
126
+
127
+ Saves each sample as a sparse matrix or NumPy array in `feats_path`.
128
+ """
129
+ prot_path = os.path.join(feats_path, prot)
130
+
131
+ # Check if directory exists and has files, if not create it
132
+ if not os.path.exists(prot_path):
133
+ os.makedirs(prot_path)
134
+ elif os.listdir(prot_path):
135
+ return
136
+
137
+ surf_file = os.path.join(surf_path, f"{prot}.surfpoints")
138
+ protein_file = os.path.join(pdbs_path, f"{prot}.pdb")
139
+ if not os.path.exists(protein_file):
140
+ protein_file = os.path.join(pdbs_path, f"{prot}.mol2")
141
+
142
+ receptor_id = prot_path.split('_')[-1]
143
+ antigen_prefix = prot.split('_')[0]
144
+
145
+ # Using set for faster membership checks
146
+ files_set = set(os.listdir(pdbs_path))
147
+ lig_files = [os.path.join(pdbs_path, f) for f in files_set if f"{antigen_prefix}_antigen_{receptor_id}" in f]
148
+
149
+ try:
150
+ cnt = 0
151
+ for features, y, point, normal in balanced_sampling(surf_file, protein_file, lig_files, cutoff=cutoff):
152
+ samples_file_name = os.path.join(prot_path, f"sample{cnt}_{int(y)}")
153
+
154
+ if feat_type == 'deepsite':
155
+ with open(f"{samples_file_name}.npy", 'w') as f:
156
+ np.save(f, (point, normal, features.astype(np.float16)))
157
+ elif feat_type == 'kalasanty' or feat_type == 'kalasanty_with_force_fields':
158
+ sparse_mat = sparse.coo_matrix(features.flatten())
159
+ sparse.save_npz(f"{samples_file_name}.npz", sparse_mat)
160
+
161
+ cnt += 1
162
+
163
+ print(f'Saved "{cnt}" samples for "{prot}".')
164
+
165
+ except Exception as e:
166
+ print(f'Exception occurred while processing "{prot}". Error message: "{e}".')
167
+
168
+
169
+ seed = 10 # random seed
170
+ num_cores = 6 # Set this to the number of cores you wish to use
171
+ maxPosSamples = 800 # maximum number of positive samples per protein
172
+ gridSize = 41 # size of grid (16x16x16)
173
+ voxelSize = 1 # size of voxel, e.g. 1 angstrom, if 2A we lose details, so leave it to 1
174
+ cutoff = 4.5 # cutoff threshold in Armstrong's 6 for general PPIs, 4.5 for antibody antigen databases
175
+ feature_vector_length_tmp_path = '/tmp/feature_vector_length.txt'
176
+ # feat_type = 'kalasanty' # select featurizer
177
+ feat_type = 'kalasanty_with_force_fields'
178
+ add_atoms_radius_ff_features = True # If you want to add the atom radius features that correspond to the force fields
179
+ rotate_grid = True # whether to rotate the grid (ignore)
180
+ use_protrusion = False # ignore
181
+ protr_radius = 10 # ignore
182
+ protonate = True # if protein pdbs are not protonated (do not have Hydrogens) set it to True
183
+
184
+ user = os.getenv('USER')
185
+ pdbs_path = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/eraseme/TRAIN' # input folder with protein pdbs for training
186
+ surf_path = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/surfpoints/eraseme/TRAIN' # input folder with surface points for training
187
+ feats_path = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/feats/eraseme' # training features folder
188
+
189
+
190
+ if not os.path.exists(feats_path):
191
+ os.makedirs(feats_path)
192
+
193
+ np.random.seed(seed)
194
+
195
+ all_proteins = [f.rsplit('.', 1)[0] for f in os.listdir(surf_path)]
196
+ #
197
+ # in case the procedure stacks use the 3 lines below
198
+ completed = [f.rsplit('.', 1)[0] for f in os.listdir(feats_path)]
199
+ all_proteins = [i for i in all_proteins if i not in completed]
200
+ print(len(all_proteins))
201
+
202
+
203
+ parser = PDBParser(PERMISSIVE=1) # PERMISSIVE=1 allowing more flexibility in handling non-standard or problematic entries in PDB files during parsing.
204
+
205
+ start = time.time()
206
+ with Pool(num_cores) as pool: # Use a specified number of CPU cores
207
+ list(tqdm(pool.imap(samples_per_prot, all_proteins), total=len(all_proteins)))
208
+ print(f'Total preprocess time: {(time.time() - start)/60} mins')
209
+
210
+ ###################################################################################
211
+ # Instead of using Pool and imap, iterate through all_proteins with a for loop for easy debugging
212
+ # for prot in all_proteins:
213
+ # try:
214
+ # samples_per_prot(prot)
215
+ # except Exception as e:
216
+ # print(f'Error processing protein {prot}: {e}')
217
+ # break
218
+
219
+
220
+ # the last number at the out_path will be the total number of the feature vector
221
+ if os.path.exists(feature_vector_length_tmp_path):
222
+ with open(feature_vector_length_tmp_path, 'r') as file:
223
+ feature_vector_length = int(file.read().strip())
224
+ feats_path_new = f'{feats_path}_{feature_vector_length}'
225
+ os.rename(feats_path, feats_path_new)
226
+ os.remove(feature_vector_length_tmp_path)
227
+
228
+
229
+ # remove empty features if found
230
+ remove_empty_features(feats_folder=feats_path_new, pdbs_path=pdbs_path, surf_path=surf_path)
ParaSurf/preprocess/create_proteins_file.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ # Here we create a .proteins file that has all the proteins==receptors that we are working with.
5
+ cases = ['TRAIN', 'VAL', 'TEST'] # change to ['TRAIN_VAL', 'TEST'] for MIPE
6
+ user = os.getenv('USER')
7
+ for case in cases:
8
+ pdbs_path = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/eraseme/{case}'
9
+ proteins_file = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/datasets/eraseme_{case}.proteins' # run for train val and test
10
+
11
+ # Create directories if they don't exist
12
+ os.makedirs(pdbs_path, exist_ok=True)
13
+ os.makedirs(os.path.dirname(proteins_file), exist_ok=True)
14
+
15
+ receptors = []
16
+
17
+ for prot in os.listdir(pdbs_path):
18
+ prot_name = prot.split('.')[0]
19
+ if 'rec' in prot:
20
+ receptors.append(prot_name + '\n')
21
+
22
+ with open(proteins_file,'w') as f:
23
+ f.writelines(receptors)
ParaSurf/preprocess/create_sample_files.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, random
2
+
3
+
4
+ user = os.getenv('USER')
5
+
6
+ feats_path = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/feats/eraseme_22' # input folder with protein grids (training features)
7
+ proteins_file = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/datasets/eraseme_TRAIN.proteins' # input file with a list of train proteins
8
+ samples_file = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/datasets/eraseme_TRAIN.samples' # output file with respective training samples info (class_label + sample_path)
9
+ seed = 1
10
+
11
+ with open(proteins_file, 'r') as f:
12
+ proteins = f.readlines()
13
+
14
+ sample_lines = []
15
+ feats_prefix = feats_path.rsplit('/')[-1]
16
+
17
+ for prot in proteins:
18
+ prot = prot[:-1]
19
+ prot_feats_path = os.path.join(feats_path, prot)
20
+ if not os.path.isdir(prot_feats_path):
21
+ print('No features for ', prot)
22
+ continue
23
+ for sample in os.listdir(prot_feats_path):
24
+ cls_idx = sample[-5]
25
+ sample_lines.append(cls_idx + ' ' + feats_prefix + '/' + prot + '/' + sample + '\n')
26
+
27
+ random.seed(seed)
28
+ random.shuffle(sample_lines)
29
+
30
+ with open(samples_file, 'w') as f:
31
+ f.writelines(sample_lines)
ParaSurf/preprocess/create_surfpoints.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import os
3
+ from tqdm import tqdm
4
+ from ParaSurf.utils.fix_surfpoints_format_issues import process_surfpoints_directory
5
+
6
+ def generate_molecular_surface(input_path, out_path):
7
+ """
8
+ Generates the molecular surface for protein structures in PDB files using the DMS tool.
9
+
10
+ Parameters:
11
+ - input_path (str): Path to the input directory containing protein PDB files.
12
+ - out_path (str): Path to the output directory where generated surface points files will be saved.
13
+
14
+ Process:
15
+ - The function iterates over receptor PDB files in the input path.
16
+ - For each receptor file, it checks if a corresponding surface points file already exists in the output directory.
17
+ - If the surface points file does not exist, it generates the file using the DMS tool with a density of 0.5 Å.
18
+
19
+ Outputs:
20
+ - Each receptor file generates a surface points file saved in `out_path`.
21
+ """
22
+
23
+ if not os.path.exists(out_path):
24
+ os.makedirs(out_path)
25
+
26
+ start = time.time()
27
+ for f in tqdm(os.listdir(input_path), desc="Generating surface points"):
28
+ if 'antigen' in f:
29
+ continue
30
+
31
+ surfpoints_file = os.path.join(out_path, f[:-3] + 'surfpoints')
32
+ if os.path.exists(surfpoints_file):
33
+ continue
34
+
35
+ print(f"Processing {f}")
36
+ os.system(f'dms {os.path.join(input_path, f)} -d 0.5 -n -o {surfpoints_file}')
37
+
38
+ # Calculate and print statistics
39
+ rec_count = sum(1 for receptor in os.listdir(input_path) if 'receptor' in receptor)
40
+ total_time = (time.time() - start) / 60 # Convert time to minutes
41
+ print(f'Total time to create surfpoints for {rec_count} receptors: {total_time:.2f} mins')
42
+
43
+
44
+ # Example usage
45
+ if __name__ == '__main__':
46
+ user = os.getenv('USER')
47
+ pdbs_path = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/eraseme/TRAIN' # input folder with protein pdbs for training
48
+ surfpoints_path =f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/surfpoints/eraseme/TRAIN'
49
+
50
+ # create the molecular surface
51
+ generate_molecular_surface(
52
+ input_path= pdbs_path,
53
+ out_path= surfpoints_path
54
+ )
55
+
56
+ # fix some format issues with the .surfpoints files
57
+ process_surfpoints_directory(surfpoints_path)
ParaSurf/train/V_domain_results.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from utils import calculate_metrics
4
+ import os, re
5
+
6
+
7
+ def calculate_Fv_and_cdr_regions(residues_best, gt_true_label_residues, rec_name, output_path, epoch, test_csv=None,
8
+ thres=0.5):
9
+ # Load the CSV file to identify heavy and light chains, if provided
10
+ heavy_chain_name, light_chain_name = None, None
11
+ calculate_individual_cdrs = False
12
+
13
+ if test_csv:
14
+ test_csv = pd.read_csv(test_csv)
15
+ rec_info = test_csv[test_csv['pdb_code'] == rec_name]
16
+ if not rec_info.empty:
17
+ heavy_chain_name = rec_info['Heavy_chain'].iloc[0]
18
+ light_chain_name = rec_info['Light_chain'].iloc[0]
19
+
20
+ calculate_individual_cdrs = True
21
+ else:
22
+ print(f"Receptor {rec_name} not found in the test CSV.")
23
+
24
+ # Define the CDR+-2 ranges for heavy and light chains
25
+ cdr1 = list(range(25, 41)) # CDR-H/L1: 25-40
26
+ cdr2 = list(range(54, 68)) # CDR-H/L2: 54-67
27
+ cdr3 = list(range(103, 120)) # CDR-H/L3: 103-119
28
+ framework_ranges = list(range(1, 25)) + list(range(41, 54)) + list(range(68, 103)) + list(range(120, 129))
29
+
30
+ # Initialize dictionaries
31
+ CDRH1, CDRH2, CDRH3 = {}, {}, {}
32
+ CDRL1, CDRL2, CDRL3 = {}, {}, {}
33
+ FRAMEWORK = {}
34
+
35
+ # Loop over the predictions to populate the dictionaries
36
+ for residue, data in residues_best.items():
37
+ # Split the residue into components (e.g., '30_L_C' -> ['30', 'L', 'C'])
38
+ residue_parts = residue.split('_')
39
+ residue_num = int(re.findall(r'\d+', residue_parts[0])[0])
40
+ chain_name = residue_parts[1]
41
+
42
+ # Assign residue to the corresponding CDR or FRAMEWORK based on chain and residue number
43
+ if (not heavy_chain_name or chain_name == heavy_chain_name): # If no csv or matching heavy chain
44
+ if residue_num in cdr1:
45
+ CDRH1[residue] = data
46
+ elif residue_num in cdr2:
47
+ CDRH2[residue] = data
48
+ elif residue_num in cdr3:
49
+ CDRH3[residue] = data
50
+ elif residue_num in framework_ranges:
51
+ FRAMEWORK[residue] = data
52
+
53
+ if (not light_chain_name or chain_name == light_chain_name): # If no csv or matching light chain
54
+ if residue_num in cdr1:
55
+ CDRL1[residue] = data
56
+ elif residue_num in cdr2:
57
+ CDRL2[residue] = data
58
+ elif residue_num in cdr3:
59
+ CDRL3[residue] = data
60
+ elif residue_num in framework_ranges:
61
+ FRAMEWORK[residue] = data
62
+
63
+ # Helper function to calculate and save metrics for each CDR and FRAMEWORK
64
+ def calculate_and_save_metrics(cdr_dict, cdr_name, threshold=thres):
65
+ if len(cdr_dict) > 0: # To check if CDR exists in the antibody
66
+ pred_scores = np.array([[i[1]['scores']] for i in cdr_dict.items()])
67
+ pred_labels = (pred_scores > threshold).astype(int)
68
+ gt_labels = np.array([1 if residue in gt_true_label_residues else 0 for residue in cdr_dict.keys()])
69
+
70
+ if len(np.unique(gt_labels)) > 1: # Ensure both classes are present
71
+ output_results_path = os.path.join(output_path, f'{cdr_name}_results_epoch_{epoch}_{threshold}.txt')
72
+ auc_roc, accuracy, precision, recall, f1, auc_pr, conf_matrix, mcc, _, _, _ = \
73
+ calculate_metrics(gt_labels, pred_labels, pred_scores, to_save_metrics_path=output_results_path)
74
+ return auc_roc, accuracy, precision, recall, f1, auc_pr, conf_matrix, mcc
75
+ return None
76
+
77
+ # Calculate and save metrics for each CDR and FRAMEWORK only if .csv is provided
78
+ if calculate_individual_cdrs:
79
+ calculate_and_save_metrics(CDRH1, 'CDRH1')
80
+ calculate_and_save_metrics(CDRH2, 'CDRH2')
81
+ calculate_and_save_metrics(CDRH3, 'CDRH3')
82
+ calculate_and_save_metrics(CDRL1, 'CDRL1')
83
+ calculate_and_save_metrics(CDRL2, 'CDRL2')
84
+ calculate_and_save_metrics(CDRL3, 'CDRL3')
85
+ calculate_and_save_metrics(FRAMEWORK, 'FRAMEWORK')
86
+
87
+ # Calculate the metrics for the CDR+-2 region (CDRH1 + CDRH2 + CDRH3 + CDRL1 + CDRL2 + CDRL3)
88
+ cdr_plus_minus_2 = {**CDRH1, **CDRH2, **CDRH3, **CDRL1, **CDRL2, **CDRL3}
89
+ calculate_and_save_metrics(cdr_plus_minus_2, 'CDR_plus_minus_2')
90
+
91
+ # Calculate the metrics for the Fv region (CDRs + FRAMEWORK)
92
+ fv_region = {**CDRH1, **CDRH2, **CDRH3, **CDRL1, **CDRL2, **CDRL3, **FRAMEWORK}
93
+ calculate_and_save_metrics(fv_region, 'Fv')
94
+
95
+
96
+ def calculate_Fv_and_cdr_regions_only_one_chain(residues_best, gt_true_label_residues, rec_name, output_path, epoch, thres=0.5):
97
+ """
98
+ This function calculates metrics for the Fv and CDR+-2 regions, but only for a PDB file with one chain.
99
+ The CSV file is not needed in this case, as there is only one chain.
100
+
101
+ Args:
102
+ - residues_best: Dictionary containing residue information.
103
+ - gt_true_label_residues: List of ground truth binding residues.
104
+ - rec_name: Name of the receptor.
105
+ - output_path: Directory to save output results.
106
+ - epoch: Current epoch for model validation.
107
+ - thres: Threshold for classification (default is 0.5).
108
+
109
+ Returns:
110
+ - Metrics calculated and saved for CDR+-2 and Fv regions.
111
+ """
112
+
113
+ # Define the CDR+-2 and framework ranges for the single chain
114
+ cdr1 = list(range(25, 41)) # CDR1: 25-40
115
+ cdr2 = list(range(54, 68)) # CDR2: 54-67
116
+ cdr3 = list(range(103, 120)) # CDR3: 103-119
117
+ framework_ranges = list(range(1, 25)) + list(range(41, 54)) + list(range(68, 103)) + list(range(120, 129))
118
+
119
+ # Initialize dictionaries
120
+ CDR1, CDR2, CDR3 = {}, {}, {}
121
+ FRAMEWORK = {}
122
+
123
+ # Loop over the predictions to populate the dictionaries
124
+ for residue, data in residues_best.items():
125
+ # Split the residue into components (e.g., '30_L_C' -> ['30', 'L', 'C'])
126
+ residue_parts = residue.split('_')
127
+ residue_num = int(re.findall(r'\d+', residue_parts[0])[0])
128
+
129
+ # Assign residue to the corresponding CDR or FRAMEWORK based on residue number
130
+ if residue_num in cdr1:
131
+ CDR1[residue] = data
132
+ elif residue_num in cdr2:
133
+ CDR2[residue] = data
134
+ elif residue_num in cdr3:
135
+ CDR3[residue] = data
136
+ elif residue_num in framework_ranges:
137
+ FRAMEWORK[residue] = data
138
+
139
+ # Helper function to calculate and save metrics for each CDR and FRAMEWORK
140
+ def calculate_and_save_metrics(cdr_dict, cdr_name, threshold=thres):
141
+ if len(cdr_dict) > 0: # Check if CDR exists in the antibody
142
+ pred_scores = np.array([[i[1]['scores']] for i in cdr_dict.items()])
143
+ pred_labels = (pred_scores > threshold).astype(int)
144
+ gt_labels = np.array([1 if residue in gt_true_label_residues else 0 for residue in cdr_dict.keys()])
145
+
146
+ if len(np.unique(gt_labels)) > 1: # Ensure both classes are present
147
+ output_results_path = os.path.join(output_path, f'{cdr_name}_results_epoch_{epoch}_{threshold}.txt')
148
+ auc_roc, accuracy, precision, recall, f1, auc_pr, conf_matrix, mcc, _, _, _ = \
149
+ calculate_metrics(gt_labels, pred_labels, pred_scores, to_save_metrics_path=output_results_path)
150
+ return auc_roc, accuracy, precision, recall, f1, auc_pr, conf_matrix, mcc
151
+ return None
152
+
153
+ # Calculate the metrics for the CDR+-2 region (CDR1 + CDR2 + CDR3)
154
+ cdr_plus_minus_2 = {**CDR1, **CDR2, **CDR3}
155
+ calculate_and_save_metrics(cdr_plus_minus_2, 'CDR_plus_minus_2')
156
+
157
+ # Calculate the metrics for the Fv region (CDR1 + CDR2 + CDR3 + FRAMEWORK)
158
+ fv_region = {**CDR1, **CDR2, **CDR3, **FRAMEWORK}
159
+ calculate_and_save_metrics(fv_region, 'Fv')
ParaSurf/train/__pycache__/V_domain_results.cpython-310.pyc ADDED
Binary file (4.54 kB). View file
 
ParaSurf/train/__pycache__/V_domain_results.cpython-39.pyc ADDED
Binary file (4.81 kB). View file
 
ParaSurf/train/__pycache__/bsite_extraction.cpython-310.pyc ADDED
Binary file (1.82 kB). View file
 
ParaSurf/train/__pycache__/bsite_extraction.cpython-39.pyc ADDED
Binary file (1.79 kB). View file
 
ParaSurf/train/__pycache__/distance_coords.cpython-310.pyc ADDED
Binary file (4.55 kB). View file
 
ParaSurf/train/__pycache__/distance_coords.cpython-39.pyc ADDED
Binary file (4.6 kB). View file
 
ParaSurf/train/__pycache__/features.cpython-310.pyc ADDED
Binary file (1.91 kB). View file
 
ParaSurf/train/__pycache__/features.cpython-39.pyc ADDED
Binary file (1.9 kB). View file
 
ParaSurf/train/__pycache__/network.cpython-310.pyc ADDED
Binary file (1.93 kB). View file
 
ParaSurf/train/__pycache__/network.cpython-39.pyc ADDED
Binary file (1.94 kB). View file
 
ParaSurf/train/__pycache__/protein.cpython-310.pyc ADDED
Binary file (4.47 kB). View file
 
ParaSurf/train/__pycache__/protein.cpython-39.pyc ADDED
Binary file (4.5 kB). View file
 
ParaSurf/train/__pycache__/utils.cpython-310.pyc ADDED
Binary file (14 kB). View file
 
ParaSurf/train/__pycache__/utils.cpython-39.pyc ADDED
Binary file (14 kB). View file
 
ParaSurf/train/__pycache__/validation.cpython-310.pyc ADDED
Binary file (7.12 kB). View file
 
ParaSurf/train/__pycache__/validation.cpython-39.pyc ADDED
Binary file (7.19 kB). View file
 
ParaSurf/train/bsite_extraction.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn.cluster import MeanShift
3
+
4
+
5
+ class Bsite_extractor():
6
+ def __init__(self, lig_thres=0.9, bw=15):
7
+ self.T = lig_thres
8
+ self.ms = MeanShift(bandwidth=bw,bin_seeding=True,cluster_all=False,n_jobs=4)
9
+
10
+ def _cluster_points(self,prot,lig_scores):
11
+ T_new = self.T
12
+ while sum(lig_scores>=T_new) < 10 and T_new>0.3001: # at least 10 points with prob>P and P>=0.3
13
+ T_new -= 0.1
14
+
15
+ # filtered_points = prot.surf_points[lig_scores>T_new]
16
+ filtered_points = prot.surf_points[lig_scores.flatten() > T_new]
17
+ filtered_scores = lig_scores[lig_scores>T_new]
18
+ if len(filtered_points)<5:
19
+ return ()
20
+
21
+ clustering = self.ms.fit(filtered_points)
22
+ labels = clustering.labels_
23
+
24
+ unique_l,freq = np.unique(labels,return_counts=True)
25
+
26
+ if len(unique_l[freq>=5])!=0:
27
+ unique_l = unique_l[freq>=5] # keep clusters with 5 points and more
28
+ else:
29
+ return ()
30
+
31
+ if unique_l[0]==-1: # discard the "unclustered" cluster
32
+ unique_l = unique_l[1:]
33
+
34
+ clusters = [(filtered_points[labels==l],filtered_scores[labels==l]) for l in unique_l]
35
+
36
+ return clusters
37
+
38
+ def extract_bsites(self,prot,lig_scores):
39
+ clusters = self._cluster_points(prot,lig_scores)
40
+ if len(clusters)==0:
41
+ print('No binding site found!!!')
42
+ return
43
+ for cluster in clusters:
44
+ prot.add_bsite(cluster)
45
+ prot.sort_bsites()
46
+ prot.write_bsites()
47
+
48
+
ParaSurf/train/distance_coords.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import re
3
+ import numpy as np
4
+
5
+ def locate_surface_binding_site_atoms(receptor_surf_file, antigen_pdb_file, distance_cutoff=4):
6
+ rec_coordinates = []
7
+ with open(receptor_surf_file, 'r') as file:
8
+ for line in file:
9
+ parts = line.split()
10
+
11
+ # Check for the presence of a numeric value in the 3rd element of parts
12
+ match = re.search(r'([-+]?\d*\.\d+|\d+)(?=\.)', parts[2])
13
+ if match:
14
+ numeric_value = match.group(0)
15
+ non_numeric_value = parts[2].replace(numeric_value, "")
16
+
17
+ # Update the 'parts' list
18
+ parts[2:3] = [non_numeric_value, numeric_value]
19
+
20
+ if len(parts) >= 7: # Since we added an extra element to parts, its length increased by 1
21
+ x = float(parts[3])
22
+ y = float(parts[4])
23
+ z = float(parts[5])
24
+ rec_coordinates.append((x, y, z))
25
+
26
+ ant_coordinates = []
27
+ with open(antigen_pdb_file, 'r') as file:
28
+ for line in file:
29
+ if line.startswith("ATOM"):
30
+ x = float(line[30:38].strip())
31
+ y = float(line[38:46].strip())
32
+ z = float(line[46:54].strip())
33
+ ant_coordinates.append((x, y, z))
34
+
35
+ # Create a list to store the final coordinates
36
+ final_coordinates = []
37
+
38
+ # Compare each coordinate from rec_coordinates with each coordinate from ant_coordinates
39
+ for rec_coord in rec_coordinates:
40
+ for ant_coord in ant_coordinates:
41
+ if math.dist(rec_coord, ant_coord) < distance_cutoff:
42
+ final_coordinates.append(rec_coord)
43
+ break # Break the inner loop if a match is found to avoid duplicate entries
44
+
45
+ # sanity check
46
+ for coor in final_coordinates:
47
+ if coor not in rec_coordinates:
48
+ print('BINDING SITE COORDINATE NOT IN RECEPTORs COORDINATES!!!!!!')
49
+
50
+ return final_coordinates, rec_coordinates
51
+
52
+
53
+ def locate_receptor_binding_site_atoms(receptor_pdb_file, antigen_pdb_file, distance_cutoff=4):
54
+ rec_coordinates = []
55
+ with open(receptor_pdb_file, 'r') as file:
56
+ for line in file:
57
+ if line.startswith("ATOM"):
58
+ x = float(line[30:38].strip())
59
+ y = float(line[38:46].strip())
60
+ z = float(line[46:54].strip())
61
+ rec_coordinates.append((x, y, z))
62
+
63
+ ant_coordinates = []
64
+ with open(antigen_pdb_file, 'r') as file:
65
+ for line in file:
66
+ if line.startswith("ATOM"):
67
+ x = float(line[30:38].strip())
68
+ y = float(line[38:46].strip())
69
+ z = float(line[46:54].strip())
70
+ ant_coordinates.append((x, y, z))
71
+
72
+ # Create a list to store the final coordinates
73
+ final_coordinates = []
74
+
75
+ # Compare each coordinate from rec_coordinates with each coordinate from ant_coordinates
76
+ for rec_coord in rec_coordinates:
77
+ for ant_coord in ant_coordinates:
78
+ if math.dist(rec_coord, ant_coord) < distance_cutoff:
79
+ final_coordinates.append(rec_coord)
80
+ break # Break the inner loop if a match is found to avoid duplicate entries
81
+
82
+ # sanity check
83
+ for coor in final_coordinates:
84
+ if coor not in rec_coordinates:
85
+ print('BINDING SITE COORDINATE NOT IN RECEPTORs COORDINATES!!!!!!')
86
+ return final_coordinates, rec_coordinates
87
+
88
+
89
+ def coords2pdb(coordinates, tosavepath):
90
+ with open(tosavepath, 'w') as pdb_file:
91
+ atom_number = 1
92
+ for coord in coordinates:
93
+ x, y, z = coord
94
+ pdb_file.write(f"ATOM {atom_number:5} DUM DUM A{atom_number:4} {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00\n")
95
+
96
+ atom_number += 1
97
+ if atom_number == 9999:
98
+ atom_number = 1
99
+ pdb_file.write("END")
100
+
101
+
102
+ def locate_receptor_binding_site_atoms_residue_level(receptor_file, antigen_pdb_file, distance_cutoff=4):
103
+ rec_atoms = []
104
+ chain_elements = []
105
+ with open(receptor_file, 'r') as file:
106
+ for line in file:
107
+ if line.startswith("ATOM"):
108
+ atom_id = line[6:11].strip()
109
+ atom_type = line[12:16].strip()
110
+ res_id = line[22:26].strip()
111
+ # check if there is Code for insertions of residues
112
+ insertion_code = line[26].strip()
113
+ if insertion_code:
114
+ res_id = res_id + insertion_code
115
+ res_name = line[17:20].strip()
116
+ chain_id = line[21].strip()
117
+ x = float(line[30:38].strip())
118
+ y = float(line[38:46].strip())
119
+ z = float(line[46:54].strip())
120
+ rec_atoms.append((atom_id, atom_type, res_id, res_name, chain_id, x, y, z))
121
+ chain_elements.append((atom_id, atom_type, res_id, chain_id))
122
+
123
+ ant_atoms = []
124
+ with open(antigen_pdb_file, 'r') as file:
125
+ for line in file:
126
+ if line.startswith("ATOM"):
127
+ atom_id = line[6:11].strip()
128
+ atom_type = line[12:16].strip()
129
+ res_id = line[22:26].strip()
130
+ res_name = line[17:20].strip()
131
+ chain_id = line[21].strip()
132
+ x = float(line[30:38].strip())
133
+ y = float(line[38:46].strip())
134
+ z = float(line[46:54].strip())
135
+ ant_atoms.append((atom_id, atom_type, res_id, res_name, chain_id, x, y, z))
136
+
137
+ final_atoms = []
138
+
139
+ for rec_atom in rec_atoms:
140
+ for ant_atom in ant_atoms:
141
+ if math.dist(rec_atom[5:], ant_atom[5:]) < distance_cutoff:
142
+ final_atoms.append(rec_atom)
143
+ break
144
+
145
+ rec_atoms = np.array([atom[5:] for atom in rec_atoms])
146
+ final_atoms_ = np.array([atom[5:] for atom in final_atoms])
147
+ final_elements = np.array([atom[:5] for atom in final_atoms])
148
+
149
+ return final_atoms_, rec_atoms, final_elements
150
+
151
+
152
+ def coords2pdb_residue_level(coordinates, tosavepath, elements):
153
+ with open(tosavepath, 'w') as pdb_file:
154
+ for i, atom in enumerate(coordinates):
155
+ atom_id, atom_type, res_id, res_name, chain_id = elements[i]
156
+
157
+ # Separate the numeric part from the insertion code (if any)
158
+ if res_id[-1].isalpha(): # Check if the last character is an insertion code
159
+ res_num = res_id[:-1] # Numeric part of the residue
160
+ insertion_code = res_id[-1] # Insertion code (e.g., 'A' in '30A')
161
+ else:
162
+ res_num = res_id
163
+ insertion_code = " " # No insertion code
164
+
165
+ x, y, z = atom
166
+
167
+ # Write to the PDB file with the correct formatting
168
+ pdb_file.write(
169
+ f"ATOM {int(atom_id):5} {atom_type:<4} {res_name} {chain_id}{int(res_num):4}{insertion_code:1} {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00\n")
170
+
171
+ pdb_file.write("END\n")
172
+
173
+
ParaSurf/train/features.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ParaSurf.utils import bio_data_featurizer
2
+ from ParaSurf.train.utils import rotation
3
+ import numpy as np
4
+
5
+
6
+ class KalasantyFeaturizer:
7
+ def __init__(self, gridSize, voxelSize):
8
+ grid_limit = (gridSize / 2 - 0.5) * voxelSize
9
+ grid_radius = grid_limit * np.sqrt(3)
10
+ self.neigh_radius = 4 + grid_radius # 4 > 2*R_vdw
11
+ # self.neigh_radius = 2*grid_radius # 4 > 2*R_vdw
12
+ self.featurizer = bio_data_featurizer.Featurizer(save_molecule_codes=False)
13
+ self.grid_resolution = voxelSize
14
+ self.max_dist = (gridSize - 1) * voxelSize / 2
15
+
16
+ def get_channels(self, mol, add_forcefields, add_atom_radius_features=False):
17
+ if not add_forcefields:
18
+ self.coords, self.channels = self.featurizer.get_features(mol) # returns only heavy atoms
19
+ else:
20
+ self.coords, self.channels = self.featurizer.get_features_with_force_fields(mol, add_atom_radius=add_atom_radius_features) # returns only heavy atoms
21
+
22
+
23
+ def get_channels_with_forcefields(self, mol):
24
+ self.coords, self.channels = self.featurizer.get_features_with_force_fields(mol, add_atom_radius=True) # returns only heavy atoms
25
+
26
+ def grid_feats(self, point, normal, mol_coords):
27
+ neigh_atoms = np.sqrt(np.sum((mol_coords - point) ** 2, axis=1)) < self.neigh_radius
28
+ Q = rotation(normal)
29
+ Q_inv = np.linalg.inv(Q)
30
+ transf_coords = np.transpose(mol_coords[neigh_atoms] - point)
31
+ rotated_mol_coords = np.matmul(Q_inv, transf_coords)
32
+ features = \
33
+ bio_data_featurizer.make_grid(np.transpose(rotated_mol_coords), self.channels[neigh_atoms], self.grid_resolution,
34
+ self.max_dist)[0]
35
+
36
+ return features
37
+
ParaSurf/train/network.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import torch
4
+ from ParaSurf.train.features import KalasantyFeaturizer
5
+ from ParaSurf.model import ParaSurf_model
6
+
7
+
8
+
9
+ class Network:
10
+ def __init__(self, model_path, gridSize, feature_channels, voxelSize=1, device="cuda"):
11
+ self.gridSize = gridSize # Does this change?
12
+
13
+ if device == 'cuda' and torch.cuda.is_available():
14
+ self.device = torch.device("cuda")
15
+ else:
16
+ self.device = torch.device("cpu")
17
+
18
+ # load model
19
+ self.model = ParaSurf_model.ResNet3D_Transformer(in_channels=feature_channels, block=ParaSurf_model.DilatedBottleneck,
20
+ num_blocks=[3, 4, 6, 3], num_classes=1)
21
+
22
+ # load weights
23
+ self.model.load_state_dict(torch.load(model_path, map_location=self.device)) #
24
+ # model to eval mode and to device
25
+ self.model = self.model.to(self.device).eval()
26
+
27
+ self.featurizer = KalasantyFeaturizer(gridSize, voxelSize) # it is the "rules of the game"
28
+ self.feature_channels = feature_channels
29
+
30
+ def get_lig_scores(self, prot, batch_size, add_forcefields, add_atom_radius_features):
31
+
32
+ self.featurizer.get_channels(prot.mol, add_forcefields, add_atom_radius_features)
33
+
34
+
35
+ lig_scores = []
36
+ input_data = torch.zeros((batch_size, self.gridSize, self.gridSize, self.gridSize, self.feature_channels), device=self.device)
37
+
38
+ batch_cnt = 0
39
+ for p, n in zip(prot.surf_points, prot.surf_normals):
40
+ input_data[batch_cnt,:,:,:,:] = torch.tensor(self.featurizer.grid_feats(p, n, prot.heavy_atom_coords), device=self.device)
41
+ batch_cnt += 1
42
+ if batch_cnt == batch_size:
43
+ with torch.no_grad():
44
+ output = self.model(input_data)
45
+ output = torch.sigmoid(output)
46
+ lig_scores.extend(output.cpu().numpy())
47
+ batch_cnt = 0
48
+
49
+ if batch_cnt > 0:
50
+ with torch.no_grad():
51
+ output = self.model(input_data[:batch_cnt])
52
+ output = torch.sigmoid(output)
53
+ lig_scores.extend(output.cpu().numpy())
54
+
55
+ print(np.array(lig_scores).shape)
56
+ return np.array(lig_scores)
57
+
58
+
ParaSurf/train/protein.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, numpy as np
2
+ import shutil
3
+ # import pybel
4
+ from openbabel import pybel
5
+ from ParaSurf.train.utils import simplify_dms
6
+ from ParaSurf.utils.fix_surfpoints_format_issues import process_surfpoints_directory
7
+
8
+
9
+ class Protein_pred:
10
+ def __init__(self, prot_file, save_path, seed=None, atom_points_threshold=5, locate_only_surface=False):
11
+
12
+ prot_id = prot_file.split('/')[-1].split('.')[0]
13
+ self.save_path = os.path.join(save_path, prot_id)
14
+
15
+ if not os.path.exists(self.save_path):
16
+ os.makedirs(self.save_path)
17
+
18
+ self.mol = next(pybel.readfile(prot_file.split('.')[-1], prot_file))
19
+ self.atom_points_thresh = atom_points_threshold
20
+
21
+ surfpoints_file = os.path.join(self.save_path, prot_id + '.surfpoints')
22
+
23
+ # we have all the surfpoints ready from the preprocessing step
24
+ if not os.path.exists(surfpoints_file):
25
+ os.system('dms ' + prot_file + ' -d 0.1 -n -o ' + surfpoints_file) #default value for d is 0.2
26
+ # fix any format issues
27
+ print('\nfixing surfpoints format ...')
28
+ process_surfpoints_directory(self.save_path)
29
+ # raise Exception('probably DMS not installed')
30
+
31
+ # locate surface: if we want the final coordinates to have the receptor atoms or we want just the surface atoms
32
+ self.surf_points, self.surf_normals = simplify_dms(surfpoints_file, seed=seed,
33
+ locate_surface=locate_only_surface)
34
+
35
+ self.heavy_atom_coords = np.array([atom.coords for atom in self.mol.atoms if atom.atomicnum > 1])
36
+
37
+ self.binding_sites = []
38
+ if prot_file.endswith('pdb'):
39
+ with open(prot_file, 'r') as f:
40
+ lines = f.readlines()
41
+ self.heavy_atom_lines = [line for line in lines if line[:4] == 'ATOM' and line.split()[2][0] != 'H']
42
+ if len(self.heavy_atom_lines) != len(self.heavy_atom_coords):
43
+ ligand_in_pdb = len([line for line in lines if line.startswith('HETATM')]) > 0
44
+ if ligand_in_pdb:
45
+ raise Exception('Ligand found in PDBfile. Please remove it to procede.')
46
+ else:
47
+ raise Exception('Incosistency between Coords and PDBLines')
48
+ else:
49
+ raise IOError('Protein file should be .pdb')
50
+
51
+ def _surfpoints_to_atoms(self, surfpoints):
52
+ close_atoms = np.zeros(len(surfpoints), dtype=int)
53
+ for p, surf_coord in enumerate(surfpoints):
54
+ dist = np.sqrt(np.sum((self.heavy_atom_coords - surf_coord) ** 2, axis=1))
55
+ close_atoms[p] = np.argmin(dist)
56
+
57
+ return np.unique(close_atoms)
58
+
59
+ def add_bsite(self, cluster): # cluster -> tuple: (surf_points,scores)
60
+ atom_idxs = self._surfpoints_to_atoms(cluster[0])
61
+ self.binding_sites.append(Bsite(self.heavy_atom_coords, atom_idxs, cluster[1]))
62
+
63
+ def sort_bsites(self):
64
+ avg_scores = np.array([bsite.score for bsite in self.binding_sites])
65
+ sorted_idxs = np.flip(np.argsort(avg_scores), axis=0)
66
+ self.binding_sites = [self.binding_sites[idx] for idx in sorted_idxs]
67
+
68
+ def write_bsites(self):
69
+ if not os.path.exists(self.save_path):
70
+ os.makedirs(self.save_path)
71
+
72
+ centers = np.array([bsite.center for bsite in self.binding_sites])
73
+ np.savetxt(os.path.join(self.save_path, 'centers.txt'), centers, delimiter=' ', fmt='%10.3f')
74
+
75
+ pocket_count = 0
76
+ for i, bsite in enumerate(self.binding_sites):
77
+ outlines = [self.heavy_atom_lines[idx] for idx in bsite.atom_idxs]
78
+ if len(outlines) > self.atom_points_thresh:
79
+ pocket_count += 1
80
+ with open(os.path.join(self.save_path, 'pocket' + str(pocket_count) + '.pdb'), 'w') as f:
81
+ f.writelines(outlines)
82
+
83
+
84
+
85
+
86
+ class Bsite:
87
+ def __init__(self, mol_coords, atom_idxs, scores):
88
+ self.coords = mol_coords[atom_idxs]
89
+ self.center = np.average(self.coords, axis=0)
90
+ self.score = np.average(scores)
91
+ self.atom_idxs = atom_idxs
92
+
ParaSurf/train/train.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time, random
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from torch.utils.data import DataLoader
8
+ from torch.optim.lr_scheduler import StepLR
9
+ from ParaSurf.model import ParaSurf_model
10
+ from ParaSurf.model.dataset import dataset
11
+ from validation import validate_residue_level
12
+ import wandb
13
+ from tqdm import tqdm
14
+
15
+
16
+ user = os.getenv('USER')
17
+ base_dir = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data'
18
+ CFG = {
19
+ 'name': 'ParaSurf train dummy eraseme folder',
20
+ 'initial_lr': 0.0001,
21
+ 'epochs': 100,
22
+ 'batch_size': 64,
23
+ 'grid': 41, # don't change
24
+ 'seed': 42,
25
+ 'wandb': False,
26
+ 'debug': False,
27
+ 'model_weights': None, # if ('' or None )is given then training starts from scratch
28
+ 'num_workers': 8,
29
+ 'feat_type': ['kalasanty_with_force_fields'],
30
+ 'feats_path': os.path.join(base_dir, 'feats'),
31
+ 'TRAIN_samples': os.path.join(base_dir, 'datasets/eraseme_TRAIN.samples'),
32
+ 'VAL_proteins_list': os.path.join(base_dir, 'datasets/eraseme_VAL.proteins'),
33
+ 'VAL_proteins': os.path.join(base_dir, 'pdbs/eraseme/VAL'),
34
+ 'save_dir': f'/home/{user}/PycharmProjects/github_projects/ParaSurf/ParaSurf/train/eraseme/model_weights'
35
+ }
36
+
37
+ if CFG['wandb']:
38
+ wandb.init(project='ParaSurf', entity='your_project_name', config=CFG, name=CFG['name'])
39
+
40
+
41
+ # Set random seed for repeatability
42
+ def set_seed(seed_value):
43
+ """Set seed for reproducibility."""
44
+ random.seed(seed_value)
45
+ np.random.seed(seed_value)
46
+ torch.manual_seed(seed_value)
47
+ if torch.cuda.is_available():
48
+ torch.cuda.manual_seed_all(seed_value)
49
+
50
+
51
+ set_seed(CFG['seed'])
52
+
53
+ with open(CFG['TRAIN_samples']) as f:
54
+ lines = f.readlines()
55
+ feature_vector_lentgh = int(lines[0].split()[1].split('/')[0].split('_')[-1])
56
+
57
+ # model
58
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
59
+ model = ParaSurf_model.ResNet3D_Transformer(in_channels=feature_vector_lentgh,
60
+ block=ParaSurf_model.DilatedBottleneck,
61
+ num_blocks=[3, 4, 6, 3], num_classes=1).to(device)
62
+ print(model)
63
+ criterion = nn.BCEWithLogitsLoss()
64
+ optimizer = optim.Adam(model.parameters(), lr=CFG['initial_lr'])
65
+
66
+ scheduler = StepLR(optimizer, step_size=5, gamma=0.1)
67
+
68
+ # Load Dataset
69
+ train_set = dataset(CFG['TRAIN_samples'], CFG['batch_size'], CFG['feats_path'], CFG['grid'], True,
70
+ feature_vector_lentgh, CFG['feat_type'])
71
+ train_loader = DataLoader(dataset=train_set, batch_size=CFG['batch_size'], shuffle=True,
72
+ num_workers=CFG['num_workers'])
73
+
74
+ # Training
75
+ if not os.path.exists(CFG['save_dir']):
76
+ os.makedirs(CFG['save_dir'])
77
+
78
+ # check if pretrain weights are loaded and start the epoch from there
79
+ if CFG['model_weights'] and os.path.exists(CFG['model_weights']):
80
+ model.load_state_dict(torch.load(CFG['model_weights']))
81
+ start_epoch = int(CFG['model_weights'].split('/')[-1].split('.')[0].split('_')[1]) + 1
82
+ print(f"\nLoading weights from epoch {start_epoch-1} ...\n")
83
+ print(f"Start training for epoch {start_epoch} ...")
84
+ else:
85
+ print('\nStart training from scratch ...')
86
+ start_epoch = 0
87
+
88
+
89
+ train_losses = [] # to keep track of training losses
90
+
91
+ for epoch in range(start_epoch, CFG['epochs']):
92
+ start = time.time()
93
+ model.train()
94
+ total_loss = 0.0
95
+
96
+ correct_train_predictions = 0 # Reset for each epoch
97
+ total_train_samples = 0 # Reset for each epoch
98
+
99
+ for i, (inputs, labels) in tqdm(enumerate(train_loader), total=len(train_loader)):
100
+ inputs, labels = inputs.float().to(device), labels.to(device).unsqueeze(1)
101
+ total_train_samples += labels.shape[0]
102
+ optimizer.zero_grad()
103
+
104
+ # scaler option
105
+ # with torch.cuda.amp.autocast():
106
+ outputs = model(inputs)
107
+ loss = criterion(outputs, labels.float())
108
+
109
+ loss.backward()
110
+
111
+ # Apply gradient clipping
112
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
113
+
114
+ optimizer.step()
115
+
116
+ total_loss += loss.item()
117
+
118
+ predicted_train = torch.sigmoid(outputs) > 0.5
119
+ correct_train_predictions += (predicted_train == labels).sum().item()
120
+
121
+ # Print the training loss every 100 batches
122
+ if (i + 1) % 100 == 0:
123
+ print(f"Epoch: {epoch + 1} Batch: {i + 1} Train Loss: {loss.item():.3f}")
124
+
125
+ # if CFG['wandb']:
126
+ # wandb.log({'Mini Batch Train Loss': loss.item()})
127
+
128
+ if CFG['debug']:
129
+ break
130
+
131
+
132
+ avg_train_loss = total_loss / len(train_loader)
133
+ train_accuracy = correct_train_predictions / total_train_samples # Calculate training accuracy
134
+
135
+ train_losses.append(avg_train_loss)
136
+
137
+ cur_model_weight_path = os.path.join(CFG['save_dir'], f'epoch_{epoch}.pth')
138
+ torch.save(model.state_dict(), cur_model_weight_path)
139
+
140
+ avg_auc_roc, avg_precision, avg_recall, avg_auc_pr, avg_f1 = validate_residue_level(valset=CFG['VAL_proteins_list'],
141
+ modelweights=cur_model_weight_path,
142
+ test_folder=CFG['VAL_proteins'],
143
+ epoch=epoch + 1,
144
+ feat_type=CFG['feat_type'],
145
+ feature_vector_lentgh=feature_vector_lentgh)
146
+
147
+
148
+ print(
149
+ f"Epoch {epoch + 1}/{CFG['epochs']} - Train Loss: {avg_train_loss:.3f}, Train Accuracy: {train_accuracy:.3f},"
150
+ f"Val_AUC-ROC: {avg_auc_roc:.3f}, Val_Precision: {avg_precision:.3f}, Val_Recall: {avg_recall:.3f},"
151
+ f" Val_AUC_pr: {avg_auc_pr:.3f}, Val_F1: {avg_f1}")
152
+
153
+ print(f"Total epoch time: {(time.time() - start) / 60:.3f} mins")
154
+
155
+ if CFG['wandb']:
156
+ wandb.log({'Epoch': epoch,
157
+ 'Train Loss': avg_train_loss,
158
+ 'Train Accuracy': train_accuracy,
159
+ 'Valid AUC-ROC': avg_auc_roc,
160
+ 'Valid Precision': avg_precision,
161
+ 'Valid Recall': avg_recall,
162
+ 'Valid AUC-pr': avg_auc_pr,
163
+ 'Valid F1': avg_f1
164
+ })
165
+
166
+
167
+ # Step the scheduler
168
+ scheduler.step()
169
+
170
+ # # Finish the wandb run at the end of all epochs for the current iteration
171
+ if CFG['wandb']:
172
+ wandb.finish()
ParaSurf/train/utils.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings, os
2
+ import numpy as np
3
+ from scipy.spatial.distance import euclidean
4
+ from sklearn.cluster import KMeans
5
+ import matplotlib.pyplot as plt
6
+ from sklearn.metrics import roc_auc_score, roc_curve
7
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, auc, precision_recall_curve, \
8
+ confusion_matrix, matthews_corrcoef
9
+
10
+
11
+ def mol2_reader(mol_file): # does not handle H2
12
+ if mol_file[-4:] != 'mol2':
13
+ raise Exception("File's extension is not .mol2")
14
+
15
+ with open(mol_file, 'r') as f:
16
+ lines = f.readlines()
17
+
18
+ for i, line in enumerate(lines):
19
+ if '@<TRIPOS>ATOM' in line:
20
+ first_atom_idx = i + 1
21
+ if '@<TRIPOS>BOND' in line:
22
+ last_atom_idx = i - 1
23
+
24
+ return lines[first_atom_idx:last_atom_idx + 1]
25
+
26
+
27
+ # maybe change to read_surfpoints_new ??
28
+ def readSurfPoints(surf_file):
29
+ with open(surf_file, 'r') as f:
30
+ lines = f.readlines()
31
+
32
+ lines = [l for l in lines if len(l.split()) > 7]
33
+
34
+ if len(lines) > 100000:
35
+ warnings.warn('{} has too many points'.format(surf_file))
36
+ return
37
+ if len(lines) == 0:
38
+ warnings.warn('{} is empty'.format(surf_file))
39
+ return
40
+
41
+ coords = np.zeros((len(lines), 3))
42
+ normals = np.zeros((len(lines), 3))
43
+ for i, l in enumerate(lines):
44
+ parts = l.split()
45
+
46
+ try:
47
+ coords[i, 0] = float(parts[3])
48
+ coords[i, 1] = float(parts[4])
49
+ coords[i, 2] = float(parts[5])
50
+ normals[i, 0] = float(parts[8])
51
+ normals[i, 1] = float(parts[9])
52
+ normals[i, 2] = float(parts[10])
53
+ except:
54
+ coords[i, 0] = float(parts[2][-8:])
55
+ coords[i, 1] = float(parts[3])
56
+ coords[i, 2] = float(parts[4])
57
+ normals[i, 0] = float(parts[7])
58
+ normals[i, 1] = float(parts[8])
59
+ normals[i, 2] = float(parts[9])
60
+
61
+ return coords, normals
62
+
63
+
64
+ def readSurfPoints_with_receptor_atoms(surf_file):
65
+ with open(surf_file, 'r') as f:
66
+ lines = f.readlines()
67
+
68
+ # lines = [l for l in lines if len(l.split()) > 7]
69
+ lines = [l for l in lines]
70
+ if len(lines) > 100000:
71
+ warnings.warn('{} has too many points'.format(surf_file))
72
+ return
73
+ if len(lines) == 0:
74
+ warnings.warn('{} is empty'.format(surf_file))
75
+ return
76
+
77
+ coords = np.zeros((len(lines), 3))
78
+ normals = np.zeros((len(lines), 3))
79
+
80
+ # First, ensure each line has at least 11 parts by filling with zeros
81
+ for i in range(len(lines)):
82
+ parts = lines[i].split()
83
+ while len(parts) < 11:
84
+ # Fill with '0' initially
85
+ parts.append('0')
86
+ lines[i] = ' '.join(parts)
87
+
88
+ # Modify lines according to the specified rules
89
+ for i in range(len(lines)):
90
+ parts = lines[i].split()
91
+ # Check if there are zeros that need to be replaced
92
+ if '0' in parts:
93
+ if i > 0: # Use previous line if not the first line
94
+ prev_parts = lines[i - 1].split()
95
+ parts = [prev_parts[j] if part == '0' else part for j, part in enumerate(parts)]
96
+ elif i < len(lines) - 1: # Use next line if not the last line
97
+ next_parts = lines[i + 1].split()
98
+ parts = [next_parts[j] if part == '0' else part for j, part in enumerate(parts)]
99
+ lines[i] = ' '.join(parts)
100
+
101
+ try:
102
+ coords[i, 0] = float(parts[3])
103
+ coords[i, 1] = float(parts[4])
104
+ coords[i, 2] = float(parts[5])
105
+ normals[i, 0] = float(parts[8])
106
+ normals[i, 1] = float(parts[9])
107
+ normals[i, 2] = float(parts[10])
108
+ except:
109
+ coords[i, 0] = float(parts[2][-8:])
110
+ coords[i, 1] = float(parts[3])
111
+ coords[i, 2] = float(parts[4])
112
+ normals[i, 0] = float(parts[7])
113
+ normals[i, 1] = float(parts[8])
114
+ normals[i, 2] = float(parts[9])
115
+
116
+ return coords, normals
117
+
118
+
119
+ def simplify_dms(init_surf_file, seed=None, locate_surface=True):
120
+ # Here we decide if we want the final coordinates to have the receptor atoms or we want just
121
+ # the surface atoms
122
+ if locate_surface:
123
+ coords, normals = readSurfPoints(init_surf_file)
124
+ else:
125
+ coords, normals = readSurfPoints_with_receptor_atoms(init_surf_file) # to also get the receptor points
126
+
127
+ return coords, normals
128
+
129
+ nCl = len(coords)
130
+
131
+ kmeans = KMeans(n_clusters=nCl, max_iter=300, n_init=1, random_state=seed).fit(coords)
132
+ point_labels = kmeans.labels_
133
+ centers = kmeans.cluster_centers_
134
+ cluster_idx, freq = np.unique(point_labels, return_counts=True)
135
+ if len(cluster_idx) != nCl:
136
+ raise Exception('Number of created clusters should be equal to nCl')
137
+
138
+ idxs = []
139
+ for cl in cluster_idx:
140
+ cluster_points_idxs = np.where(point_labels == cl)[0]
141
+ closest_idx_to_center = np.argmin([euclidean(centers[cl], coords[idx]) for idx in cluster_points_idxs])
142
+ idxs.append(cluster_points_idxs[closest_idx_to_center])
143
+
144
+ return coords[idxs], normals[idxs]
145
+
146
+
147
+ def rotation(n):
148
+ if n[0] == 0.0 and n[1] == 0.0:
149
+ if n[2] == 1.0:
150
+ return np.identity(3)
151
+ elif n[2] == -1.0:
152
+ Q = np.identity(3)
153
+ Q[0, 0] = -1
154
+ return Q
155
+ else:
156
+ print('not possible')
157
+
158
+ rx = -n[1] / np.sqrt(n[0] * n[0] + n[1] * n[1])
159
+ ry = n[0] / np.sqrt(n[0] * n[0] + n[1] * n[1])
160
+ rz = 0
161
+ th = np.arccos(n[2])
162
+
163
+ q0 = np.cos(th / 2)
164
+ q1 = np.sin(th / 2) * rx
165
+ q2 = np.sin(th / 2) * ry
166
+ q3 = np.sin(th / 2) * rz
167
+
168
+ Q = np.zeros((3, 3))
169
+ Q[0, 0] = q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3
170
+ Q[0, 1] = 2 * (q1 * q2 - q0 * q3)
171
+ Q[0, 2] = 2 * (q1 * q3 + q0 * q2)
172
+ Q[1, 0] = 2 * (q1 * q2 + q0 * q3)
173
+ Q[1, 1] = q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3
174
+ Q[1, 2] = 2 * (q3 * q2 - q0 * q1)
175
+ Q[2, 0] = 2 * (q1 * q3 - q0 * q2)
176
+ Q[2, 1] = 2 * (q3 * q2 + q0 * q1)
177
+ Q[2, 2] = q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3
178
+
179
+ return Q
180
+
181
+
182
+ def TP_TN_FP_FN_visualization2pdb(gt_binding_site_coordinates, lig_scores, to_save_path, gt_indexes):
183
+ '''
184
+ Create dummy PDB files to visualize the results (TP, TN, FP, FN) on the receptor PDB file
185
+ '''
186
+ threshold = 0.5
187
+
188
+ # Initialize lists
189
+ TP_coords = []
190
+ FP_coords = []
191
+ TN_coords = []
192
+ FN_coords = []
193
+
194
+ for i, score in enumerate(lig_scores):
195
+ # If the atom is a true binding site
196
+ if i in gt_indexes:
197
+ if score > threshold:
198
+ TP_coords.append(gt_binding_site_coordinates[i])
199
+ else:
200
+ FN_coords.append(gt_binding_site_coordinates[i])
201
+ # If the atom is not a binding site
202
+ else:
203
+ if score > threshold:
204
+ FP_coords.append(gt_binding_site_coordinates[i])
205
+ else:
206
+ TN_coords.append(gt_binding_site_coordinates[i])
207
+
208
+ def generate_pdb_file(coordinates, file_name):
209
+ """Generate a dummy PDB file using the provided coordinates."""
210
+ with open(os.path.join(to_save_path, file_name), 'w') as pdb_file:
211
+ atom_number = 1
212
+ for coord in coordinates:
213
+ x, y, z = coord
214
+ pdb_file.write(
215
+ f"ATOM {atom_number:5} DUM DUM A{atom_number:4} {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00\n")
216
+ atom_number += 1
217
+ if atom_number == 9999:
218
+ atom_number = 1
219
+ pdb_file.write("END")
220
+
221
+ # Generate PDB files for TP, FP, TN, and FN
222
+ generate_pdb_file(TP_coords, os.path.join(to_save_path, "TP_atoms.pdb"))
223
+ generate_pdb_file(FP_coords, os.path.join(to_save_path, "FP_atoms.pdb"))
224
+ generate_pdb_file(TN_coords, os.path.join(to_save_path, "TN_atoms.pdb"))
225
+ generate_pdb_file(FN_coords, os.path.join(to_save_path, "FN_atoms.pdb"))
226
+
227
+ print('TP:', len(TP_coords), 'FP:', len(FP_coords), 'FN:', len(FN_coords), 'TN:', len(TN_coords))
228
+
229
+
230
+ def visualize_TP_TN_FP_FN_residue_level(lig_scores, gt_indexes, residues, receptor_path, tosavepath):
231
+ threshold = 0.5
232
+
233
+ # Initialize lists
234
+ tp_list = []
235
+ fp_list = []
236
+ tn_list = []
237
+ fn_list = []
238
+
239
+ with open(receptor_path, 'r') as f:
240
+ lines = f.readlines()
241
+
242
+ res_atoms = [len(i[1]['atoms']) for i in residues.items()]
243
+
244
+ for i, score in enumerate(lig_scores):
245
+ # If the atom is a true binding site
246
+ lines2add = res_atoms[i]
247
+ if i in gt_indexes:
248
+ if score > threshold:
249
+ tp_list.append(lines[:lines2add])
250
+ del lines[:lines2add]
251
+ else:
252
+ fn_list.append(lines[:lines2add])
253
+ del lines[:lines2add]
254
+ # If the atom is not a binding site
255
+ else:
256
+ if score > threshold:
257
+ fp_list.append(lines[:lines2add])
258
+ del lines[:lines2add]
259
+ else:
260
+ tn_list.append(lines[:lines2add])
261
+ del lines[:lines2add]
262
+
263
+ # Generate PDB files for TP, FP, TN, and FN
264
+ with open(os.path.join(tosavepath, 'TP_residues.pdb'), 'w') as f:
265
+ for l in tp_list:
266
+ for item in l:
267
+ f.write(item)
268
+ with open(os.path.join(tosavepath, 'FP_residues.pdb'), 'w') as f:
269
+ for l in fp_list:
270
+ for item in l:
271
+ f.write(item)
272
+ with open(os.path.join(tosavepath, 'FN_residues.pdb'), 'w') as f:
273
+ for l in fn_list:
274
+ for item in l:
275
+ f.write(item)
276
+ with open(os.path.join(tosavepath, 'TN_residues.pdb'), 'w') as f:
277
+ for l in tn_list:
278
+ for item in l:
279
+ f.write(item)
280
+
281
+ # print('TP:', len(tp_list), 'FP:', len(fp_list), 'FN:', len(fn_list), 'TN:', len(tn_list))
282
+
283
+
284
+ def calculate_TP_TN_FP_FN(lig_scores, gt_indexes):
285
+ threshold = 0.5
286
+
287
+ # Initialize lists
288
+ TP = 0
289
+ FP = 0
290
+ TN = 0
291
+ FN = 0
292
+
293
+ for i, score in enumerate(lig_scores):
294
+ # If the atom is a true binding site
295
+ if i in gt_indexes:
296
+ if score > threshold:
297
+ TP += 1
298
+ else:
299
+ FN += 1
300
+ # If the atom is not a binding site
301
+ else:
302
+ if score > threshold:
303
+ FP += 1
304
+ else:
305
+ TN += 1
306
+
307
+ print('TP:', TP, 'FP:', FP, 'FN:', FN, 'TN:', TN)
308
+
309
+
310
+ def show_roc_curve(true_labels, lig_scores, auc_roc):
311
+ # Calculate ROC curve
312
+ fpr, tpr, thresholds = roc_curve(true_labels, lig_scores)
313
+
314
+ # Plot ROC curve
315
+ plt.figure(figsize=(8, 6))
316
+ plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {auc_roc:.2f})')
317
+ plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
318
+ plt.xlim([0.0, 1.0])
319
+ plt.ylim([0.0, 1.05])
320
+ plt.xlabel('False Positive Rate')
321
+ plt.ylabel('True Positive Rate')
322
+ plt.title('Receiver Operating Characteristic (ROC) Curve')
323
+ plt.legend(loc="lower right")
324
+ plt.show()
325
+
326
+
327
+ def calculate_metrics(true_labels, predicted_labels, lig_scores, to_save_metrics_path):
328
+ auc_roc = roc_auc_score(true_labels, lig_scores)
329
+ accuracy = accuracy_score(true_labels, predicted_labels)
330
+ precision = precision_score(true_labels, predicted_labels)
331
+ recall = recall_score(true_labels, predicted_labels)
332
+ f1 = f1_score(true_labels, predicted_labels)
333
+ pr, re, _ = precision_recall_curve(true_labels, lig_scores)
334
+ auc_pr = auc(re, pr)
335
+ conf_matrix = confusion_matrix(true_labels, predicted_labels)
336
+ mcc = matthews_corrcoef(true_labels, predicted_labels)
337
+
338
+ tn, fp, fn, tp = conf_matrix.ravel()
339
+ fpr = fp / (fp + tn)
340
+ tpr = tp / (tp + fn) # True positive rate == sensitivity == recall
341
+ npv = tn / (tn + fn)
342
+ spc = tn / (fp + tn) # Specificity or True Negative Rate
343
+
344
+ with open(to_save_metrics_path, 'w') as f:
345
+ print(f"AUC-ROC: {auc_roc:.4f}", file=f)
346
+ print(f"Accuracy: {accuracy:.4f}", file=f)
347
+ print(f"Precision: {precision:.4f}", file=f)
348
+ print(f"Recall: {recall:.4f}", file=f)
349
+ print(f"F1 Score: {f1:.4f}", file=f)
350
+ print(f"AUC-PR: {auc_pr:.4f}", file=f)
351
+ print(f"Confusion Matrix:\n {conf_matrix}", file=f)
352
+ print(f"Matthews Correlation Coefficient: {mcc:.4f}", file=f)
353
+ print(f"False Positive Rate (FPR): {fpr:.4f}", file=f)
354
+ print(f"Negative Predictive Value (NPV): {npv:.4f}", file=f)
355
+ print(f"Specificity (SPC): {spc:.4f}", file=f)
356
+
357
+ return auc_roc, accuracy, precision, recall, f1, auc_pr, conf_matrix, mcc, fpr, npv, spc
358
+
359
+
360
+ def filter_out_HETATMs(pdb_file_path):
361
+ with open(pdb_file_path, 'r') as infile:
362
+ lines = infile.readlines()
363
+
364
+ # Filter out lines starting with 'HETATM'
365
+ filtered_lines = [line for line in lines if not line.startswith('HETATM')]
366
+
367
+ # Write the filtered lines back to the file
368
+ with open(pdb_file_path, 'w') as outfile:
369
+ outfile.writelines(filtered_lines)
370
+
371
+
372
+ def write_residue_prediction_pdb(receptor, output_pdb_path, residues_best):
373
+ """
374
+ :param receptor: original receptor pdb file path
375
+ :param results_save_path: where to save the prediction pdb file residues: the residues dict with scores
376
+ :param residues_best: the residues dict with scores
377
+ :return: Write the prediction PDB file with scores at residue level (replaces B-factor for residues).
378
+ """
379
+ # rec_name = receptor.split('/')[-1].split('_')[0]
380
+ # output_pdb_path = os.path.join(results_save_path, f'{rec_name}_pred.pdb')
381
+ #
382
+ # # Ensure the directory exists
383
+ # os.makedirs(results_save_path, exist_ok=True)
384
+
385
+ # Open the original receptor PDB file and the output PDB file for writing the predictions
386
+ with open(receptor, 'r') as original_pdb, open(output_pdb_path, 'w') as pred_pdb:
387
+ for line in original_pdb:
388
+ if line.startswith("ATOM") or line.startswith("HETATM"): # Process only ATOM and HETATM records
389
+ # Extract residue info (residue number, chain ID, and insertion code)
390
+ chain_id = line[21]
391
+ res_num = line[22:26].strip()
392
+ insertion_code = line[26].strip()
393
+
394
+ # Create the residue ID in the same format as in residues_best
395
+ res_id = f'{res_num}_{chain_id}'
396
+ if insertion_code:
397
+ res_id = f'{res_id}_{insertion_code}'
398
+
399
+ # Check if the residue exists in residues_best
400
+ if res_id in residues_best:
401
+ # Extract the prediction score
402
+ pred_score = residues_best[res_id]['scores']
403
+
404
+ # Modify the line to replace the B-factor (position 61-66) with the prediction score
405
+ new_b_factor = f'{pred_score:6.3f}' # Format the prediction score with 3 decimal places
406
+ new_line = f'{line[:60]}{new_b_factor:>6}{line[66:]}'
407
+
408
+ # Write the modified line to the new PDB file
409
+ pred_pdb.write(new_line)
410
+ else:
411
+ # If no prediction score is found, write the original line
412
+ pred_pdb.write(line)
413
+ else:
414
+ # Write lines that do not start with ATOM or HETATM (like headers and footers) unchanged
415
+ pred_pdb.write(line)
416
+
417
+ print(f"Residue-level prediction PDB file saved as {output_pdb_path}")
418
+
419
+
420
+ def write_atom_prediction_pdb(receptor, results_save_path, lig_scores_only_receptor_atoms):
421
+ """
422
+ :param receptor: original receptor pdb file path
423
+ :param results_save_path: where to save the prediction pdb file residues: the residues dict with scores
424
+ :param residues_best: the residues dict with scores
425
+ :return: Write the prediction PDB file with scores at atom level (replaces B-factor for each atom).
426
+ """
427
+
428
+ rec_name = receptor.split('/')[-1].split('_')[0]
429
+ output_pdb_path = os.path.join(results_save_path, f'{rec_name}_pred_per_atom.pdb')
430
+
431
+ # Make sure the length of lig_scores matches the number of atoms in the PDB
432
+ assert len(lig_scores_only_receptor_atoms) == sum(1 for line in open(receptor) if line.startswith("ATOM") or line.startswith("HETATM")), \
433
+ "Number of scores doesn't match the number of atoms in the PDB file"
434
+
435
+ # Open the original receptor PDB file and the output PDB file for writing the predictions
436
+ with open(receptor, 'r') as original_pdb, open(output_pdb_path, 'w') as pred_pdb2:
437
+ atom_index = 0 # To track which score corresponds to which atom
438
+ for line in original_pdb:
439
+ if line.startswith("ATOM") or line.startswith("HETATM"): # Process only ATOM and HETATM records
440
+ # Extract the prediction score for the current atom
441
+ pred_score = lig_scores_only_receptor_atoms[atom_index][0] # Get the prediction score for this atom
442
+
443
+ # Modify the line to replace the B-factor (position 61-66) with the prediction score
444
+ new_b_factor = f'{pred_score:6.3f}' # Format the prediction score with 3 decimal places
445
+ new_line = f'{line[:60]}{new_b_factor:>6}{line[66:]}' # Insert the prediction score at the correct position
446
+
447
+ # Write the modified line to the new PDB file
448
+ pred_pdb2.write(new_line)
449
+
450
+ # Increment the atom index
451
+ atom_index += 1
452
+ else:
453
+ # Write lines that do not start with ATOM or HETATM (like headers and footers) unchanged
454
+ pred_pdb2.write(line)
455
+
456
+ print(f"Per-atom prediction PDB file saved as {output_pdb_path}")
457
+
458
+
459
+ def receptor_info(receptor, lig_scores_only_receptor_atoms):
460
+ """
461
+ Extract residue groups and compute the best scores for each residue in the receptor.
462
+
463
+ Args:
464
+ receptor (str): The path to the receptor PDB file.
465
+ lig_scores_only_receptor_atoms (ndarray): List of ligandability scores for each atom.
466
+
467
+ Returns:
468
+ residues (dict): Dictionary containing atom information and ligand scores for each residue.
469
+ residues_best (dict): Dictionary containing the best ligand score for each residue.
470
+ """
471
+ # Create the residue groups for the whole protein
472
+ residues = {}
473
+ with open(receptor, 'r') as file:
474
+ for line in file:
475
+ if line.startswith("ATOM"):
476
+ chain_id = line[21] # Extract chain identifier
477
+ atom_id = line[6:11].strip()
478
+ res_id = f'{line[22:26].strip()}_{chain_id}' # Concatenate residue ID with chain ID
479
+ insertion_code = line[26].strip()
480
+ if insertion_code:
481
+ res_id = f'{res_id}_{insertion_code}'
482
+ if res_id not in residues:
483
+ residues[res_id] = {"atoms": [], 'scores': []}
484
+ residues[res_id]["atoms"].append(atom_id)
485
+
486
+ atom2check = int(atom_id) - 1
487
+ residues[res_id]['scores'].append(lig_scores_only_receptor_atoms[atom2check][0])
488
+
489
+ # Take the best scores for the whole protein
490
+ residues_best = {}
491
+ for res_id, res_data in residues.items():
492
+ residues_best[res_id] = {'scores': []}
493
+ check_best = res_data['scores']
494
+ best_atom = check_best.index(max(check_best)) # We take the best atom score of the residue
495
+ residues_best[res_id]['scores'] = check_best[best_atom]
496
+
497
+ return residues, residues_best