Samuel Stevens commited on
Commit
c4ee5c3
·
1 Parent(s): 699b9c3

Include original predictions

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +63 -68
  3. data.py +1 -1
  4. modeling.py +2 -2
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -125,7 +125,6 @@ def add_highlights(
125
  upper: int | None = None,
126
  opacity: float = 0.9,
127
  ) -> Image.Image:
128
- breakpoint()
129
  if not len(patches):
130
  return img
131
 
@@ -198,18 +197,22 @@ class SaeActivation(typing.TypedDict):
198
 
199
 
200
  @beartype.beartype
201
- def get_image(i: int) -> tuple[str, str, int]:
202
- img_sized = data.to_sized(data.get_image(i))
203
  seg_sized = data.to_sized(data.get_seg(i))
204
  seg_u8_sized = data.to_u8(seg_sized)
205
  seg_img_sized = data.u8_to_img(seg_u8_sized)
206
 
207
- return data.img_to_base64(img_sized), data.img_to_base64(seg_img_sized), i
 
 
 
 
208
 
209
 
210
  @beartype.beartype
211
  @torch.inference_mode
212
- def get_sae_activations(image_i: int, patches: list[int]) -> list[SaeActivation]:
213
  """
214
  Given a particular cell, returns some highlighted images showing what feature fires most on this cell.
215
  """
@@ -219,7 +222,7 @@ def get_sae_activations(image_i: int, patches: list[int]) -> list[SaeActivation]
219
  split_vit, vit_transform = modeling.load_vit(DEVICE)
220
  sae = load_sae(DEVICE)
221
 
222
- img = data.get_image(image_i)
223
 
224
  x_BCWH = vit_transform(img)[None, ...].to(DEVICE)
225
 
@@ -261,7 +264,7 @@ def get_sae_activations(image_i: int, patches: list[int]) -> list[SaeActivation]
261
  examples = []
262
  for i_im, values_p in pairs:
263
  seg_sized = data.to_sized(data.get_seg(i_im))
264
- img_sized = data.to_sized(data.get_image(i_im))
265
 
266
  seg_u8_sized = data.to_u8(seg_sized)
267
  seg_img_sized = data.u8_to_img(seg_u8_sized)
@@ -286,26 +289,27 @@ def get_sae_activations(image_i: int, patches: list[int]) -> list[SaeActivation]
286
 
287
 
288
  @torch.inference_mode
289
- def get_true_labels(image_i: int) -> Image.Image:
290
- seg = human_dataset[image_i]["segmentation"]
291
- image = seg_to_img(seg)
292
- return image
293
 
 
294
 
295
- @torch.inference_mode
296
- def get_pred_labels(i: int) -> list[Image.Image | list[int]]:
297
- sample = vit_dataset[i]
298
- x = sample["image"][None, ...].to(device)
299
- x_BPD = rest_of_vit.forward_start(x)
300
- x_BPD = rest_of_vit.forward_end(x_BPD)
301
 
302
  x_WHD = einops.rearrange(x_BPD, "() (w h) dim -> w h dim", w=16, h=16)
303
 
304
- logits_WHC = head(x_WHD)
 
305
 
306
  pred_WH = logits_WHC.argmax(axis=-1)
307
- preds = einops.rearrange(pred_WH, "w h -> (w h)").tolist()
308
- return [seg_to_img(upsample(pred_WH)), preds]
 
 
 
 
309
 
310
 
311
  @beartype.beartype
