boris commited on
Commit
38705a9
1 Parent(s): 353365f

feat: create a table

Browse files
Files changed (1) hide show
  1. dev/inference/wandb-backend.ipynb +18 -6
dev/inference/wandb-backend.ipynb CHANGED
@@ -46,7 +46,8 @@
46
  "batch_size = 8\n",
47
  "num_images = 128\n",
48
  "top_k = 8\n",
49
- "text_normalizer = TextNormalizer() if normalize_text else None"
 
50
  ]
51
  },
52
  {
@@ -95,8 +96,8 @@
95
  " samples = []\n",
96
  " for row in reader:\n",
97
  " samples.append(row)\n",
98
- " # make list multiple of batch_size by adding \"empty\"\n",
99
- " samples_to_add = [{'Caption':'empty', 'Theme':'empty'}] * (-len(samples) % batch_size)\n",
100
  " samples.extend(samples_to_add)\n",
101
  " # reshape\n",
102
  " samples = [samples[i:i+batch_size] for i in range(0, len(samples), batch_size)]"
@@ -388,7 +389,6 @@
388
  " def p_clip(inputs):\n",
389
  " logits = clip(**inputs).logits_per_image\n",
390
  " return logits\n",
391
- " scores = jax.nn.softmax(logits, axis=0).squeeze() \n",
392
  " \n",
393
  " functions_pmapped = False"
394
  ]
@@ -649,7 +649,8 @@
649
  "outputs": [],
650
  "source": [
651
  "results = []\n",
652
- "columns = ['Caption', 'Theme'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]"
 
653
  ]
654
  },
655
  {
@@ -660,12 +661,23 @@
660
  "outputs": [],
661
  "source": [
662
  "for i, (idx, scores, sample) in enumerate(zip(top_idx, logits, batch)):\n",
 
663
  " cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
664
  " top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
665
- " top_scores = [logits[x] for x in idx]\n",
666
  " results.append([sample['Caption'], sample['Theme']] + top_images + top_scores)"
667
  ]
668
  },
 
 
 
 
 
 
 
 
 
 
669
  {
670
  "cell_type": "code",
671
  "execution_count": null,
 
46
  "batch_size = 8\n",
47
  "num_images = 128\n",
48
  "top_k = 8\n",
49
+ "text_normalizer = TextNormalizer() if normalize_text else None\n",
50
+ "padding_item = 'NONE'"
51
  ]
52
  },
53
  {
 
96
  " samples = []\n",
97
  " for row in reader:\n",
98
  " samples.append(row)\n",
99
+ " # make list multiple of batch_size by adding elements\n",
100
+ " samples_to_add = [{'Caption':padding_item, 'Theme':padding_item}] * (-len(samples) % batch_size)\n",
101
  " samples.extend(samples_to_add)\n",
102
  " # reshape\n",
103
  " samples = [samples[i:i+batch_size] for i in range(0, len(samples), batch_size)]"
 
389
  " def p_clip(inputs):\n",
390
  " logits = clip(**inputs).logits_per_image\n",
391
  " return logits\n",
 
392
  " \n",
393
  " functions_pmapped = False"
394
  ]
 
649
  "outputs": [],
650
  "source": [
651
  "results = []\n",
652
+ "columns = ['Caption', 'Theme'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]\n",
653
+ "logits = jax.device_get(logits)"
654
  ]
655
  },
656
  {
 
661
  "outputs": [],
662
  "source": [
663
  "for i, (idx, scores, sample) in enumerate(zip(top_idx, logits, batch)):\n",
664
+ " if sample['Caption'] == padding_item: continue\n",
665
  " cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
666
  " top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
667
+ " top_scores = [scores[x] for x in idx]\n",
668
  " results.append([sample['Caption'], sample['Theme']] + top_images + top_scores)"
669
  ]
670
  },
671
+ {
672
+ "cell_type": "code",
673
+ "execution_count": null,
674
+ "id": "4bf40461-99d3-4d36-b7cc-e0129a3c9053",
675
+ "metadata": {},
676
+ "outputs": [],
677
+ "source": [
678
+ "table = wandb.Table(columns=columns, data=results)"
679
+ ]
680
+ },
681
  {
682
  "cell_type": "code",
683
  "execution_count": null,