Ayesha-Majeed commited on
Commit
20be3dc
·
verified ·
1 Parent(s): 3fc284c

Update binary_segmentation.py

Browse files
Files changed (1) hide show
  1. binary_segmentation.py +124 -40
binary_segmentation.py CHANGED
@@ -566,81 +566,165 @@ class BinarySegmenter:
566
  except ImportError:
567
  raise ImportError("RMBG requires: pip install transformers")
568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
569
  def segment(
570
  self,
571
  image: np.ndarray,
572
  threshold: float = 0.5,
573
  return_type: Literal["mask", "rgba", "both"] = "mask"
574
  ) -> Tuple[Optional[np.ndarray], Optional[Image.Image]]:
575
- """
576
- Segment foreground object from image.
577
-
578
- Args:
579
- image: Input image as numpy array (H, W, 3) in RGB or BGR
580
- threshold: Threshold for binary mask (0-1)
581
- return_type: What to return - "mask", "rgba", or "both"
582
-
583
- Returns:
584
- Tuple of (binary_mask, rgba_image) based on return_type
585
- """
586
- # Convert BGR to RGB if needed
587
  if len(image.shape) == 3 and image.shape[2] == 3:
588
- if image[0, 0, 0] != image[0, 0, 2]: # Simple heuristic
589
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
590
- else:
591
- image_rgb = image
592
  else:
593
  raise ValueError("Input must be a color image (H, W, 3)")
594
-
595
- # Convert to PIL
 
 
 
596
  image_pil = Image.fromarray(image_rgb)
597
- original_size = image_pil.size
598
-
599
- # Transform
600
  input_tensor = self.transform(image_pil).unsqueeze(0).to(DEVICE)
601
  if DEVICE == "cpu":
602
  input_tensor = input_tensor.float()
603
 
604
-
605
  # Inference
606
  with torch.no_grad():
607
  if self.model_type == "u2netp":
608
  outputs = self.model(input_tensor)
609
- pred = outputs[0] # Main output
610
  else: # birefnet or rmbg
611
  pred = self.model(input_tensor)[-1].sigmoid()
612
-
613
- # Post-process
614
  pred = pred.squeeze().cpu().numpy()
615
-
616
- # Resize to original
617
- pred_resized = cv2.resize(pred, original_size, interpolation=cv2.INTER_LINEAR)
618
-
 
 
 
 
 
 
 
 
 
619
  # Normalize to 0-255
620
- pred_normalized = ((pred_resized - pred_resized.min()) /
621
- (pred_resized.max() - pred_resized.min() + 1e-8) * 255)
622
-
623
- # Create binary mask
 
 
624
  binary_mask = (pred_normalized > (threshold * 255)).astype(np.uint8) * 255
625
-
626
- # Optional: Morphological operations for cleaner mask
627
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
628
  binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
629
  binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)
630
-
 
 
 
 
 
 
631
  # Create RGBA if needed
632
  rgba_image = None
633
  if return_type in ["rgba", "both"]:
634
- # Create 4-channel image
635
  rgba = np.dstack([image_rgb, binary_mask])
636
  rgba_image = Image.fromarray(rgba, mode='RGBA')
637
-
638
- # Return based on type
 
 
 
639
  if return_type == "mask":
640
  return binary_mask, None
641
  elif return_type == "rgba":
642
  return None, rgba_image
643
- else: # both
644
  return binary_mask, rgba_image
645
 
