boris commited on
Commit
badd15c
1 Parent(s): 8845d77

fix(colab): use correct param name for CLIP

Browse files
Files changed (1) hide show
  1. tools/inference/inference_pipeline.ipynb +466 -849
tools/inference/inference_pipeline.ipynb CHANGED
@@ -1,865 +1,482 @@
1
  {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {
6
- "id": "118UKH5bWCGa"
7
- },
8
- "source": [
9
- "# DALL·E mini - Inference pipeline\n",
10
- "\n",
11
- "*Generate images from a text prompt*\n",
12
- "\n",
13
- "<img src=\"https://github.com/borisdayma/dalle-mini/blob/main/img/logo.png?raw=true\" width=\"200\">\n",
14
- "\n",
15
- "This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
16
- "\n",
17
- "Just want to play? Use directly [DALL·E mini app](https://huggingface.co/spaces/dalle-mini/dalle-mini).\n",
18
- "\n",
19
- "For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
20
- ]
21
- },
22
- {
23
- "cell_type": "markdown",
24
- "metadata": {
25
- "id": "dS8LbaonYm3a"
26
- },
27
- "source": [
28
- "## 🛠️ Installation and set-up"
29
- ]
30
- },
31
- {
32
- "cell_type": "code",
33
- "execution_count": null,
34
- "metadata": {
35
- "id": "uzjAM2GBYpZX"
36
- },
37
- "outputs": [],
38
- "source": [
39
- "# Install required libraries\n",
40
- "!pip install -q git+https://github.com/huggingface/transformers.git\n",
41
- "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
42
- "!pip install -q git+https://github.com/borisdayma/dalle-mini.git"
43
- ]
44
- },
45
- {
46
- "cell_type": "markdown",
47
- "metadata": {
48
- "id": "ozHzTkyv8cqU"
49
- },
50
- "source": [
51
- "We load required models:\n",
52
- "* DALL·E mini for text to encoded images\n",
53
- "* VQGAN for decoding images\n",
54
- "* CLIP for scoring predictions"
55
- ]
56
- },
57
- {
58
- "cell_type": "code",
59
- "execution_count": null,
60
- "metadata": {
61
- "id": "K6CxW2o42f-w"
62
- },
63
- "outputs": [],
64
- "source": [
65
- "# Model references\n",
66
- "\n",
67
- "# dalle-mega\n",
68
- "DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
69
- "DALLE_COMMIT_ID = None\n",
70
- "\n",
71
- "# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line\n",
72
- "# DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\"\n",
73
- "\n",
74
- "# VQGAN model\n",
75
- "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
76
- "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\""
77
- ]
78
- },
79
- {
80
- "cell_type": "code",
81
- "execution_count": null,
82
- "metadata": {
83
- "colab": {
84
- "base_uri": "https://localhost:8080/"
85
  },
86
- "id": "Yv-aR3t4Oe5v",
87
- "outputId": "3097b2c7-5dac-475f-edde-898799dd7294"
88
- },
89
- "outputs": [],
90
- "source": [
91
- "import jax\n",
92
- "import jax.numpy as jnp\n",
93
- "\n",
94
- "# check how many devices are available\n",
95
- "jax.local_device_count()"
96
- ]
97
- },
98
- {
99
- "cell_type": "code",
100
- "execution_count": null,
101
- "metadata": {
102
- "colab": {
103
- "base_uri": "https://localhost:8080/"
104
  },
105
- "id": "92zYmvsQ38vL",
106
- "outputId": "d897dfdb-dae7-4026-da36-8b23dce066e8"
107
- },
108
- "outputs": [],
109
- "source": [
110
- "# Load models & tokenizer\n",
111
- "from dalle_mini import DalleBart, DalleBartProcessor\n",
112
- "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
113
- "from transformers import CLIPProcessor, FlaxCLIPModel\n",
114
- "\n",
115
- "# Load dalle-mini\n",
116
- "model, params = DalleBart.from_pretrained(\n",
117
- " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
118
- ")\n",
119
- "\n",
120
- "# Load VQGAN\n",
121
- "vqgan, vqgan_params = VQModel.from_pretrained(\n",
122
- " VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False\n",
123
- ")"
124
- ]
125
- },
126
- {
127
- "cell_type": "markdown",
128
- "metadata": {
129
- "id": "o_vH2X1tDtzA"
130
- },
131
- "source": [
132
- "Model parameters are replicated on each device for faster inference."
133
- ]
134
- },
135
- {
136
- "cell_type": "code",
137
- "execution_count": null,
138
- "metadata": {
139
- "id": "wtvLoM48EeVw"
140
- },
141
- "outputs": [],
142
- "source": [
143
- "from flax.jax_utils import replicate\n",
144
- "\n",
145
- "params = replicate(params)\n",
146
- "vqgan_params = replicate(vqgan_params)"
147
- ]
148
- },
149
- {
150
- "cell_type": "markdown",
151
- "metadata": {
152
- "id": "0A9AHQIgZ_qw"
153
- },
154
- "source": [
155
- "Model functions are compiled and parallelized to take advantage of multiple devices."
156
- ]
157
- },
158
- {
159
- "cell_type": "code",
160
- "execution_count": null,
161
- "metadata": {
162
- "id": "sOtoOmYsSYPz"
163
- },
164
- "outputs": [],
165
- "source": [
166
- "from functools import partial\n",
167
- "\n",
168
- "# model inference\n",
169
- "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n",
170
- "def p_generate(\n",
171
- " tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n",
172
- "):\n",
173
- " return model.generate(\n",
174
- " **tokenized_prompt,\n",
175
- " prng_key=key,\n",
176
- " params=params,\n",
177
- " top_k=top_k,\n",
178
- " top_p=top_p,\n",
179
- " temperature=temperature,\n",
180
- " condition_scale=condition_scale,\n",
181
- " )\n",
182
- "\n",
183
- "\n",
184
- "# decode image\n",
185
- "@partial(jax.pmap, axis_name=\"batch\")\n",
186
- "def p_decode(indices, params):\n",
187
- " return vqgan.decode_code(indices, params=params)"
188
- ]
189
- },
190
- {
191
- "cell_type": "markdown",
192
- "metadata": {
193
- "id": "HmVN6IBwapBA"
194
- },
195
- "source": [
196
- "Keys are passed to the model on each device to generate unique inference per device."
197
- ]
198
- },
199
- {
200
- "cell_type": "code",
201
- "execution_count": null,
202
- "metadata": {
203
- "id": "4CTXmlUkThhX"
204
- },
205
- "outputs": [],
206
- "source": [
207
- "import random\n",
208
- "\n",
209
- "# create a random key\n",
210
- "seed = random.randint(0, 2**32 - 1)\n",
211
- "key = jax.random.PRNGKey(seed)"
212
- ]
213
- },
214
- {
215
- "cell_type": "markdown",
216
- "metadata": {
217
- "id": "BrnVyCo81pij"
218
- },
219
- "source": [
220
- "## 🖍 Text Prompt"
221
- ]
222
- },
223
- {
224
- "cell_type": "markdown",
225
- "metadata": {
226
- "id": "rsmj0Aj5OQox"
227
- },
228
- "source": [
229
- "Our model requires processing prompts."
230
- ]
231
- },
232
- {
233
- "cell_type": "code",
234
- "execution_count": null,
235
- "metadata": {
236
- "colab": {
237
- "base_uri": "https://localhost:8080/"
238
  },
239
- "id": "YjjhUychOVxm",
240
- "outputId": "a286f17a-a388-4754-ec4d-0464c0666c90"
241
- },
242
- "outputs": [],
243
- "source": [
244
- "from dalle_mini import DalleBartProcessor\n",
245
- "\n",
246
- "processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)"
247
- ]
248
- },
249
- {
250
- "cell_type": "markdown",
251
- "metadata": {
252
- "id": "BQ7fymSPyvF_"
253
- },
254
- "source": [
255
- "Let's define a text prompt."
256
- ]
257
- },
258
- {
259
- "cell_type": "code",
260
- "execution_count": null,
261
- "metadata": {
262
- "id": "x_0vI9ge1oKr"
263
- },
264
- "outputs": [],
265
- "source": [
266
- "prompt = \"sunset over a lake in the mountains\""
267
- ]
268
- },
269
- {
270
- "cell_type": "code",
271
- "execution_count": null,
272
- "metadata": {
273
- "id": "VKjEZGjtO49k"
274
- },
275
- "outputs": [],
276
- "source": [
277
- "tokenized_prompt = processor([prompt])"
278
- ]
279
- },
280
- {
281
- "cell_type": "markdown",
282
- "metadata": {
283
- "id": "-CEJBnuJOe5z"
284
- },
285
- "source": [
286
- "Finally we replicate it onto each device."
287
- ]
288
- },
289
- {
290
- "cell_type": "code",
291
- "execution_count": null,
292
- "metadata": {
293
- "id": "lQePgju5Oe5z"
294
- },
295
- "outputs": [],
296
- "source": [
297
- "tokenized_prompt = replicate(tokenized_prompt)"
298
- ]
299
- },
300
- {
301
- "cell_type": "markdown",
302
- "metadata": {
303
- "id": "phQ9bhjRkgAZ"
304
- },
305
- "source": [
306
- "## 🎨 Generate images\n",
307
- "\n",
308
- "We generate images using dalle-mini model and decode them with the VQGAN."
309
- ]
310
- },
311
- {
312
- "cell_type": "code",
313
- "execution_count": null,
314
- "metadata": {
315
- "id": "d0wVkXpKqnHA"
316
- },
317
- "outputs": [],
318
- "source": [
319
- "# number of predictions\n",
320
- "n_predictions = 8\n",
321
- "\n",
322
- "# We can customize generation parameters\n",
323
- "gen_top_k = None\n",
324
- "gen_top_p = None\n",
325
- "temperature = None\n",
326
- "cond_scale = 3.0"
327
- ]
328
- },
329
- {
330
- "cell_type": "code",
331
- "execution_count": null,
332
- "metadata": {
333
- "colab": {
334
- "base_uri": "https://localhost:8080/",
335
- "height": 1000,
336
- "referenced_widgets": [
337
- "cef76449b8d74217ae36c56be3990eec",
338
- "7be07ba7cfe642a596509c756dcefddc",
339
- "2a02378499fc414299f17a2d5dcac867",
340
- "427d47d9423441d286ae80a637ae35a0",
341
- "cb157fd4e37041d1beae29eaa729c8ff",
342
- "73413668398b45dfa8484a2c2be778ec",
343
- "e7d108a4b168442fb2048f58ddeb0a18",
344
- "5e81a141422f432395055f5cafb07016",
345
- "5f476a929da84fa985b2e980459da7b9",
346
- "f3b643a0ca2444fd959fff9b45d79d27",
347
- "82b87345233549d699ce3fd8080fa988"
348
- ]
349
  },
350
- "id": "SDjEx9JxR3v8",
351
- "outputId": "8f4287a7-aff9-41ef-a026-02265de0c205"
352
- },
353
- "outputs": [],
354
- "source": [
355
- "from flax.training.common_utils import shard_prng_key\n",
356
- "import numpy as np\n",
357
- "from PIL import Image\n",
358
- "from tqdm.notebook import trange\n",
359
- "\n",
360
- "print(f\"Prompt: {prompt}\\n\")\n",
361
- "# generate images\n",
362
- "images = []\n",
363
- "for i in trange(max(n_predictions // jax.device_count(), 1)):\n",
364
- " # get a new key\n",
365
- " key, subkey = jax.random.split(key)\n",
366
- " # generate images\n",
367
- " encoded_images = p_generate(\n",
368
- " tokenized_prompt,\n",
369
- " shard_prng_key(subkey),\n",
370
- " params,\n",
371
- " gen_top_k,\n",
372
- " gen_top_p,\n",
373
- " temperature,\n",
374
- " cond_scale,\n",
375
- " )\n",
376
- " # remove BOS\n",
377
- " encoded_images = encoded_images.sequences[..., 1:]\n",
378
- " # decode images\n",
379
- " decoded_images = p_decode(encoded_images, vqgan_params)\n",
380
- " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
381
- " for decoded_img in decoded_images:\n",
382
- " img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))\n",
383
- " images.append(img)\n",
384
- " display(img)"
385
- ]
386
- },
387
- {
388
- "cell_type": "markdown",
389
- "metadata": {
390
- "id": "tw02wG9zGmyB"
391
- },
392
- "source": [
393
- "## 🏅 Optional: Rank images by CLIP score\n",
394
- "\n",
395
- "We can rank images according to CLIP.\n",
396
- "\n",
397
- "**Note: your session may crash if you don't have a subscription to Colab Pro.**"
398
- ]
399
- },
400
- {
401
- "cell_type": "code",
402
- "execution_count": null,
403
- "metadata": {
404
- "id": "RGjlIW_f6GA0"
405
- },
406
- "outputs": [],
407
- "source": [
408
- "# CLIP model\n",
409
- "CLIP_REPO = \"openai/clip-vit-base-patch32\"\n",
410
- "CLIP_COMMIT_ID = None\n",
411
- "\n",
412
- "# Load CLIP\n",
413
- "clip, clip_params = FlaxCLIPModel.from_pretrained(\n",
414
- " CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
415
- ")\n",
416
- "clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
417
- "clip_params = replicate(clip_params)\n",
418
- "\n",
419
- "# score images\n",
420
- "@partial(jax.pmap, axis_name=\"batch\")\n",
421
- "def p_clip(inputs, params):\n",
422
- " logits = clip(params=params, **inputs).logits_per_image\n",
423
- " return logits"
424
- ]
425
- },
426
- {
427
- "cell_type": "code",
428
- "execution_count": null,
429
- "metadata": {
430
- "id": "FoLXpjCmGpju"
431
- },
432
- "outputs": [],
433
- "source": [
434
- "from flax.training.common_utils import shard\n",
435
- "\n",
436
- "# CLIP model\n",
437
- "CLIP_REPO = \"openai/clip-vit-base-patch32\"\n",
438
- "CLIP_COMMIT_ID = None\n",
439
- "\n",
440
- "# Load CLIP\n",
441
- "clip, clip_params = FlaxCLIPModel.from_pretrained(\n",
442
- " CLIP_REPO, revision=CLIP_COMMIT_ID, _do_init=False\n",
443
- ")\n",
444
- "clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
445
- "clip_params = replicate(clip_params)\n",
446
- "\n",
447
- "# score images\n",
448
- "@partial(jax.pmap, axis_name=\"batch\")\n",
449
- "def p_clip(inputs, params):\n",
450
- " logits = clip(params=params, **inputs).logits_per_image\n",
451
- " return logits\n",
452
- "\n",
453
- "\n",
454
- "# get clip scores\n",
455
- "clip_inputs = clip_processor(\n",
456
- " text=[prompt] * jax.device_count(),\n",
457
- " images=images,\n",
458
- " return_tensors=\"np\",\n",
459
- " padding=\"max_length\",\n",
460
- " max_length=77,\n",
461
- " truncation=True,\n",
462
- ").data\n",
463
- "logits = p_clip(shard(clip_inputs), clip.params)\n",
464
- "logits = logits.squeeze().flatten()"
465
- ]
466
- },
467
- {
468
- "cell_type": "markdown",
469
- "metadata": {
470
- "id": "4AAWRm70LgED"
471
- },
472
- "source": [
473
- "Let's now display images ranked by CLIP score."
474
- ]
475
- },
476
- {
477
- "cell_type": "code",
478
- "execution_count": null,
479
- "metadata": {
480
- "id": "zsgxxubLLkIu"
481
- },
482
- "outputs": [],
483
- "source": [
484
- "print(f\"Prompt: {prompt}\\n\")\n",
485
- "for idx in logits.argsort()[::-1]:\n",
486
- " display(images[idx])\n",
487
- " print(f\"Score: {logits[idx]:.2f}\\n\")"
488
- ]
489
- }
490
- ],
491
- "metadata": {
492
- "accelerator": "GPU",
493
- "colab": {
494
- "collapsed_sections": [],
495
- "machine_shape": "hm",
496
- "name": "DALL·E mini - Inference pipeline.ipynb",
497
- "provenance": []
498
- },
499
- "kernelspec": {
500
- "display_name": "Python 3 (ipykernel)",
501
- "language": "python",
502
- "name": "python3"
503
- },
504
- "language_info": {
505
- "codemirror_mode": {
506
- "name": "ipython",
507
- "version": 3
508
- },
509
- "file_extension": ".py",
510
- "mimetype": "text/x-python",
511
- "name": "python",
512
- "nbconvert_exporter": "python",
513
- "pygments_lexer": "ipython3",
514
- "version": "3.9.7"
515
- },
516
- "widgets": {
517
- "application/vnd.jupyter.widget-state+json": {
518
- "2a02378499fc414299f17a2d5dcac867": {
519
- "model_module": "@jupyter-widgets/controls",
520
- "model_module_version": "1.5.0",
521
- "model_name": "FloatProgressModel",
522
- "state": {
523
- "_dom_classes": [],
524
- "_model_module": "@jupyter-widgets/controls",
525
- "_model_module_version": "1.5.0",
526
- "_model_name": "FloatProgressModel",
527
- "_view_count": null,
528
- "_view_module": "@jupyter-widgets/controls",
529
- "_view_module_version": "1.5.0",
530
- "_view_name": "ProgressView",
531
- "bar_style": "",
532
- "description": "",
533
- "description_tooltip": null,
534
- "layout": "IPY_MODEL_5e81a141422f432395055f5cafb07016",
535
- "max": 8,
536
- "min": 0,
537
- "orientation": "horizontal",
538
- "style": "IPY_MODEL_5f476a929da84fa985b2e980459da7b9",
539
- "value": 5
540
- }
541
  },
542
- "427d47d9423441d286ae80a637ae35a0": {
543
- "model_module": "@jupyter-widgets/controls",
544
- "model_module_version": "1.5.0",
545
- "model_name": "HTMLModel",
546
- "state": {
547
- "_dom_classes": [],
548
- "_model_module": "@jupyter-widgets/controls",
549
- "_model_module_version": "1.5.0",
550
- "_model_name": "HTMLModel",
551
- "_view_count": null,
552
- "_view_module": "@jupyter-widgets/controls",
553
- "_view_module_version": "1.5.0",
554
- "_view_name": "HTMLView",
555
- "description": "",
556
- "description_tooltip": null,
557
- "layout": "IPY_MODEL_f3b643a0ca2444fd959fff9b45d79d27",
558
- "placeholder": "​",
559
- "style": "IPY_MODEL_82b87345233549d699ce3fd8080fa988",
560
- "value": " 5/8 [04:25&lt;02:39, 53.09s/it]"
561
- }
562
  },
563
- "5e81a141422f432395055f5cafb07016": {
564
- "model_module": "@jupyter-widgets/base",
565
- "model_module_version": "1.2.0",
566
- "model_name": "LayoutModel",
567
- "state": {
568
- "_model_module": "@jupyter-widgets/base",
569
- "_model_module_version": "1.2.0",
570
- "_model_name": "LayoutModel",
571
- "_view_count": null,
572
- "_view_module": "@jupyter-widgets/base",
573
- "_view_module_version": "1.2.0",
574
- "_view_name": "LayoutView",
575
- "align_content": null,
576
- "align_items": null,
577
- "align_self": null,
578
- "border": null,
579
- "bottom": null,
580
- "display": null,
581
- "flex": null,
582
- "flex_flow": null,
583
- "grid_area": null,
584
- "grid_auto_columns": null,
585
- "grid_auto_flow": null,
586
- "grid_auto_rows": null,
587
- "grid_column": null,
588
- "grid_gap": null,
589
- "grid_row": null,
590
- "grid_template_areas": null,
591
- "grid_template_columns": null,
592
- "grid_template_rows": null,
593
- "height": null,
594
- "justify_content": null,
595
- "justify_items": null,
596
- "left": null,
597
- "margin": null,
598
- "max_height": null,
599
- "max_width": null,
600
- "min_height": null,
601
- "min_width": null,
602
- "object_fit": null,
603
- "object_position": null,
604
- "order": null,
605
- "overflow": null,
606
- "overflow_x": null,
607
- "overflow_y": null,
608
- "padding": null,
609
- "right": null,
610
- "top": null,
611
- "visibility": null,
612
- "width": null
613
- }
614
  },
615
- "5f476a929da84fa985b2e980459da7b9": {
616
- "model_module": "@jupyter-widgets/controls",
617
- "model_module_version": "1.5.0",
618
- "model_name": "ProgressStyleModel",
619
- "state": {
620
- "_model_module": "@jupyter-widgets/controls",
621
- "_model_module_version": "1.5.0",
622
- "_model_name": "ProgressStyleModel",
623
- "_view_count": null,
624
- "_view_module": "@jupyter-widgets/base",
625
- "_view_module_version": "1.2.0",
626
- "_view_name": "StyleView",
627
- "bar_color": null,
628
- "description_width": ""
629
- }
630
  },
631
- "73413668398b45dfa8484a2c2be778ec": {
632
- "model_module": "@jupyter-widgets/base",
633
- "model_module_version": "1.2.0",
634
- "model_name": "LayoutModel",
635
- "state": {
636
- "_model_module": "@jupyter-widgets/base",
637
- "_model_module_version": "1.2.0",
638
- "_model_name": "LayoutModel",
639
- "_view_count": null,
640
- "_view_module": "@jupyter-widgets/base",
641
- "_view_module_version": "1.2.0",
642
- "_view_name": "LayoutView",
643
- "align_content": null,
644
- "align_items": null,
645
- "align_self": null,
646
- "border": null,
647
- "bottom": null,
648
- "display": null,
649
- "flex": null,
650
- "flex_flow": null,
651
- "grid_area": null,
652
- "grid_auto_columns": null,
653
- "grid_auto_flow": null,
654
- "grid_auto_rows": null,
655
- "grid_column": null,
656
- "grid_gap": null,
657
- "grid_row": null,
658
- "grid_template_areas": null,
659
- "grid_template_columns": null,
660
- "grid_template_rows": null,
661
- "height": null,
662
- "justify_content": null,
663
- "justify_items": null,
664
- "left": null,
665
- "margin": null,
666
- "max_height": null,
667
- "max_width": null,
668
- "min_height": null,
669
- "min_width": null,
670
- "object_fit": null,
671
- "object_position": null,
672
- "order": null,
673
- "overflow": null,
674
- "overflow_x": null,
675
- "overflow_y": null,
676
- "padding": null,
677
- "right": null,
678
- "top": null,
679
- "visibility": null,
680
- "width": null
681
- }
682
  },
683
- "7be07ba7cfe642a596509c756dcefddc": {
684
- "model_module": "@jupyter-widgets/controls",
685
- "model_module_version": "1.5.0",
686
- "model_name": "HTMLModel",
687
- "state": {
688
- "_dom_classes": [],
689
- "_model_module": "@jupyter-widgets/controls",
690
- "_model_module_version": "1.5.0",
691
- "_model_name": "HTMLModel",
692
- "_view_count": null,
693
- "_view_module": "@jupyter-widgets/controls",
694
- "_view_module_version": "1.5.0",
695
- "_view_name": "HTMLView",
696
- "description": "",
697
- "description_tooltip": null,
698
- "layout": "IPY_MODEL_73413668398b45dfa8484a2c2be778ec",
699
- "placeholder": "​",
700
- "style": "IPY_MODEL_e7d108a4b168442fb2048f58ddeb0a18",
701
- "value": " 62%"
702
- }
703
  },
704
- "82b87345233549d699ce3fd8080fa988": {
705
- "model_module": "@jupyter-widgets/controls",
706
- "model_module_version": "1.5.0",
707
- "model_name": "DescriptionStyleModel",
708
- "state": {
709
- "_model_module": "@jupyter-widgets/controls",
710
- "_model_module_version": "1.5.0",
711
- "_model_name": "DescriptionStyleModel",
712
- "_view_count": null,
713
- "_view_module": "@jupyter-widgets/base",
714
- "_view_module_version": "1.2.0",
715
- "_view_name": "StyleView",
716
- "description_width": ""
717
- }
718
  },
719
- "cb157fd4e37041d1beae29eaa729c8ff": {
720
- "model_module": "@jupyter-widgets/base",
721
- "model_module_version": "1.2.0",
722
- "model_name": "LayoutModel",
723
- "state": {
724
- "_model_module": "@jupyter-widgets/base",
725
- "_model_module_version": "1.2.0",
726
- "_model_name": "LayoutModel",
727
- "_view_count": null,
728
- "_view_module": "@jupyter-widgets/base",
729
- "_view_module_version": "1.2.0",
730
- "_view_name": "LayoutView",
731
- "align_content": null,
732
- "align_items": null,
733
- "align_self": null,
734
- "border": null,
735
- "bottom": null,
736
- "display": null,
737
- "flex": null,
738
- "flex_flow": null,
739
- "grid_area": null,
740
- "grid_auto_columns": null,
741
- "grid_auto_flow": null,
742
- "grid_auto_rows": null,
743
- "grid_column": null,
744
- "grid_gap": null,
745
- "grid_row": null,
746
- "grid_template_areas": null,
747
- "grid_template_columns": null,
748
- "grid_template_rows": null,
749
- "height": null,
750
- "justify_content": null,
751
- "justify_items": null,
752
- "left": null,
753
- "margin": null,
754
- "max_height": null,
755
- "max_width": null,
756
- "min_height": null,
757
- "min_width": null,
758
- "object_fit": null,
759
- "object_position": null,
760
- "order": null,
761
- "overflow": null,
762
- "overflow_x": null,
763
- "overflow_y": null,
764
- "padding": null,
765
- "right": null,
766
- "top": null,
767
- "visibility": null,
768
- "width": null
769
- }
770
  },
771
- "cef76449b8d74217ae36c56be3990eec": {
772
- "model_module": "@jupyter-widgets/controls",
773
- "model_module_version": "1.5.0",
774
- "model_name": "HBoxModel",
775
- "state": {
776
- "_dom_classes": [],
777
- "_model_module": "@jupyter-widgets/controls",
778
- "_model_module_version": "1.5.0",
779
- "_model_name": "HBoxModel",
780
- "_view_count": null,
781
- "_view_module": "@jupyter-widgets/controls",
782
- "_view_module_version": "1.5.0",
783
- "_view_name": "HBoxView",
784
- "box_style": "",
785
- "children": [
786
- "IPY_MODEL_7be07ba7cfe642a596509c756dcefddc",
787
- "IPY_MODEL_2a02378499fc414299f17a2d5dcac867",
788
- "IPY_MODEL_427d47d9423441d286ae80a637ae35a0"
789
- ],
790
- "layout": "IPY_MODEL_cb157fd4e37041d1beae29eaa729c8ff"
791
- }
792
  },
793
- "e7d108a4b168442fb2048f58ddeb0a18": {
794
- "model_module": "@jupyter-widgets/controls",
795
- "model_module_version": "1.5.0",
796
- "model_name": "DescriptionStyleModel",
797
- "state": {
798
- "_model_module": "@jupyter-widgets/controls",
799
- "_model_module_version": "1.5.0",
800
- "_model_name": "DescriptionStyleModel",
801
- "_view_count": null,
802
- "_view_module": "@jupyter-widgets/base",
803
- "_view_module_version": "1.2.0",
804
- "_view_name": "StyleView",
805
- "description_width": ""
806
- }
807
  },
808
- "f3b643a0ca2444fd959fff9b45d79d27": {
809
- "model_module": "@jupyter-widgets/base",
810
- "model_module_version": "1.2.0",
811
- "model_name": "LayoutModel",
812
- "state": {
813
- "_model_module": "@jupyter-widgets/base",
814
- "_model_module_version": "1.2.0",
815
- "_model_name": "LayoutModel",
816
- "_view_count": null,
817
- "_view_module": "@jupyter-widgets/base",
818
- "_view_module_version": "1.2.0",
819
- "_view_name": "LayoutView",
820
- "align_content": null,
821
- "align_items": null,
822
- "align_self": null,
823
- "border": null,
824
- "bottom": null,
825
- "display": null,
826
- "flex": null,
827
- "flex_flow": null,
828
- "grid_area": null,
829
- "grid_auto_columns": null,
830
- "grid_auto_flow": null,
831
- "grid_auto_rows": null,
832
- "grid_column": null,
833
- "grid_gap": null,
834
- "grid_row": null,
835
- "grid_template_areas": null,
836
- "grid_template_columns": null,
837
- "grid_template_rows": null,
838
- "height": null,
839
- "justify_content": null,
840
- "justify_items": null,
841
- "left": null,
842
- "margin": null,
843
- "max_height": null,
844
- "max_width": null,
845
- "min_height": null,
846
- "min_width": null,
847
- "object_fit": null,
848
- "object_position": null,
849
- "order": null,
850
- "overflow": null,
851
- "overflow_x": null,
852
- "overflow_y": null,
853
- "padding": null,
854
- "right": null,
855
- "top": null,
856
- "visibility": null,
857
- "width": null
858
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
859
  }
860
- }
861
- }
862
- },
863
- "nbformat": 4,
864
- "nbformat_minor": 0
865
- }
 
