Spaces:
Configuration error
Configuration error
import os | |
import torch | |
from torch.nn import functional as F | |
from omegaconf import OmegaConf | |
import comfy.utils | |
import comfy.model_management as mm | |
import folder_paths | |
from nodes import ImageScaleBy | |
from nodes import ImageScale | |
import torch.cuda | |
from .sgm.util import instantiate_from_config | |
from .SUPIR.util import convert_dtype, load_state_dict | |
import open_clip | |
from contextlib import contextmanager | |
from transformers import ( | |
CLIPTextModel, | |
CLIPTokenizer, | |
CLIPTextConfig, | |
) | |
script_directory = os.path.dirname(os.path.abspath(__file__)) | |
def dummy_build_vision_tower(*args, **kwargs): | |
# Monkey patch the CLIP class before you create an instance. | |
return None | |
def patch_build_vision_tower(): | |
original_build_vision_tower = open_clip.model._build_vision_tower | |
open_clip.model._build_vision_tower = dummy_build_vision_tower | |
try: | |
yield | |
finally: | |
open_clip.model._build_vision_tower = original_build_vision_tower | |
def build_text_model_from_openai_state_dict( | |
state_dict: dict, | |
cast_dtype=torch.float16, | |
): | |
embed_dim = state_dict["text_projection"].shape[1] | |
context_length = state_dict["positional_embedding"].shape[0] | |
vocab_size = state_dict["token_embedding.weight"].shape[0] | |
transformer_width = state_dict["ln_final.weight"].shape[0] | |
transformer_heads = transformer_width // 64 | |
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) | |
vision_cfg = None | |
text_cfg = open_clip.CLIPTextCfg( | |
context_length=context_length, | |
vocab_size=vocab_size, | |
width=transformer_width, | |
heads=transformer_heads, | |
layers=transformer_layers, | |
) | |
with patch_build_vision_tower(): | |
model = open_clip.CLIP( | |
embed_dim, | |
vision_cfg=vision_cfg, | |
text_cfg=text_cfg, | |
quick_gelu=True, | |
cast_dtype=cast_dtype, | |
) | |
model.load_state_dict(state_dict, strict=False) | |
model = model.eval() | |
for param in model.parameters(): | |
param.requires_grad = False | |
return model | |
class SUPIR_Upscale: | |
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] | |
def INPUT_TYPES(s): | |
return {"required": { | |
"supir_model": (folder_paths.get_filename_list("checkpoints"),), | |
"sdxl_model": (folder_paths.get_filename_list("checkpoints"),), | |
"image": ("IMAGE",), | |
"seed": ("INT", {"default": 123, "min": 0, "max": 0xffffffffffffffff, "step": 1}), | |
"resize_method": (s.upscale_methods, {"default": "lanczos"}), | |
"scale_by": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 20.0, "step": 0.01}), | |
"steps": ("INT", {"default": 45, "min": 3, "max": 4096, "step": 1}), | |
"restoration_scale": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 6.0, "step": 1.0}), | |
"cfg_scale": ("FLOAT", {"default": 4.0, "min": 0, "max": 100, "step": 0.01}), | |
"a_prompt": ("STRING", {"multiline": True, "default": "high quality, detailed", }), | |
"n_prompt": ("STRING", {"multiline": True, "default": "bad quality, blurry, messy", }), | |
"s_churn": ("INT", {"default": 5, "min": 0, "max": 40, "step": 1}), | |
"s_noise": ("FLOAT", {"default": 1.003, "min": 1.0, "max": 1.1, "step": 0.001}), | |
"control_scale": ("FLOAT", {"default": 1.0, "min": 0, "max": 10.0, "step": 0.05}), | |
"cfg_scale_start": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 100.0, "step": 0.05}), | |
"control_scale_start": ("FLOAT", {"default": 0.0, "min": 0, "max": 1.0, "step": 0.05}), | |
"color_fix_type": ( | |
[ | |
'None', | |
'AdaIn', | |
'Wavelet', | |
], { | |
"default": 'Wavelet' | |
}), | |
"keep_model_loaded": ("BOOLEAN", {"default": True}), | |
"use_tiled_vae": ("BOOLEAN", {"default": True}), | |
"encoder_tile_size_pixels": ("INT", {"default": 512, "min": 64, "max": 8192, "step": 64}), | |
"decoder_tile_size_latent": ("INT", {"default": 64, "min": 32, "max": 8192, "step": 64}), | |
}, | |
"optional": { | |
"captions": ("STRING", {"forceInput": True, "multiline": False, "default": "", }), | |
"diffusion_dtype": ( | |
[ | |
'fp16', | |
'bf16', | |
'fp32', | |
'auto' | |
], { | |
"default": 'auto' | |
}), | |
"encoder_dtype": ( | |
[ | |
'bf16', | |
'fp32', | |
'auto' | |
], { | |
"default": 'auto' | |
}), | |
"batch_size": ("INT", {"default": 1, "min": 1, "max": 128, "step": 1}), | |
"use_tiled_sampling": ("BOOLEAN", {"default": False}), | |
"sampler_tile_size": ("INT", {"default": 1024, "min": 64, "max": 4096, "step": 32}), | |
"sampler_tile_stride": ("INT", {"default": 512, "min": 32, "max": 2048, "step": 32}), | |
"fp8_unet": ("BOOLEAN", {"default": False}), | |
"fp8_vae": ("BOOLEAN", {"default": False}), | |
"sampler": ( | |
[ | |
'RestoreDPMPP2MSampler', | |
'RestoreEDMSampler', | |
], { | |
"default": 'RestoreEDMSampler' | |
}), | |
} | |
} | |
RETURN_TYPES = ("IMAGE",) | |
RETURN_NAMES = ("upscaled_image",) | |
FUNCTION = "process" | |
CATEGORY = "SUPIR" | |
def process(self, steps, image, color_fix_type, seed, scale_by, cfg_scale, resize_method, s_churn, s_noise, | |
encoder_tile_size_pixels, decoder_tile_size_latent, | |
control_scale, cfg_scale_start, control_scale_start, restoration_scale, keep_model_loaded, | |
a_prompt, n_prompt, sdxl_model, supir_model, use_tiled_vae, use_tiled_sampling=False, sampler_tile_size=128, sampler_tile_stride=64, captions="", diffusion_dtype="auto", | |
encoder_dtype="auto", batch_size=1, fp8_unet=False, fp8_vae=False, sampler="RestoreEDMSampler"): | |
device = mm.get_torch_device() | |
mm.unload_all_models() | |
SUPIR_MODEL_PATH = folder_paths.get_full_path("checkpoints", supir_model) | |
SDXL_MODEL_PATH = folder_paths.get_full_path("checkpoints", sdxl_model) | |
config_path = os.path.join(script_directory, "options/SUPIR_v0.yaml") | |
config_path_tiled = os.path.join(script_directory, "options/SUPIR_v0_tiled.yaml") | |
clip_config_path = os.path.join(script_directory, "configs/clip_vit_config.json") | |
tokenizer_path = os.path.join(script_directory, "configs/tokenizer") | |
custom_config = { | |
'sdxl_model': sdxl_model, | |
'diffusion_dtype': diffusion_dtype, | |
'encoder_dtype': encoder_dtype, | |
'use_tiled_vae': use_tiled_vae, | |
'supir_model': supir_model, | |
'use_tiled_sampling': use_tiled_sampling, | |
'fp8_unet': fp8_unet, | |
'fp8_vae': fp8_vae, | |
'sampler': sampler | |
} | |
if diffusion_dtype == 'auto': | |
try: | |
if mm.should_use_fp16(): | |
print("Diffusion using fp16") | |
dtype = torch.float16 | |
model_dtype = 'fp16' | |
if mm.should_use_bf16(): | |
print("Diffusion using bf16") | |
dtype = torch.bfloat16 | |
model_dtype = 'bf16' | |
else: | |
print("Diffusion using using fp32") | |
dtype = torch.float32 | |
model_dtype = 'fp32' | |
except: | |
raise AttributeError("ComfyUI too old, can't autodecet properly. Set your dtypes manually.") | |
else: | |
print(f"Diffusion using using {diffusion_dtype}") | |
dtype = convert_dtype(diffusion_dtype) | |
model_dtype = diffusion_dtype | |
if encoder_dtype == 'auto': | |
try: | |
if mm.should_use_bf16(): | |
print("Encoder using bf16") | |
vae_dtype = 'bf16' | |
else: | |
print("Encoder using using fp32") | |
vae_dtype = 'fp32' | |
except: | |
raise AttributeError("ComfyUI too old, can't autodetect properly. Set your dtypes manually.") | |
else: | |
vae_dtype = encoder_dtype | |
print(f"Encoder using using {vae_dtype}") | |
if not hasattr(self, "model") or self.model is None or self.current_config != custom_config: | |
self.current_config = custom_config | |
self.model = None | |
mm.soft_empty_cache() | |
if use_tiled_sampling: | |
config = OmegaConf.load(config_path_tiled) | |
config.model.params.sampler_config.params.tile_size = sampler_tile_size // 8 | |
config.model.params.sampler_config.params.tile_stride = sampler_tile_stride // 8 | |
config.model.params.sampler_config.target = f".sgm.modules.diffusionmodules.sampling.Tiled{sampler}" | |
print("Using tiled sampling") | |
else: | |
config = OmegaConf.load(config_path) | |
config.model.params.sampler_config.target = f".sgm.modules.diffusionmodules.sampling.{sampler}" | |
print("Using non-tiled sampling") | |
if mm.XFORMERS_IS_AVAILABLE: | |
config.model.params.control_stage_config.params.spatial_transformer_attn_type = "softmax-xformers" | |
config.model.params.network_config.params.spatial_transformer_attn_type = "softmax-xformers" | |
config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla-xformers" | |
config.model.params.ae_dtype = vae_dtype | |
config.model.params.diffusion_dtype = model_dtype | |
self.model = instantiate_from_config(config.model).cpu() | |
try: | |
print(f'Attempting to load SUPIR model: [{SUPIR_MODEL_PATH}]') | |
supir_state_dict = load_state_dict(SUPIR_MODEL_PATH) | |
except: | |
raise Exception("Failed to load SUPIR model") | |
try: | |
print(f"Attempting to load SDXL model: [{SDXL_MODEL_PATH}]") | |
sdxl_state_dict = load_state_dict(SDXL_MODEL_PATH) | |
except: | |
raise Exception("Failed to load SDXL model") | |
self.model.load_state_dict(supir_state_dict, strict=False) | |
self.model.load_state_dict(sdxl_state_dict, strict=False) | |
del supir_state_dict | |
#first clip model from SDXL checkpoint | |
try: | |
print("Loading first clip model from SDXL checkpoint") | |
replace_prefix = {} | |
replace_prefix["conditioner.embedders.0.transformer."] = "" | |
sd = comfy.utils.state_dict_prefix_replace(sdxl_state_dict, replace_prefix, filter_keys=False) | |
clip_text_config = CLIPTextConfig.from_pretrained(clip_config_path) | |
self.model.conditioner.embedders[0].tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) | |
self.model.conditioner.embedders[0].transformer = CLIPTextModel(clip_text_config) | |
self.model.conditioner.embedders[0].transformer.load_state_dict(sd, strict=False) | |
self.model.conditioner.embedders[0].eval() | |
for param in self.model.conditioner.embedders[0].parameters(): | |
param.requires_grad = False | |
except: | |
raise Exception("Failed to load first clip model from SDXL checkpoint") | |
del sdxl_state_dict | |
#second clip model from SDXL checkpoint | |
try: | |
print("Loading second clip model from SDXL checkpoint") | |
replace_prefix2 = {} | |
replace_prefix2["conditioner.embedders.1.model."] = "" | |
sd = comfy.utils.state_dict_prefix_replace(sd, replace_prefix2, filter_keys=True) | |
clip_g = build_text_model_from_openai_state_dict(sd, cast_dtype=dtype) | |
self.model.conditioner.embedders[1].model = clip_g | |
except: | |
raise Exception("Failed to load second clip model from SDXL checkpoint") | |
del sd, clip_g | |
mm.soft_empty_cache() | |
self.model.to(dtype) | |
#only unets and/or vae to fp8 | |
if fp8_unet: | |
self.model.model.to(torch.float8_e4m3fn) | |
if fp8_vae: | |
self.model.first_stage_model.to(torch.float8_e4m3fn) | |
if use_tiled_vae: | |
self.model.init_tile_vae(encoder_tile_size=encoder_tile_size_pixels, decoder_tile_size=decoder_tile_size_latent) | |
upscaled_image, = ImageScaleBy.upscale(self, image, resize_method, scale_by) | |
B, H, W, C = upscaled_image.shape | |
new_height = H if H % 64 == 0 else ((H // 64) + 1) * 64 | |
new_width = W if W % 64 == 0 else ((W // 64) + 1) * 64 | |
upscaled_image = upscaled_image.permute(0, 3, 1, 2) | |
resized_image = F.interpolate(upscaled_image, size=(new_height, new_width), mode='bicubic', align_corners=False) | |
resized_image = resized_image.to(device) | |
captions_list = [] | |
captions_list.append(captions) | |
print("captions: ", captions_list) | |
use_linear_CFG = cfg_scale_start > 0 | |
use_linear_control_scale = control_scale_start > 0 | |
out = [] | |
pbar = comfy.utils.ProgressBar(B) | |
batched_images = [resized_image[i:i + batch_size] for i in | |
range(0, len(resized_image), batch_size)] | |
captions_list = captions_list * resized_image.shape[0] | |
batched_captions = [captions_list[i:i + batch_size] for i in range(0, len(captions_list), batch_size)] | |
mm.soft_empty_cache() | |
i = 1 | |
for imgs, caps in zip(batched_images, batched_captions): | |
try: | |
samples = self.model.batchify_sample(imgs, caps, num_steps=steps, | |
restoration_scale=restoration_scale, s_churn=s_churn, | |
s_noise=s_noise, cfg_scale=cfg_scale, control_scale=control_scale, | |
seed=seed, | |
num_samples=1, p_p=a_prompt, n_p=n_prompt, | |
color_fix_type=color_fix_type, | |
use_linear_CFG=use_linear_CFG, | |
use_linear_control_scale=use_linear_control_scale, | |
cfg_scale_start=cfg_scale_start, | |
control_scale_start=control_scale_start) | |
except torch.cuda.OutOfMemoryError as e: | |
mm.free_memory(mm.get_total_memory(mm.get_torch_device()), mm.get_torch_device()) | |
self.model = None | |
mm.soft_empty_cache() | |
print("It's likely that too large of an image or batch_size for SUPIR was used," | |
" and it has devoured all of the memory it had reserved, you may need to restart ComfyUI. Make sure you are using tiled_vae, " | |
" you can also try using fp8 for reduced memory usage if your system supports it.") | |
raise e | |
out.append(samples.squeeze(0).cpu()) | |
print("Sampled ", i * len(imgs), " out of ", B) | |
i = i + 1 | |
pbar.update(1) | |
if not keep_model_loaded: | |
self.model = None | |
mm.soft_empty_cache() | |
if len(out[0].shape) == 4: | |
out_stacked = torch.cat(out, dim=0).cpu().to(torch.float32).permute(0, 2, 3, 1) | |
else: | |
out_stacked = torch.stack(out, dim=0).cpu().to(torch.float32).permute(0, 2, 3, 1) | |
final_image, = ImageScale.upscale(self, out_stacked, resize_method, W, H, crop="disabled") | |
return (final_image,) | |
NODE_CLASS_MAPPINGS = { | |
"SUPIR_Upscale": SUPIR_Upscale | |
} | |
NODE_DISPLAY_NAME_MAPPINGS = { | |
"SUPIR_Upscale": "SUPIR_Upscale" | |
} |