File size: 2,973 Bytes
6c9555d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import sys
import torch
import selfies as sf  # selfies>=2.1.1
import pickle
import pandas as pd
import numpy as np
from datasets import Dataset
from rdkit import Chem
from transformers import AutoTokenizer, AutoModel


class SELFIES(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.model = None
        self.tokenizer = None
        self.invalid = []

    def get_selfies(self, smiles_list):
        self.invalid = []
        spaced_selfies_batch = []
        for i, smiles in enumerate(smiles_list):
            try:
                selfies = sf.encoder(smiles.rstrip())
            except:
                try:
                    smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles.rstrip()))
                    selfies = sf.encoder(smiles)
                except:
                    selfies = "[]"
                    self.invalid.append(i)

            spaced_selfies_batch.append(selfies.replace('][', '] ['))

        return spaced_selfies_batch


    def get_embedding(self, selfies):
        encoding = self.tokenizer(selfies["selfies"], return_tensors='pt', max_length=128, truncation=True, padding='max_length')
        input_ids = encoding['input_ids']
        attention_mask = encoding['attention_mask']
        outputs = self.model.encoder(input_ids=input_ids, attention_mask=attention_mask)
        model_output = outputs.last_hidden_state

        input_mask_expanded = attention_mask.unsqueeze(-1).expand(model_output.size()).float()
        sum_embeddings = torch.sum(model_output * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        model_output = sum_embeddings / sum_mask

        del encoding['input_ids']
        del encoding['attention_mask']

        encoding["embedding"] = model_output

        return encoding


    def load(self, checkpoint="bart-2908.pickle"):
        """
            inputs :
                   checkpoint (pickle object)
        """

        self.tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted")
        self.model = AutoModel.from_pretrained("ibm/materials.selfies-ted")





    # TODO: remove `use_gpu` argument in validation pipeline
    def encode(self, smiles_list=[], use_gpu=False, return_tensor=False):
        """
            inputs :
                   checkpoint (pickle object)
            :return: embedding
        """
        selfies = self.get_selfies(smiles_list)
        selfies_df = pd.DataFrame(selfies,columns=["selfies"])
        data = Dataset.from_pandas(selfies_df)
        embedding = data.map(self.get_embedding, batched=True, num_proc=1, batch_size=128)
        emb = np.asarray(embedding["embedding"].copy())

        for idx in self.invalid:
            emb[idx] = np.nan
            print("Cannot encode {0} to selfies and embedding replaced by NaN".format(smiles_list[idx]))

        if return_tensor:
            return torch.tensor(emb)
        return pd.DataFrame(emb)