mgyigit commited on
Commit
8279c69
1 Parent(s): b68823e

Upload 11 files

Browse files
Files changed (11) hide show
  1. app.py +148 -0
  2. gradio_app.py +164 -0
  3. inference.py +268 -0
  4. layers.py +106 -0
  5. loss.py +36 -0
  6. models.py +93 -0
  7. new_dataloader.py +311 -0
  8. packages.txt +1 -0
  9. requirements.txt +12 -0
  10. training_data.py +31 -0
  11. utils.py +421 -0
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import streamlit_ext as ste
3
+
4
+ from inference import Inference
5
+ import random
6
+ from rdkit.Chem import Draw
7
+ from rdkit import Chem
8
+ from rdkit.Chem.Draw import IPythonConsole
9
+ import io
10
+ from PIL import Image
11
+
12
+ class DrugGENConfig:
13
+ submodel='DrugGEN'
14
+ act='relu'
15
+ max_atom=45
16
+ dim=32
17
+ depth=1
18
+ heads=8
19
+ mlp_ratio=3
20
+ dropout=0.
21
+ features=False
22
+ inference_sample_num=1000
23
+ inf_batch_size=1
24
+ protein_data_dir='data/akt'
25
+ drug_index='data/drug_smiles.index'
26
+ drug_data_dir='data/akt'
27
+ mol_data_dir='data'
28
+ log_dir='experiments/logs'
29
+ model_save_dir='experiments/models'
30
+ sample_dir='experiments/samples'
31
+ result_dir="experiments/tboard_output"
32
+ inf_dataset_file="chembl45_test.pt"
33
+ inf_drug_dataset_file='akt_test.pt'
34
+ inf_raw_file='data/chembl_test.smi'
35
+ inf_drug_raw_file="data/akt_test.smi"
36
+ inference_model="experiments/models/DrugGEN"
37
+ log_sample_step=1000
38
+ set_seed=False
39
+ seed=1
40
+
41
+ class NoTargetConfig(DrugGENConfig):
42
+ submodel="NoTarget"
43
+ dim=128
44
+ inference_model="experiments/models/NoTarget"
45
+
46
+
47
+ model_configs = {
48
+ "DrugGEN": DrugGENConfig(),
49
+ "NoTarget": NoTargetConfig()
50
+ }
51
+
52
+
53
+ with st.sidebar:
54
+ st.title("DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks")
55
+ st.write("[![arXiv](https://img.shields.io/badge/arXiv-2302.07868-b31b1b.svg)](https://arxiv.org/abs/2302.07868) [![github-repository](https://img.shields.io/badge/GitHub-black?logo=github)](https://github.com/HUBioDataLab/DrugGEN)")
56
+
57
+ with st.expander("Expand to display information about models"):
58
+ st.write("""
59
+ ### Model Variations
60
+ - **DrugGEN-Prot**: composed of two GANs, incorporates protein features to the transformer decoder module of GAN2 (together with the de novo molecules generated by GAN1) to direct the target centric molecule design.
61
+ - **DrugGEN-CrossLoss**: composed of one GAN, the input of the GAN1 generator is the real molecules dataset and the GAN1 discriminator compares the generated molecules with the real inhibitors of the given target.
62
+ - **DrugGEN-NoTarget**: composed of one GAN, focuses on learning the chemical properties from the ChEMBL training dataset, no target-specific generation.
63
+
64
+ """)
65
+
66
+ with st.form("model_selection_from"):
67
+ model_name = st.radio(
68
+ 'Select a model to make inference (DrugGEN-Prot and DrugGEN-CrossLoss models design molecules to target the AKT1 protein)',
69
+ ('DrugGEN-Prot', 'DrugGEN-CrossLoss', 'DrugGEN-NoTarget')
70
+ )
71
+
72
+ model_name = model_name.replace("DrugGEN-", "")
73
+
74
+ molecule_num_input = st.number_input('Number of molecules to generate', min_value=1, max_value=100_000, value=1000, step=1)
75
+
76
+ seed_input = st.number_input("RNG seed value (can be used for reproducibility):", min_value=0, value=42, step=1)
77
+
78
+ submitted = st.form_submit_button("Start Computing")
79
+
80
+
81
+
82
+ if submitted:
83
+ # if submitted or ("submitted" in st.session_state):
84
+ # st.session_state["submitted"] = True
85
+ config = model_configs[model_name]
86
+
87
+ config.inference_sample_num = molecule_num_input
88
+ config.seed = seed_input
89
+
90
+ with st.spinner(f'Creating the trainer class instance for {model_name}...'):
91
+ trainer = Trainer(config)
92
+ with st.spinner(f'Running inference function of {model_name} (this may take a while) ...'):
93
+ results = trainer.inference()
94
+ st.success(f"Inference of {model_name} took {results['runtime']:.2f} seconds.")
95
+
96
+ with st.expander("Expand to see the generation performance scores"):
97
+ st.write("### Generation performance scores (novelty is calculated in comparison to the training dataset)")
98
+ st.success(f"Validity: {results['fraction_valid']}")
99
+ st.success(f"Uniqueness: {results['uniqueness']}")
100
+ st.success(f"Novelty: {results['novelty']}")
101
+
102
+ with open(f'experiments/inference/{model_name}/inference_drugs.txt') as f:
103
+ inference_drugs = f.read()
104
+ # st.download_button(label="Click to download generated molecules", data=inference_drugs, file_name=f'DrugGEN-{model_name}_denovo_mols.smi', mime="text/plain")
105
+ ste.download_button(label="Click to download generated molecules", data=inference_drugs, file_name=f'DrugGEN-{model_name}_denovo_mols.smi', mime="text/plain")
106
+
107
+
108
+ st.write("Structures of randomly selected 12 de novo molecules from the inference set:")
109
+ # from rdkit.Chem import Draw
110
+ # img = Draw.MolsToGridImage(mol_list, molsPerRow=5, subImgSize=(250, 250), maxMols=num_mols,
111
+ # legends=None, useSVG=True)
112
+ generated_molecule_list = inference_drugs.split("\n")
113
+
114
+ selected_molecules = random.choices(generated_molecule_list,k=12)
115
+
116
+ selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_molecules]
117
+ # IPythonConsole.UninstallIPythonRenderer()
118
+ drawOptions = Draw.rdMolDraw2D.MolDrawOptions()
119
+ drawOptions.prepareMolsBeforeDrawing = False
120
+ drawOptions.bondLineWidth = 1.
121
+
122
+ molecule_image = Draw.MolsToGridImage(
123
+ selected_molecules,
124
+ molsPerRow=3,
125
+ subImgSize=(250, 250),
126
+ maxMols=len(selected_molecules),
127
+ # legends=None,
128
+ returnPNG=False,
129
+ # drawOptions=drawOptions,
130
+ highlightAtomLists=None,
131
+ highlightBondLists=None,
132
+
133
+ )
134
+ print(type(molecule_image))
135
+ # print(type(molecule_image._data_and_metadata()))
136
+ molecule_image.save("result_grid.png")
137
+ # png_data = io.BytesIO()
138
+ # molecule_image.save(png_data, format='PNG')
139
+ # png_data.seek(0)
140
+
141
+ # Step 2: Read the PNG image data as a PIL image
142
+ # pil_image = Image.open(png_data)
143
+ # st.image(pil_image)
144
+ st.image(molecule_image)
145
+
146
+ else:
147
+ st.warning("Please select a model to make inference")
148
+
gradio_app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from inference import Inference
3
+ import PIL
4
+ from PIL import Image
5
+ import pandas as pd
6
+ import random
7
+ from rdkit import Chem
8
+ from rdkit.Chem import Draw
9
+ from rdkit.Chem.Draw import IPythonConsole
10
+ import shutil
11
+
12
+ class DrugGENConfig:
13
+ submodel='DrugGEN'
14
+ act='relu'
15
+ max_atom=45
16
+ dim=32
17
+ depth=1
18
+ heads=8
19
+ mlp_ratio=3
20
+ dropout=0.
21
+ features=False
22
+ inference_sample_num=1000
23
+ inf_batch_size=1
24
+ protein_data_dir='data/akt'
25
+ drug_index='data/drug_smiles.index'
26
+ drug_data_dir='data/akt'
27
+ mol_data_dir='data'
28
+ log_dir='experiments/logs'
29
+ model_save_dir='experiments/models'
30
+ inference_model="experiments/models/DrugGEN"
31
+ sample_dir='experiments/samples'
32
+ result_dir="experiments/tboard_output"
33
+ dataset_file="chembl45_train.pt"
34
+ drug_dataset_file="akt_train.pt"
35
+ raw_file='data/chembl_train.smi'
36
+ drug_raw_file="data/akt_train.smi"
37
+ inf_dataset_file="chembl45_test.pt"
38
+ inf_drug_dataset_file='akt_test.pt'
39
+ inf_raw_file='data/chembl_test.smi'
40
+ inf_drug_raw_file="data/akt_test.smi"
41
+ log_sample_step=1000
42
+ set_seed=False
43
+ seed=1
44
+
45
+
46
+ class NoTargetConfig(DrugGENConfig):
47
+ submodel="NoTarget"
48
+ dim=128
49
+ inference_model="experiments/models/NoTarget"
50
+
51
+
52
+ model_configs = {
53
+ "DrugGEN": DrugGENConfig(),
54
+ "NoTarget": NoTargetConfig(),
55
+ }
56
+
57
+
58
+
59
+ def function(model_name: str, mol_num: int, seed: int) -> tuple[PIL.Image, pd.DataFrame, str]:
60
+ '''
61
+ Returns:
62
+ image, score_df, file path
63
+ '''
64
+
65
+ config = model_configs[model_name]
66
+ config.inference_sample_num = mol_num
67
+ config.seed = seed
68
+
69
+ inferer = Inference(config)
70
+ scores = inferer.inference() # create scores_df out of this
71
+
72
+ score_df = pd.DataFrame(scores, index=[0])
73
+
74
+ output_file_path = f'experiments/inference/{model_name}/inference_drugs.txt'
75
+
76
+ import os
77
+ new_path = f'{model_name}_denovo_mols.smi'
78
+ os.rename(output_file_path, new_path)
79
+
80
+ with open(new_path) as f:
81
+ inference_drugs = f.read()
82
+
83
+ generated_molecule_list = inference_drugs.split("\n")
84
+
85
+ rng = random.Random(seed)
86
+ selected_molecules = rng.choices(generated_molecule_list,k=12)
87
+ selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_molecules]
88
+
89
+ drawOptions = Draw.rdMolDraw2D.MolDrawOptions()
90
+ drawOptions.prepareMolsBeforeDrawing = False
91
+ drawOptions.bondLineWidth = 0.5
92
+
93
+ molecule_image = Draw.MolsToGridImage(
94
+ selected_molecules,
95
+ molsPerRow=3,
96
+ subImgSize=(400, 400),
97
+ maxMols=len(selected_molecules),
98
+ # legends=None,
99
+ returnPNG=False,
100
+ drawOptions=drawOptions,
101
+ highlightAtomLists=None,
102
+ highlightBondLists=None,
103
+ )
104
+
105
+ return molecule_image, score_df, new_path
106
+
107
+
108
+
109
+ with gr.Blocks() as demo:
110
+ with gr.Row():
111
+ with gr.Column(scale=1):
112
+ gr.Markdown("# DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks")
113
+ with gr.Row():
114
+ gr.Markdown("[![arXiv](https://img.shields.io/badge/arXiv-2302.07868-b31b1b.svg)](https://arxiv.org/abs/2302.07868)")
115
+ gr.Markdown("[![github-repository](https://img.shields.io/badge/GitHub-black?logo=github)](https://github.com/HUBioDataLab/DrugGEN)")
116
+
117
+ with gr.Accordion("Expand to display information about models", open=False):
118
+ gr.Markdown("""
119
+ ### Model Variations
120
+ - **DrugGEN** is the default model. The input of the generator is the real molecules (ChEMBL) dataset (to ease the learning process) and the discriminator compares the generated molecules with the real inhibitors of the given target protein.
121
+ - **NoTarget** is the non-target-specific version of DrugGEN. This model only focuses on learning the chemical properties from the ChEMBL training dataset.
122
+ """)
123
+ model_name = gr.Radio(
124
+ choices=("DrugGEN", "NoTarget"),
125
+ value="DrugGEN",
126
+ label="Select a model to make inference",
127
+ info=" DrugGEN model design molecules to target the AKT1 protein"
128
+ )
129
+
130
+ num_molecules = gr.Number(
131
+ label="Number of molecules to generate",
132
+ precision=0, # integer input
133
+ minimum=1,
134
+ value=1000,
135
+ maximum=10_000,
136
+ )
137
+ seed_num = gr.Number(
138
+ label="RNG seed value (can be used for reproducibility):",
139
+ precision=0, # integer input
140
+ minimum=0,
141
+ value=42,
142
+ )
143
+
144
+ submit_button = gr.Button(
145
+ value="Start Generating"
146
+ )
147
+
148
+ with gr.Column(scale=2):
149
+ scores_df = gr.Dataframe(
150
+ label="Scores",
151
+ headers=["Runtime (seconds)", "Validity", "Uniqueness", "Novelty (Train)", "Novelty (Inference)", "Novelty (AKT)", "MaxLen", "MeanAtomType", "SNN (ChEMBL)", "SNN (AKT)"],
152
+ )
153
+ file_download = gr.File(
154
+ label="Click to download generated molecules",
155
+ )
156
+ image_output = gr.Image(
157
+ label="Structures of randomly selected 12 de novo molecules from the inference set:"
158
+ )
159
+
160
+
161
+ submit_button.click(function, inputs=[model_name, num_molecules, seed_num], outputs=[image_output, scores_df, file_download], api_name="inference")
162
+
163
+ demo.queue(concurrency_count=1)
164
+ demo.launch()
inference.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import pickle
4
+ import random
5
+ from tqdm import tqdm
6
+ import argparse
7
+
8
+ import torch
9
+ from torch_geometric.loader import DataLoader
10
+ import torch.utils.data
11
+ from rdkit import RDLogger
12
+ torch.set_num_threads(5)
13
+ RDLogger.DisableLog('rdApp.*')
14
+
15
+ from utils import *
16
+ from models import Generator
17
+ from new_dataloader import DruggenDataset
18
+ from loss import generator_loss
19
+ from training_data import load_molecules
20
+
21
+
22
+ class Inference(object):
23
+ """Inference class for DrugGEN."""
24
+
25
+ def __init__(self, config):
26
+ if config.set_seed:
27
+ np.random.seed(config.seed)
28
+ random.seed(config.seed)
29
+ torch.manual_seed(config.seed)
30
+ torch.cuda.manual_seed_all(config.seed)
31
+
32
+ torch.backends.cudnn.deterministic = True
33
+ torch.backends.cudnn.benchmark = False
34
+
35
+ os.environ["PYTHONHASHSEED"] = str(config.seed)
36
+
37
+ print(f'Using seed {config.seed}')
38
+
39
+ self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
40
+
41
+ # Initialize configurations
42
+ self.submodel = config.submodel
43
+
44
+ self.inference_model = config.inference_model
45
+ self.sample_num = config.sample_num
46
+
47
+ # Data loader.
48
+ self.inf_raw_file = config.inf_raw_file # SMILES containing text file for first dataset.
49
+ # Write the full path to file.
50
+ self.inf_dataset_file = config.inf_dataset_file # Dataset file name for the first GAN.
51
+ # Contains large number of molecules.
52
+ self.inf_batch_size = config.inf_batch_size
53
+ self.mol_data_dir = config.mol_data_dir # Directory where the dataset files are stored.
54
+ self.dataset_name = self.inf_dataset_file.split(".")[0]
55
+ self.max_atom = config.max_atom # Model is based on one-shot generation.
56
+ # Max atom number for molecules must be specified.
57
+ self.features = config.features # Small model uses atom types as node features. (Boolean, False uses atom types only.)
58
+ # Additional node features can be added. Please check new_dataloarder.py Line 102.
59
+
60
+ self.inf_dataset = DruggenDataset(self.mol_data_dir,
61
+ self.inf_dataset_file,
62
+ self.inf_raw_file,
63
+ self.max_atom,
64
+ self.features) # Dataset for the first GAN. Custom dataset class from PyG parent class.
65
+ # Can create any molecular graph dataset given smiles string.
66
+ # Nonisomeric SMILES are suggested but not necessary.
67
+ # Uses sparse matrix representation for graphs,
68
+ # For computational and speed efficiency.
69
+
70
+ self.inf_loader = DataLoader(self.inf_dataset,
71
+ shuffle=True,
72
+ batch_size=self.inf_batch_size,
73
+ drop_last=True) # PyG dataloader for the first GAN.
74
+
75
+
76
+ # Atom and bond type dimensions for the construction of the model.
77
+ self.atom_decoders = self.decoder_load("atom") # Atom type decoders for first GAN.
78
+ # eg. 0:0, 1:6 (C), 2:7 (N), 3:8 (O), 4:9 (F)
79
+ self.bond_decoders = self.decoder_load("bond") # Bond type decoders for first GAN.
80
+ # eg. 0: (no-bond), 1: (single), 2: (double), 3: (triple), 4: (aromatic)
81
+ self.m_dim = len(self.atom_decoders) if not self.features else int(self.inf_loader.dataset[0].x.shape[1]) # Atom type dimension.
82
+ self.b_dim = len(self.bond_decoders) # Bond type dimension.
83
+ self.vertexes = int(self.inf_loader.dataset[0].x.shape[0]) # Number of nodes in the graph.
84
+
85
+ # Transformer and Convolution configurations.
86
+ self.act = config.act
87
+ self.dim = config.dim
88
+ self.depth = config.depth
89
+ self.heads = config.heads
90
+ self.mlp_ratio = config.mlp_ratio
91
+ self.dropout = config.dropout
92
+
93
+ self.build_model()
94
+
95
+
96
+ def build_model(self):
97
+ """Create generators and discriminators."""
98
+ self.G = Generator(self.act,
99
+ self.vertexes,
100
+ self.b_dim,
101
+ self.m_dim,
102
+ self.dropout,
103
+ dim=self.dim,
104
+ depth=self.depth,
105
+ heads=self.heads,
106
+ mlp_ratio=self.mlp_ratio,
107
+ submodel = self.submodel)
108
+
109
+ self.print_network(self.G, 'G')
110
+
111
+ self.G.to(self.device)
112
+
113
+
114
+ def decoder_load(self, dictionary_name):
115
+ ''' Loading the atom and bond decoders'''
116
+ with open("DrugGEN/data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
117
+ return pickle.load(f)
118
+
119
+
120
+ def print_network(self, model, name):
121
+ """Print out the network information."""
122
+ num_params = 0
123
+ for p in model.parameters():
124
+ num_params += p.numel()
125
+ print(model)
126
+ print(name)
127
+ print("The number of parameters: {}".format(num_params))
128
+
129
+
130
+ def restore_model(self, submodel, model_directory):
131
+ """Restore the trained generator and discriminator."""
132
+ print('Loading the model...')
133
+ G_path = os.path.join(model_directory, '{}-G.ckpt'.format(submodel))
134
+ self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
135
+
136
+
137
+ def inference(self):
138
+ # Load the trained generator.
139
+ self.restore_model(self.submodel, self.inference_model)
140
+
141
+ # smiles data for metrics calculation.
142
+ chembl_smiles = [line for line in open("DrugGEN/data/chembl_train.smi", 'r').read().splitlines()]
143
+ chembl_test = [line for line in open("DrugGEN/data/chembl_test.smi", 'r').read().splitlines()]
144
+ drug_smiles = [line for line in open("DrugGEN/data/akt_inhibitors.smi", 'r').read().splitlines()]
145
+ drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
146
+ drug_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in drug_mols if x is not None]
147
+
148
+
149
+ # Make directories if not exist.
150
+ if not os.path.exists("DrugGEN/experiments/inference/{}".format(self.submodel)):
151
+ os.makedirs("DrugGEN/experiments/inference/{}".format(self.submodel))
152
+
153
+
154
+ self.G.eval()
155
+
156
+ start_time = time.time()
157
+ metric_calc_dr = []
158
+ uniqueness_calc = []
159
+ real_smiles_snn = []
160
+ nodes_sample = torch.Tensor(size=[1,45,1]).to(self.device)
161
+
162
+ val_counter = 0
163
+ none_counter = 0
164
+ # Inference mode
165
+ with torch.inference_mode():
166
+ pbar = tqdm(range(self.sample_num))
167
+ pbar.set_description('Inference mode for {} model started'.format(self.submodel))
168
+ for i, data in enumerate(self.inf_loader):
169
+
170
+ val_counter += 1
171
+ # Preprocess dataset
172
+ _, a_tensor, x_tensor = load_molecules(
173
+ data=data,
174
+ batch_size=self.inf_batch_size,
175
+ device=self.device,
176
+ b_dim=self.b_dim,
177
+ m_dim=self.m_dim,
178
+ )
179
+
180
+ _, _, node_sample, edge_sample = self.G(a_tensor, x_tensor)
181
+
182
+ g_edges_hat_sample = torch.max(edge_sample, -1)[1]
183
+ g_nodes_hat_sample = torch.max(node_sample, -1)[1]
184
+
185
+ fake_mol_g = [self.inf_dataset.matrices2mol_drugs(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=self.dataset_name)
186
+ for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
187
+
188
+ a_tensor_sample = torch.max(a_tensor, -1)[1]
189
+ x_tensor_sample = torch.max(x_tensor, -1)[1]
190
+ real_mols = [self.inf_dataset.matrices2mol_drugs(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=self.dataset_name)
191
+ for e_, n_ in zip(a_tensor_sample, x_tensor_sample)]
192
+
193
+ inference_drugs = [None if line is None else Chem.MolToSmiles(line) for line in fake_mol_g]
194
+ inference_drugs = [None if x is None else max(x.split('.'), key=len) for x in inference_drugs]
195
+
196
+ for molecules in inference_drugs:
197
+ if molecules is None:
198
+ none_counter += 1
199
+
200
+ with open("DrugGEN/experiments/inference/{}/inference_drugs.txt".format(self.submodel), "a") as f:
201
+ for molecules in inference_drugs:
202
+ if molecules is not None:
203
+ molecules = molecules.replace("*", "C")
204
+ f.write(molecules)
205
+ f.write("\n")
206
+ uniqueness_calc.append(molecules)
207
+ nodes_sample = torch.cat((nodes_sample, g_nodes_hat_sample.view(1,45,1)), 0)
208
+ pbar.update(1)
209
+ metric_calc_dr.append(molecules)
210
+
211
+
212
+ generation_number = len([x for x in metric_calc_dr if x is not None])
213
+ if generation_number == self.sample_num or none_counter == self.sample_num:
214
+ break
215
+ real_smiles_snn.append(real_mols[0])
216
+
217
+ et = time.time() - start_time
218
+ gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(x), 2, nBits=1024) for x in uniqueness_calc if Chem.MolFromSmiles(x) is not None]
219
+ real_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_smiles_snn if x is not None]
220
+ print("Inference mode is lasted for {:.2f} seconds".format(et))
221
+
222
+ print("Metrics calculation started using MOSES.")
223
+ # post-process * to Carbon atom in valid molecules
224
+
225
+ return{
226
+ "Validity": fraction_valid(metric_calc_dr),
227
+ "Uniqueness": fraction_unique(uniqueness_calc),
228
+ "Novelty (Train)": novelty(metric_calc_dr, chembl_smiles),
229
+ "Novelty (Inference)": novelty(metric_calc_dr, chembl_test),
230
+ "Novelty (AKT)": novelty(metric_calc_dr, drug_smiles),
231
+ "MaxLen": Metrics.max_component(uniqueness_calc, self.vertexes),
232
+ "MeanAtomType": Metrics.mean_atom_type(nodes_sample),
233
+ "SNN (ChEMBL)": average_agg_tanimoto(np.array(real_vecs), np.array(gen_vecs)),
234
+ "SNN (AKT)": average_agg_tanimoto(np.array(drug_vecs), np.array(gen_vecs)),
235
+ }
236
+
237
+
238
+ if __name__=="__main__":
239
+ parser = argparse.ArgumentParser()
240
+
241
+ # Inference configuration.
242
+ parser.add_argument('--submodel', type=str, default="DrugGEN", help="Chose model subtype: DrugGEN, NoTarget", choices=['DrugGEN', 'NoTarget'])
243
+ parser.add_argument('--inference_model', type=str, help="Path to the model for inference")
244
+ parser.add_argument('--sample_num', type=int, default=10000, help='inference samples')
245
+
246
+ # Data configuration.
247
+ parser.add_argument('--inf_dataset_file', type=str, default='chembl45_test.pt')
248
+ parser.add_argument('--inf_raw_file', type=str, default='DrugGEN/data/chembl_test.smi')
249
+ parser.add_argument('--inf_batch_size', type=int, default=1, help='Batch size for inference')
250
+ parser.add_argument('--mol_data_dir', type=str, default='DrugGEN/data')
251
+ parser.add_argument('--features', type=str2bool, default=False, help='features dimension for nodes')
252
+
253
+ # Model configuration.
254
+ parser.add_argument('--act', type=str, default="relu", help="Activation function for the model.", choices=['relu', 'tanh', 'leaky', 'sigmoid'])
255
+ parser.add_argument('--max_atom', type=int, default=45, help='Max atom number for molecules must be specified.')
256
+ parser.add_argument('--dim', type=int, default=128, help='Dimension of the Transformer Encoder model for the GAN.')
257
+ parser.add_argument('--depth', type=int, default=1, help='Depth of the Transformer model from the GAN.')
258
+ parser.add_argument('--heads', type=int, default=8, help='Number of heads for the MultiHeadAttention module from the GAN.')
259
+ parser.add_argument('--mlp_ratio', type=int, default=3, help='MLP ratio for the Transformer.')
260
+ parser.add_argument('--dropout', type=float, default=0., help='dropout rate')
261
+
262
+ # Seed configuration.
263
+ parser.add_argument('--set_seed', type=bool, default=False, help='set seed for reproducibility')
264
+ parser.add_argument('--seed', type=int, default=1, help='seed for reproducibility')
265
+
266
+ config = parser.parse_args()
267
+ inference = Inference(config)
268
+ inference.inference()
layers.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ import math
5
+
6
+ class MLP(nn.Module):
7
+ def __init__(self, in_feat, hid_feat=None, out_feat=None, dropout=0.):
8
+ super().__init__()
9
+
10
+ if not hid_feat:
11
+ hid_feat = in_feat
12
+ if not out_feat:
13
+ out_feat = in_feat
14
+
15
+ self.fc1 = nn.Linear(in_feat, hid_feat)
16
+ self.act = torch.nn.ReLU()
17
+ self.fc2 = nn.Linear(hid_feat,out_feat)
18
+ self.droprateout = nn.Dropout(dropout)
19
+
20
+ def forward(self, x):
21
+ x = self.fc1(x)
22
+ x = self.act(x)
23
+ x = self.fc2(x)
24
+ return self.droprateout(x)
25
+
26
+ class Attention_new(nn.Module):
27
+ def __init__(self, dim, heads, attention_dropout=0.):
28
+ super().__init__()
29
+
30
+ assert dim % heads == 0
31
+
32
+ self.heads = heads
33
+ self.scale = 1./dim**0.5
34
+ self.q = nn.Linear(dim, dim)
35
+ self.k = nn.Linear(dim, dim)
36
+ self.v = nn.Linear(dim, dim)
37
+ self.e = nn.Linear(dim, dim)
38
+ self.d_k = dim // heads
39
+ self.heads = heads
40
+ self.out_e = nn.Linear(dim,dim)
41
+ self.out_n = nn.Linear(dim, dim)
42
+
43
+ def forward(self, node, edge):
44
+ b, n, c = node.shape
45
+
46
+ q_embed = self.q(node).view(-1, n, self.heads, c//self.heads)
47
+ k_embed = self.k(node).view(-1, n, self.heads, c//self.heads)
48
+ v_embed = self.v(node).view(-1, n, self.heads, c//self.heads)
49
+ e_embed = self.e(edge).view(-1, n, n, self.heads, c//self.heads)
50
+
51
+ q_embed = q_embed.unsqueeze(2)
52
+ k_embed = k_embed.unsqueeze(1)
53
+
54
+ attn = q_embed * k_embed
55
+ attn = attn/ math.sqrt(self.d_k)
56
+ attn = attn * (e_embed + 1) * e_embed
57
+
58
+ edge = self.out_e(attn.flatten(3))
59
+
60
+ attn = F.softmax(attn, dim=2)
61
+
62
+ v_embed = v_embed.unsqueeze(1)
63
+ v_embed = attn * v_embed
64
+ v_embed = v_embed.sum(dim=2).flatten(2)
65
+
66
+ node = self.out_n(v_embed)
67
+ return node, edge
68
+
69
+
70
+ class Encoder_Block(nn.Module):
71
+ def __init__(self, dim, heads, act, mlp_ratio=4, drop_rate=0.):
72
+ super().__init__()
73
+
74
+ self.ln1 = nn.LayerNorm(dim)
75
+ self.attn = Attention_new(dim, heads, drop_rate)
76
+ self.ln3 = nn.LayerNorm(dim)
77
+ self.ln4 = nn.LayerNorm(dim)
78
+ self.mlp = MLP(dim, dim*mlp_ratio, dim, dropout=drop_rate)
79
+ self.mlp2 = MLP(dim, dim*mlp_ratio, dim, dropout=drop_rate)
80
+ self.ln5 = nn.LayerNorm(dim)
81
+ self.ln6 = nn.LayerNorm(dim)
82
+
83
+ def forward(self, x, y):
84
+ x1 = self.ln1(x)
85
+ x2,y1 = self.attn(x1,y)
86
+ x2 = x1 + x2
87
+ y2 = y1 + y
88
+ x2 = self.ln3(x2)
89
+ y2 = self.ln4(y2)
90
+ x = self.ln5(x2 + self.mlp(x2))
91
+ y = self.ln6(y2 + self.mlp2(y2))
92
+ return x, y
93
+
94
+
95
+ class TransformerEncoder(nn.Module):
96
+ def __init__(self, dim, depth, heads, act, mlp_ratio=4, drop_rate=0.1):
97
+ super().__init__()
98
+
99
+ self.Encoder_Blocks = nn.ModuleList([
100
+ Encoder_Block(dim, heads, act, mlp_ratio, drop_rate)
101
+ for i in range(depth)])
102
+
103
+ def forward(self, x, y):
104
+ for Encoder_Block in self.Encoder_Blocks:
105
+ x, y = Encoder_Block(x,y)
106
+ return x, y
loss.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def discriminator_loss(generator, discriminator, mol_graph, batch_size, device, grad_pen, lambda_gp, z_edge, z_node):
5
+ # Compute loss with real molecules.
6
+ logits_real_disc = discriminator(mol_graph)
7
+ prediction_real = - torch.mean(logits_real_disc)
8
+
9
+ # Compute loss with fake molecules.
10
+ node, edge, node_sample, edge_sample = generator(z_edge, z_node)
11
+ graph = torch.cat((node_sample.view(batch_size, -1), edge_sample.view(batch_size, -1)), dim=-1)
12
+ logits_fake_disc = discriminator(graph.detach())
13
+ prediction_fake = torch.mean(logits_fake_disc)
14
+
15
+ # Compute gradient loss.
16
+ eps = torch.rand(mol_graph.size(0),1).to(device)
17
+ x_int0 = (eps * mol_graph + (1. - eps) * graph).requires_grad_(True)
18
+ grad0 = discriminator(x_int0)
19
+ d_loss_gp = grad_pen(grad0, x_int0)
20
+
21
+ # Calculate total loss
22
+ d_loss = prediction_fake + prediction_real + d_loss_gp * lambda_gp
23
+ return node, edge, d_loss
24
+
25
+
26
+ def generator_loss(generator, discriminator, adj, annot, batch_size):
27
+ # Compute loss with fake molecules.
28
+ node, edge, node_sample, edge_sample = generator(adj, annot)
29
+
30
+ graph = torch.cat((node_sample.view(batch_size, -1), edge_sample.view(batch_size, -1)), dim=-1)
31
+
32
+ logits_fake_disc = discriminator(graph)
33
+ prediction_fake = - torch.mean(logits_fake_disc)
34
+ g_loss = prediction_fake
35
+
36
+ return g_loss, node, edge, node_sample, edge_sample
models.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from layers import TransformerEncoder
4
+
5
+ class Generator(nn.Module):
6
+ """Generator network."""
7
+
8
+ def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio, submodel):
9
+ super(Generator, self).__init__()
10
+ self.submodel = submodel
11
+ self.vertexes = vertexes
12
+ self.edges = edges
13
+ self.nodes = nodes
14
+ self.depth = depth
15
+ self.dim = dim
16
+ self.heads = heads
17
+ self.mlp_ratio = mlp_ratio
18
+ self.dropout = dropout
19
+
20
+ if act == "relu":
21
+ act = nn.ReLU()
22
+ elif act == "leaky":
23
+ act = nn.LeakyReLU()
24
+ elif act == "sigmoid":
25
+ act = nn.Sigmoid()
26
+ elif act == "tanh":
27
+ act = nn.Tanh()
28
+
29
+ self.features = vertexes * vertexes * edges + vertexes * nodes
30
+ self.transformer_dim = vertexes * vertexes * dim + vertexes * dim
31
+ self.pos_enc_dim = 5
32
+
33
+ self.node_layers = nn.Sequential(nn.Linear(nodes, 64), act, nn.Linear(64,dim), act, nn.Dropout(self.dropout))
34
+ self.edge_layers = nn.Sequential(nn.Linear(edges, 64), act, nn.Linear(64,dim), act, nn.Dropout(self.dropout))
35
+ self.TransformerEncoder = TransformerEncoder(dim=self.dim, depth=self.depth, heads=self.heads, act = act,
36
+ mlp_ratio=self.mlp_ratio, drop_rate=self.dropout)
37
+
38
+ self.readout_e = nn.Linear(self.dim, edges)
39
+ self.readout_n = nn.Linear(self.dim, nodes)
40
+ self.softmax = nn.Softmax(dim = -1)
41
+
42
+ def _generate_square_subsequent_mask(self, sz):
43
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
44
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
45
+ return mask
46
+
47
+ def laplacian_positional_enc(self, adj):
48
+ A = adj
49
+ D = torch.diag(torch.count_nonzero(A, dim=-1))
50
+ L = torch.eye(A.shape[0], device=A.device) - D * A * D
51
+
52
+ EigVal, EigVec = torch.linalg.eig(L)
53
+ idx = torch.argsort(torch.real(EigVal))
54
+ EigVal, EigVec = EigVal[idx], torch.real(EigVec[:,idx])
55
+ pos_enc = EigVec[:,1:self.pos_enc_dim + 1]
56
+ return pos_enc
57
+
58
+ def forward(self, z_e, z_n):
59
+ b, n, c = z_n.shape
60
+ _, _, _ , d = z_e.shape
61
+
62
+ node = self.node_layers(z_n)
63
+ edge = self.edge_layers(z_e)
64
+ edge = (edge + edge.permute(0, 2, 1, 3)) / 2
65
+
66
+ node, edge = self.TransformerEncoder(node,edge)
67
+
68
+ node_sample = self.readout_n(node)
69
+ edge_sample = self.readout_e(edge)
70
+ return node, edge, node_sample, edge_sample
71
+
72
+
73
+ class simple_disc(nn.Module):
74
+ def __init__(self, act, m_dim, vertexes, b_dim):
75
+ super().__init__()
76
+
77
+ if act == "relu":
78
+ act = nn.ReLU()
79
+ elif act == "leaky":
80
+ act = nn.LeakyReLU()
81
+ elif act == "sigmoid":
82
+ act = nn.Sigmoid()
83
+ elif act == "tanh":
84
+ act = nn.Tanh()
85
+
86
+ features = vertexes * m_dim + vertexes * vertexes * b_dim
87
+ self.predictor = nn.Sequential(nn.Linear(features,256), act, nn.Linear(256,128), act, nn.Linear(128,64), act,
88
+ nn.Linear(64,32), act, nn.Linear(32,16), act,
89
+ nn.Linear(16,1))
90
+
91
+ def forward(self, x):
92
+ prediction = self.predictor(x)
93
+ return prediction
new_dataloader.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import numpy as np
3
+ import torch
4
+ from rdkit import Chem
5
+ from torch_geometric.data import (Data, InMemoryDataset)
6
+ import os.path as osp
7
+ from tqdm import tqdm
8
+ import re
9
+ from rdkit import RDLogger
10
+ import pandas as pd
11
+
12
+ RDLogger.DisableLog('rdApp.*')
13
+ class DruggenDataset(InMemoryDataset):
14
+
15
+ def __init__(self, root, dataset_file, raw_files, max_atom, features, transform=None, pre_transform=None, pre_filter=None):
16
+ self.dataset_name = dataset_file.split(".")[0]
17
+ self.dataset_file = dataset_file
18
+ self.raw_files = raw_files
19
+ self.max_atom = max_atom
20
+ self.features = features
21
+ super().__init__(root, transform, pre_transform, pre_filter)
22
+ path = osp.join(self.processed_dir, dataset_file)
23
+ self.data, self.slices = torch.load(path)
24
+ self.root = root
25
+
26
+
27
+ @property
28
+ def processed_dir(self):
29
+
30
+ return self.root
31
+
32
+ @property
33
+ def raw_file_names(self):
34
+ return self.raw_files
35
+
36
+ @property
37
+ def processed_file_names(self):
38
+ return self.dataset_file
39
+
40
+ def _generate_encoders_decoders(self, data):
41
+
42
+ self.data = data
43
+ print('Creating atoms and bonds encoder and decoder..')
44
+
45
+ atom_labels = set()
46
+ bond_labels = set()
47
+ max_length = 0
48
+ smiles_list = []
49
+ for smiles in tqdm(data):
50
+ mol = Chem.MolFromSmiles(smiles)
51
+ molecule_size = mol.GetNumAtoms()
52
+ if molecule_size > self.max_atom:
53
+ continue
54
+ smiles_list.append(smiles)
55
+ atom_labels.update([atom.GetAtomicNum() for atom in mol.GetAtoms()])
56
+ max_length = max(max_length, molecule_size)
57
+ bond_labels.update([bond.GetBondType() for bond in mol.GetBonds()])
58
+
59
+ atom_labels.update([0]) # add PAD symbol (for unknown atoms)
60
+ atom_labels = sorted(atom_labels) # turn set into list and sort it
61
+
62
+ bond_labels = sorted(bond_labels)
63
+ bond_labels = [Chem.rdchem.BondType.ZERO] + bond_labels
64
+
65
+ # atom_labels = sorted(set([atom.GetAtomicNum() for mol in self.data for atom in mol.GetAtoms()] + [0]))
66
+ self.atom_encoder_m = {l: i for i, l in enumerate(atom_labels)}
67
+ self.atom_decoder_m = {i: l for i, l in enumerate(atom_labels)}
68
+ self.atom_num_types = len(atom_labels)
69
+ print('Created atoms encoder and decoder with {} atom types and 1 PAD symbol!'.format(
70
+ self.atom_num_types - 1))
71
+ print("atom_labels", atom_labels)
72
+ # print('Creating bonds encoder and decoder..')
73
+ # bond_labels = [Chem.rdchem.BondType.ZERO] + list(sorted(set(bond.GetBondType()
74
+ # for mol in self.data
75
+ # for bond in mol.GetBonds())))
76
+ # bond_labels = [
77
+ # Chem.rdchem.BondType.ZERO,
78
+ # Chem.rdchem.BondType.SINGLE,
79
+ # Chem.rdchem.BondType.DOUBLE,
80
+ # Chem.rdchem.BondType.TRIPLE,
81
+ # Chem.rdchem.BondType.AROMATIC,
82
+ # ]
83
+
84
+ print("bond labels", bond_labels)
85
+ self.bond_encoder_m = {l: i for i, l in enumerate(bond_labels)}
86
+ self.bond_decoder_m = {i: l for i, l in enumerate(bond_labels)}
87
+ self.bond_num_types = len(bond_labels)
88
+ print('Created bonds encoder and decoder with {} bond types and 1 PAD symbol!'.format(
89
+ self.bond_num_types - 1))
90
+ #dataset_names = str(self.dataset_name)
91
+ with open("DrugGEN/data/encoders/" +"atom_" + self.dataset_name + ".pkl","wb") as atom_encoders:
92
+ pickle.dump(self.atom_encoder_m,atom_encoders)
93
+
94
+
95
+ with open("DrugGEN/data/decoders/" +"atom_" + self.dataset_name + ".pkl","wb") as atom_decoders:
96
+ pickle.dump(self.atom_decoder_m,atom_decoders)
97
+
98
+
99
+ with open("DrugGEN/data/encoders/" +"bond_" + self.dataset_name + ".pkl","wb") as bond_encoders:
100
+ pickle.dump(self.bond_encoder_m,bond_encoders)
101
+
102
+
103
+ with open("DrugGEN/data/decoders/" +"bond_" + self.dataset_name + ".pkl","wb") as bond_decoders:
104
+ pickle.dump(self.bond_decoder_m,bond_decoders)
105
+
106
+ return max_length, smiles_list # data is filtered now
107
+
108
+ def _genA(self, mol, connected=True, max_length=None):
109
+
110
+ max_length = max_length if max_length is not None else mol.GetNumAtoms()
111
+
112
+ A = np.zeros(shape=(max_length, max_length))
113
+
114
+ begin, end = [b.GetBeginAtomIdx() for b in mol.GetBonds()], [b.GetEndAtomIdx() for b in mol.GetBonds()]
115
+ bond_type = [self.bond_encoder_m[b.GetBondType()] for b in mol.GetBonds()]
116
+
117
+ A[begin, end] = bond_type
118
+ A[end, begin] = bond_type
119
+
120
+ degree = np.sum(A[:mol.GetNumAtoms(), :mol.GetNumAtoms()], axis=-1)
121
+
122
+ return A if connected and (degree > 0).all() else None
123
+
124
+ def _genX(self, mol, max_length=None):
125
+
126
+ max_length = max_length if max_length is not None else mol.GetNumAtoms()
127
+
128
+ return np.array([self.atom_encoder_m[atom.GetAtomicNum()] for atom in mol.GetAtoms()] + [0] * (
129
+ max_length - mol.GetNumAtoms()))
130
+
131
+ def _genF(self, mol, max_length=None):
132
+
133
+ max_length = max_length if max_length is not None else mol.GetNumAtoms()
134
+
135
+ features = np.array([[*[a.GetDegree() == i for i in range(5)],
136
+ *[a.GetExplicitValence() == i for i in range(9)],
137
+ *[int(a.GetHybridization()) == i for i in range(1, 7)],
138
+ *[a.GetImplicitValence() == i for i in range(9)],
139
+ a.GetIsAromatic(),
140
+ a.GetNoImplicit(),
141
+ *[a.GetNumExplicitHs() == i for i in range(5)],
142
+ *[a.GetNumImplicitHs() == i for i in range(5)],
143
+ *[a.GetNumRadicalElectrons() == i for i in range(5)],
144
+ a.IsInRing(),
145
+ *[a.IsInRingSize(i) for i in range(2, 9)]] for a in mol.GetAtoms()], dtype=np.int32)
146
+
147
+ return np.vstack((features, np.zeros((max_length - features.shape[0], features.shape[1]))))
148
+
149
+ def decoder_load(self, dictionary_name, file):
150
+ with open("DrugGEN/data/decoders/" + dictionary_name + "_" + file + '.pkl', 'rb') as f:
151
+ return pickle.load(f)
152
+
153
+ def drugs_decoder_load(self, dictionary_name):
154
+ with open("DrugGEN/data/decoders/" + dictionary_name +'.pkl', 'rb') as f:
155
+ return pickle.load(f)
156
+
157
+ def matrices2mol(self, node_labels, edge_labels, strict=True, file_name=None):
158
+ mol = Chem.RWMol()
159
+ RDLogger.DisableLog('rdApp.*')
160
+ atom_decoders = self.decoder_load("atom", file_name)
161
+ bond_decoders = self.decoder_load("bond", file_name)
162
+
163
+ for node_label in node_labels:
164
+ mol.AddAtom(Chem.Atom(atom_decoders[node_label]))
165
+
166
+ for start, end in zip(*np.nonzero(edge_labels)):
167
+ if start > end:
168
+ mol.AddBond(int(start), int(end), bond_decoders[edge_labels[start, end]])
169
+ #mol = self.correct_mol(mol)
170
+ if strict:
171
+ try:
172
+
173
+ Chem.SanitizeMol(mol)
174
+ except:
175
+ mol = None
176
+
177
+ return mol
178
+
179
+ def drug_decoder_load(self, dictionary_name, file):
180
+
181
+ ''' Loading the atom and bond decoders '''
182
+
183
+ with open("DrugGEN/data/decoders/" + dictionary_name +"_" + file +'.pkl', 'rb') as f:
184
+
185
+ return pickle.load(f)
186
+ def matrices2mol_drugs(self, node_labels, edge_labels, strict=True, file_name=None):
187
+ mol = Chem.RWMol()
188
+ RDLogger.DisableLog('rdApp.*')
189
+ atom_decoders = self.drug_decoder_load("atom", file_name)
190
+ bond_decoders = self.drug_decoder_load("bond", file_name)
191
+
192
+ for node_label in node_labels:
193
+
194
+ mol.AddAtom(Chem.Atom(atom_decoders[node_label]))
195
+
196
+ for start, end in zip(*np.nonzero(edge_labels)):
197
+ if start > end:
198
+ mol.AddBond(int(start), int(end), bond_decoders[edge_labels[start, end]])
199
+ #mol = self.correct_mol(mol)
200
+ if strict:
201
+ try:
202
+ Chem.SanitizeMol(mol)
203
+ except:
204
+ mol = None
205
+
206
+ return mol
207
+ def check_valency(self,mol):
208
+ """
209
+ Checks that no atoms in the mol have exceeded their possible
210
+ valency
211
+ :return: True if no valency issues, False otherwise
212
+ """
213
+ try:
214
+ Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
215
+ return True, None
216
+ except ValueError as e:
217
+ e = str(e)
218
+ p = e.find('#')
219
+ e_sub = e[p:]
220
+ atomid_valence = list(map(int, re.findall(r'\d+', e_sub)))
221
+ return False, atomid_valence
222
+
223
+
224
+ def correct_mol(self,x):
225
+ xsm = Chem.MolToSmiles(x, isomericSmiles=True)
226
+ mol = x
227
+ while True:
228
+ flag, atomid_valence = self.check_valency(mol)
229
+ if flag:
230
+ break
231
+ else:
232
+ assert len (atomid_valence) == 2
233
+ idx = atomid_valence[0]
234
+ v = atomid_valence[1]
235
+ queue = []
236
+ for b in mol.GetAtomWithIdx(idx).GetBonds():
237
+ queue.append(
238
+ (b.GetIdx(), int(b.GetBondType()), b.GetBeginAtomIdx(), b.GetEndAtomIdx())
239
+ )
240
+ queue.sort(key=lambda tup: tup[1], reverse=True)
241
+ if len(queue) > 0:
242
+ start = queue[0][2]
243
+ end = queue[0][3]
244
+ t = queue[0][1] - 1
245
+ mol.RemoveBond(start, end)
246
+
247
+ #if t >= 1:
248
+
249
+ #mol.AddBond(start, end, self.decoder_load('bond_decoders')[t])
250
+ # if '.' in Chem.MolToSmiles(mol, isomericSmiles=True):
251
+ # mol.AddBond(start, end, self.decoder_load('bond_decoders')[t])
252
+ # print(tt)
253
+ # print(Chem.MolToSmiles(mol, isomericSmiles=True))
254
+
255
+ return mol
256
+
257
+
258
+
259
+ def label2onehot(self, labels, dim):
260
+
261
+ """Convert label indices to one-hot vectors."""
262
+
263
+ out = torch.zeros(list(labels.size())+[dim])
264
+ out.scatter_(len(out.size())-1,labels.unsqueeze(-1),1.)
265
+
266
+ return out.float()
267
+
268
+ def process(self, size= None):
269
+ smiles_list = pd.read_csv(self.raw_files, header=None)[0].tolist()
270
+ max_length, smiles_list = self._generate_encoders_decoders(smiles_list)
271
+
272
+ data_list = []
273
+
274
+ self.m_dim = len(self.atom_decoder_m)
275
+ for smiles in tqdm(smiles_list, desc='Processing chembl dataset', total=len(smiles_list)):
276
+ mol = Chem.MolFromSmiles(smiles)
277
+ A = self._genA(mol, connected=True, max_length=max_length)
278
+ if A is not None:
279
+
280
+
281
+ x = torch.from_numpy(self._genX(mol, max_length=max_length)).to(torch.long).view(1, -1)
282
+
283
+ x = self.label2onehot(x,self.m_dim).squeeze()
284
+ if self.features:
285
+ f = torch.from_numpy(self._genF(mol, max_length=max_length)).to(torch.long).view(x.shape[0], -1)
286
+ x = torch.concat((x,f), dim=-1)
287
+
288
+ adjacency = torch.from_numpy(A)
289
+
290
+ edge_index = adjacency.nonzero(as_tuple=False).t().contiguous()
291
+ edge_attr = adjacency[edge_index[0], edge_index[1]].to(torch.long)
292
+
293
+ data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles)
294
+
295
+ if self.pre_filter is not None and not self.pre_filter(data):
296
+ continue
297
+
298
+ if self.pre_transform is not None:
299
+ data = self.pre_transform(data)
300
+
301
+ data_list.append(data)
302
+
303
+
304
+ torch.save(self.collate(data_list), osp.join(self.processed_dir, self.dataset_file))
305
+
306
+
307
+
308
+
309
+ if __name__ == '__main__':
310
+ data = DruggenDataset("DrugGEN/data")
311
+
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ libcairo2-dev
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ rdkit-pypi
3
+ tqdm
4
+ numpy
5
+ seaborn
6
+ matplotlib
7
+ pandas
8
+ torch_geometric
9
+ # demo related installs
10
+ streamlit
11
+ ipython
12
+ streamlit-ext
training_data.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch_geometric.utils as geoutils
3
+ from utils import label2onehot
4
+
5
+ def generate_z_values(batch_size=32, z_dim=32, vertexes=32, b_dim=32, m_dim=32, device=None):
6
+ z = torch.normal(mean=0, std=1, size=(batch_size, z_dim), device=device) # (batch,max_len)
7
+ z_edge = torch.normal(mean=0, std=1, size=(batch_size, vertexes, vertexes, b_dim), device=device) # (batch,max_len,max_len)
8
+ z_node = torch.normal(mean=0, std=1, size=(batch_size, vertexes, m_dim), device=device) # (batch,max_len)
9
+
10
+ z = z.float().requires_grad_(True)
11
+ z_edge = z_edge.float().requires_grad_(True) # Edge noise.(batch,max_len,max_len)
12
+ z_node = z_node.float().requires_grad_(True) # Node noise.(batch,max_len)
13
+ return z, z_edge, z_node
14
+
15
+
16
+ def load_molecules(data=None, b_dim=32, m_dim=32, device=None, batch_size=32):
17
+ data = data.to(device)
18
+ a = geoutils.to_dense_adj(
19
+ edge_index = data.edge_index,
20
+ batch=data.batch,
21
+ edge_attr=data.edge_attr,
22
+ max_num_nodes=int(data.batch.shape[0]/batch_size)
23
+ )
24
+ x_tensor = data.x.view(batch_size,int(data.batch.shape[0]/batch_size),-1)
25
+ a_tensor = label2onehot(a, b_dim, device)
26
+
27
+ a_tensor_vec = a_tensor.reshape(batch_size,-1)
28
+ x_tensor_vec = x_tensor.reshape(batch_size,-1)
29
+ real_graphs = torch.concat((x_tensor_vec,a_tensor_vec),dim=-1)
30
+
31
+ return real_graphs, a_tensor, x_tensor
utils.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from statistics import mean
2
+ import os
3
+ import math
4
+ import time
5
+ import datetime
6
+ from rdkit import DataStructs
7
+ from rdkit import Chem
8
+ from rdkit import RDLogger
9
+ from rdkit.Chem import AllChem
10
+ from rdkit.Chem import Draw
11
+ from rdkit.Chem.Scaffolds import MurckoScaffold
12
+ import numpy as np
13
+ import matplotlib.pyplot as plt
14
+ from matplotlib.lines import Line2D
15
+ import torch
16
+ import wandb
17
+ RDLogger.DisableLog('rdApp.*')
18
+ import warnings
19
+ from multiprocessing import Pool
20
+ class Metrics(object):
21
+
22
+ @staticmethod
23
+ def valid(x):
24
+ return x is not None and Chem.MolToSmiles(x) != ''
25
+
26
+ @staticmethod
27
+ def tanimoto_sim_1v2(data1, data2):
28
+ min_len = data1.size if data1.size > data2.size else data2
29
+ sims = []
30
+ for i in range(min_len):
31
+ sim = DataStructs.FingerprintSimilarity(data1[i], data2[i])
32
+ sims.append(sim)
33
+ mean_sim = mean(sim)
34
+ return mean_sim
35
+
36
+ @staticmethod
37
+ def mol_length(x):
38
+ if x is not None:
39
+ return len([char for char in max(x.split(sep =".")).upper() if char.isalpha()])
40
+ else:
41
+ return 0
42
+
43
+ @staticmethod
44
+ def max_component(data, max_len):
45
+
46
+ return ((np.array(list(map(Metrics.mol_length, data)), dtype=np.float32)/max_len).mean())
47
+
48
+ @staticmethod
49
+ def mean_atom_type(data):
50
+ atom_types_used = []
51
+ for i in data:
52
+
53
+ atom_types_used.append(len(i.unique().tolist()))
54
+ av_type = np.mean(atom_types_used) - 1
55
+
56
+ return av_type
57
+
58
+
59
+ def sim_reward(mol_gen, fps_r):
60
+
61
+ gen_scaf = []
62
+
63
+ for x in mol_gen:
64
+ if x is not None:
65
+ try:
66
+
67
+ gen_scaf.append(MurckoScaffold.GetScaffoldForMol(x))
68
+ except:
69
+ pass
70
+
71
+ if len(gen_scaf) == 0:
72
+
73
+ rew = 1
74
+ else:
75
+ fps = [Chem.RDKFingerprint(x) for x in gen_scaf]
76
+
77
+
78
+ fps = np.array(fps)
79
+ fps_r = np.array(fps_r)
80
+
81
+ rew = average_agg_tanimoto(fps_r,fps)
82
+ if math.isnan(rew):
83
+ rew = 1
84
+
85
+ return rew ## change this to penalty
86
+
87
+ ##########################################
88
+ ##########################################
89
+ ##########################################
90
+
91
+ def mols2grid_image(mols,path):
92
+ mols = [e if e is not None else Chem.RWMol() for e in mols]
93
+
94
+ for i in range(len(mols)):
95
+ if Metrics.valid(mols[i]):
96
+ AllChem.Compute2DCoords(mols[i])
97
+ Draw.MolToFile(mols[i], os.path.join(path,"{}.png".format(i+1)), size=(1200,1200))
98
+ #wandb.save(os.path.join(path,"{}.png".format(i+1)))
99
+ else:
100
+ continue
101
+
102
+ def save_smiles_matrices(mols,edges_hard, nodes_hard, path, data_source = None):
103
+ mols = [e if e is not None else Chem.RWMol() for e in mols]
104
+
105
+ for i in range(len(mols)):
106
+ if Metrics.valid(mols[i]):
107
+ save_path = os.path.join(path,"{}.txt".format(i+1))
108
+ with open(save_path, "a") as f:
109
+ np.savetxt(f, edges_hard[i].cpu().numpy(), header="edge matrix:\n",fmt='%1.2f')
110
+ f.write("\n")
111
+ np.savetxt(f, nodes_hard[i].cpu().numpy(), header="node matrix:\n", footer="\nsmiles:",fmt='%1.2f')
112
+ f.write("\n")
113
+ #f.write(m0)
114
+ f.write("\n")
115
+ print(Chem.MolToSmiles(mols[i]), file=open(save_path,"a"))
116
+ #wandb.save(save_path)
117
+ else:
118
+ continue
119
+
120
+
121
+ ##########################################
122
+ ##########################################
123
+ ##########################################
124
+
125
+
126
+ def dense_to_sparse_with_attr(adj):
127
+ assert adj.dim() >= 2 and adj.dim() <= 3
128
+ assert adj.size(-1) == adj.size(-2)
129
+
130
+ index = adj.nonzero(as_tuple=True)
131
+ edge_attr = adj[index]
132
+
133
+ if len(index) == 3:
134
+ batch = index[0] * adj.size(-1)
135
+ index = (batch + index[1], batch + index[2])
136
+ #index = torch.stack(index, dim=0)
137
+ return index, edge_attr
138
+
139
+
140
+ def label2onehot(labels, dim, device):
141
+ """Convert label indices to one-hot vectors."""
142
+ out = torch.zeros(list(labels.size())+[dim]).to(device)
143
+ out.scatter_(len(out.size())-1,labels.unsqueeze(-1),1.)
144
+
145
+ return out.float()
146
+
147
+
148
+ def mol_sample(sample_directory, edges, nodes, idx, i,matrices2mol, dataset_name):
149
+ sample_path = os.path.join(sample_directory,"{}_{}-epoch_iteration".format(idx+1, i+1))
150
+ g_edges_hat_sample = torch.max(edges, -1)[1]
151
+ g_nodes_hat_sample = torch.max(nodes , -1)[1]
152
+ mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=dataset_name)
153
+ for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
154
+
155
+ if not os.path.exists(sample_path):
156
+ os.makedirs(sample_path)
157
+
158
+ mols2grid_image(mol,sample_path)
159
+ save_smiles_matrices(mol,g_edges_hat_sample.detach(), g_nodes_hat_sample.detach(), sample_path)
160
+
161
+ if len(os.listdir(sample_path)) == 0:
162
+ os.rmdir(sample_path)
163
+
164
+ print("Valid molecules are saved.")
165
+ print("Valid matrices and smiles are saved")
166
+
167
+
168
+ def logging(log_path, start_time, i, idx, loss, save_path, drug_smiles, edge, node,
169
+ matrices2mol, dataset_name, real_adj, real_annot, drug_vecs):
170
+
171
+ g_edges_hat_sample = torch.max(edge, -1)[1]
172
+ g_nodes_hat_sample = torch.max(node , -1)[1]
173
+
174
+ a_tensor_sample = torch.max(real_adj, -1)[1].float()
175
+ x_tensor_sample = torch.max(real_annot, -1)[1].float()
176
+
177
+ mols = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=dataset_name)
178
+ for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
179
+
180
+ real_mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=dataset_name)
181
+ for e_, n_ in zip(a_tensor_sample, x_tensor_sample)]
182
+
183
+ atom_types_average = Metrics.mean_atom_type(g_nodes_hat_sample)
184
+ real_smiles = [Chem.MolToSmiles(x) for x in real_mol if x is not None]
185
+ gen_smiles = []
186
+ uniq_smiles = []
187
+ for line in mols:
188
+ if line is not None:
189
+ gen_smiles.append(Chem.MolToSmiles(line))
190
+ uniq_smiles.append(Chem.MolToSmiles(line))
191
+ elif line is None:
192
+ gen_smiles.append(None)
193
+
194
+ gen_smiles_saves = [None if x is None else max(x.split('.'), key=len) for x in gen_smiles]
195
+ uniq_smiles_saves = [None if x is None else max(x.split('.'), key=len) for x in uniq_smiles]
196
+
197
+ sample_save_dir = os.path.join(save_path, "samples.txt")
198
+ with open(sample_save_dir, "a") as f:
199
+ for idxs in range(len(gen_smiles_saves)):
200
+ if gen_smiles_saves[idxs] is not None:
201
+ f.write(gen_smiles_saves[idxs])
202
+ f.write("\n")
203
+
204
+ k = len(set(uniq_smiles_saves) - {None})
205
+ et = time.time() - start_time
206
+ et = str(datetime.timedelta(seconds=et))[:-7]
207
+ log = "Elapsed [{}], Epoch/Iteration [{}/{}]".format(et, idx, i+1)
208
+ gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in mols if x is not None]
209
+ chembl_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_mol if x is not None]
210
+
211
+ # Log update
212
+ #m0 = get_all_metrics(gen = gen_smiles, train = train_smiles, batch_size=batch_size, k = valid_mol_num, device=self.device)
213
+ valid = fraction_valid(gen_smiles_saves)
214
+ unique = fraction_unique(uniq_smiles_saves, k, check_validity=False)
215
+ novel_starting_mol = novelty(gen_smiles_saves, real_smiles)
216
+ novel_akt = novelty(gen_smiles_saves, drug_smiles)
217
+ if (len(uniq_smiles_saves) == 0):
218
+ snn_chembl = 0
219
+ snn_akt = 0
220
+ maxlen = 0
221
+ else:
222
+ snn_chembl = average_agg_tanimoto(np.array(chembl_vecs),np.array(gen_vecs))
223
+ snn_akt = average_agg_tanimoto(np.array(drug_vecs),np.array(gen_vecs))
224
+ maxlen = Metrics.max_component(uniq_smiles_saves, 45)
225
+
226
+ loss.update({'Validity': valid})
227
+ loss.update({'Uniqueness': unique})
228
+ loss.update({'Novelty': novel_starting_mol})
229
+ loss.update({'Novelty_akt': novel_akt})
230
+ loss.update({'SNN_chembl': snn_chembl})
231
+ loss.update({'SNN_akt': snn_akt})
232
+ loss.update({'MaxLen': maxlen})
233
+ loss.update({'Atom_types': atom_types_average})
234
+
235
+ wandb.log({"Validity": valid, "Uniqueness": unique, "Novelty": novel_starting_mol,
236
+ "Novelty_akt": novel_akt, "SNN_chembl": snn_chembl, "SNN_akt": snn_akt,
237
+ "MaxLen": maxlen, "Atom_types": atom_types_average})
238
+
239
+ for tag, value in loss.items():
240
+ log += ", {}: {:.4f}".format(tag, value)
241
+ with open(log_path, "a") as f:
242
+ f.write(log)
243
+ f.write("\n")
244
+ print(log)
245
+ print("\n")
246
+
247
+
248
+ def plot_grad_flow(named_parameters, model, itera, epoch,grad_flow_directory):
249
+ # Based on https://discuss.pytorch.org/t/check-gradient-flow-in-network/15063/10
250
+ '''Plots the gradients flowing through different layers in the net during training.
251
+ Can be used for checking for possible gradient vanishing / exploding problems.
252
+
253
+ Usage: Plug this function in Trainer class after loss.backwards() as
254
+ "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
255
+ ave_grads = []
256
+ max_grads= []
257
+ layers = []
258
+ for n, p in named_parameters:
259
+ if(p.requires_grad) and ("bias" not in n):
260
+ #print(p.grad,n)
261
+ layers.append(n)
262
+ ave_grads.append(p.grad.abs().mean().cpu())
263
+ max_grads.append(p.grad.abs().max().cpu())
264
+ plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
265
+ plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
266
+ plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k" )
267
+ plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
268
+ plt.xlim(left=0, right=len(ave_grads))
269
+ plt.ylim(bottom = -0.001, top=1) # zoom in on the lower gradient regions
270
+ plt.xlabel("Layers")
271
+ plt.ylabel("average gradient")
272
+ plt.title("Gradient flow")
273
+ plt.grid(True)
274
+ plt.legend([Line2D([0], [0], color="c", lw=4),
275
+ Line2D([0], [0], color="b", lw=4),
276
+ Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])
277
+ pltsavedir = grad_flow_directory
278
+ plt.savefig(os.path.join(pltsavedir, "weights_" + model + "_" + str(itera) + "_" + str(epoch) + ".png"), dpi= 500,bbox_inches='tight')
279
+
280
+
281
+ def get_mol(smiles_or_mol):
282
+ '''
283
+ Loads SMILES/molecule into RDKit's object
284
+ '''
285
+ if isinstance(smiles_or_mol, str):
286
+ if len(smiles_or_mol) == 0:
287
+ return None
288
+ mol = Chem.MolFromSmiles(smiles_or_mol)
289
+ if mol is None:
290
+ return None
291
+ try:
292
+ Chem.SanitizeMol(mol)
293
+ except ValueError:
294
+ return None
295
+ return mol
296
+ return smiles_or_mol
297
+
298
+
299
+ def mapper(n_jobs):
300
+ '''
301
+ Returns function for map call.
302
+ If n_jobs == 1, will use standard map
303
+ If n_jobs > 1, will use multiprocessing pool
304
+ If n_jobs is a pool object, will return its map function
305
+ '''
306
+ if n_jobs == 1:
307
+ def _mapper(*args, **kwargs):
308
+ return list(map(*args, **kwargs))
309
+
310
+ return _mapper
311
+ if isinstance(n_jobs, int):
312
+ pool = Pool(n_jobs)
313
+
314
+ def _mapper(*args, **kwargs):
315
+ try:
316
+ result = pool.map(*args, **kwargs)
317
+ finally:
318
+ pool.terminate()
319
+ return result
320
+
321
+ return _mapper
322
+ return n_jobs.map
323
+
324
+
325
+ def remove_invalid(gen, canonize=True, n_jobs=1):
326
+ """
327
+ Removes invalid molecules from the dataset
328
+ """
329
+ if not canonize:
330
+ mols = mapper(n_jobs)(get_mol, gen)
331
+ return [gen_ for gen_, mol in zip(gen, mols) if mol is not None]
332
+ return [x for x in mapper(n_jobs)(canonic_smiles, gen) if
333
+ x is not None]
334
+
335
+
336
+ def fraction_valid(gen, n_jobs=1):
337
+ """
338
+ Computes a number of valid molecules
339
+ Parameters:
340
+ gen: list of SMILES
341
+ n_jobs: number of threads for calculation
342
+ """
343
+ gen = mapper(n_jobs)(get_mol, gen)
344
+ return 1 - gen.count(None) / len(gen)
345
+ def canonic_smiles(smiles_or_mol):
346
+ mol = get_mol(smiles_or_mol)
347
+ if mol is None:
348
+ return None
349
+ return Chem.MolToSmiles(mol)
350
+ def fraction_unique(gen, k=None, n_jobs=1, check_validity=False):
351
+ """
352
+ Computes a number of unique molecules
353
+ Parameters:
354
+ gen: list of SMILES
355
+ k: compute unique@k
356
+ n_jobs: number of threads for calculation
357
+ check_validity: raises ValueError if invalid molecules are present
358
+ """
359
+ if k is not None:
360
+ if len(gen) < k:
361
+ warnings.warn(
362
+ "Can't compute unique@{}.".format(k) +
363
+ "gen contains only {} molecules".format(len(gen))
364
+ )
365
+ gen = gen[:k]
366
+ canonic = set(mapper(n_jobs)(canonic_smiles, gen))
367
+ if None in canonic and check_validity:
368
+ #canonic = [i for i in canonic if i is not None]
369
+ raise ValueError("Invalid molecule passed to unique@k")
370
+ return 0 if len(gen) == 0 else len(canonic) / len(gen)
371
+
372
+ def novelty(gen, train, n_jobs=1):
373
+ gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
374
+ gen_smiles_set = set(gen_smiles) - {None}
375
+ train_set = set(train)
376
+ return 0 if len(gen_smiles_set) == 0 else len(gen_smiles_set - train_set) / len(gen_smiles_set)
377
+
378
+
379
+
380
+ def average_agg_tanimoto(stock_vecs, gen_vecs,
381
+ batch_size=5000, agg='max',
382
+ device='cpu', p=1):
383
+ """
384
+ For each molecule in gen_vecs finds closest molecule in stock_vecs.
385
+ Returns average tanimoto score for between these molecules
386
+
387
+ Parameters:
388
+ stock_vecs: numpy array <n_vectors x dim>
389
+ gen_vecs: numpy array <n_vectors' x dim>
390
+ agg: max or mean
391
+ p: power for averaging: (mean x^p)^(1/p)
392
+ """
393
+ assert agg in ['max', 'mean'], "Can aggregate only max or mean"
394
+ agg_tanimoto = np.zeros(len(gen_vecs))
395
+ total = np.zeros(len(gen_vecs))
396
+ for j in range(0, stock_vecs.shape[0], batch_size):
397
+ x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float()
398
+ for i in range(0, gen_vecs.shape[0], batch_size):
399
+
400
+ y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float()
401
+ y_gen = y_gen.transpose(0, 1)
402
+ tp = torch.mm(x_stock, y_gen)
403
+ jac = (tp / (x_stock.sum(1, keepdim=True) +
404
+ y_gen.sum(0, keepdim=True) - tp)).cpu().numpy()
405
+ jac[np.isnan(jac)] = 1
406
+ if p != 1:
407
+ jac = jac**p
408
+ if agg == 'max':
409
+ agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(
410
+ agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0))
411
+ elif agg == 'mean':
412
+ agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0)
413
+ total[i:i + y_gen.shape[1]] += jac.shape[0]
414
+ if agg == 'mean':
415
+ agg_tanimoto /= total
416
+ if p != 1:
417
+ agg_tanimoto = (agg_tanimoto)**(1/p)
418
+ return np.mean(agg_tanimoto)
419
+
420
+ def str2bool(v):
421
+ return v.lower() in ('true')