Pedro Cuenca commited on
Commit
830d7a2
1 Parent(s): 699e1d9

Script to log predictions grid to wandb.

Browse files

Former-commit-id: 70622a0fe7da6be960735403a3bb1397d622c97a

Files changed (1) hide show
  1. demo/wandb-examples.py +214 -0
demo/wandb-examples.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ import random
5
+
6
+ import jax
7
+ import flax.linen as nn
8
+ from flax.training.common_utils import shard
9
+ from flax.jax_utils import replicate, unreplicate
10
+
11
+ from transformers.models.bart.modeling_flax_bart import *
12
+ from transformers import BartTokenizer, FlaxBartForConditionalGeneration
13
+
14
+ import io
15
+
16
+ import requests
17
+ from PIL import Image
18
+ import numpy as np
19
+ import matplotlib.pyplot as plt
20
+
21
+ import torch
22
+ import torchvision.transforms as T
23
+ import torchvision.transforms.functional as TF
24
+ from torchvision.transforms import InterpolationMode
25
+
26
+ from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel
27
+
28
+ # TODO: set those args in a config file
29
+ OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
30
+ OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
31
+ BOS_TOKEN_ID = 16384
32
+ BASE_MODEL = 'facebook/bart-large-cnn'
33
+
34
+ class CustomFlaxBartModule(FlaxBartModule):
35
+ def setup(self):
36
+ # we keep shared to easily load pre-trained weights
37
+ self.shared = nn.Embed(
38
+ self.config.vocab_size,
39
+ self.config.d_model,
40
+ embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
41
+ dtype=self.dtype,
42
+ )
43
+ # a separate embedding is used for the decoder
44
+ self.decoder_embed = nn.Embed(
45
+ OUTPUT_VOCAB_SIZE,
46
+ self.config.d_model,
47
+ embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
48
+ dtype=self.dtype,
49
+ )
50
+ self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
51
+
52
+ # the decoder has a different config
53
+ decoder_config = BartConfig(self.config.to_dict())
54
+ decoder_config.max_position_embeddings = OUTPUT_LENGTH
55
+ decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
56
+ self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
57
+
58
+ class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
59
+ def setup(self):
60
+ self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
61
+ self.lm_head = nn.Dense(
62
+ OUTPUT_VOCAB_SIZE,
63
+ use_bias=False,
64
+ dtype=self.dtype,
65
+ kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
66
+ )
67
+ self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))
68
+
69
+ class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
70
+ module_class = CustomFlaxBartForConditionalGenerationModule
71
+
72
+
73
+ import wandb
74
+ import os
75
+ os.environ["WANDB_SILENT"] = "true"
76
+ os.environ["WANDB_CONSOLE"] = "off"
77
+
78
+ id = '1i5e8rlj' #'1coradc5'
79
+ #run = wandb.init(id=id, project="dalle-mini-demo", resume="allow")
80
+ run = wandb.init(id=id,
81
+ entity='wandb',
82
+ project="hf-flax-dalle-mini",
83
+ job_type="predictions",
84
+ resume="allow"
85
+ )
86
+ artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-3iwhu4w6:latest', type='bart_model')
87
+ artifact_dir = artifact.download()
88
+
89
+ # create our model
90
+ tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)
91
+ model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
92
+ model.config.force_bos_token_to_be_generated = False
93
+ model.config.forced_bos_token_id = None
94
+ model.config.forced_eos_token_id = None
95
+
96
+ vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384")
97
+
98
+ def custom_to_pil(x):
99
+ x = np.clip(x, 0., 1.)
100
+ x = (255*x).astype(np.uint8)
101
+ x = Image.fromarray(x)
102
+ if not x.mode == "RGB":
103
+ x = x.convert("RGB")
104
+ return x
105
+
106
+ def generate(input, rng, params):
107
+ return model.generate(
108
+ **input,
109
+ max_length=257,
110
+ num_beams=1,
111
+ do_sample=True,
112
+ prng_key=rng,
113
+ eos_token_id=50000,
114
+ pad_token_id=50000,
115
+ params=params,
116
+ )
117
+
118
+ def get_images(indices, params):
119
+ return vqgan.decode_code(indices, params=params)
120
+
121
+ def plot_images(images):
122
+ fig = plt.figure(figsize=(40, 20))
123
+ columns = 4
124
+ rows = 2
125
+ plt.subplots_adjust(hspace=0, wspace=0)
126
+
127
+ for i in range(1, columns*rows +1):
128
+ fig.add_subplot(rows, columns, i)
129
+ plt.imshow(images[i-1])
130
+ plt.gca().axes.get_yaxis().set_visible(False)
131
+ plt.show()
132
+
133
+ def stack_reconstructions(images):
134
+ w, h = images[0].size[0], images[0].size[1]
135
+ img = Image.new("RGB", (len(images)*w, h))
136
+ for i, img_ in enumerate(images):
137
+ img.paste(img_, (i*w,0))
138
+ return img
139
+
140
+ p_generate = jax.pmap(generate, "batch")
141
+ p_get_images = jax.pmap(get_images, "batch")
142
+
143
+ bart_params = replicate(model.params)
144
+ vqgan_params = replicate(vqgan.params)
145
+
146
+ # ## CLIP Scoring
147
+ from transformers import CLIPProcessor, FlaxCLIPModel
148
+
149
+ clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
150
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
151
+
152
+ def hallucinate(prompt, num_images=64):
153
+ prompt = [prompt] * jax.device_count()
154
+ inputs = tokenizer(prompt, return_tensors='jax', padding="max_length", truncation=True, max_length=128).data
155
+ inputs = shard(inputs)
156
+
157
+ all_images = []
158
+ for i in range(num_images // jax.device_count()):
159
+ key = random.randint(0, 1e7)
160
+ rng = jax.random.PRNGKey(key)
161
+ rngs = jax.random.split(rng, jax.local_device_count())
162
+ indices = p_generate(inputs, rngs, bart_params).sequences
163
+ indices = indices[:, :, 1:]
164
+
165
+ images = p_get_images(indices, vqgan_params)
166
+ images = np.squeeze(np.asarray(images), 1)
167
+ for image in images:
168
+ all_images.append(custom_to_pil(image))
169
+ return all_images
170
+
171
+ def clip_top_k(prompt, images, k=8):
172
+ inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
173
+ outputs = clip(**inputs)
174
+ logits = outputs.logits_per_text
175
+ scores = np.array(logits[0]).argsort()[-k:][::-1]
176
+ return [images[score] for score in scores]
177
+
178
+
179
+ # ## Log to wandb
180
+
181
+ from PIL import ImageDraw, ImageFont
182
+
183
+ def captioned_strip(images, caption):
184
+ w, h = images[0].size[0], images[0].size[1]
185
+ img = Image.new("RGB", (len(images)*w, h + 48))
186
+ for i, img_ in enumerate(images):
187
+ img.paste(img_, (i*w, 48))
188
+ draw = ImageDraw.Draw(img)
189
+ font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
190
+ draw.text((20, 3), caption, (255,255,255), font=font)
191
+ return img
192
+
193
+ def log_to_wandb(prompts):
194
+ strips = []
195
+ for prompt in prompts:
196
+ print(f"Generating candidates for: {prompt}")
197
+ images = hallucinate(prompt, num_images=32)
198
+ selected = clip_top_k(prompt, images, k=8)
199
+ strip = captioned_strip(selected, prompt)
200
+ strips.append(wandb.Image(strip))
201
+ wandb.log({"images": strips})
202
+
203
+ prompts = prompts = [
204
+ "white snow covered mountain under blue sky during daytime",
205
+ "aerial view of beach during daytime",
206
+ "aerial view of beach at night",
207
+ "an armchair in the shape of an avocado",
208
+ "young woman riding her bike trough a forest",
209
+ "rice fields by the mediterranean coast",
210
+ "white houses on the hill of a greek coastline",
211
+ "illustration of a shark with a baby shark",
212
+ ]
213
+
214
+ log_to_wandb(prompts)