flaviagiammarino commited on
Commit
a3833bb
1 Parent(s): 19cec72

Upload 2 files

Browse files
Files changed (2) hide show
  1. scripts/pt_example.py +3 -3
  2. scripts/tf_example.py +3 -2
scripts/pt_example.py CHANGED
@@ -1,9 +1,9 @@
1
  import requests
2
- import torch
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
  from PIL import Image
6
  from transformers import SamModel, SamProcessor
 
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
@@ -16,7 +16,7 @@ input_boxes = [95., 255., 190., 350.]
16
 
17
  inputs = processor(raw_image, input_boxes=[[input_boxes]], return_tensors="pt").to(device)
18
  outputs = model(**inputs, multimask_output=False)
19
- masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
20
 
21
  def show_mask(mask, ax, random_color):
22
  if random_color:
@@ -38,7 +38,7 @@ show_box(input_boxes, ax[0])
38
  ax[0].set_title("Input Image and Bounding Box")
39
  ax[0].axis("off")
40
  ax[1].imshow(np.array(raw_image))
41
- show_mask(masks[0], ax=ax[1], random_color=False)
42
  show_box(input_boxes, ax[1])
43
  ax[1].set_title("MedSAM Segmentation")
44
  ax[1].axis("off")
 
1
  import requests
 
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
  from PIL import Image
5
  from transformers import SamModel, SamProcessor
6
+ import torch
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
 
16
 
17
  inputs = processor(raw_image, input_boxes=[[input_boxes]], return_tensors="pt").to(device)
18
  outputs = model(**inputs, multimask_output=False)
19
+ probs = processor.image_processor.post_process_masks(outputs.pred_masks.sigmoid().cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu(), binarize=False)
20
 
21
  def show_mask(mask, ax, random_color):
22
  if random_color:
 
38
  ax[0].set_title("Input Image and Bounding Box")
39
  ax[0].axis("off")
40
  ax[1].imshow(np.array(raw_image))
41
+ show_mask(mask=probs[0] > 0.5, ax=ax[1], random_color=False)
42
  show_box(input_boxes, ax[1])
43
  ax[1].set_title("MedSAM Segmentation")
44
  ax[1].axis("off")
scripts/tf_example.py CHANGED
@@ -3,6 +3,7 @@ import numpy as np
3
  import matplotlib.pyplot as plt
4
  from PIL import Image
5
  from transformers import TFSamModel, SamProcessor
 
6
 
7
  model = TFSamModel.from_pretrained("flaviagiammarino/medsam-vit-base")
8
  processor = SamProcessor.from_pretrained("flaviagiammarino/medsam-vit-base")
@@ -13,7 +14,7 @@ input_boxes = [95., 255., 190., 350.]
13
 
14
  inputs = processor(raw_image, input_boxes=[[input_boxes]], return_tensors="tf")
15
  outputs = model(**inputs, multimask_output=False)
16
- masks = processor.image_processor.post_process_masks([outputs.pred_masks.numpy()[0],], inputs["original_sizes"].numpy(), inputs["reshaped_input_sizes"].numpy())
17
 
18
  def show_mask(mask, ax, random_color):
19
  if random_color:
@@ -35,7 +36,7 @@ show_box(input_boxes, ax[0])
35
  ax[0].set_title("Input Image and Bounding Box")
36
  ax[0].axis("off")
37
  ax[1].imshow(np.array(raw_image))
38
- show_mask(masks[0], ax=ax[1], random_color=False)
39
  show_box(input_boxes, ax[1])
40
  ax[1].set_title("MedSAM Segmentation")
41
  ax[1].axis("off")
 
3
  import matplotlib.pyplot as plt
4
  from PIL import Image
5
  from transformers import TFSamModel, SamProcessor
6
+ import tensorflow as tf
7
 
8
  model = TFSamModel.from_pretrained("flaviagiammarino/medsam-vit-base")
9
  processor = SamProcessor.from_pretrained("flaviagiammarino/medsam-vit-base")
 
14
 
15
  inputs = processor(raw_image, input_boxes=[[input_boxes]], return_tensors="tf")
16
  outputs = model(**inputs, multimask_output=False)
17
+ probs = processor.image_processor.post_process_masks([tf.sigmoid(outputs.pred_masks).numpy()[0],], inputs["original_sizes"].numpy(), inputs["reshaped_input_sizes"].numpy(), binarize=False)
18
 
19
  def show_mask(mask, ax, random_color):
20
  if random_color:
 
36
  ax[0].set_title("Input Image and Bounding Box")
37
  ax[0].axis("off")
38
  ax[1].imshow(np.array(raw_image))
39
+ show_mask(mask=probs[0] > 0.5, ax=ax[1], random_color=False)
40
  show_box(input_boxes, ax[1])
41
  ax[1].set_title("MedSAM Segmentation")
42
  ax[1].axis("off")