@@ -393,64 +397,55 @@ def upsample(
393
 
394
 
395
  with gr.Blocks() as demo:
396
- image_number = gr.Number(label="Validation Example")
 
 
 
 
 
397
 
398
- input_image_base64 = gr.Text(label="Image in Base64")
399
- true_labels_base64 = gr.Text(label="Labels in Base64")
400
 
401
- get_input_image_btn = gr.Button(value="Get Input Image")
402
- get_input_image_btn.click(
403
- get_image,
404
- inputs=[image_number],
405
- outputs=[input_image_base64, true_labels_base64, image_number],
406
- api_name="get-image",
407
  )
408
 
409
- # input_image = gr.Image(
410
- # label="Input Image",
411
- # sources=["upload", "clipboard"],
412
- # type="pil",
413
- # interactive=True,
414
- # )
415
- # patch_numbers = gr.CheckboxGroup(label="Image Patch", choices=list(range(256)))
416
- # top_latent_numbers = gr.CheckboxGroup(label="Top Latents")
417
- # top_latent_numbers = [
418
- # gr.Number(label="Top Latents #{j+1}") for j in range(n_sae_latents)
419
- # ]
420
- # sae_example_images = [
421
- # gr.Image(label=f"Latent #{j}, Example #{i + 1}", format="png")
422
- # for i in range(n_sae_examples)
423
- # for j in range(n_sae_latents)
424
- # ]
425
 
 
426
  patches_json = gr.JSON(label="Patches", value=[])
427
- activations_json = gr.JSON(label="Activations", value=[])
428
-
429
- get_sae_activations_btn = gr.Button(value="Get SAE Activations")
430
- get_sae_activations_btn.click(
431
- get_sae_activations,
432
- inputs=[image_number, patches_json],
433
- outputs=[activations_json],
434
- api_name="get-sae-examples",
 
435
  )
436
- # semseg_image = gr.Image(label="Semantic Segmentaions", format="png")
437
- # semseg_colors = gr.CheckboxGroup(
438
- # label="Sem Seg Colors", choices=list(range(1, 151))
439
- # )
440
 
441
- # get_pred_labels_btn = gr.Button(value="Get Pred. Labels")
442
- # get_pred_labels_btn.click(
443
- # get_pred_labels,
444
- # inputs=[image_number],
445
- # outputs=[semseg_image, semseg_colors],
446
- # api_name="get-pred-labels",
447
- # )
 
 
 
 
448
 
449
  # get_true_labels_btn = gr.Button(value="Get True Label")
450
  # get_true_labels_btn.click(
451
  # get_true_labels,
452
- # inputs=[image_number],
453
- # outputs=semseg_image,
454
  # api_name="get-true-labels",
455
  # )
456
 
@@ -462,8 +457,8 @@ with gr.Blocks() as demo:
462
  # get_modified_labels_btn = gr.Button(value="Get Modified Label")
463
  # get_modified_labels_btn.click(
464
  # get_modified_labels,
465
- # inputs=[image_number] + latent_numbers + value_sliders,
466
- # outputs=[semseg_image, semseg_colors],
467
  # api_name="get-modified-labels",
468
  # )
469
 
 
125
  upper: int | None = None,
126
  opacity: float = 0.9,
127
  ) -> Image.Image:
 
128
  if not len(patches):
129
  return img
130
 
 
197
 
198
 
199
  @beartype.beartype
200
+ def get_img(i: int) -> dict[str, object]:
201
+ img_sized = data.to_sized(data.get_img(i))
202
  seg_sized = data.to_sized(data.get_seg(i))
203
  seg_u8_sized = data.to_u8(seg_sized)
204
  seg_img_sized = data.u8_to_img(seg_u8_sized)
205
 
206
+ return {
207
+ "index": i,
208
+ "orig_url": data.img_to_base64(img_sized),
209
+ "seg_url": data.img_to_base64(seg_img_sized),
210
+ }
211
 
212
 
213
  @beartype.beartype
214
  @torch.inference_mode
215
+ def get_sae_latents(img_i: int, patches: list[int]) -> list[SaeActivation]:
216
  """
217
  Given a particular cell, returns some highlighted images showing what feature fires most on this cell.
218
  """
 
222
  split_vit, vit_transform = modeling.load_vit(DEVICE)
223
  sae = load_sae(DEVICE)
224
 
225
+ img = data.get_img(img_i)
226
 
227
  x_BCWH = vit_transform(img)[None, ...].to(DEVICE)
228
 
 
264
  examples = []
265
  for i_im, values_p in pairs:
266
  seg_sized = data.to_sized(data.get_seg(i_im))
267
+ img_sized = data.to_sized(data.get_img(i_im))
268
 
269
  seg_u8_sized = data.to_u8(seg_sized)
270
  seg_img_sized = data.u8_to_img(seg_u8_sized)
 
289
 
290
 
291
  @torch.inference_mode
292
+ def get_preds(i: int) -> dict[str, object]:
293
+ img = data.get_img(i)
294
+ split_vit, vit_transform = modeling.load_vit(DEVICE)
 
295
 
296
+ x_BCWH = vit_transform(img)[None, ...].to(DEVICE)
297
 
298
+ x_BPD = split_vit.forward_start(x_BCWH)
299
+ x_BPD = split_vit.forward_end(x_BPD)
 
 
 
 
300
 
301
  x_WHD = einops.rearrange(x_BPD, "() (w h) dim -> w h dim", w=16, h=16)
302
 
303
+ clf = load_clf()
304
+ logits_WHC = clf(x_WHD)
305
 
306
  pred_WH = logits_WHC.argmax(axis=-1)
307
+ # preds = einops.rearrange(pred_WH, "w h -> (w h)").tolist()
308
+ return {
309
+ "index": i,
310
+ "orig_url": data.img_to_base64(data.to_sized(img)),
311
+ "seg_url": data.img_to_base64(data.u8_to_img(upsample(pred_WH))),
312
+ }
313
 
314
 
315
  @beartype.beartype
 
397
 
