tagirshin commited on
Commit
112efaf
1 Parent(s): faa0c0b

Initial commit (test)

Browse files
ga_app.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import pandas as pd
4
+ import pickle
5
+ import pygad
6
+
7
+ from tqdm.auto import tqdm
8
+ from VQGAE.models import VQGAE, OrderingNetwork
9
+ from CGRtools.containers import QueryContainer
10
+ from VQGAE.utils import frag_counts_to_inds, restore_order, decode_molecules
11
+
12
+ # define groups to filter
13
+ allene = QueryContainer()
14
+ allene.add_atom("C")
15
+ allene.add_atom("A")
16
+ allene.add_atom("A")
17
+ allene.add_bond(1, 2, 2)
18
+ allene.add_bond(1, 3, 2)
19
+
20
+ peroxide_charge = QueryContainer()
21
+ peroxide_charge.add_atom("O", charge=-1)
22
+ peroxide_charge.add_atom("O")
23
+ peroxide_charge.add_bond(1, 2, 1)
24
+
25
+ peroxide = QueryContainer()
26
+ peroxide.add_atom("O")
27
+ peroxide.add_atom("O")
28
+ peroxide.add_bond(1, 2, 1)
29
+
30
+
31
+ def tanimoto_kernel(x, y):
32
+ """
33
+ "The Tanimoto coefficient is a measure of the similarity between two sets.
34
+ It is defined as the size of the intersection divided by the size of the union of the sample sets."
35
+
36
+ The Tanimoto coefficient is also known as the Jaccard index
37
+
38
+ Adoppted from https://github.com/cimm-kzn/CIMtools/blob/master/CIMtools/metrics/pairwise.py
39
+
40
+ :param x: 2D array of features.
41
+ :param y: 2D array of features.
42
+ :return: The Tanimoto coefficient between the two arrays.
43
+ """
44
+ x_dot = np.dot(x, y.T)
45
+
46
+ x2 = (x ** 2).sum(axis=1)
47
+ y2 = (y ** 2).sum(axis=1)
48
+
49
+ len_x2 = len(x2)
50
+ len_y2 = len(y2)
51
+
52
+ result = x_dot / (np.array([x2] * len_y2).T + np.array([y2] * len_x2) - x_dot)
53
+ result[np.isnan(result)] = 0
54
+
55
+ return result
56
+
57
+
58
+ def rescoring(vqgae_latents):
59
+ frag_counts = np.array(vqgae_latents)
60
+ rf_scores = rf_model.predict_proba(frag_counts)[:, 1]
61
+ similarity_scores = tanimoto_kernel(frag_counts, X).max(-1)
62
+
63
+ frag_inds = frag_counts_to_inds(frag_counts, max_atoms=51)
64
+ _, ordering_scores = restore_order(frag_inds, ordering_model)
65
+ return rf_scores.tolist(), similarity_scores.tolist(), ordering_scores
66
+
67
+
68
+ def fitness_func_batch(ga_instance, solutions, solutions_indices):
69
+ frag_counts = np.array(solutions)
70
+
71
+ # prediction of activity by random forest
72
+ rf_score = rf_model.predict_proba(frag_counts)[:, 1]
73
+
74
+ # size penalty if molecule too small
75
+ mol_size = frag_counts.sum(-1).astype(np.int64)
76
+ size_penalty = np.where(mol_size < 18, -1.0, 0.)
77
+
78
+ # adding dissimilarity so it generates different solutions
79
+ dissimilarity_score = 1 - tanimoto_kernel(frag_counts, X).max(-1)
80
+ dissimilarity_score += np.where(dissimilarity_score == 0, -5, 0)
81
+
82
+ # prediction of ordering score
83
+ frag_inds = frag_counts_to_inds(frag_counts, max_atoms=51)
84
+ _, ordering_scores = restore_order(frag_inds, ordering_model)
85
+ ordering_scores = np.array(ordering_scores)
86
+
87
+ # full fitness function
88
+ fitness = 0.5 * rf_score + 0.3 * dissimilarity_score + size_penalty + 0.2 * ordering_scores
89
+ return fitness.tolist()
90
+
91
+
92
+ def on_generation_progress(ga):
93
+ pbar.update(1)
94
+
95
+
96
+ @st.cache_data
97
+ def load_data(batch_size):
98
+ X = np.load("saved_model/tubulin_qsar_class_train_data_vqgae.npz")["x"]
99
+ Y = np.load("saved_model/tubulin_qsar_class_train_data_vqgae.npz")["y"]
100
+ with open("saved_model/rf_class_train_tubulin.pickle", "rb") as inp:
101
+ rf_model = pickle.load(inp)
102
+
103
+ vqgae_model = VQGAE.load_from_checkpoint("saved_model/vqgae.ckpt", task="decode", batch_size=batch_size)
104
+ vqgae_model = vqgae_model.to("cpu").eval()
105
+
106
+ ordering_model = OrderingNetwork.load_from_checkpoint("saved_model/ordering_network.ckpt", batch_size=batch_size)
107
+ ordering_model = ordering_model.to("cpu").eval()
108
+ return X, Y, rf_model, vqgae_model, ordering_model
109
+
110
+
111
+ st.title('Inverse QSAR of Tubulin inhibitors in colchicine site with VQGAE')
112
+
113
+ data_load_state = st.text('Loading data...')
114
+ batch_size = 500
115
+ X, Y, rf_model, vqgae_model, ordering_model = load_data(batch_size)
116
+
117
+ data_load_state.text("Done! (using st.cache_data)")
118
+
119
+ # initial_pop = X
120
+ #
121
+ # num_parents_mating = int(initial_pop.shape[0] * 0.33 // 10 * 10)
122
+ # keep_parents = int(num_parents_mating * 0.66 // 10 * 10)
123
+ # print(num_parents_mating, keep_parents)
124
+ #
125
+ # num_generations = 30
126
+ # with tqdm(total=num_generations) as pbar:
127
+ # ga_instance = pygad.GA(
128
+ # fitness_func=fitness_func_batch,
129
+ # on_generation=on_generation_progress,
130
+ # initial_population=initial_pop,
131
+ # num_genes=initial_pop.shape[-1],
132
+ # fitness_batch_size=batch_size,
133
+ # num_generations=num_generations,
134
+ # num_parents_mating=num_parents_mating,
135
+ # parent_selection_type="rws",
136
+ # crossover_type="single_point",
137
+ # mutation_type="adaptive",
138
+ # mutation_percent_genes=[10, 5],
139
+ # # https://pygad.readthedocs.io/en/latest/pygad.html#use-adaptive-mutation-in-pygad
140
+ # save_best_solutions=False,
141
+ # save_solutions=True,
142
+ # keep_elitism=0, # turn it off to make keep_parents work
143
+ # keep_parents=keep_parents, # 2/3 of num_parents_mating
144
+ # # parallel_processing=['process', 5],
145
+ # suppress_warnings=True,
146
+ # random_seed=42,
147
+ # gene_type=int
148
+ # )
149
+ # ga_instance.run()
150
+ #
151
+ # solutions = ga_instance.solutions
152
+ # solutions = list(set(tuple(s) for s in solutions))
153
+ # print(len(solutions))
154
+ #
155
+ # scores = {"rf_score": [], "similarity_score": [], "ordering_score": []}
156
+ # for i in tqdm(range(len(solutions) // 100 + 1)):
157
+ # solution = solutions[i * 100: (i + 1) * 100]
158
+ # rf_score, similarity_score, ordering_score = rescoring(solution)
159
+ # scores["rf_score"].extend(rf_score)
160
+ # scores["similarity_score"].extend(similarity_score)
161
+ # scores["ordering_score"].extend(ordering_score)
162
+ #
163
+ # sc_df = pd.DataFrame(scores)
164
+ #
165
+ # chosen_gen = sc_df[(sc_df["similarity_score"] < 0.95) & (sc_df["rf_score"] > 0.5) & (sc_df["ordering_score"] > 0.7)]
166
+ #
167
+ # chosen_ids = chosen_gen.index.to_list()
168
+ # chosen_solutions = np.array([solutions[ind] for ind in chosen_ids])
169
+ # gen_frag_inds = frag_counts_to_inds(chosen_solutions, max_atoms=51)
170
+ #
171
+ # gen_molecules = []
172
+ # results = {"score": [], "valid": []}
173
+ # for i in tqdm(range(gen_frag_inds.shape[0] // batch_size + 1)):
174
+ # inputs = gen_frag_inds[i * batch_size: (i + 1) * batch_size]
175
+ # canon_order_inds, scores = restore_order(
176
+ # frag_inds=inputs,
177
+ # ordering_model=ordering_model,
178
+ # )
179
+ # molecules, validity = decode_molecules(
180
+ # ordered_frag_inds=canon_order_inds,
181
+ # vqgae_model=vqgae_model
182
+ # )
183
+ # gen_molecules.extend(molecules)
184
+ # results["score"].extend(scores)
185
+ # results["valid"].extend([1 if i else 0 for i in validity])
186
+ #
187
+ # gen_stats = pd.DataFrame(results)
188
+ # full_stats = pd.concat([chosen_gen.reset_index(), gen_stats], axis=1, ignore_index=False)
189
+ # valid_gen_stats = full_stats[full_stats.valid == 1]
190
+ # valid_gen_mols = []
191
+ # for i, record in zip(list(valid_gen_stats.index), valid_gen_stats.to_dict("records")):
192
+ # mol = gen_molecules[i]
193
+ # mol.meta.update({
194
+ # "rf_score": record["rf_score"],
195
+ # "similarity_score": record["similarity_score"],
196
+ # "ordering_score": record["ordering_score"],
197
+ # })
198
+ # valid_gen_mols.append(mol)
199
+ #
200
+ # filtered_gen_mols = []
201
+ # for mol in valid_gen_mols:
202
+ # is_frag = allene < mol or peroxide_charge < mol or peroxide < mol
203
+ # is_macro = False
204
+ # for ring in mol.sssr:
205
+ # if len(ring) > 8 or len(ring) < 4:
206
+ # is_macro = True
207
+ # break
208
+ # if not is_frag and not is_macro:
209
+ # filtered_gen_mols.append(mol)
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit==1.27
2
+ pygad==3.0.1
3
+ torch>2.0
4
+ pytorch-lightning==2.0.2
5
+ pyg==2.3.0
6
+ mendeleev==0.12.1
7
+ networkx>=3.0
8
+ omegaconf>2.0
9
+ cgrtools==4.1.35
10
+ scikit-learn>1.2.0
11
+ numpy>1.24
12
+ py-mini-racer
13
+ git+https://github.com/Laboratoire-de-Chemoinformatique/VQGAE.git
saved_model/ordering_network.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6549254a0ecc70b2b67dc29a30f432539cd05050e9065fc94451ceb61d5f9d5
3
+ size 328718215
saved_model/rf_class_train_tubulin.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec3a569708d585bc6f137503814f1a291d25e5c6687b8faa2ff9f1ce2ad7ff56
3
+ size 3264972
saved_model/tubulin_qsar_class_train_data_vqgae.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dafe72d49cc692eeca8b3b1b865d7e720aa0475b27532ebf3591539b9dbcddc9
3
+ size 19764418
saved_model/vqgae.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f40e3e76afbfef91aa0cc4e0cf72502169e79c0fc9cf44045902f915c489bf54
3
+ size 509183825