boris commited on
Commit
c85fbb6
·
1 Parent(s): 9a553a4
Files changed (1) hide show
  1. dev/inference/wandb-backend.ipynb +74 -27
dev/inference/wandb-backend.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": null,
6
  "id": "4ff2a984-b8b2-4a69-89cf-0d16da2393c8",
7
  "metadata": {},
8
  "outputs": [],
@@ -12,7 +12,7 @@
12
  "import random\n",
13
  "import numpy as np\n",
14
  "from PIL import Image\n",
15
- "from tqdm import tqdm\n",
16
  "import jax\n",
17
  "import jax.numpy as jnp\n",
18
  "from flax.training.common_utils import shard, shard_prng_key\n",
@@ -26,7 +26,7 @@
26
  },
27
  {
28
  "cell_type": "code",
29
- "execution_count": null,
30
  "id": "92f4557c-fd7f-4edc-81c2-de0b0a10c270",
31
  "metadata": {},
32
  "outputs": [],
@@ -36,13 +36,13 @@
36
  "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
37
  "normalize_text = True\n",
38
  "latest_only = False # log only latest or all versions\n",
39
- "suffix = '_1' # mainly for duplicate inference runs with a deleted version\n",
40
  "add_clip_32 = False"
41
  ]
42
  },
43
  {
44
  "cell_type": "code",
45
- "execution_count": null,
46
  "id": "23e00271-941c-4e1b-b6a9-107a1b77324d",
47
  "metadata": {},
48
  "outputs": [],
@@ -52,16 +52,25 @@
52
  "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
53
  "normalize_text = False\n",
54
  "latest_only = True # log only latest or all versions\n",
55
- "suffix = '_2' # mainly for duplicate inference runs with a deleted version\n",
56
  "add_clip_32 = True"
57
  ]
58
  },
