boris commited on
Commit
741bf32
1 Parent(s): ae754a3

style: reformat

Browse files
tools/inference/inference_pipeline.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
tools/inference/log_inference_samples.ipynb CHANGED
@@ -31,11 +31,14 @@
31
  "metadata": {},
32
  "outputs": [],
33
  "source": [
34
- "run_ids = ['63otg87g']\n",
35
- "ENTITY, PROJECT = 'dalle-mini', 'dalle-mini' # used only for training run\n",
36
- "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', 'e93a26e7707683d349bf5d5c41c5b0ef69b677a9'\n",
37
- "latest_only = True # log only latest or all versions\n",
38
- "suffix = '' # mainly for duplicate inference runs with a deleted version\n",
 
 
 
39
  "add_clip_32 = False"
40
  ]
41
  },
@@ -63,8 +66,8 @@
63
  "num_images = 128\n",
64
  "top_k = 8\n",
65
  "text_normalizer = TextNormalizer()\n",
66
- "padding_item = 'NONE'\n",
67
- "seed = random.randint(0, 2**32-1)\n",
68
  "key = jax.random.PRNGKey(seed)\n",
69
  "api = wandb.Api()"
70
  ]
@@ -100,12 +103,15 @@
100
  "def p_decode(indices, params):\n",
101
  " return vqgan.decode_code(indices, params=params)\n",
102
  "\n",
 
103
  "@partial(jax.pmap, axis_name=\"batch\")\n",
104
  "def p_clip16(inputs, params):\n",
105
  " logits = clip16(params=params, **inputs).logits_per_image\n",
106
  " return logits\n",
107
  "\n",
 
108
  "if add_clip_32:\n",
 
109
  " @partial(jax.pmap, axis_name=\"batch\")\n",
110
  " def p_clip32(inputs, params):\n",
111
  " logits = clip32(params=params, **inputs).logits_per_image\n",
@@ -119,13 +125,13 @@
119
  "metadata": {},
120
  "outputs": [],
121
  "source": [
122
- "with open('samples.txt', encoding='utf8') as f:\n",
123
  " samples = [l.strip() for l in f.readlines()]\n",
124
  " # make list multiple of batch_size by adding elements\n",
125
  " samples_to_add = [padding_item] * (-len(samples) % batch_size)\n",
126
  " samples.extend(samples_to_add)\n",
127
  " # reshape\n",
128
- " samples = [samples[i:i+batch_size] for i in range(0, len(samples), batch_size)]"
129
  ]
130
  },
131
  {
@@ -138,9 +144,17 @@
138
  "def get_artifact_versions(run_id, latest_only=False):\n",
139
  " try:\n",
140
  " if latest_only:\n",
141
- " return [api.artifact(type='bart_model', name=f'{ENTITY}/{PROJECT}/model-{run_id}:latest')]\n",
 
 
 
 
142
  " else:\n",
143
- " return api.artifact_versions(type_name='bart_model', name=f'{ENTITY}/{PROJECT}/model-{run_id}', per_page=10000)\n",
 
 
 
 
144
  " except:\n",
145
  " return []"
146
  ]
@@ -153,7 +167,7 @@
153
  "outputs": [],
154
  "source": [
155
  "def get_training_config(run_id):\n",
156
- " training_run = api.run(f'{ENTITY}/{PROJECT}/{run_id}')\n",
157
  " config = training_run.config\n",
158
  " return config"
159
  ]
@@ -168,8 +182,8 @@
168
  "# retrieve inference run details\n",
169
  "def get_last_inference_version(run_id):\n",
170
  " try:\n",
171
- " inference_run = api.run(f'dalle-mini/dalle-mini/{run_id}-clip16{suffix}')\n",
172
- " return inference_run.summary.get('version', None)\n",
173
  " except:\n",
174
  " return None"
175
  ]
@@ -183,7 +197,6 @@
183
  "source": [
184
  "# compile functions - needed only once per run\n",
185
  "def pmap_model_function(model):\n",
186
- " \n",
187
  " @partial(jax.pmap, axis_name=\"batch\")\n",
188
  " def _generate(tokenized_prompt, key, params):\n",
189
  " return model.generate(\n",
@@ -195,7 +208,7 @@
195
  " top_k=gen_top_k,\n",
196
  " top_p=gen_top_p\n",
197
  " )\n",
198
- " \n",
199
  " return _generate"
200
  ]
