File size: 4,378 Bytes
79d2379 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import numpy as np
import pandas as pd
import re
import selfies as sf
import torch
from rdkit import Chem
from rdkit.Chem import DataStructs, AllChem, Descriptors, QED, Draw
from rdkit.Chem.Crippen import MolLogP
from rdkit.Contrib.SA_Score import sascorer
from transformers import BartForConditionalGeneration, AutoTokenizer
from transformers.modeling_outputs import BaseModelOutput
gen_tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted")
gen_model = BartForConditionalGeneration.from_pretrained("ibm/materials.selfies-ted")
# Function to display molecule image from SMILES
def smiles_to_image(smiles):
mol = Chem.MolFromSmiles(smiles)
return Draw.MolToImage(mol) if mol else None
def calculate_properties(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol:
qed = QED.qed(mol)
logp = MolLogP(mol)
sa = sascorer.calculateScore(mol)
wt = Descriptors.MolWt(mol)
return qed, sa, logp, wt
return None, None, None, None
# Function to calculate Tanimoto similarity
def calculate_tanimoto(smiles1, smiles2):
mol1 = Chem.MolFromSmiles(smiles1)
mol2 = Chem.MolFromSmiles(smiles2)
if mol1 and mol2:
fp1 = AllChem.GetMorganFingerprintAsBitVect(mol1, 2)
fp2 = AllChem.GetMorganFingerprintAsBitVect(mol2, 2)
return round(DataStructs.FingerprintSimilarity(fp1, fp2), 2)
return None
def _perturb_latent(latent_vecs, noise_scale=0.5):
return (
torch.tensor(
np.random.uniform(0, 1, latent_vecs.shape) * noise_scale,
dtype=torch.float32,
)
+ latent_vecs
)
def _encode(selfies):
encoding = gen_tokenizer(
selfies,
return_tensors='pt',
max_length=128,
truncation=True,
padding='max_length',
)
input_ids = encoding['input_ids']
attention_mask = encoding['attention_mask']
outputs = gen_model.model.encoder(
input_ids=input_ids, attention_mask=attention_mask
)
model_output = outputs.last_hidden_state
return model_output, attention_mask
def _generate(latent_vector, mask):
encoder_outputs = BaseModelOutput(latent_vector)
decoder_output = gen_model.generate(
encoder_outputs=encoder_outputs,
attention_mask=mask,
max_new_tokens=64,
do_sample=True,
top_k=5,
top_p=0.95,
num_return_sequences=1,
)
selfies = gen_tokenizer.batch_decode(decoder_output, skip_special_tokens=True)
return [sf.decoder(re.sub(r'\]\s*(.*?)\s*\[', r']\1[', i)) for i in selfies]
# Function to generate canonical SMILES and molecule image
def generate_canonical(smiles):
s = sf.encoder(smiles)
selfie = s.replace("][", "] [")
latent_vec, mask = _encode([selfie])
gen_mol = None
for i in range(5, 51):
print("Searching Latent space")
noise = i / 10
perturbed_latent = _perturb_latent(latent_vec, noise_scale=noise)
gen = _generate(perturbed_latent, mask)
mol = Chem.MolFromSmiles(gen[0])
if mol:
gen_mol = Chem.MolToSmiles(mol)
if gen_mol != Chem.MolToSmiles(Chem.MolFromSmiles(smiles)):
break
else:
print('Abnormal molecule:', gen[0])
if gen_mol:
# Calculate properties for ref and gen molecules
print("calculating properties")
ref_properties = calculate_properties(smiles)
gen_properties = calculate_properties(gen_mol)
tanimoto_similarity = calculate_tanimoto(smiles, gen_mol)
# Prepare the table with ref mol and gen mol
data = {
"Property": ["QED", "SA", "LogP", "Mol Wt", "Tanimoto Similarity"],
"Reference Mol": [
ref_properties[0],
ref_properties[1],
ref_properties[2],
ref_properties[3],
tanimoto_similarity,
],
"Generated Mol": [
gen_properties[0],
gen_properties[1],
gen_properties[2],
gen_properties[3],
"",
],
}
df = pd.DataFrame(data)
# Display molecule image of canonical smiles
print("Getting image")
mol_image = smiles_to_image(gen_mol)
return df, gen_mol, mol_image
return "Invalid SMILES", None, None
|