Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import numpy as np | |
| from segment_anything import SamPredictor, sam_model_registry | |
| import torch | |
| from modules.AutoDetailer import mask_util | |
| from modules.Device import Device | |
| def sam_predict( | |
| predictor: SamPredictor, points: list, plabs: list, bbox: list, threshold: float | |
| ) -> list: | |
| """#### Predict masks using SAM. | |
| #### Args: | |
| - `predictor` (SamPredictor): The SAM predictor. | |
| - `points` (list): List of points. | |
| - `plabs` (list): List of point labels. | |
| - `bbox` (list): Bounding box. | |
| - `threshold` (float): Threshold for mask selection. | |
| #### Returns: | |
| - `list`: List of predicted masks. | |
| """ | |
| point_coords = None if not points else np.array(points) | |
| point_labels = None if not plabs else np.array(plabs) | |
| box = np.array([bbox]) if bbox is not None else None | |
| cur_masks, scores, _ = predictor.predict( | |
| point_coords=point_coords, point_labels=point_labels, box=box | |
| ) | |
| total_masks = [] | |
| selected = False | |
| max_score = 0 | |
| max_mask = None | |
| for idx in range(len(scores)): | |
| if scores[idx] > max_score: | |
| max_score = scores[idx] | |
| max_mask = cur_masks[idx] | |
| if scores[idx] >= threshold: | |
| selected = True | |
| total_masks.append(cur_masks[idx]) | |
| else: | |
| pass | |
| if not selected and max_mask is not None: | |
| total_masks.append(max_mask) | |
| return total_masks | |
| def is_same_device(a: torch.device, b: torch.device) -> bool: | |
| """#### Check if two devices are the same. | |
| #### Args: | |
| - `a` (torch.device): The first device. | |
| - `b` (torch.device): The second device. | |
| #### Returns: | |
| - `bool`: Whether the devices are the same. | |
| """ | |
| a_device = torch.device(a) if isinstance(a, str) else a | |
| b_device = torch.device(b) if isinstance(b, str) else b | |
| return a_device.type == b_device.type and a_device.index == b_device.index | |
| class SafeToGPU: | |
| """#### Class to safely move objects to GPU.""" | |
| def __init__(self, size: int): | |
| self.size = size | |
| def to_device(self, obj: torch.nn.Module, device: torch.device) -> None: | |
| """#### Move an object to a device. | |
| #### Args: | |
| - `obj` (torch.nn.Module): The object to move. | |
| - `device` (torch.device): The target device. | |
| """ | |
| if is_same_device(device, "cpu"): | |
| obj.to(device) | |
| else: | |
| if is_same_device(obj.device, "cpu"): # cpu to gpu | |
| Device.free_memory(self.size * 1.3, device) | |
| if Device.get_free_memory(device) > self.size * 1.3: | |
| try: | |
| obj.to(device) | |
| except: | |
| print( | |
| f"WARN: The model is not moved to the '{device}' due to insufficient memory. [1]" | |
| ) | |
| else: | |
| print( | |
| f"WARN: The model is not moved to the '{device}' due to insufficient memory. [2]" | |
| ) | |
| class SAMWrapper: | |
| """#### Wrapper class for SAM model.""" | |
| def __init__( | |
| self, model: torch.nn.Module, is_auto_mode: bool, safe_to_gpu: SafeToGPU = None | |
| ): | |
| self.model = model | |
| self.safe_to_gpu = safe_to_gpu if safe_to_gpu is not None else SafeToGPU() | |
| self.is_auto_mode = is_auto_mode | |
| def prepare_device(self) -> None: | |
| """#### Prepare the device for the model.""" | |
| if self.is_auto_mode: | |
| device = Device.get_torch_device() | |
| self.safe_to_gpu.to_device(self.model, device=device) | |
| def release_device(self) -> None: | |
| """#### Release the device from the model.""" | |
| if self.is_auto_mode: | |
| self.model.to(device="cpu") | |
| def predict( | |
| self, image: np.ndarray, points: list, plabs: list, bbox: list, threshold: float | |
| ) -> list: | |
| """#### Predict masks using the SAM model. | |
| #### Args: | |
| - `image` (np.ndarray): The input image. | |
| - `points` (list): List of points. | |
| - `plabs` (list): List of point labels. | |
| - `bbox` (list): Bounding box. | |
| - `threshold` (float): Threshold for mask selection. | |
| #### Returns: | |
| - `list`: List of predicted masks. | |
| """ | |
| predictor = SamPredictor(self.model) | |
| predictor.set_image(image, "RGB") | |
| return sam_predict(predictor, points, plabs, bbox, threshold) | |
| class SAMLoader: | |
| """#### Class to load SAM models.""" | |
| def load_model(self, model_name: str, device_mode: str = "auto") -> tuple: | |
| """#### Load a SAM model. | |
| #### Args: | |
| - `model_name` (str): The name of the model. | |
| - `device_mode` (str, optional): The device mode. Defaults to "auto". | |
| #### Returns: | |
| - `tuple`: The loaded SAM model. | |
| """ | |
| modelname = "./_internal/yolos/" + model_name | |
| if "vit_h" in model_name: | |
| model_kind = "vit_h" | |
| elif "vit_l" in model_name: | |
| model_kind = "vit_l" | |
| else: | |
| model_kind = "vit_b" | |
| sam = sam_model_registry[model_kind](checkpoint=modelname) | |
| size = os.path.getsize(modelname) | |
| safe_to = SafeToGPU(size) | |
| # Unless user explicitly wants to use CPU, we use GPU | |
| device = Device.get_torch_device() if device_mode == "Prefer GPU" else "CPU" | |
| if device_mode == "Prefer GPU": | |
| safe_to.to_device(sam, device) | |
| is_auto_mode = device_mode == "AUTO" | |
| sam_obj = SAMWrapper(sam, is_auto_mode=is_auto_mode, safe_to_gpu=safe_to) | |
| sam.sam_wrapper = sam_obj | |
| print(f"Loads SAM model: {modelname} (device:{device_mode})") | |
| return (sam,) | |
| def make_sam_mask( | |
| sam: SAMWrapper, | |
| segs: tuple, | |
| image: torch.Tensor, | |
| detection_hint: bool, | |
| dilation: int, | |
| threshold: float, | |
| bbox_expansion: int, | |
| mask_hint_threshold: float, | |
| mask_hint_use_negative: bool, | |
| ) -> torch.Tensor: | |
| """#### Create a SAM mask. | |
| #### Args: | |
| - `sam` (SAMWrapper): The SAM wrapper. | |
| - `segs` (tuple): Segmentation information. | |
| - `image` (torch.Tensor): The input image. | |
| - `detection_hint` (bool): Whether to use detection hint. | |
| - `dilation` (int): Dilation value. | |
| - `threshold` (float): Threshold for mask selection. | |
| - `bbox_expansion` (int): Bounding box expansion value. | |
| - `mask_hint_threshold` (float): Mask hint threshold. | |
| - `mask_hint_use_negative` (bool): Whether to use negative mask hint. | |
| #### Returns: | |
| - `torch.Tensor`: The created SAM mask. | |
| """ | |
| sam_obj = sam.sam_wrapper | |
| sam_obj.prepare_device() | |
| try: | |
| image = np.clip(255.0 * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8) | |
| total_masks = [] | |
| # seg_shape = segs[0] | |
| segs = segs[1] | |
| for i in range(len(segs)): | |
| bbox = segs[i].bbox | |
| center = mask_util.center_of_bbox(bbox) | |
| x1 = max(bbox[0] - bbox_expansion, 0) | |
| y1 = max(bbox[1] - bbox_expansion, 0) | |
| x2 = min(bbox[2] + bbox_expansion, image.shape[1]) | |
| y2 = min(bbox[3] + bbox_expansion, image.shape[0]) | |
| dilated_bbox = [x1, y1, x2, y2] | |
| points = [] | |
| plabs = [] | |
| points.append(center) | |
| plabs = [1] # 1 = foreground point, 0 = background point | |
| detected_masks = sam_obj.predict( | |
| image, points, plabs, dilated_bbox, threshold | |
| ) | |
| total_masks += detected_masks | |
| # merge every collected masks | |
| mask = mask_util.combine_masks2(total_masks) | |
| finally: | |
| sam_obj.release_device() | |
| if mask is not None: | |
| mask = mask.float() | |
| mask = mask_util.dilate_mask(mask.cpu().numpy(), dilation) | |
| mask = torch.from_numpy(mask) | |
| mask = mask_util.make_3d_mask(mask) | |
| return mask | |
| else: | |
| return None | |
| class SAMDetectorCombined: | |
| """#### Class to combine SAM detection.""" | |
| def doit( | |
| self, | |
| sam_model: SAMWrapper, | |
| segs: tuple, | |
| image: torch.Tensor, | |
| detection_hint: bool, | |
| dilation: int, | |
| threshold: float, | |
| bbox_expansion: int, | |
| mask_hint_threshold: float, | |
| mask_hint_use_negative: bool, | |
| ) -> tuple: | |
| """#### Combine SAM detection. | |
| #### Args: | |
| - `sam_model` (SAMWrapper): The SAM wrapper. | |
| - `segs` (tuple): Segmentation information. | |
| - `image` (torch.Tensor): The input image. | |
| - `detection_hint` (bool): Whether to use detection hint. | |
| - `dilation` (int): Dilation value. | |
| - `threshold` (float): Threshold for mask selection. | |
| - `bbox_expansion` (int): Bounding box expansion value. | |
| - `mask_hint_threshold` (float): Mask hint threshold. | |
| - `mask_hint_use_negative` (bool): Whether to use negative mask hint. | |
| #### Returns: | |
| - `tuple`: The combined SAM detection result. | |
| """ | |
| sam = make_sam_mask( | |
| sam_model, | |
| segs, | |
| image, | |
| detection_hint, | |
| dilation, | |
| threshold, | |
| bbox_expansion, | |
| mask_hint_threshold, | |
| mask_hint_use_negative, | |
| ) | |
| if sam is not None: | |
| return (sam,) | |
| else: | |
| return None | |