JingShiang Yang commited on
Commit
d5f470e
·
1 Parent(s): decc98d

Fix mask dimension handling with squeeze

Browse files
Files changed (1) hide show
  1. handler.py +9 -4
handler.py CHANGED
@@ -48,14 +48,19 @@ class EndpointHandler:
48
  labels = np.array(params.get("point_labels", [1]), dtype=np.float32)
49
 
50
  # Decode
51
- masks = self.decoder.run(None, {
52
  'image_embeddings': embeddings,
53
  'point_coords': coords.reshape(1, -1, 2),
54
  'point_labels': labels.reshape(1, -1)
55
- })[0]
56
 
57
- # Postprocess
58
- mask = (masks[0, 0] > 0.0).astype(np.uint8) * 255
 
 
 
 
 
59
 
60
  # Return result
61
  result = {"mask_shape": list(mask.shape), "has_object": bool(mask.max() > 0)}
 
48
  labels = np.array(params.get("point_labels", [1]), dtype=np.float32)
49
 
50
  # Decode
51
+ decoder_outputs = self.decoder.run(None, {
52
  'image_embeddings': embeddings,
53
  'point_coords': coords.reshape(1, -1, 2),
54
  'point_labels': labels.reshape(1, -1)
55
+ })
56
 
57
+ masks = decoder_outputs[0]
58
+
59
+ # Postprocess - squeeze to get 2D mask
60
+ mask = masks.squeeze() # Remove all dimensions of size 1
61
+ if len(mask.shape) > 2:
62
+ mask = mask[0] # Take first mask if multiple
63
+ mask = (mask > 0.0).astype(np.uint8) * 255
64
 
65
  # Return result
66
  result = {"mask_shape": list(mask.shape), "has_object": bool(mask.max() > 0)}