Spaces:
Runtime error
Runtime error
import os | |
import re | |
import logging | |
import pandas as pd | |
from omegaconf import OmegaConf | |
import gradio as gr | |
from illuminate import Illuminate | |
import matplotlib.pyplot as plt | |
from rdkit import Chem | |
from rdkit.Chem import Draw | |
def MolsMatrixToGridImage(mols, legends, filename): | |
img = Draw.MolsToGridImage(mols, molsPerRow=5, subImgSize=(400,400), legends=legends) | |
img.save(filename) | |
return img | |
def launch_illumination(target, representation, surrogate, acquisition, ranges, generations_max, function_calls_max, structural_filters): | |
config = { | |
'controller': { | |
'max_generations': generations_max, | |
'max_fitness_calls': function_calls_max | |
}, | |
'archive': { | |
'name': 'Troglitazone', | |
'size': 150, | |
'accuracy': 25000 | |
}, | |
'descriptor': { | |
'properties': [ | |
'Descriptors.ExactMolWt', | |
'Descriptors.MolLogP', | |
'Descriptors.TPSA', | |
'Crippen.MolMR' | |
], | |
'ranges': ranges | |
}, | |
'fitness': { | |
'type': 'Fingerprint', | |
'target': target, | |
'representation': representation | |
}, | |
'arbiter': { | |
'rules': [rule_set for rule_set in structural_filters] | |
}, | |
'generator': { | |
'batch_size': 40, | |
'initial_size': 40, | |
'mutation_data': 'data/smarts/mutation_collection.tsv', | |
'initial_data': 'data/smiles/guacamol_intitial_rediscovery_troglitazone.smi' | |
}, | |
'surrogate': { | |
'type': "Fingerprint", | |
'representation': surrogate, | |
}, | |
'acquisition': { | |
'type': acquisition, | |
'beta': 0.3 | |
} | |
} | |
log = logging.getLogger(__name__) | |
log.info(OmegaConf.to_yaml(config)) | |
current_instance = Illuminate(OmegaConf.create(config)) | |
current_instance() | |
stats_file = pd.read_csv("statistics.csv") | |
molecules_file = pd.read_csv("molecules.csv") | |
# files_in_directory = os.listdir('.') | |
# pattern = re.compile(r'archive_(\d+)\.csv') | |
# archive_files = [f for f in files_in_directory if pattern.match(f)] | |
# archive_numbers = [int(pattern.search(f).group(1)) for f in archive_files] | |
# archive_file = pd.read_csv(f'archive_{max(archive_numbers)}.csv') | |
# csv_files = [file for file in files_in_directory if file.endswith('.csv')] | |
# for csv_file in csv_files: | |
# if os.path.isfile(csv_file): | |
# os.remove(csv_file) | |
top_molecules = molecules_file.nlargest(10, 'fitness') | |
top_smiles = top_molecules['smiles'].tolist() | |
top_fitness = top_molecules['fitness'].tolist() | |
top_mols = [Chem.MolFromSmiles(smile) for smile in top_smiles] | |
top_legends = [f'Similarity: {score:.5f}' for score in top_fitness] | |
image = MolsMatrixToGridImage(mols=top_mols, legends=top_legends, filename='top_molecules_grid.png') | |
return image, stats_file, molecules_file | |
def validate_and_process(target, representation, surrogate, acquisition, exact_mol_wt_min, exact_mol_wt_max, mol_log_p_min, mol_log_p_max, tpsa_min, tpsa_max, mol_mr_min, mol_mr_max, generations_max, function_calls_max, structural_filters): | |
# Ensure min is less than max for each range | |
exact_mol_wt_range = sorted([exact_mol_wt_min, exact_mol_wt_max]) | |
mol_log_p_range = sorted([mol_log_p_min, mol_log_p_max]) | |
tpsa_range = sorted([tpsa_min, tpsa_max]) | |
mol_mr_range = sorted([mol_mr_min, mol_mr_max]) | |
ranges = [ | |
exact_mol_wt_range, | |
mol_log_p_range, | |
tpsa_range, | |
mol_mr_range | |
] | |
image, stats_file, molecules_file = launch_illumination(target, representation, surrogate, acquisition, ranges, generations_max, function_calls_max, structural_filters) | |
return image | |
def gradio_interface(): | |
with gr.Blocks() as demo: | |
representation_options = ["ECFP4", "ECFP6", "FCFP4", "FCFP6"] | |
surrogate_options = ["ECFP4", "ECFP6", "FCFP4", "FCFP6", "RDFP", "APFP", "TTFP"] | |
acquisition_options = ["Mean", "UCB", "EI", "logEI"] | |
target = gr.Textbox(label="Target (SMILES)", value="O=C1NC(=O)SC1Cc4ccc(OCC3(Oc2c(c(c(O)c(c2CC3)C)C)C)C)cc4") | |
with gr.Row(): | |
generations_max = gr.Slider(minimum=0, maximum=150, value=1, step=1, label="Generations") | |
function_calls_max = gr.Slider(minimum=0, maximum=15000, value=5000, step=100, label="Function Calls") | |
structural_filters = gr.CheckboxGroup(["BMS", "Dundee", "Glaxo", "Inpharmatica", "LINT", "MLSMR", "PAINS", "SureChEMBL"], label="Structural Filters") | |
with gr.Row(): | |
representation = gr.Dropdown(choices=representation_options, value="ECFP4", label="Fitness Representation") | |
surrogate = gr.Dropdown(choices=surrogate_options, value="ECFP4", label="Surrogate Representation") | |
acquisition = gr.Dropdown(choices=acquisition_options, value="Mean", label="Acquisition Function") | |
with gr.Accordion("Physicochemical Descriptors", open=False): | |
with gr.Row(): | |
exact_mol_wt_min = gr.Slider(minimum=0, maximum=885, value=225, step=1, label="Minimum Molecular Weight") | |
exact_mol_wt_max = gr.Slider(minimum=0, maximum=885, value=555, step=1, label="Maximum Molecular Weight") | |
with gr.Row(): | |
mol_log_p_min = gr.Slider(minimum=-4, maximum=8, value=-0.5, step=0.1, label="Minimum Log(P)") | |
mol_log_p_max = gr.Slider(minimum=-4, maximum=8, value=5.5, step=0.1, label="Maximum Log(P)") | |
with gr.Row(): | |
tpsa_min = gr.Slider(minimum=0, maximum=250, value=0, step=1, label="Minimum TPSA") | |
tpsa_max = gr.Slider(minimum=0, maximum=250, value=140, step=1, label="Maximum TPSA") | |
with gr.Row(): | |
mol_mr_min = gr.Slider(minimum=0, maximum=250, value=40, step=1, label="Minimum Molecular Refractivity") | |
mol_mr_max = gr.Slider(minimum=0, maximum=250, value=130, step=1, label="Maximum Molecular Refractivity") | |
submit_btn = gr.Button("Submit") | |
output_image = gr.Image(label="Top Molecules") | |
gr.DownloadButton(label=f"Download Optimisation History", value="./statistics.csv", visible=True) | |
gr.DownloadButton(label=f"Download Output Molecules", value="./molecules.csv", visible=True) | |
submit_btn.click( | |
validate_and_process, | |
inputs=[ | |
target, | |
representation, | |
surrogate, | |
acquisition, | |
exact_mol_wt_min, | |
exact_mol_wt_max, | |
mol_log_p_min, | |
mol_log_p_max, | |
tpsa_min, | |
tpsa_max, | |
mol_mr_min, | |
mol_mr_max, | |
generations_max, | |
function_calls_max, | |
structural_filters, | |
], | |
outputs=[output_image] | |
) | |
demo.launch() | |
if __name__ == "__main__": | |
gradio_interface() | |