boris commited on
Commit
e558000
1 Parent(s): 5f954fc

feat(demo): update reference

Browse files
Files changed (1) hide show
  1. tools/inference/inference_pipeline.ipynb +512 -514
tools/inference/inference_pipeline.ipynb CHANGED
@@ -1,515 +1,513 @@
1
  {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {
6
- "id": "view-in-github",
7
- "colab_type": "text"
8
- },
9
- "source": [
10
- "<a href=\"https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
- ]
12
- },
13
- {
14
- "cell_type": "markdown",
15
- "metadata": {
16
- "id": "118UKH5bWCGa"
17
- },
18
- "source": [
19
- "# DALL·E mini - Inference pipeline\n",
20
- "\n",
21
- "*Generate images from a text prompt*\n",
22
- "\n",
23
- "<img src=\"https://github.com/borisdayma/dalle-mini/blob/main/img/logo.png?raw=true\" width=\"200\">\n",
24
- "\n",
25
- "This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
26
- "\n",
27
- "Just want to play? Use [the demo](https://huggingface.co/spaces/flax-community/dalle-mini).\n",
28
- "\n",
29
- "For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
30
- ]
31
- },
32
- {
33
- "cell_type": "markdown",
34
- "metadata": {
35
- "id": "dS8LbaonYm3a"
36
- },
37
- "source": [
38
- "## 🛠️ Installation and set-up"
39
- ]
40
- },
41
- {
42
- "cell_type": "code",
43
- "execution_count": null,
44
- "metadata": {
45
- "id": "uzjAM2GBYpZX"
46
- },
47
- "outputs": [],
48
- "source": [
49
- "# Install required libraries\n",
50
- "!pip install -q transformers\n",
51
- "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
52
- "!pip install -q git+https://github.com/borisdayma/dalle-mini.git"
53
- ]
54
- },
55
- {
56
- "cell_type": "markdown",
57
- "metadata": {
58
- "id": "ozHzTkyv8cqU"
59
- },
60
- "source": [
61
- "We load required models:\n",
62
- "* dalle·mini for text to encoded images\n",
63
- "* VQGAN for decoding images\n",
64
- "* CLIP for scoring predictions"
65
- ]
66
- },
67
- {
68
- "cell_type": "code",
69
- "execution_count": null,
70
- "metadata": {
71
- "id": "K6CxW2o42f-w"
72
- },
73
- "outputs": [],
74
- "source": [
75
- "# Model references\n",
76
- "\n",
77
- "# dalle-mini\n",
78
- "DALLE_MODEL = \"dalle-mini/dalle-mini/model-mehdx7dg:latest\" # can be wandb artifact or 🤗 Hub or local folder\n",
79
- "DALLE_COMMIT_ID = None\n",
80
- "\n",
81
- "# VQGAN model\n",
82
- "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
83
- "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"\n",
84
- "\n",
85
- "# CLIP model\n",
86
- "CLIP_REPO = \"openai/clip-vit-base-patch16\"\n",
87
- "CLIP_COMMIT_ID = None"
88
- ]
89
- },
90
- {
91
- "cell_type": "code",
92
- "execution_count": null,
93
- "metadata": {
94
- "id": "Yv-aR3t4Oe5v"
95
- },
96
- "outputs": [],
97
- "source": [
98
- "import jax\n",
99
- "import jax.numpy as jnp\n",
100
- "\n",
101
- "# check how many devices are available\n",
102
- "jax.local_device_count()"
103
- ]
104
- },
105
- {
106
- "cell_type": "code",
107
- "execution_count": null,
108
- "metadata": {
109
- "id": "HWnQrQuXOe5w"
110
- },
111
- "outputs": [],
112
- "source": [
113
- "# type used for computation - use bfloat16 on TPU's\n",
114
- "dtype = jnp.bfloat16 if jax.local_device_count() == 8 else jnp.float32\n",
115
- "\n",
116
- "# TODO: fix issue with bfloat16\n",
117
- "dtype = jnp.float32"
118
- ]
119
- },
120
- {
121
- "cell_type": "code",
122
- "execution_count": null,
123
- "metadata": {
124
- "id": "92zYmvsQ38vL"
125
- },
126
- "outputs": [],
127
- "source": [
128
- "# Load models & tokenizer\n",
129
- "from dalle_mini.model import DalleBart, DalleBartTokenizer\n",
130
- "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
131
- "from transformers import CLIPProcessor, FlaxCLIPModel\n",
132
- "import wandb\n",
133
- "\n",
134
- "# Load dalle-mini\n",
135
- "model = DalleBart.from_pretrained(\n",
136
- " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True\n",
137
- ")\n",
138
- "tokenizer = DalleBartTokenizer.from_pretrained(\n",
139
- " DALLE_MODEL, revision=DALLE_COMMIT_ID\n",
140
- ")\n",
141
- "\n",
142
- "# Load VQGAN\n",
143
- "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
144
- "\n",
145
- "# Load CLIP\n",
146
- "clip = FlaxCLIPModel.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
147
- "processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)"
148
- ]
149
- },
150
- {
151
- "cell_type": "markdown",
152
- "metadata": {
153
- "id": "o_vH2X1tDtzA"
154
- },
155
- "source": [
156
- "Model parameters are replicated on each device for faster inference."
157
- ]
158
- },
159
- {
160
- "cell_type": "code",
161
- "execution_count": null,
162
- "metadata": {
163
- "id": "wtvLoM48EeVw"
164
- },
165
- "outputs": [],
166
- "source": [
167
- "from flax.jax_utils import replicate\n",
168
- "\n",
169
- "# convert model parameters for inference if requested\n",
170
- "if dtype == jnp.bfloat16:\n",
171
- " model.params = model.to_bf16(model.params)\n",
172
- "\n",
173
- "model_params = replicate(model.params)\n",
174
- "vqgan_params = replicate(vqgan.params)\n",
175
- "clip_params = replicate(clip.params)"
176
- ]
177
- },
178
- {
179
- "cell_type": "markdown",
180
- "metadata": {
181
- "id": "0A9AHQIgZ_qw"
182
- },
183
- "source": [
184
- "Model functions are compiled and parallelized to take advantage of multiple devices."
185
- ]
186
- },
187
- {
188
- "cell_type": "code",
189
- "execution_count": null,
190
- "metadata": {
191
- "id": "sOtoOmYsSYPz"
192
- },
193
- "outputs": [],
194
- "source": [
195
- "from functools import partial\n",
196
- "\n",
197
- "# model inference\n",
198
- "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4))\n",
199
- "def p_generate(tokenized_prompt, key, params, top_k, top_p):\n",
200
- " return model.generate(\n",
201
- " **tokenized_prompt,\n",
202
- " do_sample=True,\n",
203
- " num_beams=1,\n",
204
- " prng_key=key,\n",
205
- " params=params,\n",
206
- " top_k=top_k,\n",
207
- " top_p=top_p,\n",
208
- " max_length=257\n",
209
- " )\n",
210
- "\n",
211
- "\n",
212
- "# decode images\n",
213
- "@partial(jax.pmap, axis_name=\"batch\")\n",
214
- "def p_decode(indices, params):\n",
215
- " return vqgan.decode_code(indices, params=params)\n",
216
- "\n",
217
- "\n",
218
- "# score images\n",
219
- "@partial(jax.pmap, axis_name=\"batch\")\n",
220
- "def p_clip(inputs, params):\n",
221
- " logits = clip(params=params, **inputs).logits_per_image\n",
222
- " return logits"
223
- ]
224
- },
225
- {
226
- "cell_type": "markdown",
227
- "metadata": {
228
- "id": "HmVN6IBwapBA"
229
- },
230
- "source": [
231
- "Keys are passed to the model on each device to generate unique inference per device."
232
- ]
233
- },
234
- {
235
- "cell_type": "code",
236
- "execution_count": null,
237
- "metadata": {
238
- "id": "4CTXmlUkThhX"
239
- },
240
- "outputs": [],
241
- "source": [
242
- "import random\n",
243
- "\n",
244
- "# create a random key\n",
245
- "seed = random.randint(0, 2**32 - 1)\n",
246
- "key = jax.random.PRNGKey(seed)"
247
- ]
248
- },
249
- {
250
- "cell_type": "markdown",
251
- "metadata": {
252
- "id": "BrnVyCo81pij"
253
- },
254
- "source": [
255
- "## 🖍 Text Prompt"
256
- ]
257
- },
258
- {
259
- "cell_type": "markdown",
260
- "metadata": {
261
- "id": "rsmj0Aj5OQox"
262
- },
263
- "source": [
264
- "Our model may require to normalize the prompt."
265
- ]
266
- },
267
- {
268
- "cell_type": "code",
269
- "execution_count": null,
270
- "metadata": {
271
- "id": "YjjhUychOVxm"
272
- },
273
- "outputs": [],
274
- "source": [
275
- "from dalle_mini.text import TextNormalizer\n",
276
- "\n",
277
- "text_normalizer = TextNormalizer() if model.config.normalize_text else None"
278
- ]
279
- },
280
- {
281
- "cell_type": "markdown",
282
- "metadata": {
283
- "id": "BQ7fymSPyvF_"
284
- },
285
- "source": [
286
- "Let's define a text prompt."
287
- ]
288
- },
289
- {
290
- "cell_type": "code",
291
- "execution_count": null,
292
- "metadata": {
293
- "id": "x_0vI9ge1oKr"
294
- },
295
- "outputs": [],
296
- "source": [
297
- "prompt = \"a blue table\""
298
- ]
299
- },
300
- {
301
- "cell_type": "code",
302
- "execution_count": null,
303
- "metadata": {
304
- "id": "VKjEZGjtO49k"
305
- },
306
- "outputs": [],
307
- "source": [
308
- "processed_prompt = text_normalizer(prompt) if model.config.normalize_text else prompt\n",
309
- "processed_prompt"
310
- ]
311
- },
312
- {
313
- "cell_type": "markdown",
314
- "metadata": {
315
- "id": "QUzYACWxOe5z"
316
- },
317
- "source": [
318
- "We tokenize the prompt."
319
- ]
320
- },
321
- {
322
- "cell_type": "code",
323
- "execution_count": null,
324
- "metadata": {
325
- "id": "n8e7MvGwOe5z"
326
- },
327
- "outputs": [],
328
- "source": [
329
- "tokenized_prompt = tokenizer(\n",
330
- " processed_prompt,\n",
331
- " return_tensors=\"jax\",\n",
332
- " padding=\"max_length\",\n",
333
- " truncation=True,\n",
334
- " max_length=128,\n",
335
- ").data\n",
336
- "tokenized_prompt"
337
- ]
338
- },
339
- {
340
- "cell_type": "markdown",
341
- "metadata": {
342
- "id": "_Y5dqFj7prMQ"
343
- },
344
- "source": [
345
- "Notes:\n",
346
- "\n",
347
- "* `0`: BOS, special token representing the beginning of a sequence\n",
348
- "* `2`: EOS, special token representing the end of a sequence\n",
349
- "* `1`: special token representing the padding of a sequence when requesting a specific length"
350
- ]
351
- },
352
- {
353
- "cell_type": "markdown",
354
- "metadata": {
355
- "id": "-CEJBnuJOe5z"
356
- },
357
- "source": [
358
- "Finally we replicate it onto each device."
359
- ]
360
- },
361
- {
362
- "cell_type": "code",
363
- "execution_count": null,
364
- "metadata": {
365
- "id": "lQePgju5Oe5z"
366
- },
367
- "outputs": [],
368
- "source": [
369
- "tokenized_prompt = replicate(tokenized_prompt)"
370
- ]
371
- },
372
- {
373
- "cell_type": "markdown",
374
- "metadata": {
375
- "id": "phQ9bhjRkgAZ"
376
- },
377
- "source": [
378
- "## 🎨 Generate images\n",
379
- "\n",
380
- "We generate images using dalle-mini model and decode them with the VQGAN."
381
- ]
382
- },
383
- {
384
- "cell_type": "code",
385
- "execution_count": null,
386
- "metadata": {
387
- "id": "d0wVkXpKqnHA"
388
- },
389
- "outputs": [],
390
- "source": [
391
- "# number of predictions\n",
392
- "n_predictions = 32\n",
393
- "\n",
394
- "# We can customize top_k/top_p used for generating samples\n",
395
- "gen_top_k = None\n",
396
- "gen_top_p = None"
397
- ]
398
- },
399
- {
400
- "cell_type": "code",
401
- "execution_count": null,
402
- "metadata": {
403
- "id": "SDjEx9JxR3v8"
404
- },
405
- "outputs": [],
406
- "source": [
407
- "from flax.training.common_utils import shard_prng_key\n",
408
- "import numpy as np\n",
409
- "from PIL import Image\n",
410
- "from tqdm.notebook import trange\n",
411
- "\n",
412
- "# generate images\n",
413
- "images = []\n",
414
- "for i in trange(n_predictions // jax.device_count()):\n",
415
- " # get a new key\n",
416
- " key, subkey = jax.random.split(key)\n",
417
- " # generate images\n",
418
- " encoded_images = p_generate(\n",
419
- " tokenized_prompt, shard_prng_key(subkey), model_params, gen_top_k, gen_top_p\n",
420
- " )\n",
421
- " # remove BOS\n",
422
- " encoded_images = encoded_images.sequences[..., 1:]\n",
423
- " # decode images\n",
424
- " decoded_images = p_decode(encoded_images, vqgan_params)\n",
425
- " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
426
- " for img in decoded_images:\n",
427
- " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
428
- ]
429
- },
430
- {
431
- "cell_type": "markdown",
432
- "metadata": {
433
- "id": "tw02wG9zGmyB"
434
- },
435
- "source": [
436
- "Let's calculate their score with CLIP."
437
- ]
438
- },
439
- {
440
- "cell_type": "code",
441
- "execution_count": null,
442
- "metadata": {
443
- "id": "FoLXpjCmGpju"
444
- },
445
- "outputs": [],
446
- "source": [
447
- "from flax.training.common_utils import shard\n",
448
- "\n",
449
- "# get clip scores\n",
450
- "clip_inputs = processor(\n",
451
- " text=[prompt] * jax.device_count(),\n",
452
- " images=images,\n",
453
- " return_tensors=\"np\",\n",
454
- " padding=\"max_length\",\n",
455
- " max_length=77,\n",
456
- " truncation=True,\n",
457
- ").data\n",
458
- "logits = p_clip(shard(clip_inputs), clip_params)\n",
459
- "logits = logits.squeeze().flatten()"
460
- ]
461
- },
462
- {
463
- "cell_type": "markdown",
464
- "metadata": {
465
- "id": "4AAWRm70LgED"
466
- },
467
- "source": [
468
- "Let's display images ranked by CLIP score."
469
- ]
470
- },
471
- {
472
- "cell_type": "code",
473
- "execution_count": null,
474
- "metadata": {
475
- "id": "zsgxxubLLkIu"
476
- },
477
- "outputs": [],
478
- "source": [
479
- "print(f\"Prompt: {prompt}\\n\")\n",
480
- "for idx in logits.argsort()[::-1]:\n",
481
- " display(images[idx])\n",
482
- " print(f\"Score: {logits[idx]:.2f}\\n\")"
483
- ]
484
- }
485
- ],
486
- "metadata": {
487
- "accelerator": "GPU",
488
- "colab": {
489
- "collapsed_sections": [],
490
- "machine_shape": "hm",
491
- "name": "DALL·E mini - Inference pipeline.ipynb",
492
- "provenance": [],
493
- "include_colab_link": true
494
- },
495
- "kernelspec": {
496
- "display_name": "Python 3 (ipykernel)",
497
- "language": "python",
498
- "name": "python3"
499
- },
500
- "language_info": {
501
- "codemirror_mode": {
502
- "name": "ipython",
503
- "version": 3
504
- },
505
- "file_extension": ".py",
506
- "mimetype": "text/x-python",
507
- "name": "python",
508
- "nbconvert_exporter": "python",
509
- "pygments_lexer": "ipython3",
510
- "version": "3.9.7"
511
- }
512
- },
513
- "nbformat": 4,
514
- "nbformat_minor": 0
515
- }
 
