In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader
import pandas as pd

# Load OpenPhenom

In [None]:
# Load model directly
from huggingface_mae import MAEModel
open_phenom = MAEModel.from_pretrained(".")

In [None]:
open_phenom.eval()
cuda_available = torch.cuda.is_available()
if cuda_available:
 open_phenom.cuda()

# Load Rxrx3-core

In [None]:
from datasets import load_dataset
rxrx3_core = load_dataset("recursionpharma/rxrx3-core")['train']

# Infernce loop

In [None]:
def convert_path_to_well_id(path_str):
 
 return path_str.split('_')[0].replace('/','_').replace('Plate','')
 
def collate_rxrx3_core(batch):
 
 images = np.stack([np.array(i['jp2']) for i in batch]).reshape(-1,6,512,512)
 images = np.vstack([patch_image(i) for i in images]) # convert to 4 256x256 patches
 images = torch.from_numpy(images)
 well_ids = [convert_path_to_well_id(i['__key__']) for i in batch[::6]]
 return images, well_ids

def iter_border_patches(width, height, patch_size):
 
 x_start, x_end, y_start, y_end = (0, width, 0, height)

 for x in range(x_start, x_end - patch_size + 1, patch_size):
 for y in range(y_start, y_end - patch_size + 1, patch_size):
 yield x, y

def patch_image(image_array, patch_size=256):
 
 _, width, height = image_array.shape
 output_patches = []
 patch_count = 0
 for x, y in iter_border_patches(width, height, patch_size):
 patch = image_array[:, y : y + patch_size, x : x + patch_size].copy()
 output_patches.append(patch)
 
 output_patches = np.stack(output_patches)
 
 return output_patches

In [None]:
# Convert to PyTorch DataLoader
batch_size = 128
num_workers = 4
rxrx3_core_dataloader = DataLoader(rxrx3_core, batch_size=batch_size*6, shuffle=False, 
 collate_fn=collate_rxrx3_core, num_workers=num_workers)

In [None]:
# Inference loop
num_features = 384
n_crops = 4
well_ids = []
emb_ind = 0
embeddings = np.zeros(
 ((len(rxrx3_core_dataloader.dataset)//6), num_features), dtype=np.float32
)
forward_pass_counter = 0

for imgs, batch_well_ids in rxrx3_core_dataloader:

 if cuda_available:
 with torch.amp.autocast("cuda"), torch.no_grad():
 latent = open_phenom.predict(imgs.cuda())
 else:
 latent = open_phenom.predict(imgs)
 
 latent = latent.view(-1, n_crops, num_features).mean(dim=1) # average over 4 256x256 crops per image
 embeddings[emb_ind : (emb_ind + len(latent))] = latent.detach().cpu().numpy()
 well_ids.extend(batch_well_ids)

 emb_ind += len(latent)
 forward_pass_counter += 1
 if forward_pass_counter % 5 == 0:
 print(f"forward pass {forward_pass_counter} of {len(rxrx3_core_dataloader)} done, wells inferenced {emb_ind}")

embedding_df = embeddings[:emb_ind]
embedding_df = pd.DataFrame(embedding_df)
embedding_df.columns = [f"feature_{i}" for i in range(num_features)]
embedding_df['well_id'] = well_ids
embedding_df = embedding_df[['well_id']+[f"feature_{i}" for i in range(num_features)]]
embedding_df.to_parquet('OpenPhenom_rxrx3-core_embeddings.parquet')