test / molecule_generation_helpers.py
ipd's picture
init
79d2379
import numpy as np
import pandas as pd
import re
import selfies as sf
import torch
from rdkit import Chem
from rdkit.Chem import DataStructs, AllChem, Descriptors, QED, Draw
from rdkit.Chem.Crippen import MolLogP
from rdkit.Contrib.SA_Score import sascorer
from transformers import BartForConditionalGeneration, AutoTokenizer
from transformers.modeling_outputs import BaseModelOutput
gen_tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted")
gen_model = BartForConditionalGeneration.from_pretrained("ibm/materials.selfies-ted")
# Function to display molecule image from SMILES
def smiles_to_image(smiles):
mol = Chem.MolFromSmiles(smiles)
return Draw.MolToImage(mol) if mol else None
def calculate_properties(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol:
qed = QED.qed(mol)
logp = MolLogP(mol)
sa = sascorer.calculateScore(mol)
wt = Descriptors.MolWt(mol)
return qed, sa, logp, wt
return None, None, None, None
# Function to calculate Tanimoto similarity
def calculate_tanimoto(smiles1, smiles2):
mol1 = Chem.MolFromSmiles(smiles1)
mol2 = Chem.MolFromSmiles(smiles2)
if mol1 and mol2:
fp1 = AllChem.GetMorganFingerprintAsBitVect(mol1, 2)
fp2 = AllChem.GetMorganFingerprintAsBitVect(mol2, 2)
return round(DataStructs.FingerprintSimilarity(fp1, fp2), 2)
return None
def _perturb_latent(latent_vecs, noise_scale=0.5):
return (
torch.tensor(
np.random.uniform(0, 1, latent_vecs.shape) * noise_scale,
dtype=torch.float32,
)
+ latent_vecs
)
def _encode(selfies):
encoding = gen_tokenizer(
selfies,
return_tensors='pt',
max_length=128,
truncation=True,
padding='max_length',
)
input_ids = encoding['input_ids']
attention_mask = encoding['attention_mask']
outputs = gen_model.model.encoder(
input_ids=input_ids, attention_mask=attention_mask
)
model_output = outputs.last_hidden_state
return model_output, attention_mask
def _generate(latent_vector, mask):
encoder_outputs = BaseModelOutput(latent_vector)
decoder_output = gen_model.generate(
encoder_outputs=encoder_outputs,
attention_mask=mask,
max_new_tokens=64,
do_sample=True,
top_k=5,
top_p=0.95,
num_return_sequences=1,
)
selfies = gen_tokenizer.batch_decode(decoder_output, skip_special_tokens=True)
return [sf.decoder(re.sub(r'\]\s*(.*?)\s*\[', r']\1[', i)) for i in selfies]
# Function to generate canonical SMILES and molecule image
def generate_canonical(smiles):
s = sf.encoder(smiles)
selfie = s.replace("][", "] [")
latent_vec, mask = _encode([selfie])
gen_mol = None
for i in range(5, 51):
print("Searching Latent space")
noise = i / 10
perturbed_latent = _perturb_latent(latent_vec, noise_scale=noise)
gen = _generate(perturbed_latent, mask)
mol = Chem.MolFromSmiles(gen[0])
if mol:
gen_mol = Chem.MolToSmiles(mol)
if gen_mol != Chem.MolToSmiles(Chem.MolFromSmiles(smiles)):
break
else:
print('Abnormal molecule:', gen[0])
if gen_mol:
# Calculate properties for ref and gen molecules
print("calculating properties")
ref_properties = calculate_properties(smiles)
gen_properties = calculate_properties(gen_mol)
tanimoto_similarity = calculate_tanimoto(smiles, gen_mol)
# Prepare the table with ref mol and gen mol
data = {
"Property": ["QED", "SA", "LogP", "Mol Wt", "Tanimoto Similarity"],
"Reference Mol": [
ref_properties[0],
ref_properties[1],
ref_properties[2],
ref_properties[3],
tanimoto_similarity,
],
"Generated Mol": [
gen_properties[0],
gen_properties[1],
gen_properties[2],
gen_properties[3],
"",
],
}
df = pd.DataFrame(data)
# Display molecule image of canonical smiles
print("Getting image")
mol_image = smiles_to_image(gen_mol)
return df, gen_mol, mol_image
return "Invalid SMILES", None, None