1
  {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "view-in-github",
7
+ "colab_type": "text"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {
16
+ "id": "118UKH5bWCGa"
17
+ },
18
+ "source": [
19
+ "# DALL·E mini - Inference pipeline\n",
20
+ "\n",
21
+ "*Generate images from a text prompt*\n",
22
+ "\n",
23
+ "<img src=\"https://github.com/borisdayma/dalle-mini/blob/main/img/logo.png?raw=true\" width=\"200\">\n",
24
+ "\n",
25
+ "This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
26
+ "\n",
27
+ "Just want to play? Use directly [DALL·E mini app](https://huggingface.co/spaces/dalle-mini/dalle-mini).\n",
28
+ "\n",
29
+ "For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
30
+ ]
31
  },
32
+ {
33
+ "cell_type": "markdown",
34
+ "metadata": {
35
+ "id": "dS8LbaonYm3a"
36
+ },
37
+ "source": [
38
+ "## 🛠️ Installation and set-up"
39
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {
45
+ "id": "uzjAM2GBYpZX"
46
+ },
47
+ "outputs": [],
48
+ "source": [
49
+ "# Install required libraries\n",
50
+ "!pip install -q git+https://github.com/huggingface/transformers.git\n",
51
+ "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
52
+ "!pip install -q git+https://github.com/borisdayma/dalle-mini.git"
53
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  },
55
+ {
56
+ "cell_type": "markdown",
57
+ "metadata": {
58
+ "id": "ozHzTkyv8cqU"
59
+ },
60
+ "source": [
61
+ "We load required models:\n",
62
+ "* DALL·E mini for text to encoded images\n",
63
+ "* VQGAN for decoding images\n",
64
+ "* CLIP for scoring predictions"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": null,
70
+ "metadata": {
71
+ "id": "K6CxW2o42f-w"
72
+ },
73
+ "outputs": [],
74
+ "source": [
75
+ "# Model references\n",
76
+ "\n",
77
+ "# dalle-mega\n",
78
+ "DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
79
+ "DALLE_COMMIT_ID = None\n",
80
+ "\n",
81
+ "# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line\n",
82
+ "# DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\"\n",
83
+ "\n",
84
+ "# VQGAN model\n",
85
+ "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
86
+ "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\""
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": null,
92
+ "metadata": {
93
+ "id": "Yv-aR3t4Oe5v"
94
+ },
95
+ "outputs": [],
96
+ "source": [
97
+ "import jax\n",
98
+ "import jax.numpy as jnp\n",
99
+ "\n",
100
+ "# check how many devices are available\n",
101
+ "jax.local_device_count()"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": null,
107
+ "metadata": {
108
+ "id": "92zYmvsQ38vL"
109
+ },
110
+ "outputs": [],
111
+ "source": [
112
+ "# Load models & tokenizer\n",
113
+ "from dalle_mini import DalleBart, DalleBartProcessor\n",
114
+ "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
115
+ "from transformers import CLIPProcessor, FlaxCLIPModel\n",
116
+ "\n",
117
+ "# Load dalle-mini\n",
118
+ "model, params = DalleBart.from_pretrained(\n",
119
+ " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
120
+ ")\n",
121
+ "\n",
122
+ "# Load VQGAN\n",
123
+ "vqgan, vqgan_params = VQModel.from_pretrained(\n",
124
+ " VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False\n",
125
+ ")"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "markdown",
130
+ "metadata": {
131
+ "id": "o_vH2X1tDtzA"
132
+ },
133
+ "source": [
134
+ "Model parameters are replicated on each device for faster inference."
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "metadata": {
141
+ "id": "wtvLoM48EeVw"
142
+ },
143
+ "outputs": [],
144
+ "source": [
145
+ "from flax.jax_utils import replicate\n",
146
+ "\n",
147
+ "params = replicate(params)\n",
148
+ "vqgan_params = replicate(vqgan_params)"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "markdown",
153
+ "metadata": {
154
+ "id": "0A9AHQIgZ_qw"
155
+ },
156
+ "source": [
157
+ "Model functions are compiled and parallelized to take advantage of multiple devices."
158
+ ]
159
+ },
160
+ {
161
+ "cell_type": "code",
162
+ "execution_count": null,
163
+ "metadata": {
164
+ "id": "sOtoOmYsSYPz"
165
+ },
166
+ "outputs": [],
167
+ "source": [
168
+ "from functools import partial\n",
169
+ "\n",
170
+ "# model inference\n",
171
+ "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n",
172
+ "def p_generate(\n",
173
+ " tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n",
174
+ "):\n",
175
+ " return model.generate(\n",
176
+ " **tokenized_prompt,\n",
177
+ " prng_key=key,\n",
178
+ " params=params,\n",
179
+ " top_k=top_k,\n",
180
+ " top_p=top_p,\n",
181
+ " temperature=temperature,\n",
182
+ " condition_scale=condition_scale,\n",
183
+ " )\n",
184
+ "\n",
185
+ "\n",
186
+ "# decode image\n",
187
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
188
+ "def p_decode(indices, params):\n",
189
+ " return vqgan.decode_code(indices, params=params)"
190
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  },
192
+ {
193
+ "cell_type": "markdown",
194
+ "metadata": {
195
+ "id": "HmVN6IBwapBA"
196
+ },
197
+ "source": [
198
+ "Keys are passed to the model on each device to generate unique inference per device."
199
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
200
  },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": null,
204
+ "metadata": {
205
+ "id": "4CTXmlUkThhX"
206
+ },
207
+ "outputs": [],
208
+ "source": [
209
+ "import random\n",
210
+ "\n",
211
+ "# create a random key\n",
212
+ "seed = random.randint(0, 2**32 - 1)\n",
213
+ "key = jax.random.PRNGKey(seed)"
214
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  },
216
+ {
217
+ "cell_type": "markdown",
218
+ "metadata": {
219
+ "id": "BrnVyCo81pij"
220
+ },
221
+ "source": [
222
+ "## 🖍 Text Prompt"
223
+ ]
 
 
 
 
 
 
 
224
  },
225
+ {
226
+ "cell_type": "markdown",
227
+ "metadata": {
228
+ "id": "rsmj0Aj5OQox"
229
+ },
230
+ "source": [
231
+ "Our model requires processing prompts."
232
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  },
234
+ {
235
+ "cell_type": "code",
236
+ "execution_count": null,
237
+ "metadata": {
238
+ "id": "YjjhUychOVxm"
239
+ },
240
+ "outputs": [],
241
+ "source": [
242
+ "from dalle_mini import DalleBartProcessor\n",
243
+ "\n",
244
+ "processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)"
245
+ ]
 
 
 
 
 
 
 
 
246
  },
247
+ {
248
+ "cell_type": "markdown",
249
+ "metadata": {
250
+ "id": "BQ7fymSPyvF_"
251
+ },
252
+ "source": [
253
+ "Let's define a text prompt."
254
+ ]
 
 
 
 
 
 
255
  },
256
+ {
257
+ "cell_type": "code",
258
+ "execution_count": null,
259
+ "metadata": {
260
+ "id": "x_0vI9ge1oKr"
261
+ },
262
+ "outputs": [],
263
+ "source": [
264
+ "prompt = \"sunset over a lake in the mountains\""
265
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  },
267
+ {
268
+ "cell_type": "code",
269
+ "execution_count": null,
270
+ "metadata": {
271
+ "id": "VKjEZGjtO49k"
272
+ },
273
+ "outputs": [],
274
+ "source": [
275
+ "tokenized_prompt = processor([prompt])"
276
+ ]
 
 
 
 
 
 
 
 
 
 
 
277
  },
278
+ {
279
+ "cell_type": "markdown",
280
+ "metadata": {
281
+ "id": "-CEJBnuJOe5z"
282
+ },
283
+ "source": [
284
+ "Finally we replicate it onto each device."
285
+ ]
 
 
 
 
 
 
286
  },
287
+ {
288
+ "cell_type": "code",
289
+ "execution_count": null,
290
+ "metadata": {
291
+ "id": "lQePgju5Oe5z"
292
+ },
293
+ "outputs": [],
294
+ "source": [
295
+ "tokenized_prompt = replicate(tokenized_prompt)"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "markdown",
300
+ "metadata": {
301
+ "id": "phQ9bhjRkgAZ"
302
+ },
303
+ "source": [
304
+ "## 🎨 Generate images\n",
305
+ "\n",
306
+ "We generate images using dalle-mini model and decode them with the VQGAN."
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "execution_count": null,
312
+ "metadata": {
313
+ "id": "d0wVkXpKqnHA"
314
+ },
315
+ "outputs": [],
316
+ "source": [
317
+ "# number of predictions\n",
318
+ "n_predictions = 8\n",
319
+ "\n",
320
+ "# We can customize generation parameters\n",
321
+ "gen_top_k = None\n",
322
+ "gen_top_p = None\n",
323
+ "temperature = None\n",
324
+ "cond_scale = 3.0"
325
+ ]
326
+ },
327
+ {
328
+ "cell_type": "code",
329
+ "execution_count": null,
330
+ "metadata": {
331
+ "id": "SDjEx9JxR3v8"
332
+ },
333
+ "outputs": [],
334
+ "source": [
335
+ "from flax.training.common_utils import shard_prng_key\n",
336
+ "import numpy as np\n",
337
+ "from PIL import Image\n",
338
+ "from tqdm.notebook import trange\n",
339
+ "\n",
340
+ "print(f\"Prompt: {prompt}\\n\")\n",
341
+ "# generate images\n",
342
+ "images = []\n",
343
+ "for i in trange(max(n_predictions // jax.device_count(), 1)):\n",
344
+ " # get a new key\n",
345
+ " key, subkey = jax.random.split(key)\n",
346
+ " # generate images\n",
347
+ " encoded_images = p_generate(\n",
348
+ " tokenized_prompt,\n",
349
+ " shard_prng_key(subkey),\n",
350
+ " params,\n",
351
+ " gen_top_k,\n",
352
+ " gen_top_p,\n",
353
+ " temperature,\n",
354
+ " cond_scale,\n",
355
+ " )\n",
356
+ " # remove BOS\n",
357
+ " encoded_images = encoded_images.sequences[..., 1:]\n",
358
+ " # decode images\n",
359
+ " decoded_images = p_decode(encoded_images, vqgan_params)\n",
360
+ " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
361
+ " for decoded_img in decoded_images:\n",
362
+ " img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))\n",
363
+ " images.append(img)\n",
364
+ " display(img)"
365
+ ]
366
+ },
367
+ {
368
+ "cell_type": "markdown",
369
+ "metadata": {
370
+ "id": "tw02wG9zGmyB"
371
+ },
372
+ "source": [
373
+ "## 🏅 Optional: Rank images by CLIP score\n",
374
+ "\n",
375
+ "We can rank images according to CLIP.\n",
376
+ "\n",
377
+ "**Note: your session may crash if you don't have a subscription to Colab Pro.**"
378
+ ]
379
+ },
380
+ {
381
+ "cell_type": "code",
382
+ "execution_count": null,
383
+ "metadata": {
384
+ "id": "RGjlIW_f6GA0"
385
+ },
386
+ "outputs": [],
387
+ "source": [
388
+ "# CLIP model\n",
389
+ "CLIP_REPO = \"openai/clip-vit-base-patch32\"\n",
390
+ "CLIP_COMMIT_ID = None\n",
391
+ "\n",
392
+ "# Load CLIP\n",
393
+ "clip, clip_params = FlaxCLIPModel.from_pretrained(\n",
394
+ " CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
395
+ ")\n",
396
+ "clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
397
+ "clip_params = replicate(clip_params)\n",
398
+ "\n",
399
+ "# score images\n",
400
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
401
+ "def p_clip(inputs, params):\n",
402
+ " logits = clip(params=params, **inputs).logits_per_image\n",
403
+ " return logits"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "code",
408
+ "execution_count": null,
409
+ "metadata": {
410
+ "id": "FoLXpjCmGpju"
411
+ },
412
+ "outputs": [],
413
+ "source": [
414
+ "from flax.training.common_utils import shard\n",
415
+ "\n",
416
+ "# get clip scores\n",
417
+ "clip_inputs = clip_processor(\n",
418
+ " text=[prompt] * jax.device_count(),\n",
419
+ " images=images,\n",
420
+ " return_tensors=\"np\",\n",
421
+ " padding=\"max_length\",\n",
422
+ " max_length=77,\n",
423
+ " truncation=True,\n",
424
+ ").data\n",
425
+ "logits = p_clip(shard(clip_inputs), clip_params)\n",
426
+ "logits = logits.squeeze().flatten()"
427
+ ]
428
+ },
429
+ {
430
+ "cell_type": "markdown",
431
+ "metadata": {
432
+ "id": "4AAWRm70LgED"
433
+ },
434
+ "source": [
435
+ "Let's now display images ranked by CLIP score."
436
+ ]
437
+ },
438
+ {
439
+ "cell_type": "code",
440
+ "execution_count": null,
441
+ "metadata": {
442
+ "id": "zsgxxubLLkIu"
443
+ },
444
+ "outputs": [],
445
+ "source": [
446
+ "print(f\"Prompt: {prompt}\\n\")\n",
447
+ "for idx in logits.argsort()[::-1]:\n",
448
+ " display(images[idx])\n",
449
+ " print(f\"Score: {logits[idx]:.2f}\\n\")"
450
+ ]
451
+ }
452
+ ],
453
+ "metadata": {
454
+ "accelerator": "GPU",
455
+ "colab": {
456
+ "collapsed_sections": [],
457
+ "machine_shape": "hm",
458
+ "name": "DALL·E mini - Inference pipeline.ipynb",
459
+ "provenance": [],
460
+ "include_colab_link": true
461
+ },
462
+ "kernelspec": {
463
+ "display_name": "Python 3 (ipykernel)",
464
+ "language": "python",
465
+ "name": "python3"
466
+ },
467
+ "language_info": {
468
+ "codemirror_mode": {
469
+ "name": "ipython",
470
+ "version": 3
471
+ },
472
+ "file_extension": ".py",
473
+ "mimetype": "text/x-python",
474
+ "name": "python",
475
+ "nbconvert_exporter": "python",
476
+ "pygments_lexer": "ipython3",
477
+ "version": "3.9.7"
478
  }
479
+ },
480
+ "nbformat": 4,
481
+ "nbformat_minor": 0
482
+ }