File size: 5,801 Bytes
8279c69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import streamlit as st
import streamlit_ext as ste

from inference import Inference
import random
from rdkit.Chem import Draw
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole
import io
from PIL import Image

class DrugGENConfig:
    submodel='DrugGEN'
    act='relu'
    max_atom=45
    dim=32
    depth=1
    heads=8
    mlp_ratio=3
    dropout=0.
    features=False
    inference_sample_num=1000
    inf_batch_size=1
    protein_data_dir='data/akt'
    drug_index='data/drug_smiles.index'
    drug_data_dir='data/akt'
    mol_data_dir='data'
    log_dir='experiments/logs'
    model_save_dir='experiments/models'
    sample_dir='experiments/samples'
    result_dir="experiments/tboard_output"
    inf_dataset_file="chembl45_test.pt"
    inf_drug_dataset_file='akt_test.pt'
    inf_raw_file='data/chembl_test.smi'
    inf_drug_raw_file="data/akt_test.smi"
    inference_model="experiments/models/DrugGEN"
    log_sample_step=1000
    set_seed=False
    seed=1

class NoTargetConfig(DrugGENConfig):
    submodel="NoTarget"
    dim=128
    inference_model="experiments/models/NoTarget"


model_configs = {
    "DrugGEN": DrugGENConfig(),
    "NoTarget": NoTargetConfig()
}


with st.sidebar:
    st.title("DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks")
    st.write("[![arXiv](https://img.shields.io/badge/arXiv-2302.07868-b31b1b.svg)](https://arxiv.org/abs/2302.07868) [![github-repository](https://img.shields.io/badge/GitHub-black?logo=github)](https://github.com/HUBioDataLab/DrugGEN)")

    with st.expander("Expand to display information about models"):
        st.write("""
### Model Variations
- **DrugGEN-Prot**: composed of two GANs, incorporates protein features to the transformer decoder module of GAN2 (together with the de novo molecules generated by GAN1) to direct the target centric molecule design.
- **DrugGEN-CrossLoss**: composed of one GAN, the input of the GAN1 generator is the real molecules dataset and the GAN1 discriminator compares the generated molecules with the real inhibitors of the given target.
- **DrugGEN-NoTarget**: composed of one GAN, focuses on learning the chemical properties from the ChEMBL training dataset, no target-specific generation.

        """)    

    with st.form("model_selection_from"):
        model_name = st.radio(
            'Select a model to make inference (DrugGEN-Prot and DrugGEN-CrossLoss models design molecules to target the AKT1 protein)',
            ('DrugGEN-Prot', 'DrugGEN-CrossLoss', 'DrugGEN-NoTarget')
        )

        model_name = model_name.replace("DrugGEN-", "")

        molecule_num_input = st.number_input('Number of molecules to generate', min_value=1, max_value=100_000, value=1000, step=1)

        seed_input = st.number_input("RNG seed value (can be used for reproducibility):", min_value=0, value=42, step=1)
        
        submitted = st.form_submit_button("Start Computing")



if submitted:
# if submitted or ("submitted" in st.session_state):
    # st.session_state["submitted"] = True
    config = model_configs[model_name]

    config.inference_sample_num = molecule_num_input
    config.seed = seed_input
    
    with st.spinner(f'Creating the trainer class instance for {model_name}...'):
        trainer = Trainer(config)
    with st.spinner(f'Running inference function of {model_name} (this may take a while) ...'):
        results = trainer.inference()
    st.success(f"Inference of {model_name} took {results['runtime']:.2f} seconds.")

    with st.expander("Expand to see the generation performance scores"):
        st.write("### Generation performance scores (novelty is calculated in comparison to the training dataset)")
        st.success(f"Validity: {results['fraction_valid']}")
        st.success(f"Uniqueness: {results['uniqueness']}")
        st.success(f"Novelty: {results['novelty']}")

    with open(f'experiments/inference/{model_name}/inference_drugs.txt') as f:
        inference_drugs = f.read()
    # st.download_button(label="Click to download generated molecules", data=inference_drugs, file_name=f'DrugGEN-{model_name}_denovo_mols.smi', mime="text/plain")
    ste.download_button(label="Click to download generated molecules", data=inference_drugs, file_name=f'DrugGEN-{model_name}_denovo_mols.smi', mime="text/plain")

    
    st.write("Structures of randomly selected 12 de novo molecules from the inference set:")
    # from rdkit.Chem import Draw
# img = Draw.MolsToGridImage(mol_list, molsPerRow=5, subImgSize=(250, 250), maxMols=num_mols,
                           # legends=None, useSVG=True)
    generated_molecule_list = inference_drugs.split("\n")

    selected_molecules = random.choices(generated_molecule_list,k=12)

    selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_molecules]
    # IPythonConsole.UninstallIPythonRenderer()
    drawOptions = Draw.rdMolDraw2D.MolDrawOptions()
    drawOptions.prepareMolsBeforeDrawing = False
    drawOptions.bondLineWidth = 1.

    molecule_image = Draw.MolsToGridImage(
        selected_molecules,
        molsPerRow=3,
        subImgSize=(250, 250),
        maxMols=len(selected_molecules),
        # legends=None,
        returnPNG=False,
        # drawOptions=drawOptions,
        highlightAtomLists=None,
        highlightBondLists=None,
        
    )
    print(type(molecule_image))
    # print(type(molecule_image._data_and_metadata()))
    molecule_image.save("result_grid.png") 
    # png_data = io.BytesIO()
    # molecule_image.save(png_data, format='PNG')
    # png_data.seek(0)
    
    # Step 2: Read the PNG image data as a PIL image
    # pil_image = Image.open(png_data)
    # st.image(pil_image)
    st.image(molecule_image)

else:
    st.warning("Please select a model to make inference")