201
  },
@@ -222,13 +235,21 @@
222
  "training_config = get_training_config(run_id)\n",
223
  "run = None\n",
224
  "p_generate = None\n",
225
- "model_files = ['config.json', 'flax_model.msgpack', 'merges.txt', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json', 'vocab.json']\n",
 
 
 
 
 
 
 
 
226
  "for artifact in artifact_versions:\n",
227
- " print(f'Processing artifact: {artifact.name}')\n",
228
  " version = int(artifact.version[1:])\n",
229
  " results16, results32 = [], []\n",
230
- " columns = ['Caption'] + [f'Image {i+1}' for i in range(top_k)]\n",
231
- " \n",
232
  " if latest_only:\n",
233
  " assert last_inference_version is None or version > last_inference_version\n",
234
  " else:\n",
@@ -236,14 +257,23 @@
236
  " # we should start from v0\n",
237
  " assert version == 0\n",
238
  " elif version <= last_inference_version:\n",
239
- " print(f'v{version} has already been logged (versions logged up to v{last_inference_version}')\n",
 
 
240
  " else:\n",
241
  " # check we are logging the correct version\n",
242
  " assert version == last_inference_version + 1\n",
243
  "\n",
244
  " # start/resume corresponding run\n",
245
  " if run is None:\n",
246
- " run = wandb.init(job_type='inference', entity='dalle-mini', project='dalle-mini', config=training_config, id=f'{run_id}-clip16{suffix}', resume='allow')\n",
 
 
 
 
 
 
 
247
  "\n",
248
  " # work in temporary directory\n",
249
  " with tempfile.TemporaryDirectory() as tmp:\n",
@@ -264,64 +294,109 @@
264
  "\n",
265
  " # process one batch of captions\n",
266
  " for batch in tqdm(samples):\n",
267
- " processed_prompts = [text_normalizer(x) for x in batch] if model.config.normalize_text else list(batch)\n",
 
 
 
 
268
  "\n",
269
  " # repeat the prompts to distribute over each device and tokenize\n",
270
  " processed_prompts = processed_prompts * jax.device_count()\n",
271
- " tokenized_prompt = tokenizer(processed_prompts, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data\n",
 
 
 
 
 
 
272
  " tokenized_prompt = shard(tokenized_prompt)\n",
273
  "\n",
274
  " # generate images\n",
275
  " images = []\n",
276
- " pbar = tqdm(range(num_images // jax.device_count()), desc='Generating Images', leave=True)\n",
 
 
 
 
277
  " for i in pbar:\n",
278
  " key, subkey = jax.random.split(key)\n",
279
- " encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params)\n",
 
 
280
  " encoded_images = encoded_images.sequences[..., 1:]\n",
281
  " decoded_images = p_decode(encoded_images, vqgan_params)\n",
282
- " decoded_images = decoded_images.clip(0., 1.).reshape((-1, 256, 256, 3))\n",
 
 
283
  " for img in decoded_images:\n",
284
- " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))\n",
 
 
285
  "\n",
286
- " def add_clip_results(results, processor, p_clip, clip_params): \n",
287
- " clip_inputs = processor(text=batch, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
 
 
 
 
 
 
 
288
  " # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
289
- " images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
290
- " clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
 
 
 
 
 
 
 
291
  " clip_inputs = shard(clip_inputs)\n",
292
  " logits = p_clip(clip_inputs, clip_params)\n",
293
  " logits = logits.reshape(-1, num_images)\n",
294
  " top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
295
  " logits = jax.device_get(logits)\n",
296
  " # add to results table\n",
297
- " for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
298
- " if sample == padding_item: continue\n",
 
 
 
299
  " cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
300
- " top_images = [wandb.Image(cur_images[x], caption=f'Score: {scores[x]:.2f}') for x in idx]\n",
 
 
 
301
  " results.append([sample] + top_images)\n",
302
- " \n",
303
  " # get clip scores\n",
304
- " pbar.set_description('Calculating CLIP 16 scores')\n",
305
  " add_clip_results(results16, processor16, p_clip16, clip16_params)\n",
