MyCustomNodes / Batch_Sprites_BBox_Cropper.py
saliacoel's picture
Upload 2 files
c7a0808 verified
import torch
class Batch_Sprite_BBox_Cropper:
"""
ComfyUI custom node:
- Takes a batch of RGBA images (or RGB+MASK).
- Alpha clamp: alpha <= (alpha_cutoff / 255) -> 0
- Computes one global bounding box of visible pixels across the entire batch
- Crops every image to the same bbox (spritesheet-safe)
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
# User requirement: cutoff default = 10 (out of 255)
"alpha_cutoff": ("INT", {"default": 10, "min": 0, "max": 255, "step": 1}),
"verbose": ("BOOLEAN", {"default": True}),
},
# Optional: if you only have RGB image + separate alpha mask (common in ComfyUI)
"optional": {
"mask": ("MASK",),
}
}
RETURN_TYPES = ("IMAGE", "INT", "INT", "INT", "INT", "INT", "INT")
RETURN_NAMES = ("cropped_images", "left", "top", "right", "bottom", "crop_width", "crop_height")
FUNCTION = "process"
CATEGORY = "image/alpha"
def process(self, images, alpha_cutoff=10, verbose=True, mask=None):
"""
images: torch tensor [B, H, W, C] typically float32 in [0,1]
mask: torch tensor [B, H, W] or [H, W] in [0,1] (optional)
"""
if not isinstance(images, torch.Tensor):
raise TypeError("images must be a torch.Tensor")
if images.dim() != 4:
raise ValueError(f"images must be [B,H,W,C], got shape {tuple(images.shape)}")
B, H, W, C = images.shape
# Build RGBA tensor
if C == 4:
rgba = images.clone()
elif C == 3:
if mask is None:
raise ValueError(
"Input images are RGB (C=3). Provide a MASK input or pass RGBA (C=4)."
)
# Normalize mask to [B,H,W]
if mask.dim() == 2:
mask_b = mask.unsqueeze(0).expand(B, -1, -1)
elif mask.dim() == 3:
mask_b = mask
else:
raise ValueError(f"mask must be [H,W] or [B,H,W], got shape {tuple(mask.shape)}")
if mask_b.shape[0] != B or mask_b.shape[1] != H or mask_b.shape[2] != W:
raise ValueError(
f"mask shape {tuple(mask_b.shape)} must match images batch/height/width {(B,H,W)}"
)
rgba = torch.cat([images, mask_b.unsqueeze(-1)], dim=-1).clone()
else:
raise ValueError(f"Expected images with 3 (RGB) or 4 (RGBA) channels, got C={C}")
# 1) Alpha clamp: alpha <= (alpha_cutoff/255) -> 0
threshold = float(alpha_cutoff) / 255.0
alpha = rgba[..., 3]
rgba[..., 3] = torch.where(alpha <= threshold, torch.zeros_like(alpha), alpha)
# 2) Global bbox of visible pixels across batch (alpha > 0 after clamp)
visible = rgba[..., 3] > 0 # [B,H,W] boolean
if not torch.any(visible):
# Nothing visible after clamp; skip crop
if verbose:
print(
f"[RGBABatchAlphaClampGlobalCrop] No visible pixels after clamp "
f"(alpha_cutoff={alpha_cutoff}). Returning unchanged RGBA."
)
left = 0
top = 0
right = W - 1
bottom = H - 1
crop_w = W
crop_h = H
return (rgba, left, top, right, bottom, crop_w, crop_h)
# Union visibility across batch -> [H,W]
union = torch.any(visible, dim=0)
ys = torch.any(union, dim=1) # [H]
xs = torch.any(union, dim=0) # [W]
y_idx = torch.nonzero(ys, as_tuple=False).squeeze(1)
x_idx = torch.nonzero(xs, as_tuple=False).squeeze(1)
top = int(y_idx[0].item())
bottom = int(y_idx[-1].item())
left = int(x_idx[0].item())
right = int(x_idx[-1].item())
# 3) Crop all images to the same rect (inclusive right/bottom)
cropped = rgba[:, top:bottom + 1, left:right + 1, :]
crop_w = right - left + 1
crop_h = bottom - top + 1
if verbose:
print(
f"[RGBABatchAlphaClampGlobalCrop] alpha_cutoff={alpha_cutoff} "
f"-> rect: left={left}, top={top}, right={right}, bottom={bottom} "
f"(w={crop_w}, h={crop_h}), batch={B}"
)
return (cropped, left, top, right, bottom, crop_w, crop_h)
NODE_CLASS_MAPPINGS = {
"Batch_Sprite_BBox_Cropper": Batch_Sprite_BBox_Cropper
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Batch_Sprite_BBox_Cropper": "Batch_Sprite_BBox_Cropper"
}