boris commited on
Commit
fcac23a
2 Parent(s): 00ed1ab 0497ad3

Merge branch 'main' into chore-cleanup2

Browse files

Former-commit-id: fceea4c2ebba28dd9c0b43b92c2a0af41fc18bb3

README.md CHANGED
@@ -3,8 +3,8 @@ title: Dalle Mini
3
  emoji: 🎨
4
  colorFrom: red
5
  colorTo: blue
6
- sdk: streamlit
7
- app_file: app/app.py
8
  pinned: false
9
  ---
10
 
@@ -18,7 +18,7 @@ TODO: add some cool example
18
 
19
  ## How does it work?
20
 
21
- Refer to [our report](TODO).
22
 
23
  ## Development
24
 
 
3
  emoji: 🎨
4
  colorFrom: red
5
  colorTo: blue
6
+ sdk: gradio
7
+ app_file: app/app_gradio.py
8
  pinned: false
9
  ---
10
 
 
18
 
19
  ## How does it work?
20
 
21
+ Refer to [our report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA).
22
 
23
  ## Development
24
 
app/app_gradio.py CHANGED
@@ -163,7 +163,7 @@ def clip_top_k(prompt, images, k=8):
163
  scores = np.array(logits[0]).argsort()[-k:][::-1]
164
  return [images[score] for score in scores]
165
 
166
- def captioned_strip(images, caption):
167
  increased_h = 0 if caption is None else 48
168
  w, h = images[0].size[0], images[0].size[1]
169
  img = Image.new("RGB", (len(images)*w, h + increased_h))
@@ -176,19 +176,55 @@ def captioned_strip(images, caption):
176
  draw.text((20, 3), caption, (255,255,255), font=font)
177
  return img
178
 
179
- def run_inference(prompt, num_images=64, num_preds=8):
180
- images = hallucinate(prompt, num_images=num_images)
181
- images = clip_top_k(prompt, images, k=num_preds)
182
- predictions_strip = captioned_strip(images, None)
183
- return predictions_strip
184
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  gr.Interface(run_inference,
186
  inputs=[gr.inputs.Textbox(label='Prompt')], #, gr.inputs.Slider(1,64,1,8, label='Candidates to generate'), gr.inputs.Slider(1,8,1,1, label='Best predictions to show')],
187
- outputs=gr.outputs.Image(label='Generated image'),
188
- title='DALLE-mini - HuggingFace Community Week',
189
- description='This is a demo of the DALLE-mini model trained with Jax/Flax on TPU v3-8s during the HuggingFace Community Week',
190
- article="<p style='text-align: center'> DALLE-mini by Boris Dayma et al. | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a></p>",
191
  layout='vertical',
192
  theme='huggingface',
193
- examples=[['an armchair in the shape of an avocado']],
194
- server_port=8999).launch(share=True)
 
 
 
 
163
  scores = np.array(logits[0]).argsort()[-k:][::-1]
164
  return [images[score] for score in scores]
165
 
166
+ def compose_predictions(images, caption=None):
167
  increased_h = 0 if caption is None else 48
168
  w, h = images[0].size[0], images[0].size[1]
169
  img = Image.new("RGB", (len(images)*w, h + increased_h))
 
176
  draw.text((20, 3), caption, (255,255,255), font=font)
177
  return img
178
 
