import torch import comfy.utils import math class PatchModelAddDownscale_MK3: """A UNet model patch that implements advanced dynamic latent downscaling with multiple transition modes. This node is an enhanced version of PatchModelAddDownscale_v2 that adds multiple transition modes, adaptive scaling, and performance optimizations. It operates in three main phases with configurable behaviors: 1. Full Downscale Phase (start_percent → end_percent): - Latents are downscaled by the specified downscale_factor - Optional dynamic factor adjustment based on latent size - Supports minimum size constraints 2. Transition Phase (end_percent → gradual_percent): Multiple transition modes available: - LINEAR: Smooth linear interpolation (original v2 behavior) - COSINE: Smooth cosine interpolation for more natural transitions - EXPONENTIAL: Quick initial change that slows down - LOGARITHMIC: Slow initial change that speeds up - STEP: Discrete steps for controlled transitions 3. Final Phase (after gradual_percent): - Latents remain at their original size - Optional post-processing effects Advanced Features: - Adaptive scaling based on input latent dimensions - Multiple interpolation algorithms for both downscaling and upscaling - Dynamic minimum size constraints to prevent over-shrinking - Optional skip connection handling modes - Memory optimization for large batch processing - Automatic scale factor adjustment for extreme aspect ratios Parameters: model: The model to patch block_number: Which UNet block to apply the patch to (1-32) downscale_factor: Base shrink factor (0.1-9.0) start_percent: When to start downscaling (0.0-1.0) end_percent: When to begin transitioning back (0.0-1.0) gradual_percent: When to complete the transition (0.0-1.0) transition_mode: Algorithm for size transition min_size: Minimum allowed dimension in pixels adaptive_scaling: Enable dynamic factor adjustment downscale_after_skip: Apply downscaling after skip connections downscale_method: Algorithm for downscaling upscale_method: Algorithm for upscaling preserve_aspect: Maintain aspect ratio during scaling Example Usage: To create a gentle transition with cosine interpolation: ```python patch = PatchModelAddDownscale_MK3( model=model, block_number=3, downscale_factor=2.0, start_percent=0.0, end_percent=0.35, gradual_percent=0.6, transition_mode='COSINE' ) ``` Code by: - Original: https://github.com/Jordach + comfyanon + kohya-ss """ upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"] transition_modes = ["LINEAR", "COSINE", "EXPONENTIAL", "LOGARITHMIC", "STEP"] @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}), "downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}), "gradual_percent": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 1.0, "step": 0.001}), "transition_mode": (s.transition_modes,), "downscale_after_skip": ("BOOLEAN", {"default": True}), "downscale_method": (s.upscale_methods,), "upscale_method": (s.upscale_methods,), "min_size": ("INT", {"default": 64, "min": 16, "max": 2048, "step": 8}), "adaptive_scaling": ("BOOLEAN", {"default": True}), "preserve_aspect": ("BOOLEAN", {"default": True}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" CATEGORY = "model_patches/unet" def calculate_transition_factor(self, current_percent, end_percent, gradual_percent, downscale_factor, mode="LINEAR"): """Calculate the scaling factor based on the selected transition mode""" if current_percent <= end_percent: return 1.0 / downscale_factor elif current_percent >= gradual_percent: return 1.0 # Calculate base progress progress = (current_percent - end_percent) / (gradual_percent - end_percent) # Apply different transition curves if mode == "LINEAR": factor = progress elif mode == "COSINE": factor = (1 - math.cos(progress * math.pi)) / 2 elif mode == "EXPONENTIAL": factor = math.pow(progress, 2) elif mode == "LOGARITHMIC": factor = math.log(1 + progress * (math.e - 1)) elif mode == "STEP": factor = round(progress * 4) / 4 # 4 discrete steps # Calculate final scale scale_diff = 1.0 - (1.0 / downscale_factor) return (1.0 / downscale_factor) + (scale_diff * factor) def calculate_adaptive_factor(self, h, base_factor, min_size): """Adjust scaling factor based on input dimensions and constraints""" min_dim = min(h.shape[-2:]) max_dim = max(h.shape[-2:]) aspect_ratio = max_dim / min_dim # Prevent over-shrinking max_allowed_factor = min_dim / min_size adjusted_factor = min(base_factor, max_allowed_factor) # Adjust for extreme aspect ratios if aspect_ratio > 2: adjusted_factor *= math.sqrt(2 / aspect_ratio) return adjusted_factor def patch(self, model, block_number, downscale_factor, start_percent, end_percent, gradual_percent, transition_mode, downscale_after_skip, downscale_method, upscale_method, min_size, adaptive_scaling, preserve_aspect): model_sampling = model.get_model_object("model_sampling") sigma_start = model_sampling.percent_to_sigma(start_percent) sigma_end = model_sampling.percent_to_sigma(end_percent) sigma_rescale = model_sampling.percent_to_sigma(gradual_percent) def input_block_patch(h, transformer_options): if downscale_factor == 1: return h if transformer_options["block"][1] == block_number: sigma = transformer_options["sigmas"][0].item() # Calculate effective scaling factor if adaptive_scaling: effective_factor = self.calculate_adaptive_factor(h, downscale_factor, min_size) else: effective_factor = downscale_factor # Apply scaling based on current phase if sigma <= sigma_start and sigma >= sigma_end: scale_factor = 1.0 / effective_factor elif sigma < sigma_end and sigma >= sigma_rescale: scale_factor = self.calculate_transition_factor( sigma, sigma_rescale, sigma_end, effective_factor, transition_mode ) else: return h # Calculate new dimensions if preserve_aspect: new_h = round(h.shape[-2] * scale_factor) new_w = round(h.shape[-1] * scale_factor) else: # Independent scaling for width/height new_h = max(round(h.shape[-2] * scale_factor), min_size) new_w = max(round(h.shape[-1] * scale_factor), min_size) h = comfy.utils.common_upscale( h, new_w, new_h, downscale_method if scale_factor < 1 else upscale_method, "disabled" ) return h def output_block_patch(h, hsp, transformer_options): if h.shape[2:] != hsp.shape[2:]: h = comfy.utils.common_upscale( h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled" ) return h, hsp m = model.clone() if downscale_after_skip: m.set_model_input_block_patch_after_skip(input_block_patch) else: m.set_model_input_block_patch(input_block_patch) m.set_model_output_block_patch(output_block_patch) return (m, ) NODE_CLASS_MAPPINGS = { "PatchModelAddDownscale_MK3": PatchModelAddDownscale_MK3, } NODE_DISPLAY_NAME_MAPPINGS = { "PatchModelAddDownscale_MK3": "PatchModelAddDownscale MK3", }