Orpheous1 commited on
Commit
8fd2935
1 Parent(s): 414d78a
app.py CHANGED
@@ -22,7 +22,8 @@ feature_extractor = UnNest(feature_extractor)
22
  # Load Vision DiffMask
23
  diffmask = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask.ckpt')
24
  diffmask.set_vision_transformer(vit)
25
-
 
26
 
27
  # Define mask plotting functions
28
  def draw_mask(image, mask):
@@ -40,27 +41,31 @@ def draw_heatmap(image, mask):
40
 
41
 
42
  # Define callable method for the demo
43
- def get_mask(image):
44
  if image is None:
45
  return None, None
46
-
 
 
 
47
  image = torch.from_numpy(image).permute(2, 0, 1).float() / 255
48
  dm_image = feature_extractor(image).unsqueeze(0)
49
- dm_out = diffmask.get_mask(dm_image)
50
  mask = dm_out["mask"][0].detach()
51
  pred = dm_out["pred_class"][0].detach()
52
- pred = diffmask.model.config.id2label[pred.item()]
53
 
54
  masked_img = draw_mask(image, mask)
55
  heatmap = draw_heatmap(image, mask)
56
  return np.hstack((masked_img, heatmap)), pred
57
 
58
-
59
  # Launch demo interface
60
  gr.Interface(
61
  get_mask,
62
- inputs=gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"),
 
63
  outputs=[gr.outputs.Image(label="Output"), gr.outputs.Label(label="Prediction")],
64
  title="Vision DiffMask Demo",
65
  live=True,
66
  ).launch()
 
 
22
  # Load Vision DiffMask
23
  diffmask = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask.ckpt')
24
  diffmask.set_vision_transformer(vit)
25
+ diffmask_imagenet = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask_imagenet.ckpt')
26
+ diffmask_imagenet.set_vision_transformer(vit)
27
 
28
  # Define mask plotting functions
29
  def draw_mask(image, mask):
 
41
 
42
 
43
  # Define callable method for the demo
44
+ def get_mask(image, model_name: str):
45
  if image is None:
46
  return None, None
47
+ if model_name == 'DiffMask-CiFAR-10':
48
+ diffmask_model = diffmask
49
+ elif model_name == 'DiffMask-ImageNet':
50
+ diffmask_model = diffmask_imagenet
51
  image = torch.from_numpy(image).permute(2, 0, 1).float() / 255
52
  dm_image = feature_extractor(image).unsqueeze(0)
53
+ dm_out = model.get_mask(dm_image)
54
  mask = dm_out["mask"][0].detach()
55
  pred = dm_out["pred_class"][0].detach()
56
+ pred = diffmask_model.model.config.id2label[pred.item()]
57
 
58
  masked_img = draw_mask(image, mask)
59
  heatmap = draw_heatmap(image, mask)
60
  return np.hstack((masked_img, heatmap)), pred
61
 
 
62
  # Launch demo interface
63
  gr.Interface(
64
  get_mask,
65
+ inputs=[gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"),
66
+ gr.inputs.Dropdown(["DiffMask-CiFAR-10", "DiffMask-ImageNet"])],
67
  outputs=[gr.outputs.Image(label="Output"), gr.outputs.Label(label="Prediction")],
68
  title="Vision DiffMask Demo",
69
  live=True,
70
  ).launch()
71
+
checkpoints/diffmask_imagenet.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c21d3e65b049984911907d004747e2162cba8a77ac10008bb9b3d8612b745a4f
3
+ size 16607641
code/datamodules/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (388 Bytes). View file
 
code/datamodules/__pycache__/base.cpython-39.pyc ADDED
Binary file (4.57 kB). View file
 
code/datamodules/__pycache__/transformations.cpython-39.pyc ADDED
Binary file (1.88 kB). View file