import streamlit as st import streamlit.components.v1 as components import pandas as pd import mols2grid from ipywidgets import interact, widgets import textwrap import moses from transformers import EncoderDecoderModel, RobertaTokenizer from moses.metrics.utils import QED, SA, logP, NP, weight, get_n_rings from moses.utils import mapper, get_mol # @st.cache(allow_output_mutation=False, hash_funcs={Tokenizer: str}) from typing import List from util import filter_dataframe @st.cache_resource def load_models(): # protein_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/WarmMolGenTwo") # mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k") model1 = EncoderDecoderModel.from_pretrained("gokceuludogan/WarmMolGenOne") model2 = EncoderDecoderModel.from_pretrained("gokceuludogan/WarmMolGenTwo") return model1, model2 # , protein_tokenizer, mol_tokenizer def count(smiles_list: List[str]): counts = [] for smiles in smiles_list: counts.append(len(smiles)) return counts def remove_none_elements(mol_list, smiles_list): filtered_mol_list = [] filtered_smiles_list = [] indices = [] for i, element in enumerate(mol_list): if element is not None: filtered_mol_list.append(element) else: indices.append(i) removed_len = len(indices) for i in range(len(smiles_list)): if i not in indices: filtered_smiles_list.append(smiles_list.__getitem__(i)) return filtered_mol_list, filtered_smiles_list, removed_len def format_list_numbers(lst): for i, value in enumerate(lst): lst[i] = float("{:.3f}".format(value)) def generate_molecules(model_name, num_mols, max_new_tokens, do_sample, num_beams, target, pool): protein_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/WarmMolGenTwo") mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k") # model1, model2, protein_tokenizer, mol_tokenizer = load_models() model1, model2 = load_models() inputs = protein_tokenizer(target, return_tensors="pt") model = model1 if model_name == 'WarmMolGenOne' else model2 outputs = model.generate(**inputs, decoder_start_token_id=mol_tokenizer.bos_token_id, eos_token_id=mol_tokenizer.eos_token_id, pad_token_id=mol_tokenizer.eos_token_id, max_length=int(max_new_tokens), num_return_sequences=int(num_mols), do_sample=do_sample, num_beams=num_beams) output_smiles = mol_tokenizer.batch_decode(outputs, skip_special_tokens=True) st.write("### Generated Molecules") # mol_list = list(map(MolFromSmiles, output_smiles)) # print(mol_list) # QED_scores = list(map(QED.qed, mol_list)) # print(QED_scores) # st.write(output_smiles) mol_list = mapper(pool)(get_mol, output_smiles) mol_list, output_smiles, removed_len = remove_none_elements(mol_list, output_smiles) if removed_len != 0: st.write(f"#### Note that: {removed_len} numbers of generated invalid molecules are discarded.") QED_scores = mapper(pool)(QED, mol_list) SA_scores = mapper(pool)(SA, mol_list) logP_scores = mapper(pool)(logP, mol_list) NP_scores = mapper(pool)(NP, mol_list) weight_scores = mapper(pool)(weight, mol_list) format_list_numbers(QED_scores) format_list_numbers(SA_scores) format_list_numbers(logP_scores) format_list_numbers(NP_scores) format_list_numbers(weight_scores) df_smiles = pd.DataFrame( {'SMILES': output_smiles, "QED": QED_scores, "SA": SA_scores, "logP": logP_scores, "NP": NP_scores, "Weight": weight_scores}) return df_smiles def warm_molgen_demo(): with st.form("my_form"): with st.sidebar: st.sidebar.subheader("Configurable parameters") model_name = st.sidebar.selectbox( "Model Selector", options=[ "WarmMolGenOne", "WarmMolGenTwo", ], index=0, ) num_mols = st.sidebar.number_input( "Number of generated molecules", min_value=0, max_value=20, value=10, help="The number of molecules to be generated.", ) max_new_tokens = st.sidebar.number_input( "Maximum length", min_value=0, max_value=1024, value=128, help="The maximum length of the sequence to be generated.", ) do_sample = st.sidebar.selectbox( "Sampling?", (True, False), help="Whether or not to use sampling; use beam decoding otherwise.", ) target = st.text_area( "Target Sequence", "MENTENSVDSKSIKNLEPKIIHGSESMDSGISLDNSYKMDYPEMGLCIIINNKNFHKSTG", ) generate_new_molecules = st.form_submit_button("Generate Molecules") num_beams = None if do_sample is True else int(num_mols) pool = 1 if generate_new_molecules: st.session_state.df = generate_molecules(model_name, num_mols, max_new_tokens, do_sample, num_beams, target, pool) if 'df' not in st.session_state: st.session_state.df = generate_molecules(model_name, num_mols, max_new_tokens, do_sample, num_beams, target, pool) df = st.session_state.df filtered_df = filter_dataframe(df) if filtered_df.empty: st.markdown( """ No molecules were found with specified properties. """, unsafe_allow_html=True ) else: raw_html = mols2grid.display(filtered_df, height=1000)._repr_html_() components.html(raw_html, width=900, height=450, scrolling=True) st.markdown("## How to Generate") generation_code = f""" from transformers import EncoderDecoderModel, RobertaTokenizer protein_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/{model_name}") mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k") model = EncoderDecoderModel.from_pretrained("gokceuludogan/{model_name}") inputs = protein_tokenizer("{target}", return_tensors="pt") outputs = model.generate(**inputs, decoder_start_token_id=mol_tokenizer.bos_token_id, eos_token_id=mol_tokenizer.eos_token_id, pad_token_id=mol_tokenizer.eos_token_id, max_length={max_new_tokens}, num_return_sequences={num_mols}, do_sample={do_sample}, num_beams={num_beams}) mol_tokenizer.batch_decode(outputs, skip_special_tokens=True) """ st.code(textwrap.dedent(generation_code)) # textwrap.dedent("".join("Halletcez"))) st.set_page_config(page_title="WarmMolGen Demo", page_icon="🔥", layout='wide') st.markdown("# WarmMolGen Demo") st.sidebar.header("WarmMolGen Demo") st.markdown( """ This demo illustrates WarmMolGen models' generation capabilities. Given a target sequence and a set of parameters, the models generate molecules targeting the given protein sequence. Please enter an input sequence below 👇 and configure parameters from the sidebar 👈 to generate molecules! See below for saving the output molecules and the code snippet generating them! """ ) warm_molgen_demo()