Spaces:
Running
Running
feat(inference_notebook): dalle-mini is installable
Browse files
dev/inference/inference_pipeline.ipynb
CHANGED
@@ -6,7 +6,7 @@
|
|
6 |
"name": "DALL·E mini - Inference pipeline.ipynb",
|
7 |
"provenance": [],
|
8 |
"collapsed_sections": [],
|
9 |
-
"authorship_tag": "
|
10 |
"include_colab_link": true
|
11 |
},
|
12 |
"kernelspec": {
|
@@ -22,6 +22,7 @@
|
|
22 |
"49304912717a4995ae45d04a59d1f50e": {
|
23 |
"model_module": "@jupyter-widgets/controls",
|
24 |
"model_name": "HBoxModel",
|
|
|
25 |
"state": {
|
26 |
"_view_name": "HBoxView",
|
27 |
"_dom_classes": [],
|
@@ -42,6 +43,7 @@
|
|
42 |
"5fd9f97986024e8db560a6737ade9e2e": {
|
43 |
"model_module": "@jupyter-widgets/base",
|
44 |
"model_name": "LayoutModel",
|
|
|
45 |
"state": {
|
46 |
"_view_name": "LayoutView",
|
47 |
"grid_template_rows": null,
|
@@ -93,6 +95,7 @@
|
|
93 |
"caced43e3a4c493b98fb07cb41db045c": {
|
94 |
"model_module": "@jupyter-widgets/controls",
|
95 |
"model_name": "FloatProgressModel",
|
|
|
96 |
"state": {
|
97 |
"_view_name": "ProgressView",
|
98 |
"style": "IPY_MODEL_40c54b9454d346aabd197f2bcf189467",
|
@@ -116,6 +119,7 @@
|
|
116 |
"0acc161f2e9948b68b3fc4e57ef333c9": {
|
117 |
"model_module": "@jupyter-widgets/controls",
|
118 |
"model_name": "HTMLModel",
|
|
|
119 |
"state": {
|
120 |
"_view_name": "HTMLView",
|
121 |
"style": "IPY_MODEL_7e7c488f57fc4acb8d261e2db81d61f0",
|
@@ -136,6 +140,7 @@
|
|
136 |
"40c54b9454d346aabd197f2bcf189467": {
|
137 |
"model_module": "@jupyter-widgets/controls",
|
138 |
"model_name": "ProgressStyleModel",
|
|
|
139 |
"state": {
|
140 |
"_view_name": "StyleView",
|
141 |
"_model_name": "ProgressStyleModel",
|
@@ -151,6 +156,7 @@
|
|
151 |
"8b25334a48244a14aa9ba0176887e655": {
|
152 |
"model_module": "@jupyter-widgets/base",
|
153 |
"model_name": "LayoutModel",
|
|
|
154 |
"state": {
|
155 |
"_view_name": "LayoutView",
|
156 |
"grid_template_rows": null,
|
@@ -202,6 +208,7 @@
|
|
202 |
"7e7c488f57fc4acb8d261e2db81d61f0": {
|
203 |
"model_module": "@jupyter-widgets/controls",
|
204 |
"model_name": "DescriptionStyleModel",
|
|
|
205 |
"state": {
|
206 |
"_view_name": "StyleView",
|
207 |
"_model_name": "DescriptionStyleModel",
|
@@ -216,6 +223,7 @@
|
|
216 |
"72c401062a5348b1a366dffb5a403568": {
|
217 |
"model_module": "@jupyter-widgets/base",
|
218 |
"model_name": "LayoutModel",
|
|
|
219 |
"state": {
|
220 |
"_view_name": "LayoutView",
|
221 |
"grid_template_rows": null,
|
@@ -267,6 +275,7 @@
|
|
267 |
"022c124dfff348f285335732781b0887": {
|
268 |
"model_module": "@jupyter-widgets/controls",
|
269 |
"model_name": "HBoxModel",
|
|
|
270 |
"state": {
|
271 |
"_view_name": "HBoxView",
|
272 |
"_dom_classes": [],
|
@@ -287,6 +296,7 @@
|
|
287 |
"a44e47e9d26c4deb81a5a11a9db92a9f": {
|
288 |
"model_module": "@jupyter-widgets/base",
|
289 |
"model_name": "LayoutModel",
|
|
|
290 |
"state": {
|
291 |
"_view_name": "LayoutView",
|
292 |
"grid_template_rows": null,
|
@@ -338,6 +348,7 @@
|
|
338 |
"cd9c7016caae47c1b41fb2608c78b0bf": {
|
339 |
"model_module": "@jupyter-widgets/controls",
|
340 |
"model_name": "FloatProgressModel",
|
|
|
341 |
"state": {
|
342 |
"_view_name": "ProgressView",
|
343 |
"style": "IPY_MODEL_c22f207311cf4fb69bd9328eabfd4ebb",
|
@@ -361,6 +372,7 @@
|
|
361 |
"36ff1d0fea4b47e2ae35aa6bfae6a5e8": {
|
362 |
"model_module": "@jupyter-widgets/controls",
|
363 |
"model_name": "HTMLModel",
|
|
|
364 |
"state": {
|
365 |
"_view_name": "HTMLView",
|
366 |
"style": "IPY_MODEL_037563a7eadd4ac5abb7249a2914d346",
|
@@ -381,6 +393,7 @@
|
|
381 |
"c22f207311cf4fb69bd9328eabfd4ebb": {
|
382 |
"model_module": "@jupyter-widgets/controls",
|
383 |
"model_name": "ProgressStyleModel",
|
|
|
384 |
"state": {
|
385 |
"_view_name": "StyleView",
|
386 |
"_model_name": "ProgressStyleModel",
|
@@ -396,6 +409,7 @@
|
|
396 |
"5a38c6d83a264bedbf7efe6e97eba953": {
|
397 |
"model_module": "@jupyter-widgets/base",
|
398 |
"model_name": "LayoutModel",
|
|
|
399 |
"state": {
|
400 |
"_view_name": "LayoutView",
|
401 |
"grid_template_rows": null,
|
@@ -447,6 +461,7 @@
|
|
447 |
"037563a7eadd4ac5abb7249a2914d346": {
|
448 |
"model_module": "@jupyter-widgets/controls",
|
449 |
"model_name": "DescriptionStyleModel",
|
|
|
450 |
"state": {
|
451 |
"_view_name": "StyleView",
|
452 |
"_model_name": "DescriptionStyleModel",
|
@@ -461,6 +476,7 @@
|
|
461 |
"3975e7ed0b704990b1fa05909a9bb9b6": {
|
462 |
"model_module": "@jupyter-widgets/base",
|
463 |
"model_name": "LayoutModel",
|
|
|
464 |
"state": {
|
465 |
"_view_name": "LayoutView",
|
466 |
"grid_template_rows": null,
|
@@ -512,6 +528,7 @@
|
|
512 |
"f9f1fdc3819a4142b85304cd3c6358a2": {
|
513 |
"model_module": "@jupyter-widgets/controls",
|
514 |
"model_name": "HBoxModel",
|
|
|
515 |
"state": {
|
516 |
"_view_name": "HBoxView",
|
517 |
"_dom_classes": [],
|
@@ -532,6 +549,7 @@
|
|
532 |
"ea9ed54e7c9d4ead8b3e1ff4cb27fa61": {
|
533 |
"model_module": "@jupyter-widgets/base",
|
534 |
"model_name": "LayoutModel",
|
|
|
535 |
"state": {
|
536 |
"_view_name": "LayoutView",
|
537 |
"grid_template_rows": null,
|
@@ -583,6 +601,7 @@
|
|
583 |
"29d42e94b3b34c86a117b623da68faed": {
|
584 |
"model_module": "@jupyter-widgets/controls",
|
585 |
"model_name": "FloatProgressModel",
|
|
|
586 |
"state": {
|
587 |
"_view_name": "ProgressView",
|
588 |
"style": "IPY_MODEL_8ce4d20d004a4382afa0abdd3b1f7191",
|
@@ -606,6 +625,7 @@
|
|
606 |
"8b73de7dbdfe40dbbb39fb593520b984": {
|
607 |
"model_module": "@jupyter-widgets/controls",
|
608 |
"model_name": "HTMLModel",
|
|
|
609 |
"state": {
|
610 |
"_view_name": "HTMLView",
|
611 |
"style": "IPY_MODEL_717ccef4df1f477abb51814650eb47da",
|
@@ -626,6 +646,7 @@
|
|
626 |
"8ce4d20d004a4382afa0abdd3b1f7191": {
|
627 |
"model_module": "@jupyter-widgets/controls",
|
628 |
"model_name": "ProgressStyleModel",
|
|
|
629 |
"state": {
|
630 |
"_view_name": "StyleView",
|
631 |
"_model_name": "ProgressStyleModel",
|
@@ -641,6 +662,7 @@
|
|
641 |
"efc4812245c8459c92e6436889b4f600": {
|
642 |
"model_module": "@jupyter-widgets/base",
|
643 |
"model_name": "LayoutModel",
|
|
|
644 |
"state": {
|
645 |
"_view_name": "LayoutView",
|
646 |
"grid_template_rows": null,
|
@@ -692,6 +714,7 @@
|
|
692 |
"717ccef4df1f477abb51814650eb47da": {
|
693 |
"model_module": "@jupyter-widgets/controls",
|
694 |
"model_name": "DescriptionStyleModel",
|
|
|
695 |
"state": {
|
696 |
"_view_name": "StyleView",
|
697 |
"_model_name": "DescriptionStyleModel",
|
@@ -706,6 +729,7 @@
|
|
706 |
"7dba58f0391c485a86e34e8039ec6189": {
|
707 |
"model_module": "@jupyter-widgets/base",
|
708 |
"model_name": "LayoutModel",
|
|
|
709 |
"state": {
|
710 |
"_view_name": "LayoutView",
|
711 |
"grid_template_rows": null,
|
@@ -804,8 +828,7 @@
|
|
804 |
"source": [
|
805 |
"!pip install -q transformers flax\n",
|
806 |
"!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git # VQGAN model in JAX\n",
|
807 |
-
"!
|
808 |
-
"%cd dalle-mini/"
|
809 |
],
|
810 |
"execution_count": null,
|
811 |
"outputs": []
|
@@ -833,7 +856,7 @@
|
|
833 |
"import random\n",
|
834 |
"from tqdm.notebook import tqdm, trange"
|
835 |
],
|
836 |
-
"execution_count":
|
837 |
"outputs": []
|
838 |
},
|
839 |
{
|
@@ -846,7 +869,7 @@
|
|
846 |
"DALLE_REPO = 'flax-community/dalle-mini'\n",
|
847 |
"DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'"
|
848 |
],
|
849 |
-
"execution_count":
|
850 |
"outputs": []
|
851 |
},
|
852 |
{
|
@@ -871,7 +894,7 @@
|
|
871 |
"# set a prompt\n",
|
872 |
"prompt = 'picture of a waterfall under the sunset'"
|
873 |
],
|
874 |
-
"execution_count":
|
875 |
"outputs": []
|
876 |
},
|
877 |
{
|
@@ -888,7 +911,7 @@
|
|
888 |
"tokenized_prompt = tokenizer(prompt, return_tensors='jax', padding='max_length', truncation=True, max_length=128)\n",
|
889 |
"tokenized_prompt"
|
890 |
],
|
891 |
-
"execution_count":
|
892 |
"outputs": [
|
893 |
{
|
894 |
"output_type": "execute_result",
|
@@ -956,7 +979,7 @@
|
|
956 |
"subkeys = jax.random.split(key, num=n_predictions)\n",
|
957 |
"subkeys"
|
958 |
],
|
959 |
-
"execution_count":
|
960 |
"outputs": [
|
961 |
{
|
962 |
"output_type": "execute_result",
|
@@ -1004,7 +1027,7 @@
|
|
1004 |
"encoded_images = [model.generate(**tokenized_prompt, do_sample=True, num_beams=1, prng_key=subkey) for subkey in tqdm(subkeys)]\n",
|
1005 |
"encoded_images[0]"
|
1006 |
],
|
1007 |
-
"execution_count":
|
1008 |
"outputs": [
|
1009 |
{
|
1010 |
"output_type": "display_data",
|
@@ -1099,7 +1122,7 @@
|
|
1099 |
"encoded_images = [img.sequences[..., 1:] for img in encoded_images]\n",
|
1100 |
"encoded_images[0]"
|
1101 |
],
|
1102 |
-
"execution_count":
|
1103 |
"outputs": [
|
1104 |
{
|
1105 |
"output_type": "execute_result",
|
@@ -1167,7 +1190,7 @@
|
|
1167 |
"source": [
|
1168 |
"encoded_images[0].shape"
|
1169 |
],
|
1170 |
-
"execution_count":
|
1171 |
"outputs": [
|
1172 |
{
|
1173 |
"output_type": "execute_result",
|
@@ -1204,7 +1227,7 @@
|
|
1204 |
"import numpy as np\n",
|
1205 |
"from PIL import Image"
|
1206 |
],
|
1207 |
-
"execution_count":
|
1208 |
"outputs": []
|
1209 |
},
|
1210 |
{
|
@@ -1217,7 +1240,7 @@
|
|
1217 |
"VQGAN_REPO = 'flax-community/vqgan_f16_16384'\n",
|
1218 |
"VQGAN_COMMIT_ID = '90cc46addd2dd8f5be21586a9a23e1b95aa506a9'"
|
1219 |
],
|
1220 |
-
"execution_count":
|
1221 |
"outputs": []
|
1222 |
},
|
1223 |
{
|
@@ -1233,7 +1256,7 @@
|
|
1233 |
"# set up VQGAN\n",
|
1234 |
"vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)"
|
1235 |
],
|
1236 |
-
"execution_count":
|
1237 |
"outputs": [
|
1238 |
{
|
1239 |
"output_type": "stream",
|
@@ -1269,7 +1292,7 @@
|
|
1269 |
"decoded_images = [vqgan.decode_code(encoded_image) for encoded_image in tqdm(encoded_images)]\n",
|
1270 |
"decoded_images[0]"
|
1271 |
],
|
1272 |
-
"execution_count":
|
1273 |
"outputs": [
|
1274 |
{
|
1275 |
"output_type": "display_data",
|
@@ -1373,7 +1396,7 @@
|
|
1373 |
"# normalize images\n",
|
1374 |
"clipped_images = [img.squeeze().clip(0., 1.) for img in decoded_images]"
|
1375 |
],
|
1376 |
-
"execution_count":
|
1377 |
"outputs": []
|
1378 |
},
|
1379 |
{
|
@@ -1385,7 +1408,7 @@
|
|
1385 |
"# convert to image\n",
|
1386 |
"images = [Image.fromarray(np.asarray(img * 255, dtype=np.uint8)) for img in clipped_images]"
|
1387 |
],
|
1388 |
-
"execution_count":
|
1389 |
"outputs": []
|
1390 |
},
|
1391 |
{
|
@@ -1402,7 +1425,7 @@
|
|
1402 |
"# display an image\n",
|
1403 |
"images[0]"
|
1404 |
],
|
1405 |
-
"execution_count":
|
1406 |
"outputs": [
|
1407 |
{
|
1408 |
"output_type": "execute_result",
|
@@ -1438,7 +1461,7 @@
|
|
1438 |
"source": [
|
1439 |
"from transformers import CLIPProcessor, FlaxCLIPModel"
|
1440 |
],
|
1441 |
-
"execution_count":
|
1442 |
"outputs": []
|
1443 |
},
|
1444 |
{
|
@@ -1474,7 +1497,7 @@
|
|
1474 |
"logits = clip(**inputs).logits_per_image\n",
|
1475 |
"scores = jax.nn.softmax(logits, axis=0).squeeze() # normalize and sum all scores to 1"
|
1476 |
],
|
1477 |
-
"execution_count":
|
1478 |
"outputs": []
|
1479 |
},
|
1480 |
{
|
@@ -1495,7 +1518,7 @@
|
|
1495 |
" display(images[idx])\n",
|
1496 |
" print()"
|
1497 |
],
|
1498 |
-
"execution_count":
|
1499 |
"outputs": [
|
1500 |
{
|
1501 |
"output_type": "stream",
|
@@ -1690,7 +1713,7 @@
|
|
1690 |
"from flax.training.common_utils import shard\n",
|
1691 |
"from flax.jax_utils import replicate"
|
1692 |
],
|
1693 |
-
"execution_count":
|
1694 |
"outputs": []
|
1695 |
},
|
1696 |
{
|
@@ -1706,7 +1729,7 @@
|
|
1706 |
"# check we can access TPU's or GPU's\n",
|
1707 |
"jax.devices()"
|
1708 |
],
|
1709 |
-
"execution_count":
|
1710 |
"outputs": [
|
1711 |
{
|
1712 |
"output_type": "execute_result",
|
@@ -1744,7 +1767,7 @@
|
|
1744 |
"# one set of inputs per device\n",
|
1745 |
"prompt = ['picture of a waterfall under the sunset'] * jax.device_count()"
|
1746 |
],
|
1747 |
-
"execution_count":
|
1748 |
"outputs": []
|
1749 |
},
|
1750 |
{
|
@@ -1757,7 +1780,7 @@
|
|
1757 |
"tokenized_prompt = tokenizer(prompt, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data\n",
|
1758 |
"tokenized_prompt = shard(tokenized_prompt)"
|
1759 |
],
|
1760 |
-
"execution_count":
|
1761 |
"outputs": []
|
1762 |
},
|
1763 |
{
|
@@ -1793,7 +1816,7 @@
|
|
1793 |
"def p_decode(indices, params):\n",
|
1794 |
" return vqgan.decode_code(indices, params=params)"
|
1795 |
],
|
1796 |
-
"execution_count":
|
1797 |
"outputs": []
|
1798 |
},
|
1799 |
{
|
@@ -1834,7 +1857,7 @@
|
|
1834 |
" for img in decoded_images:\n",
|
1835 |
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
|
1836 |
],
|
1837 |
-
"execution_count":
|
1838 |
"outputs": [
|
1839 |
{
|
1840 |
"output_type": "display_data",
|
@@ -1877,7 +1900,7 @@
|
|
1877 |
" display(img)\n",
|
1878 |
" print()"
|
1879 |
],
|
1880 |
-
"execution_count":
|
1881 |
"outputs": [
|
1882 |
{
|
1883 |
"output_type": "display_data",
|
|
|
6 |
"name": "DALL·E mini - Inference pipeline.ipynb",
|
7 |
"provenance": [],
|
8 |
"collapsed_sections": [],
|
9 |
+
"authorship_tag": "ABX9TyMUjEt1XMLq+6/GhSnVFsSx",
|
10 |
"include_colab_link": true
|
11 |
},
|
12 |
"kernelspec": {
|
|
|
22 |
"49304912717a4995ae45d04a59d1f50e": {
|
23 |
"model_module": "@jupyter-widgets/controls",
|
24 |
"model_name": "HBoxModel",
|
25 |
+
"model_module_version": "1.5.0",
|
26 |
"state": {
|
27 |
"_view_name": "HBoxView",
|
28 |
"_dom_classes": [],
|
|
|
43 |
"5fd9f97986024e8db560a6737ade9e2e": {
|
44 |
"model_module": "@jupyter-widgets/base",
|
45 |
"model_name": "LayoutModel",
|
46 |
+
"model_module_version": "1.2.0",
|
47 |
"state": {
|
48 |
"_view_name": "LayoutView",
|
49 |
"grid_template_rows": null,
|
|
|
95 |
"caced43e3a4c493b98fb07cb41db045c": {
|
96 |
"model_module": "@jupyter-widgets/controls",
|
97 |
"model_name": "FloatProgressModel",
|
98 |
+
"model_module_version": "1.5.0",
|
99 |
"state": {
|
100 |
"_view_name": "ProgressView",
|
101 |
"style": "IPY_MODEL_40c54b9454d346aabd197f2bcf189467",
|
|
|
119 |
"0acc161f2e9948b68b3fc4e57ef333c9": {
|
120 |
"model_module": "@jupyter-widgets/controls",
|
121 |
"model_name": "HTMLModel",
|
122 |
+
"model_module_version": "1.5.0",
|
123 |
"state": {
|
124 |
"_view_name": "HTMLView",
|
125 |
"style": "IPY_MODEL_7e7c488f57fc4acb8d261e2db81d61f0",
|
|
|
140 |
"40c54b9454d346aabd197f2bcf189467": {
|
141 |
"model_module": "@jupyter-widgets/controls",
|
142 |
"model_name": "ProgressStyleModel",
|
143 |
+
"model_module_version": "1.5.0",
|
144 |
"state": {
|
145 |
"_view_name": "StyleView",
|
146 |
"_model_name": "ProgressStyleModel",
|
|
|
156 |
"8b25334a48244a14aa9ba0176887e655": {
|
157 |
"model_module": "@jupyter-widgets/base",
|
158 |
"model_name": "LayoutModel",
|
159 |
+
"model_module_version": "1.2.0",
|
160 |
"state": {
|
161 |
"_view_name": "LayoutView",
|
162 |
"grid_template_rows": null,
|
|
|
208 |
"7e7c488f57fc4acb8d261e2db81d61f0": {
|
209 |
"model_module": "@jupyter-widgets/controls",
|
210 |
"model_name": "DescriptionStyleModel",
|
211 |
+
"model_module_version": "1.5.0",
|
212 |
"state": {
|
213 |
"_view_name": "StyleView",
|
214 |
"_model_name": "DescriptionStyleModel",
|
|
|
223 |
"72c401062a5348b1a366dffb5a403568": {
|
224 |
"model_module": "@jupyter-widgets/base",
|
225 |
"model_name": "LayoutModel",
|
226 |
+
"model_module_version": "1.2.0",
|
227 |
"state": {
|
228 |
"_view_name": "LayoutView",
|
229 |
"grid_template_rows": null,
|
|
|
275 |
"022c124dfff348f285335732781b0887": {
|
276 |
"model_module": "@jupyter-widgets/controls",
|
277 |
"model_name": "HBoxModel",
|
278 |
+
"model_module_version": "1.5.0",
|
279 |
"state": {
|
280 |
"_view_name": "HBoxView",
|
281 |
"_dom_classes": [],
|
|
|
296 |
"a44e47e9d26c4deb81a5a11a9db92a9f": {
|
297 |
"model_module": "@jupyter-widgets/base",
|
298 |
"model_name": "LayoutModel",
|
299 |
+
"model_module_version": "1.2.0",
|
300 |
"state": {
|
301 |
"_view_name": "LayoutView",
|
302 |
"grid_template_rows": null,
|
|
|
348 |
"cd9c7016caae47c1b41fb2608c78b0bf": {
|
349 |
"model_module": "@jupyter-widgets/controls",
|
350 |
"model_name": "FloatProgressModel",
|
351 |
+
"model_module_version": "1.5.0",
|
352 |
"state": {
|
353 |
"_view_name": "ProgressView",
|
354 |
"style": "IPY_MODEL_c22f207311cf4fb69bd9328eabfd4ebb",
|
|
|
372 |
"36ff1d0fea4b47e2ae35aa6bfae6a5e8": {
|
373 |
"model_module": "@jupyter-widgets/controls",
|
374 |
"model_name": "HTMLModel",
|
375 |
+
"model_module_version": "1.5.0",
|
376 |
"state": {
|
377 |
"_view_name": "HTMLView",
|
378 |
"style": "IPY_MODEL_037563a7eadd4ac5abb7249a2914d346",
|
|
|
393 |
"c22f207311cf4fb69bd9328eabfd4ebb": {
|
394 |
"model_module": "@jupyter-widgets/controls",
|
395 |
"model_name": "ProgressStyleModel",
|
396 |
+
"model_module_version": "1.5.0",
|
397 |
"state": {
|
398 |
"_view_name": "StyleView",
|
399 |
"_model_name": "ProgressStyleModel",
|
|
|
409 |
"5a38c6d83a264bedbf7efe6e97eba953": {
|
410 |
"model_module": "@jupyter-widgets/base",
|
411 |
"model_name": "LayoutModel",
|
412 |
+
"model_module_version": "1.2.0",
|
413 |
"state": {
|
414 |
"_view_name": "LayoutView",
|
415 |
"grid_template_rows": null,
|
|
|
461 |
"037563a7eadd4ac5abb7249a2914d346": {
|
462 |
"model_module": "@jupyter-widgets/controls",
|
463 |
"model_name": "DescriptionStyleModel",
|
464 |
+
"model_module_version": "1.5.0",
|
465 |
"state": {
|
466 |
"_view_name": "StyleView",
|
467 |
"_model_name": "DescriptionStyleModel",
|
|
|
476 |
"3975e7ed0b704990b1fa05909a9bb9b6": {
|
477 |
"model_module": "@jupyter-widgets/base",
|
478 |
"model_name": "LayoutModel",
|
479 |
+
"model_module_version": "1.2.0",
|
480 |
"state": {
|
481 |
"_view_name": "LayoutView",
|
482 |
"grid_template_rows": null,
|
|
|
528 |
"f9f1fdc3819a4142b85304cd3c6358a2": {
|
529 |
"model_module": "@jupyter-widgets/controls",
|
530 |
"model_name": "HBoxModel",
|
531 |
+
"model_module_version": "1.5.0",
|
532 |
"state": {
|
533 |
"_view_name": "HBoxView",
|
534 |
"_dom_classes": [],
|
|
|
549 |
"ea9ed54e7c9d4ead8b3e1ff4cb27fa61": {
|
550 |
"model_module": "@jupyter-widgets/base",
|
551 |
"model_name": "LayoutModel",
|
552 |
+
"model_module_version": "1.2.0",
|
553 |
"state": {
|
554 |
"_view_name": "LayoutView",
|
555 |
"grid_template_rows": null,
|
|
|
601 |
"29d42e94b3b34c86a117b623da68faed": {
|
602 |
"model_module": "@jupyter-widgets/controls",
|
603 |
"model_name": "FloatProgressModel",
|
604 |
+
"model_module_version": "1.5.0",
|
605 |
"state": {
|
606 |
"_view_name": "ProgressView",
|
607 |
"style": "IPY_MODEL_8ce4d20d004a4382afa0abdd3b1f7191",
|
|
|
625 |
"8b73de7dbdfe40dbbb39fb593520b984": {
|
626 |
"model_module": "@jupyter-widgets/controls",
|
627 |
"model_name": "HTMLModel",
|
628 |
+
"model_module_version": "1.5.0",
|
629 |
"state": {
|
630 |
"_view_name": "HTMLView",
|
631 |
"style": "IPY_MODEL_717ccef4df1f477abb51814650eb47da",
|
|
|
646 |
"8ce4d20d004a4382afa0abdd3b1f7191": {
|
647 |
"model_module": "@jupyter-widgets/controls",
|
648 |
"model_name": "ProgressStyleModel",
|
649 |
+
"model_module_version": "1.5.0",
|
650 |
"state": {
|
651 |
"_view_name": "StyleView",
|
652 |
"_model_name": "ProgressStyleModel",
|
|
|
662 |
"efc4812245c8459c92e6436889b4f600": {
|
663 |
"model_module": "@jupyter-widgets/base",
|
664 |
"model_name": "LayoutModel",
|
665 |
+
"model_module_version": "1.2.0",
|
666 |
"state": {
|
667 |
"_view_name": "LayoutView",
|
668 |
"grid_template_rows": null,
|
|
|
714 |
"717ccef4df1f477abb51814650eb47da": {
|
715 |
"model_module": "@jupyter-widgets/controls",
|
716 |
"model_name": "DescriptionStyleModel",
|
717 |
+
"model_module_version": "1.5.0",
|
718 |
"state": {
|
719 |
"_view_name": "StyleView",
|
720 |
"_model_name": "DescriptionStyleModel",
|
|
|
729 |
"7dba58f0391c485a86e34e8039ec6189": {
|
730 |
"model_module": "@jupyter-widgets/base",
|
731 |
"model_name": "LayoutModel",
|
732 |
+
"model_module_version": "1.2.0",
|
733 |
"state": {
|
734 |
"_view_name": "LayoutView",
|
735 |
"grid_template_rows": null,
|
|
|
828 |
"source": [
|
829 |
"!pip install -q transformers flax\n",
|
830 |
"!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git # VQGAN model in JAX\n",
|
831 |
+
"!pip install -q git+https://github.com/borisdayma/dalle-mini.git # Model files"
|
|
|
832 |
],
|
833 |
"execution_count": null,
|
834 |
"outputs": []
|
|
|
856 |
"import random\n",
|
857 |
"from tqdm.notebook import tqdm, trange"
|
858 |
],
|
859 |
+
"execution_count": null,
|
860 |
"outputs": []
|
861 |
},
|
862 |
{
|
|
|
869 |
"DALLE_REPO = 'flax-community/dalle-mini'\n",
|
870 |
"DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'"
|
871 |
],
|
872 |
+
"execution_count": null,
|
873 |
"outputs": []
|
874 |
},
|
875 |
{
|
|
|
894 |
"# set a prompt\n",
|
895 |
"prompt = 'picture of a waterfall under the sunset'"
|
896 |
],
|
897 |
+
"execution_count": null,
|
898 |
"outputs": []
|
899 |
},
|
900 |
{
|
|
|
911 |
"tokenized_prompt = tokenizer(prompt, return_tensors='jax', padding='max_length', truncation=True, max_length=128)\n",
|
912 |
"tokenized_prompt"
|
913 |
],
|
914 |
+
"execution_count": null,
|
915 |
"outputs": [
|
916 |
{
|
917 |
"output_type": "execute_result",
|
|
|
979 |
"subkeys = jax.random.split(key, num=n_predictions)\n",
|
980 |
"subkeys"
|
981 |
],
|
982 |
+
"execution_count": null,
|
983 |
"outputs": [
|
984 |
{
|
985 |
"output_type": "execute_result",
|
|
|
1027 |
"encoded_images = [model.generate(**tokenized_prompt, do_sample=True, num_beams=1, prng_key=subkey) for subkey in tqdm(subkeys)]\n",
|
1028 |
"encoded_images[0]"
|
1029 |
],
|
1030 |
+
"execution_count": null,
|
1031 |
"outputs": [
|
1032 |
{
|
1033 |
"output_type": "display_data",
|
|
|
1122 |
"encoded_images = [img.sequences[..., 1:] for img in encoded_images]\n",
|
1123 |
"encoded_images[0]"
|
1124 |
],
|
1125 |
+
"execution_count": null,
|
1126 |
"outputs": [
|
1127 |
{
|
1128 |
"output_type": "execute_result",
|
|
|
1190 |
"source": [
|
1191 |
"encoded_images[0].shape"
|
1192 |
],
|
1193 |
+
"execution_count": null,
|
1194 |
"outputs": [
|
1195 |
{
|
1196 |
"output_type": "execute_result",
|
|
|
1227 |
"import numpy as np\n",
|
1228 |
"from PIL import Image"
|
1229 |
],
|
1230 |
+
"execution_count": null,
|
1231 |
"outputs": []
|
1232 |
},
|
1233 |
{
|
|
|
1240 |
"VQGAN_REPO = 'flax-community/vqgan_f16_16384'\n",
|
1241 |
"VQGAN_COMMIT_ID = '90cc46addd2dd8f5be21586a9a23e1b95aa506a9'"
|
1242 |
],
|
1243 |
+
"execution_count": null,
|
1244 |
"outputs": []
|
1245 |
},
|
1246 |
{
|
|
|
1256 |
"# set up VQGAN\n",
|
1257 |
"vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)"
|
1258 |
],
|
1259 |
+
"execution_count": null,
|
1260 |
"outputs": [
|
1261 |
{
|
1262 |
"output_type": "stream",
|
|
|
1292 |
"decoded_images = [vqgan.decode_code(encoded_image) for encoded_image in tqdm(encoded_images)]\n",
|
1293 |
"decoded_images[0]"
|
1294 |
],
|
1295 |
+
"execution_count": null,
|
1296 |
"outputs": [
|
1297 |
{
|
1298 |
"output_type": "display_data",
|
|
|
1396 |
"# normalize images\n",
|
1397 |
"clipped_images = [img.squeeze().clip(0., 1.) for img in decoded_images]"
|
1398 |
],
|
1399 |
+
"execution_count": null,
|
1400 |
"outputs": []
|
1401 |
},
|
1402 |
{
|
|
|
1408 |
"# convert to image\n",
|
1409 |
"images = [Image.fromarray(np.asarray(img * 255, dtype=np.uint8)) for img in clipped_images]"
|
1410 |
],
|
1411 |
+
"execution_count": null,
|
1412 |
"outputs": []
|
1413 |
},
|
1414 |
{
|
|
|
1425 |
"# display an image\n",
|
1426 |
"images[0]"
|
1427 |
],
|
1428 |
+
"execution_count": null,
|
1429 |
"outputs": [
|
1430 |
{
|
1431 |
"output_type": "execute_result",
|
|
|
1461 |
"source": [
|
1462 |
"from transformers import CLIPProcessor, FlaxCLIPModel"
|
1463 |
],
|
1464 |
+
"execution_count": null,
|
1465 |
"outputs": []
|
1466 |
},
|
1467 |
{
|
|
|
1497 |
"logits = clip(**inputs).logits_per_image\n",
|
1498 |
"scores = jax.nn.softmax(logits, axis=0).squeeze() # normalize and sum all scores to 1"
|
1499 |
],
|
1500 |
+
"execution_count": null,
|
1501 |
"outputs": []
|
1502 |
},
|
1503 |
{
|
|
|
1518 |
" display(images[idx])\n",
|
1519 |
" print()"
|
1520 |
],
|
1521 |
+
"execution_count": null,
|
1522 |
"outputs": [
|
1523 |
{
|
1524 |
"output_type": "stream",
|
|
|
1713 |
"from flax.training.common_utils import shard\n",
|
1714 |
"from flax.jax_utils import replicate"
|
1715 |
],
|
1716 |
+
"execution_count": null,
|
1717 |
"outputs": []
|
1718 |
},
|
1719 |
{
|
|
|
1729 |
"# check we can access TPU's or GPU's\n",
|
1730 |
"jax.devices()"
|
1731 |
],
|
1732 |
+
"execution_count": null,
|
1733 |
"outputs": [
|
1734 |
{
|
1735 |
"output_type": "execute_result",
|
|
|
1767 |
"# one set of inputs per device\n",
|
1768 |
"prompt = ['picture of a waterfall under the sunset'] * jax.device_count()"
|
1769 |
],
|
1770 |
+
"execution_count": null,
|
1771 |
"outputs": []
|
1772 |
},
|
1773 |
{
|
|
|
1780 |
"tokenized_prompt = tokenizer(prompt, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data\n",
|
1781 |
"tokenized_prompt = shard(tokenized_prompt)"
|
1782 |
],
|
1783 |
+
"execution_count": null,
|
1784 |
"outputs": []
|
1785 |
},
|
1786 |
{
|
|
|
1816 |
"def p_decode(indices, params):\n",
|
1817 |
" return vqgan.decode_code(indices, params=params)"
|
1818 |
],
|
1819 |
+
"execution_count": null,
|
1820 |
"outputs": []
|
1821 |
},
|
1822 |
{
|
|
|
1857 |
" for img in decoded_images:\n",
|
1858 |
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
|
1859 |
],
|
1860 |
+
"execution_count": null,
|
1861 |
"outputs": [
|
1862 |
{
|
1863 |
"output_type": "display_data",
|
|
|
1900 |
" display(img)\n",
|
1901 |
" print()"
|
1902 |
],
|
1903 |
+
"execution_count": null,
|
1904 |
"outputs": [
|
1905 |
{
|
1906 |
"output_type": "display_data",
|