boris commited on
Commit
353365f
1 Parent(s): b8bbe68

feat: add scoring

Browse files
Files changed (1) hide show
  1. dev/inference/wandb-backend.ipynb +338 -51
dev/inference/wandb-backend.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 197,
6
  "id": "4ff2a984-b8b2-4a69-89cf-0d16da2393c8",
7
  "metadata": {},
8
  "outputs": [],
@@ -10,7 +10,13 @@
10
  "import csv\n",
11
  "import tempfile\n",
12
  "from functools import partial\n",
 
 
 
13
  "import jax\n",
 
 
 
14
  "import wandb\n",
15
  "from dalle_mini.model import CustomFlaxBartForConditionalGeneration\n",
16
  "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
@@ -30,6 +36,30 @@
30
  "normalize_text = True"
31
  ]
32
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  {
34
  "cell_type": "code",
35
  "execution_count": null,
@@ -44,7 +74,18 @@
44
  },
45
  {
46
  "cell_type": "code",
47
- "execution_count": 245,
 
 
 
 
 
 
 
 
 
 
 
48
  "id": "e57797ab-0b3a-4490-be58-03d8d1c23fe9",
49
  "metadata": {},
50
  "outputs": [],
@@ -53,51 +94,32 @@
53
  " reader = csv.DictReader(f)\n",
54
  " samples = []\n",
55
  " for row in reader:\n",
56
- " samples.append(row)"
 
 
 
 
 
57
  ]
58
  },
59
  {
60
  "cell_type": "code",
61
- "execution_count": 246,
62
  "id": "f75b2869-fc25-4f56-b937-e97bbb712ede",
63
  "metadata": {},
64
- "outputs": [
65
- {
66
- "data": {
67
- "text/plain": [
68
- "101"
69
- ]
70
- },
71
- "execution_count": 246,
72
- "metadata": {},
73
- "output_type": "execute_result"
74
- }
75
- ],
76
  "source": [
77
  "len(samples)"
78
  ]
79
  },
80
  {
81
  "cell_type": "code",
82
- "execution_count": 248,
83
- "id": "2ea0b166-a20c-4d78-bffb-b792ca512d17",
84
  "metadata": {},
85
- "outputs": [
86
- {
87
- "data": {
88
- "text/plain": [
89
- "104"
90
- ]
91
- },
92
- "execution_count": 248,
93
- "metadata": {},
94
- "output_type": "execute_result"
95
- }
96
- ],
97
  "source": [
98
- "samples_to_add = ['empty'] * (-len(samples) % 8)\n",
99
- "samples.extend(samples_to_add)\n",
100
- "len(samples)"
101
  ]
102
  },
103
  {
@@ -112,7 +134,7 @@
112
  },
113
  {
114
  "cell_type": "code",
115
- "execution_count": 204,
116
  "id": "3ffb1d09-bd1c-4f57-9ae5-3eda6f7d3a08",
117
  "metadata": {},
118
  "outputs": [],
@@ -148,21 +170,11 @@
148
  {
149
  "cell_type": "code",
150
  "execution_count": null,
151
- "id": "29613a9d-de7e-44e3-94f1-650085039204",
152
- "metadata": {},
153
- "outputs": [],
154
- "source": [
155
- "versions = sorted(versions, key=lambda x: int(x.version[1:]))"
156
- ]
157
- },
158
- {
159
- "cell_type": "code",
160
- "execution_count": null,
161
- "id": "d77159df-1a16-4996-aafd-1df82c5a3509",
162
  "metadata": {},
163
  "outputs": [],
164
  "source": [
165
- "versions"
166
  ]
167
  },
168
  {
@@ -253,6 +265,8 @@
253
  "source": [
254
  "if last_version_inference is None:\n",
255
  " assert version == 0\n",
 
 
256
  "else:\n",
257
  " assert version == last_version_inference + 1"
258
  ]
@@ -338,7 +352,17 @@
338
  },
339
  {
340
  "cell_type": "code",
341
- "execution_count": 207,
 
 
 
 
 
 
 
 
 
 
342
  "id": "d1cc9993-1bfc-4ec6-a004-c056189c42ac",
343
  "metadata": {},
344
  "outputs": [],
@@ -360,6 +384,12 @@
360
  " def p_decode(indices, params):\n",
361
  " return vqgan.decode_code(indices, params=params)\n",
362
  " \n",
 
 
 
 
 
 
363
  " functions_pmapped = False"
364
  ]
365
  },