398
 
399
  with gr.Blocks() as demo:
400
+ ###########
401
+ # get-img #
402
+ ###########
403
+
404
+ # Inputs
405
+ img_number = gr.Number(label="Example Index")
406
 
407
+ # Outputs
408
+ get_img_out = gr.JSON(label="get_img_out", value={})
409
 
410
+ get_input_img_btn = gr.Button(value="Get Input Image")
411
+ get_input_img_btn.click(
412
+ get_img, inputs=[img_number], outputs=[get_img_out], api_name="get-img"
 
 
 
413
  )
414
 
415
+ ###################
416
+ # get-sae-latents #
417
+ ###################
 
 
 
 
 
 
 
 
 
 
 
 
 
418
 
419
+ # Inputs
420
  patches_json = gr.JSON(label="Patches", value=[])
421
+ # Outputs
422
+ get_sae_latents_out = gr.JSON(label="get_sae_latents_out", value=[])
423
+
424
+ get_sae_latents_btn = gr.Button(value="Get SAE Latents")
425
+ get_sae_latents_btn.click(
426
+ get_sae_latents,
427
+ inputs=[img_number, patches_json],
428
+ outputs=[get_sae_latents_out],
429
+ api_name="get-sae-latents",
430
  )
 
 
 
 
431
 
432
+ #############
433
+ # get-preds #
434
+ #############
435
+
436
+ # Outputs
437
+ get_preds_out = gr.JSON(label="get_preds_out", value=[])
438
+
439
+ get_pred_labels_btn = gr.Button(value="Get Predictions")
440
+ get_pred_labels_btn.click(
441
+ get_preds, inputs=[img_number], outputs=[get_preds_out], api_name="get-preds"
442
+ )
443
 
444
  # get_true_labels_btn = gr.Button(value="Get True Label")
445
  # get_true_labels_btn.click(
446
  # get_true_labels,
447
+ # inputs=[img_number],
448
+ # outputs=semseg_img,
449
  # api_name="get-true-labels",
450
  # )
451
 
 
457
  # get_modified_labels_btn = gr.Button(value="Get Modified Label")
458
  # get_modified_labels_btn.click(
459
  # get_modified_labels,
460
+ # inputs=[img_number] + latent_numbers + value_sliders,
461
+ # outputs=[semseg_img, semseg_colors],
462
  # api_name="get-modified-labels",
463
  # )
464
 
data.py CHANGED
@@ -20,7 +20,7 @@ R2_URL = "https://pub-129e98faed1048af94c4d4119ea47be7.r2.dev"
20
 
21
  @beartype.beartype
22
  @functools.lru_cache(maxsize=512)
23
- def get_image(i: int) -> Image.Image:
24
  fpath = f"/images/ADE_val_{i + 1:08}.jpg"
25
  url = R2_URL + fpath
26
  logger.info("Getting image from '%s'.", url)
 
20
 
21
  @beartype.beartype
22
  @functools.lru_cache(maxsize=512)
23
+ def get_img(i: int) -> Image.Image:
24
  fpath = f"/images/ADE_val_{i + 1:08}.jpg"
25
  url = R2_URL + fpath
26
  logger.info("Getting image from '%s'.", url)
modeling.py CHANGED
@@ -21,7 +21,7 @@ class SplitDinov2(torch.nn.Module):
21
 
22
  def forward_start(
23
  self, x: Float[Tensor, "batch channels width height"]
24
- ) -> Float[Tensor, "batch patches dim"]:
25
  x_BPD = self.vit.prepare_tokens_with_masks(x)
26
  for blk in self.vit.blocks[: self.split_at]:
27
  x_BPD = blk(x_BPD)
@@ -29,7 +29,7 @@ class SplitDinov2(torch.nn.Module):
29
  return x_BPD
30
 
31
  def forward_end(
32
- self, x_BPD: Float[Tensor, "batch n_patches dim"]
33
  ) -> Float[Tensor, "batch patches dim"]:
34
  for blk in self.vit.blocks[-self.split_at :]:
35
  x_BPD = blk(x_BPD)
 
21
 
22
  def forward_start(
23
  self, x: Float[Tensor, "batch channels width height"]
24
+ ) -> Float[Tensor, "batch total_patches dim"]:
25
  x_BPD = self.vit.prepare_tokens_with_masks(x)
26
  for blk in self.vit.blocks[: self.split_at]:
27
  x_BPD = blk(x_BPD)
 
29
  return x_BPD
30
 
31
  def forward_end(
32
+ self, x_BPD: Float[Tensor, "batch total_patches dim"]
33
  ) -> Float[Tensor, "batch patches dim"]:
34
  for blk in self.vit.blocks[-self.split_at :]:
35
  x_BPD = blk(x_BPD)