jonas-verhellen's picture
Output Reformat
8da2b6e
raw
history blame
7.06 kB
import os
import re
import logging
import pandas as pd
from omegaconf import OmegaConf
import gradio as gr
from illuminate import Illuminate
import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit.Chem import Draw
def MolsMatrixToGridImage(mols, legends, filename):
img = Draw.MolsToGridImage(mols, molsPerRow=5, subImgSize=(400,400), legends=legends)
img.save(filename)
return img
def launch_illumination(target, representation, surrogate, acquisition, ranges, generations_max, function_calls_max, structural_filters):
config = {
'controller': {
'max_generations': generations_max,
'max_fitness_calls': function_calls_max
},
'archive': {
'name': 'Troglitazone',
'size': 150,
'accuracy': 25000
},
'descriptor': {
'properties': [
'Descriptors.ExactMolWt',
'Descriptors.MolLogP',
'Descriptors.TPSA',
'Crippen.MolMR'
],
'ranges': ranges
},
'fitness': {
'type': 'Fingerprint',
'target': target,
'representation': representation
},
'arbiter': {
'rules': [rule_set for rule_set in structural_filters]
},
'generator': {
'batch_size': 40,
'initial_size': 40,
'mutation_data': 'data/smarts/mutation_collection.tsv',
'initial_data': 'data/smiles/guacamol_intitial_rediscovery_troglitazone.smi'
},
'surrogate': {
'type': "Fingerprint",
'representation': surrogate,
},
'acquisition': {
'type': acquisition,
'beta': 0.3
}
}
log = logging.getLogger(__name__)
log.info(OmegaConf.to_yaml(config))
current_instance = Illuminate(OmegaConf.create(config))
current_instance()
stats_file = pd.read_csv("statistics.csv")
molecules_file = pd.read_csv("molecules.csv")
# files_in_directory = os.listdir('.')
# pattern = re.compile(r'archive_(\d+)\.csv')
# archive_files = [f for f in files_in_directory if pattern.match(f)]
# archive_numbers = [int(pattern.search(f).group(1)) for f in archive_files]
# archive_file = pd.read_csv(f'archive_{max(archive_numbers)}.csv')
# csv_files = [file for file in files_in_directory if file.endswith('.csv')]
# for csv_file in csv_files:
# if os.path.isfile(csv_file):
# os.remove(csv_file)
top_molecules = molecules_file.nlargest(10, 'fitness')
top_smiles = top_molecules['smiles'].tolist()
top_fitness = top_molecules['fitness'].tolist()
top_mols = [Chem.MolFromSmiles(smile) for smile in top_smiles]
top_legends = [f'Similarity: {score:.5f}' for score in top_fitness]
image = MolsMatrixToGridImage(mols=top_mols, legends=top_legends, filename='top_molecules_grid.png')
return image, stats_file, molecules_file
def validate_and_process(target, representation, surrogate, acquisition, exact_mol_wt_min, exact_mol_wt_max, mol_log_p_min, mol_log_p_max, tpsa_min, tpsa_max, mol_mr_min, mol_mr_max, generations_max, function_calls_max, structural_filters):
# Ensure min is less than max for each range
exact_mol_wt_range = sorted([exact_mol_wt_min, exact_mol_wt_max])
mol_log_p_range = sorted([mol_log_p_min, mol_log_p_max])
tpsa_range = sorted([tpsa_min, tpsa_max])
mol_mr_range = sorted([mol_mr_min, mol_mr_max])
ranges = [
exact_mol_wt_range,
mol_log_p_range,
tpsa_range,
mol_mr_range
]
image, stats_file, molecules_file = launch_illumination(target, representation, surrogate, acquisition, ranges, generations_max, function_calls_max, structural_filters)
return image
def gradio_interface():
with gr.Blocks() as demo:
representation_options = ["ECFP4", "ECFP6", "FCFP4", "FCFP6"]
surrogate_options = ["ECFP4", "ECFP6", "FCFP4", "FCFP6", "RDFP", "APFP", "TTFP"]
acquisition_options = ["Mean", "UCB", "EI", "logEI"]
target = gr.Textbox(label="Target (SMILES)", value="O=C1NC(=O)SC1Cc4ccc(OCC3(Oc2c(c(c(O)c(c2CC3)C)C)C)C)cc4")
with gr.Row():
generations_max = gr.Slider(minimum=0, maximum=150, value=1, step=1, label="Generations")
function_calls_max = gr.Slider(minimum=0, maximum=15000, value=5000, step=100, label="Function Calls")
structural_filters = gr.CheckboxGroup(["BMS", "Dundee", "Glaxo", "Inpharmatica", "LINT", "MLSMR", "PAINS", "SureChEMBL"], label="Structural Filters")
with gr.Row():
representation = gr.Dropdown(choices=representation_options, value="ECFP4", label="Fitness Representation")
surrogate = gr.Dropdown(choices=surrogate_options, value="ECFP4", label="Surrogate Representation")
acquisition = gr.Dropdown(choices=acquisition_options, value="Mean", label="Acquisition Function")
with gr.Accordion("Physicochemical Descriptors", open=False):
with gr.Row():
exact_mol_wt_min = gr.Slider(minimum=0, maximum=885, value=225, step=1, label="Minimum Molecular Weight")
exact_mol_wt_max = gr.Slider(minimum=0, maximum=885, value=555, step=1, label="Maximum Molecular Weight")
with gr.Row():
mol_log_p_min = gr.Slider(minimum=-4, maximum=8, value=-0.5, step=0.1, label="Minimum Log(P)")
mol_log_p_max = gr.Slider(minimum=-4, maximum=8, value=5.5, step=0.1, label="Maximum Log(P)")
with gr.Row():
tpsa_min = gr.Slider(minimum=0, maximum=250, value=0, step=1, label="Minimum TPSA")
tpsa_max = gr.Slider(minimum=0, maximum=250, value=140, step=1, label="Maximum TPSA")
with gr.Row():
mol_mr_min = gr.Slider(minimum=0, maximum=250, value=40, step=1, label="Minimum Molecular Refractivity")
mol_mr_max = gr.Slider(minimum=0, maximum=250, value=130, step=1, label="Maximum Molecular Refractivity")
submit_btn = gr.Button("Submit")
output_image = gr.Image(label="Top Molecules")
gr.DownloadButton(label=f"Download Optimisation History", value="./statistics.csv", visible=True)
gr.DownloadButton(label=f"Download Output Molecules", value="./molecules.csv", visible=True)
submit_btn.click(
validate_and_process,
inputs=[
target,
representation,
surrogate,
acquisition,
exact_mol_wt_min,
exact_mol_wt_max,
mol_log_p_min,
mol_log_p_max,
tpsa_min,
tpsa_max,
mol_mr_min,
mol_mr_max,
generations_max,
function_calls_max,
structural_filters,
],
outputs=[output_image]
)
demo.launch()
if __name__ == "__main__":
gradio_interface()