File size: 15,056 Bytes
4432cc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
import comfy.sd
import comfy.utils
import comfy.model_base
import comfy.model_management
import comfy.model_sampling

import torch
import folder_paths # Assuming this is available from the original context
import json
import os

# It's good practice to check if args is available, especially if running outside full ComfyUI
try:
    from comfy.cli_args import args
except ImportError:
    class ArgsMock:
        disable_metadata = False
    args = ArgsMock()


class ModelMergeSimple: # Keeping original for context, not strictly needed for the new node
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "model1": ("MODEL",),
                              "model2": ("MODEL",),
                              "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
                              }}
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "merge"

    CATEGORY = "advanced/model_merging"

    def merge(self, model1, model2, ratio):
        m = model1.clone()
        kp = model2.get_key_patches("diffusion_model.")
        for k in kp:
            m.add_patches({k: kp[k]}, ratio, 1.0 - ratio) # Original was 1.0 - ratio, ratio. Swapped to match typical 'ratio applies to model2'
        return (m, )

class ModelMergeMultiSimple:
    @classmethod
    def INPUT_TYPES(s):
        inputs = {"required": {}}
        for i in range(1, 6): # 5 models
            inputs["required"][f"model{i}"] = ("MODEL",)
            inputs["required"][f"ratio{i}"] = ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01})
        return inputs

    RETURN_TYPES = ("MODEL",)
    FUNCTION = "merge_five" # Changed function name to avoid conflict if in same file

    CATEGORY = "advanced/model_merging"

    def merge_five(self, **kwargs):
        models = []
        ratios = []

        for i in range(1, 6):
            model = kwargs.get(f"model{i}")
            ratio = kwargs.get(f"ratio{i}")
            if model is not None: # Basic check
                models.append(model)
                ratios.append(ratio)
            elif ratio > 0:
                # If a ratio is provided for a non-existent model slot (shouldn't happen with type system)
                print(f"Warning: Ratio {ratio} provided for model{i} but model is missing. Ignoring.")
                ratios.append(0.0) # Effectively ignore

        if not models:
            raise ValueError("No models provided for merging.")

        # Filter out models with a zero ratio to avoid unnecessary processing
        # and issues with zero total_ratio if all are zero.
        active_models_data = []
        for model, ratio in zip(models, ratios):
            if ratio > 0:
                active_models_data.append({"model": model, "original_ratio": ratio})

        if not active_models_data:
            print("Warning: All model ratios are 0. Returning the first provided model without changes.")
            return (models[0].clone(), )

        # Calculate the sum of original ratios for normalization
        total_original_ratio = sum(item["original_ratio"] for item in active_models_data)

        if total_original_ratio == 0: # Should be caught by previous check, but as a safeguard
            print("Warning: Sum of active model ratios is 0. Returning the first provided model.")
            return (models[0].clone(), )

        # Normalize ratios (these are the w_i in the explanation)
        normalized_ratios = [item["original_ratio"] / total_original_ratio for item in active_models_data]

        # Start with the first active model
        merged_model = active_models_data[0]["model"].clone()
        
        if len(active_models_data) == 1:
            # Only one model has a non-zero ratio, just return its clone
            return (merged_model,)

        current_cumulative_normalized_weight = normalized_ratios[0]

        # Iteratively merge subsequent models
        for i in range(1, len(active_models_data)):
            next_model_data = active_models_data[i]
            next_model_normalized_weight = normalized_ratios[i]

            # If current_cumulative_normalized_weight is zero (e.g. first model had ratio 0 but others non-zero)
            # and this is the first *actual* model to process in the loop.
            if current_cumulative_normalized_weight == 0 and i==0 : #This condition might need adjustment based on active_models_data start
                 merged_model = next_model_data["model"].clone()
                 current_cumulative_normalized_weight = next_model_normalized_weight
                 continue # skip the add_patches for this first assignment


            # The denominator for scaling factors when adding the next_model
            # This is (w_accumulated + w_next)
            denominator = current_cumulative_normalized_weight + next_model_normalized_weight
            
            if denominator == 0: # Should not happen if ratios are positive and sum > 0
                continue

            # Strength for the patches from the next_model (w_next / (w_accumulated + w_next))
            strength_for_next_model = next_model_normalized_weight / denominator
            # Strength for the patches already in merged_model (w_accumulated / (w_accumulated + w_next))
            strength_for_merged_model_self = current_cumulative_normalized_weight / denominator

            key_patches = next_model_data["model"].get_key_patches("diffusion_model.")
            
            # ComfyUI's add_patches: m.add_patches({k: kp[k]}, strength_for_incoming_patch, strength_for_self_patch)
            # This means: new_value = self_patch * strength_for_self_patch + incoming_patch * strength_for_incoming_patch
            for k in key_patches: # Iterate over keys if add_patches doesn't take the whole dict as first arg directly for all keys
                merged_model.add_patches({k: key_patches[k]}, strength_for_next_model, strength_for_merged_model_self)
            
            current_cumulative_normalized_weight += next_model_normalized_weight
            # Due to potential floating point inaccuracies, it's good to ensure the sum doesn't exceed 1.0
            # However, the logic ensures the sum of *normalized original weights* is what we track.
            # This cumulative weight is the sum of normalized_ratios of models incorporated so far.

        return (merged_model,)


