# Exploring CLOOME with Amumo 
Humer et al. 2023

Github Repo: https://github.com/ginihumer/Amumo

Interactive Article: https://jku-vds-lab.at/amumo

In [None]:
! pip install git+https://github.com/ginihumer/Amumo.git

In [1]:
import numpy as np
import pandas as pd
import sys
import os
import torch
sys.path.insert(0, os.path.abspath("src/"))

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
import amumo
from amumo import model as am_model
from amumo import data as am_data
from amumo import widgets as am_widgets
from amumo import utils as am_utils

  @numba.jit()
  @numba.jit()
  @numba.jit()
  @numba.jit()


In [3]:
basepath = './'
datapath = os.path.join(basepath, "data")
batch_size = 100
seed = 31415

In [4]:
from PIL import Image

# Data Helpers
def get_data_helper(dataset, filters=[], method=any):
    all_images, all_prompts = dataset.get_filtered_data(filters, method=method)
    print(len(all_images))

    dataset_name = dataset.name
    if len(filters) > 0:
        dataset_name = dataset_name + '_filter-' + method.__name__ + '_' + '-'.join(filters)
    else:
        dataset_name = dataset_name + '_size-%i'%len(all_images)

    return all_images, all_prompts, dataset_name


def reshape_image(arr):
    c, h, w = arr.shape
    reshaped_image = np.empty((h, w, c))

    reshaped_image[:,:,0] = arr[0]
    reshaped_image[:,:,1] = arr[1]
    reshaped_image[:,:,2] = arr[2]

    reshaped_pil = Image.fromarray(reshaped_image.astype("uint8"))

    return reshaped_pil


from rdkit import Chem
from rdkit.Chem import rdFMCS
import io
class MoleculeType(am_data.TextType):
    name = "Molecule"

    def __init__(self, data) -> None:
        # data is a list of SMILES
        super().__init__(data)

    def getMinSummary(self, ids):
        # retrieve MCS of mols
        if len(ids) == 1:
            return self.data[ids[0]]

        mols = [Chem.MolFromSmiles(smiles) for smiles in self.data[ids]]
        mcs = rdFMCS.FindMCS(mols)
        mcs_smiles = Chem.MolToSmiles(Chem.MolFromSmarts(mcs.smartsString))
        return mcs_smiles
    
    def getVisItem(self, idx):
        output_img = io.BytesIO()
        img = Chem.Draw.MolToImage(Chem.MolFromSmiles(self.data[idx]))
        img.resize((300,300)).save(output_img, format='JPEG')
        return output_img
    

class BioImageType(am_data.ImageType):
    name = "Bio Image"

    def __init__(self, data) -> None:
        super().__init__(data)


class CLOOMDataset_Dataset(am_data.DatasetInterface):
    name='CLOOMDataset'

    def __init__(self, path, seed=31415, batch_size = 100):
        super().__init__(path, seed, batch_size)

        self.MODE1_Type = BioImageType
        self.MODE2_Type = MoleculeType

        mol_index_file = os.path.join(path, "cellpainting-unique-molecule.csv")
        img_index_file = os.path.join(path, "cellpainting-all-imgpermol.csv")
        images_arr = os.path.join(path, "subset_npzs_dict_.npz")


        # molecule smiles
        all_molecules = pd.read_csv(mol_index_file)
        all_molecules.rename(columns={"SAMPLE_KEY": "SAMPLE_KEY_mol"}, inplace=True)
        # microscopy images
        all_microscopies = pd.read_csv(img_index_file)
        all_microscopies.rename(columns={"SAMPLE_KEY": "SAMPLE_KEY_img"}, inplace=True)
        # join the two dataframes
        cloome_data = pd.merge(all_molecules[["SAMPLE_KEY_mol", "SMILES"]], all_microscopies[["SAMPLE_KEY_img", "SMILES"]], on="SMILES", how="inner")
        
        # subsample data
        self.subset_idcs = self._get_random_subsample(len(cloome_data))
        self.dataset = cloome_data.iloc[self.subset_idcs]
        
        self.all_prompts = self.dataset["SMILES"].values

        # microscopy images TODO... load images on demand with a custom image loader
        all_microscopies = pd.read_csv(img_index_file)
        images_dict = np.load(images_arr, allow_pickle = True)
        all_images = []
        for img_id in self.dataset["SAMPLE_KEY_img"]:
            image = images_dict[f"{img_id}.npz"]
            im = reshape_image(image)
            all_images.append(im)

        self.all_images = np.array(all_images)
        

