{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "5edcb7d2-53dc-4170-9f2f-619c0da0ae4c", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import numpy as np\n", "from torch.utils.data import DataLoader\n", "import pandas as pd" ] }, { "cell_type": "markdown", "id": "f839c8fb-b018-4ab6-86a9-7d5bf7883b45", "metadata": {}, "source": [ "# Load OpenPhenom" ] }, { "cell_type": "code", "execution_count": null, "id": "84b9324d-fde9-4c43-bc5a-eb66cdb4f891", "metadata": {}, "outputs": [], "source": [ "# Load model directly\n", "from huggingface_mae import MAEModel\n", "open_phenom = MAEModel.from_pretrained(\"recursionpharma/OpenPhenom\")" ] }, { "cell_type": "code", "execution_count": null, "id": "57d918c5-78de-4b36-9f46-4652c5da93f2", "metadata": {}, "outputs": [], "source": [ "open_phenom.eval()\n", "cuda_available = torch.cuda.is_available()\n", "if cuda_available:\n", " open_phenom.cuda()" ] }, { "cell_type": "markdown", "id": "7c89d82d-5365-4492-b496-adb3bbd71b32", "metadata": {}, "source": [ "# Load Rxrx3-core" ] }, { "cell_type": "code", "execution_count": null, "id": "deeff3a8-db67-4905-a7e9-c43aad614a84", "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "rxrx3_core = load_dataset(\"recursionpharma/rxrx3-core\")['train']" ] }, { "cell_type": "markdown", "id": "8f2226ce-9415-4dd8-932e-54e4e1bd8c1a", "metadata": {}, "source": [ "# Infernce loop" ] }, { "cell_type": "code", "execution_count": null, "id": "aa1218ab-f9cd-413b-9228-c1146df978be", "metadata": {}, "outputs": [], "source": [ "def convert_path_to_well_id(path_str):\n", " \n", " return path_str.split('_')[0].replace('/','_').replace('Plate','')\n", " \n", "def collate_rxrx3_core(batch):\n", " \n", " images = np.stack([np.array(i['jp2']) for i in batch]).reshape(-1,6,512,512)\n", " images = np.vstack([patch_image(i) for i in images]) # convert to 4 256x256 patches\n", " images = torch.from_numpy(images)\n", " well_ids = [convert_path_to_well_id(i['__key__']) for i in batch[::6]]\n", " return images, well_ids\n", "\n", "def iter_border_patches(width, height, patch_size):\n", " \n", " x_start, x_end, y_start, y_end = (0, width, 0, height)\n", "\n", " for x in range(x_start, x_end - patch_size + 1, patch_size):\n", " for y in range(y_start, y_end - patch_size + 1, patch_size):\n", " yield x, y\n", "\n", "def patch_image(image_array, patch_size=256):\n", " \n", " _, width, height = image_array.shape\n", " output_patches = []\n", " patch_count = 0\n", " for x, y in iter_border_patches(width, height, patch_size):\n", " patch = image_array[:, y : y + patch_size, x : x + patch_size].copy()\n", " output_patches.append(patch)\n", " \n", " output_patches = np.stack(output_patches)\n", " \n", " return output_patches" ] }, { "cell_type": "code", "execution_count": null, "id": "de308003-bcfc-4b59-9715-dd884b9b2536", "metadata": {}, "outputs": [], "source": [ "# Convert to PyTorch DataLoader\n", "batch_size = 128\n", "num_workers = 4\n", "rxrx3_core_dataloader = DataLoader(rxrx3_core, batch_size=batch_size*6, shuffle=False, \n", " collate_fn=collate_rxrx3_core, num_workers=num_workers)" ] }, { "cell_type": "code", "execution_count": null, "id": "9e3ea6c2-d1aa-4e20-a175-d72ea636153e", "metadata": {}, "outputs": [], "source": [ "# Inference loop\n", "num_features = 384\n", "n_crops = 4\n", "well_ids = []\n", "emb_ind = 0\n", "embeddings = np.zeros(\n", " ((len(rxrx3_core_dataloader.dataset)//6), num_features), dtype=np.float32\n", ")\n", "forward_pass_counter = 0\n", "\n", "for imgs, batch_well_ids in rxrx3_core_dataloader:\n", "\n", " if cuda_available:\n", " with torch.amp.autocast(\"cuda\"), torch.no_grad():\n", " latent = open_phenom.predict(imgs.cuda())\n", " else:\n", " latent = open_phenom.predict(imgs)\n", " \n", " latent = latent.view(-1, n_crops, num_features).mean(dim=1) # average over 4 256x256 crops per image\n", " embeddings[emb_ind : (emb_ind + len(latent))] = latent.detach().cpu().numpy()\n", " well_ids.extend(batch_well_ids)\n", "\n", " emb_ind += len(latent)\n", " forward_pass_counter += 1\n", " if forward_pass_counter % 5 == 0:\n", " print(f\"forward pass {forward_pass_counter} of {len(rxrx3_core_dataloader)} done, wells inferenced {emb_ind}\")\n", "\n", "embedding_df = embeddings[:emb_ind]\n", "embedding_df = pd.DataFrame(embedding_df)\n", "embedding_df.columns = [f\"feature_{i}\" for i in range(num_features)]\n", "embedding_df['well_id'] = well_ids\n", "embedding_df = embedding_df[['well_id']+[f\"feature_{i}\" for i in range(num_features)]]\n", "embedding_df.to_parquet('OpenPhenom_rxrx3-core_embeddings.parquet')" ] } ], "metadata": { "kernelspec": { "display_name": "photo2", "language": "python", "name": "photo2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 5 }