mschuh commited on
Commit
84bfd88
β€’
1 Parent(s): 777321b

Upload 37 files

Browse files
Files changed (38) hide show
  1. .gitattributes +10 -0
  2. README.md +5 -8
  3. app.py +125 -0
  4. model/__init__.py +0 -0
  5. model/barlow_twins.py +525 -0
  6. model/base_model.py +75 -0
  7. model/model.py +169 -0
  8. model/preprocessor.py +180 -0
  9. model/stash/14062024_0910/history.json +0 -0
  10. model/stash/14062024_0910/log.txt +41 -0
  11. model/stash/14062024_0910/params.pkl +3 -0
  12. model/stash/14062024_0910/weights.pt +3 -0
  13. model/xgb_models/14062024_0910_barlowdti_xxl_model.json +0 -0
  14. model/xgb_models/xgb_model_BIOSNAP_full_data_14062024_0910_bt_optimized_0.json +3 -0
  15. model/xgb_models/xgb_model_BIOSNAP_missing_data_70_14062024_0910_bt_optimized_0.json +0 -0
  16. model/xgb_models/xgb_model_BIOSNAP_missing_data_80_14062024_0910_bt_optimized_0.json +3 -0
  17. model/xgb_models/xgb_model_BIOSNAP_missing_data_90_14062024_0910_bt_optimized_0.json +0 -0
  18. model/xgb_models/xgb_model_BIOSNAP_missing_data_95_14062024_0910_bt_optimized_0.json +0 -0
  19. model/xgb_models/xgb_model_BIOSNAP_unseen_drug_14062024_0910_bt_optimized_0.json +3 -0
  20. model/xgb_models/xgb_model_BIOSNAP_unseen_protein_14062024_0910_bt_optimized_0.json +3 -0
  21. model/xgb_models/xgb_model_BindingDB_14062024_0910_bt_optimized_0.json +3 -0
  22. model/xgb_models/xgb_model_DAVIS_14062024_0910_bt_optimized_0.json +0 -0
  23. model/xgb_models/xgb_model_nature_mach_intel_BindingDB_cluster_14062024_0910_bt_optimized_0.json +0 -0
  24. model/xgb_models/xgb_model_nature_mach_intel_BindingDB_protein_14062024_0910_bt_optimized_0.json +3 -0
  25. model/xgb_models/xgb_model_nature_mach_intel_BindingDB_random_14062024_0910_bt_optimized_0.json +3 -0
  26. model/xgb_models/xgb_model_nature_mach_intel_BindingDB_scaffold_14062024_0910_bt_optimized_0.json +3 -0
  27. model/xgb_models/xgb_model_nature_mach_intel_BioSNAP_cluster_14062024_0910_bt_optimized_0.json +0 -0
  28. model/xgb_models/xgb_model_nature_mach_intel_BioSNAP_protein_14062024_0910_bt_optimized_0.json +0 -0
  29. model/xgb_models/xgb_model_nature_mach_intel_BioSNAP_random_14062024_0910_bt_optimized_0.json +3 -0
  30. model/xgb_models/xgb_model_nature_mach_intel_BioSNAP_scaffold_14062024_0910_bt_optimized_0.json +0 -0
  31. model/xgb_models/xgb_model_nature_mach_intel_Human_protein_14062024_0910_bt_optimized_0.json +3 -0
  32. model/xgb_models/xgb_model_nature_mach_intel_Human_random_14062024_0910_bt_optimized_0.json +0 -0
  33. model/xgb_models/xgb_model_nature_mach_intel_Human_scaffold_14062024_0910_bt_optimized_0.json +0 -0
  34. requirements.txt +25 -0
  35. utils/__init__.py +0 -0
  36. utils/chem.py +64 -0
  37. utils/parallel.py +78 -0
  38. utils/sequence.py +339 -0
.gitattributes CHANGED
@@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model/xgb_models/xgb_model_BindingDB_14062024_0910_bt_optimized_0.json filter=lfs diff=lfs merge=lfs -text
37
+ model/xgb_models/xgb_model_BIOSNAP_full_data_14062024_0910_bt_optimized_0.json filter=lfs diff=lfs merge=lfs -text
38
+ model/xgb_models/xgb_model_BIOSNAP_missing_data_80_14062024_0910_bt_optimized_0.json filter=lfs diff=lfs merge=lfs -text
39
+ model/xgb_models/xgb_model_BIOSNAP_unseen_drug_14062024_0910_bt_optimized_0.json filter=lfs diff=lfs merge=lfs -text
40
+ model/xgb_models/xgb_model_BIOSNAP_unseen_protein_14062024_0910_bt_optimized_0.json filter=lfs diff=lfs merge=lfs -text
41
+ model/xgb_models/xgb_model_nature_mach_intel_BindingDB_protein_14062024_0910_bt_optimized_0.json filter=lfs diff=lfs merge=lfs -text
42
+ model/xgb_models/xgb_model_nature_mach_intel_BindingDB_random_14062024_0910_bt_optimized_0.json filter=lfs diff=lfs merge=lfs -text
43
+ model/xgb_models/xgb_model_nature_mach_intel_BindingDB_scaffold_14062024_0910_bt_optimized_0.json filter=lfs diff=lfs merge=lfs -text
44
+ model/xgb_models/xgb_model_nature_mach_intel_BioSNAP_random_14062024_0910_bt_optimized_0.json filter=lfs diff=lfs merge=lfs -text
45
+ model/xgb_models/xgb_model_nature_mach_intel_Human_protein_14062024_0910_bt_optimized_0.json filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,10 @@
1
  ---
2
  title: BarlowDTI
3
- emoji: πŸ”₯
4
- colorFrom: purple
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.41.0
8
  app_file: app.py
9
- pinned: false
10
- license: cc-by-4.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: BarlowDTI
3
+ emoji: πŸ’Š ↔️ 🎯
4
+ colorFrom: blue
5
+ colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.41.0
8
  app_file: app.py
