| | import torch
|
| |
|
| | class Batch_Sprite_BBox_Cropper:
|
| | """
|
| | ComfyUI custom node:
|
| | - Takes a batch of RGBA images (or RGB+MASK).
|
| | - Alpha clamp: alpha <= (alpha_cutoff / 255) -> 0
|
| | - Computes one global bounding box of visible pixels across the entire batch
|
| | - Crops every image to the same bbox (spritesheet-safe)
|
| | """
|
| |
|
| | @classmethod
|
| | def INPUT_TYPES(cls):
|
| | return {
|
| | "required": {
|
| | "images": ("IMAGE",),
|
| |
|
| | "alpha_cutoff": ("INT", {"default": 10, "min": 0, "max": 255, "step": 1}),
|
| | "verbose": ("BOOLEAN", {"default": True}),
|
| | },
|
| |
|
| | "optional": {
|
| | "mask": ("MASK",),
|
| | }
|
| | }
|
| |
|
| | RETURN_TYPES = ("IMAGE", "INT", "INT", "INT", "INT", "INT", "INT")
|
| | RETURN_NAMES = ("cropped_images", "left", "top", "right", "bottom", "crop_width", "crop_height")
|
| | FUNCTION = "process"
|
| | CATEGORY = "image/alpha"
|
| |
|
| | def process(self, images, alpha_cutoff=10, verbose=True, mask=None):
|
| | """
|
| | images: torch tensor [B, H, W, C] typically float32 in [0,1]
|
| | mask: torch tensor [B, H, W] or [H, W] in [0,1] (optional)
|
| | """
|
| |
|
| | if not isinstance(images, torch.Tensor):
|
| | raise TypeError("images must be a torch.Tensor")
|
| |
|
| | if images.dim() != 4:
|
| | raise ValueError(f"images must be [B,H,W,C], got shape {tuple(images.shape)}")
|
| |
|
| | B, H, W, C = images.shape
|
| |
|
| |
|
| | if C == 4:
|
| | rgba = images.clone()
|
| | elif C == 3:
|
| | if mask is None:
|
| | raise ValueError(
|
| | "Input images are RGB (C=3). Provide a MASK input or pass RGBA (C=4)."
|
| | )
|
| |
|
| | if mask.dim() == 2:
|
| | mask_b = mask.unsqueeze(0).expand(B, -1, -1)
|
| | elif mask.dim() == 3:
|
| | mask_b = mask
|
| | else:
|
| | raise ValueError(f"mask must be [H,W] or [B,H,W], got shape {tuple(mask.shape)}")
|
| |
|
| | if mask_b.shape[0] != B or mask_b.shape[1] != H or mask_b.shape[2] != W:
|
| | raise ValueError(
|
| | f"mask shape {tuple(mask_b.shape)} must match images batch/height/width {(B,H,W)}"
|
| | )
|
| |
|
| | rgba = torch.cat([images, mask_b.unsqueeze(-1)], dim=-1).clone()
|
| | else:
|
| | raise ValueError(f"Expected images with 3 (RGB) or 4 (RGBA) channels, got C={C}")
|
| |
|
| |
|
| | threshold = float(alpha_cutoff) / 255.0
|
| | alpha = rgba[..., 3]
|
| | rgba[..., 3] = torch.where(alpha <= threshold, torch.zeros_like(alpha), alpha)
|
| |
|
| |
|
| | visible = rgba[..., 3] > 0
|
| | if not torch.any(visible):
|
| |
|
| | if verbose:
|
| | print(
|
| | f"[RGBABatchAlphaClampGlobalCrop] No visible pixels after clamp "
|
| | f"(alpha_cutoff={alpha_cutoff}). Returning unchanged RGBA."
|
| | )
|
| | left = 0
|
| | top = 0
|
| | right = W - 1
|
| | bottom = H - 1
|
| | crop_w = W
|
| | crop_h = H
|
| | return (rgba, left, top, right, bottom, crop_w, crop_h)
|
| |
|
| |
|
| | union = torch.any(visible, dim=0)
|
| |
|
| | ys = torch.any(union, dim=1)
|
| | xs = torch.any(union, dim=0)
|
| |
|
| | y_idx = torch.nonzero(ys, as_tuple=False).squeeze(1)
|
| | x_idx = torch.nonzero(xs, as_tuple=False).squeeze(1)
|
| |
|
| | top = int(y_idx[0].item())
|
| | bottom = int(y_idx[-1].item())
|
| | left = int(x_idx[0].item())
|
| | right = int(x_idx[-1].item())
|
| |
|
| |
|
| | cropped = rgba[:, top:bottom + 1, left:right + 1, :]
|
| |
|
| | crop_w = right - left + 1
|
| | crop_h = bottom - top + 1
|
| |
|
| | if verbose:
|
| | print(
|
| | f"[RGBABatchAlphaClampGlobalCrop] alpha_cutoff={alpha_cutoff} "
|
| | f"-> rect: left={left}, top={top}, right={right}, bottom={bottom} "
|
| | f"(w={crop_w}, h={crop_h}), batch={B}"
|
| | )
|
| |
|
| | return (cropped, left, top, right, bottom, crop_w, crop_h)
|
| |
|
| |
|
| | NODE_CLASS_MAPPINGS = {
|
| | "Batch_Sprite_BBox_Cropper": Batch_Sprite_BBox_Cropper
|
| | }
|
| |
|
| | NODE_DISPLAY_NAME_MAPPINGS = {
|
| | "Batch_Sprite_BBox_Cropper": "Batch_Sprite_BBox_Cropper"
|
| | }
|
| |
|