|
|
import comfy.utils |
|
|
import torch |
|
|
import gc |
|
|
import logging |
|
|
import comfy.model_management as model_management |
|
|
|
|
|
|
|
|
def clear_gpu_and_ram_cache(): |
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.ipc_collect() |
|
|
|
|
|
|
|
|
def _smart_decode(vae, latent, tile_size=512): |
|
|
try: |
|
|
images = vae.decode(latent["samples"]) |
|
|
except model_management.OOM_EXCEPTION: |
|
|
logging.warning("VAE decode OOM, using tiled decode") |
|
|
compression = vae.spacial_compression_decode() |
|
|
images = vae.decode_tiled( |
|
|
latent["samples"], |
|
|
tile_x=tile_size // compression, |
|
|
tile_y=tile_size // compression, |
|
|
overlap=(tile_size // 4) // compression, |
|
|
) |
|
|
if len(images.shape) == 5: |
|
|
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) |
|
|
return images |
|
|
|
|
|
|
|
|
class MagicUpscaleModule: |
|
|
"""Moved into mod/ as mg_upscale_module keeping class/key name.""" |
|
|
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] |
|
|
|
|
|
@classmethod |
|
|
def INPUT_TYPES(cls): |
|
|
return { |
|
|
"required": { |
|
|
"samples": ("LATENT", {}), |
|
|
"vae": ("VAE", {}), |
|
|
"upscale_method": (cls.upscale_methods, {"default": "bilinear"}), |
|
|
"scale_by": ("FLOAT", {"default": 1.2, "min": 0.01, "max": 8.0, "step": 0.01}), |
|
|
} |
|
|
} |
|
|
|
|
|
RETURN_TYPES = ("LATENT", "IMAGE") |
|
|
RETURN_NAMES = ("LATENT", "Upscaled Image") |
|
|
FUNCTION = "process_upscale" |
|
|
CATEGORY = "MagicNodes" |
|
|
|
|
|
def process_upscale(self, samples, vae, upscale_method, scale_by): |
|
|
clear_gpu_and_ram_cache() |
|
|
images = _smart_decode(vae, samples) |
|
|
samples_t = images.movedim(-1, 1) |
|
|
width = round(samples_t.shape[3] * scale_by) |
|
|
height = round(samples_t.shape[2] * scale_by) |
|
|
|
|
|
try: |
|
|
stride = int(vae.spacial_compression_decode()) |
|
|
except Exception: |
|
|
stride = 8 |
|
|
if stride <= 0: |
|
|
stride = 8 |
|
|
def _align_up(x, s): |
|
|
return int(((x + s - 1) // s) * s) |
|
|
width_al = _align_up(width, stride) |
|
|
height_al = _align_up(height, stride) |
|
|
up = comfy.utils.common_upscale(samples_t, width_al, height_al, upscale_method, "disabled") |
|
|
up = up.movedim(1, -1) |
|
|
encoded = vae.encode(up[:, :, :, :3]) |
|
|
return ({"samples": encoded}, up) |
|
|
|