import numpy as np import pandas as pd import re import selfies as sf import torch from rdkit import Chem from rdkit.Chem import DataStructs, AllChem, Descriptors, QED, Draw from rdkit.Chem.Crippen import MolLogP from rdkit.Contrib.SA_Score import sascorer from transformers import BartForConditionalGeneration, AutoTokenizer from transformers.modeling_outputs import BaseModelOutput gen_tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted") gen_model = BartForConditionalGeneration.from_pretrained("ibm/materials.selfies-ted") # Function to display molecule image from SMILES def smiles_to_image(smiles): mol = Chem.MolFromSmiles(smiles) return Draw.MolToImage(mol) if mol else None def calculate_properties(smiles): mol = Chem.MolFromSmiles(smiles) if mol: qed = QED.qed(mol) logp = MolLogP(mol) sa = sascorer.calculateScore(mol) wt = Descriptors.MolWt(mol) return qed, sa, logp, wt return None, None, None, None # Function to calculate Tanimoto similarity def calculate_tanimoto(smiles1, smiles2): mol1 = Chem.MolFromSmiles(smiles1) mol2 = Chem.MolFromSmiles(smiles2) if mol1 and mol2: fp1 = AllChem.GetMorganFingerprintAsBitVect(mol1, 2) fp2 = AllChem.GetMorganFingerprintAsBitVect(mol2, 2) return round(DataStructs.FingerprintSimilarity(fp1, fp2), 2) return None def _perturb_latent(latent_vecs, noise_scale=0.5): return ( torch.tensor( np.random.uniform(0, 1, latent_vecs.shape) * noise_scale, dtype=torch.float32, ) + latent_vecs ) def _encode(selfies): encoding = gen_tokenizer( selfies, return_tensors='pt', max_length=128, truncation=True, padding='max_length', ) input_ids = encoding['input_ids'] attention_mask = encoding['attention_mask'] outputs = gen_model.model.encoder( input_ids=input_ids, attention_mask=attention_mask ) model_output = outputs.last_hidden_state return model_output, attention_mask def _generate(latent_vector, mask): encoder_outputs = BaseModelOutput(latent_vector) decoder_output = gen_model.generate( encoder_outputs=encoder_outputs, attention_mask=mask, max_new_tokens=64, do_sample=True, top_k=5, top_p=0.95, num_return_sequences=1, ) selfies = gen_tokenizer.batch_decode(decoder_output, skip_special_tokens=True) return [sf.decoder(re.sub(r'\]\s*(.*?)\s*\[', r']\1[', i)) for i in selfies] # Function to generate canonical SMILES and molecule image def generate_canonical(smiles): s = sf.encoder(smiles) selfie = s.replace("][", "] [") latent_vec, mask = _encode([selfie]) gen_mol = None for i in range(5, 51): print("Searching Latent space") noise = i / 10 perturbed_latent = _perturb_latent(latent_vec, noise_scale=noise) gen = _generate(perturbed_latent, mask) mol = Chem.MolFromSmiles(gen[0]) if mol: gen_mol = Chem.MolToSmiles(mol) if gen_mol != Chem.MolToSmiles(Chem.MolFromSmiles(smiles)): break else: print('Abnormal molecule:', gen[0]) if gen_mol: # Calculate properties for ref and gen molecules print("calculating properties") ref_properties = calculate_properties(smiles) gen_properties = calculate_properties(gen_mol) tanimoto_similarity = calculate_tanimoto(smiles, gen_mol) # Prepare the table with ref mol and gen mol data = { "Property": ["QED", "SA", "LogP", "Mol Wt", "Tanimoto Similarity"], "Reference Mol": [ ref_properties[0], ref_properties[1], ref_properties[2], ref_properties[3], tanimoto_similarity, ], "Generated Mol": [ gen_properties[0], gen_properties[1], gen_properties[2], gen_properties[3], "", ], } df = pd.DataFrame(data) # Display molecule image of canonical smiles print("Getting image") mol_image = smiles_to_image(gen_mol) return df, gen_mol, mol_image return "Invalid SMILES", None, None