import torch class ImageResize: def __init__(self): pass ACTION_TYPE_RESIZE = "resize only" ACTION_TYPE_CROP = "crop to ratio" ACTION_TYPE_PAD = "pad to ratio" RESIZE_MODE_DOWNSCALE = "reduce size only" RESIZE_MODE_UPSCALE = "increase size only" RESIZE_MODE_ANY = "any" RETURN_TYPES = ("IMAGE", "MASK",) FUNCTION = "resize" CATEGORY = "image" @classmethod def INPUT_TYPES(s): return { "required": { "pixels": ("IMAGE",), "action": ([s.ACTION_TYPE_RESIZE, s.ACTION_TYPE_CROP, s.ACTION_TYPE_PAD],), "smaller_side": ("INT", {"default": 0, "min": 0, "max": 8192, "step": 8}), "larger_side": ("INT", {"default": 0, "min": 0, "max": 8192, "step": 8}), "scale_factor": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.1}), "resize_mode": ([s.RESIZE_MODE_DOWNSCALE, s.RESIZE_MODE_UPSCALE, s.RESIZE_MODE_ANY],), "side_ratio": ("STRING", {"default": "4:3"}), "crop_pad_position": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), "pad_feathering": ("INT", {"default": 20, "min": 0, "max": 8192, "step": 1}), }, "optional": { "mask_optional": ("MASK",), }, } @classmethod def VALIDATE_INPUTS(s, action, smaller_side, larger_side, scale_factor, resize_mode, side_ratio, **_): if side_ratio is not None: if action != s.ACTION_TYPE_RESIZE and s.parse_side_ratio(side_ratio) is None: return f"Invalid side ratio: {side_ratio}" if smaller_side is not None and larger_side is not None and scale_factor is not None: if int(smaller_side > 0) + int(larger_side > 0) + int(scale_factor > 0) > 1: return f"At most one scaling rule (smaller_side, larger_side, scale_factor) should be enabled by setting a non-zero value" if scale_factor is not None: if resize_mode == s.RESIZE_MODE_DOWNSCALE and scale_factor > 1.0: return f"For resize_mode {s.RESIZE_MODE_DOWNSCALE}, scale_factor should be less than one but got {scale_factor}" if resize_mode == s.RESIZE_MODE_UPSCALE and scale_factor > 0.0 and scale_factor < 1.0: return f"For resize_mode {s.RESIZE_MODE_UPSCALE}, scale_factor should be larger than one but got {scale_factor}" return True @classmethod def parse_side_ratio(s, side_ratio): try: x, y = map(int, side_ratio.split(":", 1)) if x < 1 or y < 1: raise Exception("Ratio factors have to be positive numbers") return float(x) / float(y) except: return None def resize(self, pixels, action, smaller_side, larger_side, scale_factor, resize_mode, side_ratio, crop_pad_position, pad_feathering, mask_optional=None): validity = self.VALIDATE_INPUTS(action, smaller_side, larger_side, scale_factor, resize_mode, side_ratio) if validity is not True: raise Exception(validity) height, width = pixels.shape[1:3] if mask_optional is None: mask = torch.zeros(1, height, width, dtype=torch.float32) else: mask = mask_optional if mask.shape[1] != height or mask.shape[2] != width: mask = torch.nn.functional.interpolate(mask.unsqueeze(0), size=(height, width), mode="bicubic").squeeze(0).clamp(0.0, 1.0) crop_x, crop_y, pad_x, pad_y = (0.0, 0.0, 0.0, 0.0) if action == self.ACTION_TYPE_CROP: target_ratio = self.parse_side_ratio(side_ratio) if height * target_ratio < width: crop_x = width - height * target_ratio else: crop_y = height - width / target_ratio elif action == self.ACTION_TYPE_PAD: target_ratio = self.parse_side_ratio(side_ratio) if height * target_ratio > width: pad_x = height * target_ratio - width else: pad_y = width / target_ratio - height if smaller_side > 0: if width + pad_x - crop_x > height + pad_y - crop_y: scale_factor = float(smaller_side) / (height + pad_y - crop_y) else: scale_factor = float(smaller_side) / (width + pad_x - crop_x) if larger_side > 0: if width + pad_x - crop_x > height + pad_y - crop_y: scale_factor = float(larger_side) / (width + pad_x - crop_x) else: scale_factor = float(larger_side) / (height + pad_y - crop_y) if (resize_mode == self.RESIZE_MODE_DOWNSCALE and scale_factor >= 1.0) or (resize_mode == self.RESIZE_MODE_UPSCALE and scale_factor <= 1.0): scale_factor = 0.0 if scale_factor > 0.0: pixels = torch.nn.functional.interpolate(pixels.movedim(-1, 1), scale_factor=scale_factor, mode="bicubic", antialias=True).movedim(1, -1).clamp(0.0, 1.0) mask = torch.nn.functional.interpolate(mask.unsqueeze(0), scale_factor=scale_factor, mode="bicubic", antialias=True).squeeze(0).clamp(0.0, 1.0) height, width = pixels.shape[1:3] crop_x *= scale_factor crop_y *= scale_factor pad_x *= scale_factor pad_y *= scale_factor if crop_x > 0.0 or crop_y > 0.0: remove_x = (round(crop_x * crop_pad_position), round(crop_x * (1 - crop_pad_position))) if crop_x > 0.0 else (0, 0) remove_y = (round(crop_y * crop_pad_position), round(crop_y * (1 - crop_pad_position))) if crop_y > 0.0 else (0, 0) pixels = pixels[:, remove_y[0]:height - remove_y[1], remove_x[0]:width - remove_x[1], :] mask = mask[:, remove_y[0]:height - remove_y[1], remove_x[0]:width - remove_x[1]] elif pad_x > 0.0 or pad_y > 0.0: add_x = (round(pad_x * crop_pad_position), round(pad_x * (1 - crop_pad_position))) if pad_x > 0.0 else (0, 0) add_y = (round(pad_y * crop_pad_position), round(pad_y * (1 - crop_pad_position))) if pad_y > 0.0 else (0, 0) new_pixels = torch.zeros(pixels.shape[0], height + add_y[0] + add_y[1], width + add_x[0] + add_x[1], pixels.shape[3], dtype=torch.float32) new_pixels[:, add_y[0]:height + add_y[0], add_x[0]:width + add_x[0], :] = pixels pixels = new_pixels new_mask = torch.ones(mask.shape[0], height + add_y[0] + add_y[1], width + add_x[0] + add_x[1], dtype=torch.float32) new_mask[:, add_y[0]:height + add_y[0], add_x[0]:width + add_x[0]] = mask mask = new_mask if pad_feathering > 0: for i in range(mask.shape[0]): for j in range(pad_feathering): feather_strength = (1 - j / pad_feathering) * (1 - j / pad_feathering) if add_x[0] > 0 and j < width: for k in range(height): mask[i, k, add_x[0] + j] = max(mask[i, k, add_x[0] + j], feather_strength) if add_x[1] > 0 and j < width: for k in range(height): mask[i, k, width + add_x[0] - j - 1] = max(mask[i, k, width + add_x[0] - j - 1], feather_strength) if add_y[0] > 0 and j < height: for k in range(width): mask[i, add_y[0] + j, k] = max(mask[i, add_y[0] + j, k], feather_strength) if add_y[1] > 0 and j < height: for k in range(width): mask[i, height + add_y[0] - j - 1, k] = max(mask[i, height + add_y[0] - j - 1, k], feather_strength) return (pixels, mask) NODE_CLASS_MAPPINGS = { "ImageResize": ImageResize } NODE_DISPLAY_NAME_MAPPINGS = { "ImageResize": "Image Resize" }