boris commited on
Commit
5a390e8
1 Parent(s): 38c2c4e

feat: cleanup notebook

Browse files
Files changed (1) hide show
  1. tools/inference/inference_pipeline.ipynb +511 -517
tools/inference/inference_pipeline.ipynb CHANGED
@@ -1,521 +1,515 @@
1
  {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {
6
- "id": "118UKH5bWCGa"
7
- },
8
- "source": [
9
- "# DALL·E mini - Inference pipeline\n",
10
- "\n",
11
- "*Generate images from a text prompt*\n",
12
- "\n",
13
- "<img src=\"https://github.com/borisdayma/dalle-mini/blob/main/img/logo.png?raw=true\" width=\"200\">\n",
14
- "\n",
15
- "This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
16
- "\n",
17
- "Just want to play? Use [the demo](https://huggingface.co/spaces/flax-community/dalle-mini).\n",
18
- "\n",
19
- "For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
20
- ]
21
- },
22
- {
23
- "cell_type": "markdown",
24
- "metadata": {
25
- "id": "dS8LbaonYm3a"
26
- },
27
- "source": [
28
- "## 🛠️ Installation and set-up"
29
- ]
30
- },
31
- {
32
- "cell_type": "code",
33
- "execution_count": 1,
34
- "metadata": {
35
- "colab": {
36
- "base_uri": "https://localhost:8080/"
37
- },
38
- "id": "uzjAM2GBYpZX",
39
- "outputId": "70550075-5204-4c56-dce4-4fff061a096c"
40
- },
41
- "outputs": [],
42
- "source": [
43
- "# Install required libraries\n",
44
- "!pip install -q transformers\n",
45
- "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
46
- "!pip install -q git+https://github.com/borisdayma/dalle-mini.git\n",
47
- "!pip install -q wandb"
48
- ]
49
- },
50
- {
51
- "cell_type": "markdown",
52
- "metadata": {
53
- "id": "ozHzTkyv8cqU"
54
- },
55
- "source": [
56
- "We load required models:\n",
57
- "* dalle·mini for text to encoded images\n",
58
- "* VQGAN for decoding images\n",
59
- "* CLIP for scoring predictions"
60
- ]
61
- },
62
- {
63
- "cell_type": "code",
64
- "execution_count": 2,
65
- "metadata": {
66
- "id": "K6CxW2o42f-w"
67
- },
68
- "outputs": [],
69
- "source": [
70
- "# Model references\n",
71
- "\n",
72
- "# dalle-mini\n",
73
- "DALLE_MODEL = \"dalle-mini/dalle-mini/model-mehdx7dg:latest\" # can be wandb artifact or 🤗 Hub or local folder\n",
74
- "DALLE_COMMIT_ID = None\n",
75
- "\n",
76
- "# VQGAN model\n",
77
- "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
78
- "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"\n",
79
- "\n",
80
- "# CLIP model\n",
81
- "CLIP_REPO = \"openai/clip-vit-base-patch16\"\n",
82
- "CLIP_COMMIT_ID = None"
83
- ]
84
- },
85
- {
86
- "cell_type": "code",
87
- "execution_count": 3,
88
- "metadata": {},
89
- "outputs": [
90
- {
91
- "ename": "KeyboardInterrupt",
92
- "evalue": "",
93
- "output_type": "error",
94
- "traceback": [
95
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
96
- "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
97
- "Input \u001b[0;32mIn [3]\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mjnp\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# check how many devices are available\u001b[39;00m\n\u001b[0;32m----> 5\u001b[0m \u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlocal_device_count\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
98
- "File \u001b[0;32m~/.pyenv/versions/3.9.7/envs/dev/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py:330\u001b[0m, in \u001b[0;36mlocal_device_count\u001b[0;34m(backend)\u001b[0m\n\u001b[1;32m 328\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mlocal_device_count\u001b[39m(backend: Optional[Union[\u001b[38;5;28mstr\u001b[39m, XlaBackend]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mint\u001b[39m:\n\u001b[1;32m 329\u001b[0m \u001b[38;5;124;03m\"\"\"Returns the number of devices addressable by this process.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 330\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mint\u001b[39m(\u001b[43mget_backend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbackend\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mlocal_device_count())\n",
99
- "File \u001b[0;32m~/.pyenv/versions/3.9.7/envs/dev/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py:298\u001b[0m, in \u001b[0;36mget_backend\u001b[0;34m(platform)\u001b[0m\n\u001b[1;32m 296\u001b[0m \u001b[38;5;129m@lru_cache\u001b[39m(maxsize\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;66;03m# don't use util.memoize because there is no X64 dependence.\u001b[39;00m\n\u001b[1;32m 297\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_backend\u001b[39m(platform\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m--> 298\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_get_backend_uncached\u001b[49m\u001b[43m(\u001b[49m\u001b[43mplatform\u001b[49m\u001b[43m)\u001b[49m\n",
100
- "File \u001b[0;32m~/.pyenv/versions/3.9.7/envs/dev/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py:281\u001b[0m, in \u001b[0;36m_get_backend_uncached\u001b[0;34m(platform)\u001b[0m\n\u001b[1;32m 278\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(platform, (\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28;01mNone\u001b[39;00m), \u001b[38;5;28mstr\u001b[39m)):\n\u001b[1;32m 279\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m platform\n\u001b[0;32m--> 281\u001b[0m bs \u001b[38;5;241m=\u001b[39m \u001b[43mbackends\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 282\u001b[0m platform \u001b[38;5;241m=\u001b[39m (platform \u001b[38;5;129;01mor\u001b[39;00m FLAGS\u001b[38;5;241m.\u001b[39mjax_xla_backend \u001b[38;5;129;01mor\u001b[39;00m FLAGS\u001b[38;5;241m.\u001b[39mjax_platform_name\n\u001b[1;32m 283\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m 284\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m platform \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
101
- "File \u001b[0;32m~/.pyenv/versions/3.9.7/envs/dev/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py:231\u001b[0m, in \u001b[0;36mbackends\u001b[0;34m()\u001b[0m\n\u001b[1;32m 229\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m platform, priority \u001b[38;5;129;01min\u001b[39;00m platforms_and_priorites:\n\u001b[1;32m 230\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 231\u001b[0m backend \u001b[38;5;241m=\u001b[39m \u001b[43m_init_backend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mplatform\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 232\u001b[0m _backends[platform] \u001b[38;5;241m=\u001b[39m backend\n\u001b[1;32m 233\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m priority \u001b[38;5;241m>\u001b[39m default_priority:\n",
102
- "File \u001b[0;32m~/.pyenv/versions/3.9.7/envs/dev/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py:260\u001b[0m, in \u001b[0;36m_init_backend\u001b[0;34m(platform)\u001b[0m\n\u001b[1;32m 257\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnknown backend \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mplatform\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 259\u001b[0m logging\u001b[38;5;241m.\u001b[39mvlog(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInitializing backend \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m platform)\n\u001b[0;32m--> 260\u001b[0m backend \u001b[38;5;241m=\u001b[39m \u001b[43mfactory\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 261\u001b[0m \u001b[38;5;66;03m# TODO(skye): consider raising more descriptive errors directly from backend\u001b[39;00m\n\u001b[1;32m 262\u001b[0m \u001b[38;5;66;03m# factories instead of returning None.\u001b[39;00m\n\u001b[1;32m 263\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m backend \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
103
- "File \u001b[0;32m~/.pyenv/versions/3.9.7/envs/dev/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py:170\u001b[0m, in \u001b[0;36mtpu_client_timer_callback\u001b[0;34m(timer_secs)\u001b[0m\n\u001b[1;32m 167\u001b[0m t\u001b[38;5;241m.\u001b[39mstart()\n\u001b[1;32m 169\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 170\u001b[0m client \u001b[38;5;241m=\u001b[39m \u001b[43mxla_client\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmake_tpu_client\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 171\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 172\u001b[0m t\u001b[38;5;241m.\u001b[39mcancel()\n",
104
- "File \u001b[0;32m~/.pyenv/versions/3.9.7/envs/dev/lib/python3.9/site-packages/jaxlib/xla_client.py:96\u001b[0m, in \u001b[0;36mmake_tpu_client\u001b[0;34m()\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmake_tpu_client\u001b[39m():\n\u001b[0;32m---> 96\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_xla\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_tpu_client\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmax_inflight_computations\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m32\u001b[39;49m\u001b[43m)\u001b[49m\n",
105
- "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
106
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  }
108
- ],
109
- "source": [
110
- "import jax\n",
111
- "import jax.numpy as jnp\n",
112
- "\n",
113
- "# check how many devices are available\n",
114
- "jax.local_device_count()"
115
- ]
116
- },
117
- {
118
- "cell_type": "code",
119
- "execution_count": null,
120
- "metadata": {},
121
- "outputs": [],
122
- "source": [
123
- "# type used for computation - use bfloat16 on TPU's\n",
124
- "dtype = jnp.bfloat16 if jax.local_device_count() == 8 else jnp.float32\n",
125
- "\n",
126
- "# TODO: fix issue with bfloat16\n",
127
- "dtype = jnp.float32"
128
- ]
129
- },
130
- {
131
- "cell_type": "code",
132
- "execution_count": null,
133
- "metadata": {
134
  "colab": {
135
- "base_uri": "https://localhost:8080/",
136
- "height": 374
137
- },
138
- "id": "92zYmvsQ38vL",
139
- "outputId": "909b0a3c-14cb-4722-8eb2-f876ff50257c"
140
- },
141
- "outputs": [],
142
- "source": [
143
- "# Load models & tokenizer\n",
144
- "from dalle_mini.model import DalleBart, DalleBartTokenizer\n",
145
- "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
146
- "from transformers import CLIPProcessor, FlaxCLIPModel\n",
147
- "import wandb\n",
148
- "\n",
149
- "# Load dalle-mini\n",
150
- "model = DalleBart.from_pretrained(\n",
151
- " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True\n",
152
- ")\n",
153
- "tokenizer = DalleBartTokenizer.from_pretrained(\n",
154
- " DALLE_MODEL, revision=DALLE_COMMIT_ID\n",
155
- ")\n",
156
- "\n",
157
- "# Load VQGAN\n",
158
- "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
159
- "\n",
160
- "# Load CLIP\n",
161
- "clip = FlaxCLIPModel.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
162
- "processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)"
163
- ]
164
- },
165
- {
166
- "cell_type": "markdown",
167
- "metadata": {
168
- "id": "o_vH2X1tDtzA"
169
- },
170
- "source": [
171
- "Model parameters are replicated on each device for faster inference."
172
- ]
173
- },
174
- {
175
- "cell_type": "code",
176
- "execution_count": null,
177
- "metadata": {
178
- "id": "wtvLoM48EeVw"
179
- },
180
- "outputs": [],
181
- "source": [
182
- "from flax.jax_utils import replicate\n",
183
- "\n",
184
- "# convert model parameters for inference if requested\n",
185
- "if dtype == jnp.bfloat16:\n",
186
- " model.params = model.to_bf16(model.params)\n",
187
- "\n",
188
- "model_params = replicate(model.params)\n",
189
- "vqgan_params = replicate(vqgan.params)\n",
190
- "clip_params = replicate(clip.params)"
191
- ]
192
- },
193
- {
194
- "cell_type": "markdown",
195
- "metadata": {
196
- "id": "0A9AHQIgZ_qw"
197
- },
198
- "source": [
199
- "Model functions are compiled and parallelized to take advantage of multiple devices."
200
- ]
201
- },
202
- {
203
- "cell_type": "code",
204
- "execution_count": null,
205
- "metadata": {
206
- "id": "sOtoOmYsSYPz"
207
- },
208
- "outputs": [],
209
- "source": [
210
- "from functools import partial\n",
211
- "\n",
212
- "# model inference\n",
213
- "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4))\n",
214
- "def p_generate(tokenized_prompt, key, params, top_k, top_p):\n",
215
- " return model.generate(\n",
216
- " **tokenized_prompt,\n",
217
- " do_sample=True,\n",
218
- " num_beams=1,\n",
219
- " prng_key=key,\n",
220
- " params=params,\n",
221
- " top_k=top_k,\n",
222
- " top_p=top_p,\n",
223
- " max_length=257\n",
224
- " )\n",
225
- "\n",
226
- "\n",
227
- "# decode images\n",
228
- "@partial(jax.pmap, axis_name=\"batch\")\n",
229
- "def p_decode(indices, params):\n",
230
- " return vqgan.decode_code(indices, params=params)\n",
231
- "\n",
232
- "\n",
233
- "# score images\n",
234
- "@partial(jax.pmap, axis_name=\"batch\")\n",
235
- "def p_clip(inputs, params):\n",
236
- " logits = clip(params=params, **inputs).logits_per_image\n",
237
- " return logits"
238
- ]
239
- },
240
- {
241
- "cell_type": "markdown",
242
- "metadata": {
243
- "id": "HmVN6IBwapBA"
244
- },
245
- "source": [
246
- "Keys are passed to the model on each device to generate unique inference per device."
247
- ]
248
- },
249
- {
250
- "cell_type": "code",
251
- "execution_count": null,
252
- "metadata": {
253
- "id": "4CTXmlUkThhX"
254
- },
255
- "outputs": [],
256
- "source": [
257
- "import random\n",
258
- "\n",
259
- "# create a random key\n",
260
- "seed = random.randint(0, 2**32 - 1)\n",
261
- "key = jax.random.PRNGKey(seed)"
262
- ]
263
- },
264
- {
265
- "cell_type": "markdown",
266
- "metadata": {
267
- "id": "BrnVyCo81pij"
268
- },
269
- "source": [
270
- "## 🖍 Text Prompt"
271
- ]
272
- },
273
- {
274
- "cell_type": "markdown",
275
- "metadata": {
276
- "id": "rsmj0Aj5OQox"
277
- },
278
- "source": [
279
- "Our model may require to normalize the prompt."
280
- ]
281
- },
282
- {
283
- "cell_type": "code",
284
- "execution_count": null,
285
- "metadata": {
286
- "id": "YjjhUychOVxm"
287
- },
288
- "outputs": [],
289
- "source": [
290
- "from dalle_mini.text import TextNormalizer\n",
291
- "\n",
292
- "text_normalizer = TextNormalizer() if model.config.normalize_text else None"
293
- ]
294
- },
295
- {
296
- "cell_type": "markdown",
297
- "metadata": {
298
- "id": "BQ7fymSPyvF_"
299
- },
300
- "source": [
301
- "Let's define a text prompt."
302
- ]
303
- },
304
- {
305
- "cell_type": "code",
306
- "execution_count": null,
307
- "metadata": {
308
- "id": "x_0vI9ge1oKr"
309
- },
310
- "outputs": [],
311
- "source": [
312
- "prompt = \"a waterfall under the sunset\""
313
- ]
314
- },
315
- {
316
- "cell_type": "code",
317
- "execution_count": null,
318
- "metadata": {
319
- "id": "VKjEZGjtO49k"
320
- },
321
- "outputs": [],
322
- "source": [
323
- "processed_prompt = text_normalizer(prompt) if model.config.normalize_text else prompt\n",
324
- "processed_prompt"
325
- ]
326
- },
327
- {
328
- "cell_type": "markdown",
329
- "metadata": {},
330
- "source": [
331
- "We tokenize the prompt."
332
- ]
333
- },
334
- {
335
- "cell_type": "code",
336
- "execution_count": null,
337
- "metadata": {},
338
- "outputs": [],
339
- "source": [
340
- "tokenized_prompt = tokenizer(\n",
341
- " processed_prompt,\n",
342
- " return_tensors=\"jax\",\n",
343
- " padding=\"max_length\",\n",
344
- " truncation=True,\n",
345
- " max_length=128,\n",
346
- ").data\n",
347
- "tokenized_prompt"
348
- ]
349
- },
350
- {
351
- "cell_type": "markdown",
352
- "metadata": {
353
- "id": "_Y5dqFj7prMQ"
354
- },
355
- "source": [
356
- "Notes:\n",
357
- "\n",
358
- "* `0`: BOS, special token representing the beginning of a sequence\n",
359
- "* `2`: EOS, special token representing the end of a sequence\n",
360
- "* `1`: special token representing the padding of a sequence when requesting a specific length"
361
- ]
362
- },
363
- {
364
- "cell_type": "markdown",
365
- "metadata": {},
366
- "source": [
367
- "Finally we replicate it onto each device."
368
- ]
369
- },
370
- {
371
- "cell_type": "code",
372
- "execution_count": null,
373
- "metadata": {},
374
- "outputs": [],
375
- "source": [
376
- "tokenized_prompt = replicate(tokenized_prompt)"
377
- ]
378
- },
379
- {
380
- "cell_type": "markdown",
381
- "metadata": {
382
- "id": "phQ9bhjRkgAZ"
383
- },
384
- "source": [
385
- "## 🎨 Generate images\n",
386
- "\n",
387
- "We generate images using dalle-mini model and decode them with the VQGAN."
388
- ]
389
- },
390
- {
391
- "cell_type": "code",
392
- "execution_count": null,
393
- "metadata": {
394
- "id": "d0wVkXpKqnHA"
395
- },
396
- "outputs": [],
397
- "source": [
398
- "# number of predictions\n",
399
- "n_predictions = 32\n",
400
- "\n",
401
- "# We can customize top_k/top_p used for generating samples\n",
402
- "gen_top_k = None\n",
403
- "gen_top_p = None"
404
- ]
405
- },
406
- {
407
- "cell_type": "code",
408
- "execution_count": null,
409
- "metadata": {
410
- "id": "SDjEx9JxR3v8"
411
- },
412
- "outputs": [],
413
- "source": [
414
- "from flax.training.common_utils import shard_prng_key\n",
415
- "import numpy as np\n",
416
- "from PIL import Image\n",
417
- "from tqdm.notebook import trange\n",
418
- "\n",
419
- "# generate images\n",
420
- "images = []\n",
421
- "for i in trange(n_predictions // jax.device_count()):\n",
422
- " # get a new key\n",
423
- " key, subkey = jax.random.split(key)\n",
424
- " # generate images\n",
425
- " encoded_images = p_generate(\n",
426
- " tokenized_prompt, shard_prng_key(subkey), model_params, gen_top_k, gen_top_p\n",
427
- " )\n",
428
- " # remove BOS\n",
429
- " encoded_images = encoded_images.sequences[..., 1:]\n",
430
- " # decode images\n",
431
- " decoded_images = p_decode(encoded_images, vqgan_params)\n",
432
- " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
433
- " for img in decoded_images:\n",
434
- " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
435
- ]
436
- },
437
- {
438
- "cell_type": "markdown",
439
- "metadata": {
440
- "id": "tw02wG9zGmyB"
441
- },
442
- "source": [
443
- "Let's calculate their score with CLIP."
444
- ]
445
- },
446
- {
447
- "cell_type": "code",
448
- "execution_count": null,
449
- "metadata": {
450
- "id": "FoLXpjCmGpju"
451
- },
452
- "outputs": [],
453
- "source": [
454
- "from flax.training.common_utils import shard\n",
455
- "\n",
456
- "# get clip scores\n",
457
- "clip_inputs = processor(\n",
458
- " text=[prompt] * jax.device_count(),\n",
459
- " images=images,\n",
460
- " return_tensors=\"np\",\n",
461
- " padding=\"max_length\",\n",
462
- " max_length=77,\n",
463
- " truncation=True,\n",
464
- ").data\n",
465
- "logits = p_clip(shard(clip_inputs), clip_params)\n",
466
- "logits = logits.squeeze().flatten()"
467
- ]
468
- },
469
- {
470
- "cell_type": "markdown",
471
- "metadata": {
472
- "id": "4AAWRm70LgED"
473
- },
474
- "source": [
475
- "Let's display images ranked by CLIP score."
476
- ]
477
- },
478
- {
479
- "cell_type": "code",
480
- "execution_count": null,
481
- "metadata": {
482
- "id": "zsgxxubLLkIu"
483
- },
484
- "outputs": [],
485
- "source": [
486
- "print(f\"Prompt: {prompt}\\n\")\n",
487
- "for idx in logits.argsort()[::-1]:\n",
488
- " display(images[idx])\n",
489
- " print(f\"Score: {logits[idx]:.2f}\\n\")"
490
- ]
491
- }
492
- ],
493
- "metadata": {
494
- "accelerator": "GPU",
495
- "colab": {
496
- "collapsed_sections": [],
497
- "machine_shape": "hm",
498
- "name": "Copy of DALL·E mini - Inference pipeline.ipynb",
499
- "provenance": []
500
- },
501
- "kernelspec": {
502
- "display_name": "Python 3 (ipykernel)",
503
- "language": "python",
504
- "name": "python3"
505
  },
506
- "language_info": {
507
- "codemirror_mode": {
508
- "name": "ipython",
509
- "version": 3
510
- },
511
- "file_extension": ".py",
512
- "mimetype": "text/x-python",
513
- "name": "python",
514
- "nbconvert_exporter": "python",
515
- "pygments_lexer": "ipython3",
516
- "version": "3.9.7"
517
- }
518
- },
519
- "nbformat": 4,
520
- "nbformat_minor": 4
521
- }
 
