Spaces:
Build error
Build error
File size: 17,684 Bytes
d6ee7b8 af5e1e3 d6ee7b8 af5e1e3 d6ee7b8 af5e1e3 d6ee7b8 5ac8c44 d6ee7b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 |
import warnings
warnings.filterwarnings('ignore')
import os
from utils.distribution import DeviceManager
DeviceManager('cpu')
# from config import args
import pandas as pd
from emblaze import ProjectionTechnique
from embedding_produce import EmbeddingProducer
from utils.mol_to_graph import mol_to_graph_data_obj_asAtomNum
os.environ['RANK'] = '-1'
import cv2
import umap
from rdkit.Chem import AllChem
from rdkit.Chem import Draw as chem_draw
from torch_geometric.data import Batch
from utils.emblaze_utils import to_EmbeddingSet, to_Embedding
from utils.geometric_graph import to_dense, fetch_geometric_batch
from utils.tensor_operator import to_device, tensor2array
from utils.utils import print_something_first
import torch
import numpy as np
import emblaze as m_emb
from visualizer.visualizers import plt_show, vis_cellular_img
logger = print
class CMEViewer(EmbeddingProducer):
def __init__(self, *args, **kwargs):
super(CMEViewer, self).__init__(*args, **kwargs)
# Initialize the number of samples
self.samples_num = self.cfg.modal.samples_num
# Initialize the list of embedding label indices
self.embedding_label_idx_list = [self.modal_nums]
# Initialize the original color array
self.ori_color_array = None
# Initialize the number of neighbors
self.n_neighbors = self.cfg.modal.n_neighbors
# Initialize the viewer dictionary
self.viewer_dict = {'merge_modal_viewer': None, 'cross_modal_viewer': None}
# Initialize the original view
self.init_ori_view()
@print_something_first
def init_img_list(self):
self.img_list = []
for m_i, modal_name in enumerate(self.modal_name_list):
modal_imgs_dir = os.path.join(self.cfg.modal.modal_imgs_dir, f'omics_{modal_name}_vis')
for i in range(self.samples_num):
read_path = os.path.join(modal_imgs_dir, f"{i:05d}.png")
img = cv2.imread(read_path)[..., ::-1]
self.img_list.append(img)
def update_img_list(self, new_img_list):
if not isinstance(new_img_list, list):
new_img_list = [new_img_list]
self.graph_img_index.extend(list(range(len(self.img_list), len(self.img_list) + len(new_img_list))))
self.img_list.extend(new_img_list)
@print_something_first
def init_embeddings(self):
# Load the required modal names from the configuration
modal_name_list = self.cfg.modal.modal_name_list
# Get the directory of the embeddings
embeddings_dir = self.cfg.modal.embeddings_dir
# Create a list to store the loaded embeddings
embeddings_list = []
# Iterate through the modal names and load the corresponding embeddings
for name in modal_name_list:
embeddings_list.append(
np.load(os.path.join(embeddings_dir, f"{name}_embedding.npy"))[:self.cfg.modal.samples_num])
# Concatenate the embeddings into one array
self.embeddings = np.concatenate(embeddings_list, axis=0)
# Create a list to store the modal names
self.color_list = []
# Iterate through the modal names and add them to the list
for m_i, m_name in enumerate(self.modal_name_list):
self.color_list.extend([m_name] * self.samples_num)
def init_ori_view(self):
# Initialize embeddings
self.init_embeddings()
# Initialize transformation from original embeddings
self.init_transformation_from_ori_embeddings()
# Get reduced embeddings
self.reduced_embeddings = self.trans.embedding_
# Get modal name list
modal_name_list = self.cfg.modal.modal_name_list
# Get graph class number
graph_cls_num = modal_name_list.index('graph')
# Get start and end index of graph embeddings
graph_index_start = graph_cls_num * self.cfg.modal.samples_num
graph_index_len = self.cfg.modal.samples_num
graph_index_end = graph_index_start + graph_index_len
# Get reduced graph embeddings
self.reduced_graph_embeddings = self.reduced_embeddings[graph_cls_num * self.cfg.modal.samples_num:]
# Get graph image index
self.graph_img_index = list(range(graph_index_start, graph_index_end))
# Create reduced embedding set
self.reduced_embeddingSet = to_EmbeddingSet(self.reduced_embeddings, self.color_list)
# Compute neighbors of reduced embedding set
self.reduced_embeddingSet.compute_neighbors(metric='euclidean', n_neighbors=self.cfg.modal.n_neighbors)
# Initialize image list
self.init_img_list()
@print_something_first
def get_new_smiles_embeddings(self, smiles_list):
# Add comments to important steps in the code
self.model.eval() # Set model to evaluation mode
if isinstance(smiles_list, str): # Check if smiles_list is a string
smiles_list = [smiles_list] # Convert smiles_list to list if it is a string
mol_list = [] # Initialize empty list for molecules
mol_img_list = [] # Initialize empty list for molecule images
for s_idx, s in enumerate(smiles_list): # Iterate through smiles_list
rdkit_mol = AllChem.MolFromSmiles(s) # Create RDKit molecule from SMILES string
molecular_graph = mol_to_graph_data_obj_asAtomNum(rdkit_mol,
True) # Convert RDKit molecule to graph data object
mol_img = chem_draw.MolToImage(rdkit_mol, size=(224, 224), dpi=600) # Create image of molecule
mol_img = mol_img.__array__() # Convert image to array
mol_img = mol_img.copy() # Copy image array
cv2.putText(mol_img, f"{len(self.reduced_embeddings) + s_idx:05d}", (20, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
(0, 0, 0), 1, cv2.LINE_AA) # Add text to image
mol_img_list.append(mol_img) # Append image to list
mol_list.append(molecular_graph) # Append graph data object to list
geometric_batch = Batch().from_data_list(mol_list) # Create geometric batch from list of molecules
matrix_graphs, node_masks = to_dense(
*fetch_geometric_batch(geometric_batch, ['edge_attr', 'batch'])) # Fetch dense matrix graphs and node masks
matrix_graphs, node_masks = [to_device(x, self.device) for x in
[matrix_graphs, node_masks]] # Move matrix graphs and node masks to device
matrix_graphs.to(to_type='float') # Convert matrix graphs to float type
data_dict = {'graphs': matrix_graphs, 'node_masks': node_masks} # Create data dictionary
with torch.no_grad(): # Disable gradient calculation
molecule_graph_emb = self.model.get_graph_embedding(data_dict)['embedding'] # Get graph embedding
new_graph_embedding = tensor2array(molecule_graph_emb) # Convert graph embedding to array
self.update_img_list(mol_img_list) # Update image list
return new_graph_embedding # Return graph embedding
def add_new_smiles_embedding(self, smiles, construct_view=True, plot_neighbors=False):
new_embeddings = super(CMEViewer, self).add_new_smiles_embedding(smiles)
cur_embeddings_num = len(self.reduced_embeddings)
self.update_embeddings(new_embeddings, 'graph')
if construct_view:
self.construct_merge_modal_view()
if plot_neighbors:
self.plot_neighbors([i for i in range(len(self.reduced_embeddingSet.embeddings[0]) - len(smiles),
len(self.reduced_embeddingSet.embeddings[0]))])
print(f'Finished add smiles: {smiles}.')
print(f'Their ids: {[idx for idx in range(cur_embeddings_num, len(self.reduced_embeddings))]}')
print(f'Through '
f'{self.__class__.__name__}.plot_neighbors(idx)'
f'to get their neighbors.')
return new_embeddings
def add_new_img_embedding(self, in_img, construct_view=True, plot_neighbors=False, preprocess=True):
if preprocess:
import albumentations as A
from albumentations.pytorch import ToTensorV2
transform = A.Compose([
A.Resize(height=128, width=128, p=1.0),
ToTensorV2()],
p=1.0)
processed_img = transform(image=in_img.transpose(1, 2, 0))['image'] / 15
else:
processed_img = in_img.copy()
processed_img = np.expand_dims(processed_img, axis=0)
new_embeddings = super(CMEViewer, self).add_new_img_embedding(processed_img)
cur_embeddings_num = len(self.reduced_embeddings)
vis_img = vis_cellular_img(in_img, (224, 224))
self.update_img_list(vis_img)
self.update_embeddings(new_embeddings, 'img')
if construct_view:
self.construct_merge_modal_view()
if plot_neighbors:
self.plot_neighbors([i for i in range(len(self.reduced_embeddingSet.embeddings[0]) - len(processed_img),
len(self.reduced_embeddingSet.embeddings[0]))])
print(f'Finished add imgs')
print(f'Their ids: {[idx for idx in range(cur_embeddings_num, len(self.reduced_embeddings))]}')
print(f'Through '
f'{self.__class__.__name__}.plot_neighbors(idx)'
f'to get their neighbors.')
return new_embeddings
def plot_neighbors(self, idx, neighbors_num=1):
# Convert idx_list to list if it is not already a list
idx_list = [idx]
# Get embeddings from reduced graph and target
idx_cls = self.color_array[idx]
delta = 0 if idx_cls == 'graph' else self.samples_num
emb_lib = self.reduced_embeddings[delta: delta + self.samples_num]
target_emb = self.reduced_embeddings[idx_list]
# Convert embeddings to tensors and move to device
emb_lib = torch.as_tensor(emb_lib).to(self.device)
target_emb = torch.as_tensor(target_emb).to(self.device)
# Calculate distances between target embeddings and embeddings in reduced graph
dis = torch.cdist(target_emb, emb_lib, p=2)
# Get top k neighbors (neighbors_num + 1) with smallest distances
sorted_distances, indices = torch.topk(dis, k=neighbors_num, largest=False, dim=1)
# Iterate through each target index
for idx_i, idx in enumerate(idx_list):
# Get image for source index
idx_img = self.img_list[idx]
plt_show(idx_img, f"Source: {idx:05d}")
# Iterate through each neighbor
for ind_i, ind in enumerate(indices[idx_i]):
# Get image for neighbor index
neighbor_img = self.img_list[delta + tensor2array(ind)]
plt_show(neighbor_img, f"Neighbor: {ind:05d}")
def update_embeddings(self, new_embeddings, cls):
# Reduce the dimension of new embeddings
new_reduced_embeddings = self.reduced_embedding_dim(new_embeddings)
# Update the color list with the length of new embeddings
self.update_color(len(new_embeddings), cls)
# Concatenate the reduced embeddings with the new reduced embeddings
self.reduced_embeddings = np.concatenate([self.reduced_embeddings, new_reduced_embeddings], axis=0)
# Create a new EmbeddingSet with the reduced embeddings and color list
self.reduced_embeddingSet = to_EmbeddingSet(self.reduced_embeddings, self.color_list)
# Compute the neighbors of the reduced embedding set
self.reduced_embeddingSet.compute_neighbors(metric='euclidean', n_neighbors=self.n_neighbors)
# Check if the class is graph
if cls == 'graph':
# Concatenate the reduced graph embeddings with the new reduced embeddings
self.reduced_graph_embeddings = np.concatenate([self.reduced_graph_embeddings, new_reduced_embeddings],
axis=0)
def reduced_embedding_dim(self, in_embedding):
"""Reduce the dimension of the input embedding
Args:
self (object): The object instance
in_embedding (np.array): The input embedding
Returns:
np.array: The reduced embedding
"""
return self.trans.transform(in_embedding)
def update_cur_reduced_embedding(self, new_reduced_embedding):
"""Update the current reduced embedding
Args:
self (object): The object instance
new_reduced_embedding (np.array): The new reduced embedding
Returns:
m_emb.EmbeddingSet: The updated embedding set
"""
self.cur_reduced_embedding = np.concatenate([self.cur_reduced_embedding, new_reduced_embedding], axis=0)
self.cur_embeddingSet = to_EmbeddingSet(self.cur_reduced_embedding, self.cur_color_array)
self.ori_embeddingSet.compute_neighbors(metric='euclidean', n_neighbors=self.cfg.modal.ori_n_neighbors)
return self.ori_embeddingSet
@print_something_first
def construct_merge_modal_view(self):
"""Construct a merge modal view
Args:
self (object): The object instance
"""
thum = m_emb.ImageThumbnails(self.img_list) if self.is_show_img() else None
viewer = m_emb.Viewer(embeddings=self.reduced_embeddingSet, thumbnails=thum)
self.viewer_dict['merge_modal_viewer'] = viewer
print(f'Finished construct merge modal view, use "{self.__class__.__name__}.viewer_dict["merge_modal_viewer"] '
f'to plot"')
@print_something_first
def construct_cross_modal_view(self):
"""Construct a cross modal view
Args:
self (object): The object instance
"""
embeddingSet = m_emb.EmbeddingSet([to_Embedding(
self.embeddings[idx * self.samples_num: (idx + 1) * self.samples_num],
self.color_list[idx * self.samples_num: (idx + 1) * self.samples_num])
for idx in range(self.modal_nums)])
embeddingSet.compute_neighbors(metric='cosine', n_neighbors=self.cfg.modal.n_neighbors)
reduced_emb = embeddingSet.project(method=ProjectionTechnique.ALIGNED_UMAP,
metric='cosine', n_neighbors=self.cfg.modal.n_neighbors)
reduced_emb.compute_neighbors(metric='euclidean', n_neighbors=self.cfg.modal.n_neighbors)
w = m_emb.Viewer(embeddings=reduced_emb)
self.viewer_dict['cross_modal_viewer'] = w
print(f'Finished construct cross modal view, use "{self.__class__.__name__}.viewer_dict["cross_modal_viewer"] '
f'to plot"')
def help(self):
"""Print out helpful information
Args:
self (object): The object instance
"""
print(f'Use "{self.__class__.__name__}.add_new_smiles_embedding(SMILES_LIST)" to add new embeddings')
print(f'Use "{self.__class__.__name__}.plot_neighbors(idx_list,neighbors_num=1)" to plot neighbors img')
for k in self.viewer_dict:
print(
f'Use "{self.__class__.__name__}.viewer_dict["{k}"] to get the {"_".join(k.split("_")[:-1])} embeddings visualization results')
@print_something_first
def init_transformation_from_ori_embeddings(self):
"""Initialize the transformation from original embeddings
Args:
self (object): The object instance
"""
trans = umap.UMAP(metric='cosine', n_neighbors=100).fit(self.embeddings)
self.trans = trans
def is_show_img(self):
"""Check if images are shown
Args:
self (object): The object instance
Returns:
bool: Whether images are shown
"""
return self.cfg.modal.get('show_img', False)
def init_showing_imgs(self):
"""Initialize the showing images
Args:
self (object): The object instance
"""
self.img_list = []
for m_i, modal_name in enumerate(self.modal_name_list):
modal_imgs_dir = os.path.join(self.cfg.moda.modal_imgs_dir, f"omics_{modal_name}_vis")
for i in range(self.cfg.modal.samples_num):
read_path = os.path.join(modal_imgs_dir, f"{i:05d}.png")
img = cv2.imread(read_path)[..., ::-1]
self.img_list.append(img)
def init_color_arr(self):
"""Initialize the color array
Args:
self (object): The object instance
Returns:
np.array: The initialized color array
"""
self.color_array = np.concatenate([i * np.ones((self.samples_num,))
for i in range(self.modal_nums)],
axis=0)
return self.color_array
def update_color(self, add_samples_num, sample_cls):
"""Update the color list
Args:
self (object): The object instance
add_samples_num (int): The number of samples to add
sample_cls (int): The class of the samples
"""
self.color_list.extend([sample_cls] * add_samples_num)
self.color_array = np.array(self.color_list)
if __name__ == '__main__':
config = 'config/miga_vis/embedding_vis_cfg.yaml'
extra_para = {'modal': {'samples_num': 1000}}
v = CMEViewer(config, extra_para)
embeddings = v.add_new_smiles_embedding(['C[C@@H](NC(=O)C[C@@H]1O[C@H](CO)[C@H](NC(=O)c2cccnc2)C=C1)c1ccccc1'])
v.plot_neighbors([2000])
x = 1
|