File size: 5,790 Bytes
15e4ea2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import streamlit as st
import streamlit.components.v1 as components
import pandas as pd
import mols2grid
import textwrap
import numpy as np
from transformers import EncoderDecoderModel, RobertaTokenizer

# @st.cache(allow_output_mutation=False, hash_funcs={Tokenizer: str})
@st.cache(suppress_st_warning=True)
def load_models():
    # protein_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/WarmMolGenTwo")
    # mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k")
    model1 = EncoderDecoderModel.from_pretrained("gokceuludogan/WarmMolGenOne")
    model2 = EncoderDecoderModel.from_pretrained("gokceuludogan/WarmMolGenTwo")
    return model1, model2#, protein_tokenizer, mol_tokenizer


def warmmolgen_demo():
    protein_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/WarmMolGenTwo")
    mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k")
    #model1, model2, protein_tokenizer, mol_tokenizer = load_models()
    model1, model2 = load_models()

    st.sidebar.subheader("Configurable parameters")

    model_name = st.sidebar.selectbox(
        "Model Selector",
        options=[
            "WarmMolGenOne",
            "WarmMolGenTwo",
        ],
        index=0,
    )

    num_mols = st.sidebar.number_input(
        "Number of generated molecules",
        min_value=0,
        max_value=20,
        value=10,
        help="The number of molecules to be generated.",
    )

    max_new_tokens = st.sidebar.number_input(
        "Maximum length",
        min_value=0,
        max_value=1024,
        value=128,
        help="The maximum length of the sequence to be generated.",
    )
    # temp = st.sidebar.slider(
    #     "Temperature",
    #     value=1.0,
    #     min_value=0.1,
    #     max_value=100.0,
    #     help="The value used to module the next token probabilities.",
    # )
    # top_k = st.sidebar.number_input(
    #     "Top k",
    #     value=10,
    #     help="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
    # )
    # top_p = st.sidebar.number_input(
    #     "Top p",
    #     value=0.95,
    #     help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
    # )
    do_sample = st.sidebar.selectbox(
        "Sampling?",
        (True, False),
        help="Whether or not to use sampling; use beam decoding otherwise.",
    )
    # num_beams = st.sidebar.number_input(
    #     "Number of beams",
    #     min_value=0,
    #     max_value=20,
    #     value=0,
    #     help="The number of beams to use for beam search.",
    # )
    num_beams = None if do_sample is True else int(num_mols)
    # repetition_penalty = st.sidebar.number_input(
    #     "Repetition Penalty",
    #     min_value=0.0,
    #     value=3.0,
    #     step=0.1,
    #     help="The parameter for repetition penalty. 1.0 means no penalty",
    # )
    # no_repeat_ngram_size = st.sidebar.number_input(
    #     "No Repeat N-Gram Size",
    #     min_value=0,
    #     value=3,
    #     help="If set to int > 0, all ngrams of that size can only occur once.",
    # )
    target = st.text_area(
        "Target Sequence",
        "MENTENSVDSKSIKNLEPKIIHGSESMDSGISLDNSYKMDYPEMGLCIIINNKNFHKSTG",
    )
    inputs = protein_tokenizer(target, return_tensors="pt")

    model = model1 if model_name == 'WarmMolGenOne' else model2
    outputs = model.generate(**inputs, decoder_start_token_id=mol_tokenizer.bos_token_id, 
                             eos_token_id=mol_tokenizer.eos_token_id, pad_token_id=mol_tokenizer.eos_token_id, 
                             max_length=int(max_new_tokens), num_return_sequences=int(num_mols), 
                             do_sample=do_sample, num_beams=num_beams)
    output_smiles = mol_tokenizer.batch_decode(outputs, skip_special_tokens=True)
    st.write("### Generated Molecules")
    #st.write(output_smiles)    
    df_smiles = pd.DataFrame({'SMILES': output_smiles})
    #st.write(df_smiles)    
    raw_html = mols2grid.display(df_smiles, mapping={"SMILES": "SMILES"})._repr_html_()
    components.html(raw_html, width=900, height=450, scrolling=True)
    st.markdown("## How to Generate")
    generation_code = f"""
    from transformers import EncoderDecoderModel, RobertaTokenizer

    protein_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/{model_name}") 
    mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k")
    model = EncoderDecoderModel.from_pretrained("gokceuludogan/{model_name}") 

    inputs = protein_tokenizer("{target}", return_tensors="pt")
    outputs = model.generate(**inputs, decoder_start_token_id=mol_tokenizer.bos_token_id, 
                             eos_token_id=mol_tokenizer.eos_token_id, pad_token_id=mol_tokenizer.eos_token_id, 
                             max_length={max_new_tokens}, num_return_sequences={num_mols}, do_sample={do_sample}, num_beams={num_beams})

    mol_tokenizer.batch_decode(outputs, skip_special_tokens=True)
    """
    st.code(textwrap.dedent(generation_code)) # textwrap.dedent("".join("Halletcez")))

st.set_page_config(page_title="WarmMolGen Demo", page_icon="🔥", layout='wide')
st.markdown("# WarmMolGen Demo")
st.sidebar.header("WarmMolGen Demo")
st.markdown(
    """
    This demo illustrates WarmMolGen models' generation capabilities. 


    Given a target sequence and a set of parameters, the models generate molecules targeting the given protein sequence.


    Please enter an input sequence below 👇  and configure parameters from the sidebar 👈 to generate molecules!


    See below for saving the output molecules and the code snippet generating them!
"""
)

warmmolgen_demo()