boris commited on
Commit
ff051c9
1 Parent(s): bf3640d

feat: allow latest version only

Browse files
Files changed (1) hide show
  1. dev/inference/wandb-backend.ipynb +154 -358
dev/inference/wandb-backend.ipynb CHANGED
@@ -32,9 +32,25 @@
32
  "metadata": {},
33
  "outputs": [],
34
  "source": [
35
- "wandb_runs = ['rjf3rycy']\n",
 
36
  "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
37
- "normalize_text = True"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  ]
39
  },
40
  {
@@ -104,18 +120,6 @@
104
  " samples = [samples[i:i+batch_size] for i in range(0, len(samples), batch_size)]"
105
  ]
106
  },
107
- {
108
- "cell_type": "code",
109
- "execution_count": null,
110
- "id": "3ffb1d09-bd1c-4f57-9ae5-3eda6f7d3a08",
111
- "metadata": {},
112
- "outputs": [],
113
- "source": [
114
- "# TODO: iterate on runs\n",
115
- "wandb_run = wandb_runs[0]\n",
116
- "model_pmapped = False"
117
- ]
118
- },
119
  {
120
  "cell_type": "code",
121
  "execution_count": null,
@@ -123,12 +127,14 @@
123
  "metadata": {},
124
  "outputs": [],
125
  "source": [
126
- "def get_artifact_versions(run_id):\n",
127
  " try:\n",
128
- " versions = api.artifact_versions(type_name='bart_model', name=f'dalle-mini/dalle-mini/model-{run_id}', per_page=10000)\n",
 
 
 
129
  " except:\n",
130
- " versions = []\n",
131
- " return versions"
132
  ]
133
  },
134
  {
@@ -139,7 +145,7 @@
139
  "outputs": [],
140
  "source": [
141
  "def get_training_config(run_id):\n",
142
- " training_run = api.run(f'dalle-mini/dalle-mini/{run_id}')\n",
143
  " config = training_run.config\n",
144
  " return config"
145
  ]
@@ -155,7 +161,7 @@
155
  "def get_last_inference_version(run_id):\n",
156
  " try:\n",
157
  " inference_run = api.run(f'dalle-mini/dalle-mini/inference-{run_id}')\n",
158
- " return inference_run.summary.get('_step', None)\n",
159
  " except:\n",
160
  " return None"
161
  ]
@@ -186,68 +192,142 @@
186
  {
187
  "cell_type": "code",
188
  "execution_count": null,
189
- "id": "bba70f33-af8b-4eb3-9973-7be672301a0b",
190
  "metadata": {},
191
  "outputs": [],
192
  "source": [
193
- "def log_run(run_id):\n",
194
- " artifact_versions = get_artifact_versions(run_id)\n",
195
- " last_inference_version = get_last_inference_version(run_id)\n",
196
- " training_config = get_training_config(run_id)\n",
197
- " run = None\n",
198
- " p_generate = None\n",
199
- " model_files = ['config.json', 'flax_model.msgpack', 'merges.txt', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json', 'vocab.json']\n",
200
- " for artifact in artifact_versions:\n",
201
- " print(f'Processing artifact: {artifact.name}')\n",
202
- " version = int(artifact.version[1:])\n",
203
- " if last_version_inference is None:\n",
204
- " # we should start from v0\n",
205
- " assert version == 0\n",
206
- " elif version <= last_version_inference:\n",
207
- " print(f'v{version} has already been logged (versions logged up to v{last_version_inference}')\n",
208
- " else:\n",
209
- " # check we are logging the correct version\n",
210
- " assert version == last_version_inference + 1\n",
211
- " \n",
212
- " # start/resume corresponding run\n",
213
- " if run is None:\n",
214
- " run = wandb.init(job_type='inference', config=config, id=f'inference-{wandb_run}', resume='allow')\n",
215
- " \n",
216
- " # work in temporary directory\n",
217
- " with tempfile.TemporaryDirectory() as tmp:\n",
218
- " \n",
219
- " # download model files\n",
220
- " artifact = run.use_artifact(artifact)\n",
221
- " for f in model_files:\n",
222
- " artifact.get_path(f).download(tmp)\n",
223
- " \n",
224
- " # load tokenizer and model\n",
225
- " tokenizer = BartTokenizer.from_pretrained(tmp)\n",
226
- " model = CustomFlaxBartForConditionalGeneration.from_pretrained(tmp)\n",
227
- " model_params = replicate(model.params)\n",
228
- " \n",
229
- " # pmap model function needs to happen only once per model config\n",
230
- " if p_generate is None:\n",
231
- " p_generate = pmap_model_function(model)\n",
232
- " \n",
233
- " for batch in tqdm(samples):\n",
234
- " prompts = [x['Caption'] for x in batch]\n",
235
- " processed_prompts = [text_normalizer(x) for x in prompts] if normalize_text else prompts\n",
236
- " \n",
237
- "\n",
238
- " \n",
239
- " \n",
240
- " "
241
  ]
242
  },