59
  {
60
  "cell_type": "code",
61
- "execution_count": null,
62
  "id": "93b2e24b-f0e5-4abe-a3ec-0aa834cc3bf3",
63
  "metadata": {},
64
- "outputs": [],
 
 
 
 
 
 
 
 
 
65
  "source": [
66
  "batch_size = 8\n",
67
  "num_images = 128\n",
@@ -75,10 +84,18 @@
75
  },
76
  {
77
  "cell_type": "code",
78
- "execution_count": null,
79
  "id": "c6a878fa-4bf5-4978-abb5-e235841d765b",
80
  "metadata": {},
81
- "outputs": [],
 
 
 
 
 
 
 
 
82
  "source": [
83
  "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
84
  "clip = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
@@ -94,7 +111,7 @@
94
  },
95
  {
96
  "cell_type": "code",
97
- "execution_count": null,
98
  "id": "a500dd07-dbc3-477d-80d4-2b73a3b83ef3",
99
  "metadata": {},
100
  "outputs": [],
@@ -104,20 +121,42 @@
104
  " return vqgan.decode_code(indices, params=params)\n",
105
  "\n",
106
  "@partial(jax.pmap, axis_name=\"batch\")\n",
107
- "def p_clip(inputs):\n",
108
- " logits = clip(params=clip_params, **inputs).logits_per_image\n",
109
  " return logits\n",
110
  "\n",
111
  "if add_clip_32:\n",
112
  " @partial(jax.pmap, axis_name=\"batch\")\n",
113
- " def p_clip32(inputs):\n",
114
- " logits = clip32(params=clip32_params, **inputs).logits_per_image\n",
115
  " return logits"
116
  ]
117
  },
118
  {
119
  "cell_type": "code",
120
- "execution_count": null,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  "id": "e57797ab-0b3a-4490-be58-03d8d1c23fe9",
122
  "metadata": {},
123
  "outputs": [],
@@ -133,7 +172,7 @@
133
  },
134
  {
135
  "cell_type": "code",
136
- "execution_count": null,
137
  "id": "f3e02d9d-4ee1-49e7-a7bc-4d8b139e9614",
138
  "metadata": {},
139
  "outputs": [],
@@ -150,7 +189,7 @@
150
  },
151
  {
152
  "cell_type": "code",
153
- "execution_count": null,
154
  "id": "f0d7ed17-7abb-4a31-ab3c-a12b9039a570",
155
  "metadata": {},
156
  "outputs": [],
@@ -163,7 +202,7 @@
163
  },
164
  {
165
  "cell_type": "code",
166
- "execution_count": null,
167
  "id": "7e784a43-626d-4e8d-9e47-a23775b2f35f",
168
  "metadata": {},
169
  "outputs": [],
@@ -179,7 +218,7 @@
179
  },
180
  {
181
  "cell_type": "code",
182
- "execution_count": null,
183
  "id": "d1cc9993-1bfc-4ec6-a004-c056189c42ac",
184
  "metadata": {},
185
  "outputs": [],
@@ -202,7 +241,7 @@
202
  },
203
  {
204
  "cell_type": "code",
205
- "execution_count": null,
206
  "id": "23b2444c-67a9-44d7-abd1-187ed83a9431",
207
  "metadata": {},
208
  "outputs": [],
@@ -213,10 +252,19 @@
213
  },
214
  {
215
  "cell_type": "code",
216
- "execution_count": null,
217
  "id": "bba70f33-af8b-4eb3-9973-7be672301a0b",
218
  "metadata": {},
219
- "outputs": [],
 
 
 
 
 
 
 
 
 
220
  "source": [
221
  "artifact_versions = get_artifact_versions(run_id, latest_only)\n",
222
  "last_inference_version = get_last_inference_version(run_id)\n",
@@ -276,9 +324,8 @@
276
  " tokenized_prompt = shard(tokenized_prompt)\n",
277
  "\n",
278
  " # generate images\n",
279
- " print('Generating images')\n",
280
  " images = []\n",
281
- " for i in tqdm(range(num_images // jax.device_count())):\n",
282
  " key, subkey = jax.random.split(key)\n",
283
  " encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params)\n",
284
  " encoded_images = encoded_images.sequences[..., 1:]\n",
@@ -294,7 +341,7 @@
294
  " images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
295
  " clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
296
  " clip_inputs = shard(clip_inputs)\n",
297
- " logits = p_clip(clip_inputs)\n",
298
  " logits = logits.reshape(-1, num_images)\n",
299
  " top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
300
  " logits = jax.device_get(logits)\n",
@@ -314,7 +361,7 @@
314
  " images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
315
  " clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
316
  " clip_inputs = shard(clip_inputs)\n",
317
- " logits = p_clip32(clip_inputs)\n",
318
  " logits = logits.reshape(-1, num_images)\n",
319
  " top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
320
  " logits = jax.device_get(logits)\n",
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 1,
6
  "id": "4ff2a984-b8b2-4a69-89cf-0d16da2393c8",
7
  "metadata": {},
8
  "outputs": [],
 
12
  "import random\n",
13
  "import numpy as np\n",
14
  "from PIL import Image\n",
15
+ "from tqdm.notebook import tqdm\n",
16
  "import jax\n",
17
  "import jax.numpy as jnp\n",
18
  "from flax.training.common_utils import shard, shard_prng_key\n",
 
26
  },
27
  {
28
  "cell_type": "code",
29
+ "execution_count": 2,
30
  "id": "92f4557c-fd7f-4edc-81c2-de0b0a10c270",
31
  "metadata": {},
32
  "outputs": [],
 
36
  "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
37
  "normalize_text = True\n",
38
  "latest_only = False # log only latest or all versions\n",
39
+ "suffix = '' # mainly for duplicate inference runs with a deleted version\n",
40
  "add_clip_32 = False"
41
  ]
42
  },
43
  {
44
  "cell_type": "code",
45
+ "execution_count": 3,
46
  "id": "23e00271-941c-4e1b-b6a9-107a1b77324d",
47
  "metadata": {},
48
  "outputs": [],
 
52
  "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
53
  "normalize_text = False\n",
54
  "latest_only = True # log only latest or all versions\n",
55
+ "suffix = '' # mainly for duplicate inference runs with a deleted version\n",
56
  "add_clip_32 = True"
57
  ]
58
  },