179
+ def top_k_predictions(prompt, num_candidates=32, k=8):
180
+ images = hallucinate(prompt, num_images=num_candidates)
181
+ images = clip_top_k(prompt, images, k=k)
182
+ return images
183
+
184
+ def run_inference(prompt, num_images=32, num_preds=8):
185
+ images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
186
+ predictions = compose_predictions(images)
187
+ output_title = f"""
188
+ <p style="font-size:22px; font-style:bold">Best predictions</p>
189
+ <p>We asked our model to generate 32 candidates for your prompt:</p>
190
+
191
+ <pre>
192
+
193
+ <b>{prompt}</b>
194
+ </pre>
195
+ <p>We then used a pre-trained <a href="https://huggingface.co/openai/clip-vit-base-patch32">CLIP model</a> to score them according to the
196
+ similarity of the text and the image representations.</p>
197
+
198
+ <p>This is the result:</p>
199
+ """
200
+ output_description = """
201
+ <p>Read more about the process <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA">in our report</a>.<p>
202
+ <p style='text-align: center'>Created with <a href="https://github.com/borisdayma/dalle-mini">DALLE·mini</a></p>
203
+ """
204
+ return (output_title, predictions, output_description)
205
+
206
+ outputs = [
207
+ gr.outputs.HTML(label=""), # To be used as title
208
+ gr.outputs.Image(label=''),
209
+ gr.outputs.HTML(label=""), # Additional text that appears in the screenshot
210
+ ]
211
+
212
+ description = """
213
+ Welcome to our demo of DALL·E-mini. This project was created on TPU v3-8s during the 🤗 Flax / JAX Community Week.
214
+ It reproduces the essential characteristics of OpenAI's DALL·E, at a fraction of the size.
215
+
216
+ Please, write what you would like the model to generate, or select one of the examples below.
217
+ """
218
  gr.Interface(run_inference,
219
  inputs=[gr.inputs.Textbox(label='Prompt')], #, gr.inputs.Slider(1,64,1,8, label='Candidates to generate'), gr.inputs.Slider(1,8,1,1, label='Best predictions to show')],
220
+ outputs=outputs,
221
+ title='DALL·E mini',
222
+ description=description,
223
+ article="<p style='text-align: center'> DALLE·mini by Boris Dayma et al. | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a></p>",
224
  layout='vertical',
225
  theme='huggingface',
226
+ examples=[['an armchair in the shape of an avocado'], ['snowy mountains by the sea']],
227
+ allow_flagging=False,
228
+ live=False,
229
+ # server_port=8999
230
+ ).launch()
app/sample_images/image_0.jpg ADDED
app/sample_images/image_1.jpg ADDED
app/sample_images/image_2.jpg ADDED
app/sample_images/image_3.jpg ADDED
app/sample_images/image_4.jpg ADDED
app/sample_images/image_5.jpg ADDED
app/sample_images/image_6.jpg ADDED
app/sample_images/image_7.jpg ADDED
app/sample_images/readme.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ These images were generated by one of our checkpoints, as responses to the prompt "snowy mountains by the sea".
app/ui_gradio.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ from PIL import Image
5
+ import gradio as gr
6
+
7
+ def compose_predictions(images, caption=None):
8
+ increased_h = 0 if caption is None else 48
9
+ w, h = images[0].size[0], images[0].size[1]
10
+ img = Image.new("RGB", (len(images)*w, h + increased_h))
11
+ for i, img_ in enumerate(images):
12
+ img.paste(img_, (i*w, increased_h))
13
+
14
+ if caption is not None:
15
+ draw = ImageDraw.Draw(img)
16
+ font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
17
+ draw.text((20, 3), caption, (255,255,255), font=font)
18
+ return img
19
+
20
+ def compose_predictions_grid(images):
21
+ cols = 4
22
+ rows = len(images) // cols
23
+ w, h = images[0].size[0], images[0].size[1]
24
+ img = Image.new("RGB", (w * cols, h * rows))
25
+ for i, img_ in enumerate(images):
26
+ row = i // cols
27
+ col = i % cols
28
+ img.paste(img_, (w * col, h * row))
29
+ return img
30
+
31
+ def top_k_predictions_real(prompt, num_candidates=32, k=8):
32
+ images = hallucinate(prompt, num_images=num_candidates)
33
+ images = clip_top_k(prompt, images, k=num_preds)
34
+ return images
35
+
36
+ def top_k_predictions(prompt, num_candidates=32, k=8):
37
+ images = []
38
+ for i in range(k):
39
+ image = Image.open(f"sample_images/image_{i}.jpg")
40
+ images.append(image)
41
+ return images
42
+
43
+ def run_inference(prompt, num_images=32, num_preds=8):
44
+ images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
45
+ predictions = compose_predictions(images)
46
+ output_title = f"""
47
+ <p style="font-size:22px; font-style:bold">Best predictions</p>
48
+ <p>We asked our model to generate 32 candidates for your prompt:</p>
49
+
50
+ <pre>
51
+
52
+ <b>{prompt}</b>
53
+ </pre>
54
+ <p>We then used a pre-trained <a href="https://huggingface.co/openai/clip-vit-base-patch32">CLIP model</a> to score them according to the
55
+ similarity of the text and the image representations.</p>
56
+
57
+ <p>This is the result:</p>
58
+ """
59
+ output_description = """
60
+ <p>Read more about the process <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA">in our report</a>.<p>
61
+ <p style='text-align: center'>Created with <a href="https://github.com/borisdayma/dalle-mini">DALLE·mini</a></p>
62
+ """
63
+ return (output_title, predictions, output_description)
64
+
65
+ outputs = [
66
+ gr.outputs.HTML(label=""), # To be used as title
67
+ gr.outputs.Image(label=''),
68
+ gr.outputs.HTML(label=""), # Additional text that appears in the screenshot
69
+ ]
70
+
71
+ description = """
72
+ Welcome to our demo of DALL·E-mini. This project was created on TPU v3-8s during the 🤗 Flax / JAX Community Week.
73
+ It reproduces the essential characteristics of OpenAI's DALL·E, at a fraction of the size.
74
+
75
+ Please, write what you would like the model to generate, or select one of the examples below.
76
+ """
77
+ gr.Interface(run_inference,
78
+ inputs=[gr.inputs.Textbox(label='Prompt')], #, gr.inputs.Slider(1,64,1,8, label='Candidates to generate'), gr.inputs.Slider(1,8,1,1, label='Best predictions to show')],
79
+ outputs=outputs,
80
+ title='DALL·E mini',
81
+ description=description,
82
+ article="<p style='text-align: center'> DALLE·mini by Boris Dayma et al. | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a></p>",
83
+ layout='vertical',
84
+ theme='huggingface',
85
+ examples=[['an armchair in the shape of an avocado'], ['snowy mountains by the sea']],
86
+ allow_flagging=False,
87
+ live=False,
88
+ server_port=8999
89
+ ).launch(
90
+ share=True # Creates temporary public link if true
91
+ )
demo/model-sweep.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ import random
5
+
6
+ import jax
7
+ import flax.linen as nn
8
+ from flax.training.common_utils import shard
9
+ from flax.jax_utils import replicate, unreplicate
10
+
11
+ from transformers.models.bart.modeling_flax_bart import *
12
+ from transformers import BartTokenizer, FlaxBartForConditionalGeneration
13
+
14
+ import io
15
+
16
+ import requests
17
+ from PIL import Image
18
+ import numpy as np
19
+ import matplotlib.pyplot as plt
20
+
21
+ import torch
22
+ import torchvision.transforms as T
23
+ import torchvision.transforms.functional as TF
24
+ from torchvision.transforms import InterpolationMode
25
+
26
+ from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel
27
+
28
+ # TODO: set those args in a config file
29
+ OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
30
+ OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
31
+ BOS_TOKEN_ID = 16384
32
+ BASE_MODEL = 'facebook/bart-large-cnn'
33
+ WANDB_MODEL = '3iwhu4w6'
34
+
35
+ class CustomFlaxBartModule(FlaxBartModule):
36
+ def setup(self):
37
+ # we keep shared to easily load pre-trained weights
38
+ self.shared = nn.Embed(
39
+ self.config.vocab_size,
40
+ self.config.d_model,
41
+ embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
42
+ dtype=self.dtype,
43
+ )
44
+ # a separate embedding is used for the decoder
45
+ self.decoder_embed = nn.Embed(
46
+ OUTPUT_VOCAB_SIZE,
47
+ self.config.d_model,
48
+ embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
49
+ dtype=self.dtype,
50
+ )
51
+ self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
52
+
53
+ # the decoder has a different config
54
+ decoder_config = BartConfig(self.config.to_dict())
55
+ decoder_config.max_position_embeddings = OUTPUT_LENGTH
56
+ decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
57
+ self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
58
+
59
+ class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
60
+ def setup(self):
61
+ self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
62
+ self.lm_head = nn.Dense(
63
+ OUTPUT_VOCAB_SIZE,
64
+ use_bias=False,
65
+ dtype=self.dtype,
66
+ kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
67
+ )
68
+ self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))
69
+
70
+ class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
71
+ module_class = CustomFlaxBartForConditionalGenerationModule
72
+
73
+ tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)
74
+ vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384")
75
+
76
+ def custom_to_pil(x):
77
+ x = np.clip(x, 0., 1.)
78
+ x = (255*x).astype(np.uint8)
79
+ x = Image.fromarray(x)
80
+ if not x.mode == "RGB":
81
+ x = x.convert("RGB")
82
+ return x
83
+
84
+ def generate(input, rng, params):
85
+ return model.generate(
86
+ **input,
87
+ max_length=257,
88
+ num_beams=1,
89
+ do_sample=True,
90
+ prng_key=rng,
91
+ eos_token_id=50000,
92
+ pad_token_id=50000,
93
+ params=params,
94
+ )
95
+
96
+ def get_images(indices, params):
97
+ return vqgan.decode_code(indices, params=params)
98
+
99
+ def plot_images(images):
100
+ fig = plt.figure(figsize=(40, 20))
101
+ columns = 4
102
+ rows = 2
103
+ plt.subplots_adjust(hspace=0, wspace=0)
104
+
105
+ for i in range(1, columns*rows +1):
106
+ fig.add_subplot(rows, columns, i)
107
+ plt.imshow(images[i-1])
108
+ plt.gca().axes.get_yaxis().set_visible(False)
109
+ plt.show()
110
+
111
+ def stack_reconstructions(images):
112
+ w, h = images[0].size[0], images[0].size[1]
113
+ img = Image.new("RGB", (len(images)*w, h))
114
+ for i, img_ in enumerate(images):
115
+ img.paste(img_, (i*w,0))
116
+ return img
117
+
118
+ p_generate = jax.pmap(generate, "batch")
119
+ p_get_images = jax.pmap(get_images, "batch")
120
+
121
+ # ## CLIP Scoring
122
+ from transformers import CLIPProcessor, FlaxCLIPModel
123
+
124
+ clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
125
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
126
+
127
+ def hallucinate(prompt, num_images=64):
128
+ prompt = [prompt] * jax.device_count()
129
+ inputs = tokenizer(prompt, return_tensors='jax', padding="max_length", truncation=True, max_length=128).data
130
+ inputs = shard(inputs)
131
+
132
+ all_images = []
133
+ for i in range(num_images // jax.device_count()):
134
+ key = random.randint(0, 1e7)
135
+ rng = jax.random.PRNGKey(key)
136
+ rngs = jax.random.split(rng, jax.local_device_count())
137
+ indices = p_generate(inputs, rngs, bart_params).sequences
138
+ indices = indices[:, :, 1:]
139
+
140
+ images = p_get_images(indices, vqgan_params)
141
+ images = np.squeeze(np.asarray(images), 1)
142
+ for image in images:
143
+ all_images.append(custom_to_pil(image))
144
+ return all_images
145
+
146
+ def clip_top_k(prompt, images, k=8):
147
+ inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
148
+ outputs = clip(**inputs)
149
+ logits = outputs.logits_per_text
150
+ scores = np.array(logits[0]).argsort()[-k:][::-1]
151
+ return [images[score] for score in scores]
152
+
153
+ from PIL import ImageDraw, ImageFont
154
+
155
+ def captioned_strip(images, caption):
156
+ w, h = images[0].size[0], images[0].size[1]
157
+ img = Image.new("RGB", (len(images)*w, h + 48))
158
+ for i, img_ in enumerate(images):
159
+ img.paste(img_, (i*w, 48))
160
+ draw = ImageDraw.Draw(img)
161
+ font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
162
+ draw.text((20, 3), caption, (255,255,255), font=font)
163
+ return img
164
+
165
+ def log_to_wandb(prompts):
166
+ strips = []
167
+ for prompt in prompts:
168
+ print(f"Generating candidates for: {prompt}")
169
+ images = hallucinate(prompt, num_images=32)
170
+ selected = clip_top_k(prompt, images, k=8)
171
+ strip = captioned_strip(selected, prompt)
172
+ strips.append(wandb.Image(strip))
173
+ wandb.log({"images": strips})
174
+
175
+ ## Artifact loop
176
+
177
+ import wandb
178
+ import os
179
+ os.environ["WANDB_SILENT"] = "true"
180
+ os.environ["WANDB_CONSOLE"] = "off"
181
+
182
+ id = wandb.util.generate_id()
183
+ print(f"Logging images to wandb run id: {id}")
184
+
185
+ run = wandb.init(id=id,
186
+ entity='wandb',
187
+ project="hf-flax-dalle-mini",
188
+ job_type="predictions",
189
+ resume="allow"
190
+ )
191
+
192
+ artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-3iwhu4w6:v0', type='bart_model')
193
+ producer_run = artifact.logged_by()
194
+ logged_artifacts = producer_run.logged_artifacts()
195
+
196
+ for artifact in logged_artifacts:
197
+ print(f"Generating predictions with version {artifact.version}")
198
+ artifact_dir = artifact.download()
199
+
200
+ # create our model
201
+ model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
202
+ model.config.force_bos_token_to_be_generated = False
203
+ model.config.forced_bos_token_id = None
204
+ model.config.forced_eos_token_id = None
205
+
206
+ bart_params = replicate(model.params)
207
+ vqgan_params = replicate(vqgan.params)
208
+
209
+ prompts = prompts = [
210
+ "white snow covered mountain under blue sky during daytime",
211
+ "aerial view of beach during daytime",
212
+ "aerial view of beach at night",
213
+ "an armchair in the shape of an avocado",
214
+ "young woman riding her bike trough a forest",
215
+ "rice fields by the mediterranean coast",
216
+ "white houses on the hill of a greek coastline",
217
+ "illustration of a shark with a baby shark",
218
+ ]
219
+
220
+ log_to_wandb(prompts)
demo/wandb-examples.py CHANGED
@@ -83,7 +83,7 @@ run = wandb.init(id=id,
83
  job_type="predictions",
84
  resume="allow"
85
  )
86
- artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-3iwhu4w6:latest', type='bart_model')
87
  artifact_dir = artifact.download()
88
 
89
  # create our model
 
83
  job_type="predictions",
84
  resume="allow"
85
  )
86
+ artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-4oh3u7ca:latest', type='bart_model')
87
  artifact_dir = artifact.download()
88
 
89
  # create our model