pcuenca commited on
Commit
b599136
2 Parent(s): 440f966 5da4af0

Merge pull request #63 from borisdayma/chore-mv

Browse files
README.md CHANGED
@@ -22,10 +22,6 @@ You can create your own pictures with [the demo](https://huggingface.co/spaces/f
22
 
23
  Refer to [our report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA).
24
 
25
- ## Where does the logo come from?
26
-
27
- The "armchair in the shape of an avocado" was used by OpenAI when releasing DALL·E to illustrate the model's capabilities. Having successful predictions on this prompt represents a big milestone to us.
28
-
29
  ## Development
30
 
31
  This section is for the adventurous people wanting to look into the code.
@@ -58,13 +54,19 @@ Use [patil-suraj/vqgan-jax](https://github.com/patil-suraj/vqgan-jax).
58
 
59
  ### Training of Seq2Seq
60
 
61
- Refer to `dev/seq2seq` folder.
62
 
63
  You can also adjust the [sweep configuration file](https://docs.wandb.ai/guides/sweeps) if you need to perform a hyperparameter search.
64
 
65
- ### Inference
 
 
66
 
67
- Refer to `dev/notebooks/demo`.
 
 
 
 
68
 
69
  ## Authors
70
 
 
22
 
23
  Refer to [our report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA).
24
 
 
 
 
 
25
  ## Development
26
 
27
  This section is for the adventurous people wanting to look into the code.
 
54
 
55
  ### Training of Seq2Seq
56
 
57
+ Refer to [`dev/seq2seq`](dev/seq2seq) folder.
58
 
59
  You can also adjust the [sweep configuration file](https://docs.wandb.ai/guides/sweeps) if you need to perform a hyperparameter search.
60
 
61
+ ### Inference Pipeline
62
+
63
+ To generate sample predictions and understand the inference pipeline step by step, refer to [`dev/inference/inference_pipeline.ipynb`](dev/inference/inference_pipeline.ipynb).
64
 
65
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/dev/inference/inference_pipeline.ipynb)
66
+
67
+ ## Where does the logo come from?
68
+
69
+ The "armchair in the shape of an avocado" was used by OpenAI when releasing DALL·E to illustrate the model's capabilities. Having successful predictions on this prompt represents a big milestone to us.
70
 
71
  ## Authors
72
 
dev/{notebooks/encoding → encoding}/vqgan-jax-encoding-with-captions.ipynb RENAMED
File without changes
dev/{notebooks/encoding → encoding}/vqgan-jax-encoding-yfcc100m.ipynb RENAMED
File without changes
dev/{notebooks/encoding → encoding}/vqgan-jax-encoding.ipynb RENAMED
File without changes
dev/{predictions → inference}/README.md RENAMED
File without changes
dev/{predictions → inference}/dalle_mini RENAMED
File without changes
DALL·E_mini_Inference_pipeline.ipynb → dev/inference/inference_pipeline.ipynb RENAMED
File without changes
dev/inference/wandb-examples-from-backend.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import wandb
6
+ import os
7
+
8
+ from dalle_mini.backend import ServiceError, get_images_from_backend
9
+ from dalle_mini.helpers import captioned_strip
10
+
11
+ os.environ["WANDB_SILENT"] = "true"
12
+ os.environ["WANDB_CONSOLE"] = "off"
13
+
14
+ def log_to_wandb(prompts):
15
+ try:
16
+ backend_url = os.environ["BACKEND_SERVER"]
17
+ for _ in range(1):
18
+ for prompt in prompts:
19
+ print(f"Getting selections for: {prompt}")
20
+ # make a separate run per prompt
21
+ with wandb.init(
22
+ entity='wandb',
23
+ project='hf-flax-dalle-mini',
24
+ job_type='predictions',# tags=['openai'],
25
+ config={'prompt': prompt}
26
+ ):
27
+ imgs = []
28
+ selected = get_images_from_backend(prompt, backend_url)
29
+ strip = captioned_strip(selected, prompt)
30
+ imgs.append(wandb.Image(strip))
31
+ wandb.log({"images": imgs})
32
+ except ServiceError as error:
33
+ print(f"Service unavailable, status: {error.status_code}")
34
+ except KeyError:
35
+ print("Error: BACKEND_SERVER unset")
36
+
37
+ prompts = [
38
+ # "white snow covered mountain under blue sky during daytime",
39
+ # "aerial view of beach during daytime",
40
+ # "aerial view of beach at night",
41
+ # "a farmhouse surrounded by beautiful flowers",
42
+ # "an armchair in the shape of an avocado",
43
+ # "young woman riding her bike trough a forest",
44
+ # "a unicorn is passing by a rainbow in a field of flowers",
45
+ # "illustration of a baby shark swimming around corals",
46
+ # "painting of an oniric forest glade surrounded by tall trees",
47
+ # "sunset over green mountains",
48
+ # "a forest glade surrounded by tall trees in a sunny Spring morning",
49
+ # "fishing village under the moonlight in a serene sunset",
50
+ # "cartoon of a carrot with big eyes",
51
+ # "still life in the style of Kandinsky",
52
+ # "still life in the style of Picasso",
53
+ # "a graphite sketch of a gothic cathedral",
54
+ # "a graphite sketch of Elon Musk",
55
+ # "a watercolor pond with green leaves and yellow flowers",
56
+ # "a logo of a cute avocado armchair singing karaoke on stage in front of a crowd of strawberry shaped lamps",
57
+ # "happy celebration in a small village in Africa",
58
+ # "a logo of an armchair in the shape of an avocado"
59
+ # "Pele and Maradona in a hypothetical match",
60
+ # "Mohammed Ali and Mike Tyson in a hypothetical match",
61
+ # "a storefront that has the word 'openai' written on it",
62
+ # "a pentagonal green clock",
63
+ # "a collection of glasses is sitting on a table",
64
+ # "a small red block sitting on a large green block",
65
+ # "an extreme close-up view of a capybara sitting in a field",
66
+ # "a cross-section view of a walnut",
67
+ # "a professional high-quality emoji of a lovestruck cup of boba",
68
+ # "a photo of san francisco's golden gate bridge",
69
+ # "an illustration of a baby daikon radish in a tutu walking a dog",
70
+ # "a picture of the Eiffel tower on the Moon",
71
+ # "a colorful stairway to heaven",
72
+ "this is a detailed high-resolution scan of a human brain"
73
+ ]
74
+
75
+ for _ in range(1):
76
+ log_to_wandb(prompts)
dev/{predictions → inference}/wandb-examples.py RENAMED
@@ -4,16 +4,14 @@
4
  import random
5
 
6
  import jax
7
- import flax.linen as nn
8
  from flax.training.common_utils import shard
9
  from flax.jax_utils import replicate, unreplicate
10
 
11
  from transformers.models.bart.modeling_flax_bart import *
12
  from transformers import BartTokenizer, FlaxBartForConditionalGeneration
13
 
14
- import io
15
 
16
- import requests
17
  from PIL import Image
18
  import numpy as np
19
  import matplotlib.pyplot as plt
@@ -23,58 +21,24 @@ import torchvision.transforms as T
23
  import torchvision.transforms.functional as TF
24
  from torchvision.transforms import InterpolationMode
25
 
 
26
  from vqgan_jax.modeling_flax_vqgan import VQModel
27
 
28
- # TODO: set those args in a config file
29
- OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
30
- OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
31
- BOS_TOKEN_ID = 16384
32
- BASE_MODEL = 'facebook/bart-large-cnn'
33
-
34
- class CustomFlaxBartModule(FlaxBartModule):
35
- def setup(self):
36
- # we keep shared to easily load pre-trained weights
37
- self.shared = nn.Embed(
38
- self.config.vocab_size,
39
- self.config.d_model,
40
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
41
- dtype=self.dtype,
42
- )
43
- # a separate embedding is used for the decoder
44
- self.decoder_embed = nn.Embed(
45
- OUTPUT_VOCAB_SIZE,
46
- self.config.d_model,
47
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
48
- dtype=self.dtype,
49
- )
50
- self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
51
-
52
- # the decoder has a different config
53
- decoder_config = BartConfig(self.config.to_dict())
54
- decoder_config.max_position_embeddings = OUTPUT_LENGTH
55
- decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
56
- self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
57
-
58
- class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
59
- def setup(self):
60
- self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
61
- self.lm_head = nn.Dense(
62
- OUTPUT_VOCAB_SIZE,
63
- use_bias=False,
64
- dtype=self.dtype,
65
- kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
66
- )
67
- self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))
68
-
69
- class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
70
- module_class = CustomFlaxBartForConditionalGenerationModule
71
-
72
 
73
  import wandb
74
  import os
 
 
 
 
75
  os.environ["WANDB_SILENT"] = "true"
76
  os.environ["WANDB_CONSOLE"] = "off"
77
 
 
 
 
78
  # set id to None so our latest images don't get overwritten
79
  id = None
80
  run = wandb.init(id=id,
@@ -87,8 +51,10 @@ artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-4oh3u7ca:latest', ty
87
  artifact_dir = artifact.download()
88
 
89
  # create our model
90
- tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)
91
  model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
 
 
 
92
  model.config.force_bos_token_to_be_generated = False
93
  model.config.forced_bos_token_id = None
94
  model.config.forced_eos_token_id = None
@@ -143,9 +109,6 @@ p_get_images = jax.pmap(get_images, "batch")
143
  bart_params = replicate(model.params)
144
  vqgan_params = replicate(vqgan.params)
145
 
146
- # ## CLIP Scoring
147
- from transformers import CLIPProcessor, FlaxCLIPModel
148
-
149
  clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
150
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
151
 
@@ -170,16 +133,12 @@ def hallucinate(prompt, num_images=64):
170
 
171
  def clip_top_k(prompt, images, k=8):
172
  inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
 
173
  outputs = clip(**inputs)
174
  logits = outputs.logits_per_text
175
  scores = np.array(logits[0]).argsort()[-k:][::-1]
176
  return [images[score] for score in scores]
177
 
178
-
179
- # ## Log to wandb
180
-
181
- from dalle_mini.helpers import captioned_strip
182
-
183
  def log_to_wandb(prompts):
184
  strips = []
185
  for prompt in prompts:
 
4
  import random
5
 
6
  import jax
 
7
  from flax.training.common_utils import shard
8
  from flax.jax_utils import replicate, unreplicate
9
 
10
  from transformers.models.bart.modeling_flax_bart import *
11
  from transformers import BartTokenizer, FlaxBartForConditionalGeneration
12
 
13
+ import os
14
 
 
15
  from PIL import Image
16
  import numpy as np
17
  import matplotlib.pyplot as plt
 
21
  import torchvision.transforms.functional as TF
22
  from torchvision.transforms import InterpolationMode
23
 
24
+ from dalle_mini.model import CustomFlaxBartForConditionalGeneration
25
  from vqgan_jax.modeling_flax_vqgan import VQModel
26
 
27
+ # ## CLIP Scoring
28
+ from transformers import CLIPProcessor, FlaxCLIPModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  import wandb
31
  import os
32
+
33
+ from dalle_mini.helpers import captioned_strip
34
+
35
+
36
  os.environ["WANDB_SILENT"] = "true"
37
  os.environ["WANDB_CONSOLE"] = "off"
38
 
39
+ # TODO: used for legacy support
40
+ BASE_MODEL = 'facebook/bart-large-cnn'
41
+
42
  # set id to None so our latest images don't get overwritten
43
  id = None
44
  run = wandb.init(id=id,
 
51
  artifact_dir = artifact.download()
52
 
53
  # create our model
 
54
  model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
55
+
56
+ # TODO: legacy support (earlier models)
57
+ tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)
58
  model.config.force_bos_token_to_be_generated = False
59
  model.config.forced_bos_token_id = None
60
  model.config.forced_eos_token_id = None
 
109
  bart_params = replicate(model.params)
110
  vqgan_params = replicate(vqgan.params)
111
 
 
 
 
112
  clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
113
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
114
 
 
133
 
134
  def clip_top_k(prompt, images, k=8):
135
  inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
136
+ # FIXME: image should be resized and normalized prior to being processed by CLIP
137
  outputs = clip(**inputs)
138
  logits = outputs.logits_per_text
139
  scores = np.array(logits[0]).argsort()[-k:][::-1]
140
  return [images[score] for score in scores]
141
 
 
 
 
 
 
142
  def log_to_wandb(prompts):
143
  strips = []
144
  for prompt in prompts:
dev/notebooks/README.md DELETED
@@ -1,5 +0,0 @@
1
- # Notebooks
2
-
3
- These notebooks were used during development.
4
-
5
- TODO: This section requires some refactor and clean up.
 
 
 
 
 
 
dev/notebooks/demo/CustomBARTv4b_model-generate.ipynb DELETED
@@ -1,394 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {
6
- "id": "ewer-Q-0w2xA"
7
- },
8
- "source": [
9
- "# Installation"
10
- ]
11
- },
12
- {
13
- "cell_type": "code",
14
- "execution_count": null,
15
- "metadata": {
16
- "colab": {
17
- "base_uri": "https://localhost:8080/"
18
- },
19
- "id": "NpsF9ipLLl2s",
20
- "outputId": "10bf54aa-b89d-4e42-9777-bc97b00a5f32"
21
- },
22
- "outputs": [],
23
- "source": [
24
- "!pip install git+https://github.com/huggingface/transformers/\n",
25
- "!pip install git+https://github.com/google/flax"
26
- ]
27
- },
28
- {
29
- "cell_type": "code",
30
- "execution_count": null,
31
- "metadata": {
32
- "id": "M1wVkrpjU6zO"
33
- },
34
- "outputs": [],
35
- "source": [
36
- "%load_ext autoreload\n",
37
- "%autoreload 2"
38
- ]
39
- },
40
- {
41
- "cell_type": "markdown",
42
- "metadata": {
43
- "id": "t47CH1H_IOT8"
44
- },
45
- "source": [
46
- "# Custom BART Model"
47
- ]
48
- },
49
- {
50
- "cell_type": "code",
51
- "execution_count": null,
52
- "metadata": {
53
- "id": "9jQnM6S2vCpn"
54
- },
55
- "outputs": [],
56
- "source": [
57
- "# TODO: set those args in a config file\n",
58
- "OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n",
59
- "OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n",
60
- "BOS_TOKEN_ID = 16384\n",
61
- "BASE_MODEL = 'facebook/bart-large'"
62
- ]
63
- },
64
- {
65
- "cell_type": "code",
66
- "execution_count": null,
67
- "metadata": {
68
- "id": "_eEaJVxAKpV5"
69
- },
70
- "outputs": [],
71
- "source": [
72
- "import jax\n",
73
- "import flax.linen as nn\n",
74
- "\n",
75
- "from transformers.models.bart.modeling_flax_bart import *\n",
76
- "from transformers import BartTokenizer, FlaxBartForConditionalGeneration\n",
77
- "\n",
78
- "class CustomFlaxBartModule(FlaxBartModule):\n",
79
- " def setup(self):\n",
80
- " # we keep shared to easily load pre-trained weights\n",
81
- " self.shared = nn.Embed(\n",
82
- " self.config.vocab_size,\n",
83
- " self.config.d_model,\n",
84
- " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
85
- " dtype=self.dtype,\n",
86
- " )\n",
87
- " # a separate embedding is used for the decoder\n",
88
- " self.decoder_embed = nn.Embed(\n",
89
- " OUTPUT_VOCAB_SIZE,\n",
90
- " self.config.d_model,\n",
91
- " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
92
- " dtype=self.dtype,\n",
93
- " )\n",
94
- " self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n",
95
- "\n",
96
- " # the decoder has a different config\n",
97
- " decoder_config = BartConfig(self.config.to_dict())\n",
98
- " decoder_config.max_position_embeddings = OUTPUT_LENGTH\n",
99
- " decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n",
100
- " self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n",
101
- "\n",
102
- "class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n",
103
- " def setup(self):\n",
104
- " self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n",
105
- " self.lm_head = nn.Dense(\n",
106
- " OUTPUT_VOCAB_SIZE,\n",
107
- " use_bias=False,\n",
108
- " dtype=self.dtype,\n",
109
- " kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
110
- " )\n",
111
- " self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n",
112
- "\n",
113
- "class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n",
114
- " module_class = CustomFlaxBartForConditionalGenerationModule"
115
- ]
116
- },
117
- {
118
- "cell_type": "code",
119
- "execution_count": null,
120
- "metadata": {
121
- "colab": {
122
- "base_uri": "https://localhost:8080/"
123
- },
124
- "id": "S7CP9Td9m2ge",
125
- "outputId": "5638ef68-9c40-46f7-90ba-a4d05b61360d"
126
- },
127
- "outputs": [],
128
- "source": [
129
- "# load pre-trained model for encoder weights\n",
130
- "base_model = FlaxBartForConditionalGeneration.from_pretrained(BASE_MODEL)"
131
- ]
132
- },
133
- {
134
- "cell_type": "code",
135
- "execution_count": null,
136
- "metadata": {
137
- "id": "6lmynR-poceH"
138
- },
139
- "outputs": [],
140
- "source": [
141
- "# set up our new model config\n",
142
- "config = BartConfig.from_pretrained(BASE_MODEL)\n",
143
- "config.tie_word_embeddings = False\n",
144
- "config.decoder_start_token_id = BOS_TOKEN_ID\n",
145
- "config.bos_token_id = BOS_TOKEN_ID # should not be used\n",
146
- "config.pos_token_id = BOS_TOKEN_ID # should not be used\n",
147
- "#config.eos_token_id = None # prevents generation from stopping until we reach max_length"
148
- ]
149
- },
150
- {
151
- "cell_type": "code",
152
- "execution_count": null,
153
- "metadata": {
154
- "id": "_6-XKK40oEfP"
155
- },
156
- "outputs": [],
157
- "source": [
158
- "# create our model and initialize it randomly\n",
159
- "model = CustomFlaxBartForConditionalGeneration(config)"
160
- ]
161
- },
162
- {
163
- "cell_type": "code",
164
- "execution_count": null,
165
- "metadata": {
166
- "id": "-r_hZestr-NR"
167
- },
168
- "outputs": [],
169
- "source": [
170
- "# use pretrained weights\n",
171
- "model.params['model']['encoder'] = base_model.params['model']['encoder']\n",
172
- "model.params['model']['shared'] = base_model.params['model']['shared']"
173
- ]
174
- },
175
- {
176
- "cell_type": "code",
177
- "execution_count": null,
178
- "metadata": {
179
- "id": "5NEX8f62sVjx"
180
- },
181
- "outputs": [],
182
- "source": [
183
- "# no need for base_model anymore\n",
184
- "del base_model"
185
- ]
186
- },
187
- {
188
- "cell_type": "code",
189
- "execution_count": null,
190
- "metadata": {
191
- "colab": {
192
- "base_uri": "https://localhost:8080/"
193
- },
194
- "id": "Jz032w73nHEf",
195
- "outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49"
196
- },
197
- "outputs": [],
198
- "source": [
199
- "# we verify that the shape has not been modified\n",
200
- "model.params['final_logits_bias'].shape"
201
- ]
202
- },
203
- {
204
- "cell_type": "markdown",
205
- "metadata": {
206
- "id": "zLl24Ez5t7x1"
207
- },
208
- "source": [
209
- "## Inference"
210
- ]
211
- },
212
- {
213
- "cell_type": "code",
214
- "execution_count": null,
215
- "metadata": {
216
- "id": "XLLA2NK3uDQr"
217
- },
218
- "outputs": [],
219
- "source": [
220
- "tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)"
221
- ]
222
- },
223
- {
224
- "cell_type": "code",
225
- "execution_count": null,
226
- "metadata": {
227
- "colab": {
228
- "base_uri": "https://localhost:8080/"
229
- },
230
- "id": "Ntow53I_t81D",
231
- "outputId": "59289cdd-1429-4720-cc87-88810c4b99ac"
232
- },
233
- "outputs": [],
234
- "source": [
235
- "text = \"My friends are cool but they eat too many carbs.\"\n",
236
- "inputs = tokenizer(text, max_length=1024, return_tensors='jax')\n",
237
- "encoder_outputs = model.encode(**inputs)"
238
- ]
239
- },
240
- {
241
- "cell_type": "code",
242
- "execution_count": null,
243
- "metadata": {
244
- "colab": {
245
- "base_uri": "https://localhost:8080/"
246
- },
247
- "id": "vcRNJnJ_uJOJ",
248
- "outputId": "025afd54-7908-4a9c-fb59-e40bd3458711"
249
- },
250
- "outputs": [],
251
- "source": [
252
- "decoder_start_token_id = model.config.decoder_start_token_id\n",
253
- "decoder_start_token_id"
254
- ]
255
- },
256
- {
257
- "cell_type": "code",
258
- "execution_count": null,
259
- "metadata": {
260
- "id": "6QWmEwL_uMld"
261
- },
262
- "outputs": [],
263
- "source": [
264
- "decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype=\"i4\") * decoder_start_token_id\n",
265
- "outputs = model.decode(decoder_input_ids, encoder_outputs)"
266
- ]
267
- },
268
- {
269
- "cell_type": "code",
270
- "execution_count": null,
271
- "metadata": {
272
- "colab": {
273
- "base_uri": "https://localhost:8080/"
274
- },
275
- "id": "c_ys3yWBothF",
276
- "outputId": "40d4d584-e0a8-44cb-bbea-0ffa38d50a53"
277
- },
278
- "outputs": [],
279
- "source": [
280
- "outputs"
281
- ]
282
- },
283
- {
284
- "cell_type": "code",
285
- "execution_count": null,
286
- "metadata": {
287
- "colab": {
288
- "base_uri": "https://localhost:8080/"
289
- },
290
- "id": "O6s0wtB_uTC_",
291
- "outputId": "bc0e9e80-e346-4e99-d28e-3f658eda1f66"
292
- },
293
- "outputs": [],
294
- "source": [
295
- "outputs.logits.shape"
296
- ]
297
- },
298
- {
299
- "cell_type": "code",
300
- "execution_count": null,
301
- "metadata": {
302
- "colab": {
303
- "base_uri": "https://localhost:8080/"
304
- },
305
- "id": "ELzemGP3uBzy",
306
- "outputId": "dc12f98a-1ccf-450d-ba2a-9c29d7d14885"
307
- },
308
- "outputs": [],
309
- "source": [
310
- "outputs.logits.argmax(axis=-1)"
311
- ]
312
- },
313
- {
314
- "cell_type": "code",
315
- "execution_count": null,
316
- "metadata": {
317
- "colab": {
318
- "base_uri": "https://localhost:8080/"
319
- },
320
- "id": "fQjikkGEunpx",
321
- "outputId": "3dba0209-ad4e-4069-be38-6c599c677ef1"
322
- },
323
- "outputs": [],
324
- "source": [
325
- "model.config.bos_token_id, model.config.eos_token_id, model.config.pad_token_id"
326
- ]
327
- },
328
- {
329
- "cell_type": "code",
330
- "execution_count": null,
331
- "metadata": {
332
- "id": "P32mJJSbrU1F"
333
- },
334
- "outputs": [],
335
- "source": [
336
- "input_ids_test = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')"
337
- ]
338
- },
339
- {
340
- "cell_type": "code",
341
- "execution_count": null,
342
- "metadata": {
343
- "id": "C7cHbIHruELT"
344
- },
345
- "outputs": [],
346
- "source": [
347
- "greedy_output = model.generate(input_ids_test, max_length=50)"
348
- ]
349
- },
350
- {
351
- "cell_type": "code",
352
- "execution_count": null,
353
- "metadata": {
354
- "colab": {
355
- "base_uri": "https://localhost:8080/"
356
- },
357
- "id": "jYugh9cOuwc9",
358
- "outputId": "19c3a2ee-e7bc-4f1f-9c86-06bd7337b537"
359
- },
360
- "outputs": [],
361
- "source": [
362
- "greedy_output[0]"
363
- ]
364
- }
365
- ],
366
- "metadata": {
367
- "accelerator": "TPU",
368
- "colab": {
369
- "collapsed_sections": [],
370
- "machine_shape": "hm",
371
- "name": "CustomBARTv4b-model-generate.ipynb",
372
- "provenance": []
373
- },
374
- "kernelspec": {
375
- "display_name": "Python 3 (ipykernel)",
376
- "language": "python",
377
- "name": "python3"
378
- },
379
- "language_info": {
380
- "codemirror_mode": {
381
- "name": "ipython",
382
- "version": 3
383
- },
384
- "file_extension": ".py",
385
- "mimetype": "text/x-python",
386
- "name": "python",
387
- "nbconvert_exporter": "python",
388
- "pygments_lexer": "ipython3",
389
- "version": "3.8.5"
390
- }
391
- },
392
- "nbformat": 4,
393
- "nbformat_minor": 4
394
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/notebooks/demo/demo_notebook.ipynb DELETED
@@ -1,387 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {
6
- "id": "ewer-Q-0w2xA"
7
- },
8
- "source": [
9
- "# Installation"
10
- ]
11
- },
12
- {
13
- "cell_type": "code",
14
- "execution_count": null,
15
- "metadata": {
16
- "colab": {
17
- "base_uri": "https://localhost:8080/"
18
- },
19
- "id": "NpsF9ipLLl2s",
20
- "outputId": "10bf54aa-b89d-4e42-9777-bc97b00a5f32"
21
- },
22
- "outputs": [],
23
- "source": [
24
- "#!pip install git+https://github.com/huggingface/transformers/\n",
25
- "#!pip install git+https://github.com/google/flax"
26
- ]
27
- },
28
- {
29
- "cell_type": "code",
30
- "execution_count": null,
31
- "metadata": {
32
- "id": "M1wVkrpjU6zO"
33
- },
34
- "outputs": [],
35
- "source": [
36
- "%load_ext autoreload\n",
37
- "%autoreload 2"
38
- ]
39
- },
40
- {
41
- "cell_type": "code",
42
- "execution_count": null,
43
- "metadata": {},
44
- "outputs": [],
45
- "source": [
46
- "%cd ../../vqgan-jax"
47
- ]
48
- },
49
- {
50
- "cell_type": "markdown",
51
- "metadata": {
52
- "id": "t47CH1H_IOT8"
53
- },
54
- "source": [
55
- "# Custom BART Model"
56
- ]
57
- },
58
- {
59
- "cell_type": "code",
60
- "execution_count": null,
61
- "metadata": {
62
- "id": "9jQnM6S2vCpn"
63
- },
64
- "outputs": [],
65
- "source": [
66
- "# TODO: set those args in a config file\n",
67
- "OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n",
68
- "OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n",
69
- "BOS_TOKEN_ID = 16384\n",
70
- "BASE_MODEL = 'facebook/bart-large'"
71
- ]
72
- },
73
- {
74
- "cell_type": "code",
75
- "execution_count": null,
76
- "metadata": {
77
- "id": "_eEaJVxAKpV5"
78
- },
79
- "outputs": [],
80
- "source": [
81
- "import jax\n",
82
- "import flax.linen as nn\n",
83
- "\n",
84
- "from transformers.models.bart.modeling_flax_bart import *\n",
85
- "from transformers import BartTokenizer, FlaxBartForConditionalGeneration\n",
86
- "\n",
87
- "class CustomFlaxBartModule(FlaxBartModule):\n",
88
- " def setup(self):\n",
89
- " # we keep shared to easily load pre-trained weights\n",
90
- " self.shared = nn.Embed(\n",
91
- " self.config.vocab_size,\n",
92
- " self.config.d_model,\n",
93
- " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
94
- " dtype=self.dtype,\n",
95
- " )\n",
96
- " # a separate embedding is used for the decoder\n",
97
- " self.decoder_embed = nn.Embed(\n",
98
- " OUTPUT_VOCAB_SIZE,\n",
99
- " self.config.d_model,\n",
100
- " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
101
- " dtype=self.dtype,\n",
102
- " )\n",
103
- " self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n",
104
- "\n",
105
- " # the decoder has a different config\n",
106
- " decoder_config = BartConfig(self.config.to_dict())\n",
107
- " decoder_config.max_position_embeddings = OUTPUT_LENGTH\n",
108
- " decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n",
109
- " self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n",
110
- "\n",
111
- "class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n",
112
- " def setup(self):\n",
113
- " self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n",
114
- " self.lm_head = nn.Dense(\n",
115
- " OUTPUT_VOCAB_SIZE,\n",
116
- " use_bias=False,\n",
117
- " dtype=self.dtype,\n",
118
- " kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
119
- " )\n",
120
- " self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n",
121
- "\n",
122
- "class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n",
123
- " module_class = CustomFlaxBartForConditionalGenerationModule"
124
- ]
125
- },
126
- {
127
- "cell_type": "code",
128
- "execution_count": null,
129
- "metadata": {
130
- "scrolled": true
131
- },
132
- "outputs": [],
133
- "source": [
134
- "import wandb\n",
135
- "run = wandb.init()\n",
136
- "artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-1ef8yxby:latest', type='bart_model')\n",
137
- "artifact_dir = artifact.download()"
138
- ]
139
- },
140
- {
141
- "cell_type": "code",
142
- "execution_count": null,
143
- "metadata": {
144
- "id": "_6-XKK40oEfP",
145
- "scrolled": true
146
- },
147
- "outputs": [],
148
- "source": [
149
- "# create our model and initialize it randomly\n",
150
- "model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)"
151
- ]
152
- },
153
- {
154
- "cell_type": "code",
155
- "execution_count": null,
156
- "metadata": {},
157
- "outputs": [],
158
- "source": [
159
- "model.config.forced_bos_token_id = None"
160
- ]
161
- },
162
- {
163
- "cell_type": "code",
164
- "execution_count": null,
165
- "metadata": {
166
- "colab": {
167
- "base_uri": "https://localhost:8080/"
168
- },
169
- "id": "Jz032w73nHEf",
170
- "outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49"
171
- },
172
- "outputs": [],
173
- "source": [
174
- "# we verify that the shape has not been modified\n",
175
- "model.params['final_logits_bias'].shape"
176
- ]
177
- },
178
- {
179
- "cell_type": "markdown",
180
- "metadata": {
181
- "id": "zLl24Ez5t7x1"
182
- },
183
- "source": [
184
- "## Inference"
185
- ]
186
- },
187
- {
188
- "cell_type": "code",
189
- "execution_count": null,
190
- "metadata": {
191
- "id": "XLLA2NK3uDQr"
192
- },
193
- "outputs": [],
194
- "source": [
195
- "tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)"
196
- ]
197
- },
198
- {
199
- "cell_type": "code",
200
- "execution_count": null,
201
- "metadata": {},
202
- "outputs": [],
203
- "source": [
204
- "input_text = ['I enjoy walking with my cute dog']*8"
205
- ]
206
- },
207
- {
208
- "cell_type": "code",
209
- "execution_count": null,
210
- "metadata": {
211
- "id": "P32mJJSbrU1F"
212
- },
213
- "outputs": [],
214
- "source": [
215
- "input_ids_test = tokenizer(input_text, return_tensors='jax')"
216
- ]
217
- },
218
- {
219
- "cell_type": "code",
220
- "execution_count": null,
221
- "metadata": {},
222
- "outputs": [],
223
- "source": [
224
- "input_ids_test"
225
- ]
226
- },
227
- {
228
- "cell_type": "code",
229
- "execution_count": null,
230
- "metadata": {
231
- "id": "C7cHbIHruELT"
232
- },
233
- "outputs": [],
234
- "source": [
235
- "greedy_output = model.generate(input_ids_test['input_ids'], max_length=257)"
236
- ]
237
- },
238
- {
239
- "cell_type": "code",
240
- "execution_count": null,
241
- "metadata": {},
242
- "outputs": [],
243
- "source": [
244
- "greedy_output[0].shape"
245
- ]
246
- },
247
- {
248
- "cell_type": "code",
249
- "execution_count": null,
250
- "metadata": {
251
- "colab": {
252
- "base_uri": "https://localhost:8080/"
253
- },
254
- "id": "jYugh9cOuwc9",
255
- "outputId": "19c3a2ee-e7bc-4f1f-9c86-06bd7337b537"
256
- },
257
- "outputs": [],
258
- "source": [
259
- "greedy_output[0]"
260
- ]
261
- },
262
- {
263
- "cell_type": "code",
264
- "execution_count": null,
265
- "metadata": {},
266
- "outputs": [],
267
- "source": [
268
- "greedy_output[0][0]"
269
- ]
270
- },
271
- {
272
- "cell_type": "markdown",
273
- "metadata": {},
274
- "source": [
275
- "# VGAN Jax"
276
- ]
277
- },
278
- {
279
- "cell_type": "code",
280
- "execution_count": null,
281
- "metadata": {},
282
- "outputs": [],
283
- "source": [
284
- "import io\n",
285
- "\n",
286
- "import requests\n",
287
- "from PIL import Image\n",
288
- "import numpy as np\n",
289
- "\n",
290
- "import torch\n",
291
- "import torchvision.transforms as T\n",
292
- "import torchvision.transforms.functional as TF\n",
293
- "from torchvision.transforms import InterpolationMode"
294
- ]
295
- },
296
- {
297
- "cell_type": "code",
298
- "execution_count": null,
299
- "metadata": {},
300
- "outputs": [],
301
- "source": [
302
- "from modeling_flax_vqgan import VQModel"
303
- ]
304
- },
305
- {
306
- "cell_type": "code",
307
- "execution_count": null,
308
- "metadata": {},
309
- "outputs": [],
310
- "source": [
311
- "def custom_to_pil(x):\n",
312
- " x = np.clip(x, 0., 1.)\n",
313
- " x = (255*x).astype(np.uint8)\n",
314
- " x = Image.fromarray(x)\n",
315
- " if not x.mode == \"RGB\":\n",
316
- " x = x.convert(\"RGB\")\n",
317
- " return x"
318
- ]
319
- },
320
- {
321
- "cell_type": "code",
322
- "execution_count": null,
323
- "metadata": {
324
- "colab": {
325
- "base_uri": "https://localhost:8080/"
326
- },
327
- "id": "Jz032w73nHEf",
328
- "outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49",
329
- "scrolled": true
330
- },
331
- "outputs": [],
332
- "source": [
333
- "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
334
- ]
335
- },
336
- {
337
- "cell_type": "code",
338
- "execution_count": null,
339
- "metadata": {},
340
- "outputs": [],
341
- "source": [
342
- "def get_images(indices, model):\n",
343
- " indices = indices[:, 1:]\n",
344
- " print(indices.shape)\n",
345
- " img = model.decode_code(indices)\n",
346
- " return img"
347
- ]
348
- },
349
- {
350
- "cell_type": "code",
351
- "execution_count": null,
352
- "metadata": {},
353
- "outputs": [],
354
- "source": [
355
- "custom_to_pil(np.asarray(get_images(jnp.expand_dims(greedy_output[0][0],0), model)[0]))"
356
- ]
357
- }
358
- ],
359
- "metadata": {
360
- "accelerator": "TPU",
361
- "colab": {
362
- "collapsed_sections": [],
363
- "machine_shape": "hm",
364
- "name": "CustomBARTv4b-model-generate.ipynb",
365
- "provenance": []
366
- },
367
- "kernelspec": {
368
- "display_name": "Python 3 (ipykernel)",
369
- "language": "python",
370
- "name": "python3"
371
- },
372
- "language_info": {
373
- "codemirror_mode": {
374
- "name": "ipython",
375
- "version": 3
376
- },
377
- "file_extension": ".py",
378
- "mimetype": "text/x-python",
379
- "name": "python",
380
- "nbconvert_exporter": "python",
381
- "pygments_lexer": "ipython3",
382
- "version": "3.8.5"
383
- }
384
- },
385
- "nbformat": 4,
386
- "nbformat_minor": 4
387
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/notebooks/demo/model-sweep.py DELETED
@@ -1,216 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
-
4
- import random
5
-
6
- import jax
7
- import flax.linen as nn
8
- from flax.training.common_utils import shard
9
- from flax.jax_utils import replicate, unreplicate
10
-
11
- from transformers.models.bart.modeling_flax_bart import *
12
- from transformers import BartTokenizer, FlaxBartForConditionalGeneration
13
-
14
- from PIL import Image
15
- import numpy as np
16
- import matplotlib.pyplot as plt
17
-
18
- import torchvision.transforms as T
19
- import torchvision.transforms.functional as TF
20
- from torchvision.transforms import InterpolationMode
21
-
22
- from vqgan_jax.modeling_flax_vqgan import VQModel
23
-
24
- # TODO: set those args in a config file
25
- OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
26
- OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
27
- BOS_TOKEN_ID = 16384
28
- BASE_MODEL = 'facebook/bart-large-cnn'
29
- WANDB_MODEL = '3iwhu4w6'
30
-
31
- class CustomFlaxBartModule(FlaxBartModule):
32
- def setup(self):
33
- # we keep shared to easily load pre-trained weights
34
- self.shared = nn.Embed(
35
- self.config.vocab_size,
36
- self.config.d_model,
37
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
38
- dtype=self.dtype,
39
- )
40
- # a separate embedding is used for the decoder
41
- self.decoder_embed = nn.Embed(
42
- OUTPUT_VOCAB_SIZE,
43
- self.config.d_model,
44
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
45
- dtype=self.dtype,
46
- )
47
- self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
48
-
49
- # the decoder has a different config
50
- decoder_config = BartConfig(self.config.to_dict())
51
- decoder_config.max_position_embeddings = OUTPUT_LENGTH
52
- decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
53
- self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
54
-
55
- class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
56
- def setup(self):
57
- self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
58
- self.lm_head = nn.Dense(
59
- OUTPUT_VOCAB_SIZE,
60
- use_bias=False,
61
- dtype=self.dtype,
62
- kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
63
- )
64
- self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))
65
-
66
- class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
67
- module_class = CustomFlaxBartForConditionalGenerationModule
68
-
69
- tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)
70
- vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384")
71
-
72
- def custom_to_pil(x):
73
- x = np.clip(x, 0., 1.)
74
- x = (255*x).astype(np.uint8)
75
- x = Image.fromarray(x)
76
- if not x.mode == "RGB":
77
- x = x.convert("RGB")
78
- return x
79
-
80
- def generate(input, rng, params):
81
- return model.generate(
82
- **input,
83
- max_length=257,
84
- num_beams=1,
85
- do_sample=True,
86
- prng_key=rng,
87
- eos_token_id=50000,
88
- pad_token_id=50000,
89
- params=params,
90
- )
91
-
92
- def get_images(indices, params):
93
- return vqgan.decode_code(indices, params=params)
94
-
95
- def plot_images(images):
96
- fig = plt.figure(figsize=(40, 20))
97
- columns = 4
98
- rows = 2
99
- plt.subplots_adjust(hspace=0, wspace=0)
100
-
101
- for i in range(1, columns*rows +1):
102
- fig.add_subplot(rows, columns, i)
103
- plt.imshow(images[i-1])
104
- plt.gca().axes.get_yaxis().set_visible(False)
105
- plt.show()
106
-
107
- def stack_reconstructions(images):
108
- w, h = images[0].size[0], images[0].size[1]
109
- img = Image.new("RGB", (len(images)*w, h))
110
- for i, img_ in enumerate(images):
111
- img.paste(img_, (i*w,0))
112
- return img
113
-
114
- p_generate = jax.pmap(generate, "batch")
115
- p_get_images = jax.pmap(get_images, "batch")
116
-
117
- # ## CLIP Scoring
118
- from transformers import CLIPProcessor, FlaxCLIPModel
119
-
120
- clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
121
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
122
-
123
- def hallucinate(prompt, num_images=64):
124
- prompt = [prompt] * jax.device_count()
125
- inputs = tokenizer(prompt, return_tensors='jax', padding="max_length", truncation=True, max_length=128).data
126
- inputs = shard(inputs)
127
-
128
- all_images = []
129
- for i in range(num_images // jax.device_count()):
130
- key = random.randint(0, 1e7)
131
- rng = jax.random.PRNGKey(key)
132
- rngs = jax.random.split(rng, jax.local_device_count())
133
- indices = p_generate(inputs, rngs, bart_params).sequences
134
- indices = indices[:, :, 1:]
135
-
136
- images = p_get_images(indices, vqgan_params)
137
- images = np.squeeze(np.asarray(images), 1)
138
- for image in images:
139
- all_images.append(custom_to_pil(image))
140
- return all_images
141
-
142
- def clip_top_k(prompt, images, k=8):
143
- inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
144
- outputs = clip(**inputs)
145
- logits = outputs.logits_per_text
146
- scores = np.array(logits[0]).argsort()[-k:][::-1]
147
- return [images[score] for score in scores]
148
-
149
- from PIL import ImageDraw, ImageFont
150
-
151
- def captioned_strip(images, caption):
152
- w, h = images[0].size[0], images[0].size[1]
153
- img = Image.new("RGB", (len(images)*w, h + 48))
154
- for i, img_ in enumerate(images):
155
- img.paste(img_, (i*w, 48))
156
- draw = ImageDraw.Draw(img)
157
- font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
158
- draw.text((20, 3), caption, (255,255,255), font=font)
159
- return img
160
-
161
- def log_to_wandb(prompts):
162
- strips = []
163
- for prompt in prompts:
164
- print(f"Generating candidates for: {prompt}")
165
- images = hallucinate(prompt, num_images=32)
166
- selected = clip_top_k(prompt, images, k=8)
167
- strip = captioned_strip(selected, prompt)
168
- strips.append(wandb.Image(strip))
169
- wandb.log({"images": strips})
170
-
171
- ## Artifact loop
172
-
173
- import wandb
174
- import os
175
- os.environ["WANDB_SILENT"] = "true"
176
- os.environ["WANDB_CONSOLE"] = "off"
177
-
178
- id = wandb.util.generate_id()
179
- print(f"Logging images to wandb run id: {id}")
180
-
181
- run = wandb.init(id=id,
182
- entity='wandb',
183
- project="hf-flax-dalle-mini",
184
- job_type="predictions",
185
- resume="allow"
186
- )
187
-
188
- artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-3iwhu4w6:v0', type='bart_model')
189
- producer_run = artifact.logged_by()
190
- logged_artifacts = producer_run.logged_artifacts()
191
-
192
- for artifact in logged_artifacts:
193
- print(f"Generating predictions with version {artifact.version}")
194
- artifact_dir = artifact.download()
195
-
196
- # create our model
197
- model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
198
- model.config.force_bos_token_to_be_generated = False
199
- model.config.forced_bos_token_id = None
200
- model.config.forced_eos_token_id = None
201
-
202
- bart_params = replicate(model.params)
203
- vqgan_params = replicate(vqgan.params)
204
-
205
- prompts = prompts = [
206
- "white snow covered mountain under blue sky during daytime",
207
- "aerial view of beach during daytime",
208
- "aerial view of beach at night",
209
- "an armchair in the shape of an avocado",
210
- "young woman riding her bike trough a forest",
211
- "rice fields by the mediterranean coast",
212
- "white houses on the hill of a greek coastline",
213
- "illustration of a shark with a baby shark",
214
- ]
215
-
216
- log_to_wandb(prompts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/notebooks/demo/tpu-demo.ipynb DELETED
@@ -1,446 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "f6d33374",
6
- "metadata": {},
7
- "source": [
8
- "# Test notebook with CLIP scoring"
9
- ]
10
- },
11
- {
12
- "cell_type": "code",
13
- "execution_count": null,
14
- "id": "6eb74941-bb4d-4d7e-97f1-d5a3a07672bf",
15
- "metadata": {},
16
- "outputs": [],
17
- "source": [
18
- "# !pip install flax transformers\n",
19
- "# !git clone https://github.com/patil-suraj/vqgan-jax.git"
20
- ]
21
- },
22
- {
23
- "cell_type": "code",
24
- "execution_count": null,
25
- "id": "41db7534-f589-4b63-9165-9c9799e1b06e",
26
- "metadata": {},
27
- "outputs": [],
28
- "source": [
29
- "import random\n",
30
- "\n",
31
- "import jax\n",
32
- "import flax.linen as nn\n",
33
- "from flax.training.common_utils import shard\n",
34
- "from flax.jax_utils import replicate, unreplicate\n",
35
- "\n",
36
- "from transformers.models.bart.modeling_flax_bart import *\n",
37
- "from transformers import BartTokenizer, FlaxBartForConditionalGeneration\n",
38
- "\n",
39
- "import io\n",
40
- "\n",
41
- "import requests\n",
42
- "from PIL import Image\n",
43
- "import numpy as np\n",
44
- "import matplotlib.pyplot as plt\n",
45
- "\n",
46
- "import torch\n",
47
- "import torchvision.transforms as T\n",
48
- "import torchvision.transforms.functional as TF\n",
49
- "from torchvision.transforms import InterpolationMode\n",
50
- "\n",
51
- "jax.devices()"
52
- ]
53
- },
54
- {
55
- "cell_type": "code",
56
- "execution_count": null,
57
- "id": "09295910",
58
- "metadata": {},
59
- "outputs": [],
60
- "source": [
61
- "from vqgan_jax.modeling_flax_vqgan import VQModel"
62
- ]
63
- },
64
- {
65
- "cell_type": "code",
66
- "execution_count": null,
67
- "id": "b6a3462a-9004-4121-b365-3ae3aaf94dd2",
68
- "metadata": {},
69
- "outputs": [],
70
- "source": [
71
- "# TODO: set those args in a config file\n",
72
- "OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n",
73
- "OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n",
74
- "BOS_TOKEN_ID = 16384\n",
75
- "BASE_MODEL = 'facebook/bart-large-cnn'"
76
- ]
77
- },
78
- {
79
- "cell_type": "code",
80
- "execution_count": null,
81
- "id": "bbef1afb-0b36-44a5-83f7-643d7e2c0e30",
82
- "metadata": {},
83
- "outputs": [],
84
- "source": [
85
- "class CustomFlaxBartModule(FlaxBartModule):\n",
86
- " def setup(self):\n",
87
- " # we keep shared to easily load pre-trained weights\n",
88
- " self.shared = nn.Embed(\n",
89
- " self.config.vocab_size,\n",
90
- " self.config.d_model,\n",
91
- " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
92
- " dtype=self.dtype,\n",
93
- " )\n",
94
- " # a separate embedding is used for the decoder\n",
95
- " self.decoder_embed = nn.Embed(\n",
96
- " OUTPUT_VOCAB_SIZE,\n",
97
- " self.config.d_model,\n",
98
- " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
99
- " dtype=self.dtype,\n",
100
- " )\n",
101
- " self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n",
102
- "\n",
103
- " # the decoder has a different config\n",
104
- " decoder_config = BartConfig(self.config.to_dict())\n",
105
- " decoder_config.max_position_embeddings = OUTPUT_LENGTH\n",
106
- " decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n",
107
- " self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n",
108
- "\n",
109
- "class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n",
110
- " def setup(self):\n",
111
- " self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n",
112
- " self.lm_head = nn.Dense(\n",
113
- " OUTPUT_VOCAB_SIZE,\n",
114
- " use_bias=False,\n",
115
- " dtype=self.dtype,\n",
116
- " kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
117
- " )\n",
118
- " self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n",
119
- "\n",
120
- "class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n",
121
- " module_class = CustomFlaxBartForConditionalGenerationModule"
122
- ]
123
- },
124
- {
125
- "cell_type": "code",
126
- "execution_count": null,
127
- "id": "879320b7-eaa0-4dc9-bbf2-c81efc53301d",
128
- "metadata": {},
129
- "outputs": [],
130
- "source": [
131
- "import wandb\n",
132
- "run = wandb.init()\n",
133
- "artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-3h3x3565:latest', type='bart_model')\n",
134
- "artifact_dir = artifact.download()"
135
- ]
136
- },
137
- {
138
- "cell_type": "code",
139
- "execution_count": null,
140
- "id": "e8bcff33-e95b-4c01-b162-ee857a55c3e6",
141
- "metadata": {},
142
- "outputs": [],
143
- "source": [
144
- "# create our model\n",
145
- "tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)\n",
146
- "model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)\n",
147
- "model.config.force_bos_token_to_be_generated = False\n",
148
- "model.config.forced_bos_token_id = None\n",
149
- "model.config.forced_eos_token_id = None\n",
150
- "\n",
151
- "# we verify that the shape has not been modified\n",
152
- "model.params['final_logits_bias'].shape"
153
- ]
154
- },
155
- {
156
- "cell_type": "code",
157
- "execution_count": null,
158
- "id": "8d5e0f14-2502-470e-9553-daee6748601f",
159
- "metadata": {},
160
- "outputs": [],
161
- "source": [
162
- "vqgan = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
163
- ]
164
- },
165
- {
166
- "cell_type": "code",
167
- "execution_count": null,
168
- "id": "6cca395a-93c2-49bc-a3be-98287e4403d4",
169
- "metadata": {},
170
- "outputs": [],
171
- "source": [
172
- "def custom_to_pil(x):\n",
173
- " x = np.clip(x, 0., 1.)\n",
174
- " x = (255*x).astype(np.uint8)\n",
175
- " x = Image.fromarray(x)\n",
176
- " if not x.mode == \"RGB\":\n",
177
- " x = x.convert(\"RGB\")\n",
178
- " return x\n",
179
- "\n",
180
- "def generate(input, rng, params):\n",
181
- " return model.generate(\n",
182
- " **input,\n",
183
- " max_length=257,\n",
184
- " num_beams=1,\n",
185
- " do_sample=True,\n",
186
- " prng_key=rng,\n",
187
- " eos_token_id=50000,\n",
188
- " pad_token_id=50000,\n",
189
- " params=params\n",
190
- " )\n",
191
- "\n",
192
- "def get_images(indices, params):\n",
193
- " return vqgan.decode_code(indices, params=params)\n",
194
- "\n",
195
- "\n",
196
- "def plot_images(images):\n",
197
- " fig = plt.figure(figsize=(40, 20))\n",
198
- " columns = 4\n",
199
- " rows = 2\n",
200
- " plt.subplots_adjust(hspace=0, wspace=0)\n",
201
- "\n",
202
- " for i in range(1, columns*rows +1):\n",
203
- " fig.add_subplot(rows, columns, i)\n",
204
- " plt.imshow(images[i-1])\n",
205
- " plt.gca().axes.get_yaxis().set_visible(False)\n",
206
- " plt.show()\n",
207
- " \n",
208
- "def stack_reconstructions(images):\n",
209
- " w, h = images[0].size[0], images[0].size[1]\n",
210
- " img = Image.new(\"RGB\", (len(images)*w, h))\n",
211
- " for i, img_ in enumerate(images):\n",
212
- " img.paste(img_, (i*w,0))\n",
213
- " return img"
214
- ]
215
- },
216
- {
217
- "cell_type": "code",
218
- "execution_count": null,
219
- "id": "b1bec3d2-ef17-4feb-aa0d-b51ed2fdcd3e",
220
- "metadata": {},
221
- "outputs": [],
222
- "source": [
223
- "p_generate = jax.pmap(generate, \"batch\")\n",
224
- "p_get_images = jax.pmap(get_images, \"batch\")"
225
- ]
226
- },
227
- {
228
- "cell_type": "code",
229
- "execution_count": null,
230
- "id": "a539823a-a775-4d92-96a5-dc8b1eef69c5",
231
- "metadata": {},
232
- "outputs": [],
233
- "source": [
234
- "bart_params = replicate(model.params)\n",
235
- "vqgan_params = replicate(vqgan.params)"
236
- ]
237
- },
238
- {
239
- "cell_type": "code",
240
- "execution_count": null,
241
- "id": "e8b268d8-6992-422a-8373-95651474ae70",
242
- "metadata": {},
243
- "outputs": [],
244
- "source": [
245
- "prompts = [\n",
246
- " \"man in blue jacket walking on pathway in between trees during daytime\",\n",
247
- " 'white snow covered mountain under blue sky during daytime',\n",
248
- " 'white snow covered mountain under blue sky during night',\n",
249
- " \"orange tabby cat on persons hand\",\n",
250
- " \"aerial view of beach during daytime\",\n",
251
- " \"chess pieces on chess board\",\n",
252
- " \"laptop on brown wooden table\",\n",
253
- " \"white bus on road near high rise buildings\",\n",
254
- "]\n",
255
- "\n",
256
- "\n",
257
- "prompt = [prompts[1]] * jax.device_count()\n",
258
- "inputs = tokenizer(prompt, return_tensors='jax', padding=\"max_length\", truncation=True, max_length=128).data\n",
259
- "inputs = shard(inputs)"
260
- ]
261
- },
262
- {
263
- "cell_type": "code",
264
- "execution_count": null,
265
- "id": "68638cfa-9a4d-4e6a-8630-91aefb627bbd",
266
- "metadata": {},
267
- "outputs": [],
268
- "source": [
269
- "%%time\n",
270
- "for i in range(8):\n",
271
- " key = random.randint(0, 1e7)\n",
272
- " rng = jax.random.PRNGKey(key)\n",
273
- " rngs = jax.random.split(rng, jax.local_device_count())\n",
274
- " indices = p_generate(inputs, rngs, bart_params).sequences\n",
275
- " indices = indices[:, :, 1:]\n",
276
- "\n",
277
- " images = p_get_images(indices, vqgan_params)\n",
278
- " images = np.squeeze(np.asarray(images), 1)\n",
279
- " imges = [custom_to_pil(image) for image in images]\n",
280
- "\n",
281
- " plt.figure(figsize=(40, 20))\n",
282
- " plt.imshow(stack_reconstructions(imges))"
283
- ]
284
- },
285
- {
286
- "cell_type": "markdown",
287
- "id": "b6e1060f",
288
- "metadata": {},
289
- "source": [
290
- "## CLIP Scoring"
291
- ]
292
- },
293
- {
294
- "cell_type": "code",
295
- "execution_count": null,
296
- "id": "c68724bc",
297
- "metadata": {},
298
- "outputs": [],
299
- "source": [
300
- "from transformers import CLIPProcessor, FlaxCLIPModel"
301
- ]
302
- },
303
- {
304
- "cell_type": "code",
305
- "execution_count": null,
306
- "id": "17158e5b",
307
- "metadata": {},
308
- "outputs": [],
309
- "source": [
310
- "clip = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
311
- "processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")"
312
- ]
313
- },
314
- {
315
- "cell_type": "code",
316
- "execution_count": null,
317
- "id": "f1b37b6d",
318
- "metadata": {},
319
- "outputs": [],
320
- "source": [
321
- "def hallucinate(prompt, num_images=64):\n",
322
- " prompt = [prompt] * jax.device_count()\n",
323
- " inputs = tokenizer(prompt, return_tensors='jax', padding=\"max_length\", truncation=True, max_length=128).data\n",
324
- " inputs = shard(inputs)\n",
325
- "\n",
326
- " all_images = []\n",
327
- " for i in range(num_images // jax.device_count()):\n",
328
- " key = random.randint(0, 1e7)\n",
329
- " rng = jax.random.PRNGKey(key)\n",
330
- " rngs = jax.random.split(rng, jax.local_device_count())\n",
331
- " indices = p_generate(inputs, rngs, bart_params).sequences\n",
332
- " indices = indices[:, :, 1:]\n",
333
- "\n",
334
- " images = p_get_images(indices, vqgan_params)\n",
335
- " images = np.squeeze(np.asarray(images), 1)\n",
336
- " for image in images:\n",
337
- " all_images.append(custom_to_pil(image))\n",
338
- " return all_images"
339
- ]
340
- },
341
- {
342
- "cell_type": "code",
343
- "execution_count": null,
344
- "id": "831c715f",
345
- "metadata": {},
346
- "outputs": [],
347
- "source": [
348
- "def clip_top_k(prompt, images, k=8):\n",
349
- " inputs = processor(text=prompt, images=images, return_tensors=\"np\", padding=True)\n",
350
- " outputs = clip(**inputs)\n",
351
- " logits = outputs.logits_per_text\n",
352
- " scores = np.array(logits[0]).argsort()[-k:][::-1]\n",
353
- " return [images[score] for score in scores]"
354
- ]
355
- },
356
- {
357
- "cell_type": "code",
358
- "execution_count": null,
359
- "id": "00605e13",
360
- "metadata": {},
361
- "outputs": [],
362
- "source": [
363
- "prompt = \"white snow covered mountain under blue sky during daytime\"\n",
364
- "images = hallucinate(prompt)\n",
365
- "selected = clip_top_k(prompt, images, k=8)\n",
366
- "stack_reconstructions(selected)"
367
- ]
368
- },
369
- {
370
- "cell_type": "code",
371
- "execution_count": null,
372
- "id": "cc745da2",
373
- "metadata": {},
374
- "outputs": [],
375
- "source": [
376
- "prompt = \"aerial view of beach at night\"\n",
377
- "images = hallucinate(prompt)\n",
378
- "selected = clip_top_k(prompt, images, k=8)\n",
379
- "stack_reconstructions(selected)"
380
- ]
381
- },
382
- {
383
- "cell_type": "code",
384
- "execution_count": null,
385
- "id": "c9cc0b1d",
386
- "metadata": {},
387
- "outputs": [],
388
- "source": [
389
- "prompt = \"an armchair in the shape of an avocado\"\n",
390
- "images = hallucinate(prompt)\n",
391
- "selected = clip_top_k(prompt, images, k=8)\n",
392
- "stack_reconstructions(selected)"
393
- ]
394
- },
395
- {
396
- "cell_type": "code",
397
- "execution_count": null,
398
- "id": "574e9433",
399
- "metadata": {},
400
- "outputs": [],
401
- "source": [
402
- "prompt = \"young woman riding her bike into a forest\"\n",
403
- "images = hallucinate(prompt)\n",
404
- "selected = clip_top_k(prompt, images, k=8)\n",
405
- "stack_reconstructions(selected)"
406
- ]
407
- },
408
- {
409
- "cell_type": "markdown",
410
- "id": "4762c91e",
411
- "metadata": {},
412
- "source": [
413
- "`Forest` seems to dominate. Interesting cubist interpretation in the fourth image."
414
- ]
415
- },
416
- {
417
- "cell_type": "code",
418
- "execution_count": null,
419
- "id": "af30608a",
420
- "metadata": {},
421
- "outputs": [],
422
- "source": []
423
- }
424
- ],
425
- "metadata": {
426
- "kernelspec": {
427
- "display_name": "Python 3 (ipykernel)",
428
- "language": "python",
429
- "name": "python3"
430
- },
431
- "language_info": {
432
- "codemirror_mode": {
433
- "name": "ipython",
434
- "version": 3
435
- },
436
- "file_extension": ".py",
437
- "mimetype": "text/x-python",
438
- "name": "python",
439
- "nbconvert_exporter": "python",
440
- "pygments_lexer": "ipython3",
441
- "version": "3.8.10"
442
- }
443
- },
444
- "nbformat": 4,
445
- "nbformat_minor": 5
446
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/notebooks/model/data-pipeline.ipynb DELETED
@@ -1,385 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "bf8fb38a",
6
- "metadata": {},
7
- "source": [
8
- "# Data Pipeline"
9
- ]
10
- },
11
- {
12
- "cell_type": "code",
13
- "execution_count": 1,
14
- "id": "9b83dcb9",
15
- "metadata": {},
16
- "outputs": [],
17
- "source": [
18
- "from dataclasses import dataclass, field\n",
19
- "from pathlib import Path\n",
20
- "\n",
21
- "import datasets\n",
22
- "from datasets import Dataset, load_dataset\n",
23
- "import numpy as np\n",
24
- "\n",
25
- "from transformers import BartTokenizer\n",
26
- "\n",
27
- "from tqdm import tqdm\n",
28
- "\n",
29
- "import jax\n",
30
- "import jax.numpy as jnp\n",
31
- "\n",
32
- "from flax.training.common_utils import shard"
33
- ]
34
- },
35
- {
36
- "cell_type": "markdown",
37
- "id": "a661a89e",
38
- "metadata": {},
39
- "source": [
40
- "File containing image paths, captions and VQGAN-encoded indices."
41
- ]
42
- },
43
- {
44
- "cell_type": "code",
45
- "execution_count": 2,
46
- "id": "0e84e889",
47
- "metadata": {},
48
- "outputs": [],
49
- "source": [
50
- "datafile = '/data/CC12M/images-encoded-10000.tsv' # 9999 encoded images from CC12M"
51
- ]
52
- },
53
- {
54
- "cell_type": "markdown",
55
- "id": "7fdc640b",
56
- "metadata": {},
57
- "source": [
58
- "TODO: generate train/test splits if necessary."
59
- ]
60
- },
61
- {
62
- "cell_type": "code",
63
- "execution_count": 3,
64
- "id": "cc6789b4",
65
- "metadata": {},
66
- "outputs": [
67
- {
68
- "name": "stderr",
69
- "output_type": "stream",
70
- "text": [
71
- "Using custom data configuration default-91833df78e844785\n",
72
- "Reusing dataset csv (/home/pedro/.cache/huggingface/datasets/csv/default-91833df78e844785/0.0.0/e138af468cb14e747fb46a19c787ffcfa5170c821476d20d5304287ce12bbc23)\n"
73
- ]
74
- }
75
- ],
76
- "source": [
77
- "dataset = load_dataset('csv', delimiter='\\t', data_files=[datafile])"
78
- ]
79
- },
80
- {
81
- "cell_type": "code",
82
- "execution_count": 4,
83
- "id": "f3ed4919",
84
- "metadata": {},
85
- "outputs": [
86
- {
87
- "data": {
88
- "text/plain": [
89
- "DatasetDict({\n",
90
- " train: Dataset({\n",
91
- " features: ['image_file', 'caption', 'encoding'],\n",
92
- " num_rows: 9999\n",
93
- " })\n",
94
- "})"
95
- ]
96
- },
97
- "execution_count": 4,
98
- "metadata": {},
99
- "output_type": "execute_result"
100
- }
101
- ],
102
- "source": [
103
- "dataset"
104
- ]
105
- },
106
- {
107
- "cell_type": "code",
108
- "execution_count": 5,
109
- "id": "a70c7354",
110
- "metadata": {},
111
- "outputs": [
112
- {
113
- "data": {
114
- "text/plain": [
115
- "Dataset({\n",
116
- " features: ['image_file', 'caption', 'encoding'],\n",
117
- " num_rows: 9999\n",
118
- "})"
119
- ]
120
- },
121
- "execution_count": 5,
122
- "metadata": {},
123
- "output_type": "execute_result"
124
- }
125
- ],
126
- "source": [
127
- "dataset = dataset[\"train\"]\n",
128
- "dataset"
129
- ]
130
- },
131
- {
132
- "cell_type": "markdown",
133
- "id": "a73454cf",
134
- "metadata": {},
135
- "source": [
136
- "We don't really need the `image_file` field for training. We'll drop it during pre-processing because we won't be able to numericalize it to a `jnp.array`, which would be required in JAX."
137
- ]
138
- },
139
- {
140
- "cell_type": "markdown",
141
- "id": "7c0fa992",
142
- "metadata": {},
143
- "source": [
144
- "## Preprocessing"
145
- ]
146
- },
147
- {
148
- "cell_type": "markdown",
149
- "id": "a0e36582",
150
- "metadata": {},
151
- "source": [
152
- "The `encoding` field contains a string representation of the encoded indices. We'll convert them to numbers. We also need to tokenize the captions."
153
- ]
154
- },
155
- {
156
- "cell_type": "code",
157
- "execution_count": 6,
158
- "id": "d46f6ac5",
159
- "metadata": {},
160
- "outputs": [],
161
- "source": [
162
- "# Setting padding=\"max_length\" as we need fixed length inputs for jitted functions\n",
163
- "max_length = 256 # Read from data_args.max_source_length\n",
164
- "tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')\n",
165
- "image_bos = 16384 # Max token is 16383 in our VQGAN configuration"
166
- ]
167
- },
168
- {
169
- "cell_type": "code",
170
- "execution_count": 7,
171
- "id": "4cac6643",
172
- "metadata": {},
173
- "outputs": [],
174
- "source": [
175
- "def preprocess_function(examples):\n",
176
- " inputs = examples[\"caption\"]\n",
177
- "# inputs = [prefix + inp for inp in inputs] # Do we need this?\n",
178
- " model_inputs = tokenizer(\n",
179
- " inputs, max_length=max_length, padding=\"max_length\", truncation=True, return_tensors=\"np\"\n",
180
- " )\n",
181
- "\n",
182
- " model_inputs[\"labels\"] = [[image_bos] + eval(indices) for indices in examples['encoding']]\n",
183
- "\n",
184
- " return model_inputs"
185
- ]
186
- },
187
- {
188
- "cell_type": "code",
189
- "execution_count": 8,
190
- "id": "e6a4cb91",
191
- "metadata": {},
192
- "outputs": [],
193
- "source": [
194
- "num_workers = 48 # We have 96 processors in the TPU\n",
195
- "column_names = dataset.column_names\n",
196
- "input_dataset = dataset.map(preprocess_function,\n",
197
- " remove_columns=column_names,\n",
198
- " batched=True,\n",
199
- " num_proc=48\n",
200
- ")"
201
- ]
202
- },
203
- {
204
- "cell_type": "code",
205
- "execution_count": 9,
206
- "id": "a9b1b467",
207
- "metadata": {},
208
- "outputs": [],
209
- "source": [
210
- "def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):\n",
211
- " \"\"\"\n",
212
- " Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.\n",
213
- " Shuffle batches if `shuffle` is `True`.\n",
214
- " \"\"\"\n",
215
- " steps_per_epoch = len(dataset) // batch_size\n",
216
- "\n",
217
- " if shuffle:\n",
218
- " batch_idx = jax.random.permutation(rng, len(dataset))\n",
219
- " else:\n",
220
- " batch_idx = jnp.arange(len(dataset))\n",
221
- "\n",
222
- " batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.\n",
223
- " batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))\n",
224
- "\n",
225
- " for idx in batch_idx:\n",
226
- " batch = dataset[idx] \n",
227
- " batch = {k: jnp.array(v) for k, v in batch.items()}\n",
228
- " batch = shard(batch)\n",
229
- " yield batch"
230
- ]
231
- },
232
- {
233
- "cell_type": "code",
234
- "execution_count": 10,
235
- "id": "0a628505",
236
- "metadata": {},
237
- "outputs": [
238
- {
239
- "name": "stderr",
240
- "output_type": "stream",
241
- "text": [
242
- "INFO:absl:Starting the local TPU driver.\n",
243
- "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
244
- "INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Host TPU Interpreter\n"
245
- ]
246
- }
247
- ],
248
- "source": [
249
- "rng = jax.random.PRNGKey(23) # Use training_args.seed\n",
250
- "batch_size = 64 # Per device\n",
251
- "super_batch_size = batch_size * jax.device_count()"
252
- ]
253
- },
254
- {
255
- "cell_type": "code",
256
- "execution_count": 11,
257
- "id": "b3a5ce7d",
258
- "metadata": {},
259
- "outputs": [],
260
- "source": [
261
- "loader = data_loader(rng, input_dataset, batch_size=super_batch_size)"
262
- ]
263
- },
264
- {
265
- "cell_type": "code",
266
- "execution_count": 12,
267
- "id": "67aa8f9c",
268
- "metadata": {},
269
- "outputs": [],
270
- "source": [
271
- "superbatch = next(iter(loader))"
272
- ]
273
- },
274
- {
275
- "cell_type": "code",
276
- "execution_count": 13,
277
- "id": "7cd99402",
278
- "metadata": {},
279
- "outputs": [
280
- {
281
- "data": {
282
- "text/plain": [
283
- "dict_keys(['attention_mask', 'input_ids', 'labels'])"
284
- ]
285
- },
286
- "execution_count": 13,
287
- "metadata": {},
288
- "output_type": "execute_result"
289
- }
290
- ],
291
- "source": [
292
- "superbatch.keys()"
293
- ]
294
- },
295
- {
296
- "cell_type": "code",
297
- "execution_count": 14,
298
- "id": "652a4a9e",
299
- "metadata": {},
300
- "outputs": [
301
- {
302
- "data": {
303
- "text/plain": [
304
- "8"
305
- ]
306
- },
307
- "execution_count": 14,
308
- "metadata": {},
309
- "output_type": "execute_result"
310
- }
311
- ],
312
- "source": [
313
- "len(superbatch[\"labels\"])"
314
- ]
315
- },
316
- {
317
- "cell_type": "code",
318
- "execution_count": 15,
319
- "id": "de7de4e8",
320
- "metadata": {},
321
- "outputs": [
322
- {
323
- "data": {
324
- "text/plain": [
325
- "(8, 64, 257)"
326
- ]
327
- },
328
- "execution_count": 15,
329
- "metadata": {},
330
- "output_type": "execute_result"
331
- }
332
- ],
333
- "source": [
334
- "superbatch[\"labels\"].shape"
335
- ]
336
- },
337
- {
338
- "cell_type": "markdown",
339
- "id": "6800153b",
340
- "metadata": {},
341
- "source": [
342
- "Any image sequence should begin with `image_bos`:"
343
- ]
344
- },
345
- {
346
- "cell_type": "code",
347
- "execution_count": 16,
348
- "id": "cfe23a71",
349
- "metadata": {},
350
- "outputs": [],
351
- "source": [
352
- "assert superbatch[\"labels\"][1][5][0].item() == image_bos"
353
- ]
354
- },
355
- {
356
- "cell_type": "code",
357
- "execution_count": null,
358
- "id": "0fb899b4",
359
- "metadata": {},
360
- "outputs": [],
361
- "source": []
362
- }
363
- ],
364
- "metadata": {
365
- "kernelspec": {
366
- "display_name": "Python 3 (ipykernel)",
367
- "language": "python",
368
- "name": "python3"
369
- },
370
- "language_info": {
371
- "codemirror_mode": {
372
- "name": "ipython",
373
- "version": 3
374
- },
375
- "file_extension": ".py",
376
- "mimetype": "text/x-python",
377
- "name": "python",
378
- "nbconvert_exporter": "python",
379
- "pygments_lexer": "ipython3",
380
- "version": "3.8.10"
381
- }
382
- },
383
- "nbformat": 4,
384
- "nbformat_minor": 5
385
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/predictions/wandb-examples-from-backend.py DELETED
@@ -1,52 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
-
4
- from PIL import Image, ImageDraw, ImageFont
5
- import wandb
6
- import os
7
-
8
- from dalle_mini.backend import ServiceError, get_images_from_backend
9
- from dalle_mini.helpers import captioned_strip
10
-
11
- os.environ["WANDB_SILENT"] = "true"
12
- os.environ["WANDB_CONSOLE"] = "off"
13
-
14
- # set id to None so our latest images don't get overwritten
15
- id = None
16
- run = wandb.init(id=id,
17
- entity='wandb',
18
- project="hf-flax-dalle-mini",
19
- job_type="predictions",
20
- resume="allow"
21
- )
22
-
23
- def log_to_wandb(prompts):
24
- try:
25
- backend_url = os.environ["BACKEND_SERVER"]
26
-
27
- strips = []
28
- for prompt in prompts:
29
- print(f"Getting selections for: {prompt}")
30
- selected = get_images_from_backend(prompt, backend_url)
31
- strip = captioned_strip(selected, prompt)
32
- strips.append(wandb.Image(strip))
33
- wandb.log({"images": strips})
34
- except ServiceError as error:
35
- print(f"Service unavailable, status: {error.status_code}")
36
- except KeyError:
37
- print("Error: BACKEND_SERVER unset")
38
-
39
- prompts = [
40
- "white snow covered mountain under blue sky during daytime",
41
- "aerial view of beach during daytime",
42
- "aerial view of beach at night",
43
- "an armchair in the shape of an avocado",
44
- "a logo of an avocado armchair playing music",
45
- "young woman riding her bike trough a forest",
46
- "rice fields by the mediterranean coast",
47
- "white houses on the hill of a greek coastline",
48
- "illustration of a shark with a baby shark",
49
- "painting of an oniric forest glade surrounded by tall trees",
50
- ]
51
-
52
- log_to_wandb(prompts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/{notebooks/vqgan → vqgan}/JAX_VQGAN_f16_16384_Reconstruction.ipynb RENAMED
File without changes