|
import torch |
|
|
|
class ImageResize: |
|
def __init__(self): |
|
pass |
|
|
|
|
|
ACTION_TYPE_RESIZE = "resize only" |
|
ACTION_TYPE_CROP = "crop to ratio" |
|
ACTION_TYPE_PAD = "pad to ratio" |
|
RESIZE_MODE_DOWNSCALE = "reduce size only" |
|
RESIZE_MODE_UPSCALE = "increase size only" |
|
RESIZE_MODE_ANY = "any" |
|
RETURN_TYPES = ("IMAGE", "MASK",) |
|
FUNCTION = "resize" |
|
CATEGORY = "image" |
|
|
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"pixels": ("IMAGE",), |
|
"action": ([s.ACTION_TYPE_RESIZE, s.ACTION_TYPE_CROP, s.ACTION_TYPE_PAD],), |
|
"smaller_side": ("INT", {"default": 0, "min": 0, "max": 8192, "step": 8}), |
|
"larger_side": ("INT", {"default": 0, "min": 0, "max": 8192, "step": 8}), |
|
"scale_factor": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.1}), |
|
"resize_mode": ([s.RESIZE_MODE_DOWNSCALE, s.RESIZE_MODE_UPSCALE, s.RESIZE_MODE_ANY],), |
|
"side_ratio": ("STRING", {"default": "4:3"}), |
|
"crop_pad_position": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), |
|
"pad_feathering": ("INT", {"default": 20, "min": 0, "max": 8192, "step": 1}), |
|
}, |
|
"optional": { |
|
"mask_optional": ("MASK",), |
|
}, |
|
} |
|
|
|
|
|
@classmethod |
|
def VALIDATE_INPUTS(s, action, smaller_side, larger_side, scale_factor, resize_mode, side_ratio, **_): |
|
if side_ratio is not None: |
|
if action != s.ACTION_TYPE_RESIZE and s.parse_side_ratio(side_ratio) is None: |
|
return f"Invalid side ratio: {side_ratio}" |
|
|
|
if smaller_side is not None and larger_side is not None and scale_factor is not None: |
|
if int(smaller_side > 0) + int(larger_side > 0) + int(scale_factor > 0) > 1: |
|
return f"At most one scaling rule (smaller_side, larger_side, scale_factor) should be enabled by setting a non-zero value" |
|
|
|
if scale_factor is not None: |
|
if resize_mode == s.RESIZE_MODE_DOWNSCALE and scale_factor > 1.0: |
|
return f"For resize_mode {s.RESIZE_MODE_DOWNSCALE}, scale_factor should be less than one but got {scale_factor}" |
|
if resize_mode == s.RESIZE_MODE_UPSCALE and scale_factor > 0.0 and scale_factor < 1.0: |
|
return f"For resize_mode {s.RESIZE_MODE_UPSCALE}, scale_factor should be larger than one but got {scale_factor}" |
|
|
|
return True |
|
|
|
|
|
@classmethod |
|
def parse_side_ratio(s, side_ratio): |
|
try: |
|
x, y = map(int, side_ratio.split(":", 1)) |
|
if x < 1 or y < 1: |
|
raise Exception("Ratio factors have to be positive numbers") |
|
return float(x) / float(y) |
|
except: |
|
return None |
|
|
|
|
|
def resize(self, pixels, action, smaller_side, larger_side, scale_factor, resize_mode, side_ratio, crop_pad_position, pad_feathering, mask_optional=None): |
|
validity = self.VALIDATE_INPUTS(action, smaller_side, larger_side, scale_factor, resize_mode, side_ratio) |
|
if validity is not True: |
|
raise Exception(validity) |
|
|
|
height, width = pixels.shape[1:3] |
|
if mask_optional is None: |
|
mask = torch.zeros(1, height, width, dtype=torch.float32) |
|
else: |
|
mask = mask_optional |
|
if mask.shape[1] != height or mask.shape[2] != width: |
|
mask = torch.nn.functional.interpolate(mask.unsqueeze(0), size=(height, width), mode="bicubic").squeeze(0).clamp(0.0, 1.0) |
|
|
|
crop_x, crop_y, pad_x, pad_y = (0.0, 0.0, 0.0, 0.0) |
|
if action == self.ACTION_TYPE_CROP: |
|
target_ratio = self.parse_side_ratio(side_ratio) |
|
if height * target_ratio < width: |
|
crop_x = width - height * target_ratio |
|
else: |
|
crop_y = height - width / target_ratio |
|
elif action == self.ACTION_TYPE_PAD: |
|
target_ratio = self.parse_side_ratio(side_ratio) |
|
if height * target_ratio > width: |
|
pad_x = height * target_ratio - width |
|
else: |
|
pad_y = width / target_ratio - height |
|
|
|
if smaller_side > 0: |
|
if width + pad_x - crop_x > height + pad_y - crop_y: |
|
scale_factor = float(smaller_side) / (height + pad_y - crop_y) |
|
else: |
|
scale_factor = float(smaller_side) / (width + pad_x - crop_x) |
|
if larger_side > 0: |
|
if width + pad_x - crop_x > height + pad_y - crop_y: |
|
scale_factor = float(larger_side) / (width + pad_x - crop_x) |
|
else: |
|
scale_factor = float(larger_side) / (height + pad_y - crop_y) |
|
|
|
if (resize_mode == self.RESIZE_MODE_DOWNSCALE and scale_factor >= 1.0) or (resize_mode == self.RESIZE_MODE_UPSCALE and scale_factor <= 1.0): |
|
scale_factor = 0.0 |
|
|
|
if scale_factor > 0.0: |
|
pixels = torch.nn.functional.interpolate(pixels.movedim(-1, 1), scale_factor=scale_factor, mode="bicubic", antialias=True).movedim(1, -1).clamp(0.0, 1.0) |
|
mask = torch.nn.functional.interpolate(mask.unsqueeze(0), scale_factor=scale_factor, mode="bicubic", antialias=True).squeeze(0).clamp(0.0, 1.0) |
|
height, width = pixels.shape[1:3] |
|
|
|
crop_x *= scale_factor |
|
crop_y *= scale_factor |
|
pad_x *= scale_factor |
|
pad_y *= scale_factor |
|
|
|
if crop_x > 0.0 or crop_y > 0.0: |
|
remove_x = (round(crop_x * crop_pad_position), round(crop_x * (1 - crop_pad_position))) if crop_x > 0.0 else (0, 0) |
|
remove_y = (round(crop_y * crop_pad_position), round(crop_y * (1 - crop_pad_position))) if crop_y > 0.0 else (0, 0) |
|
pixels = pixels[:, remove_y[0]:height - remove_y[1], remove_x[0]:width - remove_x[1], :] |
|
mask = mask[:, remove_y[0]:height - remove_y[1], remove_x[0]:width - remove_x[1]] |
|
elif pad_x > 0.0 or pad_y > 0.0: |
|
add_x = (round(pad_x * crop_pad_position), round(pad_x * (1 - crop_pad_position))) if pad_x > 0.0 else (0, 0) |
|
add_y = (round(pad_y * crop_pad_position), round(pad_y * (1 - crop_pad_position))) if pad_y > 0.0 else (0, 0) |
|
|
|
new_pixels = torch.zeros(pixels.shape[0], height + add_y[0] + add_y[1], width + add_x[0] + add_x[1], pixels.shape[3], dtype=torch.float32) |
|
new_pixels[:, add_y[0]:height + add_y[0], add_x[0]:width + add_x[0], :] = pixels |
|
pixels = new_pixels |
|
|
|
new_mask = torch.ones(mask.shape[0], height + add_y[0] + add_y[1], width + add_x[0] + add_x[1], dtype=torch.float32) |
|
new_mask[:, add_y[0]:height + add_y[0], add_x[0]:width + add_x[0]] = mask |
|
mask = new_mask |
|
|
|
if pad_feathering > 0: |
|
for i in range(mask.shape[0]): |
|
for j in range(pad_feathering): |
|
feather_strength = (1 - j / pad_feathering) * (1 - j / pad_feathering) |
|
if add_x[0] > 0 and j < width: |
|
for k in range(height): |
|
mask[i, k, add_x[0] + j] = max(mask[i, k, add_x[0] + j], feather_strength) |
|
if add_x[1] > 0 and j < width: |
|
for k in range(height): |
|
mask[i, k, width + add_x[0] - j - 1] = max(mask[i, k, width + add_x[0] - j - 1], feather_strength) |
|
if add_y[0] > 0 and j < height: |
|
for k in range(width): |
|
mask[i, add_y[0] + j, k] = max(mask[i, add_y[0] + j, k], feather_strength) |
|
if add_y[1] > 0 and j < height: |
|
for k in range(width): |
|
mask[i, height + add_y[0] - j - 1, k] = max(mask[i, height + add_y[0] - j - 1, k], feather_strength) |
|
|
|
return (pixels, mask) |
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"ImageResize": ImageResize |
|
} |
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
"ImageResize": "Image Resize" |
|
} |
|
|