boris commited on
Commit
1c83da9
·
1 Parent(s): 27a3435

feat(inference_notebook): dalle-mini is installable

Browse files
dev/inference/inference_pipeline.ipynb CHANGED
@@ -6,7 +6,7 @@
6
  "name": "DALL·E mini - Inference pipeline.ipynb",
7
  "provenance": [],
8
  "collapsed_sections": [],
9
- "authorship_tag": "ABX9TyOmaisFwTAYRR7mJmMVxzdA",
10
  "include_colab_link": true
11
  },
12
  "kernelspec": {
@@ -22,6 +22,7 @@
22
  "49304912717a4995ae45d04a59d1f50e": {
23
  "model_module": "@jupyter-widgets/controls",
24
  "model_name": "HBoxModel",
 
25
  "state": {
26
  "_view_name": "HBoxView",
27
  "_dom_classes": [],
@@ -42,6 +43,7 @@
42
  "5fd9f97986024e8db560a6737ade9e2e": {
43
  "model_module": "@jupyter-widgets/base",
44
  "model_name": "LayoutModel",
 
45
  "state": {
46
  "_view_name": "LayoutView",
47
  "grid_template_rows": null,
@@ -93,6 +95,7 @@
93
  "caced43e3a4c493b98fb07cb41db045c": {
94
  "model_module": "@jupyter-widgets/controls",
95
  "model_name": "FloatProgressModel",
 
96
  "state": {
97
  "_view_name": "ProgressView",
98
  "style": "IPY_MODEL_40c54b9454d346aabd197f2bcf189467",
@@ -116,6 +119,7 @@
116
  "0acc161f2e9948b68b3fc4e57ef333c9": {
117
  "model_module": "@jupyter-widgets/controls",
118
  "model_name": "HTMLModel",
 
119
  "state": {
120
  "_view_name": "HTMLView",
121
  "style": "IPY_MODEL_7e7c488f57fc4acb8d261e2db81d61f0",
@@ -136,6 +140,7 @@
136
  "40c54b9454d346aabd197f2bcf189467": {
137
  "model_module": "@jupyter-widgets/controls",
138
  "model_name": "ProgressStyleModel",
 
139
  "state": {
140
  "_view_name": "StyleView",
141
  "_model_name": "ProgressStyleModel",
@@ -151,6 +156,7 @@
151
  "8b25334a48244a14aa9ba0176887e655": {
152
  "model_module": "@jupyter-widgets/base",
153
  "model_name": "LayoutModel",
 
154
  "state": {
155
  "_view_name": "LayoutView",
156
  "grid_template_rows": null,
@@ -202,6 +208,7 @@
202
  "7e7c488f57fc4acb8d261e2db81d61f0": {
203
  "model_module": "@jupyter-widgets/controls",
204
  "model_name": "DescriptionStyleModel",
 
205
  "state": {
206
  "_view_name": "StyleView",
207
  "_model_name": "DescriptionStyleModel",
@@ -216,6 +223,7 @@
216
  "72c401062a5348b1a366dffb5a403568": {
217
  "model_module": "@jupyter-widgets/base",
218
  "model_name": "LayoutModel",
 
219
  "state": {
220
  "_view_name": "LayoutView",
221
  "grid_template_rows": null,
@@ -267,6 +275,7 @@
267
  "022c124dfff348f285335732781b0887": {
268
  "model_module": "@jupyter-widgets/controls",
269
  "model_name": "HBoxModel",
 
270
  "state": {
271
  "_view_name": "HBoxView",
272
  "_dom_classes": [],
@@ -287,6 +296,7 @@
287
  "a44e47e9d26c4deb81a5a11a9db92a9f": {
288
  "model_module": "@jupyter-widgets/base",
289
  "model_name": "LayoutModel",
 
290
  "state": {
291
  "_view_name": "LayoutView",
292
  "grid_template_rows": null,
@@ -338,6 +348,7 @@
338
  "cd9c7016caae47c1b41fb2608c78b0bf": {
339
  "model_module": "@jupyter-widgets/controls",
340
  "model_name": "FloatProgressModel",
 
341
  "state": {
342
  "_view_name": "ProgressView",
343
  "style": "IPY_MODEL_c22f207311cf4fb69bd9328eabfd4ebb",
@@ -361,6 +372,7 @@
361
  "36ff1d0fea4b47e2ae35aa6bfae6a5e8": {
362
  "model_module": "@jupyter-widgets/controls",
363
  "model_name": "HTMLModel",
 
364
  "state": {
365
  "_view_name": "HTMLView",
366
  "style": "IPY_MODEL_037563a7eadd4ac5abb7249a2914d346",
@@ -381,6 +393,7 @@
381
  "c22f207311cf4fb69bd9328eabfd4ebb": {
382
  "model_module": "@jupyter-widgets/controls",
383
  "model_name": "ProgressStyleModel",
 
384
  "state": {
385
  "_view_name": "StyleView",
386
  "_model_name": "ProgressStyleModel",
@@ -396,6 +409,7 @@
396
  "5a38c6d83a264bedbf7efe6e97eba953": {
397
  "model_module": "@jupyter-widgets/base",
398
  "model_name": "LayoutModel",
 
399
  "state": {
400
  "_view_name": "LayoutView",
401
  "grid_template_rows": null,
@@ -447,6 +461,7 @@
447
  "037563a7eadd4ac5abb7249a2914d346": {
448
  "model_module": "@jupyter-widgets/controls",
449
  "model_name": "DescriptionStyleModel",
 
450
  "state": {
451
  "_view_name": "StyleView",
452
  "_model_name": "DescriptionStyleModel",
@@ -461,6 +476,7 @@
461
  "3975e7ed0b704990b1fa05909a9bb9b6": {
462
  "model_module": "@jupyter-widgets/base",
463
  "model_name": "LayoutModel",
 
464
  "state": {
465
  "_view_name": "LayoutView",
466
  "grid_template_rows": null,
@@ -512,6 +528,7 @@
512
  "f9f1fdc3819a4142b85304cd3c6358a2": {
513
  "model_module": "@jupyter-widgets/controls",
514
  "model_name": "HBoxModel",
 
515
  "state": {
516
  "_view_name": "HBoxView",
517
  "_dom_classes": [],
@@ -532,6 +549,7 @@
532
  "ea9ed54e7c9d4ead8b3e1ff4cb27fa61": {
533
  "model_module": "@jupyter-widgets/base",
534
  "model_name": "LayoutModel",
 
535
  "state": {
536
  "_view_name": "LayoutView",
537
  "grid_template_rows": null,
@@ -583,6 +601,7 @@
583
  "29d42e94b3b34c86a117b623da68faed": {
584
  "model_module": "@jupyter-widgets/controls",
585
  "model_name": "FloatProgressModel",
 
586
  "state": {
587
  "_view_name": "ProgressView",
588
  "style": "IPY_MODEL_8ce4d20d004a4382afa0abdd3b1f7191",
@@ -606,6 +625,7 @@
606
  "8b73de7dbdfe40dbbb39fb593520b984": {
607
  "model_module": "@jupyter-widgets/controls",
608
  "model_name": "HTMLModel",
 
609
  "state": {
610
  "_view_name": "HTMLView",
611
  "style": "IPY_MODEL_717ccef4df1f477abb51814650eb47da",
@@ -626,6 +646,7 @@
626
  "8ce4d20d004a4382afa0abdd3b1f7191": {
627
  "model_module": "@jupyter-widgets/controls",
628
  "model_name": "ProgressStyleModel",
 
629
  "state": {
630
  "_view_name": "StyleView",
631
  "_model_name": "ProgressStyleModel",
@@ -641,6 +662,7 @@
641
  "efc4812245c8459c92e6436889b4f600": {
642
  "model_module": "@jupyter-widgets/base",
643
  "model_name": "LayoutModel",
 
644
  "state": {
645
  "_view_name": "LayoutView",
646
  "grid_template_rows": null,
@@ -692,6 +714,7 @@
692
  "717ccef4df1f477abb51814650eb47da": {
693
  "model_module": "@jupyter-widgets/controls",
694
  "model_name": "DescriptionStyleModel",
 
695
  "state": {
696
  "_view_name": "StyleView",
697
  "_model_name": "DescriptionStyleModel",
@@ -706,6 +729,7 @@
706
  "7dba58f0391c485a86e34e8039ec6189": {
707
  "model_module": "@jupyter-widgets/base",
708
  "model_name": "LayoutModel",
 
709
  "state": {
710
  "_view_name": "LayoutView",
711
  "grid_template_rows": null,
@@ -804,8 +828,7 @@
804
  "source": [
805
  "!pip install -q transformers flax\n",
806
  "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git # VQGAN model in JAX\n",
807
- "!git clone https://github.com/borisdayma/dalle-mini # Model files\n",
808
- "%cd dalle-mini/"
809
  ],
810
  "execution_count": null,
811
  "outputs": []
@@ -833,7 +856,7 @@
833
  "import random\n",
834
  "from tqdm.notebook import tqdm, trange"
835
  ],
836
- "execution_count": 2,
837
  "outputs": []
838
  },
839
  {
@@ -846,7 +869,7 @@
846
  "DALLE_REPO = 'flax-community/dalle-mini'\n",
847
  "DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'"
848
  ],
849
- "execution_count": 3,
850
  "outputs": []
851
  },
852
  {
@@ -871,7 +894,7 @@
871
  "# set a prompt\n",
872
  "prompt = 'picture of a waterfall under the sunset'"
873
  ],
874
- "execution_count": 5,
875
  "outputs": []
876
  },
877
  {
@@ -888,7 +911,7 @@
888
  "tokenized_prompt = tokenizer(prompt, return_tensors='jax', padding='max_length', truncation=True, max_length=128)\n",
889
  "tokenized_prompt"
890
  ],
891
- "execution_count": 6,
892
  "outputs": [
893
  {
894
  "output_type": "execute_result",
@@ -956,7 +979,7 @@
956
  "subkeys = jax.random.split(key, num=n_predictions)\n",
957
  "subkeys"
958
  ],
959
- "execution_count": 7,
960
  "outputs": [
961
  {
962
  "output_type": "execute_result",
@@ -1004,7 +1027,7 @@
1004
  "encoded_images = [model.generate(**tokenized_prompt, do_sample=True, num_beams=1, prng_key=subkey) for subkey in tqdm(subkeys)]\n",
1005
  "encoded_images[0]"
1006
  ],
1007
- "execution_count": 8,
1008
  "outputs": [
1009
  {
1010
  "output_type": "display_data",
@@ -1099,7 +1122,7 @@
1099
  "encoded_images = [img.sequences[..., 1:] for img in encoded_images]\n",
1100
  "encoded_images[0]"
1101
  ],
1102
- "execution_count": 9,
1103
  "outputs": [
1104
  {
1105
  "output_type": "execute_result",
@@ -1167,7 +1190,7 @@
1167
  "source": [
1168
  "encoded_images[0].shape"
1169
  ],
1170
- "execution_count": 10,
1171
  "outputs": [
1172
  {
1173
  "output_type": "execute_result",
@@ -1204,7 +1227,7 @@
1204
  "import numpy as np\n",
1205
  "from PIL import Image"
1206
  ],
1207
- "execution_count": 11,
1208
  "outputs": []
1209
  },
1210
  {
@@ -1217,7 +1240,7 @@
1217
  "VQGAN_REPO = 'flax-community/vqgan_f16_16384'\n",
1218
  "VQGAN_COMMIT_ID = '90cc46addd2dd8f5be21586a9a23e1b95aa506a9'"
1219
  ],
1220
- "execution_count": 12,
1221
  "outputs": []
1222
  },
1223
  {
@@ -1233,7 +1256,7 @@
1233
  "# set up VQGAN\n",
1234
  "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)"
1235
  ],
1236
- "execution_count": 13,
1237
  "outputs": [
1238
  {
1239
  "output_type": "stream",
@@ -1269,7 +1292,7 @@
1269
  "decoded_images = [vqgan.decode_code(encoded_image) for encoded_image in tqdm(encoded_images)]\n",
1270
  "decoded_images[0]"
1271
  ],
1272
- "execution_count": 14,
1273
  "outputs": [
1274
  {
1275
  "output_type": "display_data",
@@ -1373,7 +1396,7 @@
1373
  "# normalize images\n",
1374
  "clipped_images = [img.squeeze().clip(0., 1.) for img in decoded_images]"
1375
  ],
1376
- "execution_count": 15,
1377
  "outputs": []
1378
  },
1379
  {
@@ -1385,7 +1408,7 @@
1385
  "# convert to image\n",
1386
  "images = [Image.fromarray(np.asarray(img * 255, dtype=np.uint8)) for img in clipped_images]"
1387
  ],
1388
- "execution_count": 16,
1389
  "outputs": []
1390
  },
1391
  {
@@ -1402,7 +1425,7 @@
1402
  "# display an image\n",
1403
  "images[0]"
1404
  ],
1405
- "execution_count": 17,
1406
  "outputs": [
1407
  {
1408
  "output_type": "execute_result",
@@ -1438,7 +1461,7 @@
1438
  "source": [
1439
  "from transformers import CLIPProcessor, FlaxCLIPModel"
1440
  ],
1441
- "execution_count": 18,
1442
  "outputs": []
1443
  },
1444
  {
@@ -1474,7 +1497,7 @@
1474
  "logits = clip(**inputs).logits_per_image\n",
1475
  "scores = jax.nn.softmax(logits, axis=0).squeeze() # normalize and sum all scores to 1"
1476
  ],
1477
- "execution_count": 20,
1478
  "outputs": []
1479
  },
1480
  {
@@ -1495,7 +1518,7 @@
1495
  " display(images[idx])\n",
1496
  " print()"
1497
  ],
1498
- "execution_count": 21,
1499
  "outputs": [
1500
  {
1501
  "output_type": "stream",
@@ -1690,7 +1713,7 @@
1690
  "from flax.training.common_utils import shard\n",
1691
  "from flax.jax_utils import replicate"
1692
  ],
1693
- "execution_count": 22,
1694
  "outputs": []
1695
  },
1696
  {
@@ -1706,7 +1729,7 @@
1706
  "# check we can access TPU's or GPU's\n",
1707
  "jax.devices()"
1708
  ],
1709
- "execution_count": 23,
1710
  "outputs": [
1711
  {
1712
  "output_type": "execute_result",
@@ -1744,7 +1767,7 @@
1744
  "# one set of inputs per device\n",
1745
  "prompt = ['picture of a waterfall under the sunset'] * jax.device_count()"
1746
  ],
1747
- "execution_count": 25,
1748
  "outputs": []
1749
  },
1750
  {
@@ -1757,7 +1780,7 @@
1757
  "tokenized_prompt = tokenizer(prompt, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data\n",
1758
  "tokenized_prompt = shard(tokenized_prompt)"
1759
  ],
1760
- "execution_count": 26,
1761
  "outputs": []
1762
  },
1763
  {
@@ -1793,7 +1816,7 @@
1793
  "def p_decode(indices, params):\n",
1794
  " return vqgan.decode_code(indices, params=params)"
1795
  ],
1796
- "execution_count": 27,
1797
  "outputs": []
1798
  },
1799
  {
@@ -1834,7 +1857,7 @@
1834
  " for img in decoded_images:\n",
1835
  " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
1836
  ],
1837
- "execution_count": 28,
1838
  "outputs": [
1839
  {
1840
  "output_type": "display_data",
@@ -1877,7 +1900,7 @@
1877
  " display(img)\n",
1878
  " print()"
1879
  ],
1880
- "execution_count": 29,
1881
  "outputs": [
1882
  {
1883
  "output_type": "display_data",
 
6
  "name": "DALL·E mini - Inference pipeline.ipynb",
7
  "provenance": [],
8
  "collapsed_sections": [],
9
+ "authorship_tag": "ABX9TyMUjEt1XMLq+6/GhSnVFsSx",
10
  "include_colab_link": true
11
  },
12
  "kernelspec": {
 
22
  "49304912717a4995ae45d04a59d1f50e": {
23
  "model_module": "@jupyter-widgets/controls",
24
  "model_name": "HBoxModel",
25
+ "model_module_version": "1.5.0",
26
  "state": {
27
  "_view_name": "HBoxView",
28
  "_dom_classes": [],
 
43
  "5fd9f97986024e8db560a6737ade9e2e": {
44
  "model_module": "@jupyter-widgets/base",
45
  "model_name": "LayoutModel",
46
+ "model_module_version": "1.2.0",
47
  "state": {
48
  "_view_name": "LayoutView",
49
  "grid_template_rows": null,
 
95
  "caced43e3a4c493b98fb07cb41db045c": {
96
  "model_module": "@jupyter-widgets/controls",
97
  "model_name": "FloatProgressModel",
98
+ "model_module_version": "1.5.0",
99
  "state": {
100
  "_view_name": "ProgressView",
101
  "style": "IPY_MODEL_40c54b9454d346aabd197f2bcf189467",
 
119
  "0acc161f2e9948b68b3fc4e57ef333c9": {
120
  "model_module": "@jupyter-widgets/controls",
121
  "model_name": "HTMLModel",
122
+ "model_module_version": "1.5.0",
123
  "state": {
124
  "_view_name": "HTMLView",
125
  "style": "IPY_MODEL_7e7c488f57fc4acb8d261e2db81d61f0",
 
140
  "40c54b9454d346aabd197f2bcf189467": {
141
  "model_module": "@jupyter-widgets/controls",
142
  "model_name": "ProgressStyleModel",
143
+ "model_module_version": "1.5.0",
144
  "state": {
145
  "_view_name": "StyleView",
146
  "_model_name": "ProgressStyleModel",
 
156
  "8b25334a48244a14aa9ba0176887e655": {
157
  "model_module": "@jupyter-widgets/base",
158
  "model_name": "LayoutModel",
159
+ "model_module_version": "1.2.0",
160
  "state": {
161
  "_view_name": "LayoutView",
162
  "grid_template_rows": null,
 
208
  "7e7c488f57fc4acb8d261e2db81d61f0": {
209
  "model_module": "@jupyter-widgets/controls",
210
  "model_name": "DescriptionStyleModel",
211
+ "model_module_version": "1.5.0",
212
  "state": {
213
  "_view_name": "StyleView",
214
  "_model_name": "DescriptionStyleModel",
 
223
  "72c401062a5348b1a366dffb5a403568": {
224
  "model_module": "@jupyter-widgets/base",
225
  "model_name": "LayoutModel",
226
+ "model_module_version": "1.2.0",
227
  "state": {
228
  "_view_name": "LayoutView",
229
  "grid_template_rows": null,
 
275
  "022c124dfff348f285335732781b0887": {
276
  "model_module": "@jupyter-widgets/controls",
277
  "model_name": "HBoxModel",
278
+ "model_module_version": "1.5.0",
279
  "state": {
280
  "_view_name": "HBoxView",
281
  "_dom_classes": [],
 
296
  "a44e47e9d26c4deb81a5a11a9db92a9f": {
297
  "model_module": "@jupyter-widgets/base",
298
  "model_name": "LayoutModel",
299
+ "model_module_version": "1.2.0",
300
  "state": {
301
  "_view_name": "LayoutView",
302
  "grid_template_rows": null,
 
348
  "cd9c7016caae47c1b41fb2608c78b0bf": {
349
  "model_module": "@jupyter-widgets/controls",
350
  "model_name": "FloatProgressModel",
351
+ "model_module_version": "1.5.0",
352
  "state": {
353
  "_view_name": "ProgressView",
354
  "style": "IPY_MODEL_c22f207311cf4fb69bd9328eabfd4ebb",
 
372
  "36ff1d0fea4b47e2ae35aa6bfae6a5e8": {
373
  "model_module": "@jupyter-widgets/controls",
374
  "model_name": "HTMLModel",
375
+ "model_module_version": "1.5.0",
376
  "state": {
377
  "_view_name": "HTMLView",
378
  "style": "IPY_MODEL_037563a7eadd4ac5abb7249a2914d346",
 
393
  "c22f207311cf4fb69bd9328eabfd4ebb": {
394
  "model_module": "@jupyter-widgets/controls",
395
  "model_name": "ProgressStyleModel",
396
+ "model_module_version": "1.5.0",
397
  "state": {
398
  "_view_name": "StyleView",
399
  "_model_name": "ProgressStyleModel",
 
409
  "5a38c6d83a264bedbf7efe6e97eba953": {
410
  "model_module": "@jupyter-widgets/base",
411
  "model_name": "LayoutModel",
412
+ "model_module_version": "1.2.0",
413
  "state": {
414
  "_view_name": "LayoutView",
415
  "grid_template_rows": null,
 
461
  "037563a7eadd4ac5abb7249a2914d346": {
462
  "model_module": "@jupyter-widgets/controls",
463
  "model_name": "DescriptionStyleModel",
464
+ "model_module_version": "1.5.0",
465
  "state": {
466
  "_view_name": "StyleView",
467
  "_model_name": "DescriptionStyleModel",
 
476
  "3975e7ed0b704990b1fa05909a9bb9b6": {
477
  "model_module": "@jupyter-widgets/base",
478
  "model_name": "LayoutModel",
479
+ "model_module_version": "1.2.0",
480
  "state": {
481
  "_view_name": "LayoutView",
482
  "grid_template_rows": null,
 
528
  "f9f1fdc3819a4142b85304cd3c6358a2": {
529
  "model_module": "@jupyter-widgets/controls",
530
  "model_name": "HBoxModel",
531
+ "model_module_version": "1.5.0",
532
  "state": {
533
  "_view_name": "HBoxView",
534
  "_dom_classes": [],
 
549
  "ea9ed54e7c9d4ead8b3e1ff4cb27fa61": {
550
  "model_module": "@jupyter-widgets/base",
551
  "model_name": "LayoutModel",
552
+ "model_module_version": "1.2.0",
553
  "state": {
554
  "_view_name": "LayoutView",
555
  "grid_template_rows": null,
 
601
  "29d42e94b3b34c86a117b623da68faed": {
602
  "model_module": "@jupyter-widgets/controls",
603
  "model_name": "FloatProgressModel",
604
+ "model_module_version": "1.5.0",
605
  "state": {
606
  "_view_name": "ProgressView",
607
  "style": "IPY_MODEL_8ce4d20d004a4382afa0abdd3b1f7191",
 
625
  "8b73de7dbdfe40dbbb39fb593520b984": {
626
  "model_module": "@jupyter-widgets/controls",
627
  "model_name": "HTMLModel",
628
+ "model_module_version": "1.5.0",
629
  "state": {
630
  "_view_name": "HTMLView",
631
  "style": "IPY_MODEL_717ccef4df1f477abb51814650eb47da",
 
646
  "8ce4d20d004a4382afa0abdd3b1f7191": {
647
  "model_module": "@jupyter-widgets/controls",
648
  "model_name": "ProgressStyleModel",
649
+ "model_module_version": "1.5.0",
650
  "state": {
651
  "_view_name": "StyleView",
652
  "_model_name": "ProgressStyleModel",
 
662
  "efc4812245c8459c92e6436889b4f600": {
663
  "model_module": "@jupyter-widgets/base",
664
  "model_name": "LayoutModel",
665
+ "model_module_version": "1.2.0",
666
  "state": {
667
  "_view_name": "LayoutView",
668
  "grid_template_rows": null,
 
714
  "717ccef4df1f477abb51814650eb47da": {
715
  "model_module": "@jupyter-widgets/controls",
716
  "model_name": "DescriptionStyleModel",
717
+ "model_module_version": "1.5.0",
718
  "state": {
719
  "_view_name": "StyleView",
720
  "_model_name": "DescriptionStyleModel",
 
729
  "7dba58f0391c485a86e34e8039ec6189": {
730
  "model_module": "@jupyter-widgets/base",
731
  "model_name": "LayoutModel",
732
+ "model_module_version": "1.2.0",
733
  "state": {
734
  "_view_name": "LayoutView",
735
  "grid_template_rows": null,
 
828
  "source": [
829
  "!pip install -q transformers flax\n",
830
  "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git # VQGAN model in JAX\n",
831
+ "!pip install -q git+https://github.com/borisdayma/dalle-mini.git # Model files"
 
832
  ],
833
  "execution_count": null,
834
  "outputs": []
 
856
  "import random\n",
857
  "from tqdm.notebook import tqdm, trange"
858
  ],
859
+ "execution_count": null,
860
  "outputs": []
861
  },
862
  {
 
869
  "DALLE_REPO = 'flax-community/dalle-mini'\n",
870
  "DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'"
871
  ],
872
+ "execution_count": null,
873
  "outputs": []
874
  },
875
  {
 
894
  "# set a prompt\n",
895
  "prompt = 'picture of a waterfall under the sunset'"
896
  ],
897
+ "execution_count": null,
898
  "outputs": []
899
  },
900
  {
 
911
  "tokenized_prompt = tokenizer(prompt, return_tensors='jax', padding='max_length', truncation=True, max_length=128)\n",
912
  "tokenized_prompt"
913
  ],
914
+ "execution_count": null,
915
  "outputs": [
916
  {
917
  "output_type": "execute_result",
 
979
  "subkeys = jax.random.split(key, num=n_predictions)\n",
980
  "subkeys"
981
  ],
982
+ "execution_count": null,
983
  "outputs": [
984
  {
985
  "output_type": "execute_result",
 
1027
  "encoded_images = [model.generate(**tokenized_prompt, do_sample=True, num_beams=1, prng_key=subkey) for subkey in tqdm(subkeys)]\n",
1028
  "encoded_images[0]"
1029
  ],
1030
+ "execution_count": null,
1031
  "outputs": [
1032
  {
1033
  "output_type": "display_data",
 
1122
  "encoded_images = [img.sequences[..., 1:] for img in encoded_images]\n",
1123
  "encoded_images[0]"
1124
  ],
1125
+ "execution_count": null,
1126
  "outputs": [
1127
  {
1128
  "output_type": "execute_result",
 
1190
  "source": [
1191
  "encoded_images[0].shape"
1192
  ],
1193
+ "execution_count": null,
1194
  "outputs": [
1195
  {
1196
  "output_type": "execute_result",
 
1227
  "import numpy as np\n",
1228
  "from PIL import Image"
1229
  ],
1230
+ "execution_count": null,
1231
  "outputs": []
1232
  },
1233
  {
 
1240
  "VQGAN_REPO = 'flax-community/vqgan_f16_16384'\n",
1241
  "VQGAN_COMMIT_ID = '90cc46addd2dd8f5be21586a9a23e1b95aa506a9'"
1242
  ],
1243
+ "execution_count": null,
1244
  "outputs": []
1245
  },
1246
  {
 
1256
  "# set up VQGAN\n",
1257
  "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)"
1258
  ],
1259
+ "execution_count": null,
1260
  "outputs": [
1261
  {
1262
  "output_type": "stream",
 
1292
  "decoded_images = [vqgan.decode_code(encoded_image) for encoded_image in tqdm(encoded_images)]\n",
1293
  "decoded_images[0]"
1294
  ],
1295
+ "execution_count": null,
1296
  "outputs": [
1297
  {
1298
  "output_type": "display_data",
 
1396
  "# normalize images\n",
1397
  "clipped_images = [img.squeeze().clip(0., 1.) for img in decoded_images]"
1398
  ],
1399
+ "execution_count": null,
1400
  "outputs": []
1401
  },
1402
  {
 
1408
  "# convert to image\n",
1409
  "images = [Image.fromarray(np.asarray(img * 255, dtype=np.uint8)) for img in clipped_images]"
1410
  ],
1411
+ "execution_count": null,
1412
  "outputs": []
1413
  },
1414
  {
 
1425
  "# display an image\n",
1426
  "images[0]"
1427
  ],
1428
+ "execution_count": null,
1429
  "outputs": [
1430
  {
1431
  "output_type": "execute_result",
 
1461
  "source": [
1462
  "from transformers import CLIPProcessor, FlaxCLIPModel"
1463
  ],
1464
+ "execution_count": null,
1465
  "outputs": []
1466
  },
1467
  {
 
1497
  "logits = clip(**inputs).logits_per_image\n",
1498
  "scores = jax.nn.softmax(logits, axis=0).squeeze() # normalize and sum all scores to 1"
1499
  ],
1500
+ "execution_count": null,
1501
  "outputs": []
1502
  },
1503
  {
 
1518
  " display(images[idx])\n",
1519
  " print()"
1520
  ],
1521
+ "execution_count": null,
1522
  "outputs": [
1523
  {
1524
  "output_type": "stream",
 
1713
  "from flax.training.common_utils import shard\n",
1714
  "from flax.jax_utils import replicate"
1715
  ],
1716
+ "execution_count": null,
1717
  "outputs": []
1718
  },
1719
  {
 
1729
  "# check we can access TPU's or GPU's\n",
1730
  "jax.devices()"
1731
  ],
1732
+ "execution_count": null,
1733
  "outputs": [
1734
  {
1735
  "output_type": "execute_result",
 
1767
  "# one set of inputs per device\n",
1768
  "prompt = ['picture of a waterfall under the sunset'] * jax.device_count()"
1769
  ],
1770
+ "execution_count": null,
1771
  "outputs": []
1772
  },
1773
  {
 
1780
  "tokenized_prompt = tokenizer(prompt, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data\n",
1781
  "tokenized_prompt = shard(tokenized_prompt)"
1782
  ],
1783
+ "execution_count": null,
1784
  "outputs": []
1785
  },
1786
  {
 
1816
  "def p_decode(indices, params):\n",
1817
  " return vqgan.decode_code(indices, params=params)"
1818
  ],
1819
+ "execution_count": null,
1820
  "outputs": []
1821
  },
1822
  {
 
1857
  " for img in decoded_images:\n",
1858
  " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
1859
  ],
1860
+ "execution_count": null,
1861
  "outputs": [
1862
  {
1863
  "output_type": "display_data",
 
1900
  " display(img)\n",
1901
  " print()"
1902
  ],
1903
+ "execution_count": null,
1904
  "outputs": [
1905
  {
1906
  "output_type": "display_data",