boris commited on
Commit
9a553a4
1 Parent(s): 2d169e3

fix: pmap clip32

Browse files
Files changed (1) hide show
  1. dev/inference/wandb-backend.ipynb +82 -13
dev/inference/wandb-backend.ipynb CHANGED
@@ -36,7 +36,8 @@
36
  "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
37
  "normalize_text = True\n",
38
  "latest_only = False # log only latest or all versions\n",
39
- "suffix = '_1' # mainly for duplicate inference runs with a deleted version"
 
40
  ]
41
  },
42
  {
@@ -51,7 +52,8 @@
51
  "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
52
  "normalize_text = False\n",
53
  "latest_only = True # log only latest or all versions\n",
54
- "suffix = '_2' # mainly for duplicate inference runs with a deleted version"
 
55
  ]
56
  },
57
  {
@@ -82,7 +84,12 @@
82
  "clip = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
83
  "processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
84
  "clip_params = replicate(clip.params)\n",
85
- "vqgan_params = replicate(vqgan.params)"
 
 
 
 
 
86
  ]
87
  },
88
  {
@@ -98,8 +105,14 @@
98
  "\n",
99
  "@partial(jax.pmap, axis_name=\"batch\")\n",
100
  "def p_clip(inputs):\n",
101
- " logits = clip(**inputs).logits_per_image\n",
102
- " return logits"
 
 
 
 
 
 
103
  ]
104
  },
105
  {
@@ -158,7 +171,7 @@
158
  "# retrieve inference run details\n",
159
  "def get_last_inference_version(run_id):\n",
160
  " try:\n",
161
- " inference_run = api.run(f'dalle-mini/dalle-mini/inf-{run_id}{suffix}')\n",
162
  " return inference_run.summary.get('version', None)\n",
163
  " except:\n",
164
  " return None"
@@ -215,6 +228,8 @@
215
  " print(f'Processing artifact: {artifact.name}')\n",
216
  " version = int(artifact.version[1:])\n",
217
  " results = []\n",
 
 
218
  " columns = ['Caption'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]\n",
219
  " \n",
220
  " if latest_only:\n",
@@ -232,7 +247,7 @@
232
  "\n",
233
  " # start/resume corresponding run\n",
234
  " if run is None:\n",
235
- " run = wandb.init(job_type='inference', entity='dalle-mini', project='dalle-mini', config=training_config, id=f'inf-{run_id}{suffix}', resume='allow')\n",
236
  "\n",
237
  " # work in temporary directory\n",
238
  " with tempfile.TemporaryDirectory() as tmp:\n",
@@ -283,7 +298,6 @@
283
  " logits = logits.reshape(-1, num_images)\n",
284
  " top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
285
  " logits = jax.device_get(logits)\n",
286
- "\n",
287
  " # add to results table\n",
288
  " for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
289
  " if sample == padding_item: continue\n",
@@ -291,11 +305,68 @@
291
  " top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
292
  " top_scores = [scores[x] for x in idx]\n",
293
  " results.append([sample] + top_images + top_scores)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  "\n",
295
  " # log results\n",
296
  " table = wandb.Table(columns=columns, data=results)\n",
297
  " run.log({'Samples': table, 'version': version})\n",
298
- " wandb.finish()"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  ]
300
  },
301
  {
@@ -314,12 +385,10 @@
314
  {
315
  "cell_type": "code",
316
  "execution_count": null,
317
- "id": "e1c04761-1016-47e9-925c-3a9ec6fec95a",
318
  "metadata": {},
319
  "outputs": [],
320
- "source": [
321
- "wandb.finish()"
322
- ]
323
  }
324
  ],
325
  "metadata": {
 
36
  "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
37
  "normalize_text = True\n",
38
  "latest_only = False # log only latest or all versions\n",
39
+ "suffix = '_1' # mainly for duplicate inference runs with a deleted version\n",
40
+ "add_clip_32 = False"
41
  ]
42
  },
43
  {
 
52
  "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
53
  "normalize_text = False\n",
54
  "latest_only = True # log only latest or all versions\n",
55
+ "suffix = '_2' # mainly for duplicate inference runs with a deleted version\n",
56
+ "add_clip_32 = True"
57
  ]
58
  },
59
  {
 
84
  "clip = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
85
  "processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
86
  "clip_params = replicate(clip.params)\n",
87
+ "vqgan_params = replicate(vqgan.params)\n",
88
+ "\n",
89
+ "if add_clip_32:\n",
90
+ " clip32 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
91
+ " processor32 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
92
+ " clip32_params = replicate(clip32.params)"
93
  ]
94
  },
