|
import streamlit as st |
|
from PIL import Image |
|
import numpy as np |
|
import pubchempy as pcp |
|
from rdkit import Chem |
|
from rdkit.Chem import Draw |
|
import time |
|
import os |
|
import random |
|
from mod import check_img |
|
import string |
|
|
|
@st.dialog("Select From Library") |
|
def lib_modal(): |
|
|
|
files = os.listdir("lib") |
|
files = random.choices(files, k=50) |
|
cols = st.columns(2) |
|
for i, file in enumerate(files): |
|
with cols[i % 2]: |
|
st.image("lib/{}".format(file), use_container_width=True) |
|
if st.button("Select", key=i): |
|
st.session_state.uploaded_file = file |
|
st.session_state.show_modal = False |
|
st.rerun() |
|
st.markdown("---") |
|
|
|
|
|
|
|
if "show_modal" not in st.session_state: |
|
st.session_state.show_modal = True |
|
|
|
st.title("3D2SMILES") |
|
st.markdown(""" |
|
This app generates SMILES strings from images of molecules ball-and-stick models. |
|
|
|
[Version 1 Paper](https://chemrxiv.org/engage/chemrxiv/article-details/673a9d62f9980725cf89abe1) | |
|
[Version 2 Paper]() | |
|
[Synthetic Dataset](https://huggingface.co/datasets/weathon/3d2smiles_synthetic) | |
|
[Real Dataset](https://huggingface.co/datasets/weathon/3d2smiles_real) | |
|
[Author Github](https://github.com/weathon) | |
|
[Feedback](mailto:wguo6358@gmail.com) | |
|
[Deploy](https://huggingface.co/spaces/weathon/3d2smiles?docker=true) |
|
""") |
|
col1, col2 = st.columns(2) |
|
gen_strategy = col1.selectbox("Select a generative strategy", ("Beam Search", "Sampling", "Greedy Search")) |
|
temp = col2.slider("Temperature", 0.0, 2.0, 1.0) |
|
|
|
|
|
mode = st.radio("Select file source:", ["Upload File", "Choose from Our Demo Library"]) |
|
if mode == "Upload File": |
|
st.session_state.pop("uploaded_file", None) |
|
uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg", "webp", "heic"]) |
|
contribute = col1.checkbox("Contribute To Public Library", value=True, help="If checked, images will be included in the PUBLIC library, and the image will be reviewed by our team and used for model training. When checked, do not upload any sensitive or personal data.") |
|
else: |
|
contribute = False |
|
uploaded_file = None |
|
st.warning("You are accessing a user-uploaded library. These images are unverified and may contain inappropriate content or incorrectly assembled files, despite some moderation. Proceed with caution. For inquiries, contact us.") |
|
|
|
from_library = st.button("Select from Library") |
|
if from_library: |
|
if st.session_state.show_modal: |
|
lib_modal() |
|
|
|
if "uploaded_file" in st.session_state: |
|
st.markdown("You have selected: {}".format(st.session_state.uploaded_file)) |
|
|
|
|
|
button = st.button("Submit") |
|
if button: |
|
if uploaded_file: |
|
start_time = time.time() |
|
image = Image.open(uploaded_file) |
|
elif "uploaded_file" in st.session_state: |
|
start_time = time.time() |
|
image = Image.open("lib/{}".format(st.session_state.uploaded_file)) |
|
else: |
|
st.error("Please upload an image or select from the library.") |
|
st.stop() |
|
|
|
options = ["CC(=O)OC1=CC=CC=C1C(=C)C(=O)O", "CC(=O)", "CC(=O)O", "CC(=O)C", "CC(=O)C1=CC=CC=C1"] |
|
grid = [st.columns(2) for _ in range(len(options) // 3 + 1)] |
|
cols = [col for row in grid for col in row] |
|
|
|
for i, (smiles, col) in enumerate(zip(options, cols)): |
|
cid = pcp.get_compounds(smiles, 'smiles') |
|
name = cid[0].synonyms[0] |
|
col.markdown(f"### {name}") |
|
m = Chem.MolFromSmiles(smiles) |
|
img = Draw.MolToImage(m) |
|
col.image(img, use_container_width=False) |
|
pubchem_url = "https://pubchem.ncbi.nlm.nih.gov/compound/{}".format(cid[0].cid) |
|
col.markdown("[PubChem]({})".format(pubchem_url)) |
|
|
|
|
|
if contribute: |
|
flagged = check_img(image) |
|
if flagged: |
|
st.warning("The image is flagged as inappropriate. Please upload a different image.") |
|
st.stop() |
|
else: |
|
filename = "user_upload" + "".join(random.choices(string.ascii_uppercase, k=10)) + ".png" |
|
image.save("lib/{}".format(filename)) |
|
st.success("The image is stored in the library.") |
|
st.markdown("---") |
|
st.markdown("Taken {} seconds".format(round(time.time() - start_time, 2))) |