# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py import ldm_patched.modules.sd import ldm_patched.modules.utils import ldm_patched.modules.model_base import ldm_patched.modules.model_management import ldm_patched.utils.path_utils import json import os from ldm_patched.modules.args_parser import args class ModelMergeSimple: @classmethod def INPUT_TYPES(s): return {"required": { "model1": ("MODEL",), "model2": ("MODEL",), "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "merge" CATEGORY = "advanced/model_merging" def merge(self, model1, model2, ratio): m = model1.clone() kp = model2.get_key_patches("diffusion_model.") for k in kp: m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) return (m, ) class ModelSubtract: @classmethod def INPUT_TYPES(s): return {"required": { "model1": ("MODEL",), "model2": ("MODEL",), "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "merge" CATEGORY = "advanced/model_merging" def merge(self, model1, model2, multiplier): m = model1.clone() kp = model2.get_key_patches("diffusion_model.") for k in kp: m.add_patches({k: kp[k]}, - multiplier, multiplier) return (m, ) class ModelAdd: @classmethod def INPUT_TYPES(s): return {"required": { "model1": ("MODEL",), "model2": ("MODEL",), }} RETURN_TYPES = ("MODEL",) FUNCTION = "merge" CATEGORY = "advanced/model_merging" def merge(self, model1, model2): m = model1.clone() kp = model2.get_key_patches("diffusion_model.") for k in kp: m.add_patches({k: kp[k]}, 1.0, 1.0) return (m, ) class CLIPMergeSimple: @classmethod def INPUT_TYPES(s): return {"required": { "clip1": ("CLIP",), "clip2": ("CLIP",), "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), }} RETURN_TYPES = ("CLIP",) FUNCTION = "merge" CATEGORY = "advanced/model_merging" def merge(self, clip1, clip2, ratio): m = clip1.clone() kp = clip2.get_key_patches() for k in kp: if k.endswith(".position_ids") or k.endswith(".logit_scale"): continue m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) return (m, ) class ModelMergeBlocks: @classmethod def INPUT_TYPES(s): return {"required": { "model1": ("MODEL",), "model2": ("MODEL",), "input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) }} RETURN_TYPES = ("MODEL",) FUNCTION = "merge" CATEGORY = "advanced/model_merging" def merge(self, model1, model2, **kwargs): m = model1.clone() kp = model2.get_key_patches("diffusion_model.") default_ratio = next(iter(kwargs.values())) for k in kp: ratio = default_ratio k_unet = k[len("diffusion_model."):] last_arg_size = 0 for arg in kwargs: if k_unet.startswith(arg) and last_arg_size < len(arg): ratio = kwargs[arg] last_arg_size = len(arg) m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) return (m, ) def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None): full_output_folder, filename, counter, subfolder, filename_prefix = ldm_patched.utils.path_utils.get_save_image_path(filename_prefix, output_dir) prompt_info = "" if prompt is not None: prompt_info = json.dumps(prompt) metadata = {} enable_modelspec = True if isinstance(model.model, ldm_patched.modules.model_base.SDXL): metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base" elif isinstance(model.model, ldm_patched.modules.model_base.SDXLRefiner): metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner" else: enable_modelspec = False if enable_modelspec: metadata["modelspec.sai_model_spec"] = "1.0.0" metadata["modelspec.implementation"] = "sgm" metadata["modelspec.title"] = "{} {}".format(filename, counter) #TODO: # "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512", # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h", # "v2-inpainting" if model.model.model_type == ldm_patched.modules.model_base.ModelType.EPS: metadata["modelspec.predict_key"] = "epsilon" elif model.model.model_type == ldm_patched.modules.model_base.ModelType.V_PREDICTION: metadata["modelspec.predict_key"] = "v" if not args.disable_server_info: metadata["prompt"] = prompt_info if extra_pnginfo is not None: for x in extra_pnginfo: metadata[x] = json.dumps(extra_pnginfo[x]) output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) ldm_patched.modules.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata) class CheckpointSave: def __init__(self): self.output_dir = ldm_patched.utils.path_utils.get_output_directory() @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "clip": ("CLIP",), "vae": ("VAE",), "filename_prefix": ("STRING", {"default": "checkpoints/ldm_patched"}),}, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} RETURN_TYPES = () FUNCTION = "save" OUTPUT_NODE = True CATEGORY = "advanced/model_merging" def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None): save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo) return {} class CLIPSave: def __init__(self): self.output_dir = ldm_patched.utils.path_utils.get_output_directory() @classmethod def INPUT_TYPES(s): return {"required": { "clip": ("CLIP",), "filename_prefix": ("STRING", {"default": "clip/ldm_patched"}),}, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} RETURN_TYPES = () FUNCTION = "save" OUTPUT_NODE = True CATEGORY = "advanced/model_merging" def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None): prompt_info = "" if prompt is not None: prompt_info = json.dumps(prompt) metadata = {} if not args.disable_server_info: metadata["prompt"] = prompt_info if extra_pnginfo is not None: for x in extra_pnginfo: metadata[x] = json.dumps(extra_pnginfo[x]) ldm_patched.modules.model_management.load_models_gpu([clip.load_model()]) clip_sd = clip.get_sd() for prefix in ["clip_l.", "clip_g.", ""]: k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys())) current_clip_sd = {} for x in k: current_clip_sd[x] = clip_sd.pop(x) if len(current_clip_sd) == 0: continue p = prefix[:-1] replace_prefix = {} filename_prefix_ = filename_prefix if len(p) > 0: filename_prefix_ = "{}_{}".format(filename_prefix_, p) replace_prefix[prefix] = "" replace_prefix["transformer."] = "" full_output_folder, filename, counter, subfolder, filename_prefix_ = ldm_patched.utils.path_utils.get_save_image_path(filename_prefix_, self.output_dir) output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) current_clip_sd = ldm_patched.modules.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix) ldm_patched.modules.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata) return {} class VAESave: def __init__(self): self.output_dir = ldm_patched.utils.path_utils.get_output_directory() @classmethod def INPUT_TYPES(s): return {"required": { "vae": ("VAE",), "filename_prefix": ("STRING", {"default": "vae/ldm_patched_vae"}),}, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} RETURN_TYPES = () FUNCTION = "save" OUTPUT_NODE = True CATEGORY = "advanced/model_merging" def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None): full_output_folder, filename, counter, subfolder, filename_prefix = ldm_patched.utils.path_utils.get_save_image_path(filename_prefix, self.output_dir) prompt_info = "" if prompt is not None: prompt_info = json.dumps(prompt) metadata = {} if not args.disable_server_info: metadata["prompt"] = prompt_info if extra_pnginfo is not None: for x in extra_pnginfo: metadata[x] = json.dumps(extra_pnginfo[x]) output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) ldm_patched.modules.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata) return {} NODE_CLASS_MAPPINGS = { "ModelMergeSimple": ModelMergeSimple, "ModelMergeBlocks": ModelMergeBlocks, "ModelMergeSubtract": ModelSubtract, "ModelMergeAdd": ModelAdd, "CheckpointSave": CheckpointSave, "CLIPMergeSimple": CLIPMergeSimple, "CLIPSave": CLIPSave, "VAESave": VAESave, }