Orpheous1 commited on
Commit
5dc90b6
1 Parent(s): 9d9aad0
app.py CHANGED
@@ -9,6 +9,8 @@ from utils.plot import smoothen, draw_mask_on_image, draw_heatmap_on_image
9
  import gradio as gr
10
  import numpy as np
11
  import torch
 
 
12
 
13
  # Load Vision Transformer
14
  hf_model = "tanlq/vit-base-patch16-224-in21k-finetuned-cifar10"
@@ -59,19 +61,36 @@ def get_mask(image, model_name: str):
59
  dm_image = feature_extractor(image).unsqueeze(0)
60
  dm_out = diffmask_model.get_mask(dm_image)
61
  mask = dm_out["mask"][0].detach()
 
 
 
 
 
 
 
 
 
 
62
  pred = dm_out["pred_class"][0].detach()
63
  pred = diffmask_model.model.config.id2label[pred.item()]
64
 
65
  masked_img = draw_mask(image, mask)
66
  heatmap = draw_heatmap(image, mask)
67
- return np.hstack((masked_img, heatmap)), pred
 
 
 
 
 
 
68
 
69
  # Launch demo interface
70
  gr.Interface(
71
  get_mask,
72
  inputs=[gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"),
73
  gr.inputs.Dropdown(["DiffMask-CiFAR-10", "DiffMask-ImageNet"])],
74
- outputs=[gr.outputs.Image(label="Output"), gr.outputs.Label(label="Prediction")],
 
75
  title="Vision DiffMask Demo",
76
  live=True,
77
  ).launch()
 
9
  import gradio as gr
10
  import numpy as np
11
  import torch
12
+ import seaborn as sns
13
+ import matplotlib.pyplot as plt
14
 
15
  # Load Vision Transformer
16
  hf_model = "tanlq/vit-base-patch16-224-in21k-finetuned-cifar10"
 
61
  dm_image = feature_extractor(image).unsqueeze(0)
62
  dm_out = diffmask_model.get_mask(dm_image)
63
  mask = dm_out["mask"][0].detach()
64
+ logits = dm_out["logits"][0].detach().softmax(dim=-1)
65
+ logits_orig = dm_out["logits_orig"][0].detach().softmax(dim=-1)
66
+ # fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 10))
67
+ # sns.displot(logits_orig.cpu().numpy().flatten(), kind="kde", label="Original", ax=ax)
68
+ top5logits_orig = logits_orig.topk(5, dim=-1)
69
+ idx = top5logits_orig.indices
70
+ # keep the top 5 classes from the indices of the top 5 logits
71
+ top5logits_orig = top5logits_orig.values
72
+ top5logits = logits[idx]
73
+
74
  pred = dm_out["pred_class"][0].detach()
75
  pred = diffmask_model.model.config.id2label[pred.item()]
76
 
77
  masked_img = draw_mask(image, mask)
78
  heatmap = draw_heatmap(image, mask)
79
+ orig_probs = {diffmask_model.model.config.id2label[i]: top5logits_orig[i].item() for i in range(5)}
80
+ pred_probs = {diffmask_model.model.config.id2label[i]: top5logits[i].item() for i in range(5)}
81
+
82
+ return np.hstack((masked_img, heatmap)), pred, orig_probs, pred_probs
83
+
84
+
85
+
86
 
87
  # Launch demo interface
88
  gr.Interface(
89
  get_mask,
90
  inputs=[gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"),
91
  gr.inputs.Dropdown(["DiffMask-CiFAR-10", "DiffMask-ImageNet"])],
92
+ outputs=[gr.outputs.Image(label="Output"), gr.outputs.Label(label="Prediction"),
93
+ gr.Label(label="Original Probabilities"), gr.Label(label="Predicted Probabilities")],
94
  title="Vision DiffMask Demo",
95
  live=True,
96
  ).launch()
code/datamodules/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (388 Bytes). View file
 
code/datamodules/__pycache__/base.cpython-38.pyc ADDED
Binary file (4.57 kB). View file
 
code/datamodules/__pycache__/transformations.cpython-38.pyc ADDED
Binary file (1.88 kB). View file
 
code/models/__pycache__/interpretation.cpython-39.pyc CHANGED
Binary files a/code/models/__pycache__/interpretation.cpython-39.pyc and b/code/models/__pycache__/interpretation.cpython-39.pyc differ
 
code/models/interpretation.py CHANGED
@@ -277,7 +277,8 @@ class ImageInterpretationNet(pl.LightningModule):
277
  mask = F.interpolate(mask, scale_factor=S)
278
  mask = mask.reshape(B, H, W)
279
 
280
- return {"mask": mask, "kl_div": kl_div, "pred_class": pred_class}
 
281
 
282
  def forward(self, x: Tensor) -> Tensor:
283
  return self.model(x).logits
 
277
  mask = F.interpolate(mask, scale_factor=S)
278
  mask = mask.reshape(B, H, W)
279
 
280
+ return {"mask": mask, "kl_div": kl_div, "pred_class": pred_class,
281
+ "logits": logits, "logits_orig": logits_orig}
282
 
283
  def forward(self, x: Tensor) -> Tensor:
284
  return self.model(x).logits
requirements.txt CHANGED
@@ -4,4 +4,5 @@ pytorch_lightning
4
  torch
5
  torchvision
6
  transformers
7
-
 
 
4
  torch
5
  torchvision
6
  transformers
7
+ seaborn
8
+ matplotlib