CodeJackR commited on
Commit
f9b3f94
·
1 Parent(s): 16a5f8c

Update handler to handle inputs

Browse files
Files changed (1) hide show
  1. handler.py +57 -14
handler.py CHANGED
@@ -7,6 +7,7 @@ from PIL import Image
7
  import torch
8
  from transformers import SamModel, SamProcessor
9
  from typing import Dict, List, Any
 
10
 
11
  class EndpointHandler():
12
  def __init__(self, path=""):
@@ -76,20 +77,38 @@ class EndpointHandler():
76
  with torch.no_grad():
77
  outputs = self.model(**inputs)
78
 
79
- # Process the outputs to get masks
80
- masks = self.processor.image_processor.post_process_masks(
81
- outputs.pred_masks.cpu(),
82
- inputs["original_sizes"].cpu(),
83
- inputs["reshaped_input_sizes"].cpu()
84
- )[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- # Convert the best mask to a binary mask
87
- # SAM returns multiple masks, take the first one
88
- if len(masks) > 0:
89
- mask = masks[0].squeeze().numpy()
90
- mask_binary = (mask > 0.5).astype(np.uint8) * 255
91
- else:
92
- # Fallback: create a simple center mask
93
  mask_binary = np.zeros((height, width), dtype=np.uint8)
94
  center_x, center_y = width // 2, height // 2
95
  size = min(width, height) // 8
@@ -102,4 +121,28 @@ class EndpointHandler():
102
  mask_base64 = base64.b64encode(out.getvalue()).decode('utf-8')
103
 
104
  # Return in the expected format
105
- return [{"mask_png_base64": mask_base64, "num_masks": len(masks)}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import torch
8
  from transformers import SamModel, SamProcessor
9
  from typing import Dict, List, Any
10
+ import torch.nn.functional as F
11
 
12
  class EndpointHandler():
13
  def __init__(self, path=""):
 
77
  with torch.no_grad():
78
  outputs = self.model(**inputs)
79
 
80
+ try:
81
+ # Get original image size
82
+ original_height, original_width = inputs["original_sizes"][0].tolist()
83
+
84
+ # Get predicted masks and scores
85
+ pred_masks = outputs.pred_masks.cpu() # (batch, num_masks, H, W)
86
+ iou_scores = outputs.iou_scores.cpu()[0] # (num_masks,)
87
+
88
+ # The model might return 4D or 5D tensors. Squeeze if 5D.
89
+ if pred_masks.ndim == 5:
90
+ pred_masks = pred_masks.squeeze(1)
91
+
92
+ # Select the best mask
93
+ best_mask_idx = torch.argmax(iou_scores)
94
+ best_mask_tensor = pred_masks[0, best_mask_idx, :, :] # (H, W)
95
+
96
+ # Upscale the mask to original image size
97
+ # Add batch and channel dims for interpolate
98
+ upscaled_mask = F.interpolate(
99
+ best_mask_tensor.unsqueeze(0).unsqueeze(0).float(),
100
+ size=(original_height, original_width),
101
+ mode='bilinear',
102
+ align_corners=False
103
+ ).squeeze() # remove batch/channel dims
104
+
105
+ # Convert to binary mask
106
+ mask_binary = (upscaled_mask > 0.0).numpy().astype(np.uint8) * 255
107
 
108
+ except Exception as e:
109
+ print(f"Error processing masks: {e}")
110
+ # Fallback
111
+ height, width = img.size[1], img.size[0]
 
 
 
112
  mask_binary = np.zeros((height, width), dtype=np.uint8)
113
  center_x, center_y = width // 2, height // 2
114
  size = min(width, height) // 8
 
121
  mask_base64 = base64.b64encode(out.getvalue()).decode('utf-8')
122
 
123
  # Return in the expected format
124
+ return [{"mask_png_base64": mask_base64, "num_masks": 1}]
125
+
126
+ def main():
127
+ # Hardcoded input and output paths
128
+ input_path = "/Users/rp7/Downloads/test.jpeg"
129
+ output_path = "output.jpg"
130
+
131
+ # Read and base64-encode the input image
132
+ with open(input_path, "rb") as f:
133
+ img_bytes = f.read()
134
+ img_b64 = base64.b64encode(img_bytes).decode("utf-8")
135
+ data_url = f"data:image/jpeg;base64,{img_b64}"
136
+
137
+ handler = EndpointHandler(path=".")
138
+ result = handler({"inputs": data_url})[0]
139
+
140
+ # Decode the returned mask and save
141
+ mask_bytes = base64.b64decode(result["mask_png_base64"])
142
+ mask_img = Image.open(io.BytesIO(mask_bytes)).convert("RGB")
143
+ mask_img.save(output_path, format="JPEG")
144
+ print(f"Wrote mask to {output_path}")
145
+
146
+ if __name__ == "__main__":
147
+ main()
148
+