din0s commited on
Commit
d85fbeb
1 Parent(s): 4fef05a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -42,22 +42,25 @@ def draw_heatmap(image, mask):
42
  # Define callable method for the demo
43
  def get_mask(image):
44
  if image is None:
45
- return None
46
 
47
  image = torch.from_numpy(image).permute(2, 0, 1).float() / 255
48
  dm_image = feature_extractor(image).unsqueeze(0)
49
- mask = diffmask.get_mask(dm_image)["mask"][0].detach()
 
 
 
50
 
51
  masked_img = draw_mask(image, mask)
52
  heatmap = draw_heatmap(image, mask)
53
- return np.hstack((masked_img, heatmap))
54
 
55
 
56
  # Launch demo interface
57
  gr.Interface(
58
  get_mask,
59
  inputs=gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"),
60
- outputs=[gr.outputs.Image(label="Output")],
61
  title="Vision DiffMask Demo",
62
  live=True,
63
  ).launch()
 
42
  # Define callable method for the demo
43
  def get_mask(image):
44
  if image is None:
45
+ return None, None
46
 
47
  image = torch.from_numpy(image).permute(2, 0, 1).float() / 255
48
  dm_image = feature_extractor(image).unsqueeze(0)
49
+ dm_out = diffmask.get_mask(dm_image)
50
+ mask = dm_out["mask"][0].detach()
51
+ pred = dm_out["pred_class"][0].detach()
52
+ pred = diffmask.model.config.id2label[pred.item()]
53
 
54
  masked_img = draw_mask(image, mask)
55
  heatmap = draw_heatmap(image, mask)
56
+ return np.hstack((masked_img, heatmap)), pred
57
 
58
 
59
  # Launch demo interface
60
  gr.Interface(
61
  get_mask,
62
  inputs=gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"),
63
+ outputs=[gr.outputs.Image(label="Output"), gr.outputs.Label(label="Prediction")],
64
  title="Vision DiffMask Demo",
65
  live=True,
66
  ).launch()