Upload 404 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +26 -0
- Dockerfile +38 -0
- ParaSurf/create_datasets_from_csv/README.md +37 -0
- ParaSurf/create_datasets_from_csv/__pycache__/split_pdb2chains_only.cpython-39.pyc +0 -0
- ParaSurf/create_datasets_from_csv/final_dataset_preparation.py +146 -0
- ParaSurf/create_datasets_from_csv/process_csv_dataset.py +130 -0
- ParaSurf/create_datasets_from_csv/split_pdb2chains_only.py +43 -0
- ParaSurf/model/ParaSurf_model.py +173 -0
- ParaSurf/model/__pycache__/ParaSurf_model.cpython-310.pyc +0 -0
- ParaSurf/model/__pycache__/ParaSurf_model.cpython-39.pyc +0 -0
- ParaSurf/model/__pycache__/dataset.cpython-310.pyc +0 -0
- ParaSurf/model/__pycache__/dataset.cpython-39.pyc +0 -0
- ParaSurf/model/dataset.py +107 -0
- ParaSurf/model_weights/README.md +11 -0
- ParaSurf/preprocess/README.md +71 -0
- ParaSurf/preprocess/__pycache__/check_empty_features.cpython-310.pyc +0 -0
- ParaSurf/preprocess/__pycache__/check_empty_features.cpython-39.pyc +0 -0
- ParaSurf/preprocess/__pycache__/clean_dataset.cpython-310.pyc +0 -0
- ParaSurf/preprocess/__pycache__/clean_dataset.cpython-39.pyc +0 -0
- ParaSurf/preprocess/check_empty_features.py +68 -0
- ParaSurf/preprocess/check_rec_ant_touch.py +89 -0
- ParaSurf/preprocess/clean_dataset.py +27 -0
- ParaSurf/preprocess/create_input_features.py +230 -0
- ParaSurf/preprocess/create_proteins_file.py +23 -0
- ParaSurf/preprocess/create_sample_files.py +31 -0
- ParaSurf/preprocess/create_surfpoints.py +57 -0
- ParaSurf/train/V_domain_results.py +159 -0
- ParaSurf/train/__pycache__/V_domain_results.cpython-310.pyc +0 -0
- ParaSurf/train/__pycache__/V_domain_results.cpython-39.pyc +0 -0
- ParaSurf/train/__pycache__/bsite_extraction.cpython-310.pyc +0 -0
- ParaSurf/train/__pycache__/bsite_extraction.cpython-39.pyc +0 -0
- ParaSurf/train/__pycache__/distance_coords.cpython-310.pyc +0 -0
- ParaSurf/train/__pycache__/distance_coords.cpython-39.pyc +0 -0
- ParaSurf/train/__pycache__/features.cpython-310.pyc +0 -0
- ParaSurf/train/__pycache__/features.cpython-39.pyc +0 -0
- ParaSurf/train/__pycache__/network.cpython-310.pyc +0 -0
- ParaSurf/train/__pycache__/network.cpython-39.pyc +0 -0
- ParaSurf/train/__pycache__/protein.cpython-310.pyc +0 -0
- ParaSurf/train/__pycache__/protein.cpython-39.pyc +0 -0
- ParaSurf/train/__pycache__/utils.cpython-310.pyc +0 -0
- ParaSurf/train/__pycache__/utils.cpython-39.pyc +0 -0
- ParaSurf/train/__pycache__/validation.cpython-310.pyc +0 -0
- ParaSurf/train/__pycache__/validation.cpython-39.pyc +0 -0
- ParaSurf/train/bsite_extraction.py +48 -0
- ParaSurf/train/distance_coords.py +173 -0
- ParaSurf/train/features.py +37 -0
- ParaSurf/train/network.py +58 -0
- ParaSurf/train/protein.py +92 -0
- ParaSurf/train/train.py +172 -0
- ParaSurf/train/utils.py +497 -0
.gitattributes
CHANGED
@@ -33,3 +33,29 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
pdb2pqr-linux-bin64-2.1.1/_codecs_cn.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
37 |
+
pdb2pqr-linux-bin64-2.1.1/_codecs_hk.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
38 |
+
pdb2pqr-linux-bin64-2.1.1/_codecs_jp.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
39 |
+
pdb2pqr-linux-bin64-2.1.1/_codecs_kr.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
40 |
+
pdb2pqr-linux-bin64-2.1.1/_codecs_tw.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
41 |
+
pdb2pqr-linux-bin64-2.1.1/_ctypes.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
42 |
+
pdb2pqr-linux-bin64-2.1.1/datetime.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
43 |
+
pdb2pqr-linux-bin64-2.1.1/doc/images/flowchart.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
pdb2pqr-linux-bin64-2.1.1/libcrypto.so.1.0.0 filter=lfs diff=lfs merge=lfs -text
|
45 |
+
pdb2pqr-linux-bin64-2.1.1/libexpat.so.1 filter=lfs diff=lfs merge=lfs -text
|
46 |
+
pdb2pqr-linux-bin64-2.1.1/libncursesw.so.5 filter=lfs diff=lfs merge=lfs -text
|
47 |
+
pdb2pqr-linux-bin64-2.1.1/libpython2.7.so.1.0 filter=lfs diff=lfs merge=lfs -text
|
48 |
+
pdb2pqr-linux-bin64-2.1.1/libreadline.so.6 filter=lfs diff=lfs merge=lfs -text
|
49 |
+
pdb2pqr-linux-bin64-2.1.1/libssl.so.1.0.0 filter=lfs diff=lfs merge=lfs -text
|
50 |
+
pdb2pqr-linux-bin64-2.1.1/libstdc++.so.6 filter=lfs diff=lfs merge=lfs -text
|
51 |
+
pdb2pqr-linux-bin64-2.1.1/libtinfo.so.5 filter=lfs diff=lfs merge=lfs -text
|
52 |
+
pdb2pqr-linux-bin64-2.1.1/libz.so.1 filter=lfs diff=lfs merge=lfs -text
|
53 |
+
pdb2pqr-linux-bin64-2.1.1/numpy.core.multiarray.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
54 |
+
pdb2pqr-linux-bin64-2.1.1/numpy.core.umath.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
55 |
+
pdb2pqr-linux-bin64-2.1.1/numpy.fft.fftpack_lite.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
56 |
+
pdb2pqr-linux-bin64-2.1.1/numpy.linalg._umath_linalg.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
57 |
+
pdb2pqr-linux-bin64-2.1.1/numpy.linalg.lapack_lite.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
58 |
+
pdb2pqr-linux-bin64-2.1.1/numpy.random.mtrand.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
59 |
+
pdb2pqr-linux-bin64-2.1.1/pdb2pka._apbslib.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
60 |
+
pdb2pqr-linux-bin64-2.1.1/pdb2pka._pMC_mult.x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
61 |
+
pdb2pqr-linux-bin64-2.1.1/pdb2pqr filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM continuumio/miniconda3
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
COPY . /app
|
5 |
+
|
6 |
+
# Install system dependencies
|
7 |
+
RUN apt-get update && apt-get install -y build-essential wget
|
8 |
+
|
9 |
+
# Create Conda environment and install dependencies
|
10 |
+
RUN conda create -n ParaSurf python=3.10 openbabel -c conda-forge -y
|
11 |
+
RUN conda run -n ParaSurf pip install -r requirements.txt
|
12 |
+
|
13 |
+
# Increase file descriptor limit to prevent FD_SETSIZE error
|
14 |
+
RUN echo "ulimit -n 65535" >> ~/.bashrc
|
15 |
+
|
16 |
+
# Download missing Gradio frpc binary
|
17 |
+
RUN wget https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64 -O /opt/conda/envs/ParaSurf/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.3 && \
|
18 |
+
chmod +x /opt/conda/envs/ParaSurf/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.3
|
19 |
+
|
20 |
+
# Install DMS software
|
21 |
+
WORKDIR /app/dms
|
22 |
+
RUN make install
|
23 |
+
|
24 |
+
# Ensure necessary binaries are executable
|
25 |
+
RUN chmod +x /app/pdb2pqr-linux-bin64-2.1.1/pdb2pqr && \
|
26 |
+
chmod +x /opt/conda/envs/ParaSurf/bin/* && \
|
27 |
+
chmod -R 755 /app
|
28 |
+
|
29 |
+
# Set writable directories for Matplotlib cache
|
30 |
+
ENV MPLCONFIGDIR=/tmp/matplotlib
|
31 |
+
ENV XDG_CACHE_HOME=/tmp
|
32 |
+
|
33 |
+
WORKDIR /app
|
34 |
+
EXPOSE 7860
|
35 |
+
|
36 |
+
# Run the app with higher file descriptor limits
|
37 |
+
CMD ["bash", "-c", "ulimit -n 65535 && conda run --no-capture-output -n ParaSurf python app.py"]
|
38 |
+
|
ParaSurf/create_datasets_from_csv/README.md
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# **Dataset preperation**
|
2 |
+
|
3 |
+
### Steps for Dataset Preparation
|
4 |
+
#### Step 1
|
5 |
+
Download Specific PDB Files Use the process_csv_dataset.py script to download the PDB files listed in the .csv files.
|
6 |
+
```bash
|
7 |
+
# Step 1: Download specified PDB files
|
8 |
+
python process_csv_dataset.py
|
9 |
+
```
|
10 |
+
|
11 |
+
#### Step 2
|
12 |
+
Generate Final Complexes Run final_dataset_preparation.py to arrange the files into complexes with the specified chain IDs from the .csv files.
|
13 |
+
```bash
|
14 |
+
# Step 2: Organize files into final complexes
|
15 |
+
python final_dataset_preparation.py
|
16 |
+
```
|
17 |
+
|
18 |
+
|
19 |
+
After running these scripts, you will find a test_data/pdbs folder organized as follows:
|
20 |
+
```bash
|
21 |
+
├── PECAN
|
22 |
+
│ ├── TRAIN
|
23 |
+
│ │ ├── 1A3R_receptor_1.pdb
|
24 |
+
│ │ ├── 1A3R_antigen_1_1.pdb
|
25 |
+
│ │ ├── ...
|
26 |
+
│ │ ├── 5WUX_receptor_1.pdb
|
27 |
+
│ │ └── 5WUX_antigen_1_1.pdb
|
28 |
+
│ ├── VAL
|
29 |
+
│ └── TEST
|
30 |
+
├── Paragraph_Expanded
|
31 |
+
│ ├── TRAIN
|
32 |
+
│ ├── VAL
|
33 |
+
│ └── TEST
|
34 |
+
└── MIPE
|
35 |
+
├── TRAIN_VAL
|
36 |
+
└── TEST
|
37 |
+
```
|
ParaSurf/create_datasets_from_csv/__pycache__/split_pdb2chains_only.cpython-39.pyc
ADDED
Binary file (1.54 kB). View file
|
|
ParaSurf/create_datasets_from_csv/final_dataset_preparation.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import pandas as pd
|
4 |
+
from split_pdb2chains_only import extract_chains_from_pdb
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
def process_raw_pdb_data(info_df, initial_raw_pdb_files, final_folder):
|
9 |
+
"""
|
10 |
+
Processes the raw PDB files by extracting the specific antibody and antigen chains from the .csv file, merging them,
|
11 |
+
and saving the merged files in the final train_val and test folder.
|
12 |
+
|
13 |
+
Parameters:
|
14 |
+
info_df (DataFrame): DataFrame containing the PDB codes, antibody chains, and antigen chains.
|
15 |
+
initial_raw_pdb_files (str): Path to the initial raw PDB files directory.
|
16 |
+
final_folder (str): Path to the folder where the processed files will be saved.
|
17 |
+
"""
|
18 |
+
if not os.path.exists(final_folder):
|
19 |
+
os.makedirs(final_folder)
|
20 |
+
|
21 |
+
for i, row in tqdm(info_df.iterrows(), total=len(info_df)):
|
22 |
+
pdb_id = row['pdb_code']
|
23 |
+
ab_heavy_chain = row['Heavy_chain'] # Use only this line if you want to construct the only heavy chain dataset
|
24 |
+
ab_light_chain = row['Light_chain'] # Use only this line if you want to construct the only light chain dataset
|
25 |
+
ag_chain = row['ag']
|
26 |
+
|
27 |
+
pdb_file = os.path.join(initial_raw_pdb_files, pdb_id + '.pdb')
|
28 |
+
# Extract all the chains from the pdb file and save them to /tmp
|
29 |
+
chain_files, all_chains = extract_chains_from_pdb(pdb_file, '/tmp')
|
30 |
+
|
31 |
+
# Assign the correct chains
|
32 |
+
ab_heavy_chain_path = f'/tmp/{pdb_id}_chain{ab_heavy_chain}.pdb'
|
33 |
+
ab_light_chain_path = f'/tmp/{pdb_id}_chain{ab_light_chain}.pdb'
|
34 |
+
|
35 |
+
# Merge antibody chains into one file
|
36 |
+
receptor_output_path = f'{final_folder}/{pdb_id}_receptor_1.pdb'
|
37 |
+
with open(receptor_output_path, 'w') as receptor_file:
|
38 |
+
for ab_file in [ab_heavy_chain_path, ab_light_chain_path]: # also delete one (ab_heavy_chain_path or ab_light_chain_path) if you construct the only heavy/light chain dataset
|
39 |
+
with open(ab_file, 'r') as infile:
|
40 |
+
receptor_file.write(infile.read())
|
41 |
+
|
42 |
+
print(f"Successfully merged {ab_heavy_chain} and {ab_light_chain} into {receptor_output_path}")
|
43 |
+
|
44 |
+
ag_chain_list = ag_chain.split(';')
|
45 |
+
|
46 |
+
if len(ag_chain_list) == 1:
|
47 |
+
# If there's only one antigen chain
|
48 |
+
ag_chain_1 = ag_chain_list[0].strip()
|
49 |
+
ag_chain_1_path = f'/tmp/{pdb_id}_chain{ag_chain_1}.pdb'
|
50 |
+
print(f"Handling one antigen chain: {ag_chain_1}")
|
51 |
+
|
52 |
+
# Copy the single antigen chain to the output
|
53 |
+
antigen_output_path = f'{final_folder}/{pdb_id}_antigen_1_1.pdb'
|
54 |
+
shutil.copyfile(ag_chain_1_path, antigen_output_path)
|
55 |
+
|
56 |
+
print(f"Successfully copied {ag_chain_1} to {antigen_output_path}")
|
57 |
+
|
58 |
+
elif len(ag_chain_list) == 2:
|
59 |
+
# If there are two antigen chains
|
60 |
+
ag_chain_1, ag_chain_2 = ag_chain_list
|
61 |
+
ag_chain_1_path = f'/tmp/{pdb_id}_chain{ag_chain_1}.pdb'
|
62 |
+
ag_chain_2_path = f'/tmp/{pdb_id}_chain{ag_chain_2}.pdb'
|
63 |
+
print(f"Handling two antigen chains: {ag_chain_1}, {ag_chain_2}")
|
64 |
+
|
65 |
+
# Merge the antigen chains into a single PDB file
|
66 |
+
antigen_output_path = f'{final_folder}/{pdb_id}_antigen_1_1.pdb'
|
67 |
+
with open(antigen_output_path, 'w') as outfile:
|
68 |
+
for ag_file in [ag_chain_1_path, ag_chain_2_path]:
|
69 |
+
with open(ag_file, 'r') as infile:
|
70 |
+
outfile.write(infile.read())
|
71 |
+
|
72 |
+
print(f"Successfully merged {ag_chain_1} and {ag_chain_2} into {antigen_output_path}")
|
73 |
+
|
74 |
+
elif len(ag_chain_list) == 3:
|
75 |
+
# If there are three antigen chains
|
76 |
+
ag_chain_1, ag_chain_2, ag_chain_3 = ag_chain_list
|
77 |
+
ag_chain_1_path = f'/tmp/{pdb_id}_chain{ag_chain_1}.pdb'
|
78 |
+
ag_chain_2_path = f'/tmp/{pdb_id}_chain{ag_chain_2}.pdb'
|
79 |
+
ag_chain_3_path = f'/tmp/{pdb_id}_chain{ag_chain_3}.pdb'
|
80 |
+
print(f"Handling three antigen chains: {ag_chain_1}, {ag_chain_2}, {ag_chain_3}")
|
81 |
+
|
82 |
+
# Merge the antigen chains into a single PDB file
|
83 |
+
antigen_output_path = f'{final_folder}/{pdb_id}_antigen_1_1.pdb'
|
84 |
+
with open(antigen_output_path, 'w') as outfile:
|
85 |
+
for ag_file in [ag_chain_1_path, ag_chain_2_path, ag_chain_3_path]:
|
86 |
+
with open(ag_file, 'r') as infile:
|
87 |
+
outfile.write(infile.read())
|
88 |
+
|
89 |
+
print(f"Successfully merged {ag_chain_1}, {ag_chain_2}, and {ag_chain_3} into {antigen_output_path}")
|
90 |
+
|
91 |
+
# At the end, remove all the chain pdb files from the /tmp folder
|
92 |
+
for chain_file in chain_files:
|
93 |
+
os.remove(chain_file)
|
94 |
+
|
95 |
+
|
96 |
+
if __name__ == '__main__':
|
97 |
+
user = os.getenv('USER')
|
98 |
+
|
99 |
+
datasets = ['PECAN', 'Paragraph_Expanded', 'MIPE']
|
100 |
+
|
101 |
+
for dataset in datasets:
|
102 |
+
if dataset == 'MIPE': # here the split is train-val and test according to the MIPE paper
|
103 |
+
# csv path
|
104 |
+
train_val_info = pd.read_csv(f'/home/{user}/PycharmProjects/github_projects/ParaSurf/training_data/{dataset}/train_val.csv')
|
105 |
+
test_info = pd.read_csv(f'/home/{user}/PycharmProjects/github_projects/ParaSurf/training_data/{dataset}/test_set.csv')
|
106 |
+
|
107 |
+
# path to init raw PDB storage
|
108 |
+
init_pdb_files_train_val = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/train_val_data_initial_raw_pdb_files'
|
109 |
+
init_pdb_files_test = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/test_data_initial_raw_pdb_files'
|
110 |
+
|
111 |
+
# final folder
|
112 |
+
final_train_val_folder = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/TRAIN_VAL'
|
113 |
+
final_test_folder = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/TEST'
|
114 |
+
|
115 |
+
process_raw_pdb_data(train_val_info, init_pdb_files_train_val, final_train_val_folder)
|
116 |
+
process_raw_pdb_data(test_info, init_pdb_files_test, final_test_folder)
|
117 |
+
|
118 |
+
shutil.rmtree(init_pdb_files_train_val)
|
119 |
+
shutil.rmtree(init_pdb_files_test)
|
120 |
+
|
121 |
+
else:
|
122 |
+
# Paths to dataset csv files
|
123 |
+
train_info = pd.read_csv(f'/home/{user}/PycharmProjects/github_projects/ParaSurf/training_data/{dataset}/train_set.csv')
|
124 |
+
val_info = pd.read_csv(f'/home/{user}/PycharmProjects/github_projects/ParaSurf/training_data/{dataset}/val_set.csv')
|
125 |
+
test_info = pd.read_csv(f'/home/{user}/PycharmProjects/github_projects/ParaSurf/training_data/{dataset}/test_set.csv')
|
126 |
+
|
127 |
+
|
128 |
+
# Paths to init raw pdb files
|
129 |
+
initial_pdb_files_train = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/train_data_initial_raw_pdb_files'
|
130 |
+
initial_pdb_files_val = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/val_data_initial_raw_pdb_files'
|
131 |
+
initial_pdb_files_test = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/test_data_initial_raw_pdb_files'
|
132 |
+
|
133 |
+
# Final folder for the merged files that contain the final PDB complexes
|
134 |
+
final_train_folder = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/TRAIN'
|
135 |
+
final_val_folder = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/VAL'
|
136 |
+
final_test_folder = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/TEST'
|
137 |
+
|
138 |
+
# Process the train-val-test data
|
139 |
+
process_raw_pdb_data(train_info, initial_pdb_files_train, final_train_folder)
|
140 |
+
process_raw_pdb_data(val_info, initial_pdb_files_val, final_val_folder)
|
141 |
+
process_raw_pdb_data(test_info, initial_pdb_files_test, final_test_folder)
|
142 |
+
|
143 |
+
# REMOVE the init raw pdb files
|
144 |
+
shutil.rmtree(initial_pdb_files_train)
|
145 |
+
shutil.rmtree(initial_pdb_files_val)
|
146 |
+
shutil.rmtree(initial_pdb_files_test)
|
ParaSurf/create_datasets_from_csv/process_csv_dataset.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
from Bio.PDB import PDBList
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
def add_headers_if_not_present(csv_file, headerlist):
|
9 |
+
"""
|
10 |
+
Add headers to the CSV file if they are not already present.
|
11 |
+
|
12 |
+
Parameters:
|
13 |
+
csv_file (str): Path to the CSV file.
|
14 |
+
headerlist (list): List of headers to add.
|
15 |
+
"""
|
16 |
+
# Read the first row to check for headers
|
17 |
+
first_row = pd.read_csv(csv_file, nrows=1)
|
18 |
+
|
19 |
+
# Check if the first row contains the expected headers
|
20 |
+
if list(first_row.columns) != headerlist:
|
21 |
+
print(f"Headers not found in {csv_file}. Adding headers...")
|
22 |
+
# Load the full data without headers
|
23 |
+
data = pd.read_csv(csv_file, header=None)
|
24 |
+
# Assign the correct headers
|
25 |
+
data.columns = headerlist
|
26 |
+
# Save the file with the correct headers
|
27 |
+
data.to_csv(csv_file, header=True, index=False)
|
28 |
+
print(f"Headers added to {csv_file}")
|
29 |
+
else:
|
30 |
+
print(f"Headers already present in {csv_file}. No changes made.")
|
31 |
+
|
32 |
+
|
33 |
+
def download_pdb(pdb_code, output_dir):
|
34 |
+
pdbl = PDBList()
|
35 |
+
pdbl.retrieve_pdb_file(pdb_code, pdir=output_dir, file_format='pdb')
|
36 |
+
|
37 |
+
|
38 |
+
def download_and_rename_pdb_files(pdb_list, folder):
|
39 |
+
|
40 |
+
"""
|
41 |
+
Downloads PDB files from the provided list and renames them from `.ent` to `{pdb_code}.pdb`.
|
42 |
+
|
43 |
+
Parameters:
|
44 |
+
pdb_list (list): List of PDB codes to be downloaded.
|
45 |
+
folder (str): Directory where the PDB files will be saved and renamed.
|
46 |
+
"""
|
47 |
+
# Download PDB files
|
48 |
+
for pdb_code in pdb_list:
|
49 |
+
download_pdb(pdb_code, folder)
|
50 |
+
|
51 |
+
# Rename files to {pdb_code}.pdb
|
52 |
+
for pdb_file in os.listdir(folder):
|
53 |
+
if pdb_file.endswith('.ent'):
|
54 |
+
old_file_path = os.path.join(folder, pdb_file)
|
55 |
+
new_file_name = pdb_file.split('.')[0][-4:].upper() + '.pdb' #Capital because the csv gives the pdb names in capital
|
56 |
+
new_file_path = os.path.join(folder, new_file_name)
|
57 |
+
os.rename(old_file_path, new_file_path)
|
58 |
+
print(f"Renamed {old_file_path} to {new_file_path}")
|
59 |
+
|
60 |
+
|
61 |
+
def process_dataset(csv_file, folder):
|
62 |
+
|
63 |
+
"""
|
64 |
+
Processes a dataset by adding headers, extracting PDB codes, and downloading/renaming PDB files.
|
65 |
+
|
66 |
+
Parameters:
|
67 |
+
csv_file (str): Path to the CSV file.
|
68 |
+
folder (str): Directory where the PDB files will be saved and renamed.
|
69 |
+
"""
|
70 |
+
# Add headers if not present
|
71 |
+
add_headers_if_not_present(csv_file, headerlist)
|
72 |
+
|
73 |
+
# Read the CSV file
|
74 |
+
dataset = pd.read_csv(csv_file)
|
75 |
+
|
76 |
+
# Create folder if it doesn't exist
|
77 |
+
if not os.path.exists(folder):
|
78 |
+
os.makedirs(folder)
|
79 |
+
|
80 |
+
# Initialize the PDB list
|
81 |
+
pdb_list = []
|
82 |
+
|
83 |
+
# Process each row
|
84 |
+
for i, row in dataset.iterrows():
|
85 |
+
pdb_list.append(row['pdb_code'])
|
86 |
+
|
87 |
+
# Download and rename PDB files
|
88 |
+
download_and_rename_pdb_files(pdb_list, folder)
|
89 |
+
|
90 |
+
if __name__ == '__main__':
|
91 |
+
|
92 |
+
# ALL datasets follow the same process
|
93 |
+
user = os.getenv('USER')
|
94 |
+
datasets = ['PECAN', 'Paragraph_Expanded', 'MIPE']
|
95 |
+
|
96 |
+
|
97 |
+
# Define the correct headers
|
98 |
+
headerlist = ['pdb_code', 'Light_chain', 'Heavy_chain', 'ag']
|
99 |
+
|
100 |
+
for dataset in datasets:
|
101 |
+
|
102 |
+
if dataset == 'MIPE': # here the split is train-val and test according to the MIPE paper
|
103 |
+
# csv path
|
104 |
+
train_val = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/training_data/{dataset}/train_val.csv'
|
105 |
+
test = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/training_data/{dataset}/test_set.csv'
|
106 |
+
|
107 |
+
# path to init raw PDB storage
|
108 |
+
train_val_folder = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/train_val_data_initial_raw_pdb_files'
|
109 |
+
test_folder = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/test_data_initial_raw_pdb_files'
|
110 |
+
|
111 |
+
process_dataset(train_val, train_val_folder)
|
112 |
+
process_dataset(test, test_folder)
|
113 |
+
|
114 |
+
else:
|
115 |
+
# Paths to your CSV files. Download dataset from here: https://github.com/oxpig/Paragraph/tree/main/training_data/Expanded
|
116 |
+
train = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/training_data/{dataset}/train_set.csv'
|
117 |
+
val = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/training_data/{dataset}/val_set.csv'
|
118 |
+
test = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/training_data/{dataset}/test_set.csv'
|
119 |
+
|
120 |
+
|
121 |
+
# Paths for init raw PDB file storage
|
122 |
+
train_folder = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/train_data_initial_raw_pdb_files'
|
123 |
+
val_folder = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/val_data_initial_raw_pdb_files'
|
124 |
+
test_folder = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/{dataset}/test_data_initial_raw_pdb_files'
|
125 |
+
|
126 |
+
|
127 |
+
# Process each dataset
|
128 |
+
process_dataset(train, train_folder)
|
129 |
+
process_dataset(val, val_folder)
|
130 |
+
process_dataset(test, test_folder)
|
ParaSurf/create_datasets_from_csv/split_pdb2chains_only.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
def extract_chains_from_pdb(pdb_file, output_dir):
|
5 |
+
"""
|
6 |
+
Extract and save the chains from a PDB file as separate chain-specific PDB files.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
pdb_file (str): Path to the PDB file.
|
10 |
+
output_dir (str): Path to the directory where the chain-specific files should be saved.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
list: Paths to the chain-specific PDB files.
|
14 |
+
"""
|
15 |
+
chain_dict = {}
|
16 |
+
|
17 |
+
with open(pdb_file, 'r') as f:
|
18 |
+
for line in f:
|
19 |
+
if line.startswith('ATOM'):
|
20 |
+
chain_id = line[21]
|
21 |
+
if chain_id in chain_dict:
|
22 |
+
chain_dict[chain_id].append(line)
|
23 |
+
else:
|
24 |
+
chain_dict[chain_id] = [line]
|
25 |
+
|
26 |
+
chain_files = []
|
27 |
+
for chain_id, lines in chain_dict.items():
|
28 |
+
chain_file = os.path.join(output_dir, f"{os.path.splitext(os.path.basename(pdb_file))[0]}_chain{chain_id}.pdb")
|
29 |
+
with open(chain_file, 'w') as f:
|
30 |
+
f.writelines(lines)
|
31 |
+
# print(f'Chain {chain_id} saved as {chain_file}.')
|
32 |
+
chain_files.append(chain_file)
|
33 |
+
|
34 |
+
chain_ids = [chain.split("/")[-1].split(".")[0][-1] for chain in chain_files]
|
35 |
+
|
36 |
+
return chain_files, chain_ids
|
37 |
+
|
38 |
+
if __name__ == '__main__':
|
39 |
+
pdb_file = '/home/angepapa/PycharmProjects/DeepSurf2.0/3bgf.pdb'
|
40 |
+
output_dir = "/".join(pdb_file.split('/')[:-1])
|
41 |
+
chain_files, chain_ids = extract_chains_from_pdb(pdb_file, output_dir)
|
42 |
+
print(chain_files)
|
43 |
+
print(chain_ids)
|
ParaSurf/model/ParaSurf_model.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchsummary import summary
|
5 |
+
import time
|
6 |
+
|
7 |
+
|
8 |
+
class GeM(nn.Module):
|
9 |
+
def __init__(self, p=3.0, eps=1e-6):
|
10 |
+
super(GeM, self).__init__()
|
11 |
+
# Initialize p as a learnable parameter
|
12 |
+
self.p = nn.Parameter(torch.ones(1) * p)
|
13 |
+
self.eps = eps
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
return self.gem(x, self.p, self.eps)
|
17 |
+
|
18 |
+
def gem(self, x, p, eps):
|
19 |
+
# Clamp all elements in x to a minimum of eps and then raise them to the power of p
|
20 |
+
# Apply avg_pool3d with kernel size being the spatial dimension of the feature map (entire depth, height, width)
|
21 |
+
# Finally, take the power of 1/p to invert the earlier power of p operation
|
22 |
+
return F.avg_pool3d(x.clamp(min=eps).pow(p), (x.size(2), x.size(3), x.size(4))).pow(1. / p)
|
23 |
+
|
24 |
+
def __repr__(self):
|
25 |
+
# This helps in identifying the layer characteristics when printing the model or layer
|
26 |
+
return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', eps=' + str(
|
27 |
+
self.eps) + ')'
|
28 |
+
|
29 |
+
|
30 |
+
# Define a custom Bottleneck module with optional dilation
|
31 |
+
class DilatedBottleneck(nn.Module):
|
32 |
+
expansion = 4
|
33 |
+
|
34 |
+
def __init__(self, in_planes, planes, stride=1, dilation=1, dropout_prob=0.25):
|
35 |
+
super(DilatedBottleneck, self).__init__()
|
36 |
+
self.conv1 = nn.Conv3d(in_planes, planes, kernel_size=1, bias=False)
|
37 |
+
self.bn1 = nn.BatchNorm3d(planes)
|
38 |
+
self.dropout1 = nn.Dropout3d(dropout_prob)
|
39 |
+
self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation,
|
40 |
+
bias=False)
|
41 |
+
self.bn2 = nn.BatchNorm3d(planes)
|
42 |
+
self.dropout2 = nn.Dropout3d(dropout_prob)
|
43 |
+
self.conv3 = nn.Conv3d(planes, self.expansion * planes, kernel_size=1, bias=False)
|
44 |
+
self.bn3 = nn.BatchNorm3d(self.expansion * planes)
|
45 |
+
|
46 |
+
self.shortcut = nn.Sequential()
|
47 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
48 |
+
self.shortcut = nn.Sequential(
|
49 |
+
nn.Conv3d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
50 |
+
nn.BatchNorm3d(self.expansion * planes)
|
51 |
+
)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
out = self.dropout1(F.relu(self.bn1(self.conv1(x))))
|
55 |
+
out = self.dropout2(F.relu(self.bn2(self.conv2(out))))
|
56 |
+
out = self.bn3(self.conv3(out))
|
57 |
+
out += self.shortcut(x)
|
58 |
+
out = F.relu(out)
|
59 |
+
return out
|
60 |
+
|
61 |
+
|
62 |
+
# Define the Transformer Block
|
63 |
+
class TransformerBlock(nn.Module):
|
64 |
+
def __init__(self, feature_size, nhead, num_layers):
|
65 |
+
super(TransformerBlock, self).__init__()
|
66 |
+
self.transformer = nn.TransformerEncoder(
|
67 |
+
nn.TransformerEncoderLayer(d_model=feature_size, nhead=nhead),
|
68 |
+
num_layers=num_layers
|
69 |
+
)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
orig_shape = x.shape # Save original shape
|
73 |
+
x = x.flatten(2) # Flatten spatial dimensions
|
74 |
+
x = x.permute(2, 0, 1) # Reshape for the transformer (Seq, Batch, Features)
|
75 |
+
x = self.transformer(x)
|
76 |
+
x = x.permute(1, 2, 0).view(*orig_shape) # Restore original shape
|
77 |
+
return x
|
78 |
+
|
79 |
+
# Define a Compression Layer
|
80 |
+
class CompressionLayer(nn.Module):
|
81 |
+
def __init__(self, in_channels, out_channels):
|
82 |
+
super(CompressionLayer, self).__init__()
|
83 |
+
self.conv1x1 = nn.Conv3d(in_channels, out_channels, kernel_size=1)
|
84 |
+
self.bn = nn.BatchNorm3d(out_channels)
|
85 |
+
self.relu = nn.ReLU(inplace=True)
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
x = self.conv1x1(x)
|
89 |
+
x = self.bn(x)
|
90 |
+
x = self.relu(x)
|
91 |
+
return x
|
92 |
+
|
93 |
+
|
94 |
+
# Define the Enhanced ResNet with hybrid architecture
|
95 |
+
class ResNet3D_Transformer(nn.Module):
|
96 |
+
def __init__(self, in_channels, block, num_blocks, num_classes=1, dropout_prob=0.1):
|
97 |
+
super(ResNet3D_Transformer, self).__init__()
|
98 |
+
self.in_planes = 64
|
99 |
+
|
100 |
+
self.initial_layers = nn.Sequential(
|
101 |
+
nn.Conv3d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False),
|
102 |
+
nn.BatchNorm3d(64),
|
103 |
+
nn.ReLU(inplace=True),
|
104 |
+
nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
|
105 |
+
)
|
106 |
+
|
107 |
+
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
|
108 |
+
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
|
109 |
+
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2, dilation=2)
|
110 |
+
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
111 |
+
|
112 |
+
self.compression = CompressionLayer(512 * block.expansion, 256)
|
113 |
+
self.transformer_block = TransformerBlock(feature_size=256, nhead=8, num_layers=1) # change to 4
|
114 |
+
# self.gem_pooling = GeM(p=3.0, eps=1e-6)
|
115 |
+
self.dropout = nn.Dropout(dropout_prob)
|
116 |
+
|
117 |
+
|
118 |
+
self.classifier = nn.Linear(256, num_classes)
|
119 |
+
|
120 |
+
def _make_layer(self, block, planes, num_blocks, stride, dilation=1):
|
121 |
+
strides = [stride] + [1] * (num_blocks - 1)
|
122 |
+
layers = []
|
123 |
+
for s in strides:
|
124 |
+
layers.append(block(self.in_planes, planes, s, dilation))
|
125 |
+
self.in_planes = planes * block.expansion
|
126 |
+
return nn.Sequential(*layers)
|
127 |
+
|
128 |
+
def forward(self, x):
|
129 |
+
x = x.permute(0, 4, 3, 2, 1)
|
130 |
+
x = self.initial_layers(x)
|
131 |
+
x = self.layer1(x)
|
132 |
+
x = self.layer2(x)
|
133 |
+
x = self.layer3(x)
|
134 |
+
x = self.layer4(x)
|
135 |
+
# Compress and transform
|
136 |
+
x = self.compression(x)
|
137 |
+
x = self.transformer_block(x)
|
138 |
+
# Global average pooling
|
139 |
+
x = torch.mean(x, dim=[2, 3, 4])
|
140 |
+
|
141 |
+
# Classify
|
142 |
+
x = self.dropout(x) # Apply dropout before classification
|
143 |
+
x = self.classifier(x)
|
144 |
+
return x
|
145 |
+
|
146 |
+
|
147 |
+
def count_parameters(model):
|
148 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
149 |
+
|
150 |
+
|
151 |
+
if __name__ == "__main__":
|
152 |
+
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
153 |
+
device = 'cpu'
|
154 |
+
num_classes = 1
|
155 |
+
num_input_channels = 20 # Number of input channels
|
156 |
+
model = ResNet3D_Transformer(num_input_channels, DilatedBottleneck, [3, 4, 6, 3], num_classes=num_classes).to(device)
|
157 |
+
grid_size = 41 # Assuming the input grid size (for example, 41x41x41x19)
|
158 |
+
|
159 |
+
start = time.time()
|
160 |
+
num_params = count_parameters(model)
|
161 |
+
print(f"Number of parameters in the model: {num_params}")
|
162 |
+
print(model)
|
163 |
+
|
164 |
+
dummy_input = torch.randn(64, grid_size, grid_size, grid_size, num_input_channels).to(device)
|
165 |
+
dummy_input = dummy_input.float().to(device)
|
166 |
+
|
167 |
+
|
168 |
+
output = model(dummy_input)
|
169 |
+
|
170 |
+
print("Output shape:", output.shape)
|
171 |
+
print(output)
|
172 |
+
print(f'total time: {(time.time() - start)/60} mins')
|
173 |
+
|
ParaSurf/model/__pycache__/ParaSurf_model.cpython-310.pyc
ADDED
Binary file (6.29 kB). View file
|
|
ParaSurf/model/__pycache__/ParaSurf_model.cpython-39.pyc
ADDED
Binary file (6.36 kB). View file
|
|
ParaSurf/model/__pycache__/dataset.cpython-310.pyc
ADDED
Binary file (2.76 kB). View file
|
|
ParaSurf/model/__pycache__/dataset.cpython-39.pyc
ADDED
Binary file (2.76 kB). View file
|
|
ParaSurf/model/dataset.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random, os
|
3 |
+
from scipy import sparse
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from tqdm import tqdm
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
import h5py
|
8 |
+
|
9 |
+
|
10 |
+
class dataset(Dataset):
|
11 |
+
def __init__(self, train_file, batch_size, data_path, grid_size, training, feature_vector_lentgh, feature_names=['deepsite']):
|
12 |
+
|
13 |
+
super(dataset, self).__init__()
|
14 |
+
self.training = training
|
15 |
+
self.feature_vector_lentgh = feature_vector_lentgh
|
16 |
+
# in testing mode training file is not read
|
17 |
+
if self.training:
|
18 |
+
with open(train_file) as f:
|
19 |
+
self.train_lines = f.readlines()
|
20 |
+
random.shuffle(self.train_lines)
|
21 |
+
|
22 |
+
else:
|
23 |
+
self.train_lines = []
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
random.shuffle(self.train_lines)
|
28 |
+
|
29 |
+
self.pointer_tr = 0
|
30 |
+
self.pointer_val = 0
|
31 |
+
|
32 |
+
self.batch_size = batch_size
|
33 |
+
self.data_path = data_path
|
34 |
+
self.grid_size = grid_size
|
35 |
+
self.feature_names = feature_names
|
36 |
+
# if added_features is None: # resolved outside
|
37 |
+
self.nAtomTypes = 0
|
38 |
+
|
39 |
+
self.nfeats = {
|
40 |
+
'deepsite': 8,
|
41 |
+
'kalasanty': feature_vector_lentgh,
|
42 |
+
'kalasanty_with_force_fields': feature_vector_lentgh,
|
43 |
+
'kalasanty_norotgrid': 18,
|
44 |
+
'spat_protr': 1,
|
45 |
+
'spat_protr_norotgrid': 1
|
46 |
+
}
|
47 |
+
|
48 |
+
for name in feature_names:
|
49 |
+
self.nAtomTypes += self.nfeats[name]
|
50 |
+
|
51 |
+
def __len__(self):
|
52 |
+
if self.training:
|
53 |
+
return len(self.train_lines)
|
54 |
+
|
55 |
+
|
56 |
+
def __getitem__(self, index):
|
57 |
+
if self.training:
|
58 |
+
samples = self.train_lines
|
59 |
+
|
60 |
+
|
61 |
+
label, sample_file = samples[index].split()
|
62 |
+
label = int(label)
|
63 |
+
base_name, prot, sample = sample_file.split('/')
|
64 |
+
|
65 |
+
feats = np.zeros((self.grid_size, self.grid_size, self.grid_size, self.nAtomTypes))
|
66 |
+
feat_cnt = 0
|
67 |
+
|
68 |
+
for name in self.feature_names:
|
69 |
+
if 'deepsite' == name:
|
70 |
+
data = np.load(os.path.join(self.data_path, base_name + '_' + name, prot, sample), allow_pickle=True)
|
71 |
+
elif 'kalasanty' == name:
|
72 |
+
data = sparse.load_npz(os.path.join(self.data_path, base_name, prot, sample[:-1] + 'z'))
|
73 |
+
data = np.reshape(np.array(data.todense()), (self.grid_size, self.grid_size, self.grid_size, self.nfeats['kalasanty']))
|
74 |
+
elif 'kalasanty_with_force_fields' == name:
|
75 |
+
data = sparse.load_npz(os.path.join(self.data_path, base_name, prot, sample[:-1] + 'z'))
|
76 |
+
data = np.reshape(np.array(data.todense()), (self.grid_size, self.grid_size, self.grid_size, self.nfeats['kalasanty_with_force_fields']))
|
77 |
+
elif 'spat_protr' in name:
|
78 |
+
data = np.load(os.path.join(self.data_path, base_name + '_' + name, prot, sample), allow_pickle=True)
|
79 |
+
else:
|
80 |
+
print('unknown feat')
|
81 |
+
|
82 |
+
if len(data) == 3:
|
83 |
+
data = data[2] # prosoxh, mono sto scPDB, gia thn wra (sto kalasanty den exw points, normals)
|
84 |
+
|
85 |
+
feats[:, :, :, feat_cnt:feat_cnt + self.nfeats[name]] = data
|
86 |
+
feat_cnt += self.nfeats[name]
|
87 |
+
|
88 |
+
if feat_cnt != self.nAtomTypes:
|
89 |
+
print('error !')
|
90 |
+
|
91 |
+
# Modified code with explicit strides: Because pytorch does not handle negative samples
|
92 |
+
if self.training:
|
93 |
+
rot_axis = random.randint(1, 3)
|
94 |
+
feats_copy = feats.copy()
|
95 |
+
if rot_axis == 1:
|
96 |
+
feats_copy = np.rot90(feats_copy, random.randint(0, 3), axes=(0, 1))
|
97 |
+
elif rot_axis == 2:
|
98 |
+
feats_copy = np.rot90(feats_copy, random.randint(0, 3), axes=(0, 2))
|
99 |
+
elif rot_axis == 3:
|
100 |
+
feats_copy = np.rot90(feats_copy, random.randint(0, 3), axes=(1, 2))
|
101 |
+
feats = np.ascontiguousarray(feats_copy)
|
102 |
+
|
103 |
+
if np.isnan(np.sum(feats)):
|
104 |
+
print('nan input')
|
105 |
+
|
106 |
+
return feats, label
|
107 |
+
|
ParaSurf/model_weights/README.md
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## *ParaSurf Model Weights*
|
2 |
+
|
3 |
+
[Download ParaSurf model weights](https://drive.google.com/drive/folders/1Kpehru9SnWsl7_Wq93WuI_o7f8wrPgpI?usp=drive_link)
|
4 |
+
|
5 |
+
Best model weights for the 3 benchmark datasets:
|
6 |
+
* PECAN Dataset
|
7 |
+
* Paragraph Expanded
|
8 |
+
* Paragraph Expanded (Heavy Chains Only)
|
9 |
+
* Paragraph Expanded (Light Chains Only)
|
10 |
+
* MIPE Dataset
|
11 |
+
|
ParaSurf/preprocess/README.md
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# **Feature Extraction - Preprocessing phase**
|
2 |
+
|
3 |
+
This guide outlines the steps needed to generate the ParaSurf 41x41x41x22 input feature vector for training. By following these steps, you will create a dataset ready for training, organized in the specified folder structure.
|
4 |
+
### Step 1: Clean the Antibody-Antigen Complex
|
5 |
+
Remove ions, ligands, and water molecules from the antibody-antigen complex and rearrange atom IDs within the PDB structure.
|
6 |
+
```bash
|
7 |
+
# Clean the antibody-antigen complex
|
8 |
+
python clean_dataset.py
|
9 |
+
```
|
10 |
+
|
11 |
+
### Step 2: Sanity Check for Interaction
|
12 |
+
Verify that at least one antibody heavy atom is within 4.5Å of any antigen heavy atom, ensuring proximity-based interactions.
|
13 |
+
```bash
|
14 |
+
# Run sanity check
|
15 |
+
python check_rec_ant_touch.py
|
16 |
+
```
|
17 |
+
### Step 3: Generate Molecular Surface Points
|
18 |
+
Create the molecular surface for each receptor in the training folder using DMS software. These surface points will serve as a basis for feature extraction.
|
19 |
+
```bash
|
20 |
+
# Generate molecular surface points
|
21 |
+
python create_surfpoints.py
|
22 |
+
```
|
23 |
+
|
24 |
+
### Step 4: Generate ParaSurf Input Feature Grids (41x41x41x22)
|
25 |
+
```bash
|
26 |
+
# Create the 3D feature grids for each surface point generated in Step 3. Each feature grid includes 22 channels with essential structural and electrostatic information.
|
27 |
+
python create_input_features.py
|
28 |
+
```
|
29 |
+
|
30 |
+
### Step 5: Prepare .proteins Files
|
31 |
+
Generate .proteins files for training, validation, and testing. These files list all receptors (antibodies) to be used in each dataset split.
|
32 |
+
```bash
|
33 |
+
# Create train/val/test .proteins files
|
34 |
+
python create_proteins_file.py
|
35 |
+
```
|
36 |
+
|
37 |
+
### Step 6: Create .samples Files
|
38 |
+
Generate .samples files, each listing paths to feature files created in Step 4. These files act as a link between features and the training pipeline.
|
39 |
+
```bash
|
40 |
+
# Generate .samples files for network training
|
41 |
+
python create_sample_files.py
|
42 |
+
```
|
43 |
+
|
44 |
+
## **Folder Structure After Preprocessing**
|
45 |
+
|
46 |
+
|
47 |
+
After completing the above steps, the resulting folder structure should be organized as follows:
|
48 |
+
```bash
|
49 |
+
├── test_data
|
50 |
+
│ ├── datasets
|
51 |
+
│ │ ├── PECAN_TRAIN.samples
|
52 |
+
│ │ ├── PECAN_TRAIN.proteins
|
53 |
+
│ │ ├── PECAN_VAL.proteins
|
54 |
+
│ │ ├── PECAN_TEST.proteins
|
55 |
+
│ │ └── ...
|
56 |
+
├── feats
|
57 |
+
│ ├── PECAN_22
|
58 |
+
│ ├── Paragraph_Expanded_22
|
59 |
+
│ └── MIPE_22
|
60 |
+
├── surfpoints
|
61 |
+
│ ├── PECAN
|
62 |
+
│ │ └── TRAIN
|
63 |
+
│ ├── Paragraph_Expanded
|
64 |
+
│ │ └── TRAIN
|
65 |
+
│ ├── MIPE
|
66 |
+
│ │ └── TRAIN
|
67 |
+
└── pdbs # already created from ParaSurf/create_datasets_from_csv
|
68 |
+
```
|
69 |
+
|
70 |
+
|
71 |
+
Now we are ready for training!
|
ParaSurf/preprocess/__pycache__/check_empty_features.cpython-310.pyc
ADDED
Binary file (2.72 kB). View file
|
|
ParaSurf/preprocess/__pycache__/check_empty_features.cpython-39.pyc
ADDED
Binary file (2.53 kB). View file
|
|
ParaSurf/preprocess/__pycache__/clean_dataset.cpython-310.pyc
ADDED
Binary file (1.04 kB). View file
|
|
ParaSurf/preprocess/__pycache__/clean_dataset.cpython-39.pyc
ADDED
Binary file (1.01 kB). View file
|
|
ParaSurf/preprocess/check_empty_features.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
def remove_empty_features(feats_folder, pdbs_path, surf_path, log_file_path="removed_complexes_log.txt"):
|
4 |
+
"""
|
5 |
+
Checks each subfolder in the base folder for files. If a subfolder is empty, removes it along with
|
6 |
+
associated files from `data_path` and `surf_path` and logs the removals.
|
7 |
+
|
8 |
+
Parameters:
|
9 |
+
- feats_folder (str): The main directory containing subfolders with the features to check.
|
10 |
+
- data_path (str): Path where receptor and antigen PDB files are located.
|
11 |
+
- surf_path (str): Path where surface points files are located.
|
12 |
+
- log_file_path (str): Path for the log file to track removed folders and files. Default is 'removed_folders_log.txt'.
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
- total_empty_folders (int): Count of empty folders removed.
|
16 |
+
"""
|
17 |
+
|
18 |
+
# Identify all subfolders in the base folder
|
19 |
+
subfolders = [d for d in os.listdir(feats_folder) if os.path.isdir(os.path.join(feats_folder, d))]
|
20 |
+
empty_folders = []
|
21 |
+
|
22 |
+
# Open log file to record removed folders
|
23 |
+
with open(log_file_path, 'w') as log_file:
|
24 |
+
log_file.write("Log of Removed Folders and Files\n")
|
25 |
+
log_file.write("=" * 30 + "\n")
|
26 |
+
|
27 |
+
# Check each subfolder and remove if empty
|
28 |
+
for folder in subfolders:
|
29 |
+
path = os.path.join(feats_folder, folder)
|
30 |
+
if not any(os.path.isfile(os.path.join(path, i)) for i in os.listdir(path)):
|
31 |
+
empty_folders.append(folder)
|
32 |
+
pdb_code = folder.split('_')[0]
|
33 |
+
|
34 |
+
# Define paths to the files to be removed
|
35 |
+
rec_file = os.path.join(pdbs_path, pdb_code + '_receptor_1.pdb')
|
36 |
+
antigen_file = os.path.join(pdbs_path, pdb_code + '_antigen_1_1.pdb')
|
37 |
+
surf_file = os.path.join(surf_path, pdb_code + '_receptor_1.surfpoints')
|
38 |
+
|
39 |
+
# Remove the empty folder and associated files
|
40 |
+
os.rmdir(path)
|
41 |
+
if os.path.exists(rec_file):
|
42 |
+
os.remove(rec_file)
|
43 |
+
if os.path.exists(antigen_file):
|
44 |
+
os.remove(antigen_file)
|
45 |
+
if os.path.exists(surf_file):
|
46 |
+
os.remove(surf_file)
|
47 |
+
|
48 |
+
# Log each removal
|
49 |
+
log_file.write(f"{pdb_code} complex removed since no features found.\n")
|
50 |
+
|
51 |
+
total_empty_folders = len(empty_folders)
|
52 |
+
# Delete the log file if no folders were removed
|
53 |
+
if total_empty_folders == 0:
|
54 |
+
os.remove(log_file_path)
|
55 |
+
print("\nAll complexes have features!!!")
|
56 |
+
else:
|
57 |
+
print(f"Total empty folders removed: {total_empty_folders}")
|
58 |
+
print(f"Details logged in {log_file_path}")
|
59 |
+
|
60 |
+
return total_empty_folders
|
61 |
+
|
62 |
+
# Example usage
|
63 |
+
if __name__ == '__main__':
|
64 |
+
user = os.getenv('USER')
|
65 |
+
pdbs_path = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/eraseme/TRAIN'
|
66 |
+
surf_path = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/surf_points/eraseme/TRAIN'
|
67 |
+
feats_path = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/feats/eraseme_22'
|
68 |
+
remove_empty_features(feats_path, pdbs_path, surf_path)
|
ParaSurf/preprocess/check_rec_ant_touch.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
def locate_receptor_binding_site_atoms(receptor_pdb_file, antigen_pdb_file, distance_cutoff=4):
|
8 |
+
rec_coordinates = []
|
9 |
+
with open(receptor_pdb_file, 'r') as file:
|
10 |
+
for line in file:
|
11 |
+
if line.startswith("ATOM"):
|
12 |
+
x = float(line[30:38].strip())
|
13 |
+
y = float(line[38:46].strip())
|
14 |
+
z = float(line[46:54].strip())
|
15 |
+
rec_coordinates.append((x, y, z))
|
16 |
+
|
17 |
+
ant_coordinates = []
|
18 |
+
with open(antigen_pdb_file, 'r') as file:
|
19 |
+
for line in file:
|
20 |
+
if line.startswith("ATOM"):
|
21 |
+
x = float(line[30:38].strip())
|
22 |
+
y = float(line[38:46].strip())
|
23 |
+
z = float(line[46:54].strip())
|
24 |
+
ant_coordinates.append((x, y, z))
|
25 |
+
|
26 |
+
# Create a list to store the final coordinates
|
27 |
+
final_coordinates = []
|
28 |
+
|
29 |
+
# Compare each coordinate from rec_coordinates with each coordinate from ant_coordinates
|
30 |
+
for rec_coord in rec_coordinates:
|
31 |
+
for ant_coord in ant_coordinates:
|
32 |
+
if math.dist(rec_coord, ant_coord) < distance_cutoff:
|
33 |
+
final_coordinates.append(rec_coord)
|
34 |
+
break # Break the inner loop if a match is found to avoid duplicate entries
|
35 |
+
|
36 |
+
# sanity check
|
37 |
+
for coor in final_coordinates:
|
38 |
+
if coor not in rec_coordinates:
|
39 |
+
print('BINDING SITE COORDINATE NOT IN RECEPTORs COORDINATES!!!!!!')
|
40 |
+
return final_coordinates, rec_coordinates
|
41 |
+
|
42 |
+
|
43 |
+
def check_receptor_antigen_interactions(pdb_dir, distance_cutoff=6, log_file="interaction_issues.txt"):
|
44 |
+
"""
|
45 |
+
:param pdb_dir: directory with receptor and antigen pdb files
|
46 |
+
:param distance_cutoff: the distance cutoff for binding site
|
47 |
+
:param log_file: the file where issues will be logged
|
48 |
+
:return: It checks if the receptor and antigen are in contact with each other
|
49 |
+
"""
|
50 |
+
all_successful = True # A flag to track if all pairs are correct
|
51 |
+
|
52 |
+
# Open the log file for writing
|
53 |
+
with open(log_file, 'w') as log:
|
54 |
+
log.write("Receptor-Antigen Interaction Issues Log\n")
|
55 |
+
log.write("=====================================\n")
|
56 |
+
|
57 |
+
non_interacting_pdbs = 0
|
58 |
+
for pdb_file in tqdm(os.listdir(pdb_dir)):
|
59 |
+
pdb_id = pdb_file.split('_')[0]
|
60 |
+
cur_rec_pdb = os.path.join(pdb_dir, f'{pdb_id}_receptor_1.pdb')
|
61 |
+
cur_ant_pdb = os.path.join(pdb_dir, f'{pdb_id}_antigen_1_1.pdb')
|
62 |
+
|
63 |
+
if os.path.exists(cur_rec_pdb) and os.path.exists(cur_ant_pdb):
|
64 |
+
final, rec = locate_receptor_binding_site_atoms(cur_rec_pdb, cur_ant_pdb, distance_cutoff)
|
65 |
+
if len(final) == 0:
|
66 |
+
non_interacting_pdbs += 1
|
67 |
+
log.write(f'\nNON-INTERACTING PAIRS!!!: problem with {pdb_id}.pdb. {pdb_id}_receptor_1.pdb and '
|
68 |
+
f' {pdb_id}_antigen_1_1.pdb files are removed.\n')
|
69 |
+
os.remove(cur_rec_pdb)
|
70 |
+
os.remove(cur_ant_pdb)
|
71 |
+
all_successful = False # Mark as unsuccessful if any issue is found
|
72 |
+
|
73 |
+
# Check if everything was successful
|
74 |
+
if all_successful:
|
75 |
+
print("Success! All receptors interact with their associated antigens.")
|
76 |
+
# since no issue s were found we can remove the log file
|
77 |
+
os.remove(log_file)
|
78 |
+
else:
|
79 |
+
print(f'\n ~~~~~ Total pdbs found with issues: {non_interacting_pdbs} and are removed from the folder ~~~~~\n')
|
80 |
+
log.write(f'\n\n ~~~~~ Total pdbs found with issues: {non_interacting_pdbs} ~~~~~')
|
81 |
+
print(f"Some receptors do not interact with their antigens. Issues logged in {log_file}.")
|
82 |
+
|
83 |
+
|
84 |
+
# example usage
|
85 |
+
if __name__ == '__main__':
|
86 |
+
user = os.getenv('USER')
|
87 |
+
pdb_dir = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/eraseme/TRAIN'
|
88 |
+
index = pdb_dir.split('/')[-1]
|
89 |
+
check_receptor_antigen_interactions(pdb_dir, distance_cutoff=4.5, log_file=f'{pdb_dir}/{index}_interaction_issues.txt')
|
ParaSurf/preprocess/clean_dataset.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from ParaSurf.utils.remove_hydrogens_from_pdb import remove_hydrogens_from_pdb_folder
|
3 |
+
from ParaSurf.utils.remove_HETATMS_from_receptors import remove_hetatm_from_pdb_folder
|
4 |
+
from ParaSurf.utils.reaarange_atom_id import process_pdb_files_in_folder
|
5 |
+
|
6 |
+
|
7 |
+
def clean_dataset(dataset_path_with_pdbs):
|
8 |
+
"""
|
9 |
+
:param dataset_path_with_pdbs:
|
10 |
+
:return: a cleaned dataset ready to be processed for training purposes with 3 steps of filtering
|
11 |
+
"""
|
12 |
+
data_path = dataset_path_with_pdbs
|
13 |
+
|
14 |
+
# step1: remove hydrogens
|
15 |
+
remove_hydrogens_from_pdb_folder(input_folder=data_path,
|
16 |
+
output_folder=data_path)
|
17 |
+
|
18 |
+
# step2: remove HETATMS only from the receptors
|
19 |
+
remove_hetatm_from_pdb_folder(input_folder=data_path,
|
20 |
+
output_folder=data_path)
|
21 |
+
|
22 |
+
# step3: re-arrange the atom_id of each pdb
|
23 |
+
process_pdb_files_in_folder(folder_path=data_path)
|
24 |
+
|
25 |
+
if __name__ == "__main__":
|
26 |
+
user = os.getenv('USER')
|
27 |
+
clean_dataset(f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/eraseme/TRAIN')
|
ParaSurf/preprocess/create_input_features.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from multiprocessing import Pool
|
2 |
+
from multiprocessing import Lock
|
3 |
+
import time, os
|
4 |
+
import numpy as np
|
5 |
+
from Bio.PDB.PDBParser import PDBParser
|
6 |
+
from ParaSurf.utils.bsite_lib import readSurfpoints, readSurfpoints_with_residues, dist_point_from_lig
|
7 |
+
from ParaSurf.utils.features import KalasantyFeaturizer, KalasantyFeaturizer_with_force_fields
|
8 |
+
from scipy import sparse
|
9 |
+
from tqdm import tqdm
|
10 |
+
import warnings
|
11 |
+
from ParaSurf.utils.distance_coords import locate_receptor_binding_site_residues
|
12 |
+
from check_empty_features import remove_empty_features
|
13 |
+
|
14 |
+
# Ignore warnings
|
15 |
+
warnings.filterwarnings('ignore')
|
16 |
+
|
17 |
+
|
18 |
+
lock = Lock() # Instantiate a Lock for thread safety.
|
19 |
+
|
20 |
+
|
21 |
+
def balanced_sampling(surf_file, protein_file, lig_files, cutoff=4.5):
|
22 |
+
"""
|
23 |
+
Returns a subset of equal positive and negative samples from surface points in `surf_file`, with positive samples
|
24 |
+
selected from residues close to the antigen.
|
25 |
+
|
26 |
+
Parameters:
|
27 |
+
- surf_file: Path to the file with protein surface points.
|
28 |
+
- protein_file: Path to the protein structure file (e.g., PDB format).
|
29 |
+
- lig_files: List of ligand (antigen) structure file paths.
|
30 |
+
- cutoff: Distance cutoff in Ångstroms for defining binding residues (default is 4).
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
- Balanced samples including features and labels for the selected surface points.
|
34 |
+
"""
|
35 |
+
all_lig_coords = []
|
36 |
+
for lig_file in lig_files:
|
37 |
+
with lock: # Locks the parser to ensure thread safety.
|
38 |
+
lig = parser.get_structure('antigen', lig_file)
|
39 |
+
lig_coords = np.array([atom.get_coord() for atom in lig.get_atoms()])
|
40 |
+
all_lig_coords.append(lig_coords)
|
41 |
+
|
42 |
+
points, normals = readSurfpoints(surf_file) # modified by me
|
43 |
+
# create the residue groups for the whole protein
|
44 |
+
all_rec_residues = readSurfpoints_with_residues(surf_file)
|
45 |
+
|
46 |
+
|
47 |
+
# find the bind site residues
|
48 |
+
bind_site_rec_residues = locate_receptor_binding_site_residues(protein_file, lig_file, distance_cutoff=cutoff)
|
49 |
+
|
50 |
+
# gather all distances
|
51 |
+
# Create an array to store the minimum distance of each point to any ligand atom
|
52 |
+
dist_from_lig = np.full(len(points), np.inf)
|
53 |
+
near_lig = np.zeros(len(points), dtype=bool)
|
54 |
+
|
55 |
+
# Update distances only for points in bind site residues
|
56 |
+
bind_site_indices = [item for i in bind_site_rec_residues for item in all_rec_residues[i]['idx']]
|
57 |
+
bind_site_indices_set = set(bind_site_indices) # Convert to set for fast lookup
|
58 |
+
|
59 |
+
# Loop through ligand coordinates and update distances for binding site points
|
60 |
+
# IMPORTANT Step because if a residue belongs to the binding site that DOES NOT mean that all the atom of this
|
61 |
+
# residue belongs the binding site (<6 armstrong to the ligand). So here we check from the binding site residues which
|
62 |
+
# atoms actually bind (<6 armstrong to the ligand)
|
63 |
+
for lig_coords in all_lig_coords:
|
64 |
+
for i, p in enumerate(points):
|
65 |
+
if i in bind_site_indices_set:
|
66 |
+
dist = dist_point_from_lig(p, lig_coords) # Adjust this function if necessary
|
67 |
+
if dist < dist_from_lig[i]:
|
68 |
+
dist_from_lig[i] = dist
|
69 |
+
near_lig[i] = dist < cutoff
|
70 |
+
|
71 |
+
# Filter positive indices to include only those near a ligand
|
72 |
+
pos_idxs = np.array([idx for idx in bind_site_indices if near_lig[idx]])
|
73 |
+
|
74 |
+
# If there are more positive indices than allowed, select the best ones based on the distance
|
75 |
+
if len(pos_idxs) > maxPosSamples:
|
76 |
+
pos_idxs = pos_idxs[np.argsort(dist_from_lig[pos_idxs])[:maxPosSamples]]
|
77 |
+
|
78 |
+
# Select the negative samples
|
79 |
+
all_neg_samples = [idx for idx, i in enumerate(points) if idx not in pos_idxs]
|
80 |
+
|
81 |
+
# Calculate number of negative samples to match the number of positive samples
|
82 |
+
num_neg_samples = min(len(all_neg_samples), len(pos_idxs))
|
83 |
+
|
84 |
+
neg_idxs = np.array(all_neg_samples)
|
85 |
+
if len(neg_idxs) > num_neg_samples:
|
86 |
+
neg_downsampled = np.random.choice(neg_idxs, num_neg_samples, replace=False)
|
87 |
+
else:
|
88 |
+
neg_downsampled = neg_idxs
|
89 |
+
|
90 |
+
# Concatenate positive and negative indices
|
91 |
+
sample_idxs = np.concatenate((pos_idxs, neg_downsampled))
|
92 |
+
|
93 |
+
# Shuffle the indices to ensure randomness
|
94 |
+
np.random.shuffle(sample_idxs)
|
95 |
+
|
96 |
+
# create the sample labels
|
97 |
+
# Convert pos_idxs to a set for faster membership testing
|
98 |
+
pos_set = set(pos_idxs)
|
99 |
+
|
100 |
+
# Use list comprehension to create labels
|
101 |
+
sample_labels = [i in pos_set for i in sample_idxs]
|
102 |
+
if feat_type == 'kalasanty':
|
103 |
+
featurizer = KalasantyFeaturizer(protein_file, protonate, gridSize, voxelSize, use_protrusion, protr_radius)
|
104 |
+
elif feat_type == 'kalasanty_with_force_fields':
|
105 |
+
featurizer = KalasantyFeaturizer_with_force_fields(protein_file, protonate, gridSize, voxelSize, use_protrusion, protr_radius,
|
106 |
+
add_atom_radius_features=add_atoms_radius_ff_features)
|
107 |
+
|
108 |
+
feature_vector_length = featurizer.channels.shape[1]
|
109 |
+
with open(feature_vector_length_tmp_path, 'w') as file:
|
110 |
+
file.write(str(feature_vector_length))
|
111 |
+
|
112 |
+
for i, sample in enumerate(sample_idxs):
|
113 |
+
features = featurizer.grid_feats(points[sample], normals[sample], rotate_grid)
|
114 |
+
if np.count_nonzero(features) == 0:
|
115 |
+
print('Zero features', protein_file.rsplit('/', 1)[1][:-4], i, points[sample], normals[sample])
|
116 |
+
|
117 |
+
yield features, sample_labels[i], points[sample], normals[sample]
|
118 |
+
|
119 |
+
|
120 |
+
def samples_per_prot(prot):
|
121 |
+
"""
|
122 |
+
Generates and saves balanced surface point samples for a given protein.
|
123 |
+
|
124 |
+
Parameters:
|
125 |
+
- prot: Protein identifier for which features are being generated.
|
126 |
+
|
127 |
+
Saves each sample as a sparse matrix or NumPy array in `feats_path`.
|
128 |
+
"""
|
129 |
+
prot_path = os.path.join(feats_path, prot)
|
130 |
+
|
131 |
+
# Check if directory exists and has files, if not create it
|
132 |
+
if not os.path.exists(prot_path):
|
133 |
+
os.makedirs(prot_path)
|
134 |
+
elif os.listdir(prot_path):
|
135 |
+
return
|
136 |
+
|
137 |
+
surf_file = os.path.join(surf_path, f"{prot}.surfpoints")
|
138 |
+
protein_file = os.path.join(pdbs_path, f"{prot}.pdb")
|
139 |
+
if not os.path.exists(protein_file):
|
140 |
+
protein_file = os.path.join(pdbs_path, f"{prot}.mol2")
|
141 |
+
|
142 |
+
receptor_id = prot_path.split('_')[-1]
|
143 |
+
antigen_prefix = prot.split('_')[0]
|
144 |
+
|
145 |
+
# Using set for faster membership checks
|
146 |
+
files_set = set(os.listdir(pdbs_path))
|
147 |
+
lig_files = [os.path.join(pdbs_path, f) for f in files_set if f"{antigen_prefix}_antigen_{receptor_id}" in f]
|
148 |
+
|
149 |
+
try:
|
150 |
+
cnt = 0
|
151 |
+
for features, y, point, normal in balanced_sampling(surf_file, protein_file, lig_files, cutoff=cutoff):
|
152 |
+
samples_file_name = os.path.join(prot_path, f"sample{cnt}_{int(y)}")
|
153 |
+
|
154 |
+
if feat_type == 'deepsite':
|
155 |
+
with open(f"{samples_file_name}.npy", 'w') as f:
|
156 |
+
np.save(f, (point, normal, features.astype(np.float16)))
|
157 |
+
elif feat_type == 'kalasanty' or feat_type == 'kalasanty_with_force_fields':
|
158 |
+
sparse_mat = sparse.coo_matrix(features.flatten())
|
159 |
+
sparse.save_npz(f"{samples_file_name}.npz", sparse_mat)
|
160 |
+
|
161 |
+
cnt += 1
|
162 |
+
|
163 |
+
print(f'Saved "{cnt}" samples for "{prot}".')
|
164 |
+
|
165 |
+
except Exception as e:
|
166 |
+
print(f'Exception occurred while processing "{prot}". Error message: "{e}".')
|
167 |
+
|
168 |
+
|
169 |
+
seed = 10 # random seed
|
170 |
+
num_cores = 6 # Set this to the number of cores you wish to use
|
171 |
+
maxPosSamples = 800 # maximum number of positive samples per protein
|
172 |
+
gridSize = 41 # size of grid (16x16x16)
|
173 |
+
voxelSize = 1 # size of voxel, e.g. 1 angstrom, if 2A we lose details, so leave it to 1
|
174 |
+
cutoff = 4.5 # cutoff threshold in Armstrong's 6 for general PPIs, 4.5 for antibody antigen databases
|
175 |
+
feature_vector_length_tmp_path = '/tmp/feature_vector_length.txt'
|
176 |
+
# feat_type = 'kalasanty' # select featurizer
|
177 |
+
feat_type = 'kalasanty_with_force_fields'
|
178 |
+
add_atoms_radius_ff_features = True # If you want to add the atom radius features that correspond to the force fields
|
179 |
+
rotate_grid = True # whether to rotate the grid (ignore)
|
180 |
+
use_protrusion = False # ignore
|
181 |
+
protr_radius = 10 # ignore
|
182 |
+
protonate = True # if protein pdbs are not protonated (do not have Hydrogens) set it to True
|
183 |
+
|
184 |
+
user = os.getenv('USER')
|
185 |
+
pdbs_path = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/eraseme/TRAIN' # input folder with protein pdbs for training
|
186 |
+
surf_path = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/surfpoints/eraseme/TRAIN' # input folder with surface points for training
|
187 |
+
feats_path = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/feats/eraseme' # training features folder
|
188 |
+
|
189 |
+
|
190 |
+
if not os.path.exists(feats_path):
|
191 |
+
os.makedirs(feats_path)
|
192 |
+
|
193 |
+
np.random.seed(seed)
|
194 |
+
|
195 |
+
all_proteins = [f.rsplit('.', 1)[0] for f in os.listdir(surf_path)]
|
196 |
+
#
|
197 |
+
# in case the procedure stacks use the 3 lines below
|
198 |
+
completed = [f.rsplit('.', 1)[0] for f in os.listdir(feats_path)]
|
199 |
+
all_proteins = [i for i in all_proteins if i not in completed]
|
200 |
+
print(len(all_proteins))
|
201 |
+
|
202 |
+
|
203 |
+
parser = PDBParser(PERMISSIVE=1) # PERMISSIVE=1 allowing more flexibility in handling non-standard or problematic entries in PDB files during parsing.
|
204 |
+
|
205 |
+
start = time.time()
|
206 |
+
with Pool(num_cores) as pool: # Use a specified number of CPU cores
|
207 |
+
list(tqdm(pool.imap(samples_per_prot, all_proteins), total=len(all_proteins)))
|
208 |
+
print(f'Total preprocess time: {(time.time() - start)/60} mins')
|
209 |
+
|
210 |
+
###################################################################################
|
211 |
+
# Instead of using Pool and imap, iterate through all_proteins with a for loop for easy debugging
|
212 |
+
# for prot in all_proteins:
|
213 |
+
# try:
|
214 |
+
# samples_per_prot(prot)
|
215 |
+
# except Exception as e:
|
216 |
+
# print(f'Error processing protein {prot}: {e}')
|
217 |
+
# break
|
218 |
+
|
219 |
+
|
220 |
+
# the last number at the out_path will be the total number of the feature vector
|
221 |
+
if os.path.exists(feature_vector_length_tmp_path):
|
222 |
+
with open(feature_vector_length_tmp_path, 'r') as file:
|
223 |
+
feature_vector_length = int(file.read().strip())
|
224 |
+
feats_path_new = f'{feats_path}_{feature_vector_length}'
|
225 |
+
os.rename(feats_path, feats_path_new)
|
226 |
+
os.remove(feature_vector_length_tmp_path)
|
227 |
+
|
228 |
+
|
229 |
+
# remove empty features if found
|
230 |
+
remove_empty_features(feats_folder=feats_path_new, pdbs_path=pdbs_path, surf_path=surf_path)
|
ParaSurf/preprocess/create_proteins_file.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
# Here we create a .proteins file that has all the proteins==receptors that we are working with.
|
5 |
+
cases = ['TRAIN', 'VAL', 'TEST'] # change to ['TRAIN_VAL', 'TEST'] for MIPE
|
6 |
+
user = os.getenv('USER')
|
7 |
+
for case in cases:
|
8 |
+
pdbs_path = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/eraseme/{case}'
|
9 |
+
proteins_file = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/datasets/eraseme_{case}.proteins' # run for train val and test
|
10 |
+
|
11 |
+
# Create directories if they don't exist
|
12 |
+
os.makedirs(pdbs_path, exist_ok=True)
|
13 |
+
os.makedirs(os.path.dirname(proteins_file), exist_ok=True)
|
14 |
+
|
15 |
+
receptors = []
|
16 |
+
|
17 |
+
for prot in os.listdir(pdbs_path):
|
18 |
+
prot_name = prot.split('.')[0]
|
19 |
+
if 'rec' in prot:
|
20 |
+
receptors.append(prot_name + '\n')
|
21 |
+
|
22 |
+
with open(proteins_file,'w') as f:
|
23 |
+
f.writelines(receptors)
|
ParaSurf/preprocess/create_sample_files.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, random
|
2 |
+
|
3 |
+
|
4 |
+
user = os.getenv('USER')
|
5 |
+
|
6 |
+
feats_path = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/feats/eraseme_22' # input folder with protein grids (training features)
|
7 |
+
proteins_file = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/datasets/eraseme_TRAIN.proteins' # input file with a list of train proteins
|
8 |
+
samples_file = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/datasets/eraseme_TRAIN.samples' # output file with respective training samples info (class_label + sample_path)
|
9 |
+
seed = 1
|
10 |
+
|
11 |
+
with open(proteins_file, 'r') as f:
|
12 |
+
proteins = f.readlines()
|
13 |
+
|
14 |
+
sample_lines = []
|
15 |
+
feats_prefix = feats_path.rsplit('/')[-1]
|
16 |
+
|
17 |
+
for prot in proteins:
|
18 |
+
prot = prot[:-1]
|
19 |
+
prot_feats_path = os.path.join(feats_path, prot)
|
20 |
+
if not os.path.isdir(prot_feats_path):
|
21 |
+
print('No features for ', prot)
|
22 |
+
continue
|
23 |
+
for sample in os.listdir(prot_feats_path):
|
24 |
+
cls_idx = sample[-5]
|
25 |
+
sample_lines.append(cls_idx + ' ' + feats_prefix + '/' + prot + '/' + sample + '\n')
|
26 |
+
|
27 |
+
random.seed(seed)
|
28 |
+
random.shuffle(sample_lines)
|
29 |
+
|
30 |
+
with open(samples_file, 'w') as f:
|
31 |
+
f.writelines(sample_lines)
|
ParaSurf/preprocess/create_surfpoints.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import os
|
3 |
+
from tqdm import tqdm
|
4 |
+
from ParaSurf.utils.fix_surfpoints_format_issues import process_surfpoints_directory
|
5 |
+
|
6 |
+
def generate_molecular_surface(input_path, out_path):
|
7 |
+
"""
|
8 |
+
Generates the molecular surface for protein structures in PDB files using the DMS tool.
|
9 |
+
|
10 |
+
Parameters:
|
11 |
+
- input_path (str): Path to the input directory containing protein PDB files.
|
12 |
+
- out_path (str): Path to the output directory where generated surface points files will be saved.
|
13 |
+
|
14 |
+
Process:
|
15 |
+
- The function iterates over receptor PDB files in the input path.
|
16 |
+
- For each receptor file, it checks if a corresponding surface points file already exists in the output directory.
|
17 |
+
- If the surface points file does not exist, it generates the file using the DMS tool with a density of 0.5 Å.
|
18 |
+
|
19 |
+
Outputs:
|
20 |
+
- Each receptor file generates a surface points file saved in `out_path`.
|
21 |
+
"""
|
22 |
+
|
23 |
+
if not os.path.exists(out_path):
|
24 |
+
os.makedirs(out_path)
|
25 |
+
|
26 |
+
start = time.time()
|
27 |
+
for f in tqdm(os.listdir(input_path), desc="Generating surface points"):
|
28 |
+
if 'antigen' in f:
|
29 |
+
continue
|
30 |
+
|
31 |
+
surfpoints_file = os.path.join(out_path, f[:-3] + 'surfpoints')
|
32 |
+
if os.path.exists(surfpoints_file):
|
33 |
+
continue
|
34 |
+
|
35 |
+
print(f"Processing {f}")
|
36 |
+
os.system(f'dms {os.path.join(input_path, f)} -d 0.5 -n -o {surfpoints_file}')
|
37 |
+
|
38 |
+
# Calculate and print statistics
|
39 |
+
rec_count = sum(1 for receptor in os.listdir(input_path) if 'receptor' in receptor)
|
40 |
+
total_time = (time.time() - start) / 60 # Convert time to minutes
|
41 |
+
print(f'Total time to create surfpoints for {rec_count} receptors: {total_time:.2f} mins')
|
42 |
+
|
43 |
+
|
44 |
+
# Example usage
|
45 |
+
if __name__ == '__main__':
|
46 |
+
user = os.getenv('USER')
|
47 |
+
pdbs_path = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/pdbs/eraseme/TRAIN' # input folder with protein pdbs for training
|
48 |
+
surfpoints_path =f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data/surfpoints/eraseme/TRAIN'
|
49 |
+
|
50 |
+
# create the molecular surface
|
51 |
+
generate_molecular_surface(
|
52 |
+
input_path= pdbs_path,
|
53 |
+
out_path= surfpoints_path
|
54 |
+
)
|
55 |
+
|
56 |
+
# fix some format issues with the .surfpoints files
|
57 |
+
process_surfpoints_directory(surfpoints_path)
|
ParaSurf/train/V_domain_results.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
from utils import calculate_metrics
|
4 |
+
import os, re
|
5 |
+
|
6 |
+
|
7 |
+
def calculate_Fv_and_cdr_regions(residues_best, gt_true_label_residues, rec_name, output_path, epoch, test_csv=None,
|
8 |
+
thres=0.5):
|
9 |
+
# Load the CSV file to identify heavy and light chains, if provided
|
10 |
+
heavy_chain_name, light_chain_name = None, None
|
11 |
+
calculate_individual_cdrs = False
|
12 |
+
|
13 |
+
if test_csv:
|
14 |
+
test_csv = pd.read_csv(test_csv)
|
15 |
+
rec_info = test_csv[test_csv['pdb_code'] == rec_name]
|
16 |
+
if not rec_info.empty:
|
17 |
+
heavy_chain_name = rec_info['Heavy_chain'].iloc[0]
|
18 |
+
light_chain_name = rec_info['Light_chain'].iloc[0]
|
19 |
+
|
20 |
+
calculate_individual_cdrs = True
|
21 |
+
else:
|
22 |
+
print(f"Receptor {rec_name} not found in the test CSV.")
|
23 |
+
|
24 |
+
# Define the CDR+-2 ranges for heavy and light chains
|
25 |
+
cdr1 = list(range(25, 41)) # CDR-H/L1: 25-40
|
26 |
+
cdr2 = list(range(54, 68)) # CDR-H/L2: 54-67
|
27 |
+
cdr3 = list(range(103, 120)) # CDR-H/L3: 103-119
|
28 |
+
framework_ranges = list(range(1, 25)) + list(range(41, 54)) + list(range(68, 103)) + list(range(120, 129))
|
29 |
+
|
30 |
+
# Initialize dictionaries
|
31 |
+
CDRH1, CDRH2, CDRH3 = {}, {}, {}
|
32 |
+
CDRL1, CDRL2, CDRL3 = {}, {}, {}
|
33 |
+
FRAMEWORK = {}
|
34 |
+
|
35 |
+
# Loop over the predictions to populate the dictionaries
|
36 |
+
for residue, data in residues_best.items():
|
37 |
+
# Split the residue into components (e.g., '30_L_C' -> ['30', 'L', 'C'])
|
38 |
+
residue_parts = residue.split('_')
|
39 |
+
residue_num = int(re.findall(r'\d+', residue_parts[0])[0])
|
40 |
+
chain_name = residue_parts[1]
|
41 |
+
|
42 |
+
# Assign residue to the corresponding CDR or FRAMEWORK based on chain and residue number
|
43 |
+
if (not heavy_chain_name or chain_name == heavy_chain_name): # If no csv or matching heavy chain
|
44 |
+
if residue_num in cdr1:
|
45 |
+
CDRH1[residue] = data
|
46 |
+
elif residue_num in cdr2:
|
47 |
+
CDRH2[residue] = data
|
48 |
+
elif residue_num in cdr3:
|
49 |
+
CDRH3[residue] = data
|
50 |
+
elif residue_num in framework_ranges:
|
51 |
+
FRAMEWORK[residue] = data
|
52 |
+
|
53 |
+
if (not light_chain_name or chain_name == light_chain_name): # If no csv or matching light chain
|
54 |
+
if residue_num in cdr1:
|
55 |
+
CDRL1[residue] = data
|
56 |
+
elif residue_num in cdr2:
|
57 |
+
CDRL2[residue] = data
|
58 |
+
elif residue_num in cdr3:
|
59 |
+
CDRL3[residue] = data
|
60 |
+
elif residue_num in framework_ranges:
|
61 |
+
FRAMEWORK[residue] = data
|
62 |
+
|
63 |
+
# Helper function to calculate and save metrics for each CDR and FRAMEWORK
|
64 |
+
def calculate_and_save_metrics(cdr_dict, cdr_name, threshold=thres):
|
65 |
+
if len(cdr_dict) > 0: # To check if CDR exists in the antibody
|
66 |
+
pred_scores = np.array([[i[1]['scores']] for i in cdr_dict.items()])
|
67 |
+
pred_labels = (pred_scores > threshold).astype(int)
|
68 |
+
gt_labels = np.array([1 if residue in gt_true_label_residues else 0 for residue in cdr_dict.keys()])
|
69 |
+
|
70 |
+
if len(np.unique(gt_labels)) > 1: # Ensure both classes are present
|
71 |
+
output_results_path = os.path.join(output_path, f'{cdr_name}_results_epoch_{epoch}_{threshold}.txt')
|
72 |
+
auc_roc, accuracy, precision, recall, f1, auc_pr, conf_matrix, mcc, _, _, _ = \
|
73 |
+
calculate_metrics(gt_labels, pred_labels, pred_scores, to_save_metrics_path=output_results_path)
|
74 |
+
return auc_roc, accuracy, precision, recall, f1, auc_pr, conf_matrix, mcc
|
75 |
+
return None
|
76 |
+
|
77 |
+
# Calculate and save metrics for each CDR and FRAMEWORK only if .csv is provided
|
78 |
+
if calculate_individual_cdrs:
|
79 |
+
calculate_and_save_metrics(CDRH1, 'CDRH1')
|
80 |
+
calculate_and_save_metrics(CDRH2, 'CDRH2')
|
81 |
+
calculate_and_save_metrics(CDRH3, 'CDRH3')
|
82 |
+
calculate_and_save_metrics(CDRL1, 'CDRL1')
|
83 |
+
calculate_and_save_metrics(CDRL2, 'CDRL2')
|
84 |
+
calculate_and_save_metrics(CDRL3, 'CDRL3')
|
85 |
+
calculate_and_save_metrics(FRAMEWORK, 'FRAMEWORK')
|
86 |
+
|
87 |
+
# Calculate the metrics for the CDR+-2 region (CDRH1 + CDRH2 + CDRH3 + CDRL1 + CDRL2 + CDRL3)
|
88 |
+
cdr_plus_minus_2 = {**CDRH1, **CDRH2, **CDRH3, **CDRL1, **CDRL2, **CDRL3}
|
89 |
+
calculate_and_save_metrics(cdr_plus_minus_2, 'CDR_plus_minus_2')
|
90 |
+
|
91 |
+
# Calculate the metrics for the Fv region (CDRs + FRAMEWORK)
|
92 |
+
fv_region = {**CDRH1, **CDRH2, **CDRH3, **CDRL1, **CDRL2, **CDRL3, **FRAMEWORK}
|
93 |
+
calculate_and_save_metrics(fv_region, 'Fv')
|
94 |
+
|
95 |
+
|
96 |
+
def calculate_Fv_and_cdr_regions_only_one_chain(residues_best, gt_true_label_residues, rec_name, output_path, epoch, thres=0.5):
|
97 |
+
"""
|
98 |
+
This function calculates metrics for the Fv and CDR+-2 regions, but only for a PDB file with one chain.
|
99 |
+
The CSV file is not needed in this case, as there is only one chain.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
- residues_best: Dictionary containing residue information.
|
103 |
+
- gt_true_label_residues: List of ground truth binding residues.
|
104 |
+
- rec_name: Name of the receptor.
|
105 |
+
- output_path: Directory to save output results.
|
106 |
+
- epoch: Current epoch for model validation.
|
107 |
+
- thres: Threshold for classification (default is 0.5).
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
- Metrics calculated and saved for CDR+-2 and Fv regions.
|
111 |
+
"""
|
112 |
+
|
113 |
+
# Define the CDR+-2 and framework ranges for the single chain
|
114 |
+
cdr1 = list(range(25, 41)) # CDR1: 25-40
|
115 |
+
cdr2 = list(range(54, 68)) # CDR2: 54-67
|
116 |
+
cdr3 = list(range(103, 120)) # CDR3: 103-119
|
117 |
+
framework_ranges = list(range(1, 25)) + list(range(41, 54)) + list(range(68, 103)) + list(range(120, 129))
|
118 |
+
|
119 |
+
# Initialize dictionaries
|
120 |
+
CDR1, CDR2, CDR3 = {}, {}, {}
|
121 |
+
FRAMEWORK = {}
|
122 |
+
|
123 |
+
# Loop over the predictions to populate the dictionaries
|
124 |
+
for residue, data in residues_best.items():
|
125 |
+
# Split the residue into components (e.g., '30_L_C' -> ['30', 'L', 'C'])
|
126 |
+
residue_parts = residue.split('_')
|
127 |
+
residue_num = int(re.findall(r'\d+', residue_parts[0])[0])
|
128 |
+
|
129 |
+
# Assign residue to the corresponding CDR or FRAMEWORK based on residue number
|
130 |
+
if residue_num in cdr1:
|
131 |
+
CDR1[residue] = data
|
132 |
+
elif residue_num in cdr2:
|
133 |
+
CDR2[residue] = data
|
134 |
+
elif residue_num in cdr3:
|
135 |
+
CDR3[residue] = data
|
136 |
+
elif residue_num in framework_ranges:
|
137 |
+
FRAMEWORK[residue] = data
|
138 |
+
|
139 |
+
# Helper function to calculate and save metrics for each CDR and FRAMEWORK
|
140 |
+
def calculate_and_save_metrics(cdr_dict, cdr_name, threshold=thres):
|
141 |
+
if len(cdr_dict) > 0: # Check if CDR exists in the antibody
|
142 |
+
pred_scores = np.array([[i[1]['scores']] for i in cdr_dict.items()])
|
143 |
+
pred_labels = (pred_scores > threshold).astype(int)
|
144 |
+
gt_labels = np.array([1 if residue in gt_true_label_residues else 0 for residue in cdr_dict.keys()])
|
145 |
+
|
146 |
+
if len(np.unique(gt_labels)) > 1: # Ensure both classes are present
|
147 |
+
output_results_path = os.path.join(output_path, f'{cdr_name}_results_epoch_{epoch}_{threshold}.txt')
|
148 |
+
auc_roc, accuracy, precision, recall, f1, auc_pr, conf_matrix, mcc, _, _, _ = \
|
149 |
+
calculate_metrics(gt_labels, pred_labels, pred_scores, to_save_metrics_path=output_results_path)
|
150 |
+
return auc_roc, accuracy, precision, recall, f1, auc_pr, conf_matrix, mcc
|
151 |
+
return None
|
152 |
+
|
153 |
+
# Calculate the metrics for the CDR+-2 region (CDR1 + CDR2 + CDR3)
|
154 |
+
cdr_plus_minus_2 = {**CDR1, **CDR2, **CDR3}
|
155 |
+
calculate_and_save_metrics(cdr_plus_minus_2, 'CDR_plus_minus_2')
|
156 |
+
|
157 |
+
# Calculate the metrics for the Fv region (CDR1 + CDR2 + CDR3 + FRAMEWORK)
|
158 |
+
fv_region = {**CDR1, **CDR2, **CDR3, **FRAMEWORK}
|
159 |
+
calculate_and_save_metrics(fv_region, 'Fv')
|
ParaSurf/train/__pycache__/V_domain_results.cpython-310.pyc
ADDED
Binary file (4.54 kB). View file
|
|
ParaSurf/train/__pycache__/V_domain_results.cpython-39.pyc
ADDED
Binary file (4.81 kB). View file
|
|
ParaSurf/train/__pycache__/bsite_extraction.cpython-310.pyc
ADDED
Binary file (1.82 kB). View file
|
|
ParaSurf/train/__pycache__/bsite_extraction.cpython-39.pyc
ADDED
Binary file (1.79 kB). View file
|
|
ParaSurf/train/__pycache__/distance_coords.cpython-310.pyc
ADDED
Binary file (4.55 kB). View file
|
|
ParaSurf/train/__pycache__/distance_coords.cpython-39.pyc
ADDED
Binary file (4.6 kB). View file
|
|
ParaSurf/train/__pycache__/features.cpython-310.pyc
ADDED
Binary file (1.91 kB). View file
|
|
ParaSurf/train/__pycache__/features.cpython-39.pyc
ADDED
Binary file (1.9 kB). View file
|
|
ParaSurf/train/__pycache__/network.cpython-310.pyc
ADDED
Binary file (1.93 kB). View file
|
|
ParaSurf/train/__pycache__/network.cpython-39.pyc
ADDED
Binary file (1.94 kB). View file
|
|
ParaSurf/train/__pycache__/protein.cpython-310.pyc
ADDED
Binary file (4.47 kB). View file
|
|
ParaSurf/train/__pycache__/protein.cpython-39.pyc
ADDED
Binary file (4.5 kB). View file
|
|
ParaSurf/train/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (14 kB). View file
|
|
ParaSurf/train/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (14 kB). View file
|
|
ParaSurf/train/__pycache__/validation.cpython-310.pyc
ADDED
Binary file (7.12 kB). View file
|
|
ParaSurf/train/__pycache__/validation.cpython-39.pyc
ADDED
Binary file (7.19 kB). View file
|
|
ParaSurf/train/bsite_extraction.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from sklearn.cluster import MeanShift
|
3 |
+
|
4 |
+
|
5 |
+
class Bsite_extractor():
|
6 |
+
def __init__(self, lig_thres=0.9, bw=15):
|
7 |
+
self.T = lig_thres
|
8 |
+
self.ms = MeanShift(bandwidth=bw,bin_seeding=True,cluster_all=False,n_jobs=4)
|
9 |
+
|
10 |
+
def _cluster_points(self,prot,lig_scores):
|
11 |
+
T_new = self.T
|
12 |
+
while sum(lig_scores>=T_new) < 10 and T_new>0.3001: # at least 10 points with prob>P and P>=0.3
|
13 |
+
T_new -= 0.1
|
14 |
+
|
15 |
+
# filtered_points = prot.surf_points[lig_scores>T_new]
|
16 |
+
filtered_points = prot.surf_points[lig_scores.flatten() > T_new]
|
17 |
+
filtered_scores = lig_scores[lig_scores>T_new]
|
18 |
+
if len(filtered_points)<5:
|
19 |
+
return ()
|
20 |
+
|
21 |
+
clustering = self.ms.fit(filtered_points)
|
22 |
+
labels = clustering.labels_
|
23 |
+
|
24 |
+
unique_l,freq = np.unique(labels,return_counts=True)
|
25 |
+
|
26 |
+
if len(unique_l[freq>=5])!=0:
|
27 |
+
unique_l = unique_l[freq>=5] # keep clusters with 5 points and more
|
28 |
+
else:
|
29 |
+
return ()
|
30 |
+
|
31 |
+
if unique_l[0]==-1: # discard the "unclustered" cluster
|
32 |
+
unique_l = unique_l[1:]
|
33 |
+
|
34 |
+
clusters = [(filtered_points[labels==l],filtered_scores[labels==l]) for l in unique_l]
|
35 |
+
|
36 |
+
return clusters
|
37 |
+
|
38 |
+
def extract_bsites(self,prot,lig_scores):
|
39 |
+
clusters = self._cluster_points(prot,lig_scores)
|
40 |
+
if len(clusters)==0:
|
41 |
+
print('No binding site found!!!')
|
42 |
+
return
|
43 |
+
for cluster in clusters:
|
44 |
+
prot.add_bsite(cluster)
|
45 |
+
prot.sort_bsites()
|
46 |
+
prot.write_bsites()
|
47 |
+
|
48 |
+
|
ParaSurf/train/distance_coords.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import re
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
def locate_surface_binding_site_atoms(receptor_surf_file, antigen_pdb_file, distance_cutoff=4):
|
6 |
+
rec_coordinates = []
|
7 |
+
with open(receptor_surf_file, 'r') as file:
|
8 |
+
for line in file:
|
9 |
+
parts = line.split()
|
10 |
+
|
11 |
+
# Check for the presence of a numeric value in the 3rd element of parts
|
12 |
+
match = re.search(r'([-+]?\d*\.\d+|\d+)(?=\.)', parts[2])
|
13 |
+
if match:
|
14 |
+
numeric_value = match.group(0)
|
15 |
+
non_numeric_value = parts[2].replace(numeric_value, "")
|
16 |
+
|
17 |
+
# Update the 'parts' list
|
18 |
+
parts[2:3] = [non_numeric_value, numeric_value]
|
19 |
+
|
20 |
+
if len(parts) >= 7: # Since we added an extra element to parts, its length increased by 1
|
21 |
+
x = float(parts[3])
|
22 |
+
y = float(parts[4])
|
23 |
+
z = float(parts[5])
|
24 |
+
rec_coordinates.append((x, y, z))
|
25 |
+
|
26 |
+
ant_coordinates = []
|
27 |
+
with open(antigen_pdb_file, 'r') as file:
|
28 |
+
for line in file:
|
29 |
+
if line.startswith("ATOM"):
|
30 |
+
x = float(line[30:38].strip())
|
31 |
+
y = float(line[38:46].strip())
|
32 |
+
z = float(line[46:54].strip())
|
33 |
+
ant_coordinates.append((x, y, z))
|
34 |
+
|
35 |
+
# Create a list to store the final coordinates
|
36 |
+
final_coordinates = []
|
37 |
+
|
38 |
+
# Compare each coordinate from rec_coordinates with each coordinate from ant_coordinates
|
39 |
+
for rec_coord in rec_coordinates:
|
40 |
+
for ant_coord in ant_coordinates:
|
41 |
+
if math.dist(rec_coord, ant_coord) < distance_cutoff:
|
42 |
+
final_coordinates.append(rec_coord)
|
43 |
+
break # Break the inner loop if a match is found to avoid duplicate entries
|
44 |
+
|
45 |
+
# sanity check
|
46 |
+
for coor in final_coordinates:
|
47 |
+
if coor not in rec_coordinates:
|
48 |
+
print('BINDING SITE COORDINATE NOT IN RECEPTORs COORDINATES!!!!!!')
|
49 |
+
|
50 |
+
return final_coordinates, rec_coordinates
|
51 |
+
|
52 |
+
|
53 |
+
def locate_receptor_binding_site_atoms(receptor_pdb_file, antigen_pdb_file, distance_cutoff=4):
|
54 |
+
rec_coordinates = []
|
55 |
+
with open(receptor_pdb_file, 'r') as file:
|
56 |
+
for line in file:
|
57 |
+
if line.startswith("ATOM"):
|
58 |
+
x = float(line[30:38].strip())
|
59 |
+
y = float(line[38:46].strip())
|
60 |
+
z = float(line[46:54].strip())
|
61 |
+
rec_coordinates.append((x, y, z))
|
62 |
+
|
63 |
+
ant_coordinates = []
|
64 |
+
with open(antigen_pdb_file, 'r') as file:
|
65 |
+
for line in file:
|
66 |
+
if line.startswith("ATOM"):
|
67 |
+
x = float(line[30:38].strip())
|
68 |
+
y = float(line[38:46].strip())
|
69 |
+
z = float(line[46:54].strip())
|
70 |
+
ant_coordinates.append((x, y, z))
|
71 |
+
|
72 |
+
# Create a list to store the final coordinates
|
73 |
+
final_coordinates = []
|
74 |
+
|
75 |
+
# Compare each coordinate from rec_coordinates with each coordinate from ant_coordinates
|
76 |
+
for rec_coord in rec_coordinates:
|
77 |
+
for ant_coord in ant_coordinates:
|
78 |
+
if math.dist(rec_coord, ant_coord) < distance_cutoff:
|
79 |
+
final_coordinates.append(rec_coord)
|
80 |
+
break # Break the inner loop if a match is found to avoid duplicate entries
|
81 |
+
|
82 |
+
# sanity check
|
83 |
+
for coor in final_coordinates:
|
84 |
+
if coor not in rec_coordinates:
|
85 |
+
print('BINDING SITE COORDINATE NOT IN RECEPTORs COORDINATES!!!!!!')
|
86 |
+
return final_coordinates, rec_coordinates
|
87 |
+
|
88 |
+
|
89 |
+
def coords2pdb(coordinates, tosavepath):
|
90 |
+
with open(tosavepath, 'w') as pdb_file:
|
91 |
+
atom_number = 1
|
92 |
+
for coord in coordinates:
|
93 |
+
x, y, z = coord
|
94 |
+
pdb_file.write(f"ATOM {atom_number:5} DUM DUM A{atom_number:4} {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00\n")
|
95 |
+
|
96 |
+
atom_number += 1
|
97 |
+
if atom_number == 9999:
|
98 |
+
atom_number = 1
|
99 |
+
pdb_file.write("END")
|
100 |
+
|
101 |
+
|
102 |
+
def locate_receptor_binding_site_atoms_residue_level(receptor_file, antigen_pdb_file, distance_cutoff=4):
|
103 |
+
rec_atoms = []
|
104 |
+
chain_elements = []
|
105 |
+
with open(receptor_file, 'r') as file:
|
106 |
+
for line in file:
|
107 |
+
if line.startswith("ATOM"):
|
108 |
+
atom_id = line[6:11].strip()
|
109 |
+
atom_type = line[12:16].strip()
|
110 |
+
res_id = line[22:26].strip()
|
111 |
+
# check if there is Code for insertions of residues
|
112 |
+
insertion_code = line[26].strip()
|
113 |
+
if insertion_code:
|
114 |
+
res_id = res_id + insertion_code
|
115 |
+
res_name = line[17:20].strip()
|
116 |
+
chain_id = line[21].strip()
|
117 |
+
x = float(line[30:38].strip())
|
118 |
+
y = float(line[38:46].strip())
|
119 |
+
z = float(line[46:54].strip())
|
120 |
+
rec_atoms.append((atom_id, atom_type, res_id, res_name, chain_id, x, y, z))
|
121 |
+
chain_elements.append((atom_id, atom_type, res_id, chain_id))
|
122 |
+
|
123 |
+
ant_atoms = []
|
124 |
+
with open(antigen_pdb_file, 'r') as file:
|
125 |
+
for line in file:
|
126 |
+
if line.startswith("ATOM"):
|
127 |
+
atom_id = line[6:11].strip()
|
128 |
+
atom_type = line[12:16].strip()
|
129 |
+
res_id = line[22:26].strip()
|
130 |
+
res_name = line[17:20].strip()
|
131 |
+
chain_id = line[21].strip()
|
132 |
+
x = float(line[30:38].strip())
|
133 |
+
y = float(line[38:46].strip())
|
134 |
+
z = float(line[46:54].strip())
|
135 |
+
ant_atoms.append((atom_id, atom_type, res_id, res_name, chain_id, x, y, z))
|
136 |
+
|
137 |
+
final_atoms = []
|
138 |
+
|
139 |
+
for rec_atom in rec_atoms:
|
140 |
+
for ant_atom in ant_atoms:
|
141 |
+
if math.dist(rec_atom[5:], ant_atom[5:]) < distance_cutoff:
|
142 |
+
final_atoms.append(rec_atom)
|
143 |
+
break
|
144 |
+
|
145 |
+
rec_atoms = np.array([atom[5:] for atom in rec_atoms])
|
146 |
+
final_atoms_ = np.array([atom[5:] for atom in final_atoms])
|
147 |
+
final_elements = np.array([atom[:5] for atom in final_atoms])
|
148 |
+
|
149 |
+
return final_atoms_, rec_atoms, final_elements
|
150 |
+
|
151 |
+
|
152 |
+
def coords2pdb_residue_level(coordinates, tosavepath, elements):
|
153 |
+
with open(tosavepath, 'w') as pdb_file:
|
154 |
+
for i, atom in enumerate(coordinates):
|
155 |
+
atom_id, atom_type, res_id, res_name, chain_id = elements[i]
|
156 |
+
|
157 |
+
# Separate the numeric part from the insertion code (if any)
|
158 |
+
if res_id[-1].isalpha(): # Check if the last character is an insertion code
|
159 |
+
res_num = res_id[:-1] # Numeric part of the residue
|
160 |
+
insertion_code = res_id[-1] # Insertion code (e.g., 'A' in '30A')
|
161 |
+
else:
|
162 |
+
res_num = res_id
|
163 |
+
insertion_code = " " # No insertion code
|
164 |
+
|
165 |
+
x, y, z = atom
|
166 |
+
|
167 |
+
# Write to the PDB file with the correct formatting
|
168 |
+
pdb_file.write(
|
169 |
+
f"ATOM {int(atom_id):5} {atom_type:<4} {res_name} {chain_id}{int(res_num):4}{insertion_code:1} {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00\n")
|
170 |
+
|
171 |
+
pdb_file.write("END\n")
|
172 |
+
|
173 |
+
|
ParaSurf/train/features.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ParaSurf.utils import bio_data_featurizer
|
2 |
+
from ParaSurf.train.utils import rotation
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
class KalasantyFeaturizer:
|
7 |
+
def __init__(self, gridSize, voxelSize):
|
8 |
+
grid_limit = (gridSize / 2 - 0.5) * voxelSize
|
9 |
+
grid_radius = grid_limit * np.sqrt(3)
|
10 |
+
self.neigh_radius = 4 + grid_radius # 4 > 2*R_vdw
|
11 |
+
# self.neigh_radius = 2*grid_radius # 4 > 2*R_vdw
|
12 |
+
self.featurizer = bio_data_featurizer.Featurizer(save_molecule_codes=False)
|
13 |
+
self.grid_resolution = voxelSize
|
14 |
+
self.max_dist = (gridSize - 1) * voxelSize / 2
|
15 |
+
|
16 |
+
def get_channels(self, mol, add_forcefields, add_atom_radius_features=False):
|
17 |
+
if not add_forcefields:
|
18 |
+
self.coords, self.channels = self.featurizer.get_features(mol) # returns only heavy atoms
|
19 |
+
else:
|
20 |
+
self.coords, self.channels = self.featurizer.get_features_with_force_fields(mol, add_atom_radius=add_atom_radius_features) # returns only heavy atoms
|
21 |
+
|
22 |
+
|
23 |
+
def get_channels_with_forcefields(self, mol):
|
24 |
+
self.coords, self.channels = self.featurizer.get_features_with_force_fields(mol, add_atom_radius=True) # returns only heavy atoms
|
25 |
+
|
26 |
+
def grid_feats(self, point, normal, mol_coords):
|
27 |
+
neigh_atoms = np.sqrt(np.sum((mol_coords - point) ** 2, axis=1)) < self.neigh_radius
|
28 |
+
Q = rotation(normal)
|
29 |
+
Q_inv = np.linalg.inv(Q)
|
30 |
+
transf_coords = np.transpose(mol_coords[neigh_atoms] - point)
|
31 |
+
rotated_mol_coords = np.matmul(Q_inv, transf_coords)
|
32 |
+
features = \
|
33 |
+
bio_data_featurizer.make_grid(np.transpose(rotated_mol_coords), self.channels[neigh_atoms], self.grid_resolution,
|
34 |
+
self.max_dist)[0]
|
35 |
+
|
36 |
+
return features
|
37 |
+
|
ParaSurf/train/network.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from ParaSurf.train.features import KalasantyFeaturizer
|
5 |
+
from ParaSurf.model import ParaSurf_model
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
class Network:
|
10 |
+
def __init__(self, model_path, gridSize, feature_channels, voxelSize=1, device="cuda"):
|
11 |
+
self.gridSize = gridSize # Does this change?
|
12 |
+
|
13 |
+
if device == 'cuda' and torch.cuda.is_available():
|
14 |
+
self.device = torch.device("cuda")
|
15 |
+
else:
|
16 |
+
self.device = torch.device("cpu")
|
17 |
+
|
18 |
+
# load model
|
19 |
+
self.model = ParaSurf_model.ResNet3D_Transformer(in_channels=feature_channels, block=ParaSurf_model.DilatedBottleneck,
|
20 |
+
num_blocks=[3, 4, 6, 3], num_classes=1)
|
21 |
+
|
22 |
+
# load weights
|
23 |
+
self.model.load_state_dict(torch.load(model_path, map_location=self.device)) #
|
24 |
+
# model to eval mode and to device
|
25 |
+
self.model = self.model.to(self.device).eval()
|
26 |
+
|
27 |
+
self.featurizer = KalasantyFeaturizer(gridSize, voxelSize) # it is the "rules of the game"
|
28 |
+
self.feature_channels = feature_channels
|
29 |
+
|
30 |
+
def get_lig_scores(self, prot, batch_size, add_forcefields, add_atom_radius_features):
|
31 |
+
|
32 |
+
self.featurizer.get_channels(prot.mol, add_forcefields, add_atom_radius_features)
|
33 |
+
|
34 |
+
|
35 |
+
lig_scores = []
|
36 |
+
input_data = torch.zeros((batch_size, self.gridSize, self.gridSize, self.gridSize, self.feature_channels), device=self.device)
|
37 |
+
|
38 |
+
batch_cnt = 0
|
39 |
+
for p, n in zip(prot.surf_points, prot.surf_normals):
|
40 |
+
input_data[batch_cnt,:,:,:,:] = torch.tensor(self.featurizer.grid_feats(p, n, prot.heavy_atom_coords), device=self.device)
|
41 |
+
batch_cnt += 1
|
42 |
+
if batch_cnt == batch_size:
|
43 |
+
with torch.no_grad():
|
44 |
+
output = self.model(input_data)
|
45 |
+
output = torch.sigmoid(output)
|
46 |
+
lig_scores.extend(output.cpu().numpy())
|
47 |
+
batch_cnt = 0
|
48 |
+
|
49 |
+
if batch_cnt > 0:
|
50 |
+
with torch.no_grad():
|
51 |
+
output = self.model(input_data[:batch_cnt])
|
52 |
+
output = torch.sigmoid(output)
|
53 |
+
lig_scores.extend(output.cpu().numpy())
|
54 |
+
|
55 |
+
print(np.array(lig_scores).shape)
|
56 |
+
return np.array(lig_scores)
|
57 |
+
|
58 |
+
|
ParaSurf/train/protein.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, numpy as np
|
2 |
+
import shutil
|
3 |
+
# import pybel
|
4 |
+
from openbabel import pybel
|
5 |
+
from ParaSurf.train.utils import simplify_dms
|
6 |
+
from ParaSurf.utils.fix_surfpoints_format_issues import process_surfpoints_directory
|
7 |
+
|
8 |
+
|
9 |
+
class Protein_pred:
|
10 |
+
def __init__(self, prot_file, save_path, seed=None, atom_points_threshold=5, locate_only_surface=False):
|
11 |
+
|
12 |
+
prot_id = prot_file.split('/')[-1].split('.')[0]
|
13 |
+
self.save_path = os.path.join(save_path, prot_id)
|
14 |
+
|
15 |
+
if not os.path.exists(self.save_path):
|
16 |
+
os.makedirs(self.save_path)
|
17 |
+
|
18 |
+
self.mol = next(pybel.readfile(prot_file.split('.')[-1], prot_file))
|
19 |
+
self.atom_points_thresh = atom_points_threshold
|
20 |
+
|
21 |
+
surfpoints_file = os.path.join(self.save_path, prot_id + '.surfpoints')
|
22 |
+
|
23 |
+
# we have all the surfpoints ready from the preprocessing step
|
24 |
+
if not os.path.exists(surfpoints_file):
|
25 |
+
os.system('dms ' + prot_file + ' -d 0.1 -n -o ' + surfpoints_file) #default value for d is 0.2
|
26 |
+
# fix any format issues
|
27 |
+
print('\nfixing surfpoints format ...')
|
28 |
+
process_surfpoints_directory(self.save_path)
|
29 |
+
# raise Exception('probably DMS not installed')
|
30 |
+
|
31 |
+
# locate surface: if we want the final coordinates to have the receptor atoms or we want just the surface atoms
|
32 |
+
self.surf_points, self.surf_normals = simplify_dms(surfpoints_file, seed=seed,
|
33 |
+
locate_surface=locate_only_surface)
|
34 |
+
|
35 |
+
self.heavy_atom_coords = np.array([atom.coords for atom in self.mol.atoms if atom.atomicnum > 1])
|
36 |
+
|
37 |
+
self.binding_sites = []
|
38 |
+
if prot_file.endswith('pdb'):
|
39 |
+
with open(prot_file, 'r') as f:
|
40 |
+
lines = f.readlines()
|
41 |
+
self.heavy_atom_lines = [line for line in lines if line[:4] == 'ATOM' and line.split()[2][0] != 'H']
|
42 |
+
if len(self.heavy_atom_lines) != len(self.heavy_atom_coords):
|
43 |
+
ligand_in_pdb = len([line for line in lines if line.startswith('HETATM')]) > 0
|
44 |
+
if ligand_in_pdb:
|
45 |
+
raise Exception('Ligand found in PDBfile. Please remove it to procede.')
|
46 |
+
else:
|
47 |
+
raise Exception('Incosistency between Coords and PDBLines')
|
48 |
+
else:
|
49 |
+
raise IOError('Protein file should be .pdb')
|
50 |
+
|
51 |
+
def _surfpoints_to_atoms(self, surfpoints):
|
52 |
+
close_atoms = np.zeros(len(surfpoints), dtype=int)
|
53 |
+
for p, surf_coord in enumerate(surfpoints):
|
54 |
+
dist = np.sqrt(np.sum((self.heavy_atom_coords - surf_coord) ** 2, axis=1))
|
55 |
+
close_atoms[p] = np.argmin(dist)
|
56 |
+
|
57 |
+
return np.unique(close_atoms)
|
58 |
+
|
59 |
+
def add_bsite(self, cluster): # cluster -> tuple: (surf_points,scores)
|
60 |
+
atom_idxs = self._surfpoints_to_atoms(cluster[0])
|
61 |
+
self.binding_sites.append(Bsite(self.heavy_atom_coords, atom_idxs, cluster[1]))
|
62 |
+
|
63 |
+
def sort_bsites(self):
|
64 |
+
avg_scores = np.array([bsite.score for bsite in self.binding_sites])
|
65 |
+
sorted_idxs = np.flip(np.argsort(avg_scores), axis=0)
|
66 |
+
self.binding_sites = [self.binding_sites[idx] for idx in sorted_idxs]
|
67 |
+
|
68 |
+
def write_bsites(self):
|
69 |
+
if not os.path.exists(self.save_path):
|
70 |
+
os.makedirs(self.save_path)
|
71 |
+
|
72 |
+
centers = np.array([bsite.center for bsite in self.binding_sites])
|
73 |
+
np.savetxt(os.path.join(self.save_path, 'centers.txt'), centers, delimiter=' ', fmt='%10.3f')
|
74 |
+
|
75 |
+
pocket_count = 0
|
76 |
+
for i, bsite in enumerate(self.binding_sites):
|
77 |
+
outlines = [self.heavy_atom_lines[idx] for idx in bsite.atom_idxs]
|
78 |
+
if len(outlines) > self.atom_points_thresh:
|
79 |
+
pocket_count += 1
|
80 |
+
with open(os.path.join(self.save_path, 'pocket' + str(pocket_count) + '.pdb'), 'w') as f:
|
81 |
+
f.writelines(outlines)
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
class Bsite:
|
87 |
+
def __init__(self, mol_coords, atom_idxs, scores):
|
88 |
+
self.coords = mol_coords[atom_idxs]
|
89 |
+
self.center = np.average(self.coords, axis=0)
|
90 |
+
self.score = np.average(scores)
|
91 |
+
self.atom_idxs = atom_idxs
|
92 |
+
|
ParaSurf/train/train.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time, random
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.optim as optim
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from torch.optim.lr_scheduler import StepLR
|
9 |
+
from ParaSurf.model import ParaSurf_model
|
10 |
+
from ParaSurf.model.dataset import dataset
|
11 |
+
from validation import validate_residue_level
|
12 |
+
import wandb
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
|
16 |
+
user = os.getenv('USER')
|
17 |
+
base_dir = f'/home/{user}/PycharmProjects/github_projects/ParaSurf/test_data'
|
18 |
+
CFG = {
|
19 |
+
'name': 'ParaSurf train dummy eraseme folder',
|
20 |
+
'initial_lr': 0.0001,
|
21 |
+
'epochs': 100,
|
22 |
+
'batch_size': 64,
|
23 |
+
'grid': 41, # don't change
|
24 |
+
'seed': 42,
|
25 |
+
'wandb': False,
|
26 |
+
'debug': False,
|
27 |
+
'model_weights': None, # if ('' or None )is given then training starts from scratch
|
28 |
+
'num_workers': 8,
|
29 |
+
'feat_type': ['kalasanty_with_force_fields'],
|
30 |
+
'feats_path': os.path.join(base_dir, 'feats'),
|
31 |
+
'TRAIN_samples': os.path.join(base_dir, 'datasets/eraseme_TRAIN.samples'),
|
32 |
+
'VAL_proteins_list': os.path.join(base_dir, 'datasets/eraseme_VAL.proteins'),
|
33 |
+
'VAL_proteins': os.path.join(base_dir, 'pdbs/eraseme/VAL'),
|
34 |
+
'save_dir': f'/home/{user}/PycharmProjects/github_projects/ParaSurf/ParaSurf/train/eraseme/model_weights'
|
35 |
+
}
|
36 |
+
|
37 |
+
if CFG['wandb']:
|
38 |
+
wandb.init(project='ParaSurf', entity='your_project_name', config=CFG, name=CFG['name'])
|
39 |
+
|
40 |
+
|
41 |
+
# Set random seed for repeatability
|
42 |
+
def set_seed(seed_value):
|
43 |
+
"""Set seed for reproducibility."""
|
44 |
+
random.seed(seed_value)
|
45 |
+
np.random.seed(seed_value)
|
46 |
+
torch.manual_seed(seed_value)
|
47 |
+
if torch.cuda.is_available():
|
48 |
+
torch.cuda.manual_seed_all(seed_value)
|
49 |
+
|
50 |
+
|
51 |
+
set_seed(CFG['seed'])
|
52 |
+
|
53 |
+
with open(CFG['TRAIN_samples']) as f:
|
54 |
+
lines = f.readlines()
|
55 |
+
feature_vector_lentgh = int(lines[0].split()[1].split('/')[0].split('_')[-1])
|
56 |
+
|
57 |
+
# model
|
58 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
59 |
+
model = ParaSurf_model.ResNet3D_Transformer(in_channels=feature_vector_lentgh,
|
60 |
+
block=ParaSurf_model.DilatedBottleneck,
|
61 |
+
num_blocks=[3, 4, 6, 3], num_classes=1).to(device)
|
62 |
+
print(model)
|
63 |
+
criterion = nn.BCEWithLogitsLoss()
|
64 |
+
optimizer = optim.Adam(model.parameters(), lr=CFG['initial_lr'])
|
65 |
+
|
66 |
+
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)
|
67 |
+
|
68 |
+
# Load Dataset
|
69 |
+
train_set = dataset(CFG['TRAIN_samples'], CFG['batch_size'], CFG['feats_path'], CFG['grid'], True,
|
70 |
+
feature_vector_lentgh, CFG['feat_type'])
|
71 |
+
train_loader = DataLoader(dataset=train_set, batch_size=CFG['batch_size'], shuffle=True,
|
72 |
+
num_workers=CFG['num_workers'])
|
73 |
+
|
74 |
+
# Training
|
75 |
+
if not os.path.exists(CFG['save_dir']):
|
76 |
+
os.makedirs(CFG['save_dir'])
|
77 |
+
|
78 |
+
# check if pretrain weights are loaded and start the epoch from there
|
79 |
+
if CFG['model_weights'] and os.path.exists(CFG['model_weights']):
|
80 |
+
model.load_state_dict(torch.load(CFG['model_weights']))
|
81 |
+
start_epoch = int(CFG['model_weights'].split('/')[-1].split('.')[0].split('_')[1]) + 1
|
82 |
+
print(f"\nLoading weights from epoch {start_epoch-1} ...\n")
|
83 |
+
print(f"Start training for epoch {start_epoch} ...")
|
84 |
+
else:
|
85 |
+
print('\nStart training from scratch ...')
|
86 |
+
start_epoch = 0
|
87 |
+
|
88 |
+
|
89 |
+
train_losses = [] # to keep track of training losses
|
90 |
+
|
91 |
+
for epoch in range(start_epoch, CFG['epochs']):
|
92 |
+
start = time.time()
|
93 |
+
model.train()
|
94 |
+
total_loss = 0.0
|
95 |
+
|
96 |
+
correct_train_predictions = 0 # Reset for each epoch
|
97 |
+
total_train_samples = 0 # Reset for each epoch
|
98 |
+
|
99 |
+
for i, (inputs, labels) in tqdm(enumerate(train_loader), total=len(train_loader)):
|
100 |
+
inputs, labels = inputs.float().to(device), labels.to(device).unsqueeze(1)
|
101 |
+
total_train_samples += labels.shape[0]
|
102 |
+
optimizer.zero_grad()
|
103 |
+
|
104 |
+
# scaler option
|
105 |
+
# with torch.cuda.amp.autocast():
|
106 |
+
outputs = model(inputs)
|
107 |
+
loss = criterion(outputs, labels.float())
|
108 |
+
|
109 |
+
loss.backward()
|
110 |
+
|
111 |
+
# Apply gradient clipping
|
112 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
|
113 |
+
|
114 |
+
optimizer.step()
|
115 |
+
|
116 |
+
total_loss += loss.item()
|
117 |
+
|
118 |
+
predicted_train = torch.sigmoid(outputs) > 0.5
|
119 |
+
correct_train_predictions += (predicted_train == labels).sum().item()
|
120 |
+
|
121 |
+
# Print the training loss every 100 batches
|
122 |
+
if (i + 1) % 100 == 0:
|
123 |
+
print(f"Epoch: {epoch + 1} Batch: {i + 1} Train Loss: {loss.item():.3f}")
|
124 |
+
|
125 |
+
# if CFG['wandb']:
|
126 |
+
# wandb.log({'Mini Batch Train Loss': loss.item()})
|
127 |
+
|
128 |
+
if CFG['debug']:
|
129 |
+
break
|
130 |
+
|
131 |
+
|
132 |
+
avg_train_loss = total_loss / len(train_loader)
|
133 |
+
train_accuracy = correct_train_predictions / total_train_samples # Calculate training accuracy
|
134 |
+
|
135 |
+
train_losses.append(avg_train_loss)
|
136 |
+
|
137 |
+
cur_model_weight_path = os.path.join(CFG['save_dir'], f'epoch_{epoch}.pth')
|
138 |
+
torch.save(model.state_dict(), cur_model_weight_path)
|
139 |
+
|
140 |
+
avg_auc_roc, avg_precision, avg_recall, avg_auc_pr, avg_f1 = validate_residue_level(valset=CFG['VAL_proteins_list'],
|
141 |
+
modelweights=cur_model_weight_path,
|
142 |
+
test_folder=CFG['VAL_proteins'],
|
143 |
+
epoch=epoch + 1,
|
144 |
+
feat_type=CFG['feat_type'],
|
145 |
+
feature_vector_lentgh=feature_vector_lentgh)
|
146 |
+
|
147 |
+
|
148 |
+
print(
|
149 |
+
f"Epoch {epoch + 1}/{CFG['epochs']} - Train Loss: {avg_train_loss:.3f}, Train Accuracy: {train_accuracy:.3f},"
|
150 |
+
f"Val_AUC-ROC: {avg_auc_roc:.3f}, Val_Precision: {avg_precision:.3f}, Val_Recall: {avg_recall:.3f},"
|
151 |
+
f" Val_AUC_pr: {avg_auc_pr:.3f}, Val_F1: {avg_f1}")
|
152 |
+
|
153 |
+
print(f"Total epoch time: {(time.time() - start) / 60:.3f} mins")
|
154 |
+
|
155 |
+
if CFG['wandb']:
|
156 |
+
wandb.log({'Epoch': epoch,
|
157 |
+
'Train Loss': avg_train_loss,
|
158 |
+
'Train Accuracy': train_accuracy,
|
159 |
+
'Valid AUC-ROC': avg_auc_roc,
|
160 |
+
'Valid Precision': avg_precision,
|
161 |
+
'Valid Recall': avg_recall,
|
162 |
+
'Valid AUC-pr': avg_auc_pr,
|
163 |
+
'Valid F1': avg_f1
|
164 |
+
})
|
165 |
+
|
166 |
+
|
167 |
+
# Step the scheduler
|
168 |
+
scheduler.step()
|
169 |
+
|
170 |
+
# # Finish the wandb run at the end of all epochs for the current iteration
|
171 |
+
if CFG['wandb']:
|
172 |
+
wandb.finish()
|
ParaSurf/train/utils.py
ADDED
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings, os
|
2 |
+
import numpy as np
|
3 |
+
from scipy.spatial.distance import euclidean
|
4 |
+
from sklearn.cluster import KMeans
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from sklearn.metrics import roc_auc_score, roc_curve
|
7 |
+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, auc, precision_recall_curve, \
|
8 |
+
confusion_matrix, matthews_corrcoef
|
9 |
+
|
10 |
+
|
11 |
+
def mol2_reader(mol_file): # does not handle H2
|
12 |
+
if mol_file[-4:] != 'mol2':
|
13 |
+
raise Exception("File's extension is not .mol2")
|
14 |
+
|
15 |
+
with open(mol_file, 'r') as f:
|
16 |
+
lines = f.readlines()
|
17 |
+
|
18 |
+
for i, line in enumerate(lines):
|
19 |
+
if '@<TRIPOS>ATOM' in line:
|
20 |
+
first_atom_idx = i + 1
|
21 |
+
if '@<TRIPOS>BOND' in line:
|
22 |
+
last_atom_idx = i - 1
|
23 |
+
|
24 |
+
return lines[first_atom_idx:last_atom_idx + 1]
|
25 |
+
|
26 |
+
|
27 |
+
# maybe change to read_surfpoints_new ??
|
28 |
+
def readSurfPoints(surf_file):
|
29 |
+
with open(surf_file, 'r') as f:
|
30 |
+
lines = f.readlines()
|
31 |
+
|
32 |
+
lines = [l for l in lines if len(l.split()) > 7]
|
33 |
+
|
34 |
+
if len(lines) > 100000:
|
35 |
+
warnings.warn('{} has too many points'.format(surf_file))
|
36 |
+
return
|
37 |
+
if len(lines) == 0:
|
38 |
+
warnings.warn('{} is empty'.format(surf_file))
|
39 |
+
return
|
40 |
+
|
41 |
+
coords = np.zeros((len(lines), 3))
|
42 |
+
normals = np.zeros((len(lines), 3))
|
43 |
+
for i, l in enumerate(lines):
|
44 |
+
parts = l.split()
|
45 |
+
|
46 |
+
try:
|
47 |
+
coords[i, 0] = float(parts[3])
|
48 |
+
coords[i, 1] = float(parts[4])
|
49 |
+
coords[i, 2] = float(parts[5])
|
50 |
+
normals[i, 0] = float(parts[8])
|
51 |
+
normals[i, 1] = float(parts[9])
|
52 |
+
normals[i, 2] = float(parts[10])
|
53 |
+
except:
|
54 |
+
coords[i, 0] = float(parts[2][-8:])
|
55 |
+
coords[i, 1] = float(parts[3])
|
56 |
+
coords[i, 2] = float(parts[4])
|
57 |
+
normals[i, 0] = float(parts[7])
|
58 |
+
normals[i, 1] = float(parts[8])
|
59 |
+
normals[i, 2] = float(parts[9])
|
60 |
+
|
61 |
+
return coords, normals
|
62 |
+
|
63 |
+
|
64 |
+
def readSurfPoints_with_receptor_atoms(surf_file):
|
65 |
+
with open(surf_file, 'r') as f:
|
66 |
+
lines = f.readlines()
|
67 |
+
|
68 |
+
# lines = [l for l in lines if len(l.split()) > 7]
|
69 |
+
lines = [l for l in lines]
|
70 |
+
if len(lines) > 100000:
|
71 |
+
warnings.warn('{} has too many points'.format(surf_file))
|
72 |
+
return
|
73 |
+
if len(lines) == 0:
|
74 |
+
warnings.warn('{} is empty'.format(surf_file))
|
75 |
+
return
|
76 |
+
|
77 |
+
coords = np.zeros((len(lines), 3))
|
78 |
+
normals = np.zeros((len(lines), 3))
|
79 |
+
|
80 |
+
# First, ensure each line has at least 11 parts by filling with zeros
|
81 |
+
for i in range(len(lines)):
|
82 |
+
parts = lines[i].split()
|
83 |
+
while len(parts) < 11:
|
84 |
+
# Fill with '0' initially
|
85 |
+
parts.append('0')
|
86 |
+
lines[i] = ' '.join(parts)
|
87 |
+
|
88 |
+
# Modify lines according to the specified rules
|
89 |
+
for i in range(len(lines)):
|
90 |
+
parts = lines[i].split()
|
91 |
+
# Check if there are zeros that need to be replaced
|
92 |
+
if '0' in parts:
|
93 |
+
if i > 0: # Use previous line if not the first line
|
94 |
+
prev_parts = lines[i - 1].split()
|
95 |
+
parts = [prev_parts[j] if part == '0' else part for j, part in enumerate(parts)]
|
96 |
+
elif i < len(lines) - 1: # Use next line if not the last line
|
97 |
+
next_parts = lines[i + 1].split()
|
98 |
+
parts = [next_parts[j] if part == '0' else part for j, part in enumerate(parts)]
|
99 |
+
lines[i] = ' '.join(parts)
|
100 |
+
|
101 |
+
try:
|
102 |
+
coords[i, 0] = float(parts[3])
|
103 |
+
coords[i, 1] = float(parts[4])
|
104 |
+
coords[i, 2] = float(parts[5])
|
105 |
+
normals[i, 0] = float(parts[8])
|
106 |
+
normals[i, 1] = float(parts[9])
|
107 |
+
normals[i, 2] = float(parts[10])
|
108 |
+
except:
|
109 |
+
coords[i, 0] = float(parts[2][-8:])
|
110 |
+
coords[i, 1] = float(parts[3])
|
111 |
+
coords[i, 2] = float(parts[4])
|
112 |
+
normals[i, 0] = float(parts[7])
|
113 |
+
normals[i, 1] = float(parts[8])
|
114 |
+
normals[i, 2] = float(parts[9])
|
115 |
+
|
116 |
+
return coords, normals
|
117 |
+
|
118 |
+
|
119 |
+
def simplify_dms(init_surf_file, seed=None, locate_surface=True):
|
120 |
+
# Here we decide if we want the final coordinates to have the receptor atoms or we want just
|
121 |
+
# the surface atoms
|
122 |
+
if locate_surface:
|
123 |
+
coords, normals = readSurfPoints(init_surf_file)
|
124 |
+
else:
|
125 |
+
coords, normals = readSurfPoints_with_receptor_atoms(init_surf_file) # to also get the receptor points
|
126 |
+
|
127 |
+
return coords, normals
|
128 |
+
|
129 |
+
nCl = len(coords)
|
130 |
+
|
131 |
+
kmeans = KMeans(n_clusters=nCl, max_iter=300, n_init=1, random_state=seed).fit(coords)
|
132 |
+
point_labels = kmeans.labels_
|
133 |
+
centers = kmeans.cluster_centers_
|
134 |
+
cluster_idx, freq = np.unique(point_labels, return_counts=True)
|
135 |
+
if len(cluster_idx) != nCl:
|
136 |
+
raise Exception('Number of created clusters should be equal to nCl')
|
137 |
+
|
138 |
+
idxs = []
|
139 |
+
for cl in cluster_idx:
|
140 |
+
cluster_points_idxs = np.where(point_labels == cl)[0]
|
141 |
+
closest_idx_to_center = np.argmin([euclidean(centers[cl], coords[idx]) for idx in cluster_points_idxs])
|
142 |
+
idxs.append(cluster_points_idxs[closest_idx_to_center])
|
143 |
+
|
144 |
+
return coords[idxs], normals[idxs]
|
145 |
+
|
146 |
+
|
147 |
+
def rotation(n):
|
148 |
+
if n[0] == 0.0 and n[1] == 0.0:
|
149 |
+
if n[2] == 1.0:
|
150 |
+
return np.identity(3)
|
151 |
+
elif n[2] == -1.0:
|
152 |
+
Q = np.identity(3)
|
153 |
+
Q[0, 0] = -1
|
154 |
+
return Q
|
155 |
+
else:
|
156 |
+
print('not possible')
|
157 |
+
|
158 |
+
rx = -n[1] / np.sqrt(n[0] * n[0] + n[1] * n[1])
|
159 |
+
ry = n[0] / np.sqrt(n[0] * n[0] + n[1] * n[1])
|
160 |
+
rz = 0
|
161 |
+
th = np.arccos(n[2])
|
162 |
+
|
163 |
+
q0 = np.cos(th / 2)
|
164 |
+
q1 = np.sin(th / 2) * rx
|
165 |
+
q2 = np.sin(th / 2) * ry
|
166 |
+
q3 = np.sin(th / 2) * rz
|
167 |
+
|
168 |
+
Q = np.zeros((3, 3))
|
169 |
+
Q[0, 0] = q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3
|
170 |
+
Q[0, 1] = 2 * (q1 * q2 - q0 * q3)
|
171 |
+
Q[0, 2] = 2 * (q1 * q3 + q0 * q2)
|
172 |
+
Q[1, 0] = 2 * (q1 * q2 + q0 * q3)
|
173 |
+
Q[1, 1] = q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3
|
174 |
+
Q[1, 2] = 2 * (q3 * q2 - q0 * q1)
|
175 |
+
Q[2, 0] = 2 * (q1 * q3 - q0 * q2)
|
176 |
+
Q[2, 1] = 2 * (q3 * q2 + q0 * q1)
|
177 |
+
Q[2, 2] = q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3
|
178 |
+
|
179 |
+
return Q
|
180 |
+
|
181 |
+
|
182 |
+
def TP_TN_FP_FN_visualization2pdb(gt_binding_site_coordinates, lig_scores, to_save_path, gt_indexes):
|
183 |
+
'''
|
184 |
+
Create dummy PDB files to visualize the results (TP, TN, FP, FN) on the receptor PDB file
|
185 |
+
'''
|
186 |
+
threshold = 0.5
|
187 |
+
|
188 |
+
# Initialize lists
|
189 |
+
TP_coords = []
|
190 |
+
FP_coords = []
|
191 |
+
TN_coords = []
|
192 |
+
FN_coords = []
|
193 |
+
|
194 |
+
for i, score in enumerate(lig_scores):
|
195 |
+
# If the atom is a true binding site
|
196 |
+
if i in gt_indexes:
|
197 |
+
if score > threshold:
|
198 |
+
TP_coords.append(gt_binding_site_coordinates[i])
|
199 |
+
else:
|
200 |
+
FN_coords.append(gt_binding_site_coordinates[i])
|
201 |
+
# If the atom is not a binding site
|
202 |
+
else:
|
203 |
+
if score > threshold:
|
204 |
+
FP_coords.append(gt_binding_site_coordinates[i])
|
205 |
+
else:
|
206 |
+
TN_coords.append(gt_binding_site_coordinates[i])
|
207 |
+
|
208 |
+
def generate_pdb_file(coordinates, file_name):
|
209 |
+
"""Generate a dummy PDB file using the provided coordinates."""
|
210 |
+
with open(os.path.join(to_save_path, file_name), 'w') as pdb_file:
|
211 |
+
atom_number = 1
|
212 |
+
for coord in coordinates:
|
213 |
+
x, y, z = coord
|
214 |
+
pdb_file.write(
|
215 |
+
f"ATOM {atom_number:5} DUM DUM A{atom_number:4} {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00\n")
|
216 |
+
atom_number += 1
|
217 |
+
if atom_number == 9999:
|
218 |
+
atom_number = 1
|
219 |
+
pdb_file.write("END")
|
220 |
+
|
221 |
+
# Generate PDB files for TP, FP, TN, and FN
|
222 |
+
generate_pdb_file(TP_coords, os.path.join(to_save_path, "TP_atoms.pdb"))
|
223 |
+
generate_pdb_file(FP_coords, os.path.join(to_save_path, "FP_atoms.pdb"))
|
224 |
+
generate_pdb_file(TN_coords, os.path.join(to_save_path, "TN_atoms.pdb"))
|
225 |
+
generate_pdb_file(FN_coords, os.path.join(to_save_path, "FN_atoms.pdb"))
|
226 |
+
|
227 |
+
print('TP:', len(TP_coords), 'FP:', len(FP_coords), 'FN:', len(FN_coords), 'TN:', len(TN_coords))
|
228 |
+
|
229 |
+
|
230 |
+
def visualize_TP_TN_FP_FN_residue_level(lig_scores, gt_indexes, residues, receptor_path, tosavepath):
|
231 |
+
threshold = 0.5
|
232 |
+
|
233 |
+
# Initialize lists
|
234 |
+
tp_list = []
|
235 |
+
fp_list = []
|
236 |
+
tn_list = []
|
237 |
+
fn_list = []
|
238 |
+
|
239 |
+
with open(receptor_path, 'r') as f:
|
240 |
+
lines = f.readlines()
|
241 |
+
|
242 |
+
res_atoms = [len(i[1]['atoms']) for i in residues.items()]
|
243 |
+
|
244 |
+
for i, score in enumerate(lig_scores):
|
245 |
+
# If the atom is a true binding site
|
246 |
+
lines2add = res_atoms[i]
|
247 |
+
if i in gt_indexes:
|
248 |
+
if score > threshold:
|
249 |
+
tp_list.append(lines[:lines2add])
|
250 |
+
del lines[:lines2add]
|
251 |
+
else:
|
252 |
+
fn_list.append(lines[:lines2add])
|
253 |
+
del lines[:lines2add]
|
254 |
+
# If the atom is not a binding site
|
255 |
+
else:
|
256 |
+
if score > threshold:
|
257 |
+
fp_list.append(lines[:lines2add])
|
258 |
+
del lines[:lines2add]
|
259 |
+
else:
|
260 |
+
tn_list.append(lines[:lines2add])
|
261 |
+
del lines[:lines2add]
|
262 |
+
|
263 |
+
# Generate PDB files for TP, FP, TN, and FN
|
264 |
+
with open(os.path.join(tosavepath, 'TP_residues.pdb'), 'w') as f:
|
265 |
+
for l in tp_list:
|
266 |
+
for item in l:
|
267 |
+
f.write(item)
|
268 |
+
with open(os.path.join(tosavepath, 'FP_residues.pdb'), 'w') as f:
|
269 |
+
for l in fp_list:
|
270 |
+
for item in l:
|
271 |
+
f.write(item)
|
272 |
+
with open(os.path.join(tosavepath, 'FN_residues.pdb'), 'w') as f:
|
273 |
+
for l in fn_list:
|
274 |
+
for item in l:
|
275 |
+
f.write(item)
|
276 |
+
with open(os.path.join(tosavepath, 'TN_residues.pdb'), 'w') as f:
|
277 |
+
for l in tn_list:
|
278 |
+
for item in l:
|
279 |
+
f.write(item)
|
280 |
+
|
281 |
+
# print('TP:', len(tp_list), 'FP:', len(fp_list), 'FN:', len(fn_list), 'TN:', len(tn_list))
|
282 |
+
|
283 |
+
|
284 |
+
def calculate_TP_TN_FP_FN(lig_scores, gt_indexes):
|
285 |
+
threshold = 0.5
|
286 |
+
|
287 |
+
# Initialize lists
|
288 |
+
TP = 0
|
289 |
+
FP = 0
|
290 |
+
TN = 0
|
291 |
+
FN = 0
|
292 |
+
|
293 |
+
for i, score in enumerate(lig_scores):
|
294 |
+
# If the atom is a true binding site
|
295 |
+
if i in gt_indexes:
|
296 |
+
if score > threshold:
|
297 |
+
TP += 1
|
298 |
+
else:
|
299 |
+
FN += 1
|
300 |
+
# If the atom is not a binding site
|
301 |
+
else:
|
302 |
+
if score > threshold:
|
303 |
+
FP += 1
|
304 |
+
else:
|
305 |
+
TN += 1
|
306 |
+
|
307 |
+
print('TP:', TP, 'FP:', FP, 'FN:', FN, 'TN:', TN)
|
308 |
+
|
309 |
+
|
310 |
+
def show_roc_curve(true_labels, lig_scores, auc_roc):
|
311 |
+
# Calculate ROC curve
|
312 |
+
fpr, tpr, thresholds = roc_curve(true_labels, lig_scores)
|
313 |
+
|
314 |
+
# Plot ROC curve
|
315 |
+
plt.figure(figsize=(8, 6))
|
316 |
+
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {auc_roc:.2f})')
|
317 |
+
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
|
318 |
+
plt.xlim([0.0, 1.0])
|
319 |
+
plt.ylim([0.0, 1.05])
|
320 |
+
plt.xlabel('False Positive Rate')
|
321 |
+
plt.ylabel('True Positive Rate')
|
322 |
+
plt.title('Receiver Operating Characteristic (ROC) Curve')
|
323 |
+
plt.legend(loc="lower right")
|
324 |
+
plt.show()
|
325 |
+
|
326 |
+
|
327 |
+
def calculate_metrics(true_labels, predicted_labels, lig_scores, to_save_metrics_path):
|
328 |
+
auc_roc = roc_auc_score(true_labels, lig_scores)
|
329 |
+
accuracy = accuracy_score(true_labels, predicted_labels)
|
330 |
+
precision = precision_score(true_labels, predicted_labels)
|
331 |
+
recall = recall_score(true_labels, predicted_labels)
|
332 |
+
f1 = f1_score(true_labels, predicted_labels)
|
333 |
+
pr, re, _ = precision_recall_curve(true_labels, lig_scores)
|
334 |
+
auc_pr = auc(re, pr)
|
335 |
+
conf_matrix = confusion_matrix(true_labels, predicted_labels)
|
336 |
+
mcc = matthews_corrcoef(true_labels, predicted_labels)
|
337 |
+
|
338 |
+
tn, fp, fn, tp = conf_matrix.ravel()
|
339 |
+
fpr = fp / (fp + tn)
|
340 |
+
tpr = tp / (tp + fn) # True positive rate == sensitivity == recall
|
341 |
+
npv = tn / (tn + fn)
|
342 |
+
spc = tn / (fp + tn) # Specificity or True Negative Rate
|
343 |
+
|
344 |
+
with open(to_save_metrics_path, 'w') as f:
|
345 |
+
print(f"AUC-ROC: {auc_roc:.4f}", file=f)
|
346 |
+
print(f"Accuracy: {accuracy:.4f}", file=f)
|
347 |
+
print(f"Precision: {precision:.4f}", file=f)
|
348 |
+
print(f"Recall: {recall:.4f}", file=f)
|
349 |
+
print(f"F1 Score: {f1:.4f}", file=f)
|
350 |
+
print(f"AUC-PR: {auc_pr:.4f}", file=f)
|
351 |
+
print(f"Confusion Matrix:\n {conf_matrix}", file=f)
|
352 |
+
print(f"Matthews Correlation Coefficient: {mcc:.4f}", file=f)
|
353 |
+
print(f"False Positive Rate (FPR): {fpr:.4f}", file=f)
|
354 |
+
print(f"Negative Predictive Value (NPV): {npv:.4f}", file=f)
|
355 |
+
print(f"Specificity (SPC): {spc:.4f}", file=f)
|
356 |
+
|
357 |
+
return auc_roc, accuracy, precision, recall, f1, auc_pr, conf_matrix, mcc, fpr, npv, spc
|
358 |
+
|
359 |
+
|
360 |
+
def filter_out_HETATMs(pdb_file_path):
|
361 |
+
with open(pdb_file_path, 'r') as infile:
|
362 |
+
lines = infile.readlines()
|
363 |
+
|
364 |
+
# Filter out lines starting with 'HETATM'
|
365 |
+
filtered_lines = [line for line in lines if not line.startswith('HETATM')]
|
366 |
+
|
367 |
+
# Write the filtered lines back to the file
|
368 |
+
with open(pdb_file_path, 'w') as outfile:
|
369 |
+
outfile.writelines(filtered_lines)
|
370 |
+
|
371 |
+
|
372 |
+
def write_residue_prediction_pdb(receptor, output_pdb_path, residues_best):
|
373 |
+
"""
|
374 |
+
:param receptor: original receptor pdb file path
|
375 |
+
:param results_save_path: where to save the prediction pdb file residues: the residues dict with scores
|
376 |
+
:param residues_best: the residues dict with scores
|
377 |
+
:return: Write the prediction PDB file with scores at residue level (replaces B-factor for residues).
|
378 |
+
"""
|
379 |
+
# rec_name = receptor.split('/')[-1].split('_')[0]
|
380 |
+
# output_pdb_path = os.path.join(results_save_path, f'{rec_name}_pred.pdb')
|
381 |
+
#
|
382 |
+
# # Ensure the directory exists
|
383 |
+
# os.makedirs(results_save_path, exist_ok=True)
|
384 |
+
|
385 |
+
# Open the original receptor PDB file and the output PDB file for writing the predictions
|
386 |
+
with open(receptor, 'r') as original_pdb, open(output_pdb_path, 'w') as pred_pdb:
|
387 |
+
for line in original_pdb:
|
388 |
+
if line.startswith("ATOM") or line.startswith("HETATM"): # Process only ATOM and HETATM records
|
389 |
+
# Extract residue info (residue number, chain ID, and insertion code)
|
390 |
+
chain_id = line[21]
|
391 |
+
res_num = line[22:26].strip()
|
392 |
+
insertion_code = line[26].strip()
|
393 |
+
|
394 |
+
# Create the residue ID in the same format as in residues_best
|
395 |
+
res_id = f'{res_num}_{chain_id}'
|
396 |
+
if insertion_code:
|
397 |
+
res_id = f'{res_id}_{insertion_code}'
|
398 |
+
|
399 |
+
# Check if the residue exists in residues_best
|
400 |
+
if res_id in residues_best:
|
401 |
+
# Extract the prediction score
|
402 |
+
pred_score = residues_best[res_id]['scores']
|
403 |
+
|
404 |
+
# Modify the line to replace the B-factor (position 61-66) with the prediction score
|
405 |
+
new_b_factor = f'{pred_score:6.3f}' # Format the prediction score with 3 decimal places
|
406 |
+
new_line = f'{line[:60]}{new_b_factor:>6}{line[66:]}'
|
407 |
+
|
408 |
+
# Write the modified line to the new PDB file
|
409 |
+
pred_pdb.write(new_line)
|
410 |
+
else:
|
411 |
+
# If no prediction score is found, write the original line
|
412 |
+
pred_pdb.write(line)
|
413 |
+
else:
|
414 |
+
# Write lines that do not start with ATOM or HETATM (like headers and footers) unchanged
|
415 |
+
pred_pdb.write(line)
|
416 |
+
|
417 |
+
print(f"Residue-level prediction PDB file saved as {output_pdb_path}")
|
418 |
+
|
419 |
+
|
420 |
+
def write_atom_prediction_pdb(receptor, results_save_path, lig_scores_only_receptor_atoms):
|
421 |
+
"""
|
422 |
+
:param receptor: original receptor pdb file path
|
423 |
+
:param results_save_path: where to save the prediction pdb file residues: the residues dict with scores
|
424 |
+
:param residues_best: the residues dict with scores
|
425 |
+
:return: Write the prediction PDB file with scores at atom level (replaces B-factor for each atom).
|
426 |
+
"""
|
427 |
+
|
428 |
+
rec_name = receptor.split('/')[-1].split('_')[0]
|
429 |
+
output_pdb_path = os.path.join(results_save_path, f'{rec_name}_pred_per_atom.pdb')
|
430 |
+
|
431 |
+
# Make sure the length of lig_scores matches the number of atoms in the PDB
|
432 |
+
assert len(lig_scores_only_receptor_atoms) == sum(1 for line in open(receptor) if line.startswith("ATOM") or line.startswith("HETATM")), \
|
433 |
+
"Number of scores doesn't match the number of atoms in the PDB file"
|
434 |
+
|
435 |
+
# Open the original receptor PDB file and the output PDB file for writing the predictions
|
436 |
+
with open(receptor, 'r') as original_pdb, open(output_pdb_path, 'w') as pred_pdb2:
|
437 |
+
atom_index = 0 # To track which score corresponds to which atom
|
438 |
+
for line in original_pdb:
|
439 |
+
if line.startswith("ATOM") or line.startswith("HETATM"): # Process only ATOM and HETATM records
|
440 |
+
# Extract the prediction score for the current atom
|
441 |
+
pred_score = lig_scores_only_receptor_atoms[atom_index][0] # Get the prediction score for this atom
|
442 |
+
|
443 |
+
# Modify the line to replace the B-factor (position 61-66) with the prediction score
|
444 |
+
new_b_factor = f'{pred_score:6.3f}' # Format the prediction score with 3 decimal places
|
445 |
+
new_line = f'{line[:60]}{new_b_factor:>6}{line[66:]}' # Insert the prediction score at the correct position
|
446 |
+
|
447 |
+
# Write the modified line to the new PDB file
|
448 |
+
pred_pdb2.write(new_line)
|
449 |
+
|
450 |
+
# Increment the atom index
|
451 |
+
atom_index += 1
|
452 |
+
else:
|
453 |
+
# Write lines that do not start with ATOM or HETATM (like headers and footers) unchanged
|
454 |
+
pred_pdb2.write(line)
|
455 |
+
|
456 |
+
print(f"Per-atom prediction PDB file saved as {output_pdb_path}")
|
457 |
+
|
458 |
+
|
459 |
+
def receptor_info(receptor, lig_scores_only_receptor_atoms):
|
460 |
+
"""
|
461 |
+
Extract residue groups and compute the best scores for each residue in the receptor.
|
462 |
+
|
463 |
+
Args:
|
464 |
+
receptor (str): The path to the receptor PDB file.
|
465 |
+
lig_scores_only_receptor_atoms (ndarray): List of ligandability scores for each atom.
|
466 |
+
|
467 |
+
Returns:
|
468 |
+
residues (dict): Dictionary containing atom information and ligand scores for each residue.
|
469 |
+
residues_best (dict): Dictionary containing the best ligand score for each residue.
|
470 |
+
"""
|
471 |
+
# Create the residue groups for the whole protein
|
472 |
+
residues = {}
|
473 |
+
with open(receptor, 'r') as file:
|
474 |
+
for line in file:
|
475 |
+
if line.startswith("ATOM"):
|
476 |
+
chain_id = line[21] # Extract chain identifier
|
477 |
+
atom_id = line[6:11].strip()
|
478 |
+
res_id = f'{line[22:26].strip()}_{chain_id}' # Concatenate residue ID with chain ID
|
479 |
+
insertion_code = line[26].strip()
|
480 |
+
if insertion_code:
|
481 |
+
res_id = f'{res_id}_{insertion_code}'
|
482 |
+
if res_id not in residues:
|
483 |
+
residues[res_id] = {"atoms": [], 'scores': []}
|
484 |
+
residues[res_id]["atoms"].append(atom_id)
|
485 |
+
|
486 |
+
atom2check = int(atom_id) - 1
|
487 |
+
residues[res_id]['scores'].append(lig_scores_only_receptor_atoms[atom2check][0])
|
488 |
+
|
489 |
+
# Take the best scores for the whole protein
|
490 |
+
residues_best = {}
|
491 |
+
for res_id, res_data in residues.items():
|
492 |
+
residues_best[res_id] = {'scores': []}
|
493 |
+
check_best = res_data['scores']
|
494 |
+
best_atom = check_best.index(max(check_best)) # We take the best atom score of the residue
|
495 |
+
residues_best[res_id]['scores'] = check_best[best_atom]
|
496 |
+
|
497 |
+
return residues, residues_best
|