59
  {
60
  "cell_type": "code",
61
+ "execution_count": 4,
62
  "id": "93b2e24b-f0e5-4abe-a3ec-0aa834cc3bf3",
63
  "metadata": {},
64
+ "outputs": [
65
+ {
66
+ "name": "stderr",
67
+ "output_type": "stream",
68
+ "text": [
69
+ "INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: \n",
70
+ "INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: \"cuda\". Available platform names are: TPU Interpreter Host\n"
71
+ ]
72
+ }
73
+ ],
74
  "source": [
75
  "batch_size = 8\n",
76
  "num_images = 128\n",
 
84
  },
85
  {
86
  "cell_type": "code",
87
+ "execution_count": 5,
88
  "id": "c6a878fa-4bf5-4978-abb5-e235841d765b",
89
  "metadata": {},
90
+ "outputs": [
91
+ {
92
+ "name": "stdout",
93
+ "output_type": "stream",
94
+ "text": [
95
+ "Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
96
+ ]
97
+ }
98
+ ],
99
  "source": [
100
  "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
101
  "clip = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
 
111
  },
112
  {
113
  "cell_type": "code",
114
+ "execution_count": 6,
115
  "id": "a500dd07-dbc3-477d-80d4-2b73a3b83ef3",
116
  "metadata": {},
117
  "outputs": [],
 
121
  " return vqgan.decode_code(indices, params=params)\n",
122
  "\n",
123
  "@partial(jax.pmap, axis_name=\"batch\")\n",
124
+ "def p_clip(inputs, params):\n",
125
+ " logits = clip(params=params, **inputs).logits_per_image\n",
126
  " return logits\n",
127
  "\n",
128
  "if add_clip_32:\n",
129
  " @partial(jax.pmap, axis_name=\"batch\")\n",
130
+ " def p_clip32(inputs, params):\n",
131
+ " logits = clip32(params=params, **inputs).logits_per_image\n",
132
  " return logits"
133
  ]
134
  },
135
  {
136
  "cell_type": "code",
137
+ "execution_count": 7,
138
+ "id": "ebf4f7bf-2efa-46cc-b3f4-2d7a54f7b2cb",
139
+ "metadata": {},
140
+ "outputs": [
141
+ {
142
+ "data": {
143
+ "text/plain": [
144
+ "ShardedDeviceArray([4.6051702, 4.6051702, 4.6051702, 4.6051702, 4.6051702,\n",
145
+ " 4.6051702, 4.6051702, 4.6051702], dtype=float32)"
146
+ ]
147
+ },
148
+ "execution_count": 7,
149
+ "metadata": {},
150
+ "output_type": "execute_result"
151
+ }
152
+ ],
153
+ "source": [
154
+ "clip_params['logit_scale']"
155
+ ]
156
+ },
157
+ {
158
+ "cell_type": "code",
159
+ "execution_count": 8,
160
  "id": "e57797ab-0b3a-4490-be58-03d8d1c23fe9",
161
  "metadata": {},
162
  "outputs": [],
 
172
  },
173
  {
174
  "cell_type": "code",
175
+ "execution_count": 9,
176
  "id": "f3e02d9d-4ee1-49e7-a7bc-4d8b139e9614",
177
  "metadata": {},
178
  "outputs": [],
 
189
  },
190
  {
191
  "cell_type": "code",
192
+ "execution_count": 10,
193
  "id": "f0d7ed17-7abb-4a31-ab3c-a12b9039a570",
194
  "metadata": {},
195
  "outputs": [],
 
202
  },
203
  {
204
  "cell_type": "code",
205
+ "execution_count": 11,
206
  "id": "7e784a43-626d-4e8d-9e47-a23775b2f35f",
207
  "metadata": {},
208
  "outputs": [],
 
218
  },
219
  {
220
  "cell_type": "code",
221
+ "execution_count": 12,
222
  "id": "d1cc9993-1bfc-4ec6-a004-c056189c42ac",
223
  "metadata": {},
224
  "outputs": [],
 
241
  },
242
  {
243
  "cell_type": "code",
244
+ "execution_count": 13,
245
  "id": "23b2444c-67a9-44d7-abd1-187ed83a9431",
246
  "metadata": {},
247
  "outputs": [],
 
252
  },
253
  {
254
  "cell_type": "code",
255
+ "execution_count": 14,
256
  "id": "bba70f33-af8b-4eb3-9973-7be672301a0b",
257
  "metadata": {},
258
+ "outputs": [
259
+ {
260
+ "ename": "SyntaxError",
261
+ "evalue": "EOL while scanning string literal (1745443972.py, line 60)",
262
+ "output_type": "error",
263
+ "traceback": [
264
+ "\u001b[0;36m File \u001b[0;32m\"/tmp/ipykernel_402605/1745443972.py\"\u001b[0;36m, line \u001b[0;32m60\u001b[0m\n\u001b[0;31m for i in tqdm(range(num_images // jax.device_count()), desc='Generating Images):\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m EOL while scanning string literal\n"
265
+ ]
266
+ }
267
+ ],
268
  "source": [
269
  "artifact_versions = get_artifact_versions(run_id, latest_only)\n",
270
  "last_inference_version = get_last_inference_version(run_id)\n",
 
324
  " tokenized_prompt = shard(tokenized_prompt)\n",
325
  "\n",
326
  " # generate images\n",
 
327
  " images = []\n",
328
+ " for i in tqdm(range(num_images // jax.device_count()), desc='Generating Images):\n",
329
  " key, subkey = jax.random.split(key)\n",
330
  " encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params)\n",
331
  " encoded_images = encoded_images.sequences[..., 1:]\n",
 
341
  " images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
342
  " clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
343
  " clip_inputs = shard(clip_inputs)\n",
344
+ " logits = p_clip(clip_inputs, clip_params)\n",
345
  " logits = logits.reshape(-1, num_images)\n",
346
  " top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
347
  " logits = jax.device_get(logits)\n",
 
361
  " images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
362
  " clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
363
  " clip_inputs = shard(clip_inputs)\n",
364
+ " logits = p_clip32(clip_inputs, clip32_params)\n",
365
  " logits = logits.reshape(-1, num_images)\n",
366
  " top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
367
  " logits = jax.device_get(logits)\n",