9
+ pinned: true
10
+ ---
 
 
 
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import plotly.graph_objects as go
3
+ import numpy as np
4
+ import pandas as pd
5
+ from model.model import DTIModel
6
+
7
+
8
+ dt_str = "14062024_0910"
9
+
10
+
11
+ def make_spider_plot(predictions, model_names, smiles_list):
12
+ fig = go.Figure()
13
+
14
+ for i, (prediction, smiles) in enumerate(zip(predictions, smiles_list)):
15
+ fig.add_trace(go.Scatterpolar(
16
+ r=prediction,
17
+ theta=model_names,
18
+ fill='toself',
19
+ name=smiles
20
+ ))
21
+
22
+ fig.update_layout(
23
+ polar=dict(
24
+ radialaxis=dict(
25
+ visible=True,
26
+ range=[0, 1]
27
+ )),
28
+ showlegend=True
29
+ )
30
+
31
+ return fig
32
+
33
+
34
+ def predict_and_plot(amino_acid_sequence, smiles_input, datasets):
35
+ model_ensemble = {}
36
+
37
+ gbm_model_paths = {
38
+ "BindingDB": f"model/xgb_models/xgb_model_BindingDB_{dt_str}_bt_optimized_0.json",
39
+ "BioSNAP": f"model/xgb_models/xgb_model_BIOSNAP_full_data_{dt_str}_bt_optimized_0.json",
40
+ "DAVIS": f"model/xgb_models/xgb_model_DAVIS_{dt_str}_bt_optimized_0.json",
41
+ "BarlowDTI XXL": f"model/xgb_models/{dt_str}_barlowdti_xxl_model.json",
42
+ }
43
+
44
+ for model in datasets:
45
+ print(f"Loading model {model}")
46
+
47
+ model_ensemble[model] = DTIModel(
48
+ bt_model_path=f"model/stash/{dt_str}",
49
+ gbm_model_path=gbm_model_paths[model],
50
+ )
51
+
52
+ smiles_list = smiles_input.strip().split('\n')
53
+ predictions = []
54
+ for model in model_ensemble.values():
55
+ model_predictions = model.predict(smiles_list, amino_acid_sequence)
56
+ predictions.append(model_predictions)
57
+
58
+ predictions = np.array(predictions).transpose().tolist()
59
+
60
+ df = pd.DataFrame(predictions, index=smiles_list, columns=datasets).reset_index()
61
+ df.columns = ["SMILES"] + datasets
62
+
63
+ fig = make_spider_plot(predictions, datasets, smiles_list)
64
+
65
+ return fig, df
66
+
67
+
68
+ dataset_names = [
69
+ "BarlowDTI XXL",
70
+ "BindingDB",
71
+ "BioSNAP",
72
+ "DAVIS",
73
+ ]
74
+
75
+ title = "Predict Drug-Target Interactions with <span style='font-variant:small-caps;'>BarlowDTI</span>"
76
+
77
+ description = """
78
+ Input Amino Acid Sequence and SMILES to get interaction predictions visualized as a spider graph and in a table.
79
+ The values ca be interpreted as the probability of interaction between the drug and target (0 = no interaction, 1 = interaction).
80
+
81
+ __Note: Inference may take a loger time, you can upgrade to a paid GPU-enabled plan for faster inference.__
82
+ """
83
+
84
+ article = """
85
+ This interface enables the use of <span style='font-variant:small-caps;'>BarlowDTI</span> to predict drug-target interactions.
86
+ The model ensemble consists of three models trained on different datasets: BindingDB, BIOSNAP, and DAVIS.
87
+
88
+ If you use this interface in your research, please cite our paper:
89
+ ```
90
+ @misc{schuh2024barlowtwinsdeepneural,
91
+ title={Barlow Twins Deep Neural Network for Advanced 1D Drug-Target Interaction Prediction},
92
+ author={Maximilian G. Schuh and Davide Boldini and Stephan A. Sieber},
93
+ year={2024},
94
+ eprint={2408.00040},
95
+ archivePrefix={arXiv},
96
+ primaryClass={q-bio.BM},
97
+ url={https://arxiv.org/abs/2408.00040},
98
+ }
99
+ ```
100
+ """
101
+
102
+ theme = gr.themes.Base(
103
+ primary_hue="violet",
104
+ font=[gr.themes.GoogleFont('IBM Plex Sans'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
105
+ )
106
+
107
+ iface = gr.Interface(
108
+ fn=predict_and_plot,
109
+ inputs=[
110
+ gr.Textbox(label="Protein Sequence", info="Just one sequence is allowed. Remove FASTA syntax (e.g. >ABC)."),
111
+ gr.Textbox(label="Molecule SMILES", info="One per line, multiple allowed."),
112
+ gr.CheckboxGroup(choices=dataset_names, label="Select Models for Prediction", value="BarlowDTI XXL")
113
+ ],
114
+ outputs=[
115
+ gr.Plot(label="Predictions Visualization"),
116
+ gr.DataFrame(label="Predictions DataFrame"),
117
+ # gr.DownloadButton(label="Download Predictions")
118
+ ],
119
+ title=title,
120
+ description=description,
121
+ article=article,
122
+ theme=theme
123
+ )
124
+
125
+ iface.launch()
model/__init__.py ADDED
File without changes
model/barlow_twins.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ torch.manual_seed(42)
3
+ torch.backends.cudnn.deterministic = True
4
+ from torch import nn
5
+ import numpy as np
6
+ from typing import *
7
+ from datetime import datetime
8
+ import os
9
+ import pickle
10
+ import inspect
11
+ from tqdm.auto import trange
12
+
13
+ from model.base_model import BaseModel
14
+
15
+
16
+ class BarlowTwins(BaseModel):
17
+ def __init__(
18
+ self,
19
+ n_bits: int = 1024,
20
+ aa_emb_size: int = 1024,
21
+ enc_n_neurons: int = 512,
22
+ enc_n_layers: int = 2,
23
+ proj_n_neurons: int = 2048,
24
+ proj_n_layers: int = 2,
25
+ embedding_dim: int = 512,
26
+ act_function: str = "relu",
27
+ loss_weight: float = 0.005,
28
+ batch_size: int = 512,
29
+ optimizer: str = "adamw",
30
+ momentum: float = 0.9,
31
+ learning_rate: float = 0.0001,
32
+ betas: tuple = (0.9, 0.999),
33
+ weight_decay: float = 1e-3,
34
+ step_size: int = 10,
35
+ gamma: float = 0.1,
36
+ verbose: bool = True,
37
+ ):
38
+ super().__init__()
39
+
40
+ self.enc_aa = None
41
+ self.enc_mol = None
42
+ self.proj = None
43
+
44
+ self.scheduler = None
45
+ self.optimizer = None
46
+
47
+ # store input in dict
48
+ self.param_dict = {
49
+ "act_function": self.activation_dict[
50
+ act_function
51
+ ], # which activation function to use among dict options
52
+ "loss_weight": loss_weight, # off-diagonal cross correlation loss weight
53
+ "batch_size": batch_size, # samples per gradient step
54
+ "learning_rate": learning_rate, # update step magnitude when training
55
+ "betas": betas, # momentum hyperparameter for adam-like optimizers
56
+ "step_size": step_size, # decay period for the learning rate
57
+ "gamma": gamma, # decay coefficient for the learning rate
58
+ "optimizer": self.optimizer_dict[
59
+ optimizer
60
+ ], # which optimizer to use among dict options
61
+ "momentum": momentum, # momentum hyperparameter for SGD
62
+ "enc_n_neurons": enc_n_neurons, # neurons to use for the mlp encoder
63
+ "enc_n_layers": enc_n_layers, # number of hidden layers in the mlp encoder
64
+ "proj_n_neurons": proj_n_neurons, # neurons to use for the mlp projector
65
+ "proj_n_layers": proj_n_layers, # number of hidden layers in the mlp projector
66
+ "embedding_dim": embedding_dim, # latent space dim for downstream tasks
67
+ "weight_decay": weight_decay, # l2 regularization for linear layers
68
+ "verbose": verbose, # whether to print feedback
69
+ "radius": "Not defined yet", # fingerprint radius
70
+ "n_bits": n_bits, # fingerprint bit size
71
+ "aa_emb_size": aa_emb_size, # aa embedding size
72
+ }
73
+
74
+ # create history dictionary
75
+ self.history = {
76
+ "train_loss": [],
77
+ "on_diag_loss": [],
78
+ "off_diag_loss": [],
79
+ "validation_loss": [],
80
+ }
81
+
82
+ # run NN architecture construction method
83
+ self.construct_model()
84
+
85
+ # run scheduler construction method
86
+ self.construct_scheduler()
87
+
88
+ # print if necessary
89
+ if self.param_dict["verbose"] is True:
90
+ self.print_config()
91
+
92
+ @staticmethod
93
+ def __validate_inputs(locals_dict) -> None:
94
+ # get signature types from __init__
95
+ init_signature = inspect.signature(BarlowTwins.__init__)
96
+
97
+ # loop over all chosen arguments
98
+ for param_name, param_value in locals_dict.items():
99
+ # skip self
100
+ if param_name != "self":
101
+ # check that parameter exists
102
+ if param_name in init_signature.parameters:
103
+ # check that param is correct type
104
+ expected_type = init_signature.parameters[param_name].annotation
105
+ assert isinstance(
106
+ param_value, expected_type
107
+ ), f"[BT]: Type mismatch for parameter '{param_name}'"
108
+ else:
109
+ raise ValueError(f"[BT]: Unexpected parameter '{param_name}'")
110
+
111
+ def construct_mlp(self, input_units, layer_units, n_layers, output_units) -> nn.Sequential:
112
+
113
+ # make empty list to fill
114
+ mlp_list = []
115
+
116
+ # make lists defining layer sizes (input + n_neurons*n_layers + embedding_dim)
117
+ units = [input_units] + [layer_units] * n_layers
118
+
119
+ # add layer stack (linear -> batchnorm -> dropout -> activation)
120
+ for i in range(len(units) - 1):
121
+ mlp_list.append(nn.Linear(units[i], units[i + 1]))
122
+ mlp_list.append(nn.BatchNorm1d(units[i + 1]))
123
+ mlp_list.append(self.param_dict["act_function"]())
124
+
125
+ # add final linear layer
126
+ mlp_list.append(nn.Linear(units[-1], output_units))
127
+
128
+ return nn.Sequential(*mlp_list)
129
+
130
+ def construct_model(self) -> None:
131
+ # create fingerprint transformer
132
+ self.enc_mol = self.construct_mlp(
133
+ self.param_dict["n_bits"],
134
+ self.param_dict["enc_n_neurons"],
135
+ self.param_dict["enc_n_layers"],
136
+ self.param_dict["embedding_dim"],
137
+ )
138
+
139
+ # create aa transformer
140
+ self.enc_aa = self.construct_mlp(
141
+ self.param_dict["aa_emb_size"],
142
+ self.param_dict["enc_n_neurons"],
143
+ self.param_dict["enc_n_layers"],
144
+ self.param_dict["embedding_dim"],
145
+ )
146
+
147
+ # create mlp projector
148
+ self.proj = self.construct_mlp(
149
+ self.param_dict["embedding_dim"],
150
+ self.param_dict["proj_n_neurons"],
151
+ self.param_dict["proj_n_layers"],
152
+ self.param_dict["proj_n_neurons"],
153
+ )
154
+
155
+ # print if necessary
156
+ if self.param_dict["verbose"] is True:
157
+ print("[BT]: Model constructed successfully")
158
+
159
+ def construct_scheduler(self):
160
+ # make optimizer
161
+ self.optimizer = self.param_dict["optimizer"](
162
+ list(self.enc_mol.parameters())
163
+ + list(self.enc_aa.parameters())
164
+ + list(self.proj.parameters()),
165
+ lr=self.param_dict["learning_rate"],
166
+ betas=self.param_dict["betas"],
167
+ # momentum=self.param_dict["momentum"],
168
+ weight_decay=self.param_dict["weight_decay"],
169
+ )
170
+
171
+ # wrap optimizer in scheduler
172
+ """
173
+ self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
174
+ self.optimizer,
175
+ T_max=self.param_dict["step_size"], # T_0
176
+ # eta_min=1e-7,
177
+ verbose=True
178
+ )
179
+
180
+ self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
181
+ self.optimizer,
182
+ patience=self.param_dict["step_size"],
183
+ verbose=True
184
+ )
185
+ """
186
+ self.scheduler = torch.optim.lr_scheduler.StepLR(
187
+ self.optimizer,
188
+ step_size=self.param_dict["step_size"],
189
+ gamma=self.param_dict["gamma"],
190
+ )
191
+
192
+ # print if necessary
193
+ if self.param_dict["verbose"] is True:
194
+ print("[BT]: Optimizer constructed successfully")
195
+
196
+ def switch_mode(self, is_training: bool):
197
+ if is_training:
198
+ self.enc_mol.train()
199
+ self.enc_aa.train()
200
+ self.proj.train()
201
+ else:
202
+ self.enc_mol.eval()
203
+ self.enc_aa.eval()
204
+ self.proj.eval()
205
+
206
+ @staticmethod
207
+ def normalize_projection(tensor: torch.tensor) -> torch.tensor:
208
+ means = torch.mean(tensor, axis=0)
209
+ std = torch.std(tensor, axis=0)
210
+ centered = torch.add(tensor, -means)
211
+ scaled = torch.div(centered, std)
212
+
213
+ return scaled
214
+
215
+ def compute_loss(
216
+ self,
217
+ mol_embedding: torch.tensor,
218
+ aa_embedding: torch.tensor,
219
+ ) -> torch.tensor:
220
+
221
+ # empirical cross-correlation matrix
222
+ mol_embedding = self.normalize_projection(mol_embedding).T
223
+ aa_embedding = self.normalize_projection(aa_embedding)
224
+ c = mol_embedding @ aa_embedding
225
+
226
+ # normalize by number of samples
227
+ c.div_(self.param_dict["batch_size"])
228
+
229
+ # compute elements on diagonal
230
+ on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
231
+
232
+ # compute elements off diagonal
233
+ n, m = c.shape
234
+ off_diag = c.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
235
+ off_diag = off_diag.pow_(2).sum() * self.param_dict["loss_weight"]
236
+
237
+ return on_diag, off_diag
238
+
239
+ def forward(
240
+ self, mol_data: torch.tensor, aa_data: torch.tensor, is_training: bool = True
241
+ ) -> torch.tensor:
242
+
243
+ # switch according to input
244
+ self.switch_mode(is_training)
245
+
246
+ # get embeddings
247
+ mol_embeddings = self.enc_mol(mol_data)
248
+ aa_embeddings = self.enc_aa(aa_data)
249
+
250
+ # get projections
251
+ mol_proj = self.proj(mol_embeddings)
252
+ aa_proj = self.proj(aa_embeddings)
253
+
254
+ # compute loss
255
+ on_diag, off_diag = self.compute_loss(mol_proj, aa_proj)
256
+
257
+ return on_diag, off_diag
258
+
259
+ def train(
260
+ self,
261
+ train_data: torch.utils.data.DataLoader,
262
+ val_data: torch.utils.data.DataLoader = None,
263
+ num_epochs: int = 20,
264
+ patience: int = None,
265
+ ):
266
+ if self.param_dict["verbose"] is True:
267
+ print("[BT]: Training started")
268
+
269
+ if patience is None:
270
+ patience = 2 * self.param_dict["step_size"]
271
+
272
+ pbar = trange(num_epochs, desc="[BT]: Epochs", leave=False, colour="blue")
273
+
274
+ for epoch in pbar:
275
+ # initialize loss containers
276
+ train_loss = 0.0
277
+ on_diag_loss = 0.0
278
+ off_diag_loss = 0.0
279
+ val_loss = 0.0
280
+
281
+ # loop over training set
282
+ for _, (mol_data, aa_data) in enumerate(train_data):
283
+ # reset grad
284
+ self.optimizer.zero_grad()
285
+
286
+ # compute train loss for batch
287
+ on_diag, off_diag = self.forward(mol_data, aa_data, is_training=True)
288
+ t_loss = on_diag + off_diag
289
+
290
+ # backpropagation and optimization
291
+ t_loss.backward()
292
+ """
293
+ nn.utils.clip_grad_norm_(
294
+ list(self.enc_mol.parameters()) +
295
+ list(self.enc_aa.parameters()) +
296
+ list(self.proj.parameters()),
297
+ 1
298
+ )
299
+ """
300
+ self.optimizer.step()
301
+
302
+ # add i-th loss to training container
303
+ train_loss += t_loss.item()
304
+ on_diag_loss += on_diag.item()
305
+ off_diag_loss += off_diag.item()
306
+
307
+ # add mean epoch loss for train data to history dictionary
308
+ self.history["train_loss"].append(train_loss / len(train_data))
309
+ self.history["on_diag_loss"].append(on_diag_loss / len(train_data))
310
+ self.history["off_diag_loss"].append(off_diag_loss / len(train_data))
311
+
312
+ # define msg to be printed
313
+ msg = (
314
+ f"[BT]: Epoch [{epoch + 1}/{num_epochs}], "
315
+ f"Train loss: {train_loss / len(train_data):.3f}, "
316
+ f"On diagonal: {on_diag_loss / len(train_data):.3f}, "
317
+ f"Off diagonal: {off_diag_loss / len(train_data):.3f} "
318
+ )
319
+
320
+ # loop over validation set (if present)
321
+ if val_data is not None:
322
+
323
+ for _, (mol_data, aa_data) in enumerate(val_data):
324
+ # compute val loss for batch
325
+ on_diag_v_loss, off_diag_v_loss = self.forward(
326
+ mol_data, aa_data, is_training=False
327
+ )
328
+
329
+ # add i-th loss to val container
330
+ v_loss = on_diag_v_loss + off_diag_v_loss
331
+ val_loss += v_loss.item()
332
+
333
+ # add mean epoc loss for val data to history dictionary
334
+ self.history["validation_loss"].append(val_loss / len(val_data))
335
+
336
+ # add val loss to msg
337
+ msg += f", Val loss: {val_loss / len(val_data):.3f}"
338
+
339
+ # early stopping
340
+ if self.early_stopping(patience=patience):
341
+ break
342
+
343
+ pbar.set_postfix(
344
+ {
345
+ "train loss": train_loss / len(train_data),
346
+ "val loss": val_loss / len(val_data),
347
+ }
348
+ )
349
+
350
+ else:
351
+ pbar.set_postfix({"train loss": train_loss / len(train_data)})
352
+
353
+ # update scheduler
354
+ self.scheduler.step() # val_loss / len(val_data)
355
+
356
+ if self.param_dict["verbose"] is True:
357
+ print(msg)
358
+
359
+ if self.param_dict["verbose"] is True:
360
+ print("[BT]: Training finished")
361
+
362
+ def encode(
363
+ self, vector: np.ndarray, mode: str = "embedding", normalize: bool = True, encoder: str = "mol"
364
+ ) -> np.ndarray:
365
+ """
366
+ Encodes a given vector using the Barlow Twins model.
367
+
368
+ Args:
369
+ - vector (np.ndarray): the input vector to encode
370
+ - mode (str): the mode to use for encoding, either "embedding" or "projection"
371
+ - normalize (bool): whether to L2 normalize the output vector
372
+
373
+ Returns:
374
+ - np.ndarray: the encoded vector
375
+ """
376
+
377
+ # set mol encoder to eval mode
378
+ self.switch_mode(is_training=False)
379
+
380
+ # convert from numpy to tensor
381
+ if type(vector) is not torch.Tensor:
382
+ vector = torch.from_numpy(vector)
383
+
384
+ # if oly one molecule pair is passed, add a batch dimension
385
+ if len(vector.shape) == 1:
386
+ vector = vector.unsqueeze(0)
387
+
388
+ # get representation
389
+ if encoder == "mol":
390
+ embedding = self.enc_mol(vector)
391
+ if mode == "projection":
392
+ embedding = self.proj(embedding)
393
+ elif encoder == "aa":
394
+ embedding = self.enc_aa(vector)
395
+ if mode == "projection":
396
+ embedding = self.proj(embedding)
397
+ else:
398
+ raise ValueError("[BT]: Encoder not recognized")
399
+
400
+ # L2 normalize (optional)
401
+ if normalize:
402
+ embedding = torch.nn.functional.normalize(embedding)
403
+
404
+ # convert back to numpy
405
+ return embedding.cpu().detach().numpy()
406
+
407
+ def zero_shot(
408
+ self, mol_vector: np.ndarray, aa_vector: np.ndarray, l2_norm: bool = True, device: str = "cpu"
409
+ ) -> np.ndarray:
410
+
411
+ # disable training
412
+ self.switch_mode(is_training=False)
413
+
414
+ # cast aa vectors (pos and neg) to correct size, force single precision
415
+ # to both
416
+ mol_vector = np.array(mol_vector, dtype=np.float32)
417
+ aa_vector = np.array(aa_vector, dtype=np.float32)
418
+
419
+ # convert to tensors
420
+ mol_vector = torch.from_numpy(mol_vector).to(device)
421
+ aa_vector = torch.from_numpy(aa_vector).to(device)
422
+
423
+ # get embeddings
424
+ mol_embedding = self.encode(mol_vector, normalize=l2_norm, encoder="mol")
425
+ aa_embedding = self.encode(aa_vector, normalize=l2_norm, encoder="aa")
426
+
427
+ # concat mol and aa embeddings
428
+ concat = np.concatenate((mol_embedding, aa_embedding), axis=1)
429
+ return concat
430
+
431
+ def zero_shot_explain(
432
+ self, mol_vector, aa_vector, l2_norm: bool = True, device: str = "cpu"
433
+ ):
434
+ self.switch_mode(is_training=False)
435
+
436
+ mol_embedding = self.encode(mol_vector, normalize=l2_norm, encoder="mol")
437
+ aa_embedding = self.encode(aa_vector, normalize=l2_norm, encoder="aa")
438
+
439
+ return torch.cat((mol_embedding, aa_embedding), dim=1)
440
+
441
+ def consume_preprocessor(self, preprocessor) -> None:
442
+ # save attributes related to fingerprint generation from
443
+ # preprocessor object
444
+ self.param_dict["radius"] = preprocessor.radius
445
+ self.param_dict["n_bits"] = preprocessor.n_bits
446
+
447
+ def save_model(self, path: str) -> None:
448
+ # get current date and time for the filename
449
+ now = datetime.now()
450
+ formatted_date = now.strftime("%d%m%Y")
451
+ formatted_time = now.strftime("%H%M")
452
+ folder_name = f"{formatted_date}_{formatted_time}"
453
+
454
+ # make full path string and folder
455
+ folder_path = path + "/" + folder_name
456
+ os.makedirs(folder_path)
457
+
458
+ # make paths for weights, config and history
459
+ weight_path = folder_path + "/weights.pt"
460
+ param_path = folder_path + "/params.pkl"
461
+ history_path = folder_path + "/history.json"
462
+
463
+ # save each Sequential state dict in one object to the path
464
+ torch.save(
465
+ {
466
+ "enc_mol": self.enc_mol.state_dict(),
467
+ "enc_aa": self.enc_aa.state_dict(),
468
+ "proj": self.proj.state_dict(),
469
+ },
470
+ weight_path,
471
+ )
472
+
473
+ # dump params in pkl
474
+ with open(param_path, "wb") as file:
475
+ pickle.dump(self.param_dict, file)
476
+
477
+ # dump history in json
478
+ with open(history_path, "wb") as file:
479
+ pickle.dump(self.history, file)
480
+
481
+ # print if verbose is True
482
+ if self.param_dict["verbose"] is True:
483
+ print(f"[BT]: Model saved at {folder_path}")
484
+
485
+ def load_model(self, path: str) -> None:
486
+ # make weights, config and history paths
487
+ weights_path = path + "/weights.pt"
488
+ param_path = path + "/params.pkl"
489
+ history_path = path + "/history.json"
490
+
491
+ # load weights, history and params
492
+ checkpoint = torch.load(weights_path, map_location=self.device)
493
+ with open(param_path, "rb") as file:
494
+ param_dict = pickle.load(file)
495
+ with open(history_path, "rb") as file:
496
+ history = pickle.load(file)
497
+
498
+ # construct model again, overriding old verbose key with new instance
499
+ verbose = self.param_dict["verbose"]
500
+ self.param_dict = param_dict
501
+ self.param_dict["verbose"] = verbose
502
+ self.history = history
503
+ self.construct_model()
504
+
505
+ # set weights in Sequential models
506
+ self.enc_mol.load_state_dict(checkpoint["enc_mol"])
507
+ self.enc_aa.load_state_dict(checkpoint["enc_aa"])
508
+ self.proj.load_state_dict(checkpoint["proj"])
509
+
510
+ # recreate scheduler and optimizer in order to add new weights
511
+ # to graph
512
+ self.construct_scheduler()
513
+
514
+ # print if verbose is True
515
+ if self.param_dict["verbose"] is True:
516
+ print(f"[BT]: Model loaded from {path}")
517
+ print("[BT]: Loaded parameters:")
518
+ print(self.param_dict)
519
+
520
+ def move_to_device(self, device) -> None:
521
+ # move each Sequential model to device
522
+ self.enc_mol.to(device)
523
+ self.enc_aa.to(device)
524
+ self.proj.to(device)
525
+ self.device = device
model/base_model.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Any, Union
2
+ import torch
3
+ from torch import nn
4
+ import numpy as np
5
+
6
+
7
+ class BaseModel(nn.Module):
8
+ def __init__(self):
9
+ super(BaseModel, self).__init__()
10
+ # set device (gpu 0 or 1 if available or cpu)
11
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
12
+
13
+ # make empty param dict
14
+ self.param_dict = {}
15
+
16
+ # make optimizer options dict
17
+ self.optimizer_dict = {
18
+ "adam": torch.optim.Adam,
19
+ "nadam": torch.optim.NAdam,
20
+ "adamax": torch.optim.Adamax,
21
+ "adamw": torch.optim.AdamW,
22
+ "sgd": torch.optim.SGD,
23
+ }
24
+
25
+ # make loss options dict
26
+ self.loss_dict = {
27
+ "mse": nn.MSELoss,
28
+ "l1": nn.L1Loss,
29
+ "smoothl1": nn.SmoothL1Loss,
30
+ "huber": nn.HuberLoss,
31
+ "cel": nn.CrossEntropyLoss, # Suitable for classification tasks
32
+ "bcel": nn.BCELoss, # Suitable for classification tasks
33
+ }
34
+
35
+ # make activation function options dictionary
36
+ self.activation_dict = {
37
+ "relu": nn.ReLU,
38
+ "swish": nn.Hardswish,
39
+ "leaky_relu": nn.LeakyReLU,
40
+ "elu": nn.ELU,
41
+ "selu": nn.SELU,
42
+ }
43
+
44
+ # make tokenizer placeholder
45
+ self.tokenizer = None
46
+
47
+ # create history dictionary
48
+ self.history = {
49
+ "train_loss": [],
50
+ "on_diag_loss": [],
51
+ "off_diag_loss": [],
52
+ "validation_loss": [],
53
+ "learning_rate": [],
54
+ }
55
+
56
+ # create early stopping params
57
+ self.count = 0
58
+
59
+ def print_config(self) -> None:
60
+ print("[BT]: Current parameter config:")
61
+ print(self.param_dict)
62
+
63
+ def early_stopping(self, patience: int) -> bool:
64
+ # count every epoch that's worse than the best for patience times
65
+ if len(self.history["validation_loss"]) > patience:
66
+ best_loss = min(self.history["validation_loss"])
67
+ if self.history["validation_loss"][-1] > best_loss:
68
+ self.count += 1
69
+ else:
70
+ self.count = 0
71
+ if self.count >= patience:
72
+ if self.param_dict["verbose"] is True:
73
+ print("[BT]: Early stopping")
74
+ return True
75
+ return False
model/model.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from typing import List
3
+ from tqdm import tqdm
4
+ import pandas as pd
5
+ import numpy as np
6
+ import threading
7
+ from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
8
+ import time
9
+ import requests
10
+ import joblib
11
+ # from bio_embeddings.embed import SeqVecEmbedder, ProtTransBertBFDEmbedder, ProtTransT5XLU50Embedder
12
+ from Bio import SeqIO
13
+ import rdkit
14
+ from rdkit import Chem, DataStructs
15
+ from rdkit.Chem import AllChem
16
+ import torch
17
+ from typing import *
18
+ from rdkit import RDLogger
19
+ RDLogger.DisableLog("rdApp.*")
20
+
21
+ from xgboost import XGBClassifier, DMatrix
22
+
23
+ from model.barlow_twins import BarlowTwins
24
+
25
+ # sys.path.append("../utils/")
26
+ from utils.sequence import uniprot2sequence, encode_sequences
27
+
28
+
29
+
30
+ class DTIModel:
31
+ def __init__(self, bt_model_path: str, gbm_model_path: str, encoder: str = "prost_t5"):
32
+ self.bt_model = BarlowTwins()
33
+ self.bt_model.load_model(bt_model_path)
34
+
35
+ self.gbm_model = XGBClassifier()
36
+ self.gbm_model.load_model(gbm_model_path)
37
+
38
+ self.encoder = encoder
39
+
40
+ self.smiles_cache = {}
41
+ self.sequence_cache = {}
42
+
43
+ def _encode_smiles(self, smiles: str, radius: int = 2, bits: int = 1024, features: bool = False):
44
+ if smiles is None:
45
+ return None
46
+ # Check if the SMILES is already in the cache
47
+ if smiles in self.smiles_cache:
48
+ return self.smiles_cache[smiles]
49
+ else:
50
+ # Encode the SMILES and store it in the cache
51
+ try:
52
+ mol = Chem.MolFromSmiles(smiles)
53
+ morgan = AllChem.GetMorganFingerprintAsBitVect(
54
+ mol,
55
+ radius=radius,
56
+ nBits=bits,
57
+ useFeatures=features,
58
+ )
59
+ morgan = np.array(morgan)
60
+ self.smiles_cache[smiles] = morgan
61
+ return morgan
62
+ except Exception as e:
63
+ print(f"Failed to encode SMILES: {smiles}")
64
+ print(e)
65
+ return None
66
+
67
+ def _encode_smiles_mult(self, smiles: List[str], radius: int = 2, bits: int = 1024, features: bool = False):
68
+ morgan = [self._encode_smiles(s, radius, bits, features) for s in smiles]
69
+ return np.array(morgan)
70
+
71
+ def _encode_sequence(self, sequence: str):
72
+ # Clear torch cache
73
+ torch.cuda.empty_cache()
74
+ if sequence is None:
75
+ return None
76
+ # Check if the sequence is already in the cache
77
+ if sequence in self.sequence_cache:
78
+ return self.sequence_cache[sequence]
79
+ else:
80
+ # Encode the sequence and store it in the cache
81
+ try:
82
+ encoded_sequence = encode_sequences([sequence], encoder=self.encoder)
83
+ self.sequence_cache[sequence] = encoded_sequence
84
+ return encoded_sequence
85
+ except Exception as e:
86
+ print(f"Failed to encode sequence: {sequence}")
87
+ print(e)
88
+ return None
89
+
90
+ def _encode_sequence_mult(self, sequences: List[str]):
91
+ seq = [self._encode_sequence(sequence) for sequence in sequences]
92
+ return np.array(seq)
93
+
94
+ def __predict_pair(self, drug_emb: np.ndarray, target_emb: np.ndarray, pred_leaf: bool):
95
+ if drug_emb.shape[0] < target_emb.shape[0]:
96
+ drug_emb = np.tile(drug_emb, (len(target_emb), 1))
97
+ elif len(drug_emb) > len(target_emb):
98
+ target_emb = np.tile(target_emb, (len(drug_emb), 1))
99
+ emb = self.bt_model.zero_shot(drug_emb, target_emb)
100
+
101
+ if pred_leaf:
102
+ d_emb = DMatrix(emb)
103
+ return self.gbm_model.get_booster().predict(d_emb, pred_leaf=True)
104
+ else:
105
+ return self.gbm_model.predict_proba(emb)[:, 1]
106
+
107
+ def predict(self, drug: List[str] or str, target: str, pred_leaf: bool = False):
108
+ if isinstance(drug, str):
109
+ drug_emb = self._encode_smiles(drug)
110
+ else:
111
+ drug_emb = self._encode_smiles_mult(drug)
112
+ target_emb = self._encode_sequence(target)
113
+ return self.__predict_pair(drug_emb, target_emb, pred_leaf)
114
+
115
+ def get_leaf_weights(self):
116
+ return self.gbm_model.get_booster().get_score(importance_type="weight")
117
+
118
+ def _predict_fasta(self, drug: str, fasta_path: str):
119
+ drug_emb = self._encode_smiles(drug)
120
+
121
+ results = []
122
+ # Extract targets from fasta
123
+ for target in tqdm(SeqIO.parse(fasta_path, "fasta"), desc="Predicting targets"):
124
+ target_emb = self._encode_sequence(str(target.seq))
125
+ pred = self.__predict_pair(drug_emb, target_emb)
126
+ results.append(
127
+ {
128
+ "drug": drug,
129
+ "target": target.id,
130
+ "name": target.name,
131
+ "description": target.description,
132
+ "prediction": pred[0]
133
+ }
134
+ )
135
+ return pd.DataFrame(results)
136
+
137
+ def predict_fasta(self, drug: str, fasta_path: str, timeout_seconds: int = 120):
138
+ def process_target(target, results):
139
+ target_emb = self._encode_sequence(str(target.seq))
140
+ pred = self.__predict_pair(drug_emb, target_emb)
141
+ results.append({
142
+ "drug": drug,
143
+ "target": target.id,
144
+ "name": target.name,
145
+ "description": target.description,
146
+ "prediction": pred[0]
147
+ })
148
+
149
+ drug_emb = self._encode_smiles(drug)
150
+ results = []
151
+
152
+ # First, count the total number of records for the progress bar
153
+ total_records = sum(1 for _ in SeqIO.parse(fasta_path, "fasta"))
154
+
155
+ # Extract targets from fasta with a properly initialized tqdm progress bar
156
+ for target in tqdm(SeqIO.parse(fasta_path, "fasta"), total=total_records, desc="Predicting targets"):
157
+ thread_results = []
158
+ thread = threading.Thread(target=process_target, args=(target, thread_results))
159
+ thread.start()
160
+ thread.join(timeout_seconds)
161
+ if thread.is_alive():
162
+ print(f"Skipping target {target.id} due to timeout")
163
+ continue
164
+ results.extend(thread_results)
165
+
166
+ return pd.DataFrame(results)
167
+
168
+ def predict_uniprot(self, drug: List[str] or str, uniprot_id: str):
169
+ return self.predict(drug, uniprot2sequence(uniprot_id))
model/preprocessor.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
3
+ import torch
4
+ from rdkit import Chem, DataStructs
5
+ import pandas as pd
6
+ import pickle as pkl
7
+ import numpy as np
8
+ from sklearn.preprocessing import StandardScaler
9
+ import sys
10
+
11
+ # sys.path.append("../utils/")
12
+ from utils.parallel import *
13
+ from utils.chem import *
14
+ from utils.sequence import *
15
+
16
+
17
+ class Preprocessor:
18
+ def __init__(
19
+ self,
20
+ path: str,
21
+ radius: int = 2,
22
+ n_bits: int = 1024,
23
+ aa_embedding: str = "prottrans_t5_xl_u50",
24
+ num_workers: int = 1,
25
+ ):
26
+ self.path = path
27
+ self.radius = radius
28
+ self.n_bits = n_bits
29
+ self.aa_embedding = aa_embedding
30
+ self.num_workers = num_workers
31
+
32
+ self.data = None
33
+ self.fp = None
34
+ self.aa = None
35
+ self.split = None
36
+ self.label = None
37
+
38
+ self.load_data()
39
+ self.process_data()
40
+
41
+ def load_data(self):
42
+ if os.path.isfile(self.path):
43
+ self.data = pd.read_csv(self.path, low_memory=False)
44
+ else:
45
+ raise ValueError("No data file found in the specified path")
46
+
47
+ def process_data(self):
48
+ if "smiles" not in self.data.columns:
49
+ raise ValueError("No smiles column found in the data")
50
+ if "sequence" not in self.data.columns:
51
+ raise ValueError("No sequence column found in the data")
52
+
53
+ smiles = self.data.smiles.tolist()
54
+ seq = self.data.sequence.tolist()
55
+
56
+ if "split" in self.data.columns:
57
+ self.split = self.data.split.tolist()
58
+ if "label" in self.data.columns:
59
+ self.label = self.data.label.tolist()
60
+
61
+ if self.num_workers > 1:
62
+ mols = parallel(get_mols, self.num_workers, smiles)
63
+ fps = parallel(get_fp, self.num_workers, mols, self.radius, self.n_bits)
64
+ else:
65
+ mols = get_mols(smiles)
66
+
67
+ fps = get_fp(mols, self.radius, self.n_bits)
68
+
69
+ self.fp = store_fp(fps, self.n_bits)
70
+ self.aa = encode_sequences(seq, self.aa_embedding)
71
+
72
+ def return_generator(
73
+ self,
74
+ device,
75
+ batch_size: int = 512,
76
+ include_negatives: bool = False,
77
+ shuffle: bool = True,
78
+ validation_split: float = None,
79
+ ) -> (DataLoader, DataLoader):
80
+
81
+ if self.split is None and self.label is None:
82
+ print("No split or label columns found in the dataset")
83
+ dataset = MolAADataset(device, self.fp, self.aa)
84
+ elif self.split is not None:
85
+ print("Splitting data into train and validation sets from the dataset without considering labels")
86
+ train_fp, train_aa, val_fp, val_aa = [], [], [], []
87
+ for i in range(len(self.fp)):
88
+ if self.split[i] == "train":
89
+ train_fp.append(self.fp[i])
90
+ train_aa.append(self.aa[i])
91
+
92
+ elif self.split[i] == "val":
93
+ val_fp.append(self.fp[i])
94
+ val_aa.append(self.aa[i])
95
+
96
+ train_dataset = MolAADataset(device, train_fp, train_aa)
97
+ val_dataset = MolAADataset(device, val_fp, val_aa)
98
+
99
+ print(f"Train: {len(train_fp)}, Validation: {len(val_fp)}")
100
+
101
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
102
+ validation_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle)
103
+ return train_loader, validation_loader
104
+
105
+ else:
106
+ print("Splitting data into train and validation sets from the dataset")
107
+ train_fp, train_aa, val_fp, val_aa = [], [], [], []
108
+ for i in range(len(self.fp)):
109
+ if self.split[i] == "train":
110
+ if include_negatives and self.label[i] == 0:
111
+ train_fp.append(self.fp[i])
112
+ train_aa.append(self.aa[i] * -1)
113
+ elif self.label[i] == 1:
114
+ train_fp.append(self.fp[i])
115
+ train_aa.append(self.aa[i])
116
+ elif self.split[i] == "val":
117
+ if include_negatives and self.label[i] == 0:
118
+ val_fp.append(self.fp[i])
119
+ val_aa.append(self.aa[i] * -1)
120
+ elif self.label[i] == 1:
121
+ val_fp.append(self.fp[i])
122
+ val_aa.append(self.aa[i])
123
+
124
+ train_dataset = MolAADataset(device, train_fp, train_aa)
125
+ val_dataset = MolAADataset(device, val_fp, val_aa)
126
+
127
+ print(f"Train: {len(train_fp)}, Validation: {len(val_fp)}")
128
+
129
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
130
+ validation_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle)
131
+ return train_loader, validation_loader
132
+
133
+ if validation_split is not None:
134
+ print("Splitting data into train and validation by fractionation from the dataset")
135
+ dataset_size = len(dataset)
136
+ indices = list(range(dataset_size))
137
+ split = int(np.floor(validation_split * dataset_size))
138
+ if shuffle:
139
+ np.random.shuffle(indices)
140
+ train_indices, val_indices = indices[split:], indices[:split]
141
+
142
+ train_sampler = SubsetRandomSampler(train_indices)
143
+ valid_sampler = SubsetRandomSampler(val_indices)
144
+
145
+ train_loader = DataLoader(
146
+ dataset, batch_size=batch_size, sampler=train_sampler
147
+ )
148
+ validation_loader = DataLoader(
149
+ dataset, batch_size=batch_size, sampler=valid_sampler
150
+ )
151
+ return train_loader, validation_loader
152
+
153
+ else:
154
+ train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
155
+ return train_loader, None
156
+
157
+
158
+ class MolAADataset(Dataset):
159
+ def __init__(self, device, mol, aa):
160
+ self.mol = mol
161
+ self.aa = aa
162
+ self.device = device
163
+
164
+ def __len__(self):
165
+ """
166
+ Method necessary for Pytorch training
167
+ """
168
+ return len(self.mol)
169
+
170
+ def __getitem__(self, idx):
171
+ """
172
+ Method necessary for Pytorch training
173
+ """
174
+ mol_sample = torch.tensor(self.mol[idx], dtype=torch.float32)
175
+ aa_sample = torch.tensor(self.aa[idx], dtype=torch.float32)
176
+
177
+ mol_sample = mol_sample.to(self.device)
178
+ aa_sample = aa_sample.to(self.device)
179
+
180
+ return mol_sample, aa_sample
model/stash/14062024_0910/history.json ADDED
Binary file (3.33 kB). View file
 