In [5]:
# Load Data
dataset_cloome = CLOOMDataset_Dataset(datapath,seed,batch_size)
cloome_images, cloome_molecules, cloome_dataset_name = get_data_helper(dataset_cloome, filters=[], method=any)
cloome_dataset_name


100


  self.all_images = np.array(all_images)
  self.all_images = np.array(all_images)


'CLOOMDataset_size-100'

In [6]:
cloome_molecules.getVisItem(0)

<_io.BytesIO at 0x20775295630>

In [7]:
class PrecalculatedModel(am_model.CLIPModelInterface):
    model_name = 'precalculated'

    def __init__(self, name, dataset_name, modality1_features, modality2_features, logit_scale=torch.tensor(0)) -> None:
        # this class is a workaround for precalculated features
        # it just saves the features as cached files so that the "encode_image" and "encode_text" methods are not called
        self.available_models = [name]
        super().__init__(name, device='cpu')
        self.logit_scale = logit_scale
        self.modality1_features = modality1_features
        self.modality2_features = modality2_features
        self.process_precalculated_features(dataset_name)

    def process_precalculated_features(self, dataset_name):
        data_prefix = dataset_name + '_' + self.model_name + '_' + self.name
        data_prefix = data_prefix.replace('/','-')
        np.savetxt(am_utils.data_checkpoint_dir + data_prefix + '_image-embedding.csv', self.modality1_features.cpu(), delimiter = ',')
        np.savetxt(am_utils.data_checkpoint_dir + data_prefix + '_text-embedding.csv', self.modality2_features.cpu(), delimiter = ',') 

    def encode_image(self, images):
        raise NotImplementedError("this cannot be done for precalculated features -> use cached features")
    
    def encode_text(self, texts):
        raise NotImplementedError("this cannot be done for precalculated features -> use cached features")

In [8]:
molecule_features = os.path.join(datapath, "all_molecule_cellpainting_features.pkl")
image_features = os.path.join(datapath, "subset_image_cellpainting_features.pkl")

# molecule features
mol_features_torch = torch.load(molecule_features, map_location=device)
mol_features = mol_features_torch["mol_features"]
mol_ids = mol_features_torch["mol_ids"]

# microscopy features
img_features_torch = torch.load(image_features, map_location=device)
img_features = img_features_torch["img_features"]
img_ids = img_features_torch["img_ids"]

# extract subsets of features
img_feature_idcs = [np.where(np.array(img_features_torch["img_ids"])==i)[0][0] for i in dataset_cloome.dataset["SAMPLE_KEY_img"].values]
mol_feature_idcs = [np.where(np.array(mol_features_torch["mol_ids"])==i)[0][0] for i in dataset_cloome.dataset["SAMPLE_KEY_mol"].values]

mol_features_sample = mol_features_torch['mol_features'][mol_feature_idcs]
mol_features_sample = am_utils.l2_norm(mol_features_sample)

img_features_sample = img_features_torch['img_features'][img_feature_idcs]
img_features_sample = am_utils.l2_norm(img_features_sample)

# cache features
model = PrecalculatedModel('seed-%i'%seed, cloome_dataset_name, img_features_sample, mol_features_sample)

In [9]:
cloome_widget = am_widgets.CLIPExplorerWidget(cloome_dataset_name, cloome_images, cloome_molecules, models=[model])
cloome_widget.hover_widget.width = 200
cloome_widget

found cached embeddings for CLOOMDataset_size-100_precalculated_seed-31415


CLIPExplorerWidget(children=(VBox(children=(HBox(children=(Dropdown(description='Model: ', options=('precalculâ€¦