boris commited on
Commit
cb127c4
1 Parent(s): 4e4a30f

feat(log_inference_samples): cleanup

Browse files
tools/inference/log_inference_samples.ipynb CHANGED
@@ -100,11 +100,12 @@
100
  "outputs": [],
101
  "source": [
102
  "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
103
- "clip = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
104
- "processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
105
- "clip_params = replicate(clip.params)\n",
106
  "vqgan_params = replicate(vqgan.params)\n",
107
  "\n",
 
 
 
 
108
  "if add_clip_32:\n",
109
  " clip32 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
110
  " processor32 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
@@ -123,8 +124,8 @@
123
  " return vqgan.decode_code(indices, params=params)\n",
124
  "\n",
125
  "@partial(jax.pmap, axis_name=\"batch\")\n",
126
- "def p_clip(inputs, params):\n",
127
- " logits = clip(params=params, **inputs).logits_per_image\n",
128
  " return logits\n",
129
  "\n",
130
  "if add_clip_32:\n",
@@ -229,7 +230,7 @@
229
  "outputs": [],
230
  "source": [
231
  "run_id = run_ids[0]\n",
232
- "# TODO: turn everything into a class"
233
  ]
234
  },
235
  {
@@ -248,10 +249,8 @@
248
  "for artifact in artifact_versions:\n",
249
  " print(f'Processing artifact: {artifact.name}')\n",
250
  " version = int(artifact.version[1:])\n",
251
- " results = []\n",
252
- " if add_clip_32:\n",
253
- " results32 = []\n",
254
- " columns = ['Caption'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]\n",
255
  " \n",
256
  " if latest_only:\n",
257
  " assert last_inference_version is None or version > last_inference_version\n",
@@ -307,34 +306,13 @@
307
  " for img in decoded_images:\n",
308
  " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))\n",
309
  "\n",
310
- " # get clip scores\n",
311
- " pbar.set_description('Calculating CLIP scores')\n",
312
- " clip_inputs = processor(text=batch, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
313
- " # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
314
- " images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
315
- " clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
316
- " clip_inputs = shard(clip_inputs)\n",
317
- " logits = p_clip(clip_inputs, clip_params)\n",
318
- " logits = logits.reshape(-1, num_images)\n",
319
- " top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
320
- " logits = jax.device_get(logits)\n",
321
- " # add to results table\n",
322
- " for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
323
- " if sample == padding_item: continue\n",
324
- " cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
325
- " top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
326
- " top_scores = [scores[x] for x in idx]\n",
327
- " results.append([sample] + top_images + top_scores)\n",
328
- " \n",
329
- " # get clip 32 scores - TODO: this should be refactored as it is same code as above\n",
330
- " if add_clip_32:\n",
331
- " print('Calculating CLIP 32 scores')\n",
332
- " clip_inputs = processor32(text=batch, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
333
  " # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
334
  " images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
335
  " clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
336
  " clip_inputs = shard(clip_inputs)\n",
337
- " logits = p_clip32(clip_inputs, clip32_params)\n",
338
  " logits = logits.reshape(-1, num_images)\n",
339
  " top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
340
  " logits = jax.device_get(logits)\n",
@@ -342,13 +320,24 @@
342
  " for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
343
  " if sample == padding_item: continue\n",
344
  " cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
345
- " top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
346
- " top_scores = [scores[x] for x in idx]\n",
347
- " results32.append([sample] + top_images + top_scores)\n",
 
 
 
 
 
 
 
 
 
348
  " pbar.close()\n",
349
  "\n",
 
 
350
  " # log results\n",
351
- " table = wandb.Table(columns=columns, data=results)\n",
352
  " run.log({'Samples': table, 'version': version})\n",
353
  " wandb.finish()\n",
354
  " \n",
@@ -359,19 +348,6 @@
359
  " wandb.finish()\n",
360
  " run = None # ensure we don't log on this run"
361
  ]
362
- },
363
- {
364
- "cell_type": "code",
365
- "execution_count": null,
366
- "id": "4e4c7d0c-2848-4f88-b967-82fd571534f1",
367
- "metadata": {},
368
- "outputs": [],
369
- "source": [
370
- "# TODO: not implemented\n",
371
- "def log_runs(runs):\n",
372
- " for run in tqdm(runs):\n",
373
- " log_run(run)"
374
- ]
375
  }
376
  ],
377
  "metadata": {
 
100
  "outputs": [],
101
  "source": [
102
  "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
 
 
 
103
  "vqgan_params = replicate(vqgan.params)\n",
104
  "\n",
105
+ "clip16 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
106
+ "processor16 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
107
+ "clip16_params = replicate(clip16.params)\n",
108
+ "\n",
109
  "if add_clip_32:\n",
110
  " clip32 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
111
  " processor32 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
 
124
  " return vqgan.decode_code(indices, params=params)\n",
125
  "\n",
126
  "@partial(jax.pmap, axis_name=\"batch\")\n",
127
+ "def p_clip16(inputs, params):\n",
128
+ " logits = clip16(params=params, **inputs).logits_per_image\n",
129
  " return logits\n",
130
  "\n",
131
  "if add_clip_32:\n",
 
230
  "outputs": [],
231
  "source": [
232
  "run_id = run_ids[0]\n",
233
+ "# TODO: turn everything into a class or loop over runs"
234
  ]
235
  },
236
  {
 
249
  "for artifact in artifact_versions:\n",
250
  " print(f'Processing artifact: {artifact.name}')\n",
251
  " version = int(artifact.version[1:])\n",
252
+ " results16, results32 = [], []\n",
253
+ " columns = ['Caption'] + [f'Image {i+1}' for i in range(top_k)]\n",
 
 
254
  " \n",
255
  " if latest_only:\n",
256
  " assert last_inference_version is None or version > last_inference_version\n",
 
306
  " for img in decoded_images:\n",
307
  " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))\n",
308
  "\n",
309
+ " def add_clip_results(results, processor, p_clip, clip_params): \n",
310
+ " clip_inputs = processor(text=batch, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  " # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
312
  " images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
313
  " clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
314
  " clip_inputs = shard(clip_inputs)\n",
315
+ " logits = p_clip(clip_inputs, clip32_params)\n",
316
  " logits = logits.reshape(-1, num_images)\n",
317
  " top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
318
  " logits = jax.device_get(logits)\n",
 
320
  " for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
321
  " if sample == padding_item: continue\n",
322
  " cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
323
+ " top_images = [wandb.Image(cur_images[x], caption=f'Score: {scores[x]:.2f}') for x in idx]\n",
324
+ " results.append([sample] + top_images)\n",
325
+ " \n",
326
+ " # get clip scores\n",
327
+ " pbar.set_description('Calculating CLIP 16 scores')\n",
328
+ " add_clip_results(results16, processor16, p_clip16, clip16_params)\n",
329
+ " \n",
330
+ " # get clip 32 scores\n",
331
+ " if add_clip_32:\n",
332
+ " pbar.set_description('Calculating CLIP 32 scores')\n",
333
+ " add_clip_results(results32, processor32, p_clip32, clip32_params)\n",
334
+ "\n",
335
  " pbar.close()\n",
336
  "\n",
337
+ " \n",
338
+ "\n",
339
  " # log results\n",
340
+ " table = wandb.Table(columns=columns, data=results16)\n",
341
  " run.log({'Samples': table, 'version': version})\n",
342
  " wandb.finish()\n",
343
  " \n",
 
348
  " wandb.finish()\n",
349
  " run = None # ensure we don't log on this run"
350
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  }
352
  ],
353
  "metadata": {