model/stash/14062024_0910/log.txt ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ----------------
2
+ Run description: Manual param optim
3
+ ----------------
4
+ message: yes
5
+ path: all_drugbank_smiles_sequence_prost_preprocessor.pkl
6
+ load_preprocessor: True
7
+ radius: 2
8
+ n_bits: 1024
9
+ num_workers: 64
10
+ enc_n_neurons: 4096
11
+ enc_n_layers: 3
12
+ proj_n_neurons: 2048
13
+ proj_n_layers: 1
14
+ embedding_dim: 512
15
+ act_function: relu
16
+ aa_emb_size: 1024
17
+ loss_weight: 0.005
18
+ batch_size: 4096
19
+ epochs: 250
20
+ optimizer: adamw
21
+ learning_rate: 0.0003
22
+ beta_1: 0.9
23
+ beta_2: 0.999
24
+ weight_decay: 5e-05
25
+ step_size: 10
26
+ gamma: 0.1
27
+ include_negatives: False
28
+ hyperparameter_tuning: False
29
+ val_split: 0.1
30
+ aa_embedding: prost_t5
31
+ model_type: barlow_twins
32
+ device: cuda:0
33
+ msg: Manual param optim
34
+ start: 1718356109.3235965
35
+ data: <preprocessor.Preprocessor object at 0x72f2d495eb10>
36
+ train: <torch.utils.data.dataloader.DataLoader object at 0x72f2d3a66d50>
37
+ val: <torch.utils.data.dataloader.DataLoader object at 0x72f2d480e7b0>
38
+ file: <_io.BufferedReader name='all_drugbank_smiles_sequence_prost_preprocessor.pkl'>
39
+ t_preprocessing: 0
40
+ model: <barlow_twins.BarlowTwins object at 0x72f2d7652540>
41
+ t_model: 1
model/stash/14062024_0910/params.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:065b380d18b2c40bfe031b14480665e5603fcaf06a731f8bc0ec92d829bb2169
3
+ size 423
model/stash/14062024_0910/weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55014d6bc054a1aefc22e9c893deaf25939a639efa63f46e2083ff602a5961f1
3
+ size 340300017
model/xgb_models/14062024_0910_barlowdti_xxl_model.json ADDED
The diff for this file is too large to render. See raw diff
 