243
  {
244
  "cell_type": "code",
245
  "execution_count": null,
246
- "id": "4d542342-3232-48a5-a0aa-3cb5c157aa8c",
247
  "metadata": {},
248
- "outputs": [],
249
- "source": [
250
- "log_run(wandb_run)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  ]
252
  },
253
  {
@@ -257,296 +337,12 @@
257
  "metadata": {},
258
  "outputs": [],
259
  "source": [
 
260
  "def log_runs(runs):\n",
261
  " for run in tqdm(runs):\n",
262
  " log_run(run)"
263
  ]
264
  },
265
- {
266
- "cell_type": "code",
267
- "execution_count": null,
268
- "id": "7a24b903-777b-4e3d-817c-00ed613a7021",
269
- "metadata": {},
270
- "outputs": [],
271
- "source": [
272
- "# TODO: loop over samples\n",
273
- "batch = samples[0]\n",
274
- "prompts = [x['Caption'] for x in batch]\n",
275
- "processed_prompts = [text_normalizer(x) for x in prompts] if normalize_text else prompts"
276
- ]
277
- },
278
- {
279
- "cell_type": "code",
280
- "execution_count": null,
281
- "id": "d77aa785-dc05-4070-aba2-aa007524d20b",
282
- "metadata": {},
283
- "outputs": [],
284
- "source": [
285
- "processed_prompts"
286
- ]
287
- },
288
- {
289
- "cell_type": "code",
290
- "execution_count": null,
291
- "id": "95db38fb-8948-4814-98ae-c172ca7c6d0a",
292
- "metadata": {},
293
- "outputs": [],
294
- "source": [
295
- "repeated_prompts = processed_prompts * jax.device_count()"
296
- ]
297
- },
298
- {
299
- "cell_type": "code",
300
- "execution_count": null,
301
- "id": "e948ba9e-3700-4e87-926f-580a10f3e5cd",
302
- "metadata": {},
303
- "outputs": [],
304
- "source": [
305
- "tokenized_prompt = tokenizer(repeated_prompts, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data\n",
306
- "tokenized_prompt = shard(tokenized_prompt)"
307
- ]
308
- },
309
- {
310
- "cell_type": "code",
311
- "execution_count": null,
312
- "id": "30d96812-fc17-4acf-bb64-5fdb8d0cd313",
313
- "metadata": {},
314
- "outputs": [],
315
- "source": [
316
- "tokenized_prompt['input_ids'].shape"
317
- ]
318
- },
319
- {
320
- "cell_type": "code",
321
- "execution_count": null,
322
- "id": "92ea034b-2649-4d18-ab6d-877ed04ae5c4",
323
- "metadata": {},
324
- "outputs": [],
325
- "source": [
326
- "images = []\n",
327
- "for i in range(num_images // jax.device_count()):\n",
328
- " key, subkey = jax.random.split(key, 2)\n",
329
- " \n",
330
- " encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params)\n",
331
- " encoded_images = encoded_images.sequences[..., 1:]\n",
332
- " \n",
333
- " decoded_images = p_decode(encoded_images, vqgan_params)\n",
334
- " decoded_images = decoded_images.clip(0., 1.).reshape((-1, 256, 256, 3))\n",
335
- " \n",
336
- " for img in decoded_images:\n",
337
- " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))\n",
338
- " "
339
- ]
340
- },
341
- {
342
- "cell_type": "code",
343
- "execution_count": null,
344
- "id": "84d52f30-44c9-4a74-9992-fb2578f19b90",
345
- "metadata": {},
346
- "outputs": [],
347
- "source": [
348
- "len(images)"
349
- ]
350
- },
351
- {
352
- "cell_type": "code",
353
- "execution_count": null,
354
- "id": "beb594f9-5b91-47fe-98bd-41e68c6b1d73",
355
- "metadata": {},
356
- "outputs": [],
357
- "source": [
358
- "images[0]"
359
- ]
360
- },
361
- {
362
- "cell_type": "code",
363
- "execution_count": null,
364
- "id": "bb135190-64e5-44af-b416-e688b034da44",
365
- "metadata": {},
366
- "outputs": [],
367
- "source": [
368
- "images[1]"
369
- ]
370
- },
371
- {
372
- "cell_type": "code",
373
- "execution_count": null,
374
- "id": "d78a0d92-72c2-4f82-a6ab-b3f5865dd863",
375
- "metadata": {},
376
- "outputs": [],
377
- "source": [
378
- "clip_inputs = processor(text=prompts, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data"
379
- ]
380
- },
381
- {
382
- "cell_type": "code",
383
- "execution_count": null,
384
- "id": "89ff78a6-bfa4-44d9-ad66-07a4a68b4352",
385
- "metadata": {},
386
- "outputs": [],
387
- "source": [
388
- "# each shard will have one prompt\n",
389
- "clip_inputs['input_ids'].shape"
390
- ]
391
- },
392
- {
393
- "cell_type": "code",
394
- "execution_count": null,
395
- "id": "2cda8984-049c-4c87-96ad-7b0412750656",
396
- "metadata": {},
397
- "outputs": [],
398
- "source": [
399
- "# each shard needs to have the images corresponding to a specific prompt\n",
400
- "clip_inputs['pixel_values'].shape"
401
- ]
402
- },
403
- {
404
- "cell_type": "code",
405
- "execution_count": null,
406
- "id": "0a044e8f-be29-404b-b6c7-8f2395c5efc6",
407
- "metadata": {},
408
- "outputs": [],
409
- "source": [
410
- "images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
411
- "images_per_prompt_indices"
412
- ]
413
- },
414
- {
415
- "cell_type": "code",
416
- "execution_count": null,
417
- "id": "7a6c61b3-12e0-45d8-b39a-830288324d3d",
418
- "metadata": {},
419
- "outputs": [],
420
- "source": []
421
- },
422
- {
423
- "cell_type": "code",
424
- "execution_count": null,
425
- "id": "7318e67e-4214-46f9-bf60-6d139d4bd00f",
426
- "metadata": {},
427
- "outputs": [],
428
- "source": [
429
- "# reorder so each shard will have correct images\n",
430
- "clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))"
431
- ]
432
- },
433
- {
434
- "cell_type": "code",
435
- "execution_count": null,
436
- "id": "90c949a2-8e2a-4905-b6d4-92038f1704b8",
437
- "metadata": {},
438
- "outputs": [],
439
- "source": [
440
- "clip_inputs = shard(clip_inputs)"
441
- ]
442
- },
443
- {
444
- "cell_type": "code",
445
- "execution_count": null,
446
- "id": "58fa836e-5ebb-45e7-af77-ab10646dfbfb",
447
- "metadata": {},
448
- "outputs": [],
449
- "source": [
450
- "logits = p_clip(clip_inputs)"
451
- ]
452
- },
453
- {
454
- "cell_type": "code",
455
- "execution_count": null,
456
- "id": "fd7a3f91-3a1f-4a0a-8b3e-3c926cd367fb",
457
- "metadata": {},
458
- "outputs": [],
459
- "source": [
460
- "logits.shape"
461
- ]
462
- },
463
- {
464
- "cell_type": "code",
465
- "execution_count": null,
466
- "id": "fa406db7-0a21-4e4b-9890-4c7aece4280c",
467
- "metadata": {},
468
- "outputs": [],
469
- "source": [
470
- "logits = logits.reshape(-1, num_images)"
471
- ]
472
- },
473
- {
474
- "cell_type": "code",
475
- "execution_count": null,
476
- "id": "9c359a8c-2c27-4e68-8775-371857397723",
477
- "metadata": {},
478
- "outputs": [],
479
- "source": [
480
- "logits.shape"
481
- ]
482
- },
483
- {
484
- "cell_type": "code",
485
- "execution_count": null,
486
- "id": "a56b9f28-dd91-4382-bc47-11e89fda1254",
487
- "metadata": {},
488
- "outputs": [],
489
- "source": [
490
- "logits"
491
- ]
492
- },
493
- {
494
- "cell_type": "code",
495
- "execution_count": null,
496
- "id": "0bed8167-0a6d-46c1-badf-8bdc20b93c31",
497
- "metadata": {},
498
- "outputs": [],
499
- "source": [
500
- "top_idx = logits.argsort()[:, -top_k:][..., ::-1]"
501
- ]
502
- },
503
- {
504
- "cell_type": "code",
505
- "execution_count": null,
506
- "id": "188c5333-6b8c-4a17-8cc8-15651c77ef99",
507
- "metadata": {},
508
- "outputs": [],
509
- "source": [
510
- "len(images)"
511
- ]
512
- },
513
- {
514
- "cell_type": "code",
515
- "execution_count": null,
516
- "id": "babd22b3-e773-467d-8bbb-f0323f57a44b",
517
- "metadata": {},
518
- "outputs": [],
519
- "source": [
520
- "results = []\n",
521
- "columns = ['Caption', 'Theme'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]\n",
522
- "logits = jax.device_get(logits)"
523
- ]
524
- },
525
- {
526
- "cell_type": "code",
527
- "execution_count": null,
528
- "id": "75976c9f-dea5-48e3-8920-55a1bbfd91c2",
529
- "metadata": {},
530
- "outputs": [],
531
- "source": [
532
- "for i, (idx, scores, sample) in enumerate(zip(top_idx, logits, batch)):\n",
533
- " if sample['Caption'] == padding_item: continue\n",
534
- " cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
535
- " top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
536
- " top_scores = [scores[x] for x in idx]\n",
537
- " results.append([sample['Caption'], sample['Theme']] + top_images + top_scores)"
538
- ]
539
- },
540
- {
541
- "cell_type": "code",
542
- "execution_count": null,
543
- "id": "4bf40461-99d3-4d36-b7cc-e0129a3c9053",
544
- "metadata": {},
545
- "outputs": [],
546
- "source": [
547
- "table = wandb.Table(columns=columns, data=results)"
548
- ]
549
- },
550
  {
551
  "cell_type": "code",
552
  "execution_count": null,
 
32
  "metadata": {},
33
  "outputs": [],
34
  "source": [
35
+ "run_ids = ['rjf3rycy']\n",
36
+ "ENTITY, PROJECT = 'dalle-mini', 'dalle-mini' # used only for training run\n",
37
  "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
38
+ "normalize_text = True\n",
39
+ "latest_only = False # log only latest or all versions"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": null,
45
+ "id": "23e00271-941c-4e1b-b6a9-107a1b77324d",
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "run_ids = ['4oh3u7ca']\n",
50
+ "ENTITY, PROJECT = 'wandb', 'hf-flax-dalle-mini'\n",
51
+ "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
52
+ "normalize_text = False\n",
53
+ "latest_only = True # log only latest or all versions"
54
  ]
55
  },
56
  {
 
120
  " samples = [samples[i:i+batch_size] for i in range(0, len(samples), batch_size)]"
121
  ]
122
  },
 
 
 
 
 
 
 
 
 
 
 
 
