import streamlit as st import torch import os from rdkit import Chem from rdkit.Chem import Draw from transformers import BartForConditionalGeneration, BartTokenizer from admet_ai import ADMETModel import safe import io from PIL import Image import cairosvg import pandas as pd # Page Configuration st.set_page_config( page_title='Beta-Lactam Molecule Generator', layout='wide' ) # Load Models @st.cache_resource(show_spinner="Loading Models...", ttl=600) def load_models(): """ Load the molecule generation model and the ADMET-AI model. Caches the models to avoid reloading on every run. """ # Load your molecule generation model model_name = "bcadkins01/beta_lactam_generator" access_token = os.getenv("HUGGING_FACE_TOKEN") if access_token is None: st.error("Access token not found. Please set the HUGGING_FACE_TOKEN environment variable.") st.stop() model = BartForConditionalGeneration.from_pretrained(model_name, token=access_token) tokenizer = BartTokenizer.from_pretrained(model_name, token=access_token) # Load ADMET-AI model admet_model = ADMETModel() return model, tokenizer, admet_model # Load models once and reuse model, tokenizer, admet_model = load_models() # Set Generation Parameters in Sidebar st.sidebar.header('Generation Parameters') # Creativity Slider (Temperature) creativity = st.sidebar.slider( 'Creativity (Temperature):', min_value=0.0, max_value=2.4, value=1.0, step=0.2, help="Higher values lead to more diverse (or wild) outputs." ) # Number of Molecules to Generate num_molecules = st.sidebar.number_input( 'Number of Molecules to Generate:', min_value=1, max_value=3, # Adjust as needed value=3, help="Select the number of molecules you want to generate (up to 3)." ) # Function to Generate Molecule Images def generate_molecule_image(input_string, use_safe=False): """ Generates an image of the molecule from the input string. If use_safe is True, input_string is treated as a SAFE string. """ try: if use_safe and input_string is not None: # Generate image from SAFE encoding svg_str = safe.to_image(input_string) # Convert SVG to PNG bytes png_bytes = cairosvg.svg2png(bytestring=svg_str.encode('utf-8')) # Create an image object img = Image.open(io.BytesIO(png_bytes)) else: # Generate standard molecule image mol = Chem.MolFromSmiles(input_string) if mol: img = Draw.MolToImage(mol, size=(250, 250)) else: img = None return img except Exception as e: st.error(f"Error generating molecule image: {e}") return None # Generate Molecules Button if st.button('Generate Molecules'): st.info("Generating molecules... Please wait.") # Beta-lactam core structure core_smiles = "C1C(=O)N(C)C(=O)C1" # Tokenize the core SMILES input_ids = tokenizer(core_smiles, return_tensors='pt').input_ids # Generate molecules using the model with diverse beam search output_ids = model.generate( input_ids=input_ids, max_length=128, do_sample=True, temperature=1.2, # Increase for more diversity top_k=0, # Disable top-k sampling top_p=0.9, # Enable nucleus (top-p) sampling num_return_sequences=num_molecules, num_beams=1 ) # Decode generated molecule SMILES generated_smiles = [ tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids ] # Create generic molecule names for demo molecule_names = [ f"Mol{str(i).zfill(2)}" for i in range(1, len(generated_smiles) + 1) ] # Create df for generated molecules df_molecules = pd.DataFrame({ 'Molecule Name': molecule_names, 'SMILES': generated_smiles }) # Invalid SMILES Check # Function to validate SMILES def is_valid_smile(smile): return Chem.MolFromSmiles(smile) is not None # Apply validation function df_molecules['Valid'] = df_molecules['SMILES'].apply(is_valid_smile) df_valid = df_molecules[df_molecules['Valid']].copy() # Inform user if any molecules were invalid invalid_molecules = df_molecules[~df_molecules['Valid']] if not invalid_molecules.empty: st.warning(f"{len(invalid_molecules)} generated molecules were invalid and excluded from predictions.") # Check if there are valid molecules to proceed if df_valid.empty: st.error("No valid molecules were generated. Please try adjusting the generation parameters.") else: # ADMET Predictions preds = admet_model.predict(smiles=df_valid['SMILES'].tolist()) # Ensure 'SMILES' is a column in preds if 'SMILES' not in preds.columns: preds['SMILES'] = df_valid['SMILES'].values # Merge predictions with valid molecules df_results = pd.merge(df_valid, preds, on='SMILES', how='inner') # Set 'Molecule Name' as index df_results.set_index('Molecule Name', inplace=True) # Select only desired ADMET properties admet_properties = [ 'molecular_weight', 'logP', 'hydrogen_bond_acceptors', 'hydrogen_bond_donors', 'QED', 'ClinTox', 'hERG', 'BBB_Martins' ] df_results_filtered = df_results[['SMILES', 'Valid'] + admet_properties] # Check if df_results_filtered is empty after filtering if df_results_filtered.empty: st.error("No valid ADMET predictions were obtained. Please try adjusting the generation parameters.") else: # Display Molecules st.subheader('Generated Molecules') cols_per_row = min(3, len(df_results_filtered)) # Max 3 columns cols = st.columns(cols_per_row) for idx, (mol_name, row) in enumerate(df_results_filtered.iterrows()): smiles = row['SMILES'] # Attempt to encode to SAFE try: safe_string = safe.encode(smiles) except Exception as e: safe_string = None st.error(f"Could not convert to SAFE encoding for {mol_name}: {e}") # Generate molecule image (SMILES or SAFE) img = generate_molecule_image(smiles) with cols[idx % cols_per_row]: if img is not None and isinstance(img, Image.Image): st.image(img, caption=mol_name) else: st.error(f"Could not generate image for {mol_name}") # Display SMILES string st.write("**SMILES:**") st.text(smiles) # Display SAFE encoding if available if safe_string: st.write("**SAFE Encoding:**") st.text(safe_string) # Optionally display SAFE visualization safe_img = generate_molecule_image(safe_string, use_safe=True) if safe_img is not None: st.image(safe_img, caption=f"{mol_name} (SAFE Visualization)") # Display selected ADMET properties st.write("**ADMET Properties:**") admet_data = row.drop(['SMILES', 'Valid']) st.write(admet_data) else: st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.")