306
- " \n",
307
  " # get clip 32 scores\n",
308
  " if add_clip_32:\n",
309
- " pbar.set_description('Calculating CLIP 32 scores')\n",
310
  " add_clip_results(results32, processor32, p_clip32, clip32_params)\n",
311
  "\n",
312
  " pbar.close()\n",
313
  "\n",
314
- " \n",
315
- "\n",
316
  " # log results\n",
317
  " table = wandb.Table(columns=columns, data=results16)\n",
318
- " run.log({'Samples': table, 'version': version})\n",
319
  " wandb.finish()\n",
320
- " \n",
321
- " if add_clip_32: \n",
322
- " run = wandb.init(job_type='inference', entity='dalle-mini', project='dalle-mini', config=training_config, id=f'{run_id}-clip32{suffix}', resume='allow')\n",
 
 
 
 
 
 
 
323
  " table = wandb.Table(columns=columns, data=results32)\n",
324
- " run.log({'Samples': table, 'version': version})\n",
325
  " wandb.finish()\n",
326
  " run = None # ensure we don't log on this run"
327
  ]
 
31
  "metadata": {},
32
  "outputs": [],
33
  "source": [
34
+ "run_ids = [\"63otg87g\"]\n",
35
+ "ENTITY, PROJECT = \"dalle-mini\", \"dalle-mini\" # used only for training run\n",
36
+ "VQGAN_REPO, VQGAN_COMMIT_ID = (\n",
37
+ " \"dalle-mini/vqgan_imagenet_f16_16384\",\n",
38
+ " \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\",\n",
39
+ ")\n",
40
+ "latest_only = True # log only latest or all versions\n",
41
+ "suffix = \"\" # mainly for duplicate inference runs with a deleted version\n",
42
  "add_clip_32 = False"
43
  ]
44
  },
 
66
  "num_images = 128\n",
67
  "top_k = 8\n",
68
  "text_normalizer = TextNormalizer()\n",
69
+ "padding_item = \"NONE\"\n",
70
+ "seed = random.randint(0, 2 ** 32 - 1)\n",
71
  "key = jax.random.PRNGKey(seed)\n",
72
  "api = wandb.Api()"
73
  ]
 
103
  "def p_decode(indices, params):\n",
104
  " return vqgan.decode_code(indices, params=params)\n",
105
  "\n",
106
+ "\n",
107
  "@partial(jax.pmap, axis_name=\"batch\")\n",
108
  "def p_clip16(inputs, params):\n",
109
  " logits = clip16(params=params, **inputs).logits_per_image\n",
110
  " return logits\n",
111
  "\n",
112
+ "\n",
113
  "if add_clip_32:\n",
114
+ "\n",
115
  " @partial(jax.pmap, axis_name=\"batch\")\n",
116
  " def p_clip32(inputs, params):\n",
117
  " logits = clip32(params=params, **inputs).logits_per_image\n",
 
125
  "metadata": {},
126
  "outputs": [],
127
  "source": [
128
+ "with open(\"samples.txt\", encoding=\"utf8\") as f:\n",
129
  " samples = [l.strip() for l in f.readlines()]\n",
130
  " # make list multiple of batch_size by adding elements\n",
131
  " samples_to_add = [padding_item] * (-len(samples) % batch_size)\n",
132
  " samples.extend(samples_to_add)\n",
133
  " # reshape\n",
134
+ " samples = [samples[i : i + batch_size] for i in range(0, len(samples), batch_size)]"
135
  ]
136
  },
137
  {
 
144
  "def get_artifact_versions(run_id, latest_only=False):\n",
145
  " try:\n",
146
  " if latest_only:\n",
147
+ " return [\n",
148
+ " api.artifact(\n",
149
+ " type=\"bart_model\", name=f\"{ENTITY}/{PROJECT}/model-{run_id}:latest\"\n",
150
+ " )\n",
151
+ " ]\n",
152
  " else:\n",
153
+ " return api.artifact_versions(\n",
154
+ " type_name=\"bart_model\",\n",
155
+ " name=f\"{ENTITY}/{PROJECT}/model-{run_id}\",\n",
156
+ " per_page=10000,\n",
157
+ " )\n",
158
  " except:\n",
159
  " return []"
160
  ]
 
