|
import streamlit as st |
|
import numpy as np |
|
import pandas as pd |
|
import pickle |
|
import pygad |
|
|
|
from VQGAE.models import VQGAE, OrderingNetwork |
|
from CGRtools.containers import QueryContainer |
|
from VQGAE.utils import frag_counts_to_inds, restore_order, decode_molecules |
|
|
|
|
|
allene = QueryContainer() |
|
allene.add_atom("C") |
|
allene.add_atom("A") |
|
allene.add_atom("A") |
|
allene.add_bond(1, 2, 2) |
|
allene.add_bond(1, 3, 2) |
|
|
|
peroxide_charge = QueryContainer() |
|
peroxide_charge.add_atom("O", charge=-1) |
|
peroxide_charge.add_atom("O") |
|
peroxide_charge.add_bond(1, 2, 1) |
|
|
|
peroxide = QueryContainer() |
|
peroxide.add_atom("O") |
|
peroxide.add_atom("O") |
|
peroxide.add_bond(1, 2, 1) |
|
|
|
|
|
def tanimoto_kernel(x, y): |
|
""" |
|
"The Tanimoto coefficient is a measure of the similarity between two sets. |
|
It is defined as the size of the intersection divided by the size of the union of the sample sets." |
|
|
|
The Tanimoto coefficient is also known as the Jaccard index |
|
|
|
Adoppted from https://github.com/cimm-kzn/CIMtools/blob/master/CIMtools/metrics/pairwise.py |
|
|
|
:param x: 2D array of features. |
|
:param y: 2D array of features. |
|
:return: The Tanimoto coefficient between the two arrays. |
|
""" |
|
x_dot = np.dot(x, y.T) |
|
|
|
x2 = (x ** 2).sum(axis=1) |
|
y2 = (y ** 2).sum(axis=1) |
|
|
|
len_x2 = len(x2) |
|
len_y2 = len(y2) |
|
|
|
result = x_dot / (np.array([x2] * len_y2).T + np.array([y2] * len_x2) - x_dot) |
|
result[np.isnan(result)] = 0 |
|
|
|
return result |
|
|
|
|
|
def fitness_func_batch(ga_instance, solutions, solutions_indices): |
|
frag_counts = np.array(solutions) |
|
if len(frag_counts.shape) == 1: |
|
frag_counts = frag_counts[np.newaxis, :] |
|
|
|
|
|
rf_score = rf_model.predict_proba(frag_counts)[:, 1] |
|
|
|
|
|
mol_size = frag_counts.sum(-1).astype(np.int64) |
|
size_penalty = np.where(mol_size < 18, -1.0, 0.) |
|
|
|
|
|
dissimilarity_score = 1 - tanimoto_kernel(frag_counts, X).max(-1) |
|
dissimilarity_score += np.where(dissimilarity_score == 0, -5, 0) |
|
|
|
|
|
fitness = 0.5 * rf_score + 0.3 * dissimilarity_score + size_penalty |
|
|
|
|
|
if use_ordering_score: |
|
frag_inds = frag_counts_to_inds(frag_counts, max_atoms=51) |
|
_, ordering_scores = restore_order(frag_inds, ordering_model) |
|
ordering_scores = np.array(ordering_scores) |
|
fitness += 0.2 * ordering_scores |
|
|
|
return fitness.tolist() |
|
|
|
|
|
def on_generation_progress(ga): |
|
global ga_progress |
|
global ga_bar |
|
ga_progress = ga_progress + 1 |
|
ga_bar.progress(ga_progress // num_generations * 100, text=ga_progress_text) |
|
|
|
|
|
@st.cache_data |
|
def load_data(batch_size): |
|
X = np.load("saved_model/tubulin_qsar_class_train_data_vqgae.npz")["x"] |
|
Y = np.load("saved_model/tubulin_qsar_class_train_data_vqgae.npz")["y"] |
|
with open("saved_model/rf_class_train_tubulin.pickle", "rb") as inp: |
|
rf_model = pickle.load(inp) |
|
|
|
vqgae_model = VQGAE.load_from_checkpoint( |
|
"saved_model/vqgae.ckpt", |
|
task="decode", |
|
batch_size=batch_size, |
|
map_location="cpu" |
|
) |
|
vqgae_model = vqgae_model.eval() |
|
|
|
ordering_model = OrderingNetwork.load_from_checkpoint( |
|
"saved_model/ordering_network.ckpt", |
|
batch_size=batch_size, |
|
map_location="cpu" |
|
) |
|
ordering_model = ordering_model.eval() |
|
return X, Y, rf_model, vqgae_model, ordering_model |
|
|
|
st.title('Inverse QSAR of Tubulin inhibitors in colchicine site with VQGAE') |
|
|
|
batch_size = 500 |
|
X, Y, rf_model, vqgae_model, ordering_model = load_data(batch_size) |
|
|
|
assert X.shape == (603, 4096) |
|
|
|
with st.sidebar: |
|
with st.form("my_form"): |
|
num_generations = st.slider( |
|
'Number of generations for GA', |
|
min_value=3, |
|
max_value=40, |
|
value=5 |
|
) |
|
|
|
parent_selection_type = st.selectbox( |
|
label='Parent selection type', |
|
options=( |
|
'Steady-state selection', |
|
'Roulette wheel selection', |
|
'Stochastic universal selection', |
|
'Rank selection', |
|
'Random selection', |
|
'Tournament selection' |
|
), |
|
index=1 |
|
) |
|
|
|
parent_selection_translator = { |
|
"Steady-state selection": "sss", |
|
"Roulette wheel selection": "rws", |
|
"Stochastic universal selection": "sus", |
|
"Rank selection": "rank", |
|
"Random selection": "random", |
|
"Tournament selection": "tournament", |
|
} |
|
|
|
parent_selection_type = parent_selection_translator[parent_selection_type] |
|
|
|
crossover_type = st.selectbox( |
|
label='Crossover type', |
|
options=( |
|
'Single point', |
|
'Two points', |
|
), |
|
index=0 |
|
) |
|
|
|
crossover_translator = { |
|
"Single point": "single_point", |
|
"Two points": "two_points", |
|
} |
|
|
|
crossover_type = crossover_translator[crossover_type] |
|
|
|
num_parents_mating = int( |
|
st.slider( |
|
'Pecentage of parents mating taken from initial population', |
|
min_value=0, |
|
max_value=100, |
|
step=1, |
|
value=33, |
|
) * X.shape[0] // 100 |
|
) |
|
|
|
keep_parents = int( |
|
st.slider( |
|
'Percentage of parents kept taken from number of parents mating', |
|
min_value=0, |
|
max_value=100, |
|
step=1, |
|
value=66 |
|
) * num_parents_mating // 100 |
|
) |
|
|
|
use_ordering_score = st.toggle('Use ordering score', value=True) |
|
|
|
random_seed = int(st.number_input("Random seed", value=42, placeholder="Type a number...")) |
|
st.form_submit_button('Start optimisation') |
|
|
|
ga_instance = pygad.GA( |
|
fitness_func=fitness_func_batch, |
|
on_generation=on_generation_progress, |
|
initial_population=X, |
|
num_genes=X.shape[-1], |
|
fitness_batch_size=batch_size, |
|
num_generations=num_generations, |
|
num_parents_mating=num_parents_mating, |
|
parent_selection_type=parent_selection_type, |
|
crossover_type=crossover_type, |
|
mutation_type="adaptive", |
|
mutation_percent_genes=[10, 5], |
|
|
|
save_best_solutions=False, |
|
save_solutions=True, |
|
keep_elitism=0, |
|
keep_parents=keep_parents, |
|
suppress_warnings=True, |
|
random_seed=random_seed, |
|
gene_type=int |
|
) |
|
|
|
ga_progress = 0 |
|
ga_progress_text = "Genetic optimisation in progress. Please wait." |
|
ga_bar = st.progress(ga_progress // num_generations * 100, text=ga_progress_text) |
|
ga_instance.run() |
|
|
|
with st.spinner('Getting unique solutions'): |
|
unique_solutions = list(set(tuple(s) for s in ga_instance.solutions)) |
|
st.success(f'{len(unique_solutions)} solutions were obtained') |
|
|
|
scores = { |
|
"rf_score": [], |
|
"similarity_score": [] |
|
} |
|
|
|
if use_ordering_score: |
|
scores["ordering_score"] = [] |
|
|
|
rescoring_progress = 0 |
|
rescoring_progress_text = "Rescoring obtained solutions" |
|
rescoring_bar = st.progress(0, text=rescoring_progress_text) |
|
total_rescoring_steps = len(unique_solutions) // batch_size + 1 |
|
for i in range(total_rescoring_steps): |
|
vqgae_latents = unique_solutions[i * batch_size: (i + 1) * batch_size] |
|
frag_counts = np.array(vqgae_latents) |
|
rf_scores = rf_model.predict_proba(frag_counts)[:, 1] |
|
similarity_scores = tanimoto_kernel(frag_counts, X).max(-1) |
|
scores["rf_score"].extend(rf_scores.tolist()) |
|
scores["similarity_score"].extend(similarity_scores.tolist()) |
|
if use_ordering_score: |
|
frag_inds = frag_counts_to_inds(frag_counts, max_atoms=51) |
|
_, ordering_scores = restore_order(frag_inds, ordering_model) |
|
scores["ordering_score"].extend(ordering_scores) |
|
rescoring_bar.progress(i // total_rescoring_steps * 100, text=rescoring_progress_text) |
|
|
|
sc_df = pd.DataFrame(scores) |
|
|
|
if use_ordering_score: |
|
chosen_gen = sc_df[(sc_df["similarity_score"] < 0.95) & (sc_df["rf_score"] > 0.5) & (sc_df["ordering_score"] > 0.7)] |
|
else: |
|
chosen_gen = sc_df[ |
|
(sc_df["similarity_score"] < 0.95) & (sc_df["rf_score"] > 0.5)] |
|
|
|
chosen_ids = chosen_gen.index.to_list() |
|
chosen_solutions = np.array([unique_solutions[ind] for ind in chosen_ids]) |
|
gen_frag_inds = frag_counts_to_inds(chosen_solutions, max_atoms=51) |
|
st.info(f'The number of chosen solutions is {gen_frag_inds.shape[0]}', icon="ℹ️") |
|
|
|
gen_molecules = [] |
|
results = {"smiles": [], "ordering_score": [], "validity": []} |
|
decoding_progress = 0 |
|
decoding_progress_text = "Decoding chosen solutions" |
|
decoding_bar = st.progress(0, text=decoding_progress_text) |
|
total_decoding_steps = gen_frag_inds.shape[0] // batch_size + 1 |
|
for i in range(total_decoding_steps): |
|
inputs = gen_frag_inds[i * batch_size: (i + 1) * batch_size] |
|
canon_order_inds, scores = restore_order( |
|
frag_inds=inputs, |
|
ordering_model=ordering_model, |
|
) |
|
molecules, validity = decode_molecules( |
|
ordered_frag_inds=canon_order_inds, |
|
vqgae_model=vqgae_model |
|
) |
|
gen_molecules.extend(molecules) |
|
results["smiles"].extend([str(molecule) for molecule in molecules]) |
|
results["ordering_score"].extend(scores) |
|
results["validity"].extend([1 if i else 0 for i in validity]) |
|
decoding_bar.progress(i // total_decoding_steps * 100, text=rescoring_progress_text) |
|
|
|
gen_stats = pd.DataFrame(results) |
|
full_stats = pd.concat([gen_stats, chosen_gen[["similarity_score", "rf_score"]].reset_index(), ], axis=1, ignore_index=False) |
|
|
|
st.dataframe(full_stats) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|