fix(colab): use correct param name for CLIP
Browse files- tools/inference/inference_pipeline.ipynb +466 -849
tools/inference/inference_pipeline.ipynb
CHANGED
@@ -1,865 +1,482 @@
|
|
1 |
{
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
"\n",
|
13 |
-
"<img src=\"https://github.com/borisdayma/dalle-mini/blob/main/img/logo.png?raw=true\" width=\"200\">\n",
|
14 |
-
"\n",
|
15 |
-
"This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
|
16 |
-
"\n",
|
17 |
-
"Just want to play? Use directly [DALL·E mini app](https://huggingface.co/spaces/dalle-mini/dalle-mini).\n",
|
18 |
-
"\n",
|
19 |
-
"For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
|
20 |
-
]
|
21 |
-
},
|
22 |
-
{
|
23 |
-
"cell_type": "markdown",
|
24 |
-
"metadata": {
|
25 |
-
"id": "dS8LbaonYm3a"
|
26 |
-
},
|
27 |
-
"source": [
|
28 |
-
"## 🛠️ Installation and set-up"
|
29 |
-
]
|
30 |
-
},
|
31 |
-
{
|
32 |
-
"cell_type": "code",
|
33 |
-
"execution_count": null,
|
34 |
-
"metadata": {
|
35 |
-
"id": "uzjAM2GBYpZX"
|
36 |
-
},
|
37 |
-
"outputs": [],
|
38 |
-
"source": [
|
39 |
-
"# Install required libraries\n",
|
40 |
-
"!pip install -q git+https://github.com/huggingface/transformers.git\n",
|
41 |
-
"!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
|
42 |
-
"!pip install -q git+https://github.com/borisdayma/dalle-mini.git"
|
43 |
-
]
|
44 |
-
},
|
45 |
-
{
|
46 |
-
"cell_type": "markdown",
|
47 |
-
"metadata": {
|
48 |
-
"id": "ozHzTkyv8cqU"
|
49 |
-
},
|
50 |
-
"source": [
|
51 |
-
"We load required models:\n",
|
52 |
-
"* DALL·E mini for text to encoded images\n",
|
53 |
-
"* VQGAN for decoding images\n",
|
54 |
-
"* CLIP for scoring predictions"
|
55 |
-
]
|
56 |
-
},
|
57 |
-
{
|
58 |
-
"cell_type": "code",
|
59 |
-
"execution_count": null,
|
60 |
-
"metadata": {
|
61 |
-
"id": "K6CxW2o42f-w"
|
62 |
-
},
|
63 |
-
"outputs": [],
|
64 |
-
"source": [
|
65 |
-
"# Model references\n",
|
66 |
-
"\n",
|
67 |
-
"# dalle-mega\n",
|
68 |
-
"DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
|
69 |
-
"DALLE_COMMIT_ID = None\n",
|
70 |
-
"\n",
|
71 |
-
"# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line\n",
|
72 |
-
"# DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\"\n",
|
73 |
-
"\n",
|
74 |
-
"# VQGAN model\n",
|
75 |
-
"VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
|
76 |
-
"VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\""
|
77 |
-
]
|
78 |
-
},
|
79 |
-
{
|
80 |
-
"cell_type": "code",
|
81 |
-
"execution_count": null,
|
82 |
-
"metadata": {
|
83 |
-
"colab": {
|
84 |
-
"base_uri": "https://localhost:8080/"
|
85 |
},
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
},
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
"from transformers import CLIPProcessor, FlaxCLIPModel\n",
|
114 |
-
"\n",
|
115 |
-
"# Load dalle-mini\n",
|
116 |
-
"model, params = DalleBart.from_pretrained(\n",
|
117 |
-
" DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
|
118 |
-
")\n",
|
119 |
-
"\n",
|
120 |
-
"# Load VQGAN\n",
|
121 |
-
"vqgan, vqgan_params = VQModel.from_pretrained(\n",
|
122 |
-
" VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False\n",
|
123 |
-
")"
|
124 |
-
]
|
125 |
-
},
|
126 |
-
{
|
127 |
-
"cell_type": "markdown",
|
128 |
-
"metadata": {
|
129 |
-
"id": "o_vH2X1tDtzA"
|
130 |
-
},
|
131 |
-
"source": [
|
132 |
-
"Model parameters are replicated on each device for faster inference."
|
133 |
-
]
|
134 |
-
},
|
135 |
-
{
|
136 |
-
"cell_type": "code",
|
137 |
-
"execution_count": null,
|
138 |
-
"metadata": {
|
139 |
-
"id": "wtvLoM48EeVw"
|
140 |
-
},
|
141 |
-
"outputs": [],
|
142 |
-
"source": [
|
143 |
-
"from flax.jax_utils import replicate\n",
|
144 |
-
"\n",
|
145 |
-
"params = replicate(params)\n",
|
146 |
-
"vqgan_params = replicate(vqgan_params)"
|
147 |
-
]
|
148 |
-
},
|
149 |
-
{
|
150 |
-
"cell_type": "markdown",
|
151 |
-
"metadata": {
|
152 |
-
"id": "0A9AHQIgZ_qw"
|
153 |
-
},
|
154 |
-
"source": [
|
155 |
-
"Model functions are compiled and parallelized to take advantage of multiple devices."
|
156 |
-
]
|
157 |
-
},
|
158 |
-
{
|
159 |
-
"cell_type": "code",
|
160 |
-
"execution_count": null,
|
161 |
-
"metadata": {
|
162 |
-
"id": "sOtoOmYsSYPz"
|
163 |
-
},
|
164 |
-
"outputs": [],
|
165 |
-
"source": [
|
166 |
-
"from functools import partial\n",
|
167 |
-
"\n",
|
168 |
-
"# model inference\n",
|
169 |
-
"@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n",
|
170 |
-
"def p_generate(\n",
|
171 |
-
" tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n",
|
172 |
-
"):\n",
|
173 |
-
" return model.generate(\n",
|
174 |
-
" **tokenized_prompt,\n",
|
175 |
-
" prng_key=key,\n",
|
176 |
-
" params=params,\n",
|
177 |
-
" top_k=top_k,\n",
|
178 |
-
" top_p=top_p,\n",
|
179 |
-
" temperature=temperature,\n",
|
180 |
-
" condition_scale=condition_scale,\n",
|
181 |
-
" )\n",
|
182 |
-
"\n",
|
183 |
-
"\n",
|
184 |
-
"# decode image\n",
|
185 |
-
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
186 |
-
"def p_decode(indices, params):\n",
|
187 |
-
" return vqgan.decode_code(indices, params=params)"
|
188 |
-
]
|
189 |
-
},
|
190 |
-
{
|
191 |
-
"cell_type": "markdown",
|
192 |
-
"metadata": {
|
193 |
-
"id": "HmVN6IBwapBA"
|
194 |
-
},
|
195 |
-
"source": [
|
196 |
-
"Keys are passed to the model on each device to generate unique inference per device."
|
197 |
-
]
|
198 |
-
},
|
199 |
-
{
|
200 |
-
"cell_type": "code",
|
201 |
-
"execution_count": null,
|
202 |
-
"metadata": {
|
203 |
-
"id": "4CTXmlUkThhX"
|
204 |
-
},
|
205 |
-
"outputs": [],
|
206 |
-
"source": [
|
207 |
-
"import random\n",
|
208 |
-
"\n",
|
209 |
-
"# create a random key\n",
|
210 |
-
"seed = random.randint(0, 2**32 - 1)\n",
|
211 |
-
"key = jax.random.PRNGKey(seed)"
|
212 |
-
]
|
213 |
-
},
|
214 |
-
{
|
215 |
-
"cell_type": "markdown",
|
216 |
-
"metadata": {
|
217 |
-
"id": "BrnVyCo81pij"
|
218 |
-
},
|
219 |
-
"source": [
|
220 |
-
"## 🖍 Text Prompt"
|
221 |
-
]
|
222 |
-
},
|
223 |
-
{
|
224 |
-
"cell_type": "markdown",
|
225 |
-
"metadata": {
|
226 |
-
"id": "rsmj0Aj5OQox"
|
227 |
-
},
|
228 |
-
"source": [
|
229 |
-
"Our model requires processing prompts."
|
230 |
-
]
|
231 |
-
},
|
232 |
-
{
|
233 |
-
"cell_type": "code",
|
234 |
-
"execution_count": null,
|
235 |
-
"metadata": {
|
236 |
-
"colab": {
|
237 |
-
"base_uri": "https://localhost:8080/"
|
238 |
},
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
"id": "BQ7fymSPyvF_"
|
253 |
-
},
|
254 |
-
"source": [
|
255 |
-
"Let's define a text prompt."
|
256 |
-
]
|
257 |
-
},
|
258 |
-
{
|
259 |
-
"cell_type": "code",
|
260 |
-
"execution_count": null,
|
261 |
-
"metadata": {
|
262 |
-
"id": "x_0vI9ge1oKr"
|
263 |
-
},
|
264 |
-
"outputs": [],
|
265 |
-
"source": [
|
266 |
-
"prompt = \"sunset over a lake in the mountains\""
|
267 |
-
]
|
268 |
-
},
|
269 |
-
{
|
270 |
-
"cell_type": "code",
|
271 |
-
"execution_count": null,
|
272 |
-
"metadata": {
|
273 |
-
"id": "VKjEZGjtO49k"
|
274 |
-
},
|
275 |
-
"outputs": [],
|
276 |
-
"source": [
|
277 |
-
"tokenized_prompt = processor([prompt])"
|
278 |
-
]
|
279 |
-
},
|
280 |
-
{
|
281 |
-
"cell_type": "markdown",
|
282 |
-
"metadata": {
|
283 |
-
"id": "-CEJBnuJOe5z"
|
284 |
-
},
|
285 |
-
"source": [
|
286 |
-
"Finally we replicate it onto each device."
|
287 |
-
]
|
288 |
-
},
|
289 |
-
{
|
290 |
-
"cell_type": "code",
|
291 |
-
"execution_count": null,
|
292 |
-
"metadata": {
|
293 |
-
"id": "lQePgju5Oe5z"
|
294 |
-
},
|
295 |
-
"outputs": [],
|
296 |
-
"source": [
|
297 |
-
"tokenized_prompt = replicate(tokenized_prompt)"
|
298 |
-
]
|
299 |
-
},
|
300 |
-
{
|
301 |
-
"cell_type": "markdown",
|
302 |
-
"metadata": {
|
303 |
-
"id": "phQ9bhjRkgAZ"
|
304 |
-
},
|
305 |
-
"source": [
|
306 |
-
"## 🎨 Generate images\n",
|
307 |
-
"\n",
|
308 |
-
"We generate images using dalle-mini model and decode them with the VQGAN."
|
309 |
-
]
|
310 |
-
},
|
311 |
-
{
|
312 |
-
"cell_type": "code",
|
313 |
-
"execution_count": null,
|
314 |
-
"metadata": {
|
315 |
-
"id": "d0wVkXpKqnHA"
|
316 |
-
},
|
317 |
-
"outputs": [],
|
318 |
-
"source": [
|
319 |
-
"# number of predictions\n",
|
320 |
-
"n_predictions = 8\n",
|
321 |
-
"\n",
|
322 |
-
"# We can customize generation parameters\n",
|
323 |
-
"gen_top_k = None\n",
|
324 |
-
"gen_top_p = None\n",
|
325 |
-
"temperature = None\n",
|
326 |
-
"cond_scale = 3.0"
|
327 |
-
]
|
328 |
-
},
|
329 |
-
{
|
330 |
-
"cell_type": "code",
|
331 |
-
"execution_count": null,
|
332 |
-
"metadata": {
|
333 |
-
"colab": {
|
334 |
-
"base_uri": "https://localhost:8080/",
|
335 |
-
"height": 1000,
|
336 |
-
"referenced_widgets": [
|
337 |
-
"cef76449b8d74217ae36c56be3990eec",
|
338 |
-
"7be07ba7cfe642a596509c756dcefddc",
|
339 |
-
"2a02378499fc414299f17a2d5dcac867",
|
340 |
-
"427d47d9423441d286ae80a637ae35a0",
|
341 |
-
"cb157fd4e37041d1beae29eaa729c8ff",
|
342 |
-
"73413668398b45dfa8484a2c2be778ec",
|
343 |
-
"e7d108a4b168442fb2048f58ddeb0a18",
|
344 |
-
"5e81a141422f432395055f5cafb07016",
|
345 |
-
"5f476a929da84fa985b2e980459da7b9",
|
346 |
-
"f3b643a0ca2444fd959fff9b45d79d27",
|
347 |
-
"82b87345233549d699ce3fd8080fa988"
|
348 |
-
]
|
349 |
},
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
" display(images[idx])\n",
|
487 |
-
" print(f\"Score: {logits[idx]:.2f}\\n\")"
|
488 |
-
]
|
489 |
-
}
|
490 |
-
],
|
491 |
-
"metadata": {
|
492 |
-
"accelerator": "GPU",
|
493 |
-
"colab": {
|
494 |
-
"collapsed_sections": [],
|
495 |
-
"machine_shape": "hm",
|
496 |
-
"name": "DALL·E mini - Inference pipeline.ipynb",
|
497 |
-
"provenance": []
|
498 |
-
},
|
499 |
-
"kernelspec": {
|
500 |
-
"display_name": "Python 3 (ipykernel)",
|
501 |
-
"language": "python",
|
502 |
-
"name": "python3"
|
503 |
-
},
|
504 |
-
"language_info": {
|
505 |
-
"codemirror_mode": {
|
506 |
-
"name": "ipython",
|
507 |
-
"version": 3
|
508 |
-
},
|
509 |
-
"file_extension": ".py",
|
510 |
-
"mimetype": "text/x-python",
|
511 |
-
"name": "python",
|
512 |
-
"nbconvert_exporter": "python",
|
513 |
-
"pygments_lexer": "ipython3",
|
514 |
-
"version": "3.9.7"
|
515 |
-
},
|
516 |
-
"widgets": {
|
517 |
-
"application/vnd.jupyter.widget-state+json": {
|
518 |
-
"2a02378499fc414299f17a2d5dcac867": {
|
519 |
-
"model_module": "@jupyter-widgets/controls",
|
520 |
-
"model_module_version": "1.5.0",
|
521 |
-
"model_name": "FloatProgressModel",
|
522 |
-
"state": {
|
523 |
-
"_dom_classes": [],
|
524 |
-
"_model_module": "@jupyter-widgets/controls",
|
525 |
-
"_model_module_version": "1.5.0",
|
526 |
-
"_model_name": "FloatProgressModel",
|
527 |
-
"_view_count": null,
|
528 |
-
"_view_module": "@jupyter-widgets/controls",
|
529 |
-
"_view_module_version": "1.5.0",
|
530 |
-
"_view_name": "ProgressView",
|
531 |
-
"bar_style": "",
|
532 |
-
"description": "",
|
533 |
-
"description_tooltip": null,
|
534 |
-
"layout": "IPY_MODEL_5e81a141422f432395055f5cafb07016",
|
535 |
-
"max": 8,
|
536 |
-
"min": 0,
|
537 |
-
"orientation": "horizontal",
|
538 |
-
"style": "IPY_MODEL_5f476a929da84fa985b2e980459da7b9",
|
539 |
-
"value": 5
|
540 |
-
}
|
541 |
},
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
"
|
548 |
-
|
549 |
-
|
550 |
-
"_model_name": "HTMLModel",
|
551 |
-
"_view_count": null,
|
552 |
-
"_view_module": "@jupyter-widgets/controls",
|
553 |
-
"_view_module_version": "1.5.0",
|
554 |
-
"_view_name": "HTMLView",
|
555 |
-
"description": "",
|
556 |
-
"description_tooltip": null,
|
557 |
-
"layout": "IPY_MODEL_f3b643a0ca2444fd959fff9b45d79d27",
|
558 |
-
"placeholder": "",
|
559 |
-
"style": "IPY_MODEL_82b87345233549d699ce3fd8080fa988",
|
560 |
-
"value": " 5/8 [04:25<02:39, 53.09s/it]"
|
561 |
-
}
|
562 |
},
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
"
|
570 |
-
"
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
"align_self": null,
|
578 |
-
"border": null,
|
579 |
-
"bottom": null,
|
580 |
-
"display": null,
|
581 |
-
"flex": null,
|
582 |
-
"flex_flow": null,
|
583 |
-
"grid_area": null,
|
584 |
-
"grid_auto_columns": null,
|
585 |
-
"grid_auto_flow": null,
|
586 |
-
"grid_auto_rows": null,
|
587 |
-
"grid_column": null,
|
588 |
-
"grid_gap": null,
|
589 |
-
"grid_row": null,
|
590 |
-
"grid_template_areas": null,
|
591 |
-
"grid_template_columns": null,
|
592 |
-
"grid_template_rows": null,
|
593 |
-
"height": null,
|
594 |
-
"justify_content": null,
|
595 |
-
"justify_items": null,
|
596 |
-
"left": null,
|
597 |
-
"margin": null,
|
598 |
-
"max_height": null,
|
599 |
-
"max_width": null,
|
600 |
-
"min_height": null,
|
601 |
-
"min_width": null,
|
602 |
-
"object_fit": null,
|
603 |
-
"object_position": null,
|
604 |
-
"order": null,
|
605 |
-
"overflow": null,
|
606 |
-
"overflow_x": null,
|
607 |
-
"overflow_y": null,
|
608 |
-
"padding": null,
|
609 |
-
"right": null,
|
610 |
-
"top": null,
|
611 |
-
"visibility": null,
|
612 |
-
"width": null
|
613 |
-
}
|
614 |
},
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
"
|
621 |
-
|
622 |
-
|
623 |
-
"_view_count": null,
|
624 |
-
"_view_module": "@jupyter-widgets/base",
|
625 |
-
"_view_module_version": "1.2.0",
|
626 |
-
"_view_name": "StyleView",
|
627 |
-
"bar_color": null,
|
628 |
-
"description_width": ""
|
629 |
-
}
|
630 |
},
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
"
|
637 |
-
|
638 |
-
|
639 |
-
"_view_count": null,
|
640 |
-
"_view_module": "@jupyter-widgets/base",
|
641 |
-
"_view_module_version": "1.2.0",
|
642 |
-
"_view_name": "LayoutView",
|
643 |
-
"align_content": null,
|
644 |
-
"align_items": null,
|
645 |
-
"align_self": null,
|
646 |
-
"border": null,
|
647 |
-
"bottom": null,
|
648 |
-
"display": null,
|
649 |
-
"flex": null,
|
650 |
-
"flex_flow": null,
|
651 |
-
"grid_area": null,
|
652 |
-
"grid_auto_columns": null,
|
653 |
-
"grid_auto_flow": null,
|
654 |
-
"grid_auto_rows": null,
|
655 |
-
"grid_column": null,
|
656 |
-
"grid_gap": null,
|
657 |
-
"grid_row": null,
|
658 |
-
"grid_template_areas": null,
|
659 |
-
"grid_template_columns": null,
|
660 |
-
"grid_template_rows": null,
|
661 |
-
"height": null,
|
662 |
-
"justify_content": null,
|
663 |
-
"justify_items": null,
|
664 |
-
"left": null,
|
665 |
-
"margin": null,
|
666 |
-
"max_height": null,
|
667 |
-
"max_width": null,
|
668 |
-
"min_height": null,
|
669 |
-
"min_width": null,
|
670 |
-
"object_fit": null,
|
671 |
-
"object_position": null,
|
672 |
-
"order": null,
|
673 |
-
"overflow": null,
|
674 |
-
"overflow_x": null,
|
675 |
-
"overflow_y": null,
|
676 |
-
"padding": null,
|
677 |
-
"right": null,
|
678 |
-
"top": null,
|
679 |
-
"visibility": null,
|
680 |
-
"width": null
|
681 |
-
}
|
682 |
},
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
"
|
690 |
-
"
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
"_view_name": "HTMLView",
|
696 |
-
"description": "",
|
697 |
-
"description_tooltip": null,
|
698 |
-
"layout": "IPY_MODEL_73413668398b45dfa8484a2c2be778ec",
|
699 |
-
"placeholder": "",
|
700 |
-
"style": "IPY_MODEL_e7d108a4b168442fb2048f58ddeb0a18",
|
701 |
-
"value": " 62%"
|
702 |
-
}
|
703 |
},
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
"
|
710 |
-
|
711 |
-
|
712 |
-
"_view_count": null,
|
713 |
-
"_view_module": "@jupyter-widgets/base",
|
714 |
-
"_view_module_version": "1.2.0",
|
715 |
-
"_view_name": "StyleView",
|
716 |
-
"description_width": ""
|
717 |
-
}
|
718 |
},
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
"
|
726 |
-
"
|
727 |
-
|
728 |
-
|
729 |
-
"_view_module_version": "1.2.0",
|
730 |
-
"_view_name": "LayoutView",
|
731 |
-
"align_content": null,
|
732 |
-
"align_items": null,
|
733 |
-
"align_self": null,
|
734 |
-
"border": null,
|
735 |
-
"bottom": null,
|
736 |
-
"display": null,
|
737 |
-
"flex": null,
|
738 |
-
"flex_flow": null,
|
739 |
-
"grid_area": null,
|
740 |
-
"grid_auto_columns": null,
|
741 |
-
"grid_auto_flow": null,
|
742 |
-
"grid_auto_rows": null,
|
743 |
-
"grid_column": null,
|
744 |
-
"grid_gap": null,
|
745 |
-
"grid_row": null,
|
746 |
-
"grid_template_areas": null,
|
747 |
-
"grid_template_columns": null,
|
748 |
-
"grid_template_rows": null,
|
749 |
-
"height": null,
|
750 |
-
"justify_content": null,
|
751 |
-
"justify_items": null,
|
752 |
-
"left": null,
|
753 |
-
"margin": null,
|
754 |
-
"max_height": null,
|
755 |
-
"max_width": null,
|
756 |
-
"min_height": null,
|
757 |
-
"min_width": null,
|
758 |
-
"object_fit": null,
|
759 |
-
"object_position": null,
|
760 |
-
"order": null,
|
761 |
-
"overflow": null,
|
762 |
-
"overflow_x": null,
|
763 |
-
"overflow_y": null,
|
764 |
-
"padding": null,
|
765 |
-
"right": null,
|
766 |
-
"top": null,
|
767 |
-
"visibility": null,
|
768 |
-
"width": null
|
769 |
-
}
|
770 |
},
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
"
|
778 |
-
"
|
779 |
-
|
780 |
-
|
781 |
-
"_view_module": "@jupyter-widgets/controls",
|
782 |
-
"_view_module_version": "1.5.0",
|
783 |
-
"_view_name": "HBoxView",
|
784 |
-
"box_style": "",
|
785 |
-
"children": [
|
786 |
-
"IPY_MODEL_7be07ba7cfe642a596509c756dcefddc",
|
787 |
-
"IPY_MODEL_2a02378499fc414299f17a2d5dcac867",
|
788 |
-
"IPY_MODEL_427d47d9423441d286ae80a637ae35a0"
|
789 |
-
],
|
790 |
-
"layout": "IPY_MODEL_cb157fd4e37041d1beae29eaa729c8ff"
|
791 |
-
}
|
792 |
},
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
-
|
797 |
-
|
798 |
-
"
|
799 |
-
|
800 |
-
|
801 |
-
"_view_count": null,
|
802 |
-
"_view_module": "@jupyter-widgets/base",
|
803 |
-
"_view_module_version": "1.2.0",
|
804 |
-
"_view_name": "StyleView",
|
805 |
-
"description_width": ""
|
806 |
-
}
|
807 |
},
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
813 |
-
|
814 |
-
"
|
815 |
-
"
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
"
|
821 |
-
"
|
822 |
-
|
823 |
-
|
824 |
-
"
|
825 |
-
|
826 |
-
|
827 |
-
|
828 |
-
|
829 |
-
|
830 |
-
|
831 |
-
"
|
832 |
-
"
|
833 |
-
"
|
834 |
-
|
835 |
-
|
836 |
-
"
|
837 |
-
"
|
838 |
-
|
839 |
-
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
"
|
850 |
-
"
|
851 |
-
"
|
852 |
-
|
853 |
-
|
854 |
-
"
|
855 |
-
"
|
856 |
-
|
857 |
-
|
858 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
859 |
}
|
860 |
-
|
861 |
-
|
862 |
-
|
863 |
-
|
864 |
-
"nbformat_minor": 0
|
865 |
-
}
|
|
|
1 |
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "view-in-github",
|
7 |
+
"colab_type": "text"
|
8 |
+
},
|
9 |
+
"source": [
|
10 |
+
"<a href=\"https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
11 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
},
|
13 |
+
{
|
14 |
+
"cell_type": "markdown",
|
15 |
+
"metadata": {
|
16 |
+
"id": "118UKH5bWCGa"
|
17 |
+
},
|
18 |
+
"source": [
|
19 |
+
"# DALL·E mini - Inference pipeline\n",
|
20 |
+
"\n",
|
21 |
+
"*Generate images from a text prompt*\n",
|
22 |
+
"\n",
|
23 |
+
"<img src=\"https://github.com/borisdayma/dalle-mini/blob/main/img/logo.png?raw=true\" width=\"200\">\n",
|
24 |
+
"\n",
|
25 |
+
"This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
|
26 |
+
"\n",
|
27 |
+
"Just want to play? Use directly [DALL·E mini app](https://huggingface.co/spaces/dalle-mini/dalle-mini).\n",
|
28 |
+
"\n",
|
29 |
+
"For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
|
30 |
+
]
|
31 |
},
|
32 |
+
{
|
33 |
+
"cell_type": "markdown",
|
34 |
+
"metadata": {
|
35 |
+
"id": "dS8LbaonYm3a"
|
36 |
+
},
|
37 |
+
"source": [
|
38 |
+
"## 🛠️ Installation and set-up"
|
39 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
},
|
41 |
+
{
|
42 |
+
"cell_type": "code",
|
43 |
+
"execution_count": null,
|
44 |
+
"metadata": {
|
45 |
+
"id": "uzjAM2GBYpZX"
|
46 |
+
},
|
47 |
+
"outputs": [],
|
48 |
+
"source": [
|
49 |
+
"# Install required libraries\n",
|
50 |
+
"!pip install -q git+https://github.com/huggingface/transformers.git\n",
|
51 |
+
"!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
|
52 |
+
"!pip install -q git+https://github.com/borisdayma/dalle-mini.git"
|
53 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
},
|
55 |
+
{
|
56 |
+
"cell_type": "markdown",
|
57 |
+
"metadata": {
|
58 |
+
"id": "ozHzTkyv8cqU"
|
59 |
+
},
|
60 |
+
"source": [
|
61 |
+
"We load required models:\n",
|
62 |
+
"* DALL·E mini for text to encoded images\n",
|
63 |
+
"* VQGAN for decoding images\n",
|
64 |
+
"* CLIP for scoring predictions"
|
65 |
+
]
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"cell_type": "code",
|
69 |
+
"execution_count": null,
|
70 |
+
"metadata": {
|
71 |
+
"id": "K6CxW2o42f-w"
|
72 |
+
},
|
73 |
+
"outputs": [],
|
74 |
+
"source": [
|
75 |
+
"# Model references\n",
|
76 |
+
"\n",
|
77 |
+
"# dalle-mega\n",
|
78 |
+
"DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
|
79 |
+
"DALLE_COMMIT_ID = None\n",
|
80 |
+
"\n",
|
81 |
+
"# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line\n",
|
82 |
+
"# DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\"\n",
|
83 |
+
"\n",
|
84 |
+
"# VQGAN model\n",
|
85 |
+
"VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
|
86 |
+
"VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\""
|
87 |
+
]
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"cell_type": "code",
|
91 |
+
"execution_count": null,
|
92 |
+
"metadata": {
|
93 |
+
"id": "Yv-aR3t4Oe5v"
|
94 |
+
},
|
95 |
+
"outputs": [],
|
96 |
+
"source": [
|
97 |
+
"import jax\n",
|
98 |
+
"import jax.numpy as jnp\n",
|
99 |
+
"\n",
|
100 |
+
"# check how many devices are available\n",
|
101 |
+
"jax.local_device_count()"
|
102 |
+
]
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"cell_type": "code",
|
106 |
+
"execution_count": null,
|
107 |
+
"metadata": {
|
108 |
+
"id": "92zYmvsQ38vL"
|
109 |
+
},
|
110 |
+
"outputs": [],
|
111 |
+
"source": [
|
112 |
+
"# Load models & tokenizer\n",
|
113 |
+
"from dalle_mini import DalleBart, DalleBartProcessor\n",
|
114 |
+
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
|
115 |
+
"from transformers import CLIPProcessor, FlaxCLIPModel\n",
|
116 |
+
"\n",
|
117 |
+
"# Load dalle-mini\n",
|
118 |
+
"model, params = DalleBart.from_pretrained(\n",
|
119 |
+
" DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
|
120 |
+
")\n",
|
121 |
+
"\n",
|
122 |
+
"# Load VQGAN\n",
|
123 |
+
"vqgan, vqgan_params = VQModel.from_pretrained(\n",
|
124 |
+
" VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False\n",
|
125 |
+
")"
|
126 |
+
]
|
127 |
+
},
|
128 |
+
{
|
129 |
+
"cell_type": "markdown",
|
130 |
+
"metadata": {
|
131 |
+
"id": "o_vH2X1tDtzA"
|
132 |
+
},
|
133 |
+
"source": [
|
134 |
+
"Model parameters are replicated on each device for faster inference."
|
135 |
+
]
|
136 |
+
},
|
137 |
+
{
|
138 |
+
"cell_type": "code",
|
139 |
+
"execution_count": null,
|
140 |
+
"metadata": {
|
141 |
+
"id": "wtvLoM48EeVw"
|
142 |
+
},
|
143 |
+
"outputs": [],
|
144 |
+
"source": [
|
145 |
+
"from flax.jax_utils import replicate\n",
|
146 |
+
"\n",
|
147 |
+
"params = replicate(params)\n",
|
148 |
+
"vqgan_params = replicate(vqgan_params)"
|
149 |
+
]
|
150 |
+
},
|
151 |
+
{
|
152 |
+
"cell_type": "markdown",
|
153 |
+
"metadata": {
|
154 |
+
"id": "0A9AHQIgZ_qw"
|
155 |
+
},
|
156 |
+
"source": [
|
157 |
+
"Model functions are compiled and parallelized to take advantage of multiple devices."
|
158 |
+
]
|
159 |
+
},
|
160 |
+
{
|
161 |
+
"cell_type": "code",
|
162 |
+
"execution_count": null,
|
163 |
+
"metadata": {
|
164 |
+
"id": "sOtoOmYsSYPz"
|
165 |
+
},
|
166 |
+
"outputs": [],
|
167 |
+
"source": [
|
168 |
+
"from functools import partial\n",
|
169 |
+
"\n",
|
170 |
+
"# model inference\n",
|
171 |
+
"@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n",
|
172 |
+
"def p_generate(\n",
|
173 |
+
" tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n",
|
174 |
+
"):\n",
|
175 |
+
" return model.generate(\n",
|
176 |
+
" **tokenized_prompt,\n",
|
177 |
+
" prng_key=key,\n",
|
178 |
+
" params=params,\n",
|
179 |
+
" top_k=top_k,\n",
|
180 |
+
" top_p=top_p,\n",
|
181 |
+
" temperature=temperature,\n",
|
182 |
+
" condition_scale=condition_scale,\n",
|
183 |
+
" )\n",
|
184 |
+
"\n",
|
185 |
+
"\n",
|
186 |
+
"# decode image\n",
|
187 |
+
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
188 |
+
"def p_decode(indices, params):\n",
|
189 |
+
" return vqgan.decode_code(indices, params=params)"
|
190 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
},
|
192 |
+
{
|
193 |
+
"cell_type": "markdown",
|
194 |
+
"metadata": {
|
195 |
+
"id": "HmVN6IBwapBA"
|
196 |
+
},
|
197 |
+
"source": [
|
198 |
+
"Keys are passed to the model on each device to generate unique inference per device."
|
199 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
},
|
201 |
+
{
|
202 |
+
"cell_type": "code",
|
203 |
+
"execution_count": null,
|
204 |
+
"metadata": {
|
205 |
+
"id": "4CTXmlUkThhX"
|
206 |
+
},
|
207 |
+
"outputs": [],
|
208 |
+
"source": [
|
209 |
+
"import random\n",
|
210 |
+
"\n",
|
211 |
+
"# create a random key\n",
|
212 |
+
"seed = random.randint(0, 2**32 - 1)\n",
|
213 |
+
"key = jax.random.PRNGKey(seed)"
|
214 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
},
|
216 |
+
{
|
217 |
+
"cell_type": "markdown",
|
218 |
+
"metadata": {
|
219 |
+
"id": "BrnVyCo81pij"
|
220 |
+
},
|
221 |
+
"source": [
|
222 |
+
"## 🖍 Text Prompt"
|
223 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
},
|
225 |
+
{
|
226 |
+
"cell_type": "markdown",
|
227 |
+
"metadata": {
|
228 |
+
"id": "rsmj0Aj5OQox"
|
229 |
+
},
|
230 |
+
"source": [
|
231 |
+
"Our model requires processing prompts."
|
232 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
},
|
234 |
+
{
|
235 |
+
"cell_type": "code",
|
236 |
+
"execution_count": null,
|
237 |
+
"metadata": {
|
238 |
+
"id": "YjjhUychOVxm"
|
239 |
+
},
|
240 |
+
"outputs": [],
|
241 |
+
"source": [
|
242 |
+
"from dalle_mini import DalleBartProcessor\n",
|
243 |
+
"\n",
|
244 |
+
"processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)"
|
245 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
},
|
247 |
+
{
|
248 |
+
"cell_type": "markdown",
|
249 |
+
"metadata": {
|
250 |
+
"id": "BQ7fymSPyvF_"
|
251 |
+
},
|
252 |
+
"source": [
|
253 |
+
"Let's define a text prompt."
|
254 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
},
|
256 |
+
{
|
257 |
+
"cell_type": "code",
|
258 |
+
"execution_count": null,
|
259 |
+
"metadata": {
|
260 |
+
"id": "x_0vI9ge1oKr"
|
261 |
+
},
|
262 |
+
"outputs": [],
|
263 |
+
"source": [
|
264 |
+
"prompt = \"sunset over a lake in the mountains\""
|
265 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
},
|
267 |
+
{
|
268 |
+
"cell_type": "code",
|
269 |
+
"execution_count": null,
|
270 |
+
"metadata": {
|
271 |
+
"id": "VKjEZGjtO49k"
|
272 |
+
},
|
273 |
+
"outputs": [],
|
274 |
+
"source": [
|
275 |
+
"tokenized_prompt = processor([prompt])"
|
276 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
},
|
278 |
+
{
|
279 |
+
"cell_type": "markdown",
|
280 |
+
"metadata": {
|
281 |
+
"id": "-CEJBnuJOe5z"
|
282 |
+
},
|
283 |
+
"source": [
|
284 |
+
"Finally we replicate it onto each device."
|
285 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
},
|
287 |
+
{
|
288 |
+
"cell_type": "code",
|
289 |
+
"execution_count": null,
|
290 |
+
"metadata": {
|
291 |
+
"id": "lQePgju5Oe5z"
|
292 |
+
},
|
293 |
+
"outputs": [],
|
294 |
+
"source": [
|
295 |
+
"tokenized_prompt = replicate(tokenized_prompt)"
|
296 |
+
]
|
297 |
+
},
|
298 |
+
{
|
299 |
+
"cell_type": "markdown",
|
300 |
+
"metadata": {
|
301 |
+
"id": "phQ9bhjRkgAZ"
|
302 |
+
},
|
303 |
+
"source": [
|
304 |
+
"## 🎨 Generate images\n",
|
305 |
+
"\n",
|
306 |
+
"We generate images using dalle-mini model and decode them with the VQGAN."
|
307 |
+
]
|
308 |
+
},
|
309 |
+
{
|
310 |
+
"cell_type": "code",
|
311 |
+
"execution_count": null,
|
312 |
+
"metadata": {
|
313 |
+
"id": "d0wVkXpKqnHA"
|
314 |
+
},
|
315 |
+
"outputs": [],
|
316 |
+
"source": [
|
317 |
+
"# number of predictions\n",
|
318 |
+
"n_predictions = 8\n",
|
319 |
+
"\n",
|
320 |
+
"# We can customize generation parameters\n",
|
321 |
+
"gen_top_k = None\n",
|
322 |
+
"gen_top_p = None\n",
|
323 |
+
"temperature = None\n",
|
324 |
+
"cond_scale = 3.0"
|
325 |
+
]
|
326 |
+
},
|
327 |
+
{
|
328 |
+
"cell_type": "code",
|
329 |
+
"execution_count": null,
|
330 |
+
"metadata": {
|
331 |
+
"id": "SDjEx9JxR3v8"
|
332 |
+
},
|
333 |
+
"outputs": [],
|
334 |
+
"source": [
|
335 |
+
"from flax.training.common_utils import shard_prng_key\n",
|
336 |
+
"import numpy as np\n",
|
337 |
+
"from PIL import Image\n",
|
338 |
+
"from tqdm.notebook import trange\n",
|
339 |
+
"\n",
|
340 |
+
"print(f\"Prompt: {prompt}\\n\")\n",
|
341 |
+
"# generate images\n",
|
342 |
+
"images = []\n",
|
343 |
+
"for i in trange(max(n_predictions // jax.device_count(), 1)):\n",
|
344 |
+
" # get a new key\n",
|
345 |
+
" key, subkey = jax.random.split(key)\n",
|
346 |
+
" # generate images\n",
|
347 |
+
" encoded_images = p_generate(\n",
|
348 |
+
" tokenized_prompt,\n",
|
349 |
+
" shard_prng_key(subkey),\n",
|
350 |
+
" params,\n",
|
351 |
+
" gen_top_k,\n",
|
352 |
+
" gen_top_p,\n",
|
353 |
+
" temperature,\n",
|
354 |
+
" cond_scale,\n",
|
355 |
+
" )\n",
|
356 |
+
" # remove BOS\n",
|
357 |
+
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
358 |
+
" # decode images\n",
|
359 |
+
" decoded_images = p_decode(encoded_images, vqgan_params)\n",
|
360 |
+
" decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
|
361 |
+
" for decoded_img in decoded_images:\n",
|
362 |
+
" img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))\n",
|
363 |
+
" images.append(img)\n",
|
364 |
+
" display(img)"
|
365 |
+
]
|
366 |
+
},
|
367 |
+
{
|
368 |
+
"cell_type": "markdown",
|
369 |
+
"metadata": {
|
370 |
+
"id": "tw02wG9zGmyB"
|
371 |
+
},
|
372 |
+
"source": [
|
373 |
+
"## 🏅 Optional: Rank images by CLIP score\n",
|
374 |
+
"\n",
|
375 |
+
"We can rank images according to CLIP.\n",
|
376 |
+
"\n",
|
377 |
+
"**Note: your session may crash if you don't have a subscription to Colab Pro.**"
|
378 |
+
]
|
379 |
+
},
|
380 |
+
{
|
381 |
+
"cell_type": "code",
|
382 |
+
"execution_count": null,
|
383 |
+
"metadata": {
|
384 |
+
"id": "RGjlIW_f6GA0"
|
385 |
+
},
|
386 |
+
"outputs": [],
|
387 |
+
"source": [
|
388 |
+
"# CLIP model\n",
|
389 |
+
"CLIP_REPO = \"openai/clip-vit-base-patch32\"\n",
|
390 |
+
"CLIP_COMMIT_ID = None\n",
|
391 |
+
"\n",
|
392 |
+
"# Load CLIP\n",
|
393 |
+
"clip, clip_params = FlaxCLIPModel.from_pretrained(\n",
|
394 |
+
" CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
|
395 |
+
")\n",
|
396 |
+
"clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
|
397 |
+
"clip_params = replicate(clip_params)\n",
|
398 |
+
"\n",
|
399 |
+
"# score images\n",
|
400 |
+
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
401 |
+
"def p_clip(inputs, params):\n",
|
402 |
+
" logits = clip(params=params, **inputs).logits_per_image\n",
|
403 |
+
" return logits"
|
404 |
+
]
|
405 |
+
},
|
406 |
+
{
|
407 |
+
"cell_type": "code",
|
408 |
+
"execution_count": null,
|
409 |
+
"metadata": {
|
410 |
+
"id": "FoLXpjCmGpju"
|
411 |
+
},
|
412 |
+
"outputs": [],
|
413 |
+
"source": [
|
414 |
+
"from flax.training.common_utils import shard\n",
|
415 |
+
"\n",
|
416 |
+
"# get clip scores\n",
|
417 |
+
"clip_inputs = clip_processor(\n",
|
418 |
+
" text=[prompt] * jax.device_count(),\n",
|
419 |
+
" images=images,\n",
|
420 |
+
" return_tensors=\"np\",\n",
|
421 |
+
" padding=\"max_length\",\n",
|
422 |
+
" max_length=77,\n",
|
423 |
+
" truncation=True,\n",
|
424 |
+
").data\n",
|
425 |
+
"logits = p_clip(shard(clip_inputs), clip_params)\n",
|
426 |
+
"logits = logits.squeeze().flatten()"
|
427 |
+
]
|
428 |
+
},
|
429 |
+
{
|
430 |
+
"cell_type": "markdown",
|
431 |
+
"metadata": {
|
432 |
+
"id": "4AAWRm70LgED"
|
433 |
+
},
|
434 |
+
"source": [
|
435 |
+
"Let's now display images ranked by CLIP score."
|
436 |
+
]
|
437 |
+
},
|
438 |
+
{
|
439 |
+
"cell_type": "code",
|
440 |
+
"execution_count": null,
|
441 |
+
"metadata": {
|
442 |
+
"id": "zsgxxubLLkIu"
|
443 |
+
},
|
444 |
+
"outputs": [],
|
445 |
+
"source": [
|
446 |
+
"print(f\"Prompt: {prompt}\\n\")\n",
|
447 |
+
"for idx in logits.argsort()[::-1]:\n",
|
448 |
+
" display(images[idx])\n",
|
449 |
+
" print(f\"Score: {logits[idx]:.2f}\\n\")"
|
450 |
+
]
|
451 |
+
}
|
452 |
+
],
|
453 |
+
"metadata": {
|
454 |
+
"accelerator": "GPU",
|
455 |
+
"colab": {
|
456 |
+
"collapsed_sections": [],
|
457 |
+
"machine_shape": "hm",
|
458 |
+
"name": "DALL·E mini - Inference pipeline.ipynb",
|
459 |
+
"provenance": [],
|
460 |
+
"include_colab_link": true
|
461 |
+
},
|
462 |
+
"kernelspec": {
|
463 |
+
"display_name": "Python 3 (ipykernel)",
|
464 |
+
"language": "python",
|
465 |
+
"name": "python3"
|
466 |
+
},
|
467 |
+
"language_info": {
|
468 |
+
"codemirror_mode": {
|
469 |
+
"name": "ipython",
|
470 |
+
"version": 3
|
471 |
+
},
|
472 |
+
"file_extension": ".py",
|
473 |
+
"mimetype": "text/x-python",
|
474 |
+
"name": "python",
|
475 |
+
"nbconvert_exporter": "python",
|
476 |
+
"pygments_lexer": "ipython3",
|
477 |
+
"version": "3.9.7"
|
478 |
}
|
479 |
+
},
|
480 |
+
"nbformat": 4,
|
481 |
+
"nbformat_minor": 0
|
482 |
+
}
|
|
|
|