@@ -369,25 +399,282 @@
369
  "id": "7a24b903-777b-4e3d-817c-00ed613a7021",
370
  "metadata": {},
371
  "outputs": [],
372
- "source": []
 
 
 
 
 
373
  },
374
  {
375
  "cell_type": "code",
376
  "execution_count": null,
377
- "id": "e1c04761-1016-47e9-925c-3a9ec6fec95a",
378
  "metadata": {},
379
  "outputs": [],
380
  "source": [
381
- "wandb.finish()"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  ]
383
  },
384
  {
385
  "cell_type": "code",
386
  "execution_count": null,
387
- "id": "e79ac8f2-adc2-4a16-970c-dadcceadd566",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  "metadata": {},
389
  "outputs": [],
390
  "source": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  }
392
  ],
393
  "metadata": {
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": null,
6
  "id": "4ff2a984-b8b2-4a69-89cf-0d16da2393c8",
7
  "metadata": {},
8
  "outputs": [],
 
10
  "import csv\n",
11
  "import tempfile\n",
12
  "from functools import partial\n",
13
+ "import random\n",
14
+ "import numpy as np\n",
15
+ "from PIL import Image\n",
16
  "import jax\n",
17
+ "import jax.numpy as jnp\n",
18
+ "from flax.training.common_utils import shard, shard_prng_key\n",
19
+ "from flax.jax_utils import replicate\n",
20
  "import wandb\n",
21
  "from dalle_mini.model import CustomFlaxBartForConditionalGeneration\n",
22
  "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
 
36
  "normalize_text = True"
37
  ]
38
  },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": null,
42
+ "id": "93b2e24b-f0e5-4abe-a3ec-0aa834cc3bf3",
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "batch_size = 8\n",
47
+ "num_images = 128\n",
48
+ "top_k = 8\n",
49
+ "text_normalizer = TextNormalizer() if normalize_text else None"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "id": "6a045827-3461-4499-8959-38d173bc4e5e",
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "seed = random.randint(0, 2**32-1)\n",
60
+ "key = jax.random.PRNGKey(seed)"
61
+ ]
62
+ },
63
  {
64
  "cell_type": "code",
65
  "execution_count": null,
 
74
  },
75
  {
76
  "cell_type": "code",
77
+ "execution_count": null,
78
+ "id": "4927529a-8828-4150-bc76-e1b60d8dee62",
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "clip_params = replicate(clip.params)\n",
83
+ "vqgan_params = replicate(vqgan.params)"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
  "id": "e57797ab-0b3a-4490-be58-03d8d1c23fe9",
90
  "metadata": {},
91
  "outputs": [],
 
94
  " reader = csv.DictReader(f)\n",
95
  " samples = []\n",
96
  " for row in reader:\n",
97
+ " samples.append(row)\n",
98
+ " # make list multiple of batch_size by adding \"empty\"\n",
99
+ " samples_to_add = [{'Caption':'empty', 'Theme':'empty'}] * (-len(samples) % batch_size)\n",
100
+ " samples.extend(samples_to_add)\n",
101
+ " # reshape\n",
102
+ " samples = [samples[i:i+batch_size] for i in range(0, len(samples), batch_size)]"
103
  ]
104
  },
105
  {
106
  "cell_type": "code",
107
+ "execution_count": null,
108
  "id": "f75b2869-fc25-4f56-b937-e97bbb712ede",
109
  "metadata": {},
110
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
111
  "source": [
112
  "len(samples)"
113
  ]
114
  },
115
  {
116
  "cell_type": "code",
117
+ "execution_count": null,
118
+ "id": "c48525c9-447a-4430-81d7-4b699f545638",
119
  "metadata": {},
120
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
121
  "source": [
122
+ "samples[-1]"
 
 
123
  ]
124
  },
