CodeJackR commited on
Commit
38a30a4
·
1 Parent(s): f8836df

Manage image resizing

Browse files
Files changed (1) hide show
  1. handler.py +46 -20
handler.py CHANGED
@@ -64,28 +64,54 @@ class EndpointHandler():
64
 
65
  # 4. Process and select the best mask
66
  try:
67
- pred_masks_raw = outputs.pred_masks.cpu()
68
-
69
- # The model may output 5-dim tensors, but the post-processor expects 4-dim.
70
- # We squeeze the extra dimension to fix this.
71
- if pred_masks_raw.ndim == 5:
72
- pred_masks_raw = pred_masks_raw.squeeze(1)
73
-
74
- # Use the processor's post-processing utility to resize masks and remove padding
75
- masks = self.processor.post_process_masks(
76
- pred_masks_raw,
77
- inputs["original_sizes"].cpu(),
78
- inputs["reshaped_input_sizes"].cpu()
79
- )[0]
80
 
81
- # The output of post_process_masks is a tensor of shape (num_masks, H, W)
82
- # where H and W are the original image dimensions.
83
- iou_scores = outputs.iou_scores.cpu()[0]
 
 
84
  best_mask_idx = torch.argmax(iou_scores)
85
- best_mask_tensor = masks[best_mask_idx, :, :]
86
-
87
- # Convert to binary mask (float tensor to uint8 numpy array)
88
- mask_binary = (best_mask_tensor > 0.0).numpy().astype(np.uint8) * 255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  except Exception as e:
91
  print("Error processing masks: {}".format(e))
 
64
 
65
  # 4. Process and select the best mask
66
  try:
67
+ # Get the original and reshaped sizes
68
+ original_sizes = inputs["original_sizes"][0].tolist() # [H, W]
69
+ reshaped_input_sizes = inputs["reshaped_input_sizes"][0].tolist() # [H, W]
70
+
71
+ # Get predicted masks and scores
72
+ pred_masks = outputs.pred_masks.cpu() # Shape: (batch, num_masks, H, W)
73
+ iou_scores = outputs.iou_scores.cpu()[0] # Shape: (num_masks,)
 
 
 
 
 
 
74
 
75
+ # Handle different tensor dimensions
76
+ if pred_masks.ndim == 5:
77
+ pred_masks = pred_masks.squeeze(1) # Remove extra dimension if present
78
+
79
+ # Select the best mask
80
  best_mask_idx = torch.argmax(iou_scores)
81
+ best_mask = pred_masks[0, best_mask_idx, :, :] # Shape: (H, W)
82
+
83
+ # The mask is currently at the model's internal resolution
84
+ # We need to resize it to the reshaped input size first, then crop/pad to original size
85
+
86
+ # Step 1: Resize to reshaped input size
87
+ resized_mask = F.interpolate(
88
+ best_mask.unsqueeze(0).unsqueeze(0).float(),
89
+ size=reshaped_input_sizes,
90
+ mode='bilinear',
91
+ align_corners=False
92
+ ).squeeze()
93
+
94
+ # Step 2: Handle padding/cropping to get back to original size
95
+ original_h, original_w = original_sizes
96
+ reshaped_h, reshaped_w = reshaped_input_sizes
97
+
98
+ # Calculate padding that was added during preprocessing
99
+ if reshaped_h > original_h or reshaped_w > original_w:
100
+ # There was padding, we need to crop
101
+ start_h = (reshaped_h - original_h) // 2
102
+ start_w = (reshaped_w - original_w) // 2
103
+ final_mask = resized_mask[start_h:start_h + original_h, start_w:start_w + original_w]
104
+ else:
105
+ # No padding or different scaling, just resize directly
106
+ final_mask = F.interpolate(
107
+ resized_mask.unsqueeze(0).unsqueeze(0),
108
+ size=original_sizes,
109
+ mode='bilinear',
110
+ align_corners=False
111
+ ).squeeze()
112
+
113
+ # Convert to binary mask
114
+ mask_binary = (final_mask > 0.0).numpy().astype(np.uint8) * 255
115
 
116
  except Exception as e:
117
  print("Error processing masks: {}".format(e))