Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from streamlit_ketcher import st_ketcher | |
| from SynTool.mcts.tree import Tree, TreeConfig | |
| from SynTool.mcts.expansion import PolicyFunction | |
| from SynTool.mcts.search import extract_tree_stats | |
| from SynTool.utils.config import PolicyNetworkConfig | |
| from SynTool.interfaces.visualisation import to_table, extract_routes | |
| import pickle | |
| import uuid | |
| import base64 | |
| import pandas as pd | |
| import json | |
| import re | |
| def download_button(object_to_download, download_filename, button_text, pickle_it=False): | |
| """ | |
| Issued from | |
| Generates a link to download the given object_to_download. | |
| Params: | |
| ------ | |
| object_to_download: The object to be downloaded. | |
| download_filename (str): filename and extension of file. e.g. mydata.csv, | |
| some_txt_output.txt download_link_text (str): Text to display for download | |
| link. | |
| button_text (str): Text to display on download button (e.g. 'click here to download file') | |
| pickle_it (bool): If True, pickle file. | |
| Returns: | |
| ------- | |
| (str): the anchor tag to download object_to_download | |
| Examples: | |
| -------- | |
| download_link(your_df, 'YOUR_DF.csv', 'Click to download data!') | |
| download_link(your_str, 'YOUR_STRING.txt', 'Click to download text!') | |
| """ | |
| if pickle_it: | |
| try: | |
| object_to_download = pickle.dumps(object_to_download) | |
| except pickle.PicklingError as e: | |
| st.write(e) | |
| return None | |
| else: | |
| if isinstance(object_to_download, bytes): | |
| pass | |
| elif isinstance(object_to_download, pd.DataFrame): | |
| object_to_download = object_to_download.to_csv(index=False).encode('utf-8') | |
| # Try JSON encode for everything else | |
| # else: | |
| # object_to_download = json.dumps(object_to_download) | |
| try: | |
| # some strings <-> bytes conversions necessary here | |
| b64 = base64.b64encode(object_to_download.encode()).decode() | |
| except AttributeError: | |
| b64 = base64.b64encode(object_to_download).decode() | |
| button_uuid = str(uuid.uuid4()).replace('-', '') | |
| button_id = re.sub('\d+', '', button_uuid) | |
| custom_css = f""" | |
| <style> | |
| #{button_id} {{ | |
| background-color: rgb(255, 255, 255); | |
| color: rgb(38, 39, 48); | |
| text-decoration: none; | |
| border-radius: 4px; | |
| border-width: 1px; | |
| border-style: solid; | |
| border-color: rgb(230, 234, 241); | |
| border-image: initial; | |
| }} | |
| #{button_id}:hover {{ | |
| border-color: rgb(246, 51, 102); | |
| color: rgb(246, 51, 102); | |
| }} | |
| #{button_id}:active {{ | |
| box-shadow: none; | |
| background-color: rgb(246, 51, 102); | |
| color: white; | |
| }} | |
| </style> """ | |
| dl_link = custom_css + f'<a download="{download_filename}" id="{button_id}" href="data:file/txt;base64,{b64}">{button_text}</a><br></br>' | |
| return dl_link | |
| st.set_page_config( # layout="wide", | |
| page_title="SynTool GUI", | |
| page_icon="🧪",) | |
| st.title("`SynTool GUI`") | |
| st.write("*{Introduction text to be inserted here}*") | |
| st.header('Molecule input') | |
| st.write("You can provide a molecular structure by either providing its SMILES string + Enter, either by drawing it + Apply.") | |
| DEFAULT_MOL='NC(CCCCB(O)O)(CCN1CCC(CO)C1)C(=O)O' | |
| molecule = st.text_input("Molecule", DEFAULT_MOL) | |
| smile_code = st_ketcher(molecule) | |
| st.header('Launch calculation') | |
| st.write("If you modified the structure, please ensure you clicked on 'Apply' (bottom right of the molecular editor).") | |
| st.markdown(f"The molecule SMILES is actually: ``{smile_code}``") | |
| max_depth = st.slider('Maximal number of reaction steps', min_value=2, max_value=9, value=9) | |
| run_default = st.button('Launch and search a reaction path',) | |
| ranking_policy_weights_path = 'data/policy_network.ckpt' | |
| reaction_rules_path = 'data/reaction_rules.pickle' | |
| building_blocks_path = 'data/building_blocks.smi' | |
| policy_config = PolicyNetworkConfig(weights_path=ranking_policy_weights_path) | |
| policy_function = PolicyFunction(policy_config=policy_config) | |
| if run_default: | |
| st.toast('Optimisation is started. The progress will be printed below') | |
| spinner = st.spinner(text="Running with default parameters...") | |
| tree_config = TreeConfig( | |
| search_strategy="expansion_first", | |
| evaluation_type="rollout", | |
| max_iterations=100, | |
| max_depth=max_depth, | |
| min_mol_size=0, | |
| init_node_value=0.5, | |
| ucb_type="uct", | |
| c_ucb=0.1, | |
| silent=True | |
| ) | |
| with spinner: | |
| tree = Tree( | |
| target=smile_code, | |
| tree_config=tree_config, | |
| reaction_rules_path=reaction_rules_path, | |
| building_blocks_path=building_blocks_path, | |
| policy_function=policy_function, | |
| value_function=None, | |
| ) | |
| _ = list(tree) | |
| res = extract_tree_stats(tree, smile_code) # extract_routes(tree) | |
| st.header('Results') | |
| if res['found_paths']: | |
| st.write("Success!") | |
| st.subheader("Retrosynthetic Routes Report") | |
| st.markdown(to_table(tree, None, extended=True, integration=True), unsafe_allow_html=True) | |
| st.subheader("Statistics") | |
| st.write(pd.DataFrame(res, index=[0])) | |
| st.subheader("Downloads") | |
| dl_html = download_button(to_table(tree, None, extended=True, integration=False), | |
| 'results_syntool.html', | |
| 'Download results as a HTML file') | |
| dl_csv = download_button(pd.DataFrame(res, index=[0]), | |
| 'results_syntool.csv', | |
| 'Download statistics as an Excel csv file') | |
| st.markdown(dl_html+dl_csv, unsafe_allow_html=True) | |
| else: | |
| st.write("Found no reaction path.") | |
| st.divider() | |
| st.header('Restart from the beginning?') | |
| if st.button("Restart"): | |
| st.rerun() | |