okraus88 commited on
Commit
7de425e
·
verified ·
1 Parent(s): 3b3f6e0

Upload RxRx3-core_inference.ipynb

Browse files
Files changed (1) hide show
  1. RxRx3-core_inference.ipynb +195 -0
RxRx3-core_inference.ipynb ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "5edcb7d2-53dc-4170-9f2f-619c0da0ae4c",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import torch\n",
11
+ "import numpy as np\n",
12
+ "from torch.utils.data import DataLoader\n",
13
+ "import pandas as pd"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "markdown",
18
+ "id": "f839c8fb-b018-4ab6-86a9-7d5bf7883b45",
19
+ "metadata": {},
20
+ "source": [
21
+ "# Load OpenPhenom"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "id": "84b9324d-fde9-4c43-bc5a-eb66cdb4f891",
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "# Load model directly\n",
32
+ "from huggingface_mae import MAEModel\n",
33
+ "open_phenom = MAEModel.from_pretrained(\".\")"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "id": "57d918c5-78de-4b36-9f46-4652c5da93f2",
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "open_phenom.eval()\n",
44
+ "cuda_available = torch.cuda.is_available()\n",
45
+ "if cuda_available:\n",
46
+ " open_phenom.cuda()"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "markdown",
51
+ "id": "7c89d82d-5365-4492-b496-adb3bbd71b32",
52
+ "metadata": {},
53
+ "source": [
54
+ "# Load Rxrx3-core"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "id": "deeff3a8-db67-4905-a7e9-c43aad614a84",
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "from datasets import load_dataset\n",
65
+ "rxrx3_core = load_dataset(\"recursionpharma/rxrx3-core\")['train']"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "markdown",
70
+ "id": "8f2226ce-9415-4dd8-932e-54e4e1bd8c1a",
71
+ "metadata": {},
72
+ "source": [
73
+ "# Infernce loop"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "id": "aa1218ab-f9cd-413b-9228-c1146df978be",
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "def convert_path_to_well_id(path_str):\n",
84
+ " \n",
85
+ " return path_str.split('_')[0].replace('/','_').replace('Plate','')\n",
86
+ " \n",
87
+ "def collate_rxrx3_core(batch):\n",
88
+ " \n",
89
+ " images = np.stack([np.array(i['jp2']) for i in batch]).reshape(-1,6,512,512)\n",
90
+ " images = np.vstack([patch_image(i) for i in images]) # convert to 4 256x256 patches\n",
91
+ " images = torch.from_numpy(images)\n",
92
+ " well_ids = [convert_path_to_well_id(i['__key__']) for i in batch[::6]]\n",
93
+ " return images, well_ids\n",
94
+ "\n",
95
+ "def iter_border_patches(width, height, patch_size):\n",
96
+ " \n",
97
+ " x_start, x_end, y_start, y_end = (0, width, 0, height)\n",
98
+ "\n",
99
+ " for x in range(x_start, x_end - patch_size + 1, patch_size):\n",
100
+ " for y in range(y_start, y_end - patch_size + 1, patch_size):\n",
101
+ " yield x, y\n",
102
+ "\n",
103
+ "def patch_image(image_array, patch_size=256):\n",
104
+ " \n",
105
+ " _, width, height = image_array.shape\n",
106
+ " output_patches = []\n",
107
+ " patch_count = 0\n",
108
+ " for x, y in iter_border_patches(width, height, patch_size):\n",
109
+ " patch = image_array[:, y : y + patch_size, x : x + patch_size].copy()\n",
110
+ " output_patches.append(patch)\n",
111
+ " \n",
112
+ " output_patches = np.stack(output_patches)\n",
113
+ " \n",
114
+ " return output_patches"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": null,
120
+ "id": "de308003-bcfc-4b59-9715-dd884b9b2536",
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": [
124
+ "# Convert to PyTorch DataLoader\n",
125
+ "batch_size = 128\n",
126
+ "num_workers = 4\n",
127
+ "rxrx3_core_dataloader = DataLoader(rxrx3_core, batch_size=batch_size*6, shuffle=False, \n",
128
+ " collate_fn=collate_rxrx3_core, num_workers=num_workers)"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": null,
134
+ "id": "9e3ea6c2-d1aa-4e20-a175-d72ea636153e",
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": [
138
+ "# Inference loop\n",
139
+ "num_features = 384\n",
140
+ "n_crops = 4\n",
141
+ "well_ids = []\n",
142
+ "emb_ind = 0\n",
143
+ "embeddings = np.zeros(\n",
144
+ " ((len(rxrx3_core_dataloader.dataset)//6), num_features), dtype=np.float32\n",
145
+ ")\n",
146
+ "forward_pass_counter = 0\n",
147
+ "\n",
148
+ "for imgs, batch_well_ids in rxrx3_core_dataloader:\n",
149
+ "\n",
150
+ " if cuda_available:\n",
151
+ " with torch.amp.autocast(\"cuda\"), torch.no_grad():\n",
152
+ " latent = open_phenom.predict(imgs.cuda())\n",
153
+ " else:\n",
154
+ " latent = open_phenom.predict(imgs)\n",
155
+ " \n",
156
+ " latent = latent.view(-1, n_crops, num_features).mean(dim=1) # average over 4 256x256 crops per image\n",
157
+ " embeddings[emb_ind : (emb_ind + len(latent))] = latent.detach().cpu().numpy()\n",
158
+ " well_ids.extend(batch_well_ids)\n",
159
+ "\n",
160
+ " emb_ind += len(latent)\n",
161
+ " forward_pass_counter += 1\n",
162
+ " if forward_pass_counter % 5 == 0:\n",
163
+ " print(f\"forward pass {forward_pass_counter} of {len(rxrx3_core_dataloader)} done, wells inferenced {emb_ind}\")\n",
164
+ "\n",
165
+ "embedding_df = embeddings[:emb_ind]\n",
166
+ "embedding_df = pd.DataFrame(embedding_df)\n",
167
+ "embedding_df.columns = [f\"feature_{i}\" for i in range(num_features)]\n",
168
+ "embedding_df['well_id'] = well_ids\n",
169
+ "embedding_df = embedding_df[['well_id']+[f\"feature_{i}\" for i in range(num_features)]]\n",
170
+ "embedding_df.to_parquet('OpenPhenom_rxrx3-core_embeddings.parquet')"
171
+ ]
172
+ }
173
+ ],
174
+ "metadata": {
175
+ "kernelspec": {
176
+ "display_name": "photo2",
177
+ "language": "python",
178
+ "name": "photo2"
179
+ },
180
+ "language_info": {
181
+ "codemirror_mode": {
182
+ "name": "ipython",
183
+ "version": 3
184
+ },
185
+ "file_extension": ".py",
186
+ "mimetype": "text/x-python",
187
+ "name": "python",
188
+ "nbconvert_exporter": "python",
189
+ "pygments_lexer": "ipython3",
190
+ "version": "3.10.14"
191
+ }
192
+ },
193
+ "nbformat": 4,
194
+ "nbformat_minor": 5
195
+ }