|
import torch |
|
import torch.nn.functional as F |
|
from torchvision.transforms import functional as TF |
|
from PIL import Image, ImageDraw, ImageFilter, ImageFont |
|
import scipy.ndimage |
|
import numpy as np |
|
from contextlib import nullcontext |
|
import os |
|
|
|
import model_management |
|
from comfy.utils import ProgressBar |
|
from comfy.utils import common_upscale |
|
from nodes import MAX_RESOLUTION |
|
|
|
import folder_paths |
|
|
|
from ..utility.utility import tensor2pil, pil2tensor |
|
|
|
script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
|
|
class BatchCLIPSeg: |
|
|
|
def __init__(self): |
|
pass |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
|
|
return {"required": |
|
{ |
|
"images": ("IMAGE",), |
|
"text": ("STRING", {"multiline": False}), |
|
"threshold": ("FLOAT", {"default": 0.5,"min": 0.0, "max": 10.0, "step": 0.001}), |
|
"binary_mask": ("BOOLEAN", {"default": True}), |
|
"combine_mask": ("BOOLEAN", {"default": False}), |
|
"use_cuda": ("BOOLEAN", {"default": True}), |
|
}, |
|
"optional": |
|
{ |
|
"blur_sigma": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.1}), |
|
"opt_model": ("CLIPSEGMODEL", ), |
|
"prev_mask": ("MASK", {"default": None}), |
|
"image_bg_level": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), |
|
"invert": ("BOOLEAN", {"default": False}), |
|
} |
|
} |
|
|
|
CATEGORY = "KJNodes/masking" |
|
RETURN_TYPES = ("MASK", "IMAGE", ) |
|
RETURN_NAMES = ("Mask", "Image", ) |
|
FUNCTION = "segment_image" |
|
DESCRIPTION = """ |
|
Segments an image or batch of images using CLIPSeg. |
|
""" |
|
|
|
def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda, blur_sigma=0.0, opt_model=None, prev_mask=None, invert= False, image_bg_level=0.5): |
|
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation |
|
import torchvision.transforms as transforms |
|
offload_device = model_management.unet_offload_device() |
|
device = model_management.get_torch_device() |
|
if not use_cuda: |
|
device = torch.device("cpu") |
|
dtype = model_management.unet_dtype() |
|
|
|
if opt_model is None: |
|
checkpoint_path = os.path.join(folder_paths.models_dir,'clip_seg', 'clipseg-rd64-refined-fp16') |
|
if not hasattr(self, "model"): |
|
try: |
|
if not os.path.exists(checkpoint_path): |
|
from huggingface_hub import snapshot_download |
|
snapshot_download(repo_id="Kijai/clipseg-rd64-refined-fp16", local_dir=checkpoint_path, local_dir_use_symlinks=False) |
|
self.model = CLIPSegForImageSegmentation.from_pretrained(checkpoint_path) |
|
except: |
|
checkpoint_path = "CIDAS/clipseg-rd64-refined" |
|
self.model = CLIPSegForImageSegmentation.from_pretrained(checkpoint_path) |
|
processor = CLIPSegProcessor.from_pretrained(checkpoint_path) |
|
|
|
else: |
|
self.model = opt_model['model'] |
|
processor = opt_model['processor'] |
|
|
|
self.model.to(dtype).to(device) |
|
|
|
B, H, W, C = images.shape |
|
images = images.to(device) |
|
|
|
autocast_condition = (dtype != torch.float32) and not model_management.is_device_mps(device) |
|
with torch.autocast(model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext(): |
|
|
|
PIL_images = [Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) for image in images ] |
|
prompt = [text] * len(images) |
|
input_prc = processor(text=prompt, images=PIL_images, return_tensors="pt") |
|
|
|
for key in input_prc: |
|
input_prc[key] = input_prc[key].to(device) |
|
outputs = self.model(**input_prc) |
|
|
|
mask_tensor = torch.sigmoid(outputs.logits) |
|
mask_tensor = (mask_tensor - mask_tensor.min()) / (mask_tensor.max() - mask_tensor.min()) |
|
mask_tensor = torch.where(mask_tensor > (threshold), mask_tensor, torch.tensor(0, dtype=torch.float)) |
|
print(mask_tensor.shape) |
|
if len(mask_tensor.shape) == 2: |
|
mask_tensor = mask_tensor.unsqueeze(0) |
|
mask_tensor = F.interpolate(mask_tensor.unsqueeze(1), size=(H, W), mode='nearest') |
|
mask_tensor = mask_tensor.squeeze(1) |
|
|
|
self.model.to(offload_device) |
|
|
|
if binary_mask: |
|
mask_tensor = (mask_tensor > 0).float() |
|
if blur_sigma > 0: |
|
kernel_size = int(6 * int(blur_sigma) + 1) |
|
blur = transforms.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=(blur_sigma, blur_sigma)) |
|
mask_tensor = blur(mask_tensor) |
|
|
|
if combine_mask: |
|
mask_tensor = torch.max(mask_tensor, dim=0)[0] |
|
mask_tensor = mask_tensor.unsqueeze(0).repeat(len(images),1,1) |
|
|
|
del outputs |
|
model_management.soft_empty_cache() |
|
|
|
if prev_mask is not None: |
|
if prev_mask.shape != mask_tensor.shape: |
|
prev_mask = F.interpolate(prev_mask.unsqueeze(1), size=(H, W), mode='nearest') |
|
mask_tensor = mask_tensor + prev_mask.to(device) |
|
torch.clamp(mask_tensor, min=0.0, max=1.0) |
|
|
|
if invert: |
|
mask_tensor = 1 - mask_tensor |
|
|
|
image_tensor = images * mask_tensor.unsqueeze(-1) + (1 - mask_tensor.unsqueeze(-1)) * image_bg_level |
|
image_tensor = torch.clamp(image_tensor, min=0.0, max=1.0).cpu().float() |
|
|
|
mask_tensor = mask_tensor.cpu().float() |
|
|
|
return mask_tensor, image_tensor, |
|
|
|
class DownloadAndLoadCLIPSeg: |
|
|
|
def __init__(self): |
|
pass |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
|
|
return {"required": |
|
{ |
|
"model": ( |
|
[ 'Kijai/clipseg-rd64-refined-fp16', |
|
'CIDAS/clipseg-rd64-refined', |
|
], |
|
), |
|
}, |
|
} |
|
|
|
CATEGORY = "KJNodes/masking" |
|
RETURN_TYPES = ("CLIPSEGMODEL",) |
|
RETURN_NAMES = ("clipseg_model",) |
|
FUNCTION = "segment_image" |
|
DESCRIPTION = """ |
|
Downloads and loads CLIPSeg model with huggingface_hub, |
|
to ComfyUI/models/clip_seg |
|
""" |
|
|
|
def segment_image(self, model): |
|
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation |
|
checkpoint_path = os.path.join(folder_paths.models_dir,'clip_seg', os.path.basename(model)) |
|
if not hasattr(self, "model"): |
|
if not os.path.exists(checkpoint_path): |
|
from huggingface_hub import snapshot_download |
|
snapshot_download(repo_id=model, local_dir=checkpoint_path, local_dir_use_symlinks=False) |
|
self.model = CLIPSegForImageSegmentation.from_pretrained(checkpoint_path) |
|
|
|
processor = CLIPSegProcessor.from_pretrained(checkpoint_path) |
|
|
|
clipseg_model = {} |
|
clipseg_model['model'] = self.model |
|
clipseg_model['processor'] = processor |
|
|
|
return clipseg_model, |
|
|
|
class CreateTextMask: |
|
|
|
RETURN_TYPES = ("IMAGE", "MASK",) |
|
FUNCTION = "createtextmask" |
|
CATEGORY = "KJNodes/text" |
|
DESCRIPTION = """ |
|
Creates a text image and mask. |
|
Looks for fonts from this folder: |
|
ComfyUI/custom_nodes/ComfyUI-KJNodes/fonts |
|
|
|
If start_rotation and/or end_rotation are different values, |
|
creates animation between them. |
|
""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"invert": ("BOOLEAN", {"default": False}), |
|
"frames": ("INT", {"default": 1,"min": 1, "max": 4096, "step": 1}), |
|
"text_x": ("INT", {"default": 0,"min": 0, "max": 4096, "step": 1}), |
|
"text_y": ("INT", {"default": 0,"min": 0, "max": 4096, "step": 1}), |
|
"font_size": ("INT", {"default": 32,"min": 8, "max": 4096, "step": 1}), |
|
"font_color": ("STRING", {"default": "white"}), |
|
"text": ("STRING", {"default": "HELLO!", "multiline": True}), |
|
"font": (folder_paths.get_filename_list("kjnodes_fonts"), ), |
|
"width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), |
|
"height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), |
|
"start_rotation": ("INT", {"default": 0,"min": 0, "max": 359, "step": 1}), |
|
"end_rotation": ("INT", {"default": 0,"min": -359, "max": 359, "step": 1}), |
|
}, |
|
} |
|
|
|
def createtextmask(self, frames, width, height, invert, text_x, text_y, text, font_size, font_color, font, start_rotation, end_rotation): |
|
|
|
batch_size = frames |
|
out = [] |
|
masks = [] |
|
rotation = start_rotation |
|
if start_rotation != end_rotation: |
|
rotation_increment = (end_rotation - start_rotation) / (batch_size - 1) |
|
|
|
font_path = folder_paths.get_full_path("kjnodes_fonts", font) |
|
|
|
for i in range(batch_size): |
|
image = Image.new("RGB", (width, height), "black") |
|
draw = ImageDraw.Draw(image) |
|
font = ImageFont.truetype(font_path, font_size) |
|
|
|
|
|
words = text.split() |
|
|
|
|
|
lines = [] |
|
current_line = [] |
|
current_line_width = 0 |
|
try: |
|
|
|
for word in words: |
|
word_width = font.getbbox(word)[2] |
|
if current_line_width + word_width <= width - 2 * text_x: |
|
current_line.append(word) |
|
current_line_width += word_width + font.getbbox(" ")[2] |
|
else: |
|
lines.append(" ".join(current_line)) |
|
current_line = [word] |
|
current_line_width = word_width |
|
except: |
|
for word in words: |
|
word_width = font.getsize(word)[0] |
|
if current_line_width + word_width <= width - 2 * text_x: |
|
current_line.append(word) |
|
current_line_width += word_width + font.getsize(" ")[0] |
|
else: |
|
lines.append(" ".join(current_line)) |
|
current_line = [word] |
|
current_line_width = word_width |
|
|
|
|
|
if current_line: |
|
lines.append(" ".join(current_line)) |
|
|
|
|
|
y_offset = text_y |
|
for line in lines: |
|
text_width = font.getlength(line) |
|
text_height = font_size |
|
text_center_x = text_x + text_width / 2 |
|
text_center_y = y_offset + text_height / 2 |
|
try: |
|
draw.text((text_x, y_offset), line, font=font, fill=font_color, features=['-liga']) |
|
except: |
|
draw.text((text_x, y_offset), line, font=font, fill=font_color) |
|
y_offset += text_height |
|
|
|
if start_rotation != end_rotation: |
|
image = image.rotate(rotation, center=(text_center_x, text_center_y)) |
|
rotation += rotation_increment |
|
|
|
image = np.array(image).astype(np.float32) / 255.0 |
|
image = torch.from_numpy(image)[None,] |
|
mask = image[:, :, :, 0] |
|
masks.append(mask) |
|
out.append(image) |
|
|
|
if invert: |
|
return (1.0 - torch.cat(out, dim=0), 1.0 - torch.cat(masks, dim=0),) |
|
return (torch.cat(out, dim=0),torch.cat(masks, dim=0),) |
|
|
|
class ColorToMask: |
|
|
|
RETURN_TYPES = ("MASK",) |
|
FUNCTION = "clip" |
|
CATEGORY = "KJNodes/masking" |
|
DESCRIPTION = """ |
|
Converts chosen RGB value to a mask. |
|
With batch inputs, the **per_batch** |
|
controls the number of images processed at once. |
|
""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"images": ("IMAGE",), |
|
"invert": ("BOOLEAN", {"default": False}), |
|
"red": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}), |
|
"green": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}), |
|
"blue": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}), |
|
"threshold": ("INT", {"default": 10,"min": 0, "max": 255, "step": 1}), |
|
"per_batch": ("INT", {"default": 16, "min": 1, "max": 4096, "step": 1}), |
|
}, |
|
} |
|
|
|
def clip(self, images, red, green, blue, threshold, invert, per_batch): |
|
|
|
color = torch.tensor([red, green, blue], dtype=torch.uint8) |
|
black = torch.tensor([0, 0, 0], dtype=torch.uint8) |
|
white = torch.tensor([255, 255, 255], dtype=torch.uint8) |
|
|
|
if invert: |
|
black, white = white, black |
|
|
|
steps = images.shape[0] |
|
pbar = ProgressBar(steps) |
|
tensors_out = [] |
|
|
|
for start_idx in range(0, images.shape[0], per_batch): |
|
|
|
|
|
color_distances = torch.norm(images[start_idx:start_idx+per_batch] * 255 - color, dim=-1) |
|
|
|
|
|
mask = color_distances <= threshold |
|
|
|
|
|
mask_out = torch.where(mask.unsqueeze(-1), white, black).float() |
|
mask_out = mask_out.mean(dim=-1) |
|
|
|
tensors_out.append(mask_out.cpu()) |
|
batch_count = mask_out.shape[0] |
|
pbar.update(batch_count) |
|
|
|
tensors_out = torch.cat(tensors_out, dim=0) |
|
tensors_out = torch.clamp(tensors_out, min=0.0, max=1.0) |
|
return tensors_out, |
|
|
|
class CreateFluidMask: |
|
|
|
RETURN_TYPES = ("IMAGE", "MASK") |
|
FUNCTION = "createfluidmask" |
|
CATEGORY = "KJNodes/masking/generate" |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"invert": ("BOOLEAN", {"default": False}), |
|
"frames": ("INT", {"default": 1,"min": 1, "max": 4096, "step": 1}), |
|
"width": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}), |
|
"height": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}), |
|
"inflow_count": ("INT", {"default": 3,"min": 0, "max": 255, "step": 1}), |
|
"inflow_velocity": ("INT", {"default": 1,"min": 0, "max": 255, "step": 1}), |
|
"inflow_radius": ("INT", {"default": 8,"min": 0, "max": 255, "step": 1}), |
|
"inflow_padding": ("INT", {"default": 50,"min": 0, "max": 255, "step": 1}), |
|
"inflow_duration": ("INT", {"default": 60,"min": 0, "max": 255, "step": 1}), |
|
}, |
|
} |
|
|
|
def createfluidmask(self, frames, width, height, invert, inflow_count, inflow_velocity, inflow_radius, inflow_padding, inflow_duration): |
|
from ..utility.fluid import Fluid |
|
try: |
|
from scipy.special import erf |
|
except: |
|
from scipy.spatial import erf |
|
out = [] |
|
masks = [] |
|
RESOLUTION = width, height |
|
DURATION = frames |
|
|
|
INFLOW_PADDING = inflow_padding |
|
INFLOW_DURATION = inflow_duration |
|
INFLOW_RADIUS = inflow_radius |
|
INFLOW_VELOCITY = inflow_velocity |
|
INFLOW_COUNT = inflow_count |
|
|
|
print('Generating fluid solver, this may take some time.') |
|
fluid = Fluid(RESOLUTION, 'dye') |
|
|
|
center = np.floor_divide(RESOLUTION, 2) |
|
r = np.min(center) - INFLOW_PADDING |
|
|
|
points = np.linspace(-np.pi, np.pi, INFLOW_COUNT, endpoint=False) |
|
points = tuple(np.array((np.cos(p), np.sin(p))) for p in points) |
|
normals = tuple(-p for p in points) |
|
points = tuple(r * p + center for p in points) |
|
|
|
inflow_velocity = np.zeros_like(fluid.velocity) |
|
inflow_dye = np.zeros(fluid.shape) |
|
for p, n in zip(points, normals): |
|
mask = np.linalg.norm(fluid.indices - p[:, None, None], axis=0) <= INFLOW_RADIUS |
|
inflow_velocity[:, mask] += n[:, None] * INFLOW_VELOCITY |
|
inflow_dye[mask] = 1 |
|
|
|
|
|
for f in range(DURATION): |
|
print(f'Computing frame {f + 1} of {DURATION}.') |
|
if f <= INFLOW_DURATION: |
|
fluid.velocity += inflow_velocity |
|
fluid.dye += inflow_dye |
|
|
|
curl = fluid.step()[1] |
|
|
|
|
|
curl = (erf(curl * 2) + 1) / 4 |
|
|
|
color = np.dstack((curl, np.ones(fluid.shape), fluid.dye)) |
|
color = (np.clip(color, 0, 1) * 255).astype('uint8') |
|
image = np.array(color).astype(np.float32) / 255.0 |
|
image = torch.from_numpy(image)[None,] |
|
mask = image[:, :, :, 0] |
|
masks.append(mask) |
|
out.append(image) |
|
|
|
if invert: |
|
return (1.0 - torch.cat(out, dim=0),1.0 - torch.cat(masks, dim=0),) |
|
return (torch.cat(out, dim=0),torch.cat(masks, dim=0),) |
|
|
|
class CreateAudioMask: |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
FUNCTION = "createaudiomask" |
|
CATEGORY = "KJNodes/deprecated" |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"invert": ("BOOLEAN", {"default": False}), |
|
"frames": ("INT", {"default": 16,"min": 1, "max": 255, "step": 1}), |
|
"scale": ("FLOAT", {"default": 0.5,"min": 0.0, "max": 2.0, "step": 0.01}), |
|
"audio_path": ("STRING", {"default": "audio.wav"}), |
|
"width": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}), |
|
"height": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}), |
|
}, |
|
} |
|
|
|
def createaudiomask(self, frames, width, height, invert, audio_path, scale): |
|
try: |
|
import librosa |
|
except ImportError: |
|
raise Exception("Can not import librosa. Install it with 'pip install librosa'") |
|
batch_size = frames |
|
out = [] |
|
masks = [] |
|
if audio_path == "audio.wav": |
|
audio_path = os.path.join(script_directory, audio_path) |
|
audio, sr = librosa.load(audio_path) |
|
spectrogram = np.abs(librosa.stft(audio)) |
|
|
|
for i in range(batch_size): |
|
image = Image.new("RGB", (width, height), "black") |
|
draw = ImageDraw.Draw(image) |
|
frame = spectrogram[:, i] |
|
circle_radius = int(height * np.mean(frame)) |
|
circle_radius *= scale |
|
circle_center = (width // 2, height // 2) |
|
|
|
draw.ellipse([(circle_center[0] - circle_radius, circle_center[1] - circle_radius), |
|
(circle_center[0] + circle_radius, circle_center[1] + circle_radius)], |
|
fill='white') |
|
|
|
image = np.array(image).astype(np.float32) / 255.0 |
|
image = torch.from_numpy(image)[None,] |
|
mask = image[:, :, :, 0] |
|
masks.append(mask) |
|
out.append(image) |
|
|
|
if invert: |
|
return (1.0 - torch.cat(out, dim=0),) |
|
return (torch.cat(out, dim=0),torch.cat(masks, dim=0),) |
|
|
|
class CreateGradientMask: |
|
|
|
RETURN_TYPES = ("MASK",) |
|
FUNCTION = "createmask" |
|
CATEGORY = "KJNodes/masking/generate" |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"invert": ("BOOLEAN", {"default": False}), |
|
"frames": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}), |
|
"width": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}), |
|
"height": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}), |
|
}, |
|
} |
|
def createmask(self, frames, width, height, invert): |
|
|
|
batch_size = frames |
|
out = [] |
|
|
|
image_batch = np.zeros((batch_size, height, width), dtype=np.float32) |
|
|
|
for i in range(batch_size): |
|
gradient = np.linspace(1.0, 0.0, width, dtype=np.float32) |
|
time = i / frames |
|
offset_gradient = gradient - time |
|
image_batch[i] = offset_gradient.reshape(1, -1) |
|
output = torch.from_numpy(image_batch) |
|
mask = output |
|
out.append(mask) |
|
if invert: |
|
return (1.0 - torch.cat(out, dim=0),) |
|
return (torch.cat(out, dim=0),) |
|
|
|
class CreateFadeMask: |
|
|
|
RETURN_TYPES = ("MASK",) |
|
FUNCTION = "createfademask" |
|
CATEGORY = "KJNodes/deprecated" |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"invert": ("BOOLEAN", {"default": False}), |
|
"frames": ("INT", {"default": 2,"min": 2, "max": 10000, "step": 1}), |
|
"width": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}), |
|
"height": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}), |
|
"interpolation": (["linear", "ease_in", "ease_out", "ease_in_out"],), |
|
"start_level": ("FLOAT", {"default": 1.0,"min": 0.0, "max": 1.0, "step": 0.01}), |
|
"midpoint_level": ("FLOAT", {"default": 0.5,"min": 0.0, "max": 1.0, "step": 0.01}), |
|
"end_level": ("FLOAT", {"default": 0.0,"min": 0.0, "max": 1.0, "step": 0.01}), |
|
"midpoint_frame": ("INT", {"default": 0,"min": 0, "max": 4096, "step": 1}), |
|
}, |
|
} |
|
|
|
def createfademask(self, frames, width, height, invert, interpolation, start_level, midpoint_level, end_level, midpoint_frame): |
|
def ease_in(t): |
|
return t * t |
|
|
|
def ease_out(t): |
|
return 1 - (1 - t) * (1 - t) |
|
|
|
def ease_in_out(t): |
|
return 3 * t * t - 2 * t * t * t |
|
|
|
batch_size = frames |
|
out = [] |
|
image_batch = np.zeros((batch_size, height, width), dtype=np.float32) |
|
|
|
if midpoint_frame == 0: |
|
midpoint_frame = batch_size // 2 |
|
|
|
for i in range(batch_size): |
|
if i <= midpoint_frame: |
|
t = i / midpoint_frame |
|
if interpolation == "ease_in": |
|
t = ease_in(t) |
|
elif interpolation == "ease_out": |
|
t = ease_out(t) |
|
elif interpolation == "ease_in_out": |
|
t = ease_in_out(t) |
|
color = start_level - t * (start_level - midpoint_level) |
|
else: |
|
t = (i - midpoint_frame) / (batch_size - midpoint_frame) |
|
if interpolation == "ease_in": |
|
t = ease_in(t) |
|
elif interpolation == "ease_out": |
|
t = ease_out(t) |
|
elif interpolation == "ease_in_out": |
|
t = ease_in_out(t) |
|
color = midpoint_level - t * (midpoint_level - end_level) |
|
|
|
color = np.clip(color, 0, 255) |
|
image = np.full((height, width), color, dtype=np.float32) |
|
image_batch[i] = image |
|
|
|
output = torch.from_numpy(image_batch) |
|
mask = output |
|
out.append(mask) |
|
|
|
if invert: |
|
return (1.0 - torch.cat(out, dim=0),) |
|
return (torch.cat(out, dim=0),) |
|
|
|
class CreateFadeMaskAdvanced: |
|
|
|
RETURN_TYPES = ("MASK",) |
|
FUNCTION = "createfademask" |
|
CATEGORY = "KJNodes/masking/generate" |
|
DESCRIPTION = """ |
|
Create a batch of masks interpolated between given frames and values. |
|
Uses same syntax as Fizz' BatchValueSchedule. |
|
First value is the frame index (not that this starts from 0, not 1) |
|
and the second value inside the brackets is the float value of the mask in range 0.0 - 1.0 |
|
|
|
For example the default values: |
|
0:(0.0) |
|
7:(1.0) |
|
15:(0.0) |
|
|
|
Would create a mask batch fo 16 frames, starting from black, |
|
interpolating with the chosen curve to fully white at the 8th frame, |
|
and interpolating from that to fully black at the 16th frame. |
|
""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"points_string": ("STRING", {"default": "0:(0.0),\n7:(1.0),\n15:(0.0)\n", "multiline": True}), |
|
"invert": ("BOOLEAN", {"default": False}), |
|
"frames": ("INT", {"default": 16,"min": 2, "max": 10000, "step": 1}), |
|
"width": ("INT", {"default": 512,"min": 1, "max": 4096, "step": 1}), |
|
"height": ("INT", {"default": 512,"min": 1, "max": 4096, "step": 1}), |
|
"interpolation": (["linear", "ease_in", "ease_out", "ease_in_out"],), |
|
}, |
|
} |
|
|
|
def createfademask(self, frames, width, height, invert, points_string, interpolation): |
|
def ease_in(t): |
|
return t * t |
|
|
|
def ease_out(t): |
|
return 1 - (1 - t) * (1 - t) |
|
|
|
def ease_in_out(t): |
|
return 3 * t * t - 2 * t * t * t |
|
|
|
|
|
points = [] |
|
points_string = points_string.rstrip(',\n') |
|
for point_str in points_string.split(','): |
|
frame_str, color_str = point_str.split(':') |
|
frame = int(frame_str.strip()) |
|
color = float(color_str.strip()[1:-1]) |
|
points.append((frame, color)) |
|
|
|
|
|
if len(points) == 0 or points[-1][0] != frames - 1: |
|
|
|
points.append((frames - 1, points[-1][1] if points else 0)) |
|
|
|
|
|
points.sort(key=lambda x: x[0]) |
|
|
|
batch_size = frames |
|
out = [] |
|
image_batch = np.zeros((batch_size, height, width), dtype=np.float32) |
|
|
|
|
|
next_point = 1 |
|
|
|
for i in range(batch_size): |
|
while next_point < len(points) and i > points[next_point][0]: |
|
next_point += 1 |
|
|
|
|
|
prev_point = next_point - 1 |
|
t = (i - points[prev_point][0]) / (points[next_point][0] - points[prev_point][0]) |
|
if interpolation == "ease_in": |
|
t = ease_in(t) |
|
elif interpolation == "ease_out": |
|
t = ease_out(t) |
|
elif interpolation == "ease_in_out": |
|
t = ease_in_out(t) |
|
elif interpolation == "linear": |
|
pass |
|
|
|
color = points[prev_point][1] - t * (points[prev_point][1] - points[next_point][1]) |
|
color = np.clip(color, 0, 255) |
|
image = np.full((height, width), color, dtype=np.float32) |
|
image_batch[i] = image |
|
|
|
output = torch.from_numpy(image_batch) |
|
mask = output |
|
out.append(mask) |
|
|
|
if invert: |
|
return (1.0 - torch.cat(out, dim=0),) |
|
return (torch.cat(out, dim=0),) |
|
|
|
class CreateMagicMask: |
|
|
|
RETURN_TYPES = ("MASK", "MASK",) |
|
RETURN_NAMES = ("mask", "mask_inverted",) |
|
FUNCTION = "createmagicmask" |
|
CATEGORY = "KJNodes/masking/generate" |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"frames": ("INT", {"default": 16,"min": 2, "max": 4096, "step": 1}), |
|
"depth": ("INT", {"default": 12,"min": 1, "max": 500, "step": 1}), |
|
"distortion": ("FLOAT", {"default": 1.5,"min": 0.0, "max": 100.0, "step": 0.01}), |
|
"seed": ("INT", {"default": 123,"min": 0, "max": 99999999, "step": 1}), |
|
"transitions": ("INT", {"default": 1,"min": 1, "max": 20, "step": 1}), |
|
"frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), |
|
"frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), |
|
}, |
|
} |
|
|
|
def createmagicmask(self, frames, transitions, depth, distortion, seed, frame_width, frame_height): |
|
from ..utility.magictex import coordinate_grid, random_transform, magic |
|
import matplotlib.pyplot as plt |
|
rng = np.random.default_rng(seed) |
|
out = [] |
|
coords = coordinate_grid((frame_width, frame_height)) |
|
|
|
|
|
frames_per_transition = frames // transitions |
|
|
|
|
|
base_params = { |
|
"coords": random_transform(coords, rng), |
|
"depth": depth, |
|
"distortion": distortion, |
|
} |
|
for t in range(transitions): |
|
|
|
params1 = base_params.copy() |
|
params2 = base_params.copy() |
|
|
|
params1['coords'] = random_transform(coords, rng) |
|
params2['coords'] = random_transform(coords, rng) |
|
|
|
for i in range(frames_per_transition): |
|
|
|
alpha = i / frames_per_transition |
|
|
|
|
|
params = params1.copy() |
|
params['coords'] = (1 - alpha) * params1['coords'] + alpha * params2['coords'] |
|
|
|
tex = magic(**params) |
|
|
|
dpi = frame_width / 10 |
|
fig = plt.figure(figsize=(10, 10), dpi=dpi) |
|
|
|
ax = fig.add_subplot(111) |
|
plt.subplots_adjust(left=0, right=1, bottom=0, top=1) |
|
|
|
ax.get_yaxis().set_ticks([]) |
|
ax.get_xaxis().set_ticks([]) |
|
ax.imshow(tex, aspect='auto') |
|
|
|
fig.canvas.draw() |
|
img = np.array(fig.canvas.renderer._renderer) |
|
|
|
plt.close(fig) |
|
|
|
pil_img = Image.fromarray(img).convert("L") |
|
mask = torch.tensor(np.array(pil_img)) / 255.0 |
|
|
|
out.append(mask) |
|
|
|
return (torch.stack(out, dim=0), 1.0 - torch.stack(out, dim=0),) |
|
|
|
class CreateShapeMask: |
|
|
|
RETURN_TYPES = ("MASK", "MASK",) |
|
RETURN_NAMES = ("mask", "mask_inverted",) |
|
FUNCTION = "createshapemask" |
|
CATEGORY = "KJNodes/masking/generate" |
|
DESCRIPTION = """ |
|
Creates a mask or batch of masks with the specified shape. |
|
Locations are center locations. |
|
Grow value is the amount to grow the shape on each frame, creating animated masks. |
|
""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"shape": ( |
|
[ 'circle', |
|
'square', |
|
'triangle', |
|
], |
|
{ |
|
"default": 'circle' |
|
}), |
|
"frames": ("INT", {"default": 1,"min": 1, "max": 4096, "step": 1}), |
|
"location_x": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}), |
|
"location_y": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}), |
|
"grow": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}), |
|
"frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), |
|
"frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), |
|
"shape_width": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}), |
|
"shape_height": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}), |
|
}, |
|
} |
|
|
|
def createshapemask(self, frames, frame_width, frame_height, location_x, location_y, shape_width, shape_height, grow, shape): |
|
|
|
batch_size = frames |
|
out = [] |
|
color = "white" |
|
for i in range(batch_size): |
|
image = Image.new("RGB", (frame_width, frame_height), "black") |
|
draw = ImageDraw.Draw(image) |
|
|
|
|
|
current_width = max(0, shape_width + i*grow) |
|
current_height = max(0, shape_height + i*grow) |
|
|
|
if shape == 'circle' or shape == 'square': |
|
|
|
left_up_point = (location_x - current_width // 2, location_y - current_height // 2) |
|
right_down_point = (location_x + current_width // 2, location_y + current_height // 2) |
|
two_points = [left_up_point, right_down_point] |
|
|
|
if shape == 'circle': |
|
draw.ellipse(two_points, fill=color) |
|
elif shape == 'square': |
|
draw.rectangle(two_points, fill=color) |
|
|
|
elif shape == 'triangle': |
|
|
|
left_up_point = (location_x - current_width // 2, location_y + current_height // 2) |
|
right_down_point = (location_x + current_width // 2, location_y + current_height // 2) |
|
top_point = (location_x, location_y - current_height // 2) |
|
draw.polygon([top_point, left_up_point, right_down_point], fill=color) |
|
|
|
image = pil2tensor(image) |
|
mask = image[:, :, :, 0] |
|
out.append(mask) |
|
outstack = torch.cat(out, dim=0) |
|
return (outstack, 1.0 - outstack,) |
|
|
|
class CreateVoronoiMask: |
|
|
|
RETURN_TYPES = ("MASK", "MASK",) |
|
RETURN_NAMES = ("mask", "mask_inverted",) |
|
FUNCTION = "createvoronoi" |
|
CATEGORY = "KJNodes/masking/generate" |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"frames": ("INT", {"default": 16,"min": 2, "max": 4096, "step": 1}), |
|
"num_points": ("INT", {"default": 15,"min": 1, "max": 4096, "step": 1}), |
|
"line_width": ("INT", {"default": 4,"min": 1, "max": 4096, "step": 1}), |
|
"speed": ("FLOAT", {"default": 0.5,"min": 0.0, "max": 1.0, "step": 0.01}), |
|
"frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), |
|
"frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), |
|
}, |
|
} |
|
|
|
def createvoronoi(self, frames, num_points, line_width, speed, frame_width, frame_height): |
|
from scipy.spatial import Voronoi |
|
|
|
batch_size = frames |
|
out = [] |
|
|
|
|
|
aspect_ratio = frame_width / frame_height |
|
|
|
|
|
start_points = np.random.rand(num_points, 2) |
|
start_points[:, 0] *= aspect_ratio |
|
|
|
end_points = np.random.rand(num_points, 2) |
|
end_points[:, 0] *= aspect_ratio |
|
|
|
for i in range(batch_size): |
|
|
|
t = (i * speed) / (batch_size - 1) |
|
t = np.clip(t, 0, 1) |
|
points = (1 - t) * start_points + t * end_points |
|
|
|
|
|
points[:, 0] *= aspect_ratio |
|
|
|
vor = Voronoi(points) |
|
|
|
|
|
fig, ax = plt.subplots() |
|
plt.subplots_adjust(left=0, right=1, bottom=0, top=1) |
|
ax.set_xlim([0, aspect_ratio]); ax.set_ylim([0, 1]) |
|
ax.axis('off') |
|
ax.margins(0, 0) |
|
fig.set_size_inches(aspect_ratio * frame_height/100, frame_height/100) |
|
ax.fill_between([0, 1], [0, 1], color='white') |
|
|
|
|
|
for simplex in vor.ridge_vertices: |
|
simplex = np.asarray(simplex) |
|
if np.all(simplex >= 0): |
|
plt.plot(vor.vertices[simplex, 0], vor.vertices[simplex, 1], 'k-', linewidth=line_width) |
|
|
|
fig.canvas.draw() |
|
img = np.array(fig.canvas.renderer._renderer) |
|
|
|
plt.close(fig) |
|
|
|
pil_img = Image.fromarray(img).convert("L") |
|
mask = torch.tensor(np.array(pil_img)) / 255.0 |
|
|
|
out.append(mask) |
|
|
|
return (torch.stack(out, dim=0), 1.0 - torch.stack(out, dim=0),) |
|
|
|
class GetMaskSizeAndCount: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"mask": ("MASK",), |
|
}} |
|
|
|
RETURN_TYPES = ("MASK","INT", "INT", "INT",) |
|
RETURN_NAMES = ("mask", "width", "height", "count",) |
|
FUNCTION = "getsize" |
|
CATEGORY = "KJNodes/masking" |
|
DESCRIPTION = """ |
|
Returns the width, height and batch size of the mask, |
|
and passes it through unchanged. |
|
|
|
""" |
|
|
|
def getsize(self, mask): |
|
width = mask.shape[2] |
|
height = mask.shape[1] |
|
count = mask.shape[0] |
|
return {"ui": { |
|
"text": [f"{count}x{width}x{height}"]}, |
|
"result": (mask, width, height, count) |
|
} |
|
|
|
class GrowMaskWithBlur: |
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"mask": ("MASK",), |
|
"expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}), |
|
"incremental_expandrate": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.1}), |
|
"tapered_corners": ("BOOLEAN", {"default": True}), |
|
"flip_input": ("BOOLEAN", {"default": False}), |
|
"blur_radius": ("FLOAT", { |
|
"default": 0.0, |
|
"min": 0.0, |
|
"max": 100, |
|
"step": 0.1 |
|
}), |
|
"lerp_alpha": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), |
|
"decay_factor": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), |
|
}, |
|
"optional": { |
|
"fill_holes": ("BOOLEAN", {"default": False}), |
|
}, |
|
} |
|
|
|
CATEGORY = "KJNodes/masking" |
|
RETURN_TYPES = ("MASK", "MASK",) |
|
RETURN_NAMES = ("mask", "mask_inverted",) |
|
FUNCTION = "expand_mask" |
|
DESCRIPTION = """ |
|
# GrowMaskWithBlur |
|
- mask: Input mask or mask batch |
|
- expand: Expand or contract mask or mask batch by a given amount |
|
- incremental_expandrate: increase expand rate by a given amount per frame |
|
- tapered_corners: use tapered corners |
|
- flip_input: flip input mask |
|
- blur_radius: value higher than 0 will blur the mask |
|
- lerp_alpha: alpha value for interpolation between frames |
|
- decay_factor: decay value for interpolation between frames |
|
- fill_holes: fill holes in the mask (slow)""" |
|
|
|
def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, incremental_expandrate, lerp_alpha, decay_factor, fill_holes=False): |
|
alpha = lerp_alpha |
|
decay = decay_factor |
|
if flip_input: |
|
mask = 1.0 - mask |
|
c = 0 if tapered_corners else 1 |
|
kernel = np.array([[c, 1, c], |
|
[1, 1, 1], |
|
[c, 1, c]]) |
|
growmask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).cpu() |
|
out = [] |
|
previous_output = None |
|
current_expand = expand |
|
for m in growmask: |
|
output = m.numpy().astype(np.float32) |
|
for _ in range(abs(round(current_expand))): |
|
if current_expand < 0: |
|
output = scipy.ndimage.grey_erosion(output, footprint=kernel) |
|
else: |
|
output = scipy.ndimage.grey_dilation(output, footprint=kernel) |
|
if current_expand < 0: |
|
current_expand -= abs(incremental_expandrate) |
|
else: |
|
current_expand += abs(incremental_expandrate) |
|
if fill_holes: |
|
binary_mask = output > 0 |
|
output = scipy.ndimage.binary_fill_holes(binary_mask) |
|
output = output.astype(np.float32) * 255 |
|
output = torch.from_numpy(output) |
|
if alpha < 1.0 and previous_output is not None: |
|
|
|
output = alpha * output + (1 - alpha) * previous_output |
|
if decay < 1.0 and previous_output is not None: |
|
|
|
output += decay * previous_output |
|
output = output / output.max() |
|
previous_output = output |
|
out.append(output) |
|
|
|
if blur_radius != 0: |
|
|
|
for idx, tensor in enumerate(out): |
|
|
|
pil_image = tensor2pil(tensor.cpu().detach())[0] |
|
|
|
pil_image = pil_image.filter(ImageFilter.GaussianBlur(blur_radius)) |
|
|
|
out[idx] = pil2tensor(pil_image) |
|
blurred = torch.cat(out, dim=0) |
|
return (blurred, 1.0 - blurred) |
|
else: |
|
return (torch.stack(out, dim=0), 1.0 - torch.stack(out, dim=0),) |
|
|
|
class MaskBatchMulti: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"inputcount": ("INT", {"default": 2, "min": 2, "max": 1000, "step": 1}), |
|
"mask_1": ("MASK", ), |
|
"mask_2": ("MASK", ), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("MASK",) |
|
RETURN_NAMES = ("masks",) |
|
FUNCTION = "combine" |
|
CATEGORY = "KJNodes/masking" |
|
DESCRIPTION = """ |
|
Creates an image batch from multiple masks. |
|
You can set how many inputs the node has, |
|
with the **inputcount** and clicking update. |
|
""" |
|
|
|
def combine(self, inputcount, **kwargs): |
|
mask = kwargs["mask_1"] |
|
for c in range(1, inputcount): |
|
new_mask = kwargs[f"mask_{c + 1}"] |
|
if mask.shape[1:] != new_mask.shape[1:]: |
|
new_mask = F.interpolate(new_mask.unsqueeze(1), size=(mask.shape[1], mask.shape[2]), mode="bicubic").squeeze(1) |
|
mask = torch.cat((mask, new_mask), dim=0) |
|
return (mask,) |
|
|
|
class OffsetMask: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"mask": ("MASK",), |
|
"x": ("INT", { "default": 0, "min": -4096, "max": MAX_RESOLUTION, "step": 1, "display": "number" }), |
|
"y": ("INT", { "default": 0, "min": -4096, "max": MAX_RESOLUTION, "step": 1, "display": "number" }), |
|
"angle": ("INT", { "default": 0, "min": -360, "max": 360, "step": 1, "display": "number" }), |
|
"duplication_factor": ("INT", { "default": 1, "min": 1, "max": 1000, "step": 1, "display": "number" }), |
|
"roll": ("BOOLEAN", { "default": False }), |
|
"incremental": ("BOOLEAN", { "default": False }), |
|
"padding_mode": ( |
|
[ |
|
'empty', |
|
'border', |
|
'reflection', |
|
|
|
], { |
|
"default": 'empty' |
|
}), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("MASK",) |
|
RETURN_NAMES = ("mask",) |
|
FUNCTION = "offset" |
|
CATEGORY = "KJNodes/masking" |
|
DESCRIPTION = """ |
|
Offsets the mask by the specified amount. |
|
- mask: Input mask or mask batch |
|
- x: Horizontal offset |
|
- y: Vertical offset |
|
- angle: Angle in degrees |
|
- roll: roll edge wrapping |
|
- duplication_factor: Number of times to duplicate the mask to form a batch |
|
- border padding_mode: Padding mode for the mask |
|
""" |
|
|
|
def offset(self, mask, x, y, angle, roll=False, incremental=False, duplication_factor=1, padding_mode="empty"): |
|
|
|
mask = mask.repeat(duplication_factor, 1, 1).clone() |
|
|
|
batch_size, height, width = mask.shape |
|
|
|
if angle != 0 and incremental: |
|
for i in range(batch_size): |
|
rotation_angle = angle * (i+1) |
|
mask[i] = TF.rotate(mask[i].unsqueeze(0), rotation_angle).squeeze(0) |
|
elif angle > 0: |
|
for i in range(batch_size): |
|
mask[i] = TF.rotate(mask[i].unsqueeze(0), angle).squeeze(0) |
|
|
|
if roll: |
|
if incremental: |
|
for i in range(batch_size): |
|
shift_x = min(x*(i+1), width-1) |
|
shift_y = min(y*(i+1), height-1) |
|
if shift_x != 0: |
|
mask[i] = torch.roll(mask[i], shifts=shift_x, dims=1) |
|
if shift_y != 0: |
|
mask[i] = torch.roll(mask[i], shifts=shift_y, dims=0) |
|
else: |
|
shift_x = min(x, width-1) |
|
shift_y = min(y, height-1) |
|
if shift_x != 0: |
|
mask = torch.roll(mask, shifts=shift_x, dims=2) |
|
if shift_y != 0: |
|
mask = torch.roll(mask, shifts=shift_y, dims=1) |
|
else: |
|
|
|
for i in range(batch_size): |
|
if incremental: |
|
temp_x = min(x * (i+1), width-1) |
|
temp_y = min(y * (i+1), height-1) |
|
else: |
|
temp_x = min(x, width-1) |
|
temp_y = min(y, height-1) |
|
if temp_x > 0: |
|
if padding_mode == 'empty': |
|
mask[i] = torch.cat([torch.zeros((height, temp_x)), mask[i, :, :-temp_x]], dim=1) |
|
elif padding_mode in ['replicate', 'reflect']: |
|
mask[i] = F.pad(mask[i, :, :-temp_x], (0, temp_x), mode=padding_mode) |
|
elif temp_x < 0: |
|
if padding_mode == 'empty': |
|
mask[i] = torch.cat([mask[i, :, :temp_x], torch.zeros((height, -temp_x))], dim=1) |
|
elif padding_mode in ['replicate', 'reflect']: |
|
mask[i] = F.pad(mask[i, :, -temp_x:], (temp_x, 0), mode=padding_mode) |
|
|
|
if temp_y > 0: |
|
if padding_mode == 'empty': |
|
mask[i] = torch.cat([torch.zeros((temp_y, width)), mask[i, :-temp_y, :]], dim=0) |
|
elif padding_mode in ['replicate', 'reflect']: |
|
mask[i] = F.pad(mask[i, :-temp_y, :], (0, temp_y), mode=padding_mode) |
|
elif temp_y < 0: |
|
if padding_mode == 'empty': |
|
mask[i] = torch.cat([mask[i, :temp_y, :], torch.zeros((-temp_y, width))], dim=0) |
|
elif padding_mode in ['replicate', 'reflect']: |
|
mask[i] = F.pad(mask[i, -temp_y:, :], (temp_y, 0), mode=padding_mode) |
|
|
|
return mask, |
|
|
|
class RoundMask: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"mask": ("MASK",), |
|
}} |
|
|
|
RETURN_TYPES = ("MASK",) |
|
FUNCTION = "round" |
|
CATEGORY = "KJNodes/masking" |
|
DESCRIPTION = """ |
|
Rounds the mask or batch of masks to a binary mask. |
|
<img src="https://github.com/kijai/ComfyUI-KJNodes/assets/40791699/52c85202-f74e-4b96-9dac-c8bda5ddcc40" width="300" height="250" alt="RoundMask example"> |
|
|
|
""" |
|
|
|
def round(self, mask): |
|
mask = mask.round() |
|
return (mask,) |
|
|
|
class ResizeMask: |
|
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"mask": ("MASK",), |
|
"width": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, "display": "number" }), |
|
"height": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, "display": "number" }), |
|
"keep_proportions": ("BOOLEAN", { "default": False }), |
|
"upscale_method": (s.upscale_methods,), |
|
"crop": (["disabled","center"],), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("MASK", "INT", "INT",) |
|
RETURN_NAMES = ("mask", "width", "height",) |
|
FUNCTION = "resize" |
|
CATEGORY = "KJNodes/masking" |
|
DESCRIPTION = """ |
|
Resizes the mask or batch of masks to the specified width and height. |
|
""" |
|
|
|
def resize(self, mask, width, height, keep_proportions, upscale_method,crop): |
|
if keep_proportions: |
|
_, oh, ow = mask.shape |
|
width = ow if width == 0 else width |
|
height = oh if height == 0 else height |
|
ratio = min(width / ow, height / oh) |
|
width = round(ow*ratio) |
|
height = round(oh*ratio) |
|
outputs = mask.unsqueeze(1) |
|
outputs = common_upscale(outputs, width, height, upscale_method, crop) |
|
outputs = outputs.squeeze(1) |
|
|
|
return(outputs, outputs.shape[2], outputs.shape[1],) |
|
|
|
class RemapMaskRange: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"mask": ("MASK",), |
|
"min": ("FLOAT", {"default": 0.0,"min": -10.0, "max": 1.0, "step": 0.01}), |
|
"max": ("FLOAT", {"default": 1.0,"min": 0.0, "max": 10.0, "step": 0.01}), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("MASK",) |
|
RETURN_NAMES = ("mask",) |
|
FUNCTION = "remap" |
|
CATEGORY = "KJNodes/masking" |
|
DESCRIPTION = """ |
|
Sets new min and max values for the mask. |
|
""" |
|
|
|
def remap(self, mask, min, max): |
|
|
|
|
|
mask_max = torch.max(mask) |
|
|
|
|
|
mask_max = mask_max if mask_max > 0 else 1 |
|
|
|
|
|
|
|
scaled_mask = (mask / mask_max) * (max - min) + min |
|
|
|
|
|
scaled_mask = torch.clamp(scaled_mask, min=0.0, max=1.0) |
|
|
|
return (scaled_mask, ) |
|
|