125
  {
 
134
  },
135
  {
136
  "cell_type": "code",
137
+ "execution_count": null,
138
  "id": "3ffb1d09-bd1c-4f57-9ae5-3eda6f7d3a08",
139
  "metadata": {},
140
  "outputs": [],
 
170
  {
171
  "cell_type": "code",
172
  "execution_count": null,
173
+ "id": "ead44aee-52d5-4ca2-8984-c4d267d9e72a",
 
 
 
 
 
 
 
 
 
 
174
  "metadata": {},
175
  "outputs": [],
176
  "source": [
177
+ "versions[0].version"
178
  ]
179
  },
180
  {
 
265
  "source": [
266
  "if last_version_inference is None:\n",
267
  " assert version == 0\n",
268
+ "elif last_version_inference >= version:\n",
269
+ " print(f'Version {version} has already been logged')\n",
270
  "else:\n",
271
  " assert version == last_version_inference + 1"
272
  ]
 
352
  },
353
  {
354
  "cell_type": "code",
355
+ "execution_count": null,
356
+ "id": "320823c9-124a-4fc3-a12c-8c015a128285",
357
+ "metadata": {},
358
+ "outputs": [],
359
+ "source": [
360
+ "model_params = replicate(model.params)"
361
+ ]
362
+ },
363
+ {
364
+ "cell_type": "code",
365
+ "execution_count": null,
366
  "id": "d1cc9993-1bfc-4ec6-a004-c056189c42ac",
367
  "metadata": {},
368
  "outputs": [],
 
384
  " def p_decode(indices, params):\n",
385
  " return vqgan.decode_code(indices, params=params)\n",
386
  " \n",
387
+ " @partial(jax.pmap, axis_name=\"batch\")\n",
388
+ " def p_clip(inputs):\n",
389
+ " logits = clip(**inputs).logits_per_image\n",
390
+ " return logits\n",
391
+ " scores = jax.nn.softmax(logits, axis=0).squeeze() \n",
392
+ " \n",
393
  " functions_pmapped = False"
394
  ]
395
  },
 
399
  "id": "7a24b903-777b-4e3d-817c-00ed613a7021",
400
  "metadata": {},
401
  "outputs": [],
402
+ "source": [
403
+ "# TODO: loop over samples\n",
404
+ "batch = samples[0]\n",
405
+ "prompts = [x['Caption'] for x in batch]\n",
406
+ "processed_prompts = [text_normalizer(x) for x in prompts] if normalize_text else prompts"
407
+ ]
408
  },