123
  {
124
  "cell_type": "code",
125
  "execution_count": null,
 
127
  "metadata": {},
128
  "outputs": [],
129
  "source": [
130
+ "def get_artifact_versions(run_id, latest_only=False):\n",
131
  " try:\n",
132
+ " if latest_only:\n",
133
+ " return [api.artifact(type='bart_model', name=f'{ENTITY}/{PROJECT}/model-{run_id}:latest')]\n",
134
+ " else:\n",
135
+ " return api.artifact_versions(type_name='bart_model', name=f'{ENTITY}/{PROJECT}/model-{run_id}', per_page=10000)\n",
136
  " except:\n",
137
+ " return []"
 
138
  ]
139
  },
140
  {
 
145
  "outputs": [],
146
  "source": [
147
  "def get_training_config(run_id):\n",
148
+ " training_run = api.run(f'{ENTITY}/{PROJECT}/{run_id}')\n",
149
  " config = training_run.config\n",
150
  " return config"
151
  ]
 
161
  "def get_last_inference_version(run_id):\n",
162
  " try:\n",
163
  " inference_run = api.run(f'dalle-mini/dalle-mini/inference-{run_id}')\n",
164
+ " return inference_run.summary.get('version', None)\n",
165
  " except:\n",
166
  " return None"
167
  ]
 
