boris commited on
Commit
fc8c230
2 Parent(s): eb591ff c48da33

Merge pull request #26 from tmabraham/generation-training-demo

Browse files

demo for generation, including during training from wandb artifact

Files changed (1) hide show
  1. demo/demo_notebook.ipynb +495 -0
demo/demo_notebook.ipynb ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "ewer-Q-0w2xA"
7
+ },
8
+ "source": [
9
+ "# Installation"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 1,
15
+ "metadata": {
16
+ "colab": {
17
+ "base_uri": "https://localhost:8080/"
18
+ },
19
+ "id": "NpsF9ipLLl2s",
20
+ "outputId": "10bf54aa-b89d-4e42-9777-bc97b00a5f32"
21
+ },
22
+ "outputs": [],
23
+ "source": [
24
+ "#!pip install git+https://github.com/huggingface/transformers/\n",
25
+ "#!pip install git+https://github.com/google/flax"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": 2,
31
+ "metadata": {
32
+ "id": "M1wVkrpjU6zO"
33
+ },
34
+ "outputs": [],
35
+ "source": [
36
+ "%load_ext autoreload\n",
37
+ "%autoreload 2"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": 3,
43
+ "metadata": {},
44
+ "outputs": [
45
+ {
46
+ "name": "stdout",
47
+ "output_type": "stream",
48
+ "text": [
49
+ "/home/tmabraham/vqgan-jax\n"
50
+ ]
51
+ }
52
+ ],
53
+ "source": [
54
+ "%cd ../../vqgan-jax"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "markdown",
59
+ "metadata": {
60
+ "id": "t47CH1H_IOT8"
61
+ },
62
+ "source": [
63
+ "# Custom BART Model"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": 4,
69
+ "metadata": {
70
+ "id": "9jQnM6S2vCpn"
71
+ },
72
+ "outputs": [],
73
+ "source": [
74
+ "# TODO: set those args in a config file\n",
75
+ "OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n",
76
+ "OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n",
77
+ "BOS_TOKEN_ID = 16384\n",
78
+ "BASE_MODEL = 'facebook/bart-large'"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": 5,
84
+ "metadata": {
85
+ "id": "_eEaJVxAKpV5"
86
+ },
87
+ "outputs": [],
88
+ "source": [
89
+ "import jax\n",
90
+ "import flax.linen as nn\n",
91
+ "\n",
92
+ "from transformers.models.bart.modeling_flax_bart import *\n",
93
+ "from transformers import BartTokenizer, FlaxBartForConditionalGeneration\n",
94
+ "\n",
95
+ "class CustomFlaxBartModule(FlaxBartModule):\n",
96
+ " def setup(self):\n",
97
+ " # we keep shared to easily load pre-trained weights\n",
98
+ " self.shared = nn.Embed(\n",
99
+ " self.config.vocab_size,\n",
100
+ " self.config.d_model,\n",
101
+ " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
102
+ " dtype=self.dtype,\n",
103
+ " )\n",
104
+ " # a separate embedding is used for the decoder\n",
105
+ " self.decoder_embed = nn.Embed(\n",
106
+ " OUTPUT_VOCAB_SIZE,\n",
107
+ " self.config.d_model,\n",
108
+ " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
109
+ " dtype=self.dtype,\n",
110
+ " )\n",
111
+ " self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n",
112
+ "\n",
113
+ " # the decoder has a different config\n",
114
+ " decoder_config = BartConfig(self.config.to_dict())\n",
115
+ " decoder_config.max_position_embeddings = OUTPUT_LENGTH\n",
116
+ " decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n",
117
+ " self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n",
118
+ "\n",
119
+ "class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n",
120
+ " def setup(self):\n",
121
+ " self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n",
122
+ " self.lm_head = nn.Dense(\n",
123
+ " OUTPUT_VOCAB_SIZE,\n",
124
+ " use_bias=False,\n",
125
+ " dtype=self.dtype,\n",
126
+ " kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
127
+ " )\n",
128
+ " self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n",
129
+ "\n",
130
+ "class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n",
131
+ " module_class = CustomFlaxBartForConditionalGenerationModule"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": 6,
137
+ "metadata": {
138
+ "scrolled": true
139
+ },
140
+ "outputs": [
141
+ {
142
+ "name": "stderr",
143
+ "output_type": "stream",
144
+ "text": [
145
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mtmabraham\u001b[0m (use `wandb login --relogin` to force relogin)\n"
146
+ ]
147
+ },
148
+ {
149
+ "data": {
150
+ "text/html": [
151
+ "\n",
152
+ " Tracking run with wandb version 0.10.33<br/>\n",
153
+ " Syncing run <strong style=\"color:#cdcd00\">serene-resonance-1</strong> to <a href=\"https://wandb.ai\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
154
+ " Project page: <a href=\"https://wandb.ai/tmabraham/vqgan-jax\" target=\"_blank\">https://wandb.ai/tmabraham/vqgan-jax</a><br/>\n",
155
+ " Run page: <a href=\"https://wandb.ai/tmabraham/vqgan-jax/runs/1cm35ims\" target=\"_blank\">https://wandb.ai/tmabraham/vqgan-jax/runs/1cm35ims</a><br/>\n",
156
+ " Run data is saved locally in <code>/home/tmabraham/vqgan-jax/wandb/run-20210715_030616-1cm35ims</code><br/><br/>\n",
157
+ " "
158
+ ],
159
+ "text/plain": [
160
+ "<IPython.core.display.HTML object>"
161
+ ]
162
+ },
163
+ "metadata": {},
164
+ "output_type": "display_data"
165
+ },
166
+ {
167
+ "name": "stderr",
168
+ "output_type": "stream",
169
+ "text": [
170
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact model-1ef8yxby:v1, 1674.97MB. 2 files... Done. 0:0:0\n"
171
+ ]
172
+ }
173
+ ],
174
+ "source": [
175
+ "import wandb\n",
176
+ "run = wandb.init()\n",
177
+ "artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-1ef8yxby:v1', type='bart_model')\n",
178
+ "artifact_dir = artifact.download()"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": 7,
184
+ "metadata": {
185
+ "id": "_6-XKK40oEfP",
186
+ "scrolled": true
187
+ },
188
+ "outputs": [
189
+ {
190
+ "name": "stderr",
191
+ "output_type": "stream",
192
+ "text": [
193
+ "/home/tmabraham/dalle-mini/src/transformers/src/transformers/models/bart/configuration_bart.py:180: UserWarning: Please make sure the config includes `forced_bos_token_id=16384` in future versions.The config can simply be saved and uploaded again to be fixed.\n",
194
+ " warnings.warn(\n",
195
+ "INFO:absl:Starting the local TPU driver.\n",
196
+ "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
197
+ "INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: TPU Interpreter Host\n"
198
+ ]
199
+ }
200
+ ],
201
+ "source": [
202
+ "# create our model and initialize it randomly\n",
203
+ "model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "execution_count": 8,
209
+ "metadata": {
210
+ "colab": {
211
+ "base_uri": "https://localhost:8080/"
212
+ },
213
+ "id": "Jz032w73nHEf",
214
+ "outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49"
215
+ },
216
+ "outputs": [
217
+ {
218
+ "data": {
219
+ "text/plain": [
220
+ "(1, 16385)"
221
+ ]
222
+ },
223
+ "execution_count": 8,
224
+ "metadata": {},
225
+ "output_type": "execute_result"
226
+ }
227
+ ],
228
+ "source": [
229
+ "# we verify that the shape has not been modified\n",
230
+ "model.params['final_logits_bias'].shape"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "markdown",
235
+ "metadata": {
236
+ "id": "zLl24Ez5t7x1"
237
+ },
238
+ "source": [
239
+ "## Inference"
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "execution_count": 9,
245
+ "metadata": {
246
+ "id": "XLLA2NK3uDQr"
247
+ },
248
+ "outputs": [],
249
+ "source": [
250
+ "tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": 10,
256
+ "metadata": {
257
+ "id": "P32mJJSbrU1F"
258
+ },
259
+ "outputs": [],
260
+ "source": [
261
+ "input_ids_test = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')"
262
+ ]
263
+ },
264
+ {
265
+ "cell_type": "code",
266
+ "execution_count": 11,
267
+ "metadata": {},
268
+ "outputs": [
269
+ {
270
+ "data": {
271
+ "text/plain": [
272
+ "DeviceArray([[ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
273
+ " 2]], dtype=int32)"
274
+ ]
275
+ },
276
+ "execution_count": 11,
277
+ "metadata": {},
278
+ "output_type": "execute_result"
279
+ }
280
+ ],
281
+ "source": [
282
+ "input_ids_test"
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "code",
287
+ "execution_count": 12,
288
+ "metadata": {
289
+ "id": "C7cHbIHruELT"
290
+ },
291
+ "outputs": [],
292
+ "source": [
293
+ "greedy_output = model.generate(input_ids_test, max_length=257)"
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "code",
298
+ "execution_count": 13,
299
+ "metadata": {
300
+ "colab": {
301
+ "base_uri": "https://localhost:8080/"
302
+ },
303
+ "id": "jYugh9cOuwc9",
304
+ "outputId": "19c3a2ee-e7bc-4f1f-9c86-06bd7337b537"
305
+ },
306
+ "outputs": [
307
+ {
308
+ "data": {
309
+ "text/plain": [
310
+ "DeviceArray([[16384, 16384, 10042, 10042, 10042, 10042, 10042, 10042,\n",
311
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
312
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
313
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
314
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
315
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
316
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
317
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
318
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
319
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
320
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
321
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
322
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
323
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
324
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
325
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
326
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
327
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
328
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
329
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
330
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
331
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
332
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
333
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
334
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
335
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
336
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
337
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
338
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
339
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
340
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
341
+ " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
342
+ " 10042]], dtype=int32)"
343
+ ]
344
+ },
345
+ "execution_count": 13,
346
+ "metadata": {},
347
+ "output_type": "execute_result"
348
+ }
349
+ ],
350
+ "source": [
351
+ "greedy_output[0]"
352
+ ]
353
+ },
354
+ {
355
+ "cell_type": "markdown",
356
+ "metadata": {},
357
+ "source": [
358
+ "# VGAN Jax"
359
+ ]
360
+ },
361
+ {
362
+ "cell_type": "code",
363
+ "execution_count": 14,
364
+ "metadata": {},
365
+ "outputs": [],
366
+ "source": [
367
+ "import io\n",
368
+ "\n",
369
+ "import requests\n",
370
+ "from PIL import Image\n",
371
+ "import numpy as np\n",
372
+ "\n",
373
+ "import torch\n",
374
+ "import torchvision.transforms as T\n",
375
+ "import torchvision.transforms.functional as TF\n",
376
+ "from torchvision.transforms import InterpolationMode"
377
+ ]
378
+ },
379
+ {
380
+ "cell_type": "code",
381
+ "execution_count": 15,
382
+ "metadata": {},
383
+ "outputs": [],
384
+ "source": [
385
+ "from modeling_flax_vqgan import VQModel"
386
+ ]
387
+ },
388
+ {
389
+ "cell_type": "code",
390
+ "execution_count": 16,
391
+ "metadata": {},
392
+ "outputs": [],
393
+ "source": [
394
+ "def custom_to_pil(x):\n",
395
+ " x = np.clip(x, 0., 1.)\n",
396
+ " x = (255*x).astype(np.uint8)\n",
397
+ " x = Image.fromarray(x)\n",
398
+ " if not x.mode == \"RGB\":\n",
399
+ " x = x.convert(\"RGB\")\n",
400
+ " return x"
401
+ ]
402
+ },
403
+ {
404
+ "cell_type": "code",
405
+ "execution_count": 17,
406
+ "metadata": {
407
+ "colab": {
408
+ "base_uri": "https://localhost:8080/"
409
+ },
410
+ "id": "Jz032w73nHEf",
411
+ "outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49"
412
+ },
413
+ "outputs": [
414
+ {
415
+ "name": "stdout",
416
+ "output_type": "stream",
417
+ "text": [
418
+ "Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
419
+ ]
420
+ }
421
+ ],
422
+ "source": [
423
+ "model = VQModel.from_pretrained(\"valhalla/vqgan-imagenet-f16-1024\")"
424
+ ]
425
+ },
426
+ {
427
+ "cell_type": "code",
428
+ "execution_count": 18,
429
+ "metadata": {},
430
+ "outputs": [],
431
+ "source": [
432
+ "def get_images(indices, model):\n",
433
+ " indices = indices[:, 1:]\n",
434
+ " model.decode_code(indices)\n",
435
+ " return indices"
436
+ ]
437
+ },
438
+ {
439
+ "cell_type": "code",
440
+ "execution_count": 19,
441
+ "metadata": {},
442
+ "outputs": [
443
+ {
444
+ "name": "stdout",
445
+ "output_type": "stream",
446
+ "text": [
447
+ "Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
448
+ ]
449
+ },
450
+ {
451
+ "data": {
452
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAEACAIAAAD9XIvPAAAAF0lEQVR4nGP4//8/EwMDwygexaN45GEA7ucE/J1FRrMAAAAASUVORK5CYII=\n",
453
+ "text/plain": [
454
+ "<PIL.Image.Image image mode=RGB size=1x256 at 0x7FE6389B6280>"
455
+ ]
456
+ },
457
+ "execution_count": 19,
458
+ "metadata": {},
459
+ "output_type": "execute_result"
460
+ }
461
+ ],
462
+ "source": [
463
+ "custom_to_pil(np.asarray(get_images(greedy_output[0], model)[0]))"
464
+ ]
465
+ }
466
+ ],
467
+ "metadata": {
468
+ "accelerator": "TPU",
469
+ "colab": {
470
+ "collapsed_sections": [],
471
+ "machine_shape": "hm",
472
+ "name": "CustomBARTv4b-model-generate.ipynb",
473
+ "provenance": []
474
+ },
475
+ "kernelspec": {
476
+ "display_name": "Python 3",
477
+ "language": "python",
478
+ "name": "python3"
479
+ },
480
+ "language_info": {
481
+ "codemirror_mode": {
482
+ "name": "ipython",
483
+ "version": 3
484
+ },
485
+ "file_extension": ".py",
486
+ "mimetype": "text/x-python",
487
+ "name": "python",
488
+ "nbconvert_exporter": "python",
489
+ "pygments_lexer": "ipython3",
490
+ "version": "3.8.8"
491
+ }
492
+ },
493
+ "nbformat": 4,
494
+ "nbformat_minor": 1
495
+ }