|
import numpy as np |
|
import pandas as pd |
|
import re |
|
import selfies as sf |
|
import torch |
|
from rdkit import Chem |
|
from rdkit.Chem import DataStructs, AllChem, Descriptors, QED, Draw |
|
from rdkit.Chem.Crippen import MolLogP |
|
from rdkit.Contrib.SA_Score import sascorer |
|
from transformers import BartForConditionalGeneration, AutoTokenizer |
|
from transformers.modeling_outputs import BaseModelOutput |
|
|
|
gen_tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted") |
|
gen_model = BartForConditionalGeneration.from_pretrained("ibm/materials.selfies-ted") |
|
|
|
|
|
|
|
def smiles_to_image(smiles): |
|
mol = Chem.MolFromSmiles(smiles) |
|
return Draw.MolToImage(mol) if mol else None |
|
|
|
|
|
def calculate_properties(smiles): |
|
mol = Chem.MolFromSmiles(smiles) |
|
if mol: |
|
qed = QED.qed(mol) |
|
logp = MolLogP(mol) |
|
sa = sascorer.calculateScore(mol) |
|
wt = Descriptors.MolWt(mol) |
|
return qed, sa, logp, wt |
|
return None, None, None, None |
|
|
|
|
|
|
|
def calculate_tanimoto(smiles1, smiles2): |
|
mol1 = Chem.MolFromSmiles(smiles1) |
|
mol2 = Chem.MolFromSmiles(smiles2) |
|
if mol1 and mol2: |
|
fp1 = AllChem.GetMorganFingerprintAsBitVect(mol1, 2) |
|
fp2 = AllChem.GetMorganFingerprintAsBitVect(mol2, 2) |
|
return round(DataStructs.FingerprintSimilarity(fp1, fp2), 2) |
|
return None |
|
|
|
|
|
def _perturb_latent(latent_vecs, noise_scale=0.5): |
|
return ( |
|
torch.tensor( |
|
np.random.uniform(0, 1, latent_vecs.shape) * noise_scale, |
|
dtype=torch.float32, |
|
) |
|
+ latent_vecs |
|
) |
|
|
|
|
|
def _encode(selfies): |
|
encoding = gen_tokenizer( |
|
selfies, |
|
return_tensors='pt', |
|
max_length=128, |
|
truncation=True, |
|
padding='max_length', |
|
) |
|
input_ids = encoding['input_ids'] |
|
attention_mask = encoding['attention_mask'] |
|
outputs = gen_model.model.encoder( |
|
input_ids=input_ids, attention_mask=attention_mask |
|
) |
|
model_output = outputs.last_hidden_state |
|
return model_output, attention_mask |
|
|
|
|
|
def _generate(latent_vector, mask): |
|
encoder_outputs = BaseModelOutput(latent_vector) |
|
decoder_output = gen_model.generate( |
|
encoder_outputs=encoder_outputs, |
|
attention_mask=mask, |
|
max_new_tokens=64, |
|
do_sample=True, |
|
top_k=5, |
|
top_p=0.95, |
|
num_return_sequences=1, |
|
) |
|
selfies = gen_tokenizer.batch_decode(decoder_output, skip_special_tokens=True) |
|
return [sf.decoder(re.sub(r'\]\s*(.*?)\s*\[', r']\1[', i)) for i in selfies] |
|
|
|
|
|
|
|
def generate_canonical(smiles): |
|
s = sf.encoder(smiles) |
|
selfie = s.replace("][", "] [") |
|
latent_vec, mask = _encode([selfie]) |
|
gen_mol = None |
|
for i in range(5, 51): |
|
print("Searching Latent space") |
|
noise = i / 10 |
|
perturbed_latent = _perturb_latent(latent_vec, noise_scale=noise) |
|
gen = _generate(perturbed_latent, mask) |
|
mol = Chem.MolFromSmiles(gen[0]) |
|
if mol: |
|
gen_mol = Chem.MolToSmiles(mol) |
|
if gen_mol != Chem.MolToSmiles(Chem.MolFromSmiles(smiles)): |
|
break |
|
else: |
|
print('Abnormal molecule:', gen[0]) |
|
|
|
if gen_mol: |
|
|
|
print("calculating properties") |
|
ref_properties = calculate_properties(smiles) |
|
gen_properties = calculate_properties(gen_mol) |
|
tanimoto_similarity = calculate_tanimoto(smiles, gen_mol) |
|
|
|
|
|
data = { |
|
"Property": ["QED", "SA", "LogP", "Mol Wt", "Tanimoto Similarity"], |
|
"Reference Mol": [ |
|
ref_properties[0], |
|
ref_properties[1], |
|
ref_properties[2], |
|
ref_properties[3], |
|
tanimoto_similarity, |
|
], |
|
"Generated Mol": [ |
|
gen_properties[0], |
|
gen_properties[1], |
|
gen_properties[2], |
|
gen_properties[3], |
|
"", |
|
], |
|
} |
|
df = pd.DataFrame(data) |
|
|
|
|
|
print("Getting image") |
|
mol_image = smiles_to_image(gen_mol) |
|
|
|
return df, gen_mol, mol_image |
|
return "Invalid SMILES", None, None |
|
|