646
  def batch_segment(
 
566
  except ImportError:
567
  raise ImportError("RMBG requires: pip install transformers")
568
 
569
+ # def segment(
570
+ # self,
571
+ # image: np.ndarray,
572
+ # threshold: float = 0.5,
573
+ # return_type: Literal["mask", "rgba", "both"] = "mask"
574
+ # ) -> Tuple[Optional[np.ndarray], Optional[Image.Image]]:
575
+ # """
576
+ # Segment foreground object from image.
577
+
578
+ # Args:
579
+ # image: Input image as numpy array (H, W, 3) in RGB or BGR
580
+ # threshold: Threshold for binary mask (0-1)
581
+ # return_type: What to return - "mask", "rgba", or "both"
582
+
583
+ # Returns:
584
+ # Tuple of (binary_mask, rgba_image) based on return_type
585
+ # """
586
+ # # Convert BGR to RGB if needed
587
+ # if len(image.shape) == 3 and image.shape[2] == 3:
588
+ # if image[0, 0, 0] != image[0, 0, 2]: # Simple heuristic
589
+ # image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
590
+ # else:
591
+ # image_rgb = image
592
+ # else:
593
+ # raise ValueError("Input must be a color image (H, W, 3)")
594
+
595
+ # # Convert to PIL
596
+ # image_pil = Image.fromarray(image_rgb)
597
+ # original_size = image_pil.size
598
+
599
+ # # Transform
600
+ # input_tensor = self.transform(image_pil).unsqueeze(0).to(DEVICE)
601
+ # if DEVICE == "cpu":
602
+ # input_tensor = input_tensor.float()
603
+
604
+
605
+ # # Inference
606
+ # with torch.no_grad():
607
+ # if self.model_type == "u2netp":
608
+ # outputs = self.model(input_tensor)
609
+ # pred = outputs[0] # Main output
610
+ # else: # birefnet or rmbg
611
+ # pred = self.model(input_tensor)[-1].sigmoid()
612
+
613
+ # # Post-process
614
+ # pred = pred.squeeze().cpu().numpy()
615
+
616
+ # # Resize to original
617
+ # pred_resized = cv2.resize(pred, original_size, interpolation=cv2.INTER_LINEAR)
618
+
619
+ # # Normalize to 0-255
620
+ # pred_normalized = ((pred_resized - pred_resized.min()) /
621
+ # (pred_resized.max() - pred_resized.min() + 1e-8) * 255)
622
+
623
+ # # Create binary mask
624
+ # binary_mask = (pred_normalized > (threshold * 255)).astype(np.uint8) * 255
625
+
626
+ # # Optional: Morphological operations for cleaner mask
627
+ # kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
628
+ # binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
629
+ # binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)
630
+
631
+ # # Create RGBA if needed
632
+ # rgba_image = None
633
+ # if return_type in ["rgba", "both"]:
634
+ # # Create 4-channel image
635
+ # rgba = np.dstack([image_rgb, binary_mask])
636
+ # rgba_image = Image.fromarray(rgba, mode='RGBA')
637
+
638
+ # # Return based on type
639
+ # if return_type == "mask":
640
+ # return binary_mask, None
641
+ # elif return_type == "rgba":
642
+ # return None, rgba_image
643
+ # else: # both
644
+ # return binary_mask, rgba_image
645
+
646
  def segment(
647
  self,
648
  image: np.ndarray,
649
  threshold: float = 0.5,
650
  return_type: Literal["mask", "rgba", "both"] = "mask"
651
  ) -> Tuple[Optional[np.ndarray], Optional[Image.Image]]:
652
+
653
+ # Convert BGR to RGB
 
 
 
 
 
 
 
 
 
 
654
  if len(image.shape) == 3 and image.shape[2] == 3:
655
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
 
 
656
  else:
657
  raise ValueError("Input must be a color image (H, W, 3)")
658
+
659
+ # Store ORIGINAL dimensions (H, W) from numpy
660
+ orig_h, orig_w = image.shape[:2]
661
+
662
+ # Convert to PIL for transforms
663
  image_pil = Image.fromarray(image_rgb)
664
+
665
+ # Transform (model resizes internally e.g. 320x320 / 512x512)
 
666
  input_tensor = self.transform(image_pil).unsqueeze(0).to(DEVICE)
667
  if DEVICE == "cpu":
668
  input_tensor = input_tensor.float()
669
 
 
670
  # Inference
671
  with torch.no_grad():
672
  if self.model_type == "u2netp":
673
  outputs = self.model(input_tensor)
674
+ pred = outputs[0]
675
  else: # birefnet or rmbg
676
  pred = self.model(input_tensor)[-1].sigmoid()
677
+
678
+ # Post-process - squeeze to 2D
679
  pred = pred.squeeze().cpu().numpy()
680
+
681
+ # ✅ FIX: Resize back to ORIGINAL (width, height) for cv2
682
+ # cv2.resize takes (width, height) = (orig_w, orig_h)
683
+ pred_resized = cv2.resize(
684
+ pred,
685
+ (orig_w, orig_h), # ← correct order for cv2
686
+ interpolation=cv2.INTER_LINEAR
687
+ )
688
+
689
+ # Verify shape matches original
690
+ assert pred_resized.shape == (orig_h, orig_w), \
691
+ f"Shape mismatch! Got {pred_resized.shape}, expected ({orig_h}, {orig_w})"
692
+
693
  # Normalize to 0-255
694
+ pred_normalized = (
695
+ (pred_resized - pred_resized.min()) /
696
+ (pred_resized.max() - pred_resized.min() + 1e-8) * 255
697
+ )
698
+
699
+ # Binary mask
700
  binary_mask = (pred_normalized > (threshold * 255)).astype(np.uint8) * 255
701
+
702
+ # Morphological cleanup
703
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
704
  binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
705
  binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)
706
+
707
+ # ✅ Verify final mask dimensions match input
708
+ assert binary_mask.shape == (orig_h, orig_w), \
709
+ f"Final mask mismatch! Got {binary_mask.shape}, expected ({orig_h}, {orig_w})"
710
+
711
+ logger.info(f"Input shape: ({orig_h}, {orig_w}) | Output mask shape: {binary_mask.shape} ✅")
712
+
713
  # Create RGBA if needed
714
  rgba_image = None
715
  if return_type in ["rgba", "both"]:
 
716
  rgba = np.dstack([image_rgb, binary_mask])
717
  rgba_image = Image.fromarray(rgba, mode='RGBA')
718
+
719
+ # Verify RGBA dimensions
720
+ assert rgba_image.size == (orig_w, orig_h), \
721
+ f"RGBA size mismatch! Got {rgba_image.size}, expected ({orig_w}, {orig_h})"
722
+
723
  if return_type == "mask":
724
  return binary_mask, None
725
  elif return_type == "rgba":
726
  return None, rgba_image
727
+ else:
728
  return binary_mask, rgba_image
729
 
730
  def batch_segment(