1
  {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "view-in-github",
7
+ "colab_type": "text"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {
16
+ "id": "118UKH5bWCGa"
17
+ },
18
+ "source": [
19
+ "# DALL·E mini - Inference pipeline\n",
20
+ "\n",
21
+ "*Generate images from a text prompt*\n",
22
+ "\n",
23
+ "<img src=\"https://github.com/borisdayma/dalle-mini/blob/main/img/logo.png?raw=true\" width=\"200\">\n",
24
+ "\n",
25
+ "This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
26
+ "\n",
27
+ "Just want to play? Use [the demo](https://huggingface.co/spaces/flax-community/dalle-mini).\n",
28
+ "\n",
29
+ "For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "markdown",
34
+ "metadata": {
35
+ "id": "dS8LbaonYm3a"
36
+ },
37
+ "source": [
38
+ "## 🛠️ Installation and set-up"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {
45
+ "id": "uzjAM2GBYpZX"
46
+ },
47
+ "outputs": [],
48
+ "source": [
49
+ "# Install required libraries\n",
50
+ "!pip install -q transformers\n",
51
+ "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
52
+ "!pip install -q git+https://github.com/borisdayma/dalle-mini.git"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "markdown",
57
+ "metadata": {
58
+ "id": "ozHzTkyv8cqU"
59
+ },
60
+ "source": [
61
+ "We load required models:\n",
62
+ "* dalle·mini for text to encoded images\n",
63
+ "* VQGAN for decoding images\n",
64
+ "* CLIP for scoring predictions"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": null,
70
+ "metadata": {
71
+ "id": "K6CxW2o42f-w"
72
+ },
73
+ "outputs": [],
74
+ "source": [
75
+ "# Model references\n",
76
+ "\n",
77
+ "# dalle-mini\n",
78
+ "DALLE_MODEL = \"dalle-mini/dalle-mini/model-mehdx7dg:latest\" # can be wandb artifact or 🤗 Hub or local folder\n",
79
+ "DALLE_COMMIT_ID = None\n",
80
+ "\n",
81
+ "# VQGAN model\n",
82
+ "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
83
+ "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"\n",
84
+ "\n",
85
+ "# CLIP model\n",
86
+ "CLIP_REPO = \"openai/clip-vit-base-patch16\"\n",
87
+ "CLIP_COMMIT_ID = None"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "metadata": {
94
+ "id": "Yv-aR3t4Oe5v"
95
+ },
96
+ "outputs": [],
97
+ "source": [
98
+ "import jax\n",
99
+ "import jax.numpy as jnp\n",
100
+ "\n",
101
+ "# check how many devices are available\n",
102
+ "jax.local_device_count()"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": null,
108
+ "metadata": {
109
+ "id": "HWnQrQuXOe5w"
110
+ },
111
+ "outputs": [],
112
+ "source": [
113
+ "# type used for computation - use bfloat16 on TPU's\n",
114
+ "dtype = jnp.bfloat16 if jax.local_device_count() == 8 else jnp.float32\n",
115
+ "\n",
116
+ "# TODO: fix issue with bfloat16\n",
117
+ "dtype = jnp.float32"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": null,
123
+ "metadata": {
124
+ "id": "92zYmvsQ38vL"
125
+ },
126
+ "outputs": [],
127
+ "source": [
128
+ "# Load models & tokenizer\n",
129
+ "from dalle_mini.model import DalleBart, DalleBartTokenizer\n",
130
+ "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
131
+ "from transformers import CLIPProcessor, FlaxCLIPModel\n",
132
+ "import wandb\n",
133
+ "\n",
134
+ "# Load dalle-mini\n",
135
+ "model = DalleBart.from_pretrained(\n",
136
+ " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True\n",
137
+ ")\n",
138
+ "tokenizer = DalleBartTokenizer.from_pretrained(\n",
139
+ " DALLE_MODEL, revision=DALLE_COMMIT_ID\n",
140
+ ")\n",
141
+ "\n",
142
+ "# Load VQGAN\n",
143
+ "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
144
+ "\n",
145
+ "# Load CLIP\n",
146
+ "clip = FlaxCLIPModel.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
147
+ "processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "markdown",
152
+ "metadata": {
153
+ "id": "o_vH2X1tDtzA"
154
+ },
155
+ "source": [
156
+ "Model parameters are replicated on each device for faster inference."
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "metadata": {
163
+ "id": "wtvLoM48EeVw"
164
+ },
165
+ "outputs": [],
166
+ "source": [
167
+ "from flax.jax_utils import replicate\n",
168
+ "\n",
169
+ "# convert model parameters for inference if requested\n",
170
+ "if dtype == jnp.bfloat16:\n",
171
+ " model.params = model.to_bf16(model.params)\n",
172
+ "\n",
173
+ "model_params = replicate(model.params)\n",
174
+ "vqgan_params = replicate(vqgan.params)\n",
175
+ "clip_params = replicate(clip.params)"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "markdown",
180
+ "metadata": {
181
+ "id": "0A9AHQIgZ_qw"
182
+ },
183
+ "source": [
184
+ "Model functions are compiled and parallelized to take advantage of multiple devices."
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": null,
190
+ "metadata": {
191
+ "id": "sOtoOmYsSYPz"
192
+ },
193
+ "outputs": [],
194
+ "source": [
195
+ "from functools import partial\n",
196
+ "\n",
197
+ "# model inference\n",
198
+ "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4))\n",
199
+ "def p_generate(tokenized_prompt, key, params, top_k, top_p):\n",
200
+ " return model.generate(\n",
201
+ " **tokenized_prompt,\n",
202
+ " do_sample=True,\n",
203
+ " num_beams=1,\n",
204
+ " prng_key=key,\n",
205
+ " params=params,\n",
206
+ " top_k=top_k,\n",
207
+ " top_p=top_p,\n",
208
+ " max_length=257\n",
209
+ " )\n",
210
+ "\n",
211
+ "\n",
212
+ "# decode images\n",
213
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
214
+ "def p_decode(indices, params):\n",
215
+ " return vqgan.decode_code(indices, params=params)\n",
216
+ "\n",
217
+ "\n",
218
+ "# score images\n",
219
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
220
+ "def p_clip(inputs, params):\n",
221
+ " logits = clip(params=params, **inputs).logits_per_image\n",
222
+ " return logits"
223
+ ]
224
+ },
225
+ {
226
+ "cell_type": "markdown",
227
+ "metadata": {
228
+ "id": "HmVN6IBwapBA"
229
+ },
230
+ "source": [
231
+ "Keys are passed to the model on each device to generate unique inference per device."
232
+ ]
233
+ },
234
+ {
235
+ "cell_type": "code",
236
+ "execution_count": null,
237
+ "metadata": {
238
+ "id": "4CTXmlUkThhX"
239
+ },
240
+ "outputs": [],
241
+ "source": [
242
+ "import random\n",
243
+ "\n",
244
+ "# create a random key\n",
245
+ "seed = random.randint(0, 2**32 - 1)\n",
246
+ "key = jax.random.PRNGKey(seed)"
247
+ ]
248
+ },
249
+ {
250
+ "cell_type": "markdown",
251
+ "metadata": {
252
+ "id": "BrnVyCo81pij"
253
+ },
254
+ "source": [
255
+ "## 🖍 Text Prompt"
256
+ ]
257
+ },
258
+ {
259
+ "cell_type": "markdown",
260
+ "metadata": {
261
+ "id": "rsmj0Aj5OQox"
262
+ },
263
+ "source": [
264
+ "Our model may require to normalize the prompt."
265
+ ]
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "execution_count": null,
270
+ "metadata": {
271
+ "id": "YjjhUychOVxm"
272
+ },
273
+ "outputs": [],
274
+ "source": [
275
+ "from dalle_mini.text import TextNormalizer\n",
276
+ "\n",
277
+ "text_normalizer = TextNormalizer() if model.config.normalize_text else None"
278
+ ]
279
+ },
280
+ {
281
+ "cell_type": "markdown",
282
+ "metadata": {
283
+ "id": "BQ7fymSPyvF_"
284
+ },
285
+ "source": [
286
+ "Let's define a text prompt."
287
+ ]
288
+ },
289
+ {
290
+ "cell_type": "code",
291
+ "execution_count": null,
292
+ "metadata": {
293
+ "id": "x_0vI9ge1oKr"
294
+ },
295
+ "outputs": [],
296
+ "source": [
297
+ "prompt = \"a blue table\""
298
+ ]
299
+ },
300
+ {
301
+ "cell_type": "code",
302
+ "execution_count": null,
303
+ "metadata": {
304
+ "id": "VKjEZGjtO49k"
305
+ },
306
+ "outputs": [],
307
+ "source": [
308
+ "processed_prompt = text_normalizer(prompt) if model.config.normalize_text else prompt\n",
309
+ "processed_prompt"
310
+ ]
311
+ },
312
+ {
313
+ "cell_type": "markdown",
314
+ "metadata": {
315
+ "id": "QUzYACWxOe5z"
316
+ },
317
+ "source": [
318
+ "We tokenize the prompt."
319
+ ]
320
+ },
321
+ {
322
+ "cell_type": "code",
323
+ "execution_count": null,
324
+ "metadata": {
325
+ "id": "n8e7MvGwOe5z"
326
+ },
327
+ "outputs": [],
328
+ "source": [
329
+ "tokenized_prompt = tokenizer(\n",
330
+ " processed_prompt,\n",
331
+ " return_tensors=\"jax\",\n",
332
+ " padding=\"max_length\",\n",
333
+ " truncation=True,\n",
334
+ " max_length=128,\n",
335
+ ").data\n",
336
+ "tokenized_prompt"
337
+ ]
338
+ },
339
+ {
340
+ "cell_type": "markdown",
341
+ "metadata": {
342
+ "id": "_Y5dqFj7prMQ"
343
+ },
344
+ "source": [
345
+ "Notes:\n",
346
+ "\n",
347
+ "* `0`: BOS, special token representing the beginning of a sequence\n",
348
+ "* `2`: EOS, special token representing the end of a sequence\n",
349
+ "* `1`: special token representing the padding of a sequence when requesting a specific length"
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "markdown",
354
+ "metadata": {
355
+ "id": "-CEJBnuJOe5z"
356
+ },
357
+ "source": [
358
+ "Finally we replicate it onto each device."
359
+ ]
360
+ },
361
+ {
362
+ "cell_type": "code",
363
+ "execution_count": null,
364
+ "metadata": {
365
+ "id": "lQePgju5Oe5z"
366
+ },
367
+ "outputs": [],
368
+ "source": [
369
+ "tokenized_prompt = replicate(tokenized_prompt)"
370
+ ]
371
+ },
372
+ {
373
+ "cell_type": "markdown",
374
+ "metadata": {
375
+ "id": "phQ9bhjRkgAZ"
376
+ },
377
+ "source": [
378
+ "## 🎨 Generate images\n",
379
+ "\n",
380
+ "We generate images using dalle-mini model and decode them with the VQGAN."
381
+ ]
382
+ },
383
+ {
384
+ "cell_type": "code",
385
+ "execution_count": null,
386
+ "metadata": {
387
+ "id": "d0wVkXpKqnHA"
388
+ },
389
+ "outputs": [],
390
+ "source": [
391
+ "# number of predictions\n",
392
+ "n_predictions = 32\n",
393
+ "\n",
394
+ "# We can customize top_k/top_p used for generating samples\n",
395
+ "gen_top_k = None\n",
396
+ "gen_top_p = None"
397
+ ]
398
+ },
399
+ {
400
+ "cell_type": "code",
401
+ "execution_count": null,
402
+ "metadata": {
403
+ "id": "SDjEx9JxR3v8"
404
+ },
405
+ "outputs": [],
406
+ "source": [
407
+ "from flax.training.common_utils import shard_prng_key\n",
408
+ "import numpy as np\n",
409
+ "from PIL import Image\n",
410
+ "from tqdm.notebook import trange\n",
411
+ "\n",
412
+ "# generate images\n",
413
+ "images = []\n",
414
+ "for i in trange(n_predictions // jax.device_count()):\n",
415
+ " # get a new key\n",
416
+ " key, subkey = jax.random.split(key)\n",
417
+ " # generate images\n",
418
+ " encoded_images = p_generate(\n",
419
+ " tokenized_prompt, shard_prng_key(subkey), model_params, gen_top_k, gen_top_p\n",
420
+ " )\n",
421
+ " # remove BOS\n",
422
+ " encoded_images = encoded_images.sequences[..., 1:]\n",
423
+ " # decode images\n",
424
+ " decoded_images = p_decode(encoded_images, vqgan_params)\n",
425
+ " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
426
+ " for img in decoded_images:\n",
427
+ " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
428
+ ]
429
+ },
430
+ {
431
+ "cell_type": "markdown",
432
+ "metadata": {
433
+ "id": "tw02wG9zGmyB"
434
+ },
435
+ "source": [
436
+ "Let's calculate their score with CLIP."
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": null,
442
+ "metadata": {
443
+ "id": "FoLXpjCmGpju"
444
+ },
445
+ "outputs": [],
446
+ "source": [
447
+ "from flax.training.common_utils import shard\n",
448
+ "\n",
449
+ "# get clip scores\n",
450
+ "clip_inputs = processor(\n",
451
+ " text=[prompt] * jax.device_count(),\n",
452
+ " images=images,\n",
453
+ " return_tensors=\"np\",\n",
454
+ " padding=\"max_length\",\n",
455
+ " max_length=77,\n",
456
+ " truncation=True,\n",
457
+ ").data\n",
458
+ "logits = p_clip(shard(clip_inputs), clip_params)\n",
459
+ "logits = logits.squeeze().flatten()"
460
+ ]
461
+ },
462
+ {
463
+ "cell_type": "markdown",
464
+ "metadata": {
465
+ "id": "4AAWRm70LgED"
466
+ },
467
+ "source": [
468
+ "Let's display images ranked by CLIP score."
469
+ ]
470
+ },
471
+ {
472
+ "cell_type": "code",
473
+ "execution_count": null,
474
+ "metadata": {
475
+ "id": "zsgxxubLLkIu"
476
+ },
477
+ "outputs": [],
478
+ "source": [
479
+ "print(f\"Prompt: {prompt}\\n\")\n",
480
+ "for idx in logits.argsort()[::-1]:\n",
481
+ " display(images[idx])\n",
482
+ " print(f\"Score: {logits[idx]:.2f}\\n\")"
483
+ ]
484
  }
485
+ ],
486
+ "metadata": {
487
+ "accelerator": "GPU",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
  "colab": {
489
+ "collapsed_sections": [],
490
+ "machine_shape": "hm",
491
+ "name": "DALL·E mini - Inference pipeline.ipynb",
492
+ "provenance": [],
493
+ "include_colab_link": true
494
+ },
495
+ "kernelspec": {
496
+ "display_name": "Python 3 (ipykernel)",
497
+ "language": "python",
498
+ "name": "python3"
499
+ },
500
+ "language_info": {
501
+ "codemirror_mode": {
502
+ "name": "ipython",
503
+ "version": 3
504
+ },
505
+ "file_extension": ".py",
506
+ "mimetype": "text/x-python",
507
+ "name": "python",
508
+ "nbconvert_exporter": "python",
509
+ "pygments_lexer": "ipython3",
510
+ "version": "3.9.7"
511
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
  },
513
+ "nbformat": 4,
514
+ "nbformat_minor": 0
515
+ }