CodeJackR
commited on
Commit
·
f9b3f94
1
Parent(s):
16a5f8c
Update handler to handle inputs
Browse files- handler.py +57 -14
handler.py
CHANGED
|
@@ -7,6 +7,7 @@ from PIL import Image
|
|
| 7 |
import torch
|
| 8 |
from transformers import SamModel, SamProcessor
|
| 9 |
from typing import Dict, List, Any
|
|
|
|
| 10 |
|
| 11 |
class EndpointHandler():
|
| 12 |
def __init__(self, path=""):
|
|
@@ -76,20 +77,38 @@ class EndpointHandler():
|
|
| 76 |
with torch.no_grad():
|
| 77 |
outputs = self.model(**inputs)
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
mask_binary = (mask > 0.5).astype(np.uint8) * 255
|
| 91 |
-
else:
|
| 92 |
-
# Fallback: create a simple center mask
|
| 93 |
mask_binary = np.zeros((height, width), dtype=np.uint8)
|
| 94 |
center_x, center_y = width // 2, height // 2
|
| 95 |
size = min(width, height) // 8
|
|
@@ -102,4 +121,28 @@ class EndpointHandler():
|
|
| 102 |
mask_base64 = base64.b64encode(out.getvalue()).decode('utf-8')
|
| 103 |
|
| 104 |
# Return in the expected format
|
| 105 |
-
return [{"mask_png_base64": mask_base64, "num_masks":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
import torch
|
| 8 |
from transformers import SamModel, SamProcessor
|
| 9 |
from typing import Dict, List, Any
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
|
| 12 |
class EndpointHandler():
|
| 13 |
def __init__(self, path=""):
|
|
|
|
| 77 |
with torch.no_grad():
|
| 78 |
outputs = self.model(**inputs)
|
| 79 |
|
| 80 |
+
try:
|
| 81 |
+
# Get original image size
|
| 82 |
+
original_height, original_width = inputs["original_sizes"][0].tolist()
|
| 83 |
+
|
| 84 |
+
# Get predicted masks and scores
|
| 85 |
+
pred_masks = outputs.pred_masks.cpu() # (batch, num_masks, H, W)
|
| 86 |
+
iou_scores = outputs.iou_scores.cpu()[0] # (num_masks,)
|
| 87 |
+
|
| 88 |
+
# The model might return 4D or 5D tensors. Squeeze if 5D.
|
| 89 |
+
if pred_masks.ndim == 5:
|
| 90 |
+
pred_masks = pred_masks.squeeze(1)
|
| 91 |
+
|
| 92 |
+
# Select the best mask
|
| 93 |
+
best_mask_idx = torch.argmax(iou_scores)
|
| 94 |
+
best_mask_tensor = pred_masks[0, best_mask_idx, :, :] # (H, W)
|
| 95 |
+
|
| 96 |
+
# Upscale the mask to original image size
|
| 97 |
+
# Add batch and channel dims for interpolate
|
| 98 |
+
upscaled_mask = F.interpolate(
|
| 99 |
+
best_mask_tensor.unsqueeze(0).unsqueeze(0).float(),
|
| 100 |
+
size=(original_height, original_width),
|
| 101 |
+
mode='bilinear',
|
| 102 |
+
align_corners=False
|
| 103 |
+
).squeeze() # remove batch/channel dims
|
| 104 |
+
|
| 105 |
+
# Convert to binary mask
|
| 106 |
+
mask_binary = (upscaled_mask > 0.0).numpy().astype(np.uint8) * 255
|
| 107 |
|
| 108 |
+
except Exception as e:
|
| 109 |
+
print(f"Error processing masks: {e}")
|
| 110 |
+
# Fallback
|
| 111 |
+
height, width = img.size[1], img.size[0]
|
|
|
|
|
|
|
|
|
|
| 112 |
mask_binary = np.zeros((height, width), dtype=np.uint8)
|
| 113 |
center_x, center_y = width // 2, height // 2
|
| 114 |
size = min(width, height) // 8
|
|
|
|
| 121 |
mask_base64 = base64.b64encode(out.getvalue()).decode('utf-8')
|
| 122 |
|
| 123 |
# Return in the expected format
|
| 124 |
+
return [{"mask_png_base64": mask_base64, "num_masks": 1}]
|
| 125 |
+
|
| 126 |
+
def main():
|
| 127 |
+
# Hardcoded input and output paths
|
| 128 |
+
input_path = "/Users/rp7/Downloads/test.jpeg"
|
| 129 |
+
output_path = "output.jpg"
|
| 130 |
+
|
| 131 |
+
# Read and base64-encode the input image
|
| 132 |
+
with open(input_path, "rb") as f:
|
| 133 |
+
img_bytes = f.read()
|
| 134 |
+
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
|
| 135 |
+
data_url = f"data:image/jpeg;base64,{img_b64}"
|
| 136 |
+
|
| 137 |
+
handler = EndpointHandler(path=".")
|
| 138 |
+
result = handler({"inputs": data_url})[0]
|
| 139 |
+
|
| 140 |
+
# Decode the returned mask and save
|
| 141 |
+
mask_bytes = base64.b64decode(result["mask_png_base64"])
|
| 142 |
+
mask_img = Image.open(io.BytesIO(mask_bytes)).convert("RGB")
|
| 143 |
+
mask_img.save(output_path, format="JPEG")
|
| 144 |
+
print(f"Wrote mask to {output_path}")
|
| 145 |
+
|
| 146 |
+
if __name__ == "__main__":
|
| 147 |
+
main()
|
| 148 |
+
|