95
  {
 
105
  "\n",
106
  "@partial(jax.pmap, axis_name=\"batch\")\n",
107
  "def p_clip(inputs):\n",
108
+ " logits = clip(params=clip_params, **inputs).logits_per_image\n",
109
+ " return logits\n",
110
+ "\n",
111
+ "if add_clip_32:\n",
112
+ " @partial(jax.pmap, axis_name=\"batch\")\n",
113
+ " def p_clip32(inputs):\n",
114
+ " logits = clip32(params=clip32_params, **inputs).logits_per_image\n",
115
+ " return logits"
116
  ]
117
  },
118
  {
 
171
  "# retrieve inference run details\n",
172
  "def get_last_inference_version(run_id):\n",
173
  " try:\n",
174
+ " inference_run = api.run(f'dalle-mini/dalle-mini/{run_id}-clip16{suffix}')\n",
175
  " return inference_run.summary.get('version', None)\n",
176
  " except:\n",
177
  " return None"
 
228
  " print(f'Processing artifact: {artifact.name}')\n",
229
  " version = int(artifact.version[1:])\n",
230
  " results = []\n",
231
+ " if add_clip_32:\n",
232
+ " results32 = []\n",
233
  " columns = ['Caption'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]\n",
234
  " \n",
235
  " if latest_only:\n",
 
247
  "\n",
248
  " # start/resume corresponding run\n",
249
  " if run is None:\n",
250
+ " run = wandb.init(job_type='inference', entity='dalle-mini', project='dalle-mini', config=training_config, id=f'{run_id}-clip16{suffix}', resume='allow')\n",
251
  "\n",
252
  " # work in temporary directory\n",
253
  " with tempfile.TemporaryDirectory() as tmp:\n",
 
298
  " logits = logits.reshape(-1, num_images)\n",
299
  " top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
300
  " logits = jax.device_get(logits)\n",
 
301
  " # add to results table\n",
302
  " for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
303
  " if sample == padding_item: continue\n",
 
305
  " top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
306
  " top_scores = [scores[x] for x in idx]\n",
307
  " results.append([sample] + top_images + top_scores)\n",
308
+ " \n",
309
+ " # get clip 32 scores - TODO: this should be refactored as it is same code as above\n",
310
+ " if add_clip_32:\n",
311
+ " print('Calculating CLIP 32 scores')\n",
312
+ " clip_inputs = processor32(text=batch, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
313
+ " # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
314
+ " images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
315
+ " clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
316
+ " clip_inputs = shard(clip_inputs)\n",
317
+ " logits = p_clip32(clip_inputs)\n",
318
+ " logits = logits.reshape(-1, num_images)\n",
319
+ " top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
320
+ " logits = jax.device_get(logits)\n",
321
+ " # add to results table\n",
322
+ " for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
323
+ " if sample == padding_item: continue\n",
324
+ " cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
325
+ " top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
326
+ " top_scores = [scores[x] for x in idx]\n",
327
+ " results32.append([sample] + top_images + top_scores)\n",
328
  "\n",
329
  " # log results\n",
330
  " table = wandb.Table(columns=columns, data=results)\n",
331
  " run.log({'Samples': table, 'version': version})\n",
332
+ " wandb.finish()\n",
333
+ " \n",
334
+ " if add_clip_32: \n",
335
+ " run = wandb.init(job_type='inference', entity='dalle-mini', project='dalle-mini', config=training_config, id=f'{run_id}-clip32{suffix}', resume='allow')\n",
336
+ " table = wandb.Table(columns=columns, data=results32)\n",
337
+ " run.log({'Samples': table, 'version': version})\n",
338
+ " wandb.finish()\n",
339
+ " run = None # ensure we don't log on this run"
340
+ ]
341
+ },
342
+ {
343
+ "cell_type": "code",
344
+ "execution_count": null,
345
+ "id": "fdcd09d6-079c-461a-a81a-d9e650d3b099",
346
+ "metadata": {},
347
+ "outputs": [],
348
+ "source": [
349
+ "p_clip32"
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "code",
354
+ "execution_count": null,
355
+ "id": "7d86ceee-c9ac-4860-abad-410cadd16c3c",
356
+ "metadata": {},
357
+ "outputs": [],
358
+ "source": [
359
+ "clip_inputs['attention_mask'].shape, clip_inputs['pixel_values'].shape"
360
+ ]
361
+ },
362
+ {
363
+ "cell_type": "code",
364
+ "execution_count": null,
365
+ "id": "fbba4858-da2d-4dd5-97b7-ce3ab4746f96",
366
+ "metadata": {},
367
+ "outputs": [],
368
+ "source": [
369
+ "clip_inputs['input_ids'].shape"
370
  ]
371
  },
372
  {
 
385
  {
386
  "cell_type": "code",
387
  "execution_count": null,
388
+ "id": "a7a5fdf5-3c6e-421b-96a8-5115f730328c",
389
  "metadata": {},
390
  "outputs": [],
391
+ "source": []
 
 
392
  }
393
  ],
394
  "metadata": {