model/xgb_models/xgb_model_BIOSNAP_full_data_14062024_0910_bt_optimized_0.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1481be9c69558a91c41d65970ba60ace4cb685a4c90b03be37a813b9f1abc96
3
+ size 27471157
model/xgb_models/xgb_model_BIOSNAP_missing_data_70_14062024_0910_bt_optimized_0.json ADDED
The diff for this file is too large to render. See raw diff
 
model/xgb_models/xgb_model_BIOSNAP_missing_data_80_14062024_0910_bt_optimized_0.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ab4553ac67b4d75b85eae69c6a19daaad8c6575c3d01252dc8b58682656551b
3
+ size 12831515
model/xgb_models/xgb_model_BIOSNAP_missing_data_90_14062024_0910_bt_optimized_0.json ADDED
The diff for this file is too large to render. See raw diff
 
model/xgb_models/xgb_model_BIOSNAP_missing_data_95_14062024_0910_bt_optimized_0.json ADDED
The diff for this file is too large to render. See raw diff
 
model/xgb_models/xgb_model_BIOSNAP_unseen_drug_14062024_0910_bt_optimized_0.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3992b670436e6e2c728eade63581c1962f7cc546b81fb61cd43b6f9eb426f17
3
+ size 40338690
model/xgb_models/xgb_model_BIOSNAP_unseen_protein_14062024_0910_bt_optimized_0.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:557742dd11578818bbe6454c946ab2d5a5846556457d22c89cdbf5b47bd34831
3
+ size 18191873
model/xgb_models/xgb_model_BindingDB_14062024_0910_bt_optimized_0.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84e911499ec13f38e1edc4b006faf2ef3e827d1d7d0fd53f481e0e41c82d59c1
3
+ size 24742914
model/xgb_models/xgb_model_DAVIS_14062024_0910_bt_optimized_0.json ADDED
The diff for this file is too large to render. See raw diff
 