1
  {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "colab_type": "text",
7
+ "id": "view-in-github"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {
16
+ "id": "118UKH5bWCGa"
17
+ },
18
+ "source": [
19
+ "# DALL·E mini - Inference pipeline\n",
20
+ "\n",
21
+ "*Generate images from a text prompt*\n",
22
+ "\n",
23
+ "<img src=\"https://github.com/borisdayma/dalle-mini/blob/main/img/logo.png?raw=true\" width=\"200\">\n",
24
+ "\n",
25
+ "This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
26
+ "\n",
27
+ "Just want to play? Use [the demo](https://huggingface.co/spaces/flax-community/dalle-mini).\n",
28
+ "\n",
29
+ "For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "markdown",
34
+ "metadata": {
35
+ "id": "dS8LbaonYm3a"
36
+ },
37
+ "source": [
38
+ "## 🛠️ Installation and set-up"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {
45
+ "id": "uzjAM2GBYpZX"
46
+ },
47
+ "outputs": [],
48
+ "source": [
49
+ "# Install required libraries\n",
50
+ "!pip install -q transformers\n",
51
+ "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
52
+ "!pip install -q git+https://github.com/borisdayma/dalle-mini.git"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "markdown",
57
+ "metadata": {
58
+ "id": "ozHzTkyv8cqU"
59
+ },
60
+ "source": [
61
+ "We load required models:\n",
62
+ "* dalle·mini for text to encoded images\n",
63
+ "* VQGAN for decoding images\n",
64
+ "* CLIP for scoring predictions"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": null,
70
+ "metadata": {
71
+ "id": "K6CxW2o42f-w"
72
+ },
73
+ "outputs": [],
74
+ "source": [
75
+ "# Model references\n",
76
+ "\n",
77
+ "# dalle-mini\n",
78
+ "DALLE_MODEL = \"dalle-mini/dalle-mini/model-1reghx5l:latest\" # can be wandb artifact or 🤗 Hub or local folder\n",
79
+ "DALLE_COMMIT_ID = None\n",
80
+ "\n",
81
+ "# VQGAN model\n",
82
+ "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
83
+ "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"\n",
84
+ "\n",
85
+ "# CLIP model\n",
86
+ "CLIP_REPO = \"openai/clip-vit-base-patch16\"\n",
87
+ "CLIP_COMMIT_ID = None"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "metadata": {
94
+ "id": "Yv-aR3t4Oe5v"
95
+ },
96
+ "outputs": [],
97
+ "source": [
98
+ "import jax\n",
99
+ "import jax.numpy as jnp\n",
100
+ "\n",
101
+ "# check how many devices are available\n",
102
+ "jax.local_device_count()"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": null,
108
+ "metadata": {
109
+ "id": "HWnQrQuXOe5w"
110
+ },
111
+ "outputs": [],
112
+ "source": [
113
+ "# type used for computation - use bfloat16 on TPU's\n",
114
+ "dtype = jnp.bfloat16 if jax.local_device_count() == 8 else jnp.float32\n",
115
+ "\n",
116
+ "# TODO: fix issue with bfloat16\n",
117
+ "dtype = jnp.float32"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": null,
123
+ "metadata": {
124
+ "id": "92zYmvsQ38vL"
125
+ },
126
+ "outputs": [],
127
+ "source": [
128
+ "# Load models & tokenizer\n",
129
+ "from dalle_mini.model import DalleBart, DalleBartTokenizer\n",
130
+ "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
131
+ "from transformers import CLIPProcessor, FlaxCLIPModel\n",
132
+ "import wandb\n",
133
+ "\n",
134
+ "# Load dalle-mini\n",
135
+ "model = DalleBart.from_pretrained(\n",
136
+ " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True\n",
137
+ ")\n",
138
+ "tokenizer = DalleBartTokenizer.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)\n",
139
+ "\n",
140
+ "# Load VQGAN\n",
141
+ "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
142
+ "\n",
143
+ "# Load CLIP\n",
144
+ "clip = FlaxCLIPModel.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
145
+ "processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "markdown",
150
+ "metadata": {
151
+ "id": "o_vH2X1tDtzA"
152
+ },
153
+ "source": [
154
+ "Model parameters are replicated on each device for faster inference."
155
+ ]
156
+ },
157
+ {
158
+ "cell_type": "code",
159
+ "execution_count": null,
160
+ "metadata": {
161
+ "id": "wtvLoM48EeVw"
162
+ },
163
+ "outputs": [],
164
+ "source": [
165
+ "from flax.jax_utils import replicate\n",
166
+ "\n",
167
+ "# convert model parameters for inference if requested\n",
168
+ "if dtype == jnp.bfloat16:\n",
169
+ " model.params = model.to_bf16(model.params)\n",
170
+ "\n",
171
+ "model_params = replicate(model.params)\n",
172
+ "vqgan_params = replicate(vqgan.params)\n",
173
+ "clip_params = replicate(clip.params)"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "markdown",
178
+ "metadata": {
179
+ "id": "0A9AHQIgZ_qw"
180
+ },
181
+ "source": [
182
+ "Model functions are compiled and parallelized to take advantage of multiple devices."
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": null,
188
+ "metadata": {
189
+ "id": "sOtoOmYsSYPz"
190
+ },
191
+ "outputs": [],
192
+ "source": [
193
+ "from functools import partial\n",
194
+ "\n",
195
+ "# model inference\n",
196
+ "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4))\n",
197
+ "def p_generate(tokenized_prompt, key, params, top_k, top_p):\n",
198
+ " return model.generate(\n",
199
+ " **tokenized_prompt,\n",
200
+ " do_sample=True,\n",
201
+ " num_beams=1,\n",
202
+ " prng_key=key,\n",
203
+ " params=params,\n",
204
+ " top_k=top_k,\n",
205
+ " top_p=top_p,\n",
206
+ " max_length=257\n",
207
+ " )\n",
208
+ "\n",
209
+ "\n",
210
+ "# decode images\n",
211
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
212
+ "def p_decode(indices, params):\n",
213
+ " return vqgan.decode_code(indices, params=params)\n",
214
+ "\n",
215
+ "\n",
216
+ "# score images\n",
217
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
218
+ "def p_clip(inputs, params):\n",
219
+ " logits = clip(params=params, **inputs).logits_per_image\n",
220
+ " return logits"
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "markdown",
225
+ "metadata": {
226
+ "id": "HmVN6IBwapBA"
227
+ },
228
+ "source": [
229
+ "Keys are passed to the model on each device to generate unique inference per device."
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": null,
235
+ "metadata": {
236
+ "id": "4CTXmlUkThhX"
237
+ },
238
+ "outputs": [],
239
+ "source": [
240
+ "import random\n",
241
+ "\n",
242
+ "# create a random key\n",
243
+ "seed = random.randint(0, 2 ** 32 - 1)\n",
244
+ "key = jax.random.PRNGKey(seed)"
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "markdown",
249
+ "metadata": {
250
+ "id": "BrnVyCo81pij"
251
+ },
252
+ "source": [
253
+ "## 🖍 Text Prompt"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "markdown",
258
+ "metadata": {
259
+ "id": "rsmj0Aj5OQox"
260
+ },
261
+ "source": [
262
+ "Our model may require to normalize the prompt."
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "execution_count": null,
268
+ "metadata": {
269
+ "id": "YjjhUychOVxm"
270
+ },
271
+ "outputs": [],
272
+ "source": [
273
+ "from dalle_mini.text import TextNormalizer\n",
274
+ "\n",
275
+ "text_normalizer = TextNormalizer() if model.config.normalize_text else None"
276
+ ]
277
+ },
278
+ {
279
+ "cell_type": "markdown",
280
+ "metadata": {
281
+ "id": "BQ7fymSPyvF_"
282
+ },
283
+ "source": [
284
+ "Let's define a text prompt."
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "code",
289
+ "execution_count": null,
290
+ "metadata": {
291
+ "id": "x_0vI9ge1oKr"
292
+ },
293
+ "outputs": [],
294
+ "source": [
295
+ "prompt = \"a blue table\""
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": null,
301
+ "metadata": {
302
+ "id": "VKjEZGjtO49k"
303
+ },
304
+ "outputs": [],
305
+ "source": [
306
+ "processed_prompt = text_normalizer(prompt) if model.config.normalize_text else prompt\n",
307
+ "processed_prompt"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "markdown",
312
+ "metadata": {
313
+ "id": "QUzYACWxOe5z"
314
+ },
315
+ "source": [
316
+ "We tokenize the prompt."
317
+ ]
318
+ },
319
+ {
320
+ "cell_type": "code",
321
+ "execution_count": null,
322
+ "metadata": {
323
+ "id": "n8e7MvGwOe5z"
324
+ },
325
+ "outputs": [],
326
+ "source": [
327
+ "tokenized_prompt = tokenizer(\n",
328
+ " processed_prompt,\n",
329
+ " return_tensors=\"jax\",\n",
330
+ " padding=\"max_length\",\n",
331
+ " truncation=True,\n",
332
+ " max_length=128,\n",
333
+ ").data\n",
334
+ "tokenized_prompt"
335
+ ]
336
+ },
337
+ {
338
+ "cell_type": "markdown",
339
+ "metadata": {
340
+ "id": "_Y5dqFj7prMQ"
341
+ },
342
+ "source": [
343
+ "Notes:\n",
344
+ "\n",
345
+ "* `0`: BOS, special token representing the beginning of a sequence\n",
346
+ "* `2`: EOS, special token representing the end of a sequence\n",
347
+ "* `1`: special token representing the padding of a sequence when requesting a specific length"
348
+ ]
349
+ },
350
+ {
351
+ "cell_type": "markdown",
352
+ "metadata": {
353
+ "id": "-CEJBnuJOe5z"
354
+ },
355
+ "source": [
356
+ "Finally we replicate it onto each device."
357
+ ]
358
+ },
359
+ {
360
+ "cell_type": "code",
361
+ "execution_count": null,
362
+ "metadata": {
363
+ "id": "lQePgju5Oe5z"
364
+ },
365
+ "outputs": [],
366
+ "source": [
367
+ "tokenized_prompt = replicate(tokenized_prompt)"
368
+ ]
369
+ },
370
+ {
371
+ "cell_type": "markdown",
372
+ "metadata": {
373
+ "id": "phQ9bhjRkgAZ"
374
+ },
375
+ "source": [
376
+ "## 🎨 Generate images\n",
377
+ "\n",
378
+ "We generate images using dalle-mini model and decode them with the VQGAN."
379
+ ]
380
+ },
381
+ {
382
+ "cell_type": "code",
383
+ "execution_count": null,
384
+ "metadata": {
385
+ "id": "d0wVkXpKqnHA"
386
+ },
387
+ "outputs": [],
388
+ "source": [
389
+ "# number of predictions\n",
390
+ "n_predictions = 32\n",
391
+ "\n",
392
+ "# We can customize top_k/top_p used for generating samples\n",
393
+ "gen_top_k = None\n",
394
+ "gen_top_p = None"
395
+ ]
396
+ },
397
+ {
398
+ "cell_type": "code",
399
+ "execution_count": null,
400
+ "metadata": {
401
+ "id": "SDjEx9JxR3v8"
402
+ },
403
+ "outputs": [],
404
+ "source": [
405
+ "from flax.training.common_utils import shard_prng_key\n",
406
+ "import numpy as np\n",
407
+ "from PIL import Image\n",
408
+ "from tqdm.notebook import trange\n",
409
+ "\n",
410
+ "# generate images\n",
411
+ "images = []\n",
412
+ "for i in trange(n_predictions // jax.device_count()):\n",
413
+ " # get a new key\n",
414
+ " key, subkey = jax.random.split(key)\n",
415
+ " # generate images\n",
416
+ " encoded_images = p_generate(\n",
417
+ " tokenized_prompt, shard_prng_key(subkey), model_params, gen_top_k, gen_top_p\n",
418
+ " )\n",
419
+ " # remove BOS\n",
420
+ " encoded_images = encoded_images.sequences[..., 1:]\n",
421
+ " # decode images\n",
422
+ " decoded_images = p_decode(encoded_images, vqgan_params)\n",
423
+ " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
424
+ " for img in decoded_images:\n",
425
+ " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
426
+ ]
427
+ },
428
+ {
429
+ "cell_type": "markdown",
430
+ "metadata": {
431
+ "id": "tw02wG9zGmyB"
432
+ },
433
+ "source": [
434
+ "Let's calculate their score with CLIP."
435
+ ]
436
+ },
437
+ {
438
+ "cell_type": "code",
439
+ "execution_count": null,
440
+ "metadata": {
441
+ "id": "FoLXpjCmGpju"
442
+ },
443
+ "outputs": [],
444
+ "source": [
445
+ "from flax.training.common_utils import shard\n",
446
+ "\n",
447
+ "# get clip scores\n",
448
+ "clip_inputs = processor(\n",
449
+ " text=[prompt] * jax.device_count(),\n",
450
+ " images=images,\n",
451
+ " return_tensors=\"np\",\n",
452
+ " padding=\"max_length\",\n",
453
+ " max_length=77,\n",
454
+ " truncation=True,\n",
455
+ ").data\n",
456
+ "logits = p_clip(shard(clip_inputs), clip_params)\n",
457
+ "logits = logits.squeeze().flatten()"
458
+ ]
459
+ },
460
+ {
461
+ "cell_type": "markdown",
462
+ "metadata": {
463
+ "id": "4AAWRm70LgED"
464
+ },
465
+ "source": [
466
+ "Let's display images ranked by CLIP score."
467
+ ]
468
+ },
469
+ {
470
+ "cell_type": "code",
471
+ "execution_count": null,
472
+ "metadata": {
473
+ "id": "zsgxxubLLkIu"
474
+ },
475
+ "outputs": [],
476
+ "source": [
477
+ "print(f\"Prompt: {prompt}\\n\")\n",
478
+ "for idx in logits.argsort()[::-1]:\n",
479
+ " display(images[idx])\n",
480
+ " print(f\"Score: {logits[idx]:.2f}\\n\")"
481
+ ]
482
+ }
483
+ ],
484
+ "metadata": {
485
+ "accelerator": "GPU",
486
+ "colab": {
487
+ "collapsed_sections": [],
488
+ "include_colab_link": true,
489
+ "machine_shape": "hm",
490
+ "name": "DALL·E mini - Inference pipeline.ipynb",
491
+ "provenance": []
492
+ },
493
+ "kernelspec": {
494
+ "display_name": "Python 3 (ipykernel)",
495
+ "language": "python",
496
+ "name": "python3"
497
+ },
498
+ "language_info": {
499
+ "codemirror_mode": {
500
+ "name": "ipython",
501
+ "version": 3
502
+ },
503
+ "file_extension": ".py",
504
+ "mimetype": "text/x-python",
505
+ "name": "python",
506
+ "nbconvert_exporter": "python",
507
+ "pygments_lexer": "ipython3",
508
+ "version": "3.9.7"
509
+ }
510
+ },
511
+ "nbformat": 4,
512
+ "nbformat_minor": 0
513
+ }