192
  {
193
  "cell_type": "code",
194
  "execution_count": null,
195
+ "id": "23b2444c-67a9-44d7-abd1-187ed83a9431",
196
  "metadata": {},
197
  "outputs": [],
198
  "source": [
199
+ "run_id = run_ids[0]\n",
200
+ "# TODO: turn everything into a class"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  ]
202
  },
203
  {
204
  "cell_type": "code",
205
  "execution_count": null,
206
+ "id": "bba70f33-af8b-4eb3-9973-7be672301a0b",
207
  "metadata": {},
208
+ "outputs": [
209
+ {
210
+ "name": "stdout",
211
+ "output_type": "stream",
212
+ "text": [
213
+ "Processing artifact: model-4oh3u7ca:v54\n"
214
+ ]
215
+ },
216
+ {
217
+ "name": "stderr",
218
+ "output_type": "stream",
219
+ "text": [
220
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mborisd13\u001b[0m (use `wandb login --relogin` to force relogin)\n"
221
+ ]
222
+ },
223
+ {
224
+ "data": {
225
+ "text/html": [
226
+ "\n",
227
+ " Syncing run <strong><a href=\"https://wandb.ai/dalle-mini/dalle-mini/runs/inference-4oh3u7ca\" target=\"_blank\">inference-4oh3u7ca</a></strong> to <a href=\"https://wandb.ai/dalle-mini/dalle-mini\" target=\"_blank\">Weights & Biases</a> (<a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">docs</a>).<br/>\n",
228
+ "\n",
229
+ " "
230
+ ],
231
+ "text/plain": [
232
+ "<IPython.core.display.HTML object>"
233
+ ]
234
+ },
235
+ "metadata": {},
236
+ "output_type": "display_data"
237
+ }
238
+ ],
239
+ "source": [
240
+ "artifact_versions = get_artifact_versions(run_id, latest_only)\n",
241
+ "last_inference_version = get_last_inference_version(run_id)\n",
242
+ "training_config = get_training_config(run_id)\n",
243
+ "run = None\n",
244
+ "p_generate = None\n",
245
+ "model_files = ['config.json', 'flax_model.msgpack', 'merges.txt', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json', 'vocab.json']\n",
246
+ "for artifact in artifact_versions:\n",
247
+ " print(f'Processing artifact: {artifact.name}')\n",
248
+ " version = int(artifact.version[1:])\n",
249
+ " results = []\n",
250
+ " columns = ['Caption', 'Theme'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]\n",
251
+ " \n",
252
+ " if latest_only:\n",
253
+ " assert last_inference_version is None or version > last_inference_version\n",
254
+ " else:\n",
255
+ " if last_inference_version is None:\n",
256
+ " # we should start from v0\n",
257
+ " assert version == 0\n",
258
+ " elif version <= last_inference_version:\n",
259
+ " print(f'v{version} has already been logged (versions logged up to v{last_inference_version}')\n",
260
+ " else:\n",
261
+ " # check we are logging the correct version\n",
262
+ " assert version == last_inference_version + 1\n",
263
+ "\n",
264
+ " # start/resume corresponding run\n",
265
+ " if run is None:\n",
266
+ " run = wandb.init(job_type='inference', entity='dalle-mini', project='dalle-mini', config=training_config, id=f'inference-{run_id}', resume='allow')\n",
267
+ "\n",
268
+ " # work in temporary directory\n",
269
+ " with tempfile.TemporaryDirectory() as tmp:\n",
270
+ "\n",
271
+ " # download model files\n",
272
+ " artifact = run.use_artifact(artifact)\n",
273
+ " for f in model_files:\n",
274
+ " artifact.get_path(f).download(tmp)\n",
275
+ "\n",
276
+ " # load tokenizer and model\n",
277
+ " tokenizer = BartTokenizer.from_pretrained(tmp)\n",
278
+ " model = CustomFlaxBartForConditionalGeneration.from_pretrained(tmp)\n",
279
+ " model_params = replicate(model.params)\n",
280
+ "\n",
281
+ " # pmap model function needs to happen only once per model config\n",
282
+ " if p_generate is None:\n",
283
+ " p_generate = pmap_model_function(model)\n",
284
+ "\n",
285
+ " # process one batch of captions\n",
286
+ " for batch in tqdm(samples):\n",
287
+ " prompts = [x['Caption'] for x in batch]\n",
288
+ " processed_prompts = [text_normalizer(x) for x in prompts] if normalize_text else prompts\n",
289
+ "\n",
290
+ " # repeat the prompts to distribute over each device and tokenize\n",
291
+ " processed_prompts = processed_prompts * jax.device_count()\n",
292
+ " tokenized_prompt = tokenizer(processed_prompts, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data\n",
293
+ " tokenized_prompt = shard(tokenized_prompt)\n",
294
+ "\n",
295
+ " # generate images\n",
296
+ " print('Generating images')\n",
297
+ " images = []\n",
298
+ " for i in tqdm(range(num_images // jax.device_count())):\n",
299
+ " key, subkey = jax.random.split(key)\n",
300
+ " encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params)\n",
301
+ " encoded_images = encoded_images.sequences[..., 1:]\n",
302
+ " decoded_images = p_decode(encoded_images, vqgan_params)\n",
303
+ " decoded_images = decoded_images.clip(0., 1.).reshape((-1, 256, 256, 3))\n",
304
+ " for img in decoded_images:\n",
305
+ " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))\n",
306
+ "\n",
307
+ " # get clip scores\n",
308
+ " print('Calculating CLIP scores')\n",
309
+ " clip_inputs = processor(text=prompts, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
310
+ " # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
311
+ " images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
312
+ " clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
313
+ " clip_inputs = shard(clip_inputs)\n",
314
+ " logits = p_clip(clip_inputs)\n",
315
+ " logits = logits.reshape(-1, num_images)\n",
316
+ " top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
317
+ " logits = jax.device_get(logits)\n",
318
+ "\n",
319
+ " # add to results table\n",
320
+ " for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
321
+ " if sample['Caption'] == padding_item: continue\n",
322
+ " cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
323
+ " top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
324
+ " top_scores = [scores[x] for x in idx]\n",
325
+ " results.append([sample['Caption'], sample['Theme']] + top_images + top_scores)\n",
326
+ "\n",
327
+ " # log results\n",
328
+ " table = wandb.Table(columns=columns, data=results)\n",
329
+ " run.log({'Samples': table, 'version': version})\n",
330
+ " wandb.finish()"
331
  ]
332
  },
333
  {
 
337
  "metadata": {},
338
  "outputs": [],
339
  "source": [
340
+ "# TODO: not implemented\n",
341
  "def log_runs(runs):\n",
342
  " for run in tqdm(runs):\n",
343
  " log_run(run)"
344
  ]
345
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  {
347
  "cell_type": "code",
348
  "execution_count": null,