167
  "outputs": [],
168
  "source": [
169
  "def get_training_config(run_id):\n",
170
+ " training_run = api.run(f\"{ENTITY}/{PROJECT}/{run_id}\")\n",
171
  " config = training_run.config\n",
172
  " return config"
173
  ]
 
182
  "# retrieve inference run details\n",
183
  "def get_last_inference_version(run_id):\n",
184
  " try:\n",
185
+ " inference_run = api.run(f\"dalle-mini/dalle-mini/{run_id}-clip16{suffix}\")\n",
186
+ " return inference_run.summary.get(\"version\", None)\n",
187
  " except:\n",
188
  " return None"
189
  ]
 
197
  "source": [
198
  "# compile functions - needed only once per run\n",
199
  "def pmap_model_function(model):\n",
 
200
  " @partial(jax.pmap, axis_name=\"batch\")\n",
201
  " def _generate(tokenized_prompt, key, params):\n",
202
  " return model.generate(\n",
 
208
  " top_k=gen_top_k,\n",
209
  " top_p=gen_top_p\n",
210
  " )\n",
211
+ "\n",
212
  " return _generate"
213
  ]
214
  },
 
235
  "training_config = get_training_config(run_id)\n",
236
  "run = None\n",
237
  "p_generate = None\n",
238
+ "model_files = [\n",
239
+ " \"config.json\",\n",
240
+ " \"flax_model.msgpack\",\n",
241
+ " \"merges.txt\",\n",
242
+ " \"special_tokens_map.json\",\n",
243
+ " \"tokenizer.json\",\n",
244
+ " \"tokenizer_config.json\",\n",
245
+ " \"vocab.json\",\n",
246
+ "]\n",
247
  "for artifact in artifact_versions:\n",
248
+ " print(f\"Processing artifact: {artifact.name}\")\n",
249
  " version = int(artifact.version[1:])\n",
250
  " results16, results32 = [], []\n",
251
+ " columns = [\"Caption\"] + [f\"Image {i+1}\" for i in range(top_k)]\n",
252
+ "\n",
253
  " if latest_only:\n",
254
  " assert last_inference_version is None or version > last_inference_version\n",
255
  " else:\n",
 
257
  " # we should start from v0\n",
258
  " assert version == 0\n",
259
  " elif version <= last_inference_version:\n",
260
+ " print(\n",
261
+ " f\"v{version} has already been logged (versions logged up to v{last_inference_version}\"\n",
262
+ " )\n",
263
  " else:\n",
264
  " # check we are logging the correct version\n",
265
  " assert version == last_inference_version + 1\n",
266
  "\n",
267
  " # start/resume corresponding run\n",
268
  " if run is None:\n",
269
+ " run = wandb.init(\n",
270
+ " job_type=\"inference\",\n",
271
+ " entity=\"dalle-mini\",\n",
272
+ " project=\"dalle-mini\",\n",
273
+ " config=training_config,\n",
274
+ " id=f\"{run_id}-clip16{suffix}\",\n",
275
+ " resume=\"allow\",\n",
276
+ " )\n",
277
  "\n",
278
  " # work in temporary directory\n",
279
  " with tempfile.TemporaryDirectory() as tmp:\n",
 
294
  "\n",
295
  " # process one batch of captions\n",
296
  " for batch in tqdm(samples):\n",
297
+ " processed_prompts = (\n",
298
+ " [text_normalizer(x) for x in batch]\n",
299
+ " if model.config.normalize_text\n",
300
+ " else list(batch)\n",
301
+ " )\n",
302
  "\n",
303
  " # repeat the prompts to distribute over each device and tokenize\n",
304
  " processed_prompts = processed_prompts * jax.device_count()\n",
305
+ " tokenized_prompt = tokenizer(\n",
306
+ " processed_prompts,\n",
307
+ " return_tensors=\"jax\",\n",
308
+ " padding=\"max_length\",\n",
309
+ " truncation=True,\n",
310
+ " max_length=128,\n",
311
+ " ).data\n",
312
  " tokenized_prompt = shard(tokenized_prompt)\n",
313
  "\n",
314
  " # generate images\n",
