File size: 4,378 Bytes
79d2379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import numpy as np
import pandas as pd
import re
import selfies as sf
import torch
from rdkit import Chem
from rdkit.Chem import DataStructs, AllChem, Descriptors, QED, Draw
from rdkit.Chem.Crippen import MolLogP
from rdkit.Contrib.SA_Score import sascorer
from transformers import BartForConditionalGeneration, AutoTokenizer
from transformers.modeling_outputs import BaseModelOutput

gen_tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted")
gen_model = BartForConditionalGeneration.from_pretrained("ibm/materials.selfies-ted")


# Function to display molecule image from SMILES
def smiles_to_image(smiles):
    mol = Chem.MolFromSmiles(smiles)
    return Draw.MolToImage(mol) if mol else None


def calculate_properties(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        qed = QED.qed(mol)
        logp = MolLogP(mol)
        sa = sascorer.calculateScore(mol)
        wt = Descriptors.MolWt(mol)
        return qed, sa, logp, wt
    return None, None, None, None


# Function to calculate Tanimoto similarity
def calculate_tanimoto(smiles1, smiles2):
    mol1 = Chem.MolFromSmiles(smiles1)
    mol2 = Chem.MolFromSmiles(smiles2)
    if mol1 and mol2:
        fp1 = AllChem.GetMorganFingerprintAsBitVect(mol1, 2)
        fp2 = AllChem.GetMorganFingerprintAsBitVect(mol2, 2)
        return round(DataStructs.FingerprintSimilarity(fp1, fp2), 2)
    return None


def _perturb_latent(latent_vecs, noise_scale=0.5):
    return (
        torch.tensor(
            np.random.uniform(0, 1, latent_vecs.shape) * noise_scale,
            dtype=torch.float32,
        )
        + latent_vecs
    )


def _encode(selfies):
    encoding = gen_tokenizer(
        selfies,
        return_tensors='pt',
        max_length=128,
        truncation=True,
        padding='max_length',
    )
    input_ids = encoding['input_ids']
    attention_mask = encoding['attention_mask']
    outputs = gen_model.model.encoder(
        input_ids=input_ids, attention_mask=attention_mask
    )
    model_output = outputs.last_hidden_state
    return model_output, attention_mask


def _generate(latent_vector, mask):
    encoder_outputs = BaseModelOutput(latent_vector)
    decoder_output = gen_model.generate(
        encoder_outputs=encoder_outputs,
        attention_mask=mask,
        max_new_tokens=64,
        do_sample=True,
        top_k=5,
        top_p=0.95,
        num_return_sequences=1,
    )
    selfies = gen_tokenizer.batch_decode(decoder_output, skip_special_tokens=True)
    return [sf.decoder(re.sub(r'\]\s*(.*?)\s*\[', r']\1[', i)) for i in selfies]


# Function to generate canonical SMILES and molecule image
def generate_canonical(smiles):
    s = sf.encoder(smiles)
    selfie = s.replace("][", "] [")
    latent_vec, mask = _encode([selfie])
    gen_mol = None
    for i in range(5, 51):
        print("Searching Latent space")
        noise = i / 10
        perturbed_latent = _perturb_latent(latent_vec, noise_scale=noise)
        gen = _generate(perturbed_latent, mask)
        mol = Chem.MolFromSmiles(gen[0])
        if mol:
            gen_mol = Chem.MolToSmiles(mol)
            if gen_mol != Chem.MolToSmiles(Chem.MolFromSmiles(smiles)):
                break
        else:
            print('Abnormal molecule:', gen[0])

    if gen_mol:
        # Calculate properties for ref and gen molecules
        print("calculating properties")
        ref_properties = calculate_properties(smiles)
        gen_properties = calculate_properties(gen_mol)
        tanimoto_similarity = calculate_tanimoto(smiles, gen_mol)

        # Prepare the table with ref mol and gen mol
        data = {
            "Property": ["QED", "SA", "LogP", "Mol Wt", "Tanimoto Similarity"],
            "Reference Mol": [
                ref_properties[0],
                ref_properties[1],
                ref_properties[2],
                ref_properties[3],
                tanimoto_similarity,
            ],
            "Generated Mol": [
                gen_properties[0],
                gen_properties[1],
                gen_properties[2],
                gen_properties[3],
                "",
            ],
        }
        df = pd.DataFrame(data)

        # Display molecule image of canonical smiles
        print("Getting image")
        mol_image = smiles_to_image(gen_mol)

        return df, gen_mol, mol_image
    return "Invalid SMILES", None, None