|
import comfy.sd
|
|
import comfy.utils
|
|
import comfy.model_base
|
|
import comfy.model_management
|
|
import comfy.model_sampling
|
|
|
|
import torch
|
|
import folder_paths
|
|
import json
|
|
import os
|
|
|
|
|
|
try:
|
|
from comfy.cli_args import args
|
|
except ImportError:
|
|
class ArgsMock:
|
|
disable_metadata = False
|
|
args = ArgsMock()
|
|
|
|
|
|
class ModelMergeSimple:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": { "model1": ("MODEL",),
|
|
"model2": ("MODEL",),
|
|
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
}}
|
|
RETURN_TYPES = ("MODEL",)
|
|
FUNCTION = "merge"
|
|
|
|
CATEGORY = "advanced/model_merging"
|
|
|
|
def merge(self, model1, model2, ratio):
|
|
m = model1.clone()
|
|
kp = model2.get_key_patches("diffusion_model.")
|
|
for k in kp:
|
|
m.add_patches({k: kp[k]}, ratio, 1.0 - ratio)
|
|
return (m, )
|
|
|
|
class ModelMergeMultiSimple:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
inputs = {"required": {}}
|
|
for i in range(1, 6):
|
|
inputs["required"][f"model{i}"] = ("MODEL",)
|
|
inputs["required"][f"ratio{i}"] = ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
|
return inputs
|
|
|
|
RETURN_TYPES = ("MODEL",)
|
|
FUNCTION = "merge_five"
|
|
|
|
CATEGORY = "advanced/model_merging"
|
|
|
|
def merge_five(self, **kwargs):
|
|
models = []
|
|
ratios = []
|
|
|
|
for i in range(1, 6):
|
|
model = kwargs.get(f"model{i}")
|
|
ratio = kwargs.get(f"ratio{i}")
|
|
if model is not None:
|
|
models.append(model)
|
|
ratios.append(ratio)
|
|
elif ratio > 0:
|
|
|
|
print(f"Warning: Ratio {ratio} provided for model{i} but model is missing. Ignoring.")
|
|
ratios.append(0.0)
|
|
|
|
if not models:
|
|
raise ValueError("No models provided for merging.")
|
|
|
|
|
|
|
|
active_models_data = []
|
|
for model, ratio in zip(models, ratios):
|
|
if ratio > 0:
|
|
active_models_data.append({"model": model, "original_ratio": ratio})
|
|
|
|
if not active_models_data:
|
|
print("Warning: All model ratios are 0. Returning the first provided model without changes.")
|
|
return (models[0].clone(), )
|
|
|
|
|
|
total_original_ratio = sum(item["original_ratio"] for item in active_models_data)
|
|
|
|
if total_original_ratio == 0:
|
|
print("Warning: Sum of active model ratios is 0. Returning the first provided model.")
|
|
return (models[0].clone(), )
|
|
|
|
|
|
normalized_ratios = [item["original_ratio"] / total_original_ratio for item in active_models_data]
|
|
|
|
|
|
merged_model = active_models_data[0]["model"].clone()
|
|
|
|
if len(active_models_data) == 1:
|
|
|
|
return (merged_model,)
|
|
|
|
current_cumulative_normalized_weight = normalized_ratios[0]
|
|
|
|
|
|
for i in range(1, len(active_models_data)):
|
|
next_model_data = active_models_data[i]
|
|
next_model_normalized_weight = normalized_ratios[i]
|
|
|
|
|
|
|
|
if current_cumulative_normalized_weight == 0 and i==0 :
|
|
merged_model = next_model_data["model"].clone()
|
|
current_cumulative_normalized_weight = next_model_normalized_weight
|
|
continue
|
|
|
|
|
|
|
|
|
|
denominator = current_cumulative_normalized_weight + next_model_normalized_weight
|
|
|
|
if denominator == 0:
|
|
continue
|
|
|
|
|
|
strength_for_next_model = next_model_normalized_weight / denominator
|
|
|
|
strength_for_merged_model_self = current_cumulative_normalized_weight / denominator
|
|
|
|
key_patches = next_model_data["model"].get_key_patches("diffusion_model.")
|
|
|
|
|
|
|
|
for k in key_patches:
|
|
merged_model.add_patches({k: key_patches[k]}, strength_for_next_model, strength_for_merged_model_self)
|
|
|
|
current_cumulative_normalized_weight += next_model_normalized_weight
|
|
|
|
|
|
|
|
|
|
return (merged_model,)
|
|
|
|
|
|
|
|
class ModelSubtract:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": { "model1": ("MODEL",),
|
|
"model2": ("MODEL",),
|
|
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
|
}}
|
|
RETURN_TYPES = ("MODEL",)
|
|
FUNCTION = "merge"
|
|
CATEGORY = "advanced/model_merging"
|
|
def merge(self, model1, model2, multiplier):
|
|
m = model1.clone()
|
|
kp = model2.get_key_patches("diffusion_model.")
|
|
for k in kp:
|
|
m.add_patches({k: kp[k]}, multiplier, -multiplier)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
m = model1.clone()
|
|
kp = model2.get_key_patches("diffusion_model.")
|
|
for k in kp:
|
|
m.add_patches({k: kp[k]}, -multiplier, 1.0)
|
|
return (m, )
|
|
|
|
class ModelAdd:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": { "model1": ("MODEL",),
|
|
"model2": ("MODEL",),
|
|
}}
|
|
RETURN_TYPES = ("MODEL",)
|
|
FUNCTION = "merge"
|
|
CATEGORY = "advanced/model_merging"
|
|
def merge(self, model1, model2):
|
|
m = model1.clone()
|
|
kp = model2.get_key_patches("diffusion_model.")
|
|
for k in kp:
|
|
m.add_patches({k: kp[k]}, 1.0, 1.0)
|
|
return (m, )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CLIPMergeSimple:
|
|
@classmethod
|
|
def INPUT_TYPES(s): return {"required": { "clip1": ("CLIP",),"clip2": ("CLIP",),"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), }}
|
|
RETURN_TYPES = ("CLIP",)
|
|
FUNCTION = "merge"
|
|
CATEGORY = "advanced/model_merging"
|
|
def merge(self, clip1, clip2, ratio): return (clip1, )
|
|
|
|
class CLIPSubtract:
|
|
@classmethod
|
|
def INPUT_TYPES(s): return {"required": { "clip1": ("CLIP",),"clip2": ("CLIP",),"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),}}
|
|
RETURN_TYPES = ("CLIP",)
|
|
FUNCTION = "merge"
|
|
CATEGORY = "advanced/model_merging"
|
|
def merge(self, clip1, clip2, multiplier): return (clip1,)
|
|
|
|
class CLIPAdd:
|
|
@classmethod
|
|
def INPUT_TYPES(s): return {"required": { "clip1": ("CLIP",),"clip2": ("CLIP",),}}
|
|
RETURN_TYPES = ("CLIP",)
|
|
FUNCTION = "merge"
|
|
CATEGORY = "advanced/model_merging"
|
|
def merge(self, clip1, clip2): return (clip1,)
|
|
|
|
class ModelMergeBlocks:
|
|
@classmethod
|
|
def INPUT_TYPES(s): return {"required": { "model1": ("MODEL",),"model2": ("MODEL",),"input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),"middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})}}
|
|
RETURN_TYPES = ("MODEL",)
|
|
FUNCTION = "merge"
|
|
CATEGORY = "advanced/model_merging"
|
|
def merge(self, model1, model2, **kwargs): return (model1,)
|
|
|
|
def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None): pass
|
|
|
|
class CheckpointSave:
|
|
def __init__(self): self.output_dir = folder_paths.get_output_directory()
|
|
@classmethod
|
|
def INPUT_TYPES(s): return {"required": { "model": ("MODEL",),"clip": ("CLIP",),"vae": ("VAE",),"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
|
RETURN_TYPES = ()
|
|
FUNCTION = "save"
|
|
OUTPUT_NODE = True
|
|
CATEGORY = "advanced/model_merging"
|
|
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None): return {}
|
|
|
|
class CLIPSave:
|
|
def __init__(self): self.output_dir = folder_paths.get_output_directory()
|
|
@classmethod
|
|
def INPUT_TYPES(s): return {"required": { "clip": ("CLIP",),"filename_prefix": ("STRING", {"default": "clip/ComfyUI"}),},"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
|
RETURN_TYPES = ()
|
|
FUNCTION = "save"
|
|
OUTPUT_NODE = True
|
|
CATEGORY = "advanced/model_merging"
|
|
def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None): return {}
|
|
|
|
class VAESave:
|
|
def __init__(self): self.output_dir = folder_paths.get_output_directory()
|
|
@classmethod
|
|
def INPUT_TYPES(s): return {"required": { "vae": ("VAE",),"filename_prefix": ("STRING", {"default": "vae/ComfyUI_vae"}),},"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
|
RETURN_TYPES = ()
|
|
FUNCTION = "save"
|
|
OUTPUT_NODE = True
|
|
CATEGORY = "advanced/model_merging"
|
|
def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None): return {}
|
|
|
|
class ModelSave:
|
|
def __init__(self): self.output_dir = folder_paths.get_output_directory()
|
|
@classmethod
|
|
def INPUT_TYPES(s): return {"required": { "model": ("MODEL",),"filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),},"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
|
RETURN_TYPES = ()
|
|
FUNCTION = "save"
|
|
OUTPUT_NODE = True
|
|
CATEGORY = "advanced/model_merging"
|
|
def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None): return {}
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"ModelMergeSimple": ModelMergeSimple,
|
|
"ModelMergeMultiSimple": ModelMergeMultiSimple,
|
|
"ModelMergeBlocks": ModelMergeBlocks,
|
|
"ModelMergeSubtract": ModelSubtract,
|
|
"ModelMergeAdd": ModelAdd,
|
|
"CheckpointSave": CheckpointSave,
|
|
"CLIPMergeSimple": CLIPMergeSimple,
|
|
"CLIPMergeSubtract": CLIPSubtract,
|
|
"CLIPMergeAdd": CLIPAdd,
|
|
"CLIPSave": CLIPSave,
|
|
"VAESave": VAESave,
|
|
"ModelSave": ModelSave,
|
|
}
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"ModelMergeSimple": "Model Merge Simple (2 Models)",
|
|
"ModelMergeMultiSimple": "Model Merge Multi Simple (5 Models)",
|
|
"ModelMergeBlocks": "Model Merge Blocks",
|
|
"ModelMergeSubtract": "Model Subtract",
|
|
"ModelMergeAdd": "Model Add",
|
|
"CheckpointSave": "Save Checkpoint",
|
|
"CLIPMergeSimple": "CLIP Merge Simple",
|
|
"CLIPMergeSubtract": "CLIP Subtract",
|
|
"CLIPMergeAdd": "CLIP Add",
|
|
"CLIPSave": "CLIP Save",
|
|
"VAESave": "VAE Save",
|
|
"ModelSave": "Model Save",
|
|
}
|
|
|
|
print("Custom model merging nodes loaded.") |