315
  " images = []\n",
316
+ " pbar = tqdm(\n",
317
+ " range(num_images // jax.device_count()),\n",
318
+ " desc=\"Generating Images\",\n",
319
+ " leave=True,\n",
320
+ " )\n",
321
  " for i in pbar:\n",
322
  " key, subkey = jax.random.split(key)\n",
323
+ " encoded_images = p_generate(\n",
324
+ " tokenized_prompt, shard_prng_key(subkey), model_params\n",
325
+ " )\n",
326
  " encoded_images = encoded_images.sequences[..., 1:]\n",
327
  " decoded_images = p_decode(encoded_images, vqgan_params)\n",
328
+ " decoded_images = decoded_images.clip(0.0, 1.0).reshape(\n",
329
+ " (-1, 256, 256, 3)\n",
330
+ " )\n",
331
  " for img in decoded_images:\n",
332
+ " images.append(\n",
333
+ " Image.fromarray(np.asarray(img * 255, dtype=np.uint8))\n",
334
+ " )\n",
335
  "\n",
336
+ " def add_clip_results(results, processor, p_clip, clip_params):\n",
337
+ " clip_inputs = processor(\n",
338
+ " text=batch,\n",
339
+ " images=images,\n",
340
+ " return_tensors=\"np\",\n",
341
+ " padding=\"max_length\",\n",
342
+ " max_length=77,\n",
343
+ " truncation=True,\n",
344
+ " ).data\n",
345
  " # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
346
+ " images_per_prompt_indices = np.asarray(\n",
347
+ " range(0, len(images), batch_size)\n",
348
+ " )\n",
349
+ " clip_inputs[\"pixel_values\"] = jnp.concatenate(\n",
350
+ " list(\n",
351
+ " clip_inputs[\"pixel_values\"][images_per_prompt_indices + i]\n",
352
+ " for i in range(batch_size)\n",
353
+ " )\n",
354
+ " )\n",
355
  " clip_inputs = shard(clip_inputs)\n",
356
  " logits = p_clip(clip_inputs, clip_params)\n",
357
  " logits = logits.reshape(-1, num_images)\n",
358
  " top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
359
  " logits = jax.device_get(logits)\n",
360
  " # add to results table\n",
361
+ " for i, (idx, scores, sample) in enumerate(\n",
362
+ " zip(top_scores, logits, batch)\n",
363
+ " ):\n",
364
+ " if sample == padding_item:\n",
365
+ " continue\n",
366
  " cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
367
+ " top_images = [\n",
368
+ " wandb.Image(cur_images[x], caption=f\"Score: {scores[x]:.2f}\")\n",
369
+ " for x in idx\n",
370
+ " ]\n",
371
  " results.append([sample] + top_images)\n",
372
+ "\n",
373
  " # get clip scores\n",
374
+ " pbar.set_description(\"Calculating CLIP 16 scores\")\n",
375
  " add_clip_results(results16, processor16, p_clip16, clip16_params)\n",
376
+ "\n",
377
  " # get clip 32 scores\n",
378
  " if add_clip_32:\n",
379
+ " pbar.set_description(\"Calculating CLIP 32 scores\")\n",
380
  " add_clip_results(results32, processor32, p_clip32, clip32_params)\n",
381
  "\n",
382
  " pbar.close()\n",
383
  "\n",
 
 
384
  " # log results\n",
385
  " table = wandb.Table(columns=columns, data=results16)\n",
386
+ " run.log({\"Samples\": table, \"version\": version})\n",
387
  " wandb.finish()\n",
388
+ "\n",
389
+ " if add_clip_32:\n",
390
+ " run = wandb.init(\n",
391
+ " job_type=\"inference\",\n",
392
+ " entity=\"dalle-mini\",\n",
393
+ " project=\"dalle-mini\",\n",
394
+ " config=training_config,\n",
395
+ " id=f\"{run_id}-clip32{suffix}\",\n",
396
+ " resume=\"allow\",\n",
397
+ " )\n",
398
  " table = wandb.Table(columns=columns, data=results32)\n",
399
+ " run.log({\"Samples\": table, \"version\": version})\n",
400
  " wandb.finish()\n",
401
  " run = None # ensure we don't log on this run"
402
  ]