|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
import os |
|
|
from typing import Literal, Union |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import rembg |
|
|
import torch |
|
|
from huggingface_hub import snapshot_download |
|
|
from PIL import Image |
|
|
from segment_anything import ( |
|
|
SamAutomaticMaskGenerator, |
|
|
SamPredictor, |
|
|
sam_model_registry, |
|
|
) |
|
|
from transformers import pipeline |
|
|
from embodied_gen.data.utils import resize_pil, trellis_preprocess |
|
|
from embodied_gen.utils.process_media import filter_small_connected_components |
|
|
from embodied_gen.validators.quality_checkers import ImageSegChecker |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"SAMRemover", |
|
|
"SAMPredictor", |
|
|
"RembgRemover", |
|
|
"get_segmented_image_by_agent", |
|
|
] |
|
|
|
|
|
|
|
|
class SAMRemover(object): |
|
|
"""Loads SAM models and performs background removal on images. |
|
|
|
|
|
Attributes: |
|
|
checkpoint (str): Path to the model checkpoint. |
|
|
model_type (str): Type of the SAM model to load. |
|
|
area_ratio (float): Area ratio for filtering small connected components. |
|
|
|
|
|
Example: |
|
|
```py |
|
|
from embodied_gen.models.segment_model import SAMRemover |
|
|
remover = SAMRemover(model_type="vit_h") |
|
|
result = remover("input.jpg", "output.png") |
|
|
``` |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
checkpoint: str = None, |
|
|
model_type: str = "vit_h", |
|
|
area_ratio: float = 15, |
|
|
): |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.model_type = model_type |
|
|
self.area_ratio = area_ratio |
|
|
|
|
|
if checkpoint is None: |
|
|
suffix = "sam" |
|
|
model_path = snapshot_download( |
|
|
repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*" |
|
|
) |
|
|
checkpoint = os.path.join( |
|
|
model_path, suffix, "sam_vit_h_4b8939.pth" |
|
|
) |
|
|
|
|
|
self.mask_generator = self._load_sam_model(checkpoint) |
|
|
|
|
|
def _load_sam_model(self, checkpoint: str) -> SamAutomaticMaskGenerator: |
|
|
"""Loads the SAM model and returns a mask generator. |
|
|
|
|
|
Args: |
|
|
checkpoint (str): Path to model checkpoint. |
|
|
|
|
|
Returns: |
|
|
SamAutomaticMaskGenerator: Mask generator instance. |
|
|
""" |
|
|
sam = sam_model_registry[self.model_type](checkpoint=checkpoint) |
|
|
sam.to(device=self.device) |
|
|
|
|
|
return SamAutomaticMaskGenerator(sam) |
|
|
|
|
|
def __call__( |
|
|
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None |
|
|
) -> Image.Image: |
|
|
"""Removes the background from an image using the SAM model. |
|
|
|
|
|
Args: |
|
|
image (Union[str, Image.Image, np.ndarray]): Input image. |
|
|
save_path (str, optional): Path to save the output image. |
|
|
|
|
|
Returns: |
|
|
Image.Image: Image with background removed (RGBA). |
|
|
""" |
|
|
|
|
|
if isinstance(image, str): |
|
|
image = Image.open(image) |
|
|
elif isinstance(image, np.ndarray): |
|
|
image = Image.fromarray(image).convert("RGB") |
|
|
image = resize_pil(image) |
|
|
image = np.array(image.convert("RGB")) |
|
|
|
|
|
|
|
|
masks = self.mask_generator.generate(image) |
|
|
masks = sorted(masks, key=lambda x: x["area"], reverse=True) |
|
|
|
|
|
if not masks: |
|
|
logger.warning( |
|
|
"Segmentation failed: No mask generated, return raw image." |
|
|
) |
|
|
output_image = Image.fromarray(image, mode="RGB") |
|
|
else: |
|
|
|
|
|
best_mask = masks[0]["segmentation"] |
|
|
mask = (best_mask * 255).astype(np.uint8) |
|
|
mask = filter_small_connected_components( |
|
|
mask, area_ratio=self.area_ratio |
|
|
) |
|
|
|
|
|
background_removed = cv2.bitwise_and(image, image, mask=mask) |
|
|
output_image = np.dstack((background_removed, mask)) |
|
|
output_image = Image.fromarray(output_image, mode="RGBA") |
|
|
|
|
|
if save_path is not None: |
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True) |
|
|
output_image.save(save_path) |
|
|
|
|
|
return output_image |
|
|
|
|
|
|
|
|
class SAMPredictor(object): |
|
|
"""Loads SAM models and predicts segmentation masks from user points. |
|
|
|
|
|
Args: |
|
|
checkpoint (str, optional): Path to model checkpoint. |
|
|
model_type (str, optional): SAM model type. |
|
|
binary_thresh (float, optional): Threshold for binary mask. |
|
|
device (str, optional): Device for inference. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
checkpoint: str = None, |
|
|
model_type: str = "vit_h", |
|
|
binary_thresh: float = 0.1, |
|
|
device: str = "cuda", |
|
|
): |
|
|
self.device = device |
|
|
self.model_type = model_type |
|
|
|
|
|
if checkpoint is None: |
|
|
suffix = "sam" |
|
|
model_path = snapshot_download( |
|
|
repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*" |
|
|
) |
|
|
checkpoint = os.path.join( |
|
|
model_path, suffix, "sam_vit_h_4b8939.pth" |
|
|
) |
|
|
|
|
|
self.predictor = self._load_sam_model(checkpoint) |
|
|
self.binary_thresh = binary_thresh |
|
|
|
|
|
def _load_sam_model(self, checkpoint: str) -> SamPredictor: |
|
|
"""Loads the SAM model and returns a predictor. |
|
|
|
|
|
Args: |
|
|
checkpoint (str): Path to model checkpoint. |
|
|
|
|
|
Returns: |
|
|
SamPredictor: Predictor instance. |
|
|
""" |
|
|
sam = sam_model_registry[self.model_type](checkpoint=checkpoint) |
|
|
sam.to(device=self.device) |
|
|
|
|
|
return SamPredictor(sam) |
|
|
|
|
|
def preprocess_image(self, image: Image.Image) -> np.ndarray: |
|
|
"""Preprocesses input image for SAM prediction. |
|
|
|
|
|
Args: |
|
|
image (Image.Image): Input image. |
|
|
|
|
|
Returns: |
|
|
np.ndarray: Preprocessed image array. |
|
|
""" |
|
|
if isinstance(image, str): |
|
|
image = Image.open(image) |
|
|
elif isinstance(image, np.ndarray): |
|
|
image = Image.fromarray(image).convert("RGB") |
|
|
|
|
|
image = resize_pil(image) |
|
|
image = np.array(image.convert("RGB")) |
|
|
|
|
|
return image |
|
|
|
|
|
def generate_masks( |
|
|
self, |
|
|
image: np.ndarray, |
|
|
selected_points: list[list[int]], |
|
|
) -> np.ndarray: |
|
|
"""Generates segmentation masks from selected points. |
|
|
|
|
|
Args: |
|
|
image (np.ndarray): Input image array. |
|
|
selected_points (list[list[int]]): List of points and labels. |
|
|
|
|
|
Returns: |
|
|
list[tuple[np.ndarray, str]]: List of masks and names. |
|
|
""" |
|
|
if len(selected_points) == 0: |
|
|
return [] |
|
|
|
|
|
points = ( |
|
|
torch.Tensor([p for p, _ in selected_points]) |
|
|
.to(self.predictor.device) |
|
|
.unsqueeze(1) |
|
|
) |
|
|
|
|
|
labels = ( |
|
|
torch.Tensor([int(l) for _, l in selected_points]) |
|
|
.to(self.predictor.device) |
|
|
.unsqueeze(1) |
|
|
) |
|
|
|
|
|
transformed_points = self.predictor.transform.apply_coords_torch( |
|
|
points, image.shape[:2] |
|
|
) |
|
|
|
|
|
masks, scores, _ = self.predictor.predict_torch( |
|
|
point_coords=transformed_points, |
|
|
point_labels=labels, |
|
|
multimask_output=True, |
|
|
) |
|
|
valid_mask = masks[:, torch.argmax(scores, dim=1)] |
|
|
masks_pos = valid_mask[labels[:, 0] == 1, 0].cpu().detach().numpy() |
|
|
masks_neg = valid_mask[labels[:, 0] == 0, 0].cpu().detach().numpy() |
|
|
if len(masks_neg) == 0: |
|
|
masks_neg = np.zeros_like(masks_pos) |
|
|
if len(masks_pos) == 0: |
|
|
masks_pos = np.zeros_like(masks_neg) |
|
|
masks_neg = masks_neg.max(axis=0, keepdims=True) |
|
|
masks_pos = masks_pos.max(axis=0, keepdims=True) |
|
|
valid_mask = (masks_pos.astype(int) - masks_neg.astype(int)).clip(0, 1) |
|
|
|
|
|
binary_mask = (valid_mask > self.binary_thresh).astype(np.int32) |
|
|
|
|
|
return [(mask, f"mask_{i}") for i, mask in enumerate(binary_mask)] |
|
|
|
|
|
def get_segmented_image( |
|
|
self, image: np.ndarray, masks: list[tuple[np.ndarray, str]] |
|
|
) -> Image.Image: |
|
|
"""Combines masks and returns segmented image with alpha channel. |
|
|
|
|
|
Args: |
|
|
image (np.ndarray): Input image array. |
|
|
masks (list[tuple[np.ndarray, str]]): List of masks. |
|
|
|
|
|
Returns: |
|
|
Image.Image: Segmented RGBA image. |
|
|
""" |
|
|
seg_image = Image.fromarray(image, mode="RGB") |
|
|
alpha_channel = np.zeros( |
|
|
(seg_image.height, seg_image.width), dtype=np.uint8 |
|
|
) |
|
|
for mask, _ in masks: |
|
|
|
|
|
alpha_channel = np.maximum(alpha_channel, mask) |
|
|
|
|
|
alpha_channel = np.clip(alpha_channel, 0, 1) |
|
|
alpha_channel = (alpha_channel * 255).astype(np.uint8) |
|
|
alpha_image = Image.fromarray(alpha_channel, mode="L") |
|
|
r, g, b = seg_image.split() |
|
|
seg_image = Image.merge("RGBA", (r, g, b, alpha_image)) |
|
|
|
|
|
return seg_image |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
image: Union[str, Image.Image, np.ndarray], |
|
|
selected_points: list[list[int]], |
|
|
) -> Image.Image: |
|
|
"""Segments image using selected points. |
|
|
|
|
|
Args: |
|
|
image (Union[str, Image.Image, np.ndarray]): Input image. |
|
|
selected_points (list[list[int]]): List of points and labels. |
|
|
|
|
|
Returns: |
|
|
Image.Image: Segmented RGBA image. |
|
|
""" |
|
|
image = self.preprocess_image(image) |
|
|
self.predictor.set_image(image) |
|
|
masks = self.generate_masks(image, selected_points) |
|
|
|
|
|
return self.get_segmented_image(image, masks) |
|
|
|
|
|
|
|
|
class RembgRemover(object): |
|
|
"""Removes background from images using the rembg library. |
|
|
|
|
|
Example: |
|
|
```py |
|
|
from embodied_gen.models.segment_model import RembgRemover |
|
|
remover = RembgRemover() |
|
|
result = remover("input.jpg", "output.png") |
|
|
``` |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
"""Initializes the RembgRemover.""" |
|
|
self.rembg_session = rembg.new_session("u2net") |
|
|
|
|
|
def __call__( |
|
|
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None |
|
|
) -> Image.Image: |
|
|
"""Removes background from an image. |
|
|
|
|
|
Args: |
|
|
image (Union[str, Image.Image, np.ndarray]): Input image. |
|
|
save_path (str, optional): Path to save the output image. |
|
|
|
|
|
Returns: |
|
|
Image.Image: Image with background removed (RGBA). |
|
|
""" |
|
|
if isinstance(image, str): |
|
|
image = Image.open(image) |
|
|
elif isinstance(image, np.ndarray): |
|
|
image = Image.fromarray(image) |
|
|
|
|
|
image = resize_pil(image) |
|
|
output_image = rembg.remove(image, session=self.rembg_session) |
|
|
|
|
|
if save_path is not None: |
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True) |
|
|
output_image.save(save_path) |
|
|
|
|
|
return output_image |
|
|
|
|
|
|
|
|
class BMGG14Remover(object): |
|
|
"""Removes background using the RMBG-1.4 segmentation model. |
|
|
|
|
|
Example: |
|
|
```py |
|
|
from embodied_gen.models.segment_model import BMGG14Remover |
|
|
remover = BMGG14Remover() |
|
|
result = remover("input.jpg", "output.png") |
|
|
``` |
|
|
""" |
|
|
|
|
|
def __init__(self) -> None: |
|
|
"""Initializes the BMGG14Remover.""" |
|
|
self.model = pipeline( |
|
|
"image-segmentation", |
|
|
model="briaai/RMBG-1.4", |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
def __call__( |
|
|
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None |
|
|
): |
|
|
"""Removes background from an image. |
|
|
|
|
|
Args: |
|
|
image (Union[str, Image.Image, np.ndarray]): Input image. |
|
|
save_path (str, optional): Path to save the output image. |
|
|
|
|
|
Returns: |
|
|
Image.Image: Image with background removed. |
|
|
""" |
|
|
if isinstance(image, str): |
|
|
image = Image.open(image) |
|
|
elif isinstance(image, np.ndarray): |
|
|
image = Image.fromarray(image) |
|
|
|
|
|
image = resize_pil(image) |
|
|
output_image = self.model(image) |
|
|
|
|
|
if save_path is not None: |
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True) |
|
|
output_image.save(save_path) |
|
|
|
|
|
return output_image |
|
|
|
|
|
|
|
|
def invert_rgba_pil( |
|
|
image: Image.Image, mask: Image.Image, save_path: str = None |
|
|
) -> Image.Image: |
|
|
"""Inverts the alpha channel of an RGBA image using a mask. |
|
|
|
|
|
Args: |
|
|
image (Image.Image): Input RGB image. |
|
|
mask (Image.Image): Mask image for alpha inversion. |
|
|
save_path (str, optional): Path to save the output image. |
|
|
|
|
|
Returns: |
|
|
Image.Image: RGBA image with inverted alpha. |
|
|
""" |
|
|
mask = (255 - np.array(mask))[..., None] |
|
|
image_array = np.concatenate([np.array(image), mask], axis=-1) |
|
|
inverted_image = Image.fromarray(image_array, "RGBA") |
|
|
|
|
|
if save_path is not None: |
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True) |
|
|
inverted_image.save(save_path) |
|
|
|
|
|
return inverted_image |
|
|
|
|
|
|
|
|
def get_segmented_image_by_agent( |
|
|
image: Image.Image, |
|
|
sam_remover: SAMRemover, |
|
|
rbg_remover: RembgRemover, |
|
|
seg_checker: ImageSegChecker = None, |
|
|
save_path: str = None, |
|
|
mode: Literal["loose", "strict"] = "loose", |
|
|
) -> Image.Image: |
|
|
"""Segments an image using SAM and rembg, with quality checking. |
|
|
|
|
|
Args: |
|
|
image (Image.Image): Input image. |
|
|
sam_remover (SAMRemover): SAM-based remover. |
|
|
rbg_remover (RembgRemover): rembg-based remover. |
|
|
seg_checker (ImageSegChecker, optional): Quality checker. |
|
|
save_path (str, optional): Path to save the output image. |
|
|
mode (Literal["loose", "strict"], optional): Segmentation mode. |
|
|
|
|
|
Returns: |
|
|
Image.Image: Segmented RGBA image. |
|
|
""" |
|
|
|
|
|
def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool: |
|
|
if seg_checker is None: |
|
|
return True |
|
|
return raw_img.mode == "RGBA" and seg_checker([raw_img, seg_img])[0] |
|
|
|
|
|
out_sam = f"{save_path}_sam.png" if save_path else None |
|
|
out_sam_inv = f"{save_path}_sam_inv.png" if save_path else None |
|
|
out_rbg = f"{save_path}_rbg.png" if save_path else None |
|
|
|
|
|
seg_image = sam_remover(image, out_sam) |
|
|
seg_image = seg_image.convert("RGBA") |
|
|
_, _, _, alpha = seg_image.split() |
|
|
seg_image_inv = invert_rgba_pil(image.convert("RGB"), alpha, out_sam_inv) |
|
|
seg_image_rbg = rbg_remover(image, out_rbg) |
|
|
|
|
|
final_image = None |
|
|
if _is_valid_seg(image, seg_image): |
|
|
final_image = seg_image |
|
|
elif _is_valid_seg(image, seg_image_inv): |
|
|
final_image = seg_image_inv |
|
|
elif _is_valid_seg(image, seg_image_rbg): |
|
|
logger.warning(f"Failed to segment by `SAM`, retry with `rembg`.") |
|
|
final_image = seg_image_rbg |
|
|
else: |
|
|
if mode == "strict": |
|
|
raise RuntimeError( |
|
|
f"Failed to segment by `SAM` or `rembg`, abort." |
|
|
) |
|
|
logger.warning("Failed to segment by SAM or rembg, use raw image.") |
|
|
final_image = image.convert("RGBA") |
|
|
|
|
|
if save_path: |
|
|
final_image.save(save_path) |
|
|
|
|
|
final_image = trellis_preprocess(final_image) |
|
|
|
|
|
return final_image |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
input_image = "outputs/text2image/demo_objects/electrical/sample_0.jpg" |
|
|
output_image = "sample_0_seg2.png" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
remover = SAMRemover(model_type="vit_h") |
|
|
remover = RembgRemover() |
|
|
clean_image = remover(input_image) |
|
|
clean_image.save(output_image) |
|
|
get_segmented_image_by_agent( |
|
|
Image.open(input_image), remover, remover, None, "./test_seg.png" |
|
|
) |
|
|
|
|
|
remover = BMGG14Remover() |
|
|
remover("embodied_gen/models/test_seg.jpg", "./seg.png") |
|
|
|