valhalla commited on
Commit
4b5a542
1 Parent(s): 8f484d9

add tpu demo notebook

Browse files
demo/.ipynb_checkpoints/tpu-demo-checkpoint.ipynb ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "6eb74941-bb4d-4d7e-97f1-d5a3a07672bf",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "# !pip install flax transformers\n",
11
+ "# !git clone https://github.com/patil-suraj/vqgan-jax.git"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 305,
17
+ "id": "41db7534-f589-4b63-9165-9c9799e1b06e",
18
+ "metadata": {},
19
+ "outputs": [
20
+ {
21
+ "name": "stdout",
22
+ "output_type": "stream",
23
+ "text": [
24
+ "/home/surajpatil/vqgan-jax\n"
25
+ ]
26
+ },
27
+ {
28
+ "data": {
29
+ "text/plain": [
30
+ "[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n",
31
+ " TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n",
32
+ " TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n",
33
+ " TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n",
34
+ " TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n",
35
+ " TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n",
36
+ " TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n",
37
+ " TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]"
38
+ ]
39
+ },
40
+ "execution_count": 305,
41
+ "metadata": {},
42
+ "output_type": "execute_result"
43
+ }
44
+ ],
45
+ "source": [
46
+ "%cd ~/vqgan-jax\n",
47
+ "\n",
48
+ "import random\n",
49
+ "\n",
50
+ "\n",
51
+ "import jax\n",
52
+ "import flax.linen as nn\n",
53
+ "from flax.training.common_utils import shard\n",
54
+ "from flax.jax_utils import replicate, unreplicate\n",
55
+ "\n",
56
+ "from transformers.models.bart.modeling_flax_bart import *\n",
57
+ "from transformers import BartTokenizer, FlaxBartForConditionalGeneration\n",
58
+ "\n",
59
+ "import io\n",
60
+ "\n",
61
+ "import requests\n",
62
+ "from PIL import Image\n",
63
+ "import numpy as np\n",
64
+ "import matplotlib.pyplot as plt\n",
65
+ "\n",
66
+ "import torch\n",
67
+ "import torchvision.transforms as T\n",
68
+ "import torchvision.transforms.functional as TF\n",
69
+ "from torchvision.transforms import InterpolationMode\n",
70
+ "\n",
71
+ "\n",
72
+ "from modeling_flax_vqgan import VQModel\n",
73
+ "\n",
74
+ "jax.devices()"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": 2,
80
+ "id": "b6a3462a-9004-4121-b365-3ae3aaf94dd2",
81
+ "metadata": {},
82
+ "outputs": [],
83
+ "source": [
84
+ "# TODO: set those args in a config file\n",
85
+ "OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n",
86
+ "OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n",
87
+ "BOS_TOKEN_ID = 16384\n",
88
+ "BASE_MODEL = 'facebook/bart-large'"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": 3,
94
+ "id": "bbef1afb-0b36-44a5-83f7-643d7e2c0e30",
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "class CustomFlaxBartModule(FlaxBartModule):\n",
99
+ " def setup(self):\n",
100
+ " # we keep shared to easily load pre-trained weights\n",
101
+ " self.shared = nn.Embed(\n",
102
+ " self.config.vocab_size,\n",
103
+ " self.config.d_model,\n",
104
+ " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
105
+ " dtype=self.dtype,\n",
106
+ " )\n",
107
+ " # a separate embedding is used for the decoder\n",
108
+ " self.decoder_embed = nn.Embed(\n",
109
+ " OUTPUT_VOCAB_SIZE,\n",
110
+ " self.config.d_model,\n",
111
+ " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
112
+ " dtype=self.dtype,\n",
113
+ " )\n",
114
+ " self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n",
115
+ "\n",
116
+ " # the decoder has a different config\n",
117
+ " decoder_config = BartConfig(self.config.to_dict())\n",
118
+ " decoder_config.max_position_embeddings = OUTPUT_LENGTH\n",
119
+ " decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n",
120
+ " self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n",
121
+ "\n",
122
+ "class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n",
123
+ " def setup(self):\n",
124
+ " self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n",
125
+ " self.lm_head = nn.Dense(\n",
126
+ " OUTPUT_VOCAB_SIZE,\n",
127
+ " use_bias=False,\n",
128
+ " dtype=self.dtype,\n",
129
+ " kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
130
+ " )\n",
131
+ " self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n",
132
+ "\n",
133
+ "class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n",
134
+ " module_class = CustomFlaxBartForConditionalGenerationModule"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "id": "879320b7-eaa0-4dc9-bbf2-c81efc53301d",
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": [
144
+ "import wandb\n",
145
+ "run = wandb.init()\n",
146
+ "artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-3h3x3565:v7', type='bart_model')\n",
147
+ "artifact_dir = artifact.download()"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": 164,
153
+ "id": "e8bcff33-e95b-4c01-b162-ee857a55c3e6",
154
+ "metadata": {},
155
+ "outputs": [
156
+ {
157
+ "name": "stderr",
158
+ "output_type": "stream",
159
+ "text": [
160
+ "/home/surajpatil/transformers/src/transformers/models/bart/configuration_bart.py:177: UserWarning: Please make sure the config includes `forced_bos_token_id=16384` in future versions.The config can simply be saved and uploaded again to be fixed.\n",
161
+ " warnings.warn(\n"
162
+ ]
163
+ },
164
+ {
165
+ "data": {
166
+ "text/plain": [
167
+ "(1, 16385)"
168
+ ]
169
+ },
170
+ "execution_count": 164,
171
+ "metadata": {},
172
+ "output_type": "execute_result"
173
+ }
174
+ ],
175
+ "source": [
176
+ "# create our model and initialize it randomly\n",
177
+ "tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)\n",
178
+ "model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)\n",
179
+ "model.config.force_bos_token_to_be_generated = False\n",
180
+ "model.config.forced_bos_token_id = None\n",
181
+ "model.config.forced_eos_token_id = None\n",
182
+ "\n",
183
+ "# we verify that the shape has not been modified\n",
184
+ "model.params['final_logits_bias'].shape"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": 6,
190
+ "id": "8d5e0f14-2502-470e-9553-daee6748601f",
191
+ "metadata": {},
192
+ "outputs": [
193
+ {
194
+ "data": {
195
+ "application/vnd.jupyter.widget-view+json": {
196
+ "model_id": "9b979a72ab9e449387a89bf9b3012af5",
197
+ "version_major": 2,
198
+ "version_minor": 0
199
+ },
200
+ "text/plain": [
201
+ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…"
202
+ ]
203
+ },
204
+ "metadata": {},
205
+ "output_type": "display_data"
206
+ },
207
+ {
208
+ "name": "stdout",
209
+ "output_type": "stream",
210
+ "text": [
211
+ "\n"
212
+ ]
213
+ },
214
+ {
215
+ "data": {
216
+ "application/vnd.jupyter.widget-view+json": {
217
+ "model_id": "01730e0e9d02428ca9dad680f9fdda42",
218
+ "version_major": 2,
219
+ "version_minor": 0
220
+ },
221
+ "text/plain": [
222
+ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=304307206.0, style=ProgressStyle(descri…"
223
+ ]
224
+ },
225
+ "metadata": {},
226
+ "output_type": "display_data"
227
+ },
228
+ {
229
+ "name": "stdout",
230
+ "output_type": "stream",
231
+ "text": [
232
+ "\n",
233
+ "Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
234
+ ]
235
+ }
236
+ ],
237
+ "source": [
238
+ "vqgan = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "execution_count": 295,
244
+ "id": "6cca395a-93c2-49bc-a3be-98287e4403d4",
245
+ "metadata": {},
246
+ "outputs": [],
247
+ "source": [
248
+ "def custom_to_pil(x):\n",
249
+ " x = np.clip(x, 0., 1.)\n",
250
+ " x = (255*x).astype(np.uint8)\n",
251
+ " x = Image.fromarray(x)\n",
252
+ " if not x.mode == \"RGB\":\n",
253
+ " x = x.convert(\"RGB\")\n",
254
+ " return x\n",
255
+ "\n",
256
+ "def generate(input, rng, params):\n",
257
+ " return model.generate(\n",
258
+ " **input,\n",
259
+ " max_length=257,\n",
260
+ " num_beams=1,\n",
261
+ " do_sample=True,\n",
262
+ " prng_key=rng,\n",
263
+ " eos_token_id=50000,\n",
264
+ " pad_token_id=50000,\n",
265
+ " params=params\n",
266
+ " )\n",
267
+ "\n",
268
+ "def get_images(indices, params):\n",
269
+ " return vqgan.decode_code(indices, params=params)\n",
270
+ "\n",
271
+ "\n",
272
+ "def plot_images(images):\n",
273
+ " fig = plt.figure(figsize=(40, 20))\n",
274
+ " columns = 4\n",
275
+ " rows = 2\n",
276
+ " plt.subplots_adjust(hspace=0, wspace=0)\n",
277
+ "\n",
278
+ " for i in range(1, columns*rows +1):\n",
279
+ " fig.add_subplot(rows, columns, i)\n",
280
+ " plt.imshow(images[i-1])\n",
281
+ " plt.gca().axes.get_yaxis().set_visible(False)\n",
282
+ " plt.show()\n",
283
+ " \n",
284
+ "def stack_reconstructions(images):\n",
285
+ " w, h = images[0].size[0], images[0].size[1]\n",
286
+ " img = Image.new(\"RGB\", (len(images)*w, h))\n",
287
+ " for i, img_ in enumerate(images):\n",
288
+ " img.paste(img_, (i*w,0))\n",
289
+ " return img"
290
+ ]
291
+ },
292
+ {
293
+ "cell_type": "code",
294
+ "execution_count": 166,
295
+ "id": "b1bec3d2-ef17-4feb-aa0d-b51ed2fdcd3e",
296
+ "metadata": {},
297
+ "outputs": [],
298
+ "source": [
299
+ "p_generate = jax.pmap(generate, \"batch\")\n",
300
+ "p_get_images = jax.pmap(get_images, \"batch\")"
301
+ ]
302
+ },
303
+ {
304
+ "cell_type": "code",
305
+ "execution_count": null,
306
+ "id": "a539823a-a775-4d92-96a5-dc8b1eef69c5",
307
+ "metadata": {},
308
+ "outputs": [],
309
+ "source": [
310
+ "bart_params = replicate(model.params)\n",
311
+ "vqgan_params = replicate(vqgan.params)"
312
+ ]
313
+ },
314
+ {
315
+ "cell_type": "code",
316
+ "execution_count": 328,
317
+ "id": "e8b268d8-6992-422a-8373-95651474ae70",
318
+ "metadata": {},
319
+ "outputs": [],
320
+ "source": [
321
+ "prompts = [\n",
322
+ " \"man in blue jacket walking on pathway in between trees during daytime\",\n",
323
+ " 'white snow covered mountain under blue sky during daytime',\n",
324
+ " 'white snow covered mountain under blue sky during night',\n",
325
+ " \"orange tabby cat on persons hand\",\n",
326
+ " \"aerial view of beach during daytime\",\n",
327
+ " \"chess pieces on chess board\",\n",
328
+ " \"laptop on brown wooden table\",\n",
329
+ " \"white bus on road near high rise buildings\",\n",
330
+ "]\n",
331
+ "\n",
332
+ "\n",
333
+ "prompt = [prompts[-1]] * 8\n",
334
+ "inputs = tokenizer(prompt, return_tensors='jax', padding=\"max_length\", truncation=True, max_length=128).data\n",
335
+ "inputs = shard(inputs)"
336
+ ]
337
+ },
338
+ {
339
+ "cell_type": "code",
340
+ "execution_count": null,
341
+ "id": "68638cfa-9a4d-4e6a-8630-91aefb627bbd",
342
+ "metadata": {},
343
+ "outputs": [],
344
+ "source": [
345
+ "%%time\n",
346
+ "for i in range(8):\n",
347
+ " key = random.randint(0, 1e7)\n",
348
+ " rng = jax.random.PRNGKey(key)\n",
349
+ " rngs = jax.random.split(rng, jax.local_device_count())\n",
350
+ " indices = p_generate(inputs, rngs, bart_params).sequences\n",
351
+ " indices = indices[:, :, 1:]\n",
352
+ "\n",
353
+ " images = p_get_images(indices, vqgan_params)\n",
354
+ " images = np.squeeze(np.asarray(images), 1)\n",
355
+ " imges = [custom_to_pil(image) for image in images]\n",
356
+ "\n",
357
+ " plt.figure(figsize=(40, 20))\n",
358
+ " plt.imshow(stack_reconstructions(imges))"
359
+ ]
360
+ },
361
+ {
362
+ "cell_type": "code",
363
+ "execution_count": null,
364
+ "id": "681af54e-da10-4b8e-80d0-ebcbdf23f376",
365
+ "metadata": {},
366
+ "outputs": [],
367
+ "source": []
368
+ }
369
+ ],
370
+ "metadata": {
371
+ "kernelspec": {
372
+ "display_name": "Python 3",
373
+ "language": "python",
374
+ "name": "python3"
375
+ },
376
+ "language_info": {
377
+ "codemirror_mode": {
378
+ "name": "ipython",
379
+ "version": 3
380
+ },
381
+ "file_extension": ".py",
382
+ "mimetype": "text/x-python",
383
+ "name": "python",
384
+ "nbconvert_exporter": "python",
385
+ "pygments_lexer": "ipython3",
386
+ "version": "3.8.10"
387
+ }
388
+ },
389
+ "nbformat": 4,
390
+ "nbformat_minor": 5
391
+ }
demo/tpu-demo.ipynb ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "6eb74941-bb4d-4d7e-97f1-d5a3a07672bf",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "# !pip install flax transformers\n",
11
+ "# !git clone https://github.com/patil-suraj/vqgan-jax.git"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 305,
17
+ "id": "41db7534-f589-4b63-9165-9c9799e1b06e",
18
+ "metadata": {},
19
+ "outputs": [
20
+ {
21
+ "name": "stdout",
22
+ "output_type": "stream",
23
+ "text": [
24
+ "/home/surajpatil/vqgan-jax\n"
25
+ ]
26
+ },
27
+ {
28
+ "data": {
29
+ "text/plain": [
30
+ "[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n",
31
+ " TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n",
32
+ " TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n",
33
+ " TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n",
34
+ " TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n",
35
+ " TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n",
36
+ " TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n",
37
+ " TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]"
38
+ ]
39
+ },
40
+ "execution_count": 305,
41
+ "metadata": {},
42
+ "output_type": "execute_result"
43
+ }
44
+ ],
45
+ "source": [
46
+ "%cd ~/vqgan-jax\n",
47
+ "\n",
48
+ "import random\n",
49
+ "\n",
50
+ "\n",
51
+ "import jax\n",
52
+ "import flax.linen as nn\n",
53
+ "from flax.training.common_utils import shard\n",
54
+ "from flax.jax_utils import replicate, unreplicate\n",
55
+ "\n",
56
+ "from transformers.models.bart.modeling_flax_bart import *\n",
57
+ "from transformers import BartTokenizer, FlaxBartForConditionalGeneration\n",
58
+ "\n",
59
+ "import io\n",
60
+ "\n",
61
+ "import requests\n",
62
+ "from PIL import Image\n",
63
+ "import numpy as np\n",
64
+ "import matplotlib.pyplot as plt\n",
65
+ "\n",
66
+ "import torch\n",
67
+ "import torchvision.transforms as T\n",
68
+ "import torchvision.transforms.functional as TF\n",
69
+ "from torchvision.transforms import InterpolationMode\n",
70
+ "\n",
71
+ "\n",
72
+ "from modeling_flax_vqgan import VQModel\n",
73
+ "\n",
74
+ "jax.devices()"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": 2,
80
+ "id": "b6a3462a-9004-4121-b365-3ae3aaf94dd2",
81
+ "metadata": {},
82
+ "outputs": [],
83
+ "source": [
84
+ "# TODO: set those args in a config file\n",
85
+ "OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n",
86
+ "OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n",
87
+ "BOS_TOKEN_ID = 16384\n",
88
+ "BASE_MODEL = 'facebook/bart-large'"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": 3,
94
+ "id": "bbef1afb-0b36-44a5-83f7-643d7e2c0e30",
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "class CustomFlaxBartModule(FlaxBartModule):\n",
99
+ " def setup(self):\n",
100
+ " # we keep shared to easily load pre-trained weights\n",
101
+ " self.shared = nn.Embed(\n",
102
+ " self.config.vocab_size,\n",
103
+ " self.config.d_model,\n",
104
+ " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
105
+ " dtype=self.dtype,\n",
106
+ " )\n",
107
+ " # a separate embedding is used for the decoder\n",
108
+ " self.decoder_embed = nn.Embed(\n",
109
+ " OUTPUT_VOCAB_SIZE,\n",
110
+ " self.config.d_model,\n",
111
+ " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
112
+ " dtype=self.dtype,\n",
113
+ " )\n",
114
+ " self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n",
115
+ "\n",
116
+ " # the decoder has a different config\n",
117
+ " decoder_config = BartConfig(self.config.to_dict())\n",
118
+ " decoder_config.max_position_embeddings = OUTPUT_LENGTH\n",
119
+ " decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n",
120
+ " self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n",
121
+ "\n",
122
+ "class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n",
123
+ " def setup(self):\n",
124
+ " self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n",
125
+ " self.lm_head = nn.Dense(\n",
126
+ " OUTPUT_VOCAB_SIZE,\n",
127
+ " use_bias=False,\n",
128
+ " dtype=self.dtype,\n",
129
+ " kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
130
+ " )\n",
131
+ " self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n",
132
+ "\n",
133
+ "class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n",
134
+ " module_class = CustomFlaxBartForConditionalGenerationModule"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "id": "879320b7-eaa0-4dc9-bbf2-c81efc53301d",
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": [
144
+ "import wandb\n",
145
+ "run = wandb.init()\n",
146
+ "artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-3h3x3565:v7', type='bart_model')\n",
147
+ "artifact_dir = artifact.download()"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": 164,
153
+ "id": "e8bcff33-e95b-4c01-b162-ee857a55c3e6",
154
+ "metadata": {},
155
+ "outputs": [
156
+ {
157
+ "name": "stderr",
158
+ "output_type": "stream",
159
+ "text": [
160
+ "/home/surajpatil/transformers/src/transformers/models/bart/configuration_bart.py:177: UserWarning: Please make sure the config includes `forced_bos_token_id=16384` in future versions.The config can simply be saved and uploaded again to be fixed.\n",
161
+ " warnings.warn(\n"
162
+ ]
163
+ },
164
+ {
165
+ "data": {
166
+ "text/plain": [
167
+ "(1, 16385)"
168
+ ]
169
+ },
170
+ "execution_count": 164,
171
+ "metadata": {},
172
+ "output_type": "execute_result"
173
+ }
174
+ ],
175
+ "source": [
176
+ "# create our model and initialize it randomly\n",
177
+ "tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)\n",
178
+ "model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)\n",
179
+ "model.config.force_bos_token_to_be_generated = False\n",
180
+ "model.config.forced_bos_token_id = None\n",
181
+ "model.config.forced_eos_token_id = None\n",
182
+ "\n",
183
+ "# we verify that the shape has not been modified\n",
184
+ "model.params['final_logits_bias'].shape"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": 6,
190
+ "id": "8d5e0f14-2502-470e-9553-daee6748601f",
191
+ "metadata": {},
192
+ "outputs": [
193
+ {
194
+ "data": {
195
+ "application/vnd.jupyter.widget-view+json": {
196
+ "model_id": "9b979a72ab9e449387a89bf9b3012af5",
197
+ "version_major": 2,
198
+ "version_minor": 0
199
+ },
200
+ "text/plain": [
201
+ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…"
202
+ ]
203
+ },
204
+ "metadata": {},
205
+ "output_type": "display_data"
206
+ },
207
+ {
208
+ "name": "stdout",
209
+ "output_type": "stream",
210
+ "text": [
211
+ "\n"
212
+ ]
213
+ },
214
+ {
215
+ "data": {
216
+ "application/vnd.jupyter.widget-view+json": {
217
+ "model_id": "01730e0e9d02428ca9dad680f9fdda42",
218
+ "version_major": 2,
219
+ "version_minor": 0
220
+ },
221
+ "text/plain": [
222
+ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=304307206.0, style=ProgressStyle(descri…"
223
+ ]
224
+ },
225
+ "metadata": {},
226
+ "output_type": "display_data"
227
+ },
228
+ {
229
+ "name": "stdout",
230
+ "output_type": "stream",
231
+ "text": [
232
+ "\n",
233
+ "Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
234
+ ]
235
+ }
236
+ ],
237
+ "source": [
238
+ "vqgan = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "execution_count": 295,
244
+ "id": "6cca395a-93c2-49bc-a3be-98287e4403d4",
245
+ "metadata": {},
246
+ "outputs": [],
247
+ "source": [
248
+ "def custom_to_pil(x):\n",
249
+ " x = np.clip(x, 0., 1.)\n",
250
+ " x = (255*x).astype(np.uint8)\n",
251
+ " x = Image.fromarray(x)\n",
252
+ " if not x.mode == \"RGB\":\n",
253
+ " x = x.convert(\"RGB\")\n",
254
+ " return x\n",
255
+ "\n",
256
+ "def generate(input, rng, params):\n",
257
+ " return model.generate(\n",
258
+ " **input,\n",
259
+ " max_length=257,\n",
260
+ " num_beams=1,\n",
261
+ " do_sample=True,\n",
262
+ " prng_key=rng,\n",
263
+ " eos_token_id=50000,\n",
264
+ " pad_token_id=50000,\n",
265
+ " params=params\n",
266
+ " )\n",
267
+ "\n",
268
+ "def get_images(indices, params):\n",
269
+ " return vqgan.decode_code(indices, params=params)\n",
270
+ "\n",
271
+ "\n",
272
+ "def plot_images(images):\n",
273
+ " fig = plt.figure(figsize=(40, 20))\n",
274
+ " columns = 4\n",
275
+ " rows = 2\n",
276
+ " plt.subplots_adjust(hspace=0, wspace=0)\n",
277
+ "\n",
278
+ " for i in range(1, columns*rows +1):\n",
279
+ " fig.add_subplot(rows, columns, i)\n",
280
+ " plt.imshow(images[i-1])\n",
281
+ " plt.gca().axes.get_yaxis().set_visible(False)\n",
282
+ " plt.show()\n",
283
+ " \n",
284
+ "def stack_reconstructions(images):\n",
285
+ " w, h = images[0].size[0], images[0].size[1]\n",
286
+ " img = Image.new(\"RGB\", (len(images)*w, h))\n",
287
+ " for i, img_ in enumerate(images):\n",
288
+ " img.paste(img_, (i*w,0))\n",
289
+ " return img"
290
+ ]
291
+ },
292
+ {
293
+ "cell_type": "code",
294
+ "execution_count": 166,
295
+ "id": "b1bec3d2-ef17-4feb-aa0d-b51ed2fdcd3e",
296
+ "metadata": {},
297
+ "outputs": [],
298
+ "source": [
299
+ "p_generate = jax.pmap(generate, \"batch\")\n",
300
+ "p_get_images = jax.pmap(get_images, \"batch\")"
301
+ ]
302
+ },
303
+ {
304
+ "cell_type": "code",
305
+ "execution_count": null,
306
+ "id": "a539823a-a775-4d92-96a5-dc8b1eef69c5",
307
+ "metadata": {},
308
+ "outputs": [],
309
+ "source": [
310
+ "bart_params = replicate(model.params)\n",
311
+ "vqgan_params = replicate(vqgan.params)"
312
+ ]
313
+ },
314
+ {
315
+ "cell_type": "code",
316
+ "execution_count": 328,
317
+ "id": "e8b268d8-6992-422a-8373-95651474ae70",
318
+ "metadata": {},
319
+ "outputs": [],
320
+ "source": [
321
+ "prompts = [\n",
322
+ " \"man in blue jacket walking on pathway in between trees during daytime\",\n",
323
+ " 'white snow covered mountain under blue sky during daytime',\n",
324
+ " 'white snow covered mountain under blue sky during night',\n",
325
+ " \"orange tabby cat on persons hand\",\n",
326
+ " \"aerial view of beach during daytime\",\n",
327
+ " \"chess pieces on chess board\",\n",
328
+ " \"laptop on brown wooden table\",\n",
329
+ " \"white bus on road near high rise buildings\",\n",
330
+ "]\n",
331
+ "\n",
332
+ "\n",
333
+ "prompt = [prompts[-1]] * 8\n",
334
+ "inputs = tokenizer(prompt, return_tensors='jax', padding=\"max_length\", truncation=True, max_length=128).data\n",
335
+ "inputs = shard(inputs)"
336
+ ]
337
+ },
338
+ {
339
+ "cell_type": "code",
340
+ "execution_count": null,
341
+ "id": "68638cfa-9a4d-4e6a-8630-91aefb627bbd",
342
+ "metadata": {},
343
+ "outputs": [],
344
+ "source": [
345
+ "%%time\n",
346
+ "for i in range(8):\n",
347
+ " key = random.randint(0, 1e7)\n",
348
+ " rng = jax.random.PRNGKey(key)\n",
349
+ " rngs = jax.random.split(rng, jax.local_device_count())\n",
350
+ " indices = p_generate(inputs, rngs, bart_params).sequences\n",
351
+ " indices = indices[:, :, 1:]\n",
352
+ "\n",
353
+ " images = p_get_images(indices, vqgan_params)\n",
354
+ " images = np.squeeze(np.asarray(images), 1)\n",
355
+ " imges = [custom_to_pil(image) for image in images]\n",
356
+ "\n",
357
+ " plt.figure(figsize=(40, 20))\n",
358
+ " plt.imshow(stack_reconstructions(imges))"
359
+ ]
360
+ },
361
+ {
362
+ "cell_type": "code",
363
+ "execution_count": null,
364
+ "id": "681af54e-da10-4b8e-80d0-ebcbdf23f376",
365
+ "metadata": {},
366
+ "outputs": [],
367
+ "source": []
368
+ }
369
+ ],
370
+ "metadata": {
371
+ "kernelspec": {
372
+ "display_name": "Python 3",
373
+ "language": "python",
374
+ "name": "python3"
375
+ },
376
+ "language_info": {
377
+ "codemirror_mode": {
378
+ "name": "ipython",
379
+ "version": 3
380
+ },
381
+ "file_extension": ".py",
382
+ "mimetype": "text/x-python",
383
+ "name": "python",
384
+ "nbconvert_exporter": "python",
385
+ "pygments_lexer": "ipython3",
386
+ "version": "3.8.10"
387
+ }
388
+ },
389
+ "nbformat": 4,
390
+ "nbformat_minor": 5
391
+ }