# --- Other classes from your provided code for completeness ---
class ModelSubtract:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "model1": ("MODEL",),
                              "model2": ("MODEL",),
                              "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
                              }}
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "merge"
    CATEGORY = "advanced/model_merging"
    def merge(self, model1, model2, multiplier):
        m = model1.clone()
        kp = model2.get_key_patches("diffusion_model.")
        for k in kp:
            m.add_patches({k: kp[k]}, multiplier, -multiplier) # Note: ComfyUI add_patches is (patches, mult, mult_self)
                                                              # For subtract A - B*mult: m.add_patches({k: kp[k]}, -multiplier, 1.0)
                                                              # The example in comfy has it as (A,B,mult) -> A - B*mult + B*mult which is just A
                                                              # (A,B,mult) with add_patches(kp, B, A_self) means A_self*m + B*kp
                                                              # The ModelSubtract node in Comfy's model_toolרובֿ.py is:
                                                              # m.add_patches({k: kp[k]}, -multiplier, 1.0)
                                                              # This results in: final_patch = patch_self * 1.0 + patch_other * (-multiplier)
                                                              # So, model1_patch - model2_patch * multiplier
        # Correcting ModelSubtract based on typical understanding (model1 - model2*multiplier)
        # The provided code had: m.add_patches({k: kp[k]}, - multiplier, multiplier)
        # This would result in: final = self * mult + other * (-mult)
        # Let's assume the user wants model1 - model2*multiplier
        m = model1.clone()
        kp = model2.get_key_patches("diffusion_model.")
        for k in kp:
            m.add_patches({k: kp[k]}, -multiplier, 1.0)
        return (m, )

class ModelAdd:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "model1": ("MODEL",),
                              "model2": ("MODEL",),
                              }}
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "merge"
    CATEGORY = "advanced/model_merging"
    def merge(self, model1, model2):
        m = model1.clone()
        kp = model2.get_key_patches("diffusion_model.")
        for k in kp:
            m.add_patches({k: kp[k]}, 1.0, 1.0) # model1*1.0 + model2*1.0
        return (m, )

# ... (rest of your provided classes: CLIPMergeSimple, CLIPSubtract, CLIPAdd, ModelMergeBlocks, save_checkpoint, CheckpointSave, etc.)
# It's important that these classes are also defined if you are running this as a standalone script for testing,
# or if they are in the same file. For ComfyUI, it will pick them up.

# Make sure to add the new class to NODE_CLASS_MAPPINGS
# Original mappings:
# NODE_CLASS_MAPPINGS = {
#     "ModelMergeSimple": ModelMergeSimple,
#     "ModelMergeBlocks": ModelMergeBlocks,
#     "ModelMergeSubtract": ModelSubtract,
#     "ModelMergeAdd": ModelAdd,
#     # ... other mappings
# }

# Add the new node:
# Assuming NODE_CLASS_MAPPINGS and NODE_DISPLAY_NAME_MAPPINGS are defined at the end of the file.