model/xgb_models/xgb_model_nature_mach_intel_BindingDB_cluster_14062024_0910_bt_optimized_0.json ADDED
The diff for this file is too large to render. See raw diff
 
model/xgb_models/xgb_model_nature_mach_intel_BindingDB_protein_14062024_0910_bt_optimized_0.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4a4b08241bf5779e9ef688b6c5a452ac13f4a67480ec6c17cc203ddd35ab7f7
3
+ size 16983875
model/xgb_models/xgb_model_nature_mach_intel_BindingDB_random_14062024_0910_bt_optimized_0.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef54574bb754850ec34c0769df1222c6087541fc5e5bb3e17653982e079fb440
3
+ size 64523467
model/xgb_models/xgb_model_nature_mach_intel_BindingDB_scaffold_14062024_0910_bt_optimized_0.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:716944cd5a88e6b7dd062a3c9cc331980908541d1b8039f321bdda0112c6668d
3
+ size 25668977
model/xgb_models/xgb_model_nature_mach_intel_BioSNAP_cluster_14062024_0910_bt_optimized_0.json ADDED
The diff for this file is too large to render. See raw diff
 
model/xgb_models/xgb_model_nature_mach_intel_BioSNAP_protein_14062024_0910_bt_optimized_0.json ADDED
The diff for this file is too large to render. See raw diff
 
