MIGAVis / embedding_viewer.py
McHz1s's picture
图片检索分子
af5e1e3
import warnings
warnings.filterwarnings('ignore')
import os
from utils.distribution import DeviceManager
DeviceManager('cpu')
# from config import args
import pandas as pd
from emblaze import ProjectionTechnique
from embedding_produce import EmbeddingProducer
from utils.mol_to_graph import mol_to_graph_data_obj_asAtomNum
os.environ['RANK'] = '-1'
import cv2
import umap
from rdkit.Chem import AllChem
from rdkit.Chem import Draw as chem_draw
from torch_geometric.data import Batch
from utils.emblaze_utils import to_EmbeddingSet, to_Embedding
from utils.geometric_graph import to_dense, fetch_geometric_batch
from utils.tensor_operator import to_device, tensor2array
from utils.utils import print_something_first
import torch
import numpy as np
import emblaze as m_emb
from visualizer.visualizers import plt_show, vis_cellular_img
logger = print
class CMEViewer(EmbeddingProducer):
def __init__(self, *args, **kwargs):
super(CMEViewer, self).__init__(*args, **kwargs)
# Initialize the number of samples
self.samples_num = self.cfg.modal.samples_num
# Initialize the list of embedding label indices
self.embedding_label_idx_list = [self.modal_nums]
# Initialize the original color array
self.ori_color_array = None
# Initialize the number of neighbors
self.n_neighbors = self.cfg.modal.n_neighbors
# Initialize the viewer dictionary
self.viewer_dict = {'merge_modal_viewer': None, 'cross_modal_viewer': None}
# Initialize the original view
self.init_ori_view()
@print_something_first
def init_img_list(self):
self.img_list = []
for m_i, modal_name in enumerate(self.modal_name_list):
modal_imgs_dir = os.path.join(self.cfg.modal.modal_imgs_dir, f'omics_{modal_name}_vis')
for i in range(self.samples_num):
read_path = os.path.join(modal_imgs_dir, f"{i:05d}.png")
img = cv2.imread(read_path)[..., ::-1]
self.img_list.append(img)
def update_img_list(self, new_img_list):
if not isinstance(new_img_list, list):
new_img_list = [new_img_list]
self.graph_img_index.extend(list(range(len(self.img_list), len(self.img_list) + len(new_img_list))))
self.img_list.extend(new_img_list)
@print_something_first
def init_embeddings(self):
# Load the required modal names from the configuration
modal_name_list = self.cfg.modal.modal_name_list
# Get the directory of the embeddings
embeddings_dir = self.cfg.modal.embeddings_dir
# Create a list to store the loaded embeddings
embeddings_list = []
# Iterate through the modal names and load the corresponding embeddings
for name in modal_name_list:
embeddings_list.append(
np.load(os.path.join(embeddings_dir, f"{name}_embedding.npy"))[:self.cfg.modal.samples_num])
# Concatenate the embeddings into one array
self.embeddings = np.concatenate(embeddings_list, axis=0)
# Create a list to store the modal names
self.color_list = []
# Iterate through the modal names and add them to the list
for m_i, m_name in enumerate(self.modal_name_list):
self.color_list.extend([m_name] * self.samples_num)
def init_ori_view(self):
# Initialize embeddings
self.init_embeddings()
# Initialize transformation from original embeddings
self.init_transformation_from_ori_embeddings()
# Get reduced embeddings
self.reduced_embeddings = self.trans.embedding_
# Get modal name list
modal_name_list = self.cfg.modal.modal_name_list
# Get graph class number
graph_cls_num = modal_name_list.index('graph')
# Get start and end index of graph embeddings
graph_index_start = graph_cls_num * self.cfg.modal.samples_num
graph_index_len = self.cfg.modal.samples_num
graph_index_end = graph_index_start + graph_index_len
# Get reduced graph embeddings
self.reduced_graph_embeddings = self.reduced_embeddings[graph_cls_num * self.cfg.modal.samples_num:]
# Get graph image index
self.graph_img_index = list(range(graph_index_start, graph_index_end))
# Create reduced embedding set
self.reduced_embeddingSet = to_EmbeddingSet(self.reduced_embeddings, self.color_list)
# Compute neighbors of reduced embedding set
self.reduced_embeddingSet.compute_neighbors(metric='euclidean', n_neighbors=self.cfg.modal.n_neighbors)
# Initialize image list
self.init_img_list()
@print_something_first
def get_new_smiles_embeddings(self, smiles_list):
# Add comments to important steps in the code
self.model.eval() # Set model to evaluation mode
if isinstance(smiles_list, str): # Check if smiles_list is a string
smiles_list = [smiles_list] # Convert smiles_list to list if it is a string
mol_list = [] # Initialize empty list for molecules
mol_img_list = [] # Initialize empty list for molecule images
for s_idx, s in enumerate(smiles_list): # Iterate through smiles_list
rdkit_mol = AllChem.MolFromSmiles(s) # Create RDKit molecule from SMILES string
molecular_graph = mol_to_graph_data_obj_asAtomNum(rdkit_mol,
True) # Convert RDKit molecule to graph data object
mol_img = chem_draw.MolToImage(rdkit_mol, size=(224, 224), dpi=600) # Create image of molecule
mol_img = mol_img.__array__() # Convert image to array
mol_img = mol_img.copy() # Copy image array
cv2.putText(mol_img, f"{len(self.reduced_embeddings) + s_idx:05d}", (20, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
(0, 0, 0), 1, cv2.LINE_AA) # Add text to image
mol_img_list.append(mol_img) # Append image to list
mol_list.append(molecular_graph) # Append graph data object to list
geometric_batch = Batch().from_data_list(mol_list) # Create geometric batch from list of molecules
matrix_graphs, node_masks = to_dense(
*fetch_geometric_batch(geometric_batch, ['edge_attr', 'batch'])) # Fetch dense matrix graphs and node masks
matrix_graphs, node_masks = [to_device(x, self.device) for x in
[matrix_graphs, node_masks]] # Move matrix graphs and node masks to device
matrix_graphs.to(to_type='float') # Convert matrix graphs to float type
data_dict = {'graphs': matrix_graphs, 'node_masks': node_masks} # Create data dictionary
with torch.no_grad(): # Disable gradient calculation
molecule_graph_emb = self.model.get_graph_embedding(data_dict)['embedding'] # Get graph embedding
new_graph_embedding = tensor2array(molecule_graph_emb) # Convert graph embedding to array
self.update_img_list(mol_img_list) # Update image list
return new_graph_embedding # Return graph embedding
def add_new_smiles_embedding(self, smiles, construct_view=True, plot_neighbors=False):
new_embeddings = super(CMEViewer, self).add_new_smiles_embedding(smiles)
cur_embeddings_num = len(self.reduced_embeddings)
self.update_embeddings(new_embeddings, 'graph')
if construct_view:
self.construct_merge_modal_view()
if plot_neighbors:
self.plot_neighbors([i for i in range(len(self.reduced_embeddingSet.embeddings[0]) - len(smiles),
len(self.reduced_embeddingSet.embeddings[0]))])
print(f'Finished add smiles: {smiles}.')
print(f'Their ids: {[idx for idx in range(cur_embeddings_num, len(self.reduced_embeddings))]}')
print(f'Through '
f'{self.__class__.__name__}.plot_neighbors(idx)'
f'to get their neighbors.')
return new_embeddings
def add_new_img_embedding(self, in_img, construct_view=True, plot_neighbors=False, preprocess=True):
if preprocess:
import albumentations as A
from albumentations.pytorch import ToTensorV2
transform = A.Compose([
A.Resize(height=128, width=128, p=1.0),
ToTensorV2()],
p=1.0)
processed_img = transform(image=in_img.transpose(1, 2, 0))['image'] / 15
else:
processed_img = in_img.copy()
processed_img = np.expand_dims(processed_img, axis=0)
new_embeddings = super(CMEViewer, self).add_new_img_embedding(processed_img)
cur_embeddings_num = len(self.reduced_embeddings)
vis_img = vis_cellular_img(in_img, (224, 224))
self.update_img_list(vis_img)
self.update_embeddings(new_embeddings, 'img')
if construct_view:
self.construct_merge_modal_view()
if plot_neighbors:
self.plot_neighbors([i for i in range(len(self.reduced_embeddingSet.embeddings[0]) - len(processed_img),
len(self.reduced_embeddingSet.embeddings[0]))])
print(f'Finished add imgs')
print(f'Their ids: {[idx for idx in range(cur_embeddings_num, len(self.reduced_embeddings))]}')
print(f'Through '
f'{self.__class__.__name__}.plot_neighbors(idx)'
f'to get their neighbors.')
return new_embeddings
def plot_neighbors(self, idx, neighbors_num=1):
# Convert idx_list to list if it is not already a list
idx_list = [idx]
# Get embeddings from reduced graph and target
idx_cls = self.color_array[idx]
delta = 0 if idx_cls == 'graph' else self.samples_num
emb_lib = self.reduced_embeddings[delta: delta + self.samples_num]
target_emb = self.reduced_embeddings[idx_list]
# Convert embeddings to tensors and move to device
emb_lib = torch.as_tensor(emb_lib).to(self.device)
target_emb = torch.as_tensor(target_emb).to(self.device)
# Calculate distances between target embeddings and embeddings in reduced graph
dis = torch.cdist(target_emb, emb_lib, p=2)
# Get top k neighbors (neighbors_num + 1) with smallest distances
sorted_distances, indices = torch.topk(dis, k=neighbors_num, largest=False, dim=1)
# Iterate through each target index
for idx_i, idx in enumerate(idx_list):
# Get image for source index
idx_img = self.img_list[idx]
plt_show(idx_img, f"Source: {idx:05d}")
# Iterate through each neighbor
for ind_i, ind in enumerate(indices[idx_i]):
# Get image for neighbor index
neighbor_img = self.img_list[delta + tensor2array(ind)]
plt_show(neighbor_img, f"Neighbor: {ind:05d}")
def update_embeddings(self, new_embeddings, cls):
# Reduce the dimension of new embeddings
new_reduced_embeddings = self.reduced_embedding_dim(new_embeddings)
# Update the color list with the length of new embeddings
self.update_color(len(new_embeddings), cls)
# Concatenate the reduced embeddings with the new reduced embeddings
self.reduced_embeddings = np.concatenate([self.reduced_embeddings, new_reduced_embeddings], axis=0)
# Create a new EmbeddingSet with the reduced embeddings and color list
self.reduced_embeddingSet = to_EmbeddingSet(self.reduced_embeddings, self.color_list)
# Compute the neighbors of the reduced embedding set
self.reduced_embeddingSet.compute_neighbors(metric='euclidean', n_neighbors=self.n_neighbors)
# Check if the class is graph
if cls == 'graph':
# Concatenate the reduced graph embeddings with the new reduced embeddings
self.reduced_graph_embeddings = np.concatenate([self.reduced_graph_embeddings, new_reduced_embeddings],
axis=0)
def reduced_embedding_dim(self, in_embedding):
"""Reduce the dimension of the input embedding
Args:
self (object): The object instance
in_embedding (np.array): The input embedding
Returns:
np.array: The reduced embedding
"""
return self.trans.transform(in_embedding)
def update_cur_reduced_embedding(self, new_reduced_embedding):
"""Update the current reduced embedding
Args:
self (object): The object instance
new_reduced_embedding (np.array): The new reduced embedding
Returns:
m_emb.EmbeddingSet: The updated embedding set
"""
self.cur_reduced_embedding = np.concatenate([self.cur_reduced_embedding, new_reduced_embedding], axis=0)
self.cur_embeddingSet = to_EmbeddingSet(self.cur_reduced_embedding, self.cur_color_array)
self.ori_embeddingSet.compute_neighbors(metric='euclidean', n_neighbors=self.cfg.modal.ori_n_neighbors)
return self.ori_embeddingSet
@print_something_first
def construct_merge_modal_view(self):
"""Construct a merge modal view
Args:
self (object): The object instance
"""
thum = m_emb.ImageThumbnails(self.img_list) if self.is_show_img() else None
viewer = m_emb.Viewer(embeddings=self.reduced_embeddingSet, thumbnails=thum)
self.viewer_dict['merge_modal_viewer'] = viewer
print(f'Finished construct merge modal view, use "{self.__class__.__name__}.viewer_dict["merge_modal_viewer"] '
f'to plot"')
@print_something_first
def construct_cross_modal_view(self):
"""Construct a cross modal view
Args:
self (object): The object instance
"""
embeddingSet = m_emb.EmbeddingSet([to_Embedding(
self.embeddings[idx * self.samples_num: (idx + 1) * self.samples_num],
self.color_list[idx * self.samples_num: (idx + 1) * self.samples_num])
for idx in range(self.modal_nums)])
embeddingSet.compute_neighbors(metric='cosine', n_neighbors=self.cfg.modal.n_neighbors)
reduced_emb = embeddingSet.project(method=ProjectionTechnique.ALIGNED_UMAP,
metric='cosine', n_neighbors=self.cfg.modal.n_neighbors)
reduced_emb.compute_neighbors(metric='euclidean', n_neighbors=self.cfg.modal.n_neighbors)
w = m_emb.Viewer(embeddings=reduced_emb)
self.viewer_dict['cross_modal_viewer'] = w
print(f'Finished construct cross modal view, use "{self.__class__.__name__}.viewer_dict["cross_modal_viewer"] '
f'to plot"')
def help(self):
"""Print out helpful information
Args:
self (object): The object instance
"""
print(f'Use "{self.__class__.__name__}.add_new_smiles_embedding(SMILES_LIST)" to add new embeddings')
print(f'Use "{self.__class__.__name__}.plot_neighbors(idx_list,neighbors_num=1)" to plot neighbors img')
for k in self.viewer_dict:
print(
f'Use "{self.__class__.__name__}.viewer_dict["{k}"] to get the {"_".join(k.split("_")[:-1])} embeddings visualization results')
@print_something_first
def init_transformation_from_ori_embeddings(self):
"""Initialize the transformation from original embeddings
Args:
self (object): The object instance
"""
trans = umap.UMAP(metric='cosine', n_neighbors=100).fit(self.embeddings)
self.trans = trans
def is_show_img(self):
"""Check if images are shown
Args:
self (object): The object instance
Returns:
bool: Whether images are shown
"""
return self.cfg.modal.get('show_img', False)
def init_showing_imgs(self):
"""Initialize the showing images
Args:
self (object): The object instance
"""
self.img_list = []
for m_i, modal_name in enumerate(self.modal_name_list):
modal_imgs_dir = os.path.join(self.cfg.moda.modal_imgs_dir, f"omics_{modal_name}_vis")
for i in range(self.cfg.modal.samples_num):
read_path = os.path.join(modal_imgs_dir, f"{i:05d}.png")
img = cv2.imread(read_path)[..., ::-1]
self.img_list.append(img)
def init_color_arr(self):
"""Initialize the color array
Args:
self (object): The object instance
Returns:
np.array: The initialized color array
"""
self.color_array = np.concatenate([i * np.ones((self.samples_num,))
for i in range(self.modal_nums)],
axis=0)
return self.color_array
def update_color(self, add_samples_num, sample_cls):
"""Update the color list
Args:
self (object): The object instance
add_samples_num (int): The number of samples to add
sample_cls (int): The class of the samples
"""
self.color_list.extend([sample_cls] * add_samples_num)
self.color_array = np.array(self.color_list)
if __name__ == '__main__':
config = 'config/miga_vis/embedding_vis_cfg.yaml'
extra_para = {'modal': {'samples_num': 1000}}
v = CMEViewer(config, extra_para)
embeddings = v.add_new_smiles_embedding(['C[C@@H](NC(=O)C[C@@H]1O[C@H](CO)[C@H](NC(=O)c2cccnc2)C=C1)c1ccccc1'])
v.plot_neighbors([2000])
x = 1