import numpy as np import scipy.ndimage import torch import comfy.utils from nodes import MAX_RESOLUTION def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): source = source.to(destination.device) if resize_source: source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") source = comfy.utils.repeat_to_batch_size(source, destination.shape[0]) x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier)) y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier)) left, top = (x // multiplier, y // multiplier) right, bottom = (left + source.shape[3], top + source.shape[2],) if mask is None: mask = torch.ones_like(source) else: mask = mask.to(destination.device, copy=True) mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0]) # calculate the bounds of the source that will be overlapping the destination # this prevents the source trying to overwrite latent pixels that are out of bounds # of the destination visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),) mask = mask[:, :, :visible_height, :visible_width] inverse_mask = torch.ones_like(mask) - mask source_portion = mask * source[:, :, :visible_height, :visible_width] destination_portion = inverse_mask * destination[:, :, top:bottom, left:right] destination[:, :, top:bottom, left:right] = source_portion + destination_portion return destination class LatentCompositeMasked: @classmethod def INPUT_TYPES(s): return { "required": { "destination": ("LATENT",), "source": ("LATENT",), "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), "resize_source": ("BOOLEAN", {"default": False}), }, "optional": { "mask": ("MASK",), } } RETURN_TYPES = ("LATENT",) FUNCTION = "composite" CATEGORY = "latent" def composite(self, destination, source, x, y, resize_source, mask = None): output = destination.copy() destination = destination["samples"].clone() source = source["samples"] output["samples"] = composite(destination, source, x, y, mask, 8, resize_source) return (output,) class ImageCompositeMasked: @classmethod def INPUT_TYPES(s): return { "required": { "destination": ("IMAGE",), "source": ("IMAGE",), "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "resize_source": ("BOOLEAN", {"default": False}), }, "optional": { "mask": ("MASK",), } } RETURN_TYPES = ("IMAGE",) FUNCTION = "composite" CATEGORY = "image" def composite(self, destination, source, x, y, resize_source, mask = None): destination = destination.clone().movedim(-1, 1) output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1) return (output,) class MaskToImage: @classmethod def INPUT_TYPES(s): return { "required": { "mask": ("MASK",), } } CATEGORY = "mask" RETURN_TYPES = ("IMAGE",) FUNCTION = "mask_to_image" def mask_to_image(self, mask): result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) return (result,) class ImageToMask: @classmethod def INPUT_TYPES(s): return { "required": { "image": ("IMAGE",), "channel": (["red", "green", "blue", "alpha"],), } } CATEGORY = "mask" RETURN_TYPES = ("MASK",) FUNCTION = "image_to_mask" def image_to_mask(self, image, channel): channels = ["red", "green", "blue", "alpha"] mask = image[:, :, :, channels.index(channel)] return (mask,) class ImageColorToMask: @classmethod def INPUT_TYPES(s): return { "required": { "image": ("IMAGE",), "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), } } CATEGORY = "mask" RETURN_TYPES = ("MASK",) FUNCTION = "image_to_mask" def image_to_mask(self, image, color): temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int) temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2] mask = torch.where(temp == color, 255, 0).float() return (mask,) class SolidMask: @classmethod def INPUT_TYPES(cls): return { "required": { "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), } } CATEGORY = "mask" RETURN_TYPES = ("MASK",) FUNCTION = "solid" def solid(self, value, width, height): out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu") return (out,) class InvertMask: @classmethod def INPUT_TYPES(cls): return { "required": { "mask": ("MASK",), } } CATEGORY = "mask" RETURN_TYPES = ("MASK",) FUNCTION = "invert" def invert(self, mask): out = 1.0 - mask return (out,) class CropMask: @classmethod def INPUT_TYPES(cls): return { "required": { "mask": ("MASK",), "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), } } CATEGORY = "mask" RETURN_TYPES = ("MASK",) FUNCTION = "crop" def crop(self, mask, x, y, width, height): mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])) out = mask[:, y:y + height, x:x + width] return (out,) class MaskComposite: @classmethod def INPUT_TYPES(cls): return { "required": { "destination": ("MASK",), "source": ("MASK",), "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "operation": (["multiply", "add", "subtract", "and", "or", "xor"],), } } CATEGORY = "mask" RETURN_TYPES = ("MASK",) FUNCTION = "combine" def combine(self, destination, source, x, y, operation): output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone() source = source.reshape((-1, source.shape[-2], source.shape[-1])) left, top = (x, y,) right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2])) visible_width, visible_height = (right - left, bottom - top,) source_portion = source[:, :visible_height, :visible_width] destination_portion = destination[:, top:bottom, left:right] if operation == "multiply": output[:, top:bottom, left:right] = destination_portion * source_portion elif operation == "add": output[:, top:bottom, left:right] = destination_portion + source_portion elif operation == "subtract": output[:, top:bottom, left:right] = destination_portion - source_portion elif operation == "and": output[:, top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float() elif operation == "or": output[:, top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float() elif operation == "xor": output[:, top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float() output = torch.clamp(output, 0.0, 1.0) return (output,) class FeatherMask: @classmethod def INPUT_TYPES(cls): return { "required": { "mask": ("MASK",), "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), } } CATEGORY = "mask" RETURN_TYPES = ("MASK",) FUNCTION = "feather" def feather(self, mask, left, top, right, bottom): output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone() left = min(left, output.shape[-1]) right = min(right, output.shape[-1]) top = min(top, output.shape[-2]) bottom = min(bottom, output.shape[-2]) for x in range(left): feather_rate = (x + 1.0) / left output[:, :, x] *= feather_rate for x in range(right): feather_rate = (x + 1) / right output[:, :, -x] *= feather_rate for y in range(top): feather_rate = (y + 1) / top output[:, y, :] *= feather_rate for y in range(bottom): feather_rate = (y + 1) / bottom output[:, -y, :] *= feather_rate return (output,) class GrowMask: @classmethod def INPUT_TYPES(cls): return { "required": { "mask": ("MASK",), "expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}), "tapered_corners": ("BOOLEAN", {"default": True}), }, } CATEGORY = "mask" RETURN_TYPES = ("MASK",) FUNCTION = "expand_mask" def expand_mask(self, mask, expand, tapered_corners): c = 0 if tapered_corners else 1 kernel = np.array([[c, 1, c], [1, 1, 1], [c, 1, c]]) mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])) out = [] for m in mask: output = m.numpy() for _ in range(abs(expand)): if expand < 0: output = scipy.ndimage.grey_erosion(output, footprint=kernel) else: output = scipy.ndimage.grey_dilation(output, footprint=kernel) output = torch.from_numpy(output) out.append(output) return (torch.stack(out, dim=0),) class ThresholdMask: @classmethod def INPUT_TYPES(s): return { "required": { "mask": ("MASK",), "value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), } } CATEGORY = "mask" RETURN_TYPES = ("MASK",) FUNCTION = "image_to_mask" def image_to_mask(self, mask, value): mask = (mask > value).float() return (mask,) NODE_CLASS_MAPPINGS = { "LatentCompositeMasked": LatentCompositeMasked, "ImageCompositeMasked": ImageCompositeMasked, "MaskToImage": MaskToImage, "ImageToMask": ImageToMask, "ImageColorToMask": ImageColorToMask, "SolidMask": SolidMask, "InvertMask": InvertMask, "CropMask": CropMask, "MaskComposite": MaskComposite, "FeatherMask": FeatherMask, "GrowMask": GrowMask, "ThresholdMask": ThresholdMask, } NODE_DISPLAY_NAME_MAPPINGS = { "ImageToMask": "Convert Image to Mask", "MaskToImage": "Convert Mask to Image", }