Pedro Cuenca commited on
Commit
d7be08c
·
1 Parent(s): 0e8338d

Script that predicts using all saved versions of a model.

Browse files

Former-commit-id: 8425de3fcab74d1bcc7aeb04e9d6b36a098acc70

Files changed (1) hide show
  1. demo/model-sweep.py +220 -0
demo/model-sweep.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ import random
5
+
6
+ import jax
7
+ import flax.linen as nn
8
+ from flax.training.common_utils import shard
9
+ from flax.jax_utils import replicate, unreplicate
10
+
11
+ from transformers.models.bart.modeling_flax_bart import *
12
+ from transformers import BartTokenizer, FlaxBartForConditionalGeneration
13
+
14
+ import io
15
+
16
+ import requests
17
+ from PIL import Image
18
+ import numpy as np
19
+ import matplotlib.pyplot as plt
20
+
21
+ import torch
22
+ import torchvision.transforms as T
23
+ import torchvision.transforms.functional as TF
24
+ from torchvision.transforms import InterpolationMode
25
+
26
+ from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel
27
+
28
+ # TODO: set those args in a config file
29
+ OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
30
+ OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
31
+ BOS_TOKEN_ID = 16384
32
+ BASE_MODEL = 'facebook/bart-large-cnn'
33
+ WANDB_MODEL = '3iwhu4w6'
34
+
35
+ class CustomFlaxBartModule(FlaxBartModule):
36
+ def setup(self):
37
+ # we keep shared to easily load pre-trained weights
38
+ self.shared = nn.Embed(
39
+ self.config.vocab_size,
40
+ self.config.d_model,
41
+ embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
42
+ dtype=self.dtype,
43
+ )
44
+ # a separate embedding is used for the decoder
45
+ self.decoder_embed = nn.Embed(
46
+ OUTPUT_VOCAB_SIZE,
47
+ self.config.d_model,
48
+ embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
49
+ dtype=self.dtype,
50
+ )
51
+ self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
52
+
53
+ # the decoder has a different config
54
+ decoder_config = BartConfig(self.config.to_dict())
55
+ decoder_config.max_position_embeddings = OUTPUT_LENGTH
56
+ decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
57
+ self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
58
+
59
+ class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
60
+ def setup(self):
61
+ self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
62
+ self.lm_head = nn.Dense(
63
+ OUTPUT_VOCAB_SIZE,
64
+ use_bias=False,
65
+ dtype=self.dtype,
66
+ kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
67
+ )
68
+ self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))
69
+
70
+ class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
71
+ module_class = CustomFlaxBartForConditionalGenerationModule
72
+
73
+ tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)
74
+ vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384")
75
+
76
+ def custom_to_pil(x):
77
+ x = np.clip(x, 0., 1.)
78
+ x = (255*x).astype(np.uint8)
79
+ x = Image.fromarray(x)
80
+ if not x.mode == "RGB":
81
+ x = x.convert("RGB")
82
+ return x
83
+
84
+ def generate(input, rng, params):
85
+ return model.generate(
86
+ **input,
87
+ max_length=257,
88
+ num_beams=1,
89
+ do_sample=True,
90
+ prng_key=rng,
91
+ eos_token_id=50000,
92
+ pad_token_id=50000,
93
+ params=params,
94
+ )
95
+
96
+ def get_images(indices, params):
97
+ return vqgan.decode_code(indices, params=params)
98
+
99
+ def plot_images(images):
100
+ fig = plt.figure(figsize=(40, 20))
101
+ columns = 4
102
+ rows = 2
103
+ plt.subplots_adjust(hspace=0, wspace=0)
104
+
105
+ for i in range(1, columns*rows +1):
106
+ fig.add_subplot(rows, columns, i)
107
+ plt.imshow(images[i-1])
108
+ plt.gca().axes.get_yaxis().set_visible(False)
109
+ plt.show()
110
+
111
+ def stack_reconstructions(images):
112
+ w, h = images[0].size[0], images[0].size[1]
113
+ img = Image.new("RGB", (len(images)*w, h))
114
+ for i, img_ in enumerate(images):
115
+ img.paste(img_, (i*w,0))
116
+ return img
117
+
118
+ p_generate = jax.pmap(generate, "batch")
119
+ p_get_images = jax.pmap(get_images, "batch")
120
+
121
+ # ## CLIP Scoring
122
+ from transformers import CLIPProcessor, FlaxCLIPModel
123
+
124
+ clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
125
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
126
+
127
+ def hallucinate(prompt, num_images=64):
128
+ prompt = [prompt] * jax.device_count()
129
+ inputs = tokenizer(prompt, return_tensors='jax', padding="max_length", truncation=True, max_length=128).data
130
+ inputs = shard(inputs)
131
+
132
+ all_images = []
133
+ for i in range(num_images // jax.device_count()):
134
+ key = random.randint(0, 1e7)
135
+ rng = jax.random.PRNGKey(key)
136
+ rngs = jax.random.split(rng, jax.local_device_count())
137
+ indices = p_generate(inputs, rngs, bart_params).sequences
138
+ indices = indices[:, :, 1:]
139
+
140
+ images = p_get_images(indices, vqgan_params)
141
+ images = np.squeeze(np.asarray(images), 1)
142
+ for image in images:
143
+ all_images.append(custom_to_pil(image))
144
+ return all_images
145
+
146
+ def clip_top_k(prompt, images, k=8):
147
+ inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
148
+ outputs = clip(**inputs)
149
+ logits = outputs.logits_per_text
150
+ scores = np.array(logits[0]).argsort()[-k:][::-1]
151
+ return [images[score] for score in scores]
152
+
153
+ from PIL import ImageDraw, ImageFont
154
+
155
+ def captioned_strip(images, caption):
156
+ w, h = images[0].size[0], images[0].size[1]
157
+ img = Image.new("RGB", (len(images)*w, h + 48))
158
+ for i, img_ in enumerate(images):
159
+ img.paste(img_, (i*w, 48))
160
+ draw = ImageDraw.Draw(img)
161
+ font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
162
+ draw.text((20, 3), caption, (255,255,255), font=font)
163
+ return img
164
+
165
+ def log_to_wandb(prompts):
166
+ strips = []
167
+ for prompt in prompts:
168
+ print(f"Generating candidates for: {prompt}")
169
+ images = hallucinate(prompt, num_images=32)
170
+ selected = clip_top_k(prompt, images, k=8)
171
+ strip = captioned_strip(selected, prompt)
172
+ strips.append(wandb.Image(strip))
173
+ wandb.log({"images": strips})
174
+
175
+ ## Artifact loop
176
+
177
+ import wandb
178
+ import os
179
+ os.environ["WANDB_SILENT"] = "true"
180
+ os.environ["WANDB_CONSOLE"] = "off"
181
+
182
+ id = wandb.util.generate_id()
183
+ print(f"Logging images to wandb run id: {id}")
184
+
185
+ run = wandb.init(id=id,
186
+ entity='wandb',
187
+ project="hf-flax-dalle-mini",
188
+ job_type="predictions",
189
+ resume="allow"
190
+ )
191
+
192
+ artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-3iwhu4w6:v0', type='bart_model')
193
+ producer_run = artifact.logged_by()
194
+ logged_artifacts = producer_run.logged_artifacts()
195
+
196
+ for artifact in logged_artifacts:
197
+ print(f"Generating predictions with version {artifact.version}")
198
+ artifact_dir = artifact.download()
199
+
200
+ # create our model
201
+ model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
202
+ model.config.force_bos_token_to_be_generated = False
203
+ model.config.forced_bos_token_id = None
204
+ model.config.forced_eos_token_id = None
205
+
206
+ bart_params = replicate(model.params)
207
+ vqgan_params = replicate(vqgan.params)
208
+
209
+ prompts = prompts = [
210
+ "white snow covered mountain under blue sky during daytime",
211
+ "aerial view of beach during daytime",
212
+ "aerial view of beach at night",
213
+ "an armchair in the shape of an avocado",
214
+ "young woman riding her bike trough a forest",
215
+ "rice fields by the mediterranean coast",
216
+ "white houses on the hill of a greek coastline",
217
+ "illustration of a shark with a baby shark",
218
+ ]
219
+
220
+ log_to_wandb(prompts)