Spaces:
Build error
Build error
import warnings | |
warnings.filterwarnings('ignore') | |
import os | |
from utils.distribution import DeviceManager | |
DeviceManager('cpu') | |
# from config import args | |
import pandas as pd | |
from emblaze import ProjectionTechnique | |
from embedding_produce import EmbeddingProducer | |
from utils.mol_to_graph import mol_to_graph_data_obj_asAtomNum | |
os.environ['RANK'] = '-1' | |
import cv2 | |
import umap | |
from rdkit.Chem import AllChem | |
from rdkit.Chem import Draw as chem_draw | |
from torch_geometric.data import Batch | |
from utils.emblaze_utils import to_EmbeddingSet, to_Embedding | |
from utils.geometric_graph import to_dense, fetch_geometric_batch | |
from utils.tensor_operator import to_device, tensor2array | |
from utils.utils import print_something_first | |
import torch | |
import numpy as np | |
import emblaze as m_emb | |
from visualizer.visualizers import plt_show, vis_cellular_img | |
logger = print | |
class CMEViewer(EmbeddingProducer): | |
def __init__(self, *args, **kwargs): | |
super(CMEViewer, self).__init__(*args, **kwargs) | |
# Initialize the number of samples | |
self.samples_num = self.cfg.modal.samples_num | |
# Initialize the list of embedding label indices | |
self.embedding_label_idx_list = [self.modal_nums] | |
# Initialize the original color array | |
self.ori_color_array = None | |
# Initialize the number of neighbors | |
self.n_neighbors = self.cfg.modal.n_neighbors | |
# Initialize the viewer dictionary | |
self.viewer_dict = {'merge_modal_viewer': None, 'cross_modal_viewer': None} | |
# Initialize the original view | |
self.init_ori_view() | |
def init_img_list(self): | |
self.img_list = [] | |
for m_i, modal_name in enumerate(self.modal_name_list): | |
modal_imgs_dir = os.path.join(self.cfg.modal.modal_imgs_dir, f'omics_{modal_name}_vis') | |
for i in range(self.samples_num): | |
read_path = os.path.join(modal_imgs_dir, f"{i:05d}.png") | |
img = cv2.imread(read_path)[..., ::-1] | |
self.img_list.append(img) | |
def update_img_list(self, new_img_list): | |
if not isinstance(new_img_list, list): | |
new_img_list = [new_img_list] | |
self.graph_img_index.extend(list(range(len(self.img_list), len(self.img_list) + len(new_img_list)))) | |
self.img_list.extend(new_img_list) | |
def init_embeddings(self): | |
# Load the required modal names from the configuration | |
modal_name_list = self.cfg.modal.modal_name_list | |
# Get the directory of the embeddings | |
embeddings_dir = self.cfg.modal.embeddings_dir | |
# Create a list to store the loaded embeddings | |
embeddings_list = [] | |
# Iterate through the modal names and load the corresponding embeddings | |
for name in modal_name_list: | |
embeddings_list.append( | |
np.load(os.path.join(embeddings_dir, f"{name}_embedding.npy"))[:self.cfg.modal.samples_num]) | |
# Concatenate the embeddings into one array | |
self.embeddings = np.concatenate(embeddings_list, axis=0) | |
# Create a list to store the modal names | |
self.color_list = [] | |
# Iterate through the modal names and add them to the list | |
for m_i, m_name in enumerate(self.modal_name_list): | |
self.color_list.extend([m_name] * self.samples_num) | |
def init_ori_view(self): | |
# Initialize embeddings | |
self.init_embeddings() | |
# Initialize transformation from original embeddings | |
self.init_transformation_from_ori_embeddings() | |
# Get reduced embeddings | |
self.reduced_embeddings = self.trans.embedding_ | |
# Get modal name list | |
modal_name_list = self.cfg.modal.modal_name_list | |
# Get graph class number | |
graph_cls_num = modal_name_list.index('graph') | |
# Get start and end index of graph embeddings | |
graph_index_start = graph_cls_num * self.cfg.modal.samples_num | |
graph_index_len = self.cfg.modal.samples_num | |
graph_index_end = graph_index_start + graph_index_len | |
# Get reduced graph embeddings | |
self.reduced_graph_embeddings = self.reduced_embeddings[graph_cls_num * self.cfg.modal.samples_num:] | |
# Get graph image index | |
self.graph_img_index = list(range(graph_index_start, graph_index_end)) | |
# Create reduced embedding set | |
self.reduced_embeddingSet = to_EmbeddingSet(self.reduced_embeddings, self.color_list) | |
# Compute neighbors of reduced embedding set | |
self.reduced_embeddingSet.compute_neighbors(metric='euclidean', n_neighbors=self.cfg.modal.n_neighbors) | |
# Initialize image list | |
self.init_img_list() | |
def get_new_smiles_embeddings(self, smiles_list): | |
# Add comments to important steps in the code | |
self.model.eval() # Set model to evaluation mode | |
if isinstance(smiles_list, str): # Check if smiles_list is a string | |
smiles_list = [smiles_list] # Convert smiles_list to list if it is a string | |
mol_list = [] # Initialize empty list for molecules | |
mol_img_list = [] # Initialize empty list for molecule images | |
for s_idx, s in enumerate(smiles_list): # Iterate through smiles_list | |
rdkit_mol = AllChem.MolFromSmiles(s) # Create RDKit molecule from SMILES string | |
molecular_graph = mol_to_graph_data_obj_asAtomNum(rdkit_mol, | |
True) # Convert RDKit molecule to graph data object | |
mol_img = chem_draw.MolToImage(rdkit_mol, size=(224, 224), dpi=600) # Create image of molecule | |
mol_img = mol_img.__array__() # Convert image to array | |
mol_img = mol_img.copy() # Copy image array | |
cv2.putText(mol_img, f"{len(self.reduced_embeddings) + s_idx:05d}", (20, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, | |
(0, 0, 0), 1, cv2.LINE_AA) # Add text to image | |
mol_img_list.append(mol_img) # Append image to list | |
mol_list.append(molecular_graph) # Append graph data object to list | |
geometric_batch = Batch().from_data_list(mol_list) # Create geometric batch from list of molecules | |
matrix_graphs, node_masks = to_dense( | |
*fetch_geometric_batch(geometric_batch, ['edge_attr', 'batch'])) # Fetch dense matrix graphs and node masks | |
matrix_graphs, node_masks = [to_device(x, self.device) for x in | |
[matrix_graphs, node_masks]] # Move matrix graphs and node masks to device | |
matrix_graphs.to(to_type='float') # Convert matrix graphs to float type | |
data_dict = {'graphs': matrix_graphs, 'node_masks': node_masks} # Create data dictionary | |
with torch.no_grad(): # Disable gradient calculation | |
molecule_graph_emb = self.model.get_graph_embedding(data_dict)['embedding'] # Get graph embedding | |
new_graph_embedding = tensor2array(molecule_graph_emb) # Convert graph embedding to array | |
self.update_img_list(mol_img_list) # Update image list | |
return new_graph_embedding # Return graph embedding | |
def add_new_smiles_embedding(self, smiles, construct_view=True, plot_neighbors=False): | |
new_embeddings = super(CMEViewer, self).add_new_smiles_embedding(smiles) | |
cur_embeddings_num = len(self.reduced_embeddings) | |
self.update_embeddings(new_embeddings, 'graph') | |
if construct_view: | |
self.construct_merge_modal_view() | |
if plot_neighbors: | |
self.plot_neighbors([i for i in range(len(self.reduced_embeddingSet.embeddings[0]) - len(smiles), | |
len(self.reduced_embeddingSet.embeddings[0]))]) | |
print(f'Finished add smiles: {smiles}.') | |
print(f'Their ids: {[idx for idx in range(cur_embeddings_num, len(self.reduced_embeddings))]}') | |
print(f'Through ' | |
f'{self.__class__.__name__}.plot_neighbors(idx)' | |
f'to get their neighbors.') | |
return new_embeddings | |
def add_new_img_embedding(self, in_img, construct_view=True, plot_neighbors=False, preprocess=True): | |
if preprocess: | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
transform = A.Compose([ | |
A.Resize(height=128, width=128, p=1.0), | |
ToTensorV2()], | |
p=1.0) | |
processed_img = transform(image=in_img.transpose(1, 2, 0))['image'] / 15 | |
else: | |
processed_img = in_img.copy() | |
processed_img = np.expand_dims(processed_img, axis=0) | |
new_embeddings = super(CMEViewer, self).add_new_img_embedding(processed_img) | |
cur_embeddings_num = len(self.reduced_embeddings) | |
vis_img = vis_cellular_img(in_img, (224, 224)) | |
self.update_img_list(vis_img) | |
self.update_embeddings(new_embeddings, 'img') | |
if construct_view: | |
self.construct_merge_modal_view() | |
if plot_neighbors: | |
self.plot_neighbors([i for i in range(len(self.reduced_embeddingSet.embeddings[0]) - len(processed_img), | |
len(self.reduced_embeddingSet.embeddings[0]))]) | |
print(f'Finished add imgs') | |
print(f'Their ids: {[idx for idx in range(cur_embeddings_num, len(self.reduced_embeddings))]}') | |
print(f'Through ' | |
f'{self.__class__.__name__}.plot_neighbors(idx)' | |
f'to get their neighbors.') | |
return new_embeddings | |
def plot_neighbors(self, idx, neighbors_num=1): | |
# Convert idx_list to list if it is not already a list | |
idx_list = [idx] | |
# Get embeddings from reduced graph and target | |
idx_cls = self.color_array[idx] | |
delta = 0 if idx_cls == 'graph' else self.samples_num | |
emb_lib = self.reduced_embeddings[delta: delta + self.samples_num] | |
target_emb = self.reduced_embeddings[idx_list] | |
# Convert embeddings to tensors and move to device | |
emb_lib = torch.as_tensor(emb_lib).to(self.device) | |
target_emb = torch.as_tensor(target_emb).to(self.device) | |
# Calculate distances between target embeddings and embeddings in reduced graph | |
dis = torch.cdist(target_emb, emb_lib, p=2) | |
# Get top k neighbors (neighbors_num + 1) with smallest distances | |
sorted_distances, indices = torch.topk(dis, k=neighbors_num, largest=False, dim=1) | |
# Iterate through each target index | |
for idx_i, idx in enumerate(idx_list): | |
# Get image for source index | |
idx_img = self.img_list[idx] | |
plt_show(idx_img, f"Source: {idx:05d}") | |
# Iterate through each neighbor | |
for ind_i, ind in enumerate(indices[idx_i]): | |
# Get image for neighbor index | |
neighbor_img = self.img_list[delta + tensor2array(ind)] | |
plt_show(neighbor_img, f"Neighbor: {ind:05d}") | |
def update_embeddings(self, new_embeddings, cls): | |
# Reduce the dimension of new embeddings | |
new_reduced_embeddings = self.reduced_embedding_dim(new_embeddings) | |
# Update the color list with the length of new embeddings | |
self.update_color(len(new_embeddings), cls) | |
# Concatenate the reduced embeddings with the new reduced embeddings | |
self.reduced_embeddings = np.concatenate([self.reduced_embeddings, new_reduced_embeddings], axis=0) | |
# Create a new EmbeddingSet with the reduced embeddings and color list | |
self.reduced_embeddingSet = to_EmbeddingSet(self.reduced_embeddings, self.color_list) | |
# Compute the neighbors of the reduced embedding set | |
self.reduced_embeddingSet.compute_neighbors(metric='euclidean', n_neighbors=self.n_neighbors) | |
# Check if the class is graph | |
if cls == 'graph': | |
# Concatenate the reduced graph embeddings with the new reduced embeddings | |
self.reduced_graph_embeddings = np.concatenate([self.reduced_graph_embeddings, new_reduced_embeddings], | |
axis=0) | |
def reduced_embedding_dim(self, in_embedding): | |
"""Reduce the dimension of the input embedding | |
Args: | |
self (object): The object instance | |
in_embedding (np.array): The input embedding | |
Returns: | |
np.array: The reduced embedding | |
""" | |
return self.trans.transform(in_embedding) | |
def update_cur_reduced_embedding(self, new_reduced_embedding): | |
"""Update the current reduced embedding | |
Args: | |
self (object): The object instance | |
new_reduced_embedding (np.array): The new reduced embedding | |
Returns: | |
m_emb.EmbeddingSet: The updated embedding set | |
""" | |
self.cur_reduced_embedding = np.concatenate([self.cur_reduced_embedding, new_reduced_embedding], axis=0) | |
self.cur_embeddingSet = to_EmbeddingSet(self.cur_reduced_embedding, self.cur_color_array) | |
self.ori_embeddingSet.compute_neighbors(metric='euclidean', n_neighbors=self.cfg.modal.ori_n_neighbors) | |
return self.ori_embeddingSet | |
def construct_merge_modal_view(self): | |
"""Construct a merge modal view | |
Args: | |
self (object): The object instance | |
""" | |
thum = m_emb.ImageThumbnails(self.img_list) if self.is_show_img() else None | |
viewer = m_emb.Viewer(embeddings=self.reduced_embeddingSet, thumbnails=thum) | |
self.viewer_dict['merge_modal_viewer'] = viewer | |
print(f'Finished construct merge modal view, use "{self.__class__.__name__}.viewer_dict["merge_modal_viewer"] ' | |
f'to plot"') | |
def construct_cross_modal_view(self): | |
"""Construct a cross modal view | |
Args: | |
self (object): The object instance | |
""" | |
embeddingSet = m_emb.EmbeddingSet([to_Embedding( | |
self.embeddings[idx * self.samples_num: (idx + 1) * self.samples_num], | |
self.color_list[idx * self.samples_num: (idx + 1) * self.samples_num]) | |
for idx in range(self.modal_nums)]) | |
embeddingSet.compute_neighbors(metric='cosine', n_neighbors=self.cfg.modal.n_neighbors) | |
reduced_emb = embeddingSet.project(method=ProjectionTechnique.ALIGNED_UMAP, | |
metric='cosine', n_neighbors=self.cfg.modal.n_neighbors) | |
reduced_emb.compute_neighbors(metric='euclidean', n_neighbors=self.cfg.modal.n_neighbors) | |
w = m_emb.Viewer(embeddings=reduced_emb) | |
self.viewer_dict['cross_modal_viewer'] = w | |
print(f'Finished construct cross modal view, use "{self.__class__.__name__}.viewer_dict["cross_modal_viewer"] ' | |
f'to plot"') | |
def help(self): | |
"""Print out helpful information | |
Args: | |
self (object): The object instance | |
""" | |
print(f'Use "{self.__class__.__name__}.add_new_smiles_embedding(SMILES_LIST)" to add new embeddings') | |
print(f'Use "{self.__class__.__name__}.plot_neighbors(idx_list,neighbors_num=1)" to plot neighbors img') | |
for k in self.viewer_dict: | |
print( | |
f'Use "{self.__class__.__name__}.viewer_dict["{k}"] to get the {"_".join(k.split("_")[:-1])} embeddings visualization results') | |
def init_transformation_from_ori_embeddings(self): | |
"""Initialize the transformation from original embeddings | |
Args: | |
self (object): The object instance | |
""" | |
trans = umap.UMAP(metric='cosine', n_neighbors=100).fit(self.embeddings) | |
self.trans = trans | |
def is_show_img(self): | |
"""Check if images are shown | |
Args: | |
self (object): The object instance | |
Returns: | |
bool: Whether images are shown | |
""" | |
return self.cfg.modal.get('show_img', False) | |
def init_showing_imgs(self): | |
"""Initialize the showing images | |
Args: | |
self (object): The object instance | |
""" | |
self.img_list = [] | |
for m_i, modal_name in enumerate(self.modal_name_list): | |
modal_imgs_dir = os.path.join(self.cfg.moda.modal_imgs_dir, f"omics_{modal_name}_vis") | |
for i in range(self.cfg.modal.samples_num): | |
read_path = os.path.join(modal_imgs_dir, f"{i:05d}.png") | |
img = cv2.imread(read_path)[..., ::-1] | |
self.img_list.append(img) | |
def init_color_arr(self): | |
"""Initialize the color array | |
Args: | |
self (object): The object instance | |
Returns: | |
np.array: The initialized color array | |
""" | |
self.color_array = np.concatenate([i * np.ones((self.samples_num,)) | |
for i in range(self.modal_nums)], | |
axis=0) | |
return self.color_array | |
def update_color(self, add_samples_num, sample_cls): | |
"""Update the color list | |
Args: | |
self (object): The object instance | |
add_samples_num (int): The number of samples to add | |
sample_cls (int): The class of the samples | |
""" | |
self.color_list.extend([sample_cls] * add_samples_num) | |
self.color_array = np.array(self.color_list) | |
if __name__ == '__main__': | |
config = 'config/miga_vis/embedding_vis_cfg.yaml' | |
extra_para = {'modal': {'samples_num': 1000}} | |
v = CMEViewer(config, extra_para) | |
embeddings = v.add_new_smiles_embedding(['C[C@@H](NC(=O)C[C@@H]1O[C@H](CO)[C@H](NC(=O)c2cccnc2)C=C1)c1ccccc1']) | |
v.plot_neighbors([2000]) | |
x = 1 | |