# Placeholder for the rest of the code structure
class CLIPMergeSimple:
    @classmethod
    def INPUT_TYPES(s): return {"required": { "clip1": ("CLIP",),"clip2": ("CLIP",),"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), }}
    RETURN_TYPES = ("CLIP",)
    FUNCTION = "merge"
    CATEGORY = "advanced/model_merging"
    def merge(self, clip1, clip2, ratio): return (clip1, ) # Placeholder

class CLIPSubtract:
    @classmethod
    def INPUT_TYPES(s): return {"required": { "clip1": ("CLIP",),"clip2": ("CLIP",),"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),}}
    RETURN_TYPES = ("CLIP",)
    FUNCTION = "merge"
    CATEGORY = "advanced/model_merging"
    def merge(self, clip1, clip2, multiplier): return (clip1,) # Placeholder

class CLIPAdd:
    @classmethod
    def INPUT_TYPES(s): return {"required": { "clip1": ("CLIP",),"clip2": ("CLIP",),}}
    RETURN_TYPES = ("CLIP",)
    FUNCTION = "merge"
    CATEGORY = "advanced/model_merging"
    def merge(self, clip1, clip2): return (clip1,) # Placeholder

class ModelMergeBlocks:
    @classmethod
    def INPUT_TYPES(s): return {"required": { "model1": ("MODEL",),"model2": ("MODEL",),"input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),"middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})}}
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "merge"
    CATEGORY = "advanced/model_merging"
    def merge(self, model1, model2, **kwargs): return (model1,) # Placeholder

def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None): pass # Placeholder

class CheckpointSave:
    def __init__(self): self.output_dir = folder_paths.get_output_directory()
    @classmethod
    def INPUT_TYPES(s): return {"required": { "model": ("MODEL",),"clip": ("CLIP",),"vae": ("VAE",),"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
    RETURN_TYPES = ()
    FUNCTION = "save"
    OUTPUT_NODE = True
    CATEGORY = "advanced/model_merging"
    def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None): return {} # Placeholder

class CLIPSave:
    def __init__(self): self.output_dir = folder_paths.get_output_directory()
    @classmethod
    def INPUT_TYPES(s): return {"required": { "clip": ("CLIP",),"filename_prefix": ("STRING", {"default": "clip/ComfyUI"}),},"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
    RETURN_TYPES = ()
    FUNCTION = "save"
    OUTPUT_NODE = True
    CATEGORY = "advanced/model_merging"
    def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None): return {} # Placeholder

class VAESave:
    def __init__(self): self.output_dir = folder_paths.get_output_directory()
    @classmethod
    def INPUT_TYPES(s): return {"required": { "vae": ("VAE",),"filename_prefix": ("STRING", {"default": "vae/ComfyUI_vae"}),},"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
    RETURN_TYPES = ()
    FUNCTION = "save"
    OUTPUT_NODE = True
    CATEGORY = "advanced/model_merging"
    def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None): return {} # Placeholder

class ModelSave:
    def __init__(self): self.output_dir = folder_paths.get_output_directory()
    @classmethod
    def INPUT_TYPES(s): return {"required": { "model": ("MODEL",),"filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),},"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
    RETURN_TYPES = ()
    FUNCTION = "save"
    OUTPUT_NODE = True
    CATEGORY = "advanced/model_merging"
    def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None): return {} # Placeholder


NODE_CLASS_MAPPINGS = {
    "ModelMergeSimple": ModelMergeSimple,
    "ModelMergeMultiSimple": ModelMergeMultiSimple, # Added new class
    "ModelMergeBlocks": ModelMergeBlocks,
    "ModelMergeSubtract": ModelSubtract,
    "ModelMergeAdd": ModelAdd,
    "CheckpointSave": CheckpointSave,
    "CLIPMergeSimple": CLIPMergeSimple,
    "CLIPMergeSubtract": CLIPSubtract,
    "CLIPMergeAdd": CLIPAdd,
    "CLIPSave": CLIPSave,
    "VAESave": VAESave,
    "ModelSave": ModelSave,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "ModelMergeSimple": "Model Merge Simple (2 Models)", # Clarified original
    "ModelMergeMultiSimple": "Model Merge Multi Simple (5 Models)", # Added new display name
    "ModelMergeBlocks": "Model Merge Blocks",
    "ModelMergeSubtract": "Model Subtract",
    "ModelMergeAdd": "Model Add",
    "CheckpointSave": "Save Checkpoint",
    "CLIPMergeSimple": "CLIP Merge Simple",
    "CLIPMergeSubtract": "CLIP Subtract",
    "CLIPMergeAdd": "CLIP Add",
    "CLIPSave": "CLIP Save",
    "VAESave": "VAE Save",
    "ModelSave": "Model Save",
}

print("Custom model merging nodes loaded.")