model/xgb_models/xgb_model_nature_mach_intel_BioSNAP_random_14062024_0910_bt_optimized_0.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:faaa3a7fcb8efd23876b23b9a07620bd4ca007d05c354e6bfd2c413f3244402b
3
+ size 18444715
model/xgb_models/xgb_model_nature_mach_intel_BioSNAP_scaffold_14062024_0910_bt_optimized_0.json ADDED
The diff for this file is too large to render. See raw diff
 
model/xgb_models/xgb_model_nature_mach_intel_Human_protein_14062024_0910_bt_optimized_0.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e97db0190ff7a15820982d35191f0092319801ea2992c2ef545b9028a8d2ca1
3
+ size 12630195
model/xgb_models/xgb_model_nature_mach_intel_Human_random_14062024_0910_bt_optimized_0.json ADDED
The diff for this file is too large to render. See raw diff
 
model/xgb_models/xgb_model_nature_mach_intel_Human_scaffold_14062024_0910_bt_optimized_0.json ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Babel==2.14.0
2
+ biopython==1.83
3
+ chembl-structure-pipeline==1.2.2
4
+ ConfigSpace==0.7.1
5
+ cycler==0.12.1
6
+ dask==2024.5.1
7
+ joblib==1.4.0
8
+ keras==3.4.1
9
+ numpy==1.26.4
10
+ optuna==3.6.1
11
+ pandas==2.2.2
12
+ plotly
13
+ rdkit==2023.9.5
14
+ scikit-learn==1.4.2
15
+ scipy==1.13.0
16
+ seaborn==0.13.2
17
+ sentencepiece==0.2.0
18
+ shap==0.46.0
19
+ smac==2.1.0
20
+ tensorflow==2.17.0
21
+ torch==2.4.0
22
+ tqdm==4.66.2
23
+ transformers==4.41.0
24
+ umap==0.1.1
25
+ xgboost==2.0.3
utils/__init__.py ADDED
File without changes
utils/chem.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import rdkit
2
+ from rdkit import Chem, DataStructs
3
+ from rdkit.Chem import AllChem
4
+ from typing import *
5
+ import numpy as np
6
+ from rdkit import RDLogger
7
+
8
+ RDLogger.DisableLog("rdApp.*")
9
+
10
+
11
+ def try_or_none(func, *args, **kwargs):
12
+ try:
13
+ return func(*args, **kwargs)
14
+ except:
15
+ return None
16
+
17
+
18
+ def get_smiles(mols: List[rdkit.Chem.rdchem.Mol]) -> List[str]:
19
+ """
20
+ Gets list of smiles from list of rdkit molecules
21
+ """
22
+ return [Chem.MolToSmiles(x) for x in mols]
23
+
24
+
25
+ def get_mols(smiles: List[str]) -> List[rdkit.Chem.rdchem.Mol]:
26
+ """
27
+ Gets list of rdkit molecules from list of smiles
28
+ """
29
+ return [Chem.MolFromSmiles(x) for x in smiles]
30
+
31
+
32
+ def get_fp(
33
+ mols: List[rdkit.Chem.rdchem.Mol],
34
+ radius: int = 2,
35
+ nBits: int = 1024,
36
+ useFeatures: bool = False,
37
+ ):
38
+ """
39
+ Computes ECFP/FCFP from list of RDKIT mols
40
+ """
41
+
42
+ output = np.empty(len(mols), dtype=object)
43
+
44
+ for i, mol in enumerate(mols):
45
+ output[i] = AllChem.GetMorganFingerprintAsBitVect(
46
+ mol,
47
+ radius=radius,
48
+ nBits=nBits,
49
+ useFeatures=useFeatures,
50
+ )
51
+
52
+ return output
53
+
54
+
55
+ def store_fp(fps: List, nBits: int = 1024):
56
+ """
57
+ Stores list of RDKIT sparse vectors in numpy array using C data structures
58
+ """
59
+
60
+ array = np.empty((len(fps), nBits), dtype=np.float32)
61
+ for i in range(len(array)):
62
+ DataStructs.ConvertToNumpyArray(fps[i], array[i])
63
+
64
+ return array
utils/parallel.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing
2
+ import numpy as np
3
+ import psutil
4
+ from typing import *
5
+
6
+
7
+ def parallel(function: Callable, n_jobs: int, x: List, *args) -> List:
8
+ """Higher order function to run other functions on multiple processes
9
+
10
+ Simple parallelization utility, slices the input list x in chunks and
11
+ executes the function on each chunk in different processes. Not suited
12
+ for functions that have already multithreading/processing implemented.
13
+
14
+ Args:
15
+ function: callable to run on different processes
16
+ n_jobs: how many cores to use
17
+ x: list (M,) to use as input for function
18
+ *args: optional arguments for function
19
+
20
+ Returns:
21
+ Object (M,) containing the output of function. Content and type depend
22
+ on function. If function returns list, then parallel will also return
23
+ a list. If function returns a numpy array, then parallel will return an
24
+ array.
25
+ """
26
+
27
+ # check that parallelization is required. n_jobs might be passed as 1 by
28
+ # i.e. Dataset methods if they notice that the loaded HTS is too large
29
+ # to be used on different cores.
30
+ if n_jobs > 1:
31
+ # split list in chunks
32
+ chunks = split_list(x, n_jobs)
33
+
34
+ # create list of tuples containing the chunks and *args
35
+ args = stitch_args(chunks, args)
36
+
37
+ # create multiprocessing pool and run function on chunks
38
+ pool = multiprocessing.Pool(n_jobs)
39
+ output = pool.starmap(function, args)
40
+ pool.close()
41
+
42
+ # unroll output (list of function outputs) into a single object
43
+ # of size M
44
+ if isinstance(output[0], list):
45
+ unrolled = [x for k in output for x in k]
46
+ elif isinstance(output[0], np.ndarray):
47
+ unrolled = np.concatenate(output, axis=0)
48
+
49
+ else:
50
+ # run function normally
51
+ unrolled = function(x, *args)
52
+
53
+ return unrolled
54
+
55
+
56
+ def stitch_args(chunks: List[List], args: Tuple) -> List[Tuple]:
57
+ """
58
+ Stitches together the chunks to be run in parallel and optional function
59
+ arguments into tuples
60
+ """
61
+ output = [[x] for x in chunks]
62
+ for i in range(len(output)):
63
+ for j in range(len(args)):
64
+ output[i].append(args[j])
65
+
66
+ return [tuple(x) for x in output]
67
+
68
+
69
+ def split_list(x: List, n_jobs: int) -> List[List]:
70
+ """
71
+ Converts a list into a list of lists of size n_jobs.
72
+ """
73
+ idxs = np.array_split(range(len(x)), n_jobs)
74
+ output = [0] * n_jobs
75
+ for i in range(n_jobs):
76
+ output[i] = [x[k] for k in idxs[i]]
77
+
78
+ return output
utils/sequence.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import numpy as np
3
+ # from bio_embeddings.embed import SeqVecEmbedder, ProtTransBertBFDEmbedder, ProtTransT5XLU50Embedder
4
+ from transformers import T5Tokenizer, T5EncoderModel
5
+ import torch
6
+ import re
7
+ import concurrent.futures
8
+ from tqdm.auto import tqdm
9
+ import multiprocessing
10
+ from multiprocessing import Pool
11
+
12
+
13
+ ENCODERS = {
14
+ # "seqvec": SeqVecEmbedder(),
15
+ # "prottrans_bert_bfd": ProtTransBertBFDEmbedder(),
16
+ # "prottrans_t5_xl_u50": ProtTransT5XLU50Embedder(),
17
+ "prot_t5": {
18
+ "tokenizer": T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False),
19
+ "model": T5EncoderModel.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc')
20
+ },
21
+ "prost_t5": {
22
+ "tokenizer": T5Tokenizer.from_pretrained("Rostlab/ProstT5", do_lower_case=False),
23
+ "model": T5EncoderModel.from_pretrained("Rostlab/ProstT5")
24
+ }
25
+ }
26
+
27
+
28
+ def drugbank2smiles(drugbank_id):
29
+ url = f"https://go.drugbank.com/drugs/{drugbank_id}.smiles"
30
+ response = requests.get(url)
31
+
32
+ if response.status_code == 200:
33
+ return response.text
34
+ else:
35
+ # print(f"Failed to get SMILES for {drugbank_id}")
36
+ return None
37
+
38
+
39
+ def uniprot2sequence(uniprot_id):
40
+ url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.fasta"
41
+ response = requests.get(url)
42
+
43
+ if response.status_code == 200:
44
+ # Extract sequence from FASTA format
45
+ sequence = "".join(response.text.split("\n")[1:])
46
+ return sequence
47
+ else:
48
+ # print(f"Failed to get sequence for {uniprot_id}")
49
+ return None
50
+
51
+
52
+ def encode_sequences(sequences: list, encoder: str):
53
+ if encoder not in ENCODERS.keys():
54
+ raise ValueError(f"Invalid encoder: {encoder}")
55
+
56
+ model = ENCODERS[encoder]["model"]
57
+ tokenizer = ENCODERS[encoder]["tokenizer"]
58
+
59
+ # Cache for storing encoded sequences
60
+ cache = {}
61
+
62
+ def encode_sequence(sequence: str):
63
+ if sequence is None:
64
+ return None
65
+ if len(sequence) <= 3:
66
+ raise ValueError(f"Invalid sequence: {sequence}")
67
+ # Check if the sequence is already in the cache
68
+ if sequence in cache:
69
+ return cache[sequence]
70
+ else:
71
+ # Encode the sequence and store it in the cache
72
+ try:
73
+ encoded_sequence = model.embed(sequence)
74
+ encoded_sequence = np.mean(encoded_sequence, axis=0)
75
+ cache[sequence] = encoded_sequence
76
+ return encoded_sequence
77
+ except Exception as e:
78
+ print(f"Failed to encode sequence: {sequence}")
79
+ print(e)
80
+ return None
81
+
82
+ def encode_sequence_device_failover(sequence: str, function, timeout: int = 120):
83
+ if sequence is None:
84
+ return None
85
+
86
+ if sequence in cache:
87
+ return cache[sequence]
88
+
89
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
90
+ torch.cuda.empty_cache()
91
+
92
+ try:
93
+ # Try to process using GPU
94
+ result = function(sequence, device)
95
+ except RuntimeError as e:
96
+ print(e)
97
+ return None
98
+ if "CUDA out of memory." in str(e):
99
+ print("Trying on CPU instead.")
100
+ device = torch.device("cpu")
101
+ with concurrent.futures.ThreadPoolExecutor() as executor:
102
+ future = executor.submit(function, sequence, device)
103
+ try:
104
+ result = future.result(timeout=timeout)
105
+ except concurrent.futures.TimeoutError:
106
+ print(f"CPU encoding timed out.")
107
+ cache[sequence] = None
108
+ return None
109
+ else:
110
+ cache[sequence] = None
111
+ raise Exception(e)
112
+ except Exception as e:
113
+ print(f"Failed to encode sequence: {sequence}")
114
+ cache[sequence] = None
115
+ return None
116
+
117
+ cache[sequence] = result
118
+ return result
119
+
120
+ def encode_sequence_hf_3d(sequence, device):
121
+ sequence_1d_list = [sequence]
122
+ model.full() if device == "cpu" else model.half()
123
+ model.to(device)
124
+
125
+ ids = tokenizer.batch_encode_plus(
126
+ sequence_1d_list,
127
+ add_special_tokens=True,
128
+ padding="longest",
129
+ return_tensors="pt"
130
+ ).to(device)
131
+
132
+ with torch.no_grad():
133
+ embedding = model(
134
+ ids.input_ids,
135
+ attention_mask=ids.attention_mask
136
+ )
137
+
138
+ # Skip the first token, which is the special token for the entire sequence and mean pool the rest
139
+ assert embedding.last_hidden_state.shape[0] == 1
140
+
141
+ encoded_sequence = embedding.last_hidden_state[0, 1:-1, :]
142
+ encoded_sequence = encoded_sequence.mean(dim=0).cpu().numpy().flatten()
143
+
144
+ assert encoded_sequence.shape[0] == 1024
145
+ return encoded_sequence
146
+
147
+ def encode_sequence_hf(sequence, device):
148
+ sequence_1d_list = [sequence]
149
+ model.full() if device == "cpu" else model.half()
150
+ model.to(device)
151
+
152
+ ids = tokenizer.batch_encode_plus(
153
+ sequence_1d_list,
154
+ add_special_tokens=True,
155
+ padding="longest",
156
+ return_tensors="pt"
157
+ ).to(device)
158
+
159
+ with torch.no_grad():
160
+ embedding = model(
161
+ ids.input_ids,
162
+ attention_mask=ids.attention_mask
163
+ )
164
+
165
+ assert embedding.last_hidden_state.shape[0] == 1
166
+
167
+ encoded_sequence = embedding.last_hidden_state[0, :-1, :]
168
+ encoded_sequence = encoded_sequence.mean(dim=0).cpu().numpy().flatten()
169
+
170
+ assert encoded_sequence.shape[0] == 1024
171
+ return encoded_sequence
172
+
173
+ # Use list comprehension to encode all sequences, utilizing the cache
174
+ if encoder == "seqvec":
175
+ raise NotImplementedError("SeqVec is not supported")
176
+ seq = encoder_function.embed(list(sequences))
177
+ seq = np.sum(seq, axis=0)
178
+
179
+ if encoder == "prost_t5":
180
+ sequences = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequences]
181
+ # The direction of the translation is indicated by two special tokens:
182
+ # if you go from AAs to 3Di (or if you want to embed AAs), you need to prepend "<AA2fold>"
183
+ # if you go from 3Di to AAs (or if you want to embed 3Di), you need to prepend "<fold2AA>"
184
+ sequences = ["<AA2fold>" + " " + s if s.isupper() else "<fold2AA>" + " " + s for s in sequences]
185
+ seq = [encode_sequence_device_failover(sequence, encode_sequence_hf_3d) for sequence in tqdm(sequences, desc="Encoding sequences")]
186
+
187
+ elif encoder == "prot_t5":
188
+ sequences = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequences]
189
+ seq = [encode_sequence_device_failover(sequence, encode_sequence_hf) for sequence in tqdm(sequences, desc="Encoding sequences")]
190
+
191
+ else:
192
+ raise NotImplementedError("SeqVec is not supported")
193
+ seq = [encode_sequence(sequence) for sequence in sequences]
194
+
195
+ return np.array(seq)
196
+
197
+
198
+ class SequenceEncoder:
199
+ def __init__(self, encoder: str):
200
+ if encoder not in ENCODERS:
201
+ raise ValueError(f"Invalid encoder: {encoder}")
202
+ self.encoder = encoder
203
+ self.model = ENCODERS[encoder]["model"]
204
+ self.tokenizer = ENCODERS[encoder]["tokenizer"]
205
+ self.cache = {}
206
+
207
+ def encode_sequence(self, sequence: str):
208
+ if sequence is None:
209
+ return None
210
+ if len(sequence) <= 3:
211
+ raise ValueError(f"Invalid sequence: {sequence}")
212
+
213
+ if sequence in self.cache:
214
+ return self.cache[sequence]
215
+
216
+ try:
217
+ encoded_sequence = self.model.embed(sequence)
218
+ encoded_sequence = np.mean(encoded_sequence, axis=0)
219
+ self.cache[sequence] = encoded_sequence
220
+ return encoded_sequence
221
+ except Exception as e:
222
+ print(f"Failed to encode sequence: {sequence}")
223
+ print(e)
224
+ return None
225
+
226
+ def encode_sequence_device_failover(self, sequence: str, function, timeout: int = 5):
227
+ if sequence is None:
228
+ return None
229
+
230
+ if sequence in self.cache:
231
+ return self.cache[sequence]
232
+
233
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
234
+ torch.cuda.empty_cache()
235
+
236
+ try:
237
+ result = function(sequence, device)
238
+ except RuntimeError as e:
239
+ return None
240
+ print(e)
241
+ if "CUDA out of memory." in str(e):
242
+ print("Trying on CPU instead.")
243
+ device = torch.device("cpu")
244
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
245
+ future = executor.submit(function, sequence, device)
246
+ try:
247
+ result = future.result(timeout=timeout)
248
+ except:
249
+ print(f"CPU encoding timed out.")
250
+ self.cache[sequence] = None
251
+ return None
252
+ finally:
253
+ executor.shutdown(wait=False)
254
+ else:
255
+ self.cache[sequence] = None
256
+ return None
257
+ except Exception as e:
258
+ print(f"Failed to encode sequence: {sequence}")
259
+ self.cache[sequence] = None
260
+ return None
261
+
262
+ self.cache[sequence] = result
263
+ return result
264
+
265
+ def encode_sequence_hf_3d(self, sequence, device):
266
+ sequence_1d_list = [sequence]
267
+ self.model.full() if device == "cpu" else self.model.half()
268
+ self.model.to(device)
269
+
270
+ ids = self.tokenizer.batch_encode_plus(
271
+ sequence_1d_list,
272
+ add_special_tokens=True,
273
+ padding="longest",
274
+ return_tensors="pt"
275
+ ).to(device)
276
+
277
+ with torch.no_grad():
278
+ embedding = self.model(
279
+ ids.input_ids,
280
+ attention_mask=ids.attention_mask
281
+ )
282
+
283
+ assert embedding.last_hidden_state.shape[0] == 1
284
+
285
+ encoded_sequence = embedding.last_hidden_state[0, 1:-1, :]
286
+ encoded_sequence = encoded_sequence.mean(dim=0).cpu().numpy().flatten()
287
+
288
+ assert encoded_sequence.shape[0] == 1024
289
+ return encoded_sequence
290
+
291
+ def encode_sequence_hf(self, sequence, device):
292
+ sequence_1d_list = [sequence]
293
+ self.model.full() if device == "cpu" else self.model.half()
294
+ self.model.to(device)
295
+
296
+ ids = self.tokenizer.batch_encode_plus(
297
+ sequence_1d_list,
298
+ add_special_tokens=True,
299
+ padding="longest",
300
+ return_tensors="pt"
301
+ ).to(device)
302
+
303
+ with torch.no_grad():
304
+ embedding = self.model(
305
+ ids.input_ids,
306
+ attention_mask=ids.attention_mask
307
+ )
308
+
309
+ assert embedding.last_hidden_state.shape[0] == 1
310
+
311
+ encoded_sequence = embedding.last_hidden_state[0, :-1, :]
312
+ encoded_sequence = encoded_sequence.mean(dim=0).cpu().numpy().flatten()
313
+
314
+ assert encoded_sequence.shape[0] == 1024
315
+ return encoded_sequence
316
+
317
+ def encode_sequences(self, sequences: list):
318
+ if self.encoder == "seqvec":
319
+ raise NotImplementedError("SeqVec is not supported")
320
+ seq = self.encoder_function.embed(list(sequences))
321
+ seq = np.sum(seq, axis=0)
322
+
323
+ elif self.encoder == "prost_t5":
324
+ sequences = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequences]
325
+ sequences = ["<AA2fold>" + " " + s if s.isupper() else "<fold2AA>" + " " + s for s in sequences]
326
+ seq = [self.encode_sequence_device_failover(sequence, self.encode_sequence_hf_3d) for sequence in tqdm(sequences, desc="Encoding sequences")]
327
+
328
+ elif self.encoder == "prot_t5":
329
+ sequences = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequences]
330
+ seq = [self.encode_sequence_device_failover(sequence, self.encode_sequence_hf) for sequence in tqdm(sequences, desc="Encoding sequences")]
331
+
332
+ else:
333
+ raise NotImplementedError("SeqVec is not supported")
334
+ seq = [self.encode_sequence(sequence) for sequence in sequences]
335
+
336
+ if any([x is None for x in seq]):
337
+ return seq
338
+ else:
339
+ return np.array(seq)