File size: 4,564 Bytes
8866644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#credit to Acly for this module
#from https://github.com/Acly/comfyui-inpaint-nodes
import torch
import torch.nn.functional as F
import comfy
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher
from comfy.model_management import cast_to_device

from .log import log_node_warn, log_node_error, log_node_info

# Inpaint
if hasattr(comfy.lora, "calculate_weight"):
  original_calculate_weight = comfy.lora.calculate_weight
else:
    original_calculate_weight = ModelPatcher.calculate_weight
injected_model_patcher_calculate_weight = False

class InpaintHead(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.head = torch.nn.Parameter(torch.empty(size=(320, 5, 3, 3), device="cpu"))

    def __call__(self, x):
        x = F.pad(x, (1, 1, 1, 1), "replicate")
        return F.conv2d(x, weight=self.head)

def calculate_weight_patched(patches, weight, key, intermediate_type=torch.float32):
    remaining = []

    for p in patches:
        alpha = p[0]
        v = p[1]

        is_fooocus_patch = isinstance(v, tuple) and len(v) == 2 and v[0] == "fooocus"
        if not is_fooocus_patch:
            remaining.append(p)
            continue

        if alpha != 0.0:
            v = v[1]
            w1 = cast_to_device(v[0], weight.device, torch.float32)
            if w1.shape == weight.shape:
                w_min = cast_to_device(v[1], weight.device, torch.float32)
                w_max = cast_to_device(v[2], weight.device, torch.float32)
                w1 = (w1 / 255.0) * (w_max - w_min) + w_min
                weight += alpha * cast_to_device(w1, weight.device, weight.dtype)
            else:
                pass
                # log_node_warn(self.node_name,
                #     f"Shape mismatch {key}, weight not merged ({w1.shape} != {weight.shape})"
                # )

        if len(remaining) > 0:
            return original_calculate_weight(remaining, weight, key, intermediate_type)
    return weight

def inject_patched_calculate_weight():
    global injected_model_patcher_calculate_weight
    if not injected_model_patcher_calculate_weight:
        print(
            "[comfyui-inpaint-nodes] Injecting patched comfy.model_patcher.ModelPatcher.calculate_weight"
        )
        if hasattr(comfy.lora, "calculate_weight"):
            comfy.lora.calculate_weight = calculate_weight_patched
        else:
            ModelPatcher.calculate_weight = calculate_weight_patched
        injected_model_patcher_calculate_weight = True

class InpaintWorker:
    def __init__(self, node_name):
        self.node_name = node_name if node_name is not None else ""

    def load_fooocus_patch(self, lora: dict, to_load: dict):
        patch_dict = {}
        loaded_keys = set()
        for key in to_load.values():
            if value := lora.get(key, None):
                patch_dict[key] = ("fooocus", value)
                loaded_keys.add(key)

        not_loaded = sum(1 for x in lora if x not in loaded_keys)
        if not_loaded > 0:
            log_node_info(self.node_name,
                f"{len(loaded_keys)} Lora keys loaded, {not_loaded} remaining keys not found in model."
            )
        return patch_dict


    def patch(self, model, latent, patch):
        base_model: BaseModel = model.model
        latent_pixels = base_model.process_latent_in(latent["samples"])
        noise_mask = latent["noise_mask"].round()
        latent_mask = F.max_pool2d(noise_mask, (8, 8)).round().to(latent_pixels)

        inpaint_head_model, inpaint_lora = patch
        feed = torch.cat([latent_mask, latent_pixels], dim=1)
        inpaint_head_model.to(device=feed.device, dtype=feed.dtype)
        inpaint_head_feature = inpaint_head_model(feed)

        def input_block_patch(h, transformer_options):
            if transformer_options["block"][1] == 0:
                h = h + inpaint_head_feature.to(h)
            return h

        lora_keys = comfy.lora.model_lora_keys_unet(model.model, {})
        lora_keys.update({x: x for x in base_model.state_dict().keys()})
        loaded_lora = self.load_fooocus_patch(inpaint_lora, lora_keys)

        m = model.clone()
        m.set_model_input_block_patch(input_block_patch)
        patched = m.add_patches(loaded_lora, 1.0)

        not_patched_count = sum(1 for x in loaded_lora if x not in patched)
        if not_patched_count > 0:
            log_node_error(self.node_name, f"Failed to patch {not_patched_count} keys")

        inject_patched_calculate_weight()
        return (m,)