File size: 17,684 Bytes
d6ee7b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af5e1e3
d6ee7b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af5e1e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6ee7b8
 
 
 
af5e1e3
 
d6ee7b8
 
 
 
 
 
 
 
5ac8c44
d6ee7b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
import warnings

warnings.filterwarnings('ignore')

import os

from utils.distribution import DeviceManager

DeviceManager('cpu')

# from config import args
import pandas as pd

from emblaze import ProjectionTechnique

from embedding_produce import EmbeddingProducer

from utils.mol_to_graph import mol_to_graph_data_obj_asAtomNum

os.environ['RANK'] = '-1'

import cv2
import umap
from rdkit.Chem import AllChem
from rdkit.Chem import Draw as chem_draw
from torch_geometric.data import Batch

from utils.emblaze_utils import to_EmbeddingSet, to_Embedding
from utils.geometric_graph import to_dense, fetch_geometric_batch
from utils.tensor_operator import to_device, tensor2array
from utils.utils import print_something_first

import torch

import numpy as np

import emblaze as m_emb
from visualizer.visualizers import plt_show, vis_cellular_img

logger = print


class CMEViewer(EmbeddingProducer):
    def __init__(self, *args, **kwargs):
        super(CMEViewer, self).__init__(*args, **kwargs)
        # Initialize the number of samples
        self.samples_num = self.cfg.modal.samples_num
        # Initialize the list of embedding label indices
        self.embedding_label_idx_list = [self.modal_nums]
        # Initialize the original color array
        self.ori_color_array = None
        # Initialize the number of neighbors
        self.n_neighbors = self.cfg.modal.n_neighbors
        # Initialize the viewer dictionary
        self.viewer_dict = {'merge_modal_viewer': None, 'cross_modal_viewer': None}
        # Initialize the original view
        self.init_ori_view()

    @print_something_first
    def init_img_list(self):
        self.img_list = []
        for m_i, modal_name in enumerate(self.modal_name_list):
            modal_imgs_dir = os.path.join(self.cfg.modal.modal_imgs_dir, f'omics_{modal_name}_vis')
            for i in range(self.samples_num):
                read_path = os.path.join(modal_imgs_dir, f"{i:05d}.png")
                img = cv2.imread(read_path)[..., ::-1]
                self.img_list.append(img)

    def update_img_list(self, new_img_list):
        if not isinstance(new_img_list, list):
            new_img_list = [new_img_list]
        self.graph_img_index.extend(list(range(len(self.img_list), len(self.img_list) + len(new_img_list))))
        self.img_list.extend(new_img_list)

    @print_something_first
    def init_embeddings(self):
        # Load the required modal names from the configuration
        modal_name_list = self.cfg.modal.modal_name_list
        # Get the directory of the embeddings
        embeddings_dir = self.cfg.modal.embeddings_dir
        # Create a list to store the loaded embeddings
        embeddings_list = []
        # Iterate through the modal names and load the corresponding embeddings
        for name in modal_name_list:
            embeddings_list.append(
                np.load(os.path.join(embeddings_dir, f"{name}_embedding.npy"))[:self.cfg.modal.samples_num])
        # Concatenate the embeddings into one array
        self.embeddings = np.concatenate(embeddings_list, axis=0)
        # Create a list to store the modal names
        self.color_list = []
        # Iterate through the modal names and add them to the list
        for m_i, m_name in enumerate(self.modal_name_list):
            self.color_list.extend([m_name] * self.samples_num)

    def init_ori_view(self):
        # Initialize embeddings
        self.init_embeddings()
        # Initialize transformation from original embeddings
        self.init_transformation_from_ori_embeddings()
        # Get reduced embeddings
        self.reduced_embeddings = self.trans.embedding_
        # Get modal name list
        modal_name_list = self.cfg.modal.modal_name_list
        # Get graph class number
        graph_cls_num = modal_name_list.index('graph')
        # Get start and end index of graph embeddings
        graph_index_start = graph_cls_num * self.cfg.modal.samples_num
        graph_index_len = self.cfg.modal.samples_num
        graph_index_end = graph_index_start + graph_index_len
        # Get reduced graph embeddings
        self.reduced_graph_embeddings = self.reduced_embeddings[graph_cls_num * self.cfg.modal.samples_num:]
        # Get graph image index
        self.graph_img_index = list(range(graph_index_start, graph_index_end))
        # Create reduced embedding set
        self.reduced_embeddingSet = to_EmbeddingSet(self.reduced_embeddings, self.color_list)
        # Compute neighbors of reduced embedding set
        self.reduced_embeddingSet.compute_neighbors(metric='euclidean', n_neighbors=self.cfg.modal.n_neighbors)
        # Initialize image list
        self.init_img_list()

    @print_something_first
    def get_new_smiles_embeddings(self, smiles_list):
        # Add comments to important steps in the code
        self.model.eval()  # Set model to evaluation mode
        if isinstance(smiles_list, str):  # Check if smiles_list is a string
            smiles_list = [smiles_list]  # Convert smiles_list to list if it is a string
        mol_list = []  # Initialize empty list for molecules
        mol_img_list = []  # Initialize empty list for molecule images
        for s_idx, s in enumerate(smiles_list):  # Iterate through smiles_list
            rdkit_mol = AllChem.MolFromSmiles(s)  # Create RDKit molecule from SMILES string
            molecular_graph = mol_to_graph_data_obj_asAtomNum(rdkit_mol,
                                                              True)  # Convert RDKit molecule to graph data object
            mol_img = chem_draw.MolToImage(rdkit_mol, size=(224, 224), dpi=600)  # Create image of molecule
            mol_img = mol_img.__array__()  # Convert image to array
            mol_img = mol_img.copy()  # Copy image array
            cv2.putText(mol_img, f"{len(self.reduced_embeddings) + s_idx:05d}", (20, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
                        (0, 0, 0), 1, cv2.LINE_AA)  # Add text to image
            mol_img_list.append(mol_img)  # Append image to list
            mol_list.append(molecular_graph)  # Append graph data object to list
        geometric_batch = Batch().from_data_list(mol_list)  # Create geometric batch from list of molecules
        matrix_graphs, node_masks = to_dense(
            *fetch_geometric_batch(geometric_batch, ['edge_attr', 'batch']))  # Fetch dense matrix graphs and node masks
        matrix_graphs, node_masks = [to_device(x, self.device) for x in
                                     [matrix_graphs, node_masks]]  # Move matrix graphs and node masks to device
        matrix_graphs.to(to_type='float')  # Convert matrix graphs to float type
        data_dict = {'graphs': matrix_graphs, 'node_masks': node_masks}  # Create data dictionary
        with torch.no_grad():  # Disable gradient calculation
            molecule_graph_emb = self.model.get_graph_embedding(data_dict)['embedding']  # Get graph embedding
            new_graph_embedding = tensor2array(molecule_graph_emb)  # Convert graph embedding to array
            self.update_img_list(mol_img_list)  # Update image list
            return new_graph_embedding  # Return graph embedding

    def add_new_smiles_embedding(self, smiles, construct_view=True, plot_neighbors=False):
        new_embeddings = super(CMEViewer, self).add_new_smiles_embedding(smiles)
        cur_embeddings_num = len(self.reduced_embeddings)
        self.update_embeddings(new_embeddings, 'graph')
        if construct_view:
            self.construct_merge_modal_view()
        if plot_neighbors:
            self.plot_neighbors([i for i in range(len(self.reduced_embeddingSet.embeddings[0]) - len(smiles),
                                                  len(self.reduced_embeddingSet.embeddings[0]))])
        print(f'Finished add smiles: {smiles}.')
        print(f'Their ids: {[idx for idx in range(cur_embeddings_num, len(self.reduced_embeddings))]}')
        print(f'Through '
              f'{self.__class__.__name__}.plot_neighbors(idx)'
              f'to get their neighbors.')
        return new_embeddings

    def add_new_img_embedding(self, in_img, construct_view=True, plot_neighbors=False, preprocess=True):
        if preprocess:
            import albumentations as A
            from albumentations.pytorch import ToTensorV2
            transform = A.Compose([
                A.Resize(height=128, width=128, p=1.0),
                ToTensorV2()],
                p=1.0)
            processed_img = transform(image=in_img.transpose(1, 2, 0))['image'] / 15
        else:
            processed_img = in_img.copy()
        processed_img = np.expand_dims(processed_img, axis=0)
        new_embeddings = super(CMEViewer, self).add_new_img_embedding(processed_img)
        cur_embeddings_num = len(self.reduced_embeddings)
        vis_img = vis_cellular_img(in_img, (224, 224))
        self.update_img_list(vis_img)
        self.update_embeddings(new_embeddings, 'img')
        if construct_view:
            self.construct_merge_modal_view()
        if plot_neighbors:
            self.plot_neighbors([i for i in range(len(self.reduced_embeddingSet.embeddings[0]) - len(processed_img),
                                                  len(self.reduced_embeddingSet.embeddings[0]))])
        print(f'Finished add imgs')
        print(f'Their ids: {[idx for idx in range(cur_embeddings_num, len(self.reduced_embeddings))]}')
        print(f'Through '
              f'{self.__class__.__name__}.plot_neighbors(idx)'
              f'to get their neighbors.')
        return new_embeddings

    def plot_neighbors(self, idx, neighbors_num=1):
        # Convert idx_list to list if it is not already a list
        idx_list = [idx]
        # Get embeddings from reduced graph and target
        idx_cls = self.color_array[idx]
        delta = 0 if idx_cls == 'graph' else self.samples_num
        emb_lib = self.reduced_embeddings[delta: delta + self.samples_num]
        target_emb = self.reduced_embeddings[idx_list]
        # Convert embeddings to tensors and move to device
        emb_lib = torch.as_tensor(emb_lib).to(self.device)
        target_emb = torch.as_tensor(target_emb).to(self.device)
        # Calculate distances between target embeddings and embeddings in reduced graph
        dis = torch.cdist(target_emb, emb_lib, p=2)
        # Get top k neighbors (neighbors_num + 1) with smallest distances
        sorted_distances, indices = torch.topk(dis, k=neighbors_num, largest=False, dim=1)
        # Iterate through each target index
        for idx_i, idx in enumerate(idx_list):
            # Get image for source index
            idx_img = self.img_list[idx]
            plt_show(idx_img, f"Source: {idx:05d}")
            # Iterate through each neighbor
            for ind_i, ind in enumerate(indices[idx_i]):
                # Get image for neighbor index
                neighbor_img = self.img_list[delta + tensor2array(ind)]
                plt_show(neighbor_img, f"Neighbor: {ind:05d}")

    def update_embeddings(self, new_embeddings, cls):
        # Reduce the dimension of new embeddings
        new_reduced_embeddings = self.reduced_embedding_dim(new_embeddings)
        # Update the color list with the length of new embeddings
        self.update_color(len(new_embeddings), cls)
        # Concatenate the reduced embeddings with the new reduced embeddings
        self.reduced_embeddings = np.concatenate([self.reduced_embeddings, new_reduced_embeddings], axis=0)
        # Create a new EmbeddingSet with the reduced embeddings and color list
        self.reduced_embeddingSet = to_EmbeddingSet(self.reduced_embeddings, self.color_list)
        # Compute the neighbors of the reduced embedding set
        self.reduced_embeddingSet.compute_neighbors(metric='euclidean', n_neighbors=self.n_neighbors)
        # Check if the class is graph
        if cls == 'graph':
            # Concatenate the reduced graph embeddings with the new reduced embeddings
            self.reduced_graph_embeddings = np.concatenate([self.reduced_graph_embeddings, new_reduced_embeddings],
                                                           axis=0)

    def reduced_embedding_dim(self, in_embedding):
        """Reduce the dimension of the input embedding
         Args:
            self (object): The object instance
            in_embedding (np.array): The input embedding
         Returns:
            np.array: The reduced embedding
        """
        return self.trans.transform(in_embedding)

    def update_cur_reduced_embedding(self, new_reduced_embedding):
        """Update the current reduced embedding
         Args:
            self (object): The object instance
            new_reduced_embedding (np.array): The new reduced embedding
         Returns:
            m_emb.EmbeddingSet: The updated embedding set
        """
        self.cur_reduced_embedding = np.concatenate([self.cur_reduced_embedding, new_reduced_embedding], axis=0)
        self.cur_embeddingSet = to_EmbeddingSet(self.cur_reduced_embedding, self.cur_color_array)
        self.ori_embeddingSet.compute_neighbors(metric='euclidean', n_neighbors=self.cfg.modal.ori_n_neighbors)
        return self.ori_embeddingSet

    @print_something_first
    def construct_merge_modal_view(self):
        """Construct a merge modal view
         Args:
            self (object): The object instance
        """
        thum = m_emb.ImageThumbnails(self.img_list) if self.is_show_img() else None
        viewer = m_emb.Viewer(embeddings=self.reduced_embeddingSet, thumbnails=thum)
        self.viewer_dict['merge_modal_viewer'] = viewer
        print(f'Finished construct merge modal view, use "{self.__class__.__name__}.viewer_dict["merge_modal_viewer"] '
              f'to plot"')

    @print_something_first
    def construct_cross_modal_view(self):
        """Construct a cross modal view
         Args:
            self (object): The object instance
        """
        embeddingSet = m_emb.EmbeddingSet([to_Embedding(
            self.embeddings[idx * self.samples_num: (idx + 1) * self.samples_num],
            self.color_list[idx * self.samples_num: (idx + 1) * self.samples_num])
            for idx in range(self.modal_nums)])
        embeddingSet.compute_neighbors(metric='cosine', n_neighbors=self.cfg.modal.n_neighbors)
        reduced_emb = embeddingSet.project(method=ProjectionTechnique.ALIGNED_UMAP,
                                           metric='cosine', n_neighbors=self.cfg.modal.n_neighbors)
        reduced_emb.compute_neighbors(metric='euclidean', n_neighbors=self.cfg.modal.n_neighbors)
        w = m_emb.Viewer(embeddings=reduced_emb)
        self.viewer_dict['cross_modal_viewer'] = w
        print(f'Finished construct cross modal view, use "{self.__class__.__name__}.viewer_dict["cross_modal_viewer"] '
              f'to plot"')

    def help(self):
        """Print out helpful information
         Args:
            self (object): The object instance
        """
        print(f'Use "{self.__class__.__name__}.add_new_smiles_embedding(SMILES_LIST)" to add new embeddings')
        print(f'Use "{self.__class__.__name__}.plot_neighbors(idx_list,neighbors_num=1)" to plot neighbors img')
        for k in self.viewer_dict:
            print(
                f'Use "{self.__class__.__name__}.viewer_dict["{k}"] to get the {"_".join(k.split("_")[:-1])} embeddings visualization results')

    @print_something_first
    def init_transformation_from_ori_embeddings(self):
        """Initialize the transformation from original embeddings
        Args:
            self (object): The object instance
        """
        trans = umap.UMAP(metric='cosine', n_neighbors=100).fit(self.embeddings)
        self.trans = trans


    def is_show_img(self):
        """Check if images are shown
         Args:
            self (object): The object instance
         Returns:
            bool: Whether images are shown
        """
        return self.cfg.modal.get('show_img', False)

    def init_showing_imgs(self):
        """Initialize the showing images
         Args:
            self (object): The object instance
        """
        self.img_list = []
        for m_i, modal_name in enumerate(self.modal_name_list):
            modal_imgs_dir = os.path.join(self.cfg.moda.modal_imgs_dir, f"omics_{modal_name}_vis")
            for i in range(self.cfg.modal.samples_num):
                read_path = os.path.join(modal_imgs_dir, f"{i:05d}.png")
                img = cv2.imread(read_path)[..., ::-1]
                self.img_list.append(img)

    def init_color_arr(self):
        """Initialize the color array
         Args:
            self (object): The object instance
         Returns:
            np.array: The initialized color array
        """
        self.color_array = np.concatenate([i * np.ones((self.samples_num,))
                                           for i in range(self.modal_nums)],
                                          axis=0)
        return self.color_array

    def update_color(self, add_samples_num, sample_cls):
        """Update the color list
         Args:
            self (object): The object instance
            add_samples_num (int): The number of samples to add
            sample_cls (int): The class of the samples
        """
        self.color_list.extend([sample_cls] * add_samples_num)
        self.color_array = np.array(self.color_list)


if __name__ == '__main__':
    config = 'config/miga_vis/embedding_vis_cfg.yaml'
    extra_para = {'modal': {'samples_num': 1000}}
    v = CMEViewer(config, extra_para)
    embeddings = v.add_new_smiles_embedding(['C[C@@H](NC(=O)C[C@@H]1O[C@H](CO)[C@H](NC(=O)c2cccnc2)C=C1)c1ccccc1'])
    v.plot_neighbors([2000])
    x = 1