CodeJackR commited on
Commit
233c56f
·
1 Parent(s): d87db6a

Manage image resizing

Browse files
Files changed (1) hide show
  1. handler.py +28 -39
handler.py CHANGED
@@ -64,50 +64,39 @@ class EndpointHandler():
64
 
65
  # 4. Process and select the best mask
66
  try:
67
- # Use the processor's post_process_masks method correctly
68
- post_processed_masks = self.processor.post_process_masks(
69
- outputs.pred_masks,
70
- inputs["original_sizes"],
71
- inputs["reshaped_input_sizes"]
72
- )
73
 
74
- # post_processed_masks is a list with one element (for batch size 1)
75
- masks = post_processed_masks[0] # Shape: (num_masks, H, W)
76
- iou_scores = outputs.iou_scores[0] # Shape: (num_masks,)
77
 
78
- print("Number of masks generated: {}".format(masks.shape[0]))
79
- print("IoU scores: {}".format(iou_scores.tolist()))
 
 
 
 
 
80
 
81
- # Ensure we have masks and select the best one safely
82
- if masks.shape[0] > 0:
83
- best_mask_idx = torch.argmax(iou_scores)
84
- # Ensure the index is within bounds
85
- best_mask_idx = min(best_mask_idx.item(), masks.shape[0] - 1)
86
- best_mask = masks[best_mask_idx] # Shape: (H, W)
87
-
88
- # Safely convert to 2D by squeezing all singleton dimensions
89
- best_mask = best_mask.squeeze()
90
-
91
- # If still not 2D, take the last 2 dimensions
92
- if best_mask.ndim > 2:
93
- # Take the last 2 dimensions which should be height and width
94
- best_mask = best_mask.view(-1, best_mask.shape[-1])
95
- elif best_mask.ndim == 1:
96
- # If somehow we got 1D, try to reshape to square
97
- size = int(best_mask.shape[0] ** 0.5)
98
- if size * size == best_mask.shape[0]:
99
- best_mask = best_mask.view(size, size)
100
- else:
101
- raise ValueError("Cannot reshape 1D mask to 2D")
102
-
103
- print("Final mask shape: {}".format(best_mask.shape))
104
- else:
105
- raise ValueError("No masks were generated")
106
 
107
- # Convert to numpy and create binary mask
108
- mask_binary = (best_mask > 0.0).cpu().numpy().astype(np.uint8) * 255
 
 
 
 
 
109
 
110
- print("Final mask_binary shape: {}".format(mask_binary.shape))
 
111
 
112
  except Exception as e:
113
  print("Error processing masks: {}".format(e))
 
64
 
65
  # 4. Process and select the best mask
66
  try:
67
+ # Get original image dimensions
68
+ original_height, original_width = img.size[1], img.size[0]
69
+
70
+ # Get predicted masks and scores
71
+ predicted_masks = outputs.pred_masks.cpu()
72
+ iou_scores = outputs.iou_scores.cpu()[0]
73
 
74
+ # Handle different tensor dimensions
75
+ if predicted_masks.ndim == 5:
76
+ predicted_masks = predicted_masks.squeeze(1)
77
 
78
+ # Resize masks to standard size first
79
+ predicted_masks = torch.nn.functional.interpolate(
80
+ predicted_masks,
81
+ size=(1024, 1024),
82
+ mode='bilinear',
83
+ align_corners=False
84
+ )
85
 
86
+ # Select the best mask
87
+ best_mask_idx = torch.argmax(iou_scores)
88
+ best_mask = predicted_masks[0, best_mask_idx, :, :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ # Resize to original image dimensions
91
+ final_mask = torch.nn.functional.interpolate(
92
+ best_mask.unsqueeze(0).unsqueeze(0),
93
+ size=(original_height, original_width),
94
+ mode='bilinear',
95
+ align_corners=False
96
+ ).squeeze()
97
 
98
+ # Convert to binary mask
99
+ mask_binary = (final_mask > 0.0).numpy().astype(np.uint8) * 255
100
 
101
  except Exception as e:
102
  print("Error processing masks: {}".format(e))