409
  {
410
  "cell_type": "code",
411
  "execution_count": null,
412
+ "id": "d77aa785-dc05-4070-aba2-aa007524d20b",
413
  "metadata": {},
414
  "outputs": [],
415
  "source": [
416
+ "processed_prompts"
417
+ ]
418
+ },
419
+ {
420
+ "cell_type": "code",
421
+ "execution_count": null,
422
+ "id": "95db38fb-8948-4814-98ae-c172ca7c6d0a",
423
+ "metadata": {},
424
+ "outputs": [],
425
+ "source": [
426
+ "repeated_prompts = processed_prompts * jax.device_count()"
427
+ ]
428
+ },
429
+ {
430
+ "cell_type": "code",
431
+ "execution_count": null,
432
+ "id": "e948ba9e-3700-4e87-926f-580a10f3e5cd",
433
+ "metadata": {},
434
+ "outputs": [],
435
+ "source": [
436
+ "tokenized_prompt = tokenizer(repeated_prompts, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data\n",
437
+ "tokenized_prompt = shard(tokenized_prompt)"
438
+ ]
439
+ },
440
+ {
441
+ "cell_type": "code",
442
+ "execution_count": null,
443
+ "id": "30d96812-fc17-4acf-bb64-5fdb8d0cd313",
444
+ "metadata": {},
445
+ "outputs": [],
446
+ "source": [
447
+ "tokenized_prompt['input_ids'].shape"
448
+ ]
449
+ },
450
+ {
451
+ "cell_type": "code",
452
+ "execution_count": null,
453
+ "id": "92ea034b-2649-4d18-ab6d-877ed04ae5c4",
454
+ "metadata": {},
455
+ "outputs": [],
456
+ "source": [
457
+ "images = []\n",
458
+ "for i in range(num_images // jax.device_count()):\n",
459
+ " key, subkey = jax.random.split(key, 2)\n",
460
+ " \n",
461
+ " encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params)\n",
462
+ " encoded_images = encoded_images.sequences[..., 1:]\n",
463
+ " \n",
464
+ " decoded_images = p_decode(encoded_images, vqgan_params)\n",
465
+ " decoded_images = decoded_images.clip(0., 1.).reshape((-1, 256, 256, 3))\n",
466
+ " \n",
467
+ " for img in decoded_images:\n",
468
+ " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))\n",
469
+ " "
470
+ ]
471
+ },
472
+ {
473
+ "cell_type": "code",
474
+ "execution_count": null,
475
+ "id": "84d52f30-44c9-4a74-9992-fb2578f19b90",
476
+ "metadata": {},
477
+ "outputs": [],
478
+ "source": [
479
+ "len(images)"
480
+ ]
481
+ },
482
+ {
483
+ "cell_type": "code",
484
+ "execution_count": null,
485
+ "id": "beb594f9-5b91-47fe-98bd-41e68c6b1d73",
486
+ "metadata": {},
487
+ "outputs": [],
488
+ "source": [
489
+ "images[0]"
490
+ ]
491
+ },
492
+ {
493
+ "cell_type": "code",
494
+ "execution_count": null,
495
+ "id": "bb135190-64e5-44af-b416-e688b034da44",
496
+ "metadata": {},
497
+ "outputs": [],
498
+ "source": [
499
+ "images[1]"
500
+ ]
501
+ },
502
+ {
503
+ "cell_type": "code",
504
+ "execution_count": null,
505
+ "id": "d78a0d92-72c2-4f82-a6ab-b3f5865dd863",
506
+ "metadata": {},
507
+ "outputs": [],
508
+ "source": [
509
+ "clip_inputs = processor(text=prompts, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data"
510
  ]
511
  },
512
  {
513
  "cell_type": "code",
514
  "execution_count": null,
515
+ "id": "89ff78a6-bfa4-44d9-ad66-07a4a68b4352",
516
+ "metadata": {},
517
+ "outputs": [],
518
+ "source": [
519
+ "# each shard will have one prompt\n",
520
+ "clip_inputs['input_ids'].shape"
521
+ ]
522
+ },
523
+ {
524
+ "cell_type": "code",
525
+ "execution_count": null,
526
+ "id": "2cda8984-049c-4c87-96ad-7b0412750656",
527
+ "metadata": {},
528
+ "outputs": [],
529
+ "source": [
530
+ "# each shard needs to have the images corresponding to a specific prompt\n",
531
+ "clip_inputs['pixel_values'].shape"
532
+ ]
533
+ },
534
+ {
535
+ "cell_type": "code",
536
+ "execution_count": null,
537
+ "id": "0a044e8f-be29-404b-b6c7-8f2395c5efc6",
538
+ "metadata": {},
539
+ "outputs": [],
540
+ "source": [
541
+ "images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
542
+ "images_per_prompt_indices"
543
+ ]
544
+ },
545
+ {
546
+ "cell_type": "code",
547
+ "execution_count": null,
548
+ "id": "7a6c61b3-12e0-45d8-b39a-830288324d3d",
549
  "metadata": {},
550
  "outputs": [],
551
  "source": []
552
+ },
553
+ {
554
+ "cell_type": "code",
555
+ "execution_count": null,
556
+ "id": "7318e67e-4214-46f9-bf60-6d139d4bd00f",
557
+ "metadata": {},
558
+ "outputs": [],
559
+ "source": [
560
+ "# reorder so each shard will have correct images\n",
561
+ "clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))"
562
+ ]
563
+ },
564
+ {
565
+ "cell_type": "code",
566
+ "execution_count": null,
567
+ "id": "90c949a2-8e2a-4905-b6d4-92038f1704b8",
568
+ "metadata": {},
569
+ "outputs": [],
570
+ "source": [
571
+ "clip_inputs = shard(clip_inputs)"
572
+ ]
573
+ },
574
+ {
575
+ "cell_type": "code",
576
+ "execution_count": null,
577
+ "id": "58fa836e-5ebb-45e7-af77-ab10646dfbfb",
578
+ "metadata": {},
579
+ "outputs": [],
580
+ "source": [
581
+ "logits = p_clip(clip_inputs)"
582
+ ]
583
+ },
584
+ {
585
+ "cell_type": "code",
586
+ "execution_count": null,
587
+ "id": "fd7a3f91-3a1f-4a0a-8b3e-3c926cd367fb",
588
+ "metadata": {},
589
+ "outputs": [],
590
+ "source": [
591
+ "logits.shape"
592
+ ]
593
+ },
594
+ {
595
+ "cell_type": "code",
596
+ "execution_count": null,
597
+ "id": "fa406db7-0a21-4e4b-9890-4c7aece4280c",
598
+ "metadata": {},
599
+ "outputs": [],
600
+ "source": [
601
+ "logits = logits.reshape(-1, num_images)"
602
+ ]
603
+ },
604
+ {
605
+ "cell_type": "code",
606
+ "execution_count": null,
607
+ "id": "9c359a8c-2c27-4e68-8775-371857397723",
608
+ "metadata": {},
609
+ "outputs": [],
610
+ "source": [
611
+ "logits.shape"
612
+ ]
613
+ },
614
+ {
615
+ "cell_type": "code",
616
+ "execution_count": null,
617
+ "id": "a56b9f28-dd91-4382-bc47-11e89fda1254",
618
+ "metadata": {},
619
+ "outputs": [],
620
+ "source": [
621
+ "logits"
622
+ ]
623
+ },
624
+ {
625
+ "cell_type": "code",
626
+ "execution_count": null,
627
+ "id": "0bed8167-0a6d-46c1-badf-8bdc20b93c31",
628
+ "metadata": {},
629
+ "outputs": [],
630
+ "source": [
631
+ "top_idx = logits.argsort()[:, -top_k:][..., ::-1]"
632
+ ]
633
+ },
634
+ {
635
+ "cell_type": "code",
636
+ "execution_count": null,
637
+ "id": "188c5333-6b8c-4a17-8cc8-15651c77ef99",
638
+ "metadata": {},
639
+ "outputs": [],
640
+ "source": [
641
+ "len(images)"
642
+ ]
643
+ },
644
+ {
645
+ "cell_type": "code",
646
+ "execution_count": null,
647
+ "id": "babd22b3-e773-467d-8bbb-f0323f57a44b",
648
+ "metadata": {},
649
+ "outputs": [],
650
+ "source": [
651
+ "results = []\n",
652
+ "columns = ['Caption', 'Theme'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]"
653
+ ]
654
+ },
655
+ {
656
+ "cell_type": "code",
657
+ "execution_count": null,
658
+ "id": "75976c9f-dea5-48e3-8920-55a1bbfd91c2",
659
+ "metadata": {},
660
+ "outputs": [],
661
+ "source": [
662
+ "for i, (idx, scores, sample) in enumerate(zip(top_idx, logits, batch)):\n",
663
+ " cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
664
+ " top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
665
+ " top_scores = [logits[x] for x in idx]\n",
666
+ " results.append([sample['Caption'], sample['Theme']] + top_images + top_scores)"
667
+ ]
668
+ },
669
+ {
670
+ "cell_type": "code",
671
+ "execution_count": null,
672
+ "id": "e1c04761-1016-47e9-925c-3a9ec6fec95a",
673
+ "metadata": {},
674
+ "outputs": [],
675
+ "source": [
676
+ "wandb.finish()"
677
+ ]
678
  }
679
  ],
680
  "metadata": {