pantat88 commited on
Commit
dc5af96
1 Parent(s): 7493ff4

Upload controlnet.py

Browse files
Files changed (1) hide show
  1. controlnet.py +1121 -0
controlnet.py ADDED
@@ -0,0 +1,1121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import logging
4
+ import re
5
+ from collections import OrderedDict
6
+ from copy import copy
7
+ from typing import Dict, Optional, Tuple
8
+ import modules.scripts as scripts
9
+ from modules import shared, devices, script_callbacks, processing, masking, images
10
+ import gradio as gr
11
+ import time
12
+
13
+
14
+ from einops import rearrange
15
+ from scripts import global_state, hook, external_code, processor, batch_hijack, controlnet_version, utils
16
+ from scripts.controlnet_lora import bind_control_lora, unbind_control_lora
17
+ from scripts.processor import *
18
+ from scripts.adapter import Adapter, StyleAdapter, Adapter_light
19
+ from scripts.controlnet_lllite import PlugableControlLLLite, clear_all_lllite
20
+ from scripts.controlmodel_ipadapter import PlugableIPAdapter, clear_all_ip_adapter
21
+ from scripts.utils import load_state_dict, get_unique_axis0
22
+ from scripts.hook import ControlParams, UnetHook, ControlModelType, HackedImageRNG
23
+ from scripts.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit
24
+ from scripts.logging import logger
25
+ from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img
26
+ from modules.images import save_image
27
+ from scripts.infotext import Infotext
28
+
29
+ import cv2
30
+ import numpy as np
31
+ import torch
32
+
33
+ from pathlib import Path
34
+ from PIL import Image, ImageFilter, ImageOps
35
+ from scripts.lvminthin import lvmin_thin, nake_nms
36
+ from scripts.processor import model_free_preprocessors
37
+ from scripts.controlnet_model_guess import build_model_by_guess
38
+
39
+
40
+ gradio_compat = True
41
+ try:
42
+ from distutils.version import LooseVersion
43
+ from importlib_metadata import version
44
+ if LooseVersion(version("gradio")) < LooseVersion("3.10"):
45
+ gradio_compat = False
46
+ except ImportError:
47
+ pass
48
+
49
+
50
+ # Gradio 3.32 bug fix
51
+ import tempfile
52
+ gradio_tempfile_path = os.path.join(tempfile.gettempdir(), 'gradio')
53
+ os.makedirs(gradio_tempfile_path, exist_ok=True)
54
+
55
+
56
+ def clear_all_secondary_control_models():
57
+ clear_all_lllite()
58
+ clear_all_ip_adapter()
59
+
60
+
61
+ def find_closest_lora_model_name(search: str):
62
+ if not search:
63
+ return None
64
+ if search in global_state.cn_models:
65
+ return search
66
+ search = search.lower()
67
+ if search in global_state.cn_models_names:
68
+ return global_state.cn_models_names.get(search)
69
+ applicable = [name for name in global_state.cn_models_names.keys()
70
+ if search in name.lower()]
71
+ if not applicable:
72
+ return None
73
+ applicable = sorted(applicable, key=lambda name: len(name))
74
+ return global_state.cn_models_names[applicable[0]]
75
+
76
+
77
+ def swap_img2img_pipeline(p: processing.StableDiffusionProcessingImg2Img):
78
+ p.__class__ = processing.StableDiffusionProcessingTxt2Img
79
+ dummy = processing.StableDiffusionProcessingTxt2Img()
80
+ for k,v in dummy.__dict__.items():
81
+ if hasattr(p, k):
82
+ continue
83
+ setattr(p, k, v)
84
+
85
+
86
+ global_state.update_cn_models()
87
+
88
+
89
+ def image_dict_from_any(image) -> Optional[Dict[str, np.ndarray]]:
90
+ if image is None:
91
+ return None
92
+
93
+ if isinstance(image, (tuple, list)):
94
+ image = {'image': image[0], 'mask': image[1]}
95
+ elif not isinstance(image, dict):
96
+ image = {'image': image, 'mask': None}
97
+ else: # type(image) is dict
98
+ # copy to enable modifying the dict and prevent response serialization error
99
+ image = dict(image)
100
+
101
+ if isinstance(image['image'], str):
102
+ if os.path.exists(image['image']):
103
+ image['image'] = np.array(Image.open(image['image'])).astype('uint8')
104
+ elif image['image']:
105
+ image['image'] = external_code.to_base64_nparray(image['image'])
106
+ else:
107
+ image['image'] = None
108
+
109
+ # If there is no image, return image with None image and None mask
110
+ if image['image'] is None:
111
+ image['mask'] = None
112
+ return image
113
+
114
+ if isinstance(image['mask'], str):
115
+ if os.path.exists(image['mask']):
116
+ image['mask'] = np.array(Image.open(image['mask'])).astype('uint8')
117
+ elif image['mask']:
118
+ image['mask'] = external_code.to_base64_nparray(image['mask'])
119
+ else:
120
+ image['mask'] = np.zeros_like(image['image'], dtype=np.uint8)
121
+ elif image['mask'] is None:
122
+ image['mask'] = np.zeros_like(image['image'], dtype=np.uint8)
123
+
124
+ return image
125
+
126
+
127
+ def image_has_mask(input_image: np.ndarray) -> bool:
128
+ """
129
+ Determine if an image has an alpha channel (mask) that is not empty.
130
+
131
+ The function checks if the input image has three dimensions (height, width, channels),
132
+ and if the third dimension (channel dimension) is of size 4 (presumably RGB + alpha).
133
+ Then it checks if the maximum value in the alpha channel is greater than 127. This is
134
+ presumably to check if there is any non-transparent (or semi-transparent) pixel in the
135
+ image. A pixel is considered non-transparent if its alpha value is above 127.
136
+
137
+ Args:
138
+ input_image (np.ndarray): A 3D numpy array representing an image. The dimensions
139
+ should represent [height, width, channels].
140
+
141
+ Returns:
142
+ bool: True if the image has a non-empty alpha channel, False otherwise.
143
+ """
144
+ return (
145
+ input_image.ndim == 3 and
146
+ input_image.shape[2] == 4 and
147
+ np.max(input_image[:, :, 3]) > 127
148
+ )
149
+
150
+
151
+ def prepare_mask(
152
+ mask: Image.Image, p: processing.StableDiffusionProcessing
153
+ ) -> Image.Image:
154
+ """
155
+ Prepare an image mask for the inpainting process.
156
+
157
+ This function takes as input a PIL Image object and an instance of the
158
+ StableDiffusionProcessing class, and performs the following steps to prepare the mask:
159
+
160
+ 1. Convert the mask to grayscale (mode "L").
161
+ 2. If the 'inpainting_mask_invert' attribute of the processing instance is True,
162
+ invert the mask colors.
163
+ 3. If the 'mask_blur' attribute of the processing instance is greater than 0,
164
+ apply a Gaussian blur to the mask with a radius equal to 'mask_blur'.
165
+
166
+ Args:
167
+ mask (Image.Image): The input mask as a PIL Image object.
168
+ p (processing.StableDiffusionProcessing): An instance of the StableDiffusionProcessing class
169
+ containing the processing parameters.
170
+
171
+ Returns:
172
+ mask (Image.Image): The prepared mask as a PIL Image object.
173
+ """
174
+ mask = mask.convert("L")
175
+ if getattr(p, "inpainting_mask_invert", False):
176
+ mask = ImageOps.invert(mask)
177
+
178
+ if hasattr(p, 'mask_blur_x'):
179
+ if getattr(p, "mask_blur_x", 0) > 0:
180
+ np_mask = np.array(mask)
181
+ kernel_size = 2 * int(2.5 * p.mask_blur_x + 0.5) + 1
182
+ np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), p.mask_blur_x)
183
+ mask = Image.fromarray(np_mask)
184
+ if getattr(p, "mask_blur_y", 0) > 0:
185
+ np_mask = np.array(mask)
186
+ kernel_size = 2 * int(2.5 * p.mask_blur_y + 0.5) + 1
187
+ np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), p.mask_blur_y)
188
+ mask = Image.fromarray(np_mask)
189
+ else:
190
+ if getattr(p, "mask_blur", 0) > 0:
191
+ mask = mask.filter(ImageFilter.GaussianBlur(p.mask_blur))
192
+
193
+ return mask
194
+
195
+
196
+ def set_numpy_seed(p: processing.StableDiffusionProcessing) -> Optional[int]:
197
+ """
198
+ Set the random seed for NumPy based on the provided parameters.
199
+
200
+ Args:
201
+ p (processing.StableDiffusionProcessing): The instance of the StableDiffusionProcessing class.
202
+
203
+ Returns:
204
+ Optional[int]: The computed random seed if successful, or None if an exception occurs.
205
+
206
+ This function sets the random seed for NumPy using the seed and subseed values from the given instance of
207
+ StableDiffusionProcessing. If either seed or subseed is -1, it uses the first value from `all_seeds`.
208
+ Otherwise, it takes the maximum of the provided seed value and 0.
209
+
210
+ The final random seed is computed by adding the seed and subseed values, applying a bitwise AND operation
211
+ with 0xFFFFFFFF to ensure it fits within a 32-bit integer.
212
+ """
213
+ try:
214
+ tmp_seed = int(p.all_seeds[0] if p.seed == -1 else max(int(p.seed), 0))
215
+ tmp_subseed = int(p.all_seeds[0] if p.subseed == -1 else max(int(p.subseed), 0))
216
+ seed = (tmp_seed + tmp_subseed) & 0xFFFFFFFF
217
+ np.random.seed(seed)
218
+ return seed
219
+ except Exception as e:
220
+ logger.warning(e)
221
+ logger.warning('Warning: Failed to use consistent random seed.')
222
+ return None
223
+
224
+
225
+ class Script(scripts.Script, metaclass=(
226
+ utils.TimeMeta if logger.level == logging.DEBUG else type)):
227
+
228
+ model_cache = OrderedDict()
229
+
230
+ def __init__(self) -> None:
231
+ super().__init__()
232
+ self.latest_network = None
233
+ self.preprocessor = global_state.cache_preprocessors(global_state.cn_preprocessor_modules)
234
+ self.unloadable = global_state.cn_preprocessor_unloadable
235
+ self.input_image = None
236
+ self.latest_model_hash = ""
237
+ self.enabled_units = []
238
+ self.detected_map = []
239
+ self.post_processors = []
240
+ self.noise_modifier = None
241
+ batch_hijack.instance.process_batch_callbacks.append(self.batch_tab_process)
242
+ batch_hijack.instance.process_batch_each_callbacks.append(self.batch_tab_process_each)
243
+ batch_hijack.instance.postprocess_batch_each_callbacks.insert(0, self.batch_tab_postprocess_each)
244
+ batch_hijack.instance.postprocess_batch_callbacks.insert(0, self.batch_tab_postprocess)
245
+
246
+ def title(self):
247
+ return "ControlNet"
248
+
249
+ def show(self, is_img2img):
250
+ return scripts.AlwaysVisible
251
+
252
+ @staticmethod
253
+ def get_default_ui_unit(is_ui=True):
254
+ cls = UiControlNetUnit if is_ui else external_code.ControlNetUnit
255
+ return cls(
256
+ enabled=False,
257
+ module="none",
258
+ model="None"
259
+ )
260
+
261
+ def uigroup(self, tabname: str, is_img2img: bool, elem_id_tabname: str) -> Tuple[ControlNetUiGroup, gr.State]:
262
+ group = ControlNetUiGroup(
263
+ gradio_compat,
264
+ Script.get_default_ui_unit(),
265
+ self.preprocessor,
266
+ )
267
+ group.render(tabname, elem_id_tabname, is_img2img)
268
+ group.register_callbacks(is_img2img)
269
+ return group, group.render_and_register_unit(tabname, is_img2img)
270
+
271
+ def ui(self, is_img2img):
272
+ """this function should create gradio UI elements. See https://gradio.app/docs/#components
273
+ The return value should be an array of all components that are used in processing.
274
+ Values of those returned components will be passed to run() and process() functions.
275
+ """
276
+ infotext = Infotext()
277
+
278
+ controls = ()
279
+ max_models = shared.opts.data.get("control_net_unit_count", 3)
280
+ elem_id_tabname = ("img2img" if is_img2img else "txt2img") + "_controlnet"
281
+ with gr.Group(elem_id=elem_id_tabname):
282
+ with gr.Accordion(f"ControlNet {controlnet_version.version_flag}", open = False, elem_id="controlnet"):
283
+ if max_models > 1:
284
+ with gr.Tabs(elem_id=f"{elem_id_tabname}_tabs"):
285
+ for i in range(max_models):
286
+ with gr.Tab(f"ControlNet Unit {i}",
287
+ elem_classes=['cnet-unit-tab']):
288
+ group, state = self.uigroup(f"ControlNet-{i}", is_img2img, elem_id_tabname)
289
+ infotext.register_unit(i, group)
290
+ controls += (state,)
291
+ else:
292
+ with gr.Column():
293
+ group, state = self.uigroup(f"ControlNet", is_img2img, elem_id_tabname)
294
+ infotext.register_unit(0, group)
295
+ controls += (state,)
296
+
297
+ if shared.opts.data.get("control_net_sync_field_args", True):
298
+ self.infotext_fields = infotext.infotext_fields
299
+ self.paste_field_names = infotext.paste_field_names
300
+
301
+ return controls
302
+
303
+ @staticmethod
304
+ def clear_control_model_cache():
305
+ Script.model_cache.clear()
306
+ gc.collect()
307
+ devices.torch_gc()
308
+
309
+ @staticmethod
310
+ def load_control_model(p, unet, model):
311
+ if model in Script.model_cache:
312
+ logger.info(f"Loading model from cache: {model}")
313
+ return Script.model_cache[model]
314
+
315
+ # Remove model from cache to clear space before building another model
316
+ if len(Script.model_cache) > 0 and len(Script.model_cache) >= shared.opts.data.get("control_net_model_cache_size", 2):
317
+ Script.model_cache.popitem(last=False)
318
+ gc.collect()
319
+ devices.torch_gc()
320
+
321
+ model_net = Script.build_control_model(p, unet, model)
322
+
323
+ if shared.opts.data.get("control_net_model_cache_size", 2) > 0:
324
+ Script.model_cache[model] = model_net
325
+
326
+ return model_net
327
+
328
+ @staticmethod
329
+ def build_control_model(p, unet, model):
330
+ if model is None or model == 'None':
331
+ raise RuntimeError(f"You have not selected any ControlNet Model.")
332
+
333
+ model_path = global_state.cn_models.get(model, None)
334
+ if model_path is None:
335
+ model = find_closest_lora_model_name(model)
336
+ model_path = global_state.cn_models.get(model, None)
337
+
338
+ if model_path is None:
339
+ raise RuntimeError(f"model not found: {model}")
340
+
341
+ # trim '"' at start/end
342
+ if model_path.startswith("\"") and model_path.endswith("\""):
343
+ model_path = model_path[1:-1]
344
+
345
+ if not os.path.exists(model_path):
346
+ raise ValueError(f"file not found: {model_path}")
347
+
348
+ logger.info(f"Loading model: {model}")
349
+ state_dict = load_state_dict(model_path)
350
+ network = build_model_by_guess(state_dict, unet, model_path)
351
+ network.to('cpu', dtype=p.sd_model.dtype)
352
+ logger.info(f"ControlNet model {model} loaded.")
353
+ return network
354
+
355
+ @staticmethod
356
+ def get_remote_call(p, attribute, default=None, idx=0, strict=False, force=False):
357
+ if not force and not shared.opts.data.get("control_net_allow_script_control", False):
358
+ return default
359
+
360
+ def get_element(obj, strict=False):
361
+ if not isinstance(obj, list):
362
+ return obj if not strict or idx == 0 else None
363
+ elif idx < len(obj):
364
+ return obj[idx]
365
+ else:
366
+ return None
367
+
368
+ attribute_value = get_element(getattr(p, attribute, None), strict)
369
+ default_value = get_element(default)
370
+ return attribute_value if attribute_value is not None else default_value
371
+
372
+ @staticmethod
373
+ def parse_remote_call(p, unit: external_code.ControlNetUnit, idx):
374
+ selector = Script.get_remote_call
375
+
376
+ unit.enabled = selector(p, "control_net_enabled", unit.enabled, idx, strict=True)
377
+ unit.module = selector(p, "control_net_module", unit.module, idx)
378
+ unit.model = selector(p, "control_net_model", unit.model, idx)
379
+ unit.weight = selector(p, "control_net_weight", unit.weight, idx)
380
+ unit.image = selector(p, "control_net_image", unit.image, idx)
381
+ unit.resize_mode = selector(p, "control_net_resize_mode", unit.resize_mode, idx)
382
+ unit.low_vram = selector(p, "control_net_lowvram", unit.low_vram, idx)
383
+ unit.processor_res = selector(p, "control_net_pres", unit.processor_res, idx)
384
+ unit.threshold_a = selector(p, "control_net_pthr_a", unit.threshold_a, idx)
385
+ unit.threshold_b = selector(p, "control_net_pthr_b", unit.threshold_b, idx)
386
+ unit.guidance_start = selector(p, "control_net_guidance_start", unit.guidance_start, idx)
387
+ unit.guidance_end = selector(p, "control_net_guidance_end", unit.guidance_end, idx)
388
+ # Backward compatibility. See https://github.com/Mikubill/sd-webui-controlnet/issues/1740
389
+ # for more details.
390
+ unit.guidance_end = selector(p, "control_net_guidance_strength", unit.guidance_end, idx)
391
+ unit.control_mode = selector(p, "control_net_control_mode", unit.control_mode, idx)
392
+ unit.pixel_perfect = selector(p, "control_net_pixel_perfect", unit.pixel_perfect, idx)
393
+
394
+ return unit
395
+
396
+ @staticmethod
397
+ def detectmap_proc(detected_map, module, resize_mode, h, w):
398
+
399
+ if 'inpaint' in module:
400
+ detected_map = detected_map.astype(np.float32)
401
+ else:
402
+ detected_map = HWC3(detected_map)
403
+
404
+ def safe_numpy(x):
405
+ # A very safe method to make sure that Apple/Mac works
406
+ y = x
407
+
408
+ # below is very boring but do not change these. If you change these Apple or Mac may fail.
409
+ y = y.copy()
410
+ y = np.ascontiguousarray(y)
411
+ y = y.copy()
412
+ return y
413
+
414
+ def get_pytorch_control(x):
415
+ # A very safe method to make sure that Apple/Mac works
416
+ y = x
417
+
418
+ # below is very boring but do not change these. If you change these Apple or Mac may fail.
419
+ y = torch.from_numpy(y)
420
+ y = y.float() / 255.0
421
+ y = rearrange(y, 'h w c -> 1 c h w')
422
+ y = y.clone()
423
+ y = y.to(devices.get_device_for("controlnet"))
424
+ y = y.clone()
425
+ return y
426
+
427
+ def high_quality_resize(x, size):
428
+ # Written by lvmin
429
+ # Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
430
+
431
+ inpaint_mask = None
432
+ if x.ndim == 3 and x.shape[2] == 4:
433
+ inpaint_mask = x[:, :, 3]
434
+ x = x[:, :, 0:3]
435
+
436
+ if x.shape[0] != size[1] or x.shape[1] != size[0]:
437
+ new_size_is_smaller = (size[0] * size[1]) < (x.shape[0] * x.shape[1])
438
+ new_size_is_bigger = (size[0] * size[1]) > (x.shape[0] * x.shape[1])
439
+ unique_color_count = len(get_unique_axis0(x.reshape(-1, x.shape[2])))
440
+ is_one_pixel_edge = False
441
+ is_binary = False
442
+ if unique_color_count == 2:
443
+ is_binary = np.min(x) < 16 and np.max(x) > 240
444
+ if is_binary:
445
+ xc = x
446
+ xc = cv2.erode(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
447
+ xc = cv2.dilate(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
448
+ one_pixel_edge_count = np.where(xc < x)[0].shape[0]
449
+ all_edge_count = np.where(x > 127)[0].shape[0]
450
+ is_one_pixel_edge = one_pixel_edge_count * 2 > all_edge_count
451
+
452
+ if 2 < unique_color_count < 200:
453
+ interpolation = cv2.INTER_NEAREST
454
+ elif new_size_is_smaller:
455
+ interpolation = cv2.INTER_AREA
456
+ else:
457
+ interpolation = cv2.INTER_CUBIC # Must be CUBIC because we now use nms. NEVER CHANGE THIS
458
+
459
+ y = cv2.resize(x, size, interpolation=interpolation)
460
+ if inpaint_mask is not None:
461
+ inpaint_mask = cv2.resize(inpaint_mask, size, interpolation=interpolation)
462
+
463
+ if is_binary:
464
+ y = np.mean(y.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8)
465
+ if is_one_pixel_edge:
466
+ y = nake_nms(y)
467
+ _, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
468
+ y = lvmin_thin(y, prunings=new_size_is_bigger)
469
+ else:
470
+ _, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
471
+ y = np.stack([y] * 3, axis=2)
472
+ else:
473
+ y = x
474
+
475
+ if inpaint_mask is not None:
476
+ inpaint_mask = (inpaint_mask > 127).astype(np.float32) * 255.0
477
+ inpaint_mask = inpaint_mask[:, :, None].clip(0, 255).astype(np.uint8)
478
+ y = np.concatenate([y, inpaint_mask], axis=2)
479
+
480
+ return y
481
+
482
+ if resize_mode == external_code.ResizeMode.RESIZE:
483
+ detected_map = high_quality_resize(detected_map, (w, h))
484
+ detected_map = safe_numpy(detected_map)
485
+ return get_pytorch_control(detected_map), detected_map
486
+
487
+ old_h, old_w, _ = detected_map.shape
488
+ old_w = float(old_w)
489
+ old_h = float(old_h)
490
+ k0 = float(h) / old_h
491
+ k1 = float(w) / old_w
492
+
493
+ safeint = lambda x: int(np.round(x))
494
+
495
+ if resize_mode == external_code.ResizeMode.OUTER_FIT:
496
+ k = min(k0, k1)
497
+ borders = np.concatenate([detected_map[0, :, :], detected_map[-1, :, :], detected_map[:, 0, :], detected_map[:, -1, :]], axis=0)
498
+ high_quality_border_color = np.median(borders, axis=0).astype(detected_map.dtype)
499
+ if len(high_quality_border_color) == 4:
500
+ # Inpaint hijack
501
+ high_quality_border_color[3] = 255
502
+ high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1])
503
+ detected_map = high_quality_resize(detected_map, (safeint(old_w * k), safeint(old_h * k)))
504
+ new_h, new_w, _ = detected_map.shape
505
+ pad_h = max(0, (h - new_h) // 2)
506
+ pad_w = max(0, (w - new_w) // 2)
507
+ high_quality_background[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = detected_map
508
+ detected_map = high_quality_background
509
+ detected_map = safe_numpy(detected_map)
510
+ return get_pytorch_control(detected_map), detected_map
511
+ else:
512
+ k = max(k0, k1)
513
+ detected_map = high_quality_resize(detected_map, (safeint(old_w * k), safeint(old_h * k)))
514
+ new_h, new_w, _ = detected_map.shape
515
+ pad_h = max(0, (new_h - h) // 2)
516
+ pad_w = max(0, (new_w - w) // 2)
517
+ detected_map = detected_map[pad_h:pad_h+h, pad_w:pad_w+w]
518
+ detected_map = safe_numpy(detected_map)
519
+ return get_pytorch_control(detected_map), detected_map
520
+
521
+ @staticmethod
522
+ def get_enabled_units(p):
523
+ units = external_code.get_all_units_in_processing(p)
524
+ if len(units) == 0:
525
+ # fill a null group
526
+ remote_unit = Script.parse_remote_call(p, Script.get_default_ui_unit(), 0)
527
+ if remote_unit.enabled:
528
+ units.append(remote_unit)
529
+
530
+ enabled_units = [
531
+ copy(local_unit)
532
+ for idx, unit in enumerate(units)
533
+ for local_unit in (Script.parse_remote_call(p, unit, idx),)
534
+ if local_unit.enabled
535
+ ]
536
+ Infotext.write_infotext(enabled_units, p)
537
+ return enabled_units
538
+
539
+ @staticmethod
540
+ def choose_input_image(
541
+ p: processing.StableDiffusionProcessing,
542
+ unit: external_code.ControlNetUnit,
543
+ idx: int
544
+ ) -> Tuple[np.ndarray, bool]:
545
+ """ Choose input image from following sources with descending priority:
546
+ - p.image_control: [Deprecated] Lagacy way to pass image to controlnet.
547
+ - p.control_net_input_image: [Deprecated] Lagacy way to pass image to controlnet.
548
+ - unit.image:
549
+ - ControlNet tab input image.
550
+ - Input image from API call.
551
+ - p.init_images: A1111 img2img tab input image.
552
+
553
+ Returns:
554
+ - The input image in ndarray form.
555
+ - Whether input image is from A1111.
556
+ """
557
+ image_from_a1111 = False
558
+
559
+ p_input_image = Script.get_remote_call(p, "control_net_input_image", None, idx)
560
+ image = image_dict_from_any(unit.image)
561
+
562
+ if batch_hijack.instance.is_batch and getattr(p, "image_control", None) is not None:
563
+ logger.warning("Warn: Using legacy field 'p.image_control'.")
564
+ input_image = HWC3(np.asarray(p.image_control))
565
+ elif p_input_image is not None:
566
+ logger.warning("Warn: Using legacy field 'p.controlnet_input_image'")
567
+ if isinstance(p_input_image, dict) and "mask" in p_input_image and "image" in p_input_image:
568
+ color = HWC3(np.asarray(p_input_image['image']))
569
+ alpha = np.asarray(p_input_image['mask'])[..., None]
570
+ input_image = np.concatenate([color, alpha], axis=2)
571
+ else:
572
+ input_image = HWC3(np.asarray(p_input_image))
573
+ elif image is not None:
574
+ while len(image['mask'].shape) < 3:
575
+ image['mask'] = image['mask'][..., np.newaxis]
576
+
577
+ # Need to check the image for API compatibility
578
+ if isinstance(image['image'], str):
579
+ from modules.api.api import decode_base64_to_image
580
+ input_image = HWC3(np.asarray(decode_base64_to_image(image['image'])))
581
+ else:
582
+ input_image = HWC3(image['image'])
583
+
584
+ have_mask = 'mask' in image and not (
585
+ (image['mask'][:, :, 0] <= 5).all() or
586
+ (image['mask'][:, :, 0] >= 250).all()
587
+ )
588
+
589
+ if 'inpaint' in unit.module:
590
+ logger.info("using inpaint as input")
591
+ color = HWC3(image['image'])
592
+ if have_mask:
593
+ alpha = image['mask'][:, :, 0:1]
594
+ else:
595
+ alpha = np.zeros_like(color)[:, :, 0:1]
596
+ input_image = np.concatenate([color, alpha], axis=2)
597
+ else:
598
+ if have_mask and not shared.opts.data.get("controlnet_ignore_noninpaint_mask", False):
599
+ logger.info("using mask as input")
600
+ input_image = HWC3(image['mask'][:, :, 0])
601
+ unit.module = 'none' # Always use black bg and white line
602
+ else:
603
+ # use img2img init_image as default
604
+ input_image = getattr(p, "init_images", [None])[0]
605
+ if input_image is None:
606
+ if batch_hijack.instance.is_batch:
607
+ shared.state.interrupted = True
608
+ raise ValueError('controlnet is enabled but no input image is given')
609
+
610
+ input_image = HWC3(np.asarray(input_image))
611
+ image_from_a1111 = True
612
+
613
+ assert isinstance(input_image, np.ndarray)
614
+ return input_image, image_from_a1111
615
+
616
+ @staticmethod
617
+ def bound_check_params(unit: external_code.ControlNetUnit) -> None:
618
+ """
619
+ Checks and corrects negative parameters in ControlNetUnit 'unit'.
620
+ Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to
621
+ their default values if negative.
622
+
623
+ Args:
624
+ unit (external_code.ControlNetUnit): The ControlNetUnit instance to check.
625
+ """
626
+ cfg = preprocessor_sliders_config.get(
627
+ global_state.get_module_basename(unit.module), [])
628
+ defaults = {
629
+ param: cfg_default['value']
630
+ for param, cfg_default in zip(
631
+ ("processor_res", 'threshold_a', 'threshold_b'), cfg)
632
+ if cfg_default is not None
633
+ }
634
+ for param, default_value in defaults.items():
635
+ value = getattr(unit, param)
636
+ if value < 0:
637
+ setattr(unit, param, default_value)
638
+ logger.warning(f'[{unit.module}.{param}] Invalid value({value}), using default value {default_value}.')
639
+
640
+ def controlnet_main_entry(self, p):
641
+ sd_ldm = p.sd_model
642
+ unet = sd_ldm.model.diffusion_model
643
+ self.noise_modifier = None
644
+
645
+ setattr(p, 'controlnet_control_loras', [])
646
+
647
+ if self.latest_network is not None:
648
+ # always restore (~0.05s)
649
+ self.latest_network.restore()
650
+
651
+ # always clear (~0.05s)
652
+ clear_all_secondary_control_models()
653
+
654
+ if not batch_hijack.instance.is_batch:
655
+ self.enabled_units = Script.get_enabled_units(p)
656
+
657
+ if len(self.enabled_units) == 0:
658
+ self.latest_network = None
659
+ return
660
+
661
+ detected_maps = []
662
+ forward_params = []
663
+ post_processors = []
664
+
665
+ # cache stuff
666
+ if self.latest_model_hash != p.sd_model.sd_model_hash:
667
+ Script.clear_control_model_cache()
668
+
669
+ for idx, unit in enumerate(self.enabled_units):
670
+ unit.module = global_state.get_module_basename(unit.module)
671
+
672
+ # unload unused preproc
673
+ module_list = [unit.module for unit in self.enabled_units]
674
+ for key in self.unloadable:
675
+ if key not in module_list:
676
+ self.unloadable.get(key, lambda:None)()
677
+
678
+ self.latest_model_hash = p.sd_model.sd_model_hash
679
+ for idx, unit in enumerate(self.enabled_units):
680
+ Script.bound_check_params(unit)
681
+
682
+ resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
683
+ control_mode = external_code.control_mode_from_value(unit.control_mode)
684
+
685
+ if unit.module in model_free_preprocessors:
686
+ model_net = None
687
+ else:
688
+ model_net = Script.load_control_model(p, unet, unit.model)
689
+ model_net.reset()
690
+
691
+ if getattr(model_net, 'is_control_lora', False):
692
+ control_lora = model_net.control_model
693
+ bind_control_lora(unet, control_lora)
694
+ p.controlnet_control_loras.append(control_lora)
695
+
696
+ input_image, image_from_a1111 = Script.choose_input_image(p, unit, idx)
697
+ if image_from_a1111:
698
+ a1111_i2i_resize_mode = getattr(p, "resize_mode", None)
699
+ if a1111_i2i_resize_mode is not None:
700
+ resize_mode = external_code.resize_mode_from_value(a1111_i2i_resize_mode)
701
+
702
+ a1111_mask_image : Optional[Image.Image] = getattr(p, "image_mask", None)
703
+ if 'inpaint' in unit.module and not image_has_mask(input_image) and a1111_mask_image is not None:
704
+ a1111_mask = np.array(prepare_mask(a1111_mask_image, p))
705
+ if a1111_mask.ndim == 2:
706
+ if a1111_mask.shape[0] == input_image.shape[0]:
707
+ if a1111_mask.shape[1] == input_image.shape[1]:
708
+ input_image = np.concatenate([input_image[:, :, 0:3], a1111_mask[:, :, None]], axis=2)
709
+ a1111_i2i_resize_mode = getattr(p, "resize_mode", None)
710
+ if a1111_i2i_resize_mode is not None:
711
+ resize_mode = external_code.resize_mode_from_value(a1111_i2i_resize_mode)
712
+
713
+ if 'reference' not in unit.module and issubclass(type(p), StableDiffusionProcessingImg2Img) \
714
+ and p.inpaint_full_res and a1111_mask_image is not None:
715
+ logger.debug("A1111 inpaint mask START")
716
+ input_image = [input_image[:, :, i] for i in range(input_image.shape[2])]
717
+ input_image = [Image.fromarray(x) for x in input_image]
718
+
719
+ mask = prepare_mask(a1111_mask_image, p)
720
+
721
+ crop_region = masking.get_crop_region(np.array(mask), p.inpaint_full_res_padding)
722
+ crop_region = masking.expand_crop_region(crop_region, p.width, p.height, mask.width, mask.height)
723
+
724
+ input_image = [
725
+ images.resize_image(resize_mode.int_value(), i, mask.width, mask.height)
726
+ for i in input_image
727
+ ]
728
+
729
+ input_image = [x.crop(crop_region) for x in input_image]
730
+ input_image = [
731
+ images.resize_image(external_code.ResizeMode.OUTER_FIT.int_value(), x, p.width, p.height)
732
+ for x in input_image
733
+ ]
734
+
735
+ input_image = [np.asarray(x)[:, :, 0] for x in input_image]
736
+ input_image = np.stack(input_image, axis=2)
737
+ logger.debug("A1111 inpaint mask END")
738
+
739
+ if 'inpaint_only' == unit.module and issubclass(type(p), StableDiffusionProcessingImg2Img) and p.image_mask is not None:
740
+ logger.warning('A1111 inpaint and ControlNet inpaint duplicated. ControlNet support enabled.')
741
+ unit.module = 'inpaint'
742
+
743
+ # safe numpy
744
+ logger.debug("Safe numpy convertion START")
745
+ input_image = np.ascontiguousarray(input_image.copy()).copy()
746
+ logger.debug("Safe numpy convertion END")
747
+
748
+ logger.info(f"Loading preprocessor: {unit.module}")
749
+ preprocessor = self.preprocessor[unit.module]
750
+
751
+ high_res_fix = isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False)
752
+
753
+ h = (p.height // 8) * 8
754
+ w = (p.width // 8) * 8
755
+
756
+ if high_res_fix:
757
+ if p.hr_resize_x == 0 and p.hr_resize_y == 0:
758
+ hr_y = int(p.height * p.hr_scale)
759
+ hr_x = int(p.width * p.hr_scale)
760
+ else:
761
+ hr_y, hr_x = p.hr_resize_y, p.hr_resize_x
762
+ hr_y = (hr_y // 8) * 8
763
+ hr_x = (hr_x // 8) * 8
764
+ else:
765
+ hr_y = h
766
+ hr_x = w
767
+
768
+ if unit.module == 'inpaint_only+lama' and resize_mode == external_code.ResizeMode.OUTER_FIT:
769
+ # inpaint_only+lama is special and required outpaint fix
770
+ _, input_image = Script.detectmap_proc(input_image, unit.module, resize_mode, hr_y, hr_x)
771
+
772
+ control_model_type = ControlModelType.ControlNet
773
+ global_average_pooling = False
774
+
775
+ if 'reference' in unit.module:
776
+ control_model_type = ControlModelType.AttentionInjection
777
+ elif 'revision' in unit.module:
778
+ control_model_type = ControlModelType.ReVision
779
+ elif hasattr(model_net, 'control_model') and (isinstance(model_net.control_model, Adapter) or isinstance(model_net.control_model, Adapter_light)):
780
+ control_model_type = ControlModelType.T2I_Adapter
781
+ elif hasattr(model_net, 'control_model') and isinstance(model_net.control_model, StyleAdapter):
782
+ control_model_type = ControlModelType.T2I_StyleAdapter
783
+ elif isinstance(model_net, PlugableIPAdapter):
784
+ control_model_type = ControlModelType.IPAdapter
785
+ elif isinstance(model_net, PlugableControlLLLite):
786
+ control_model_type = ControlModelType.Controlllite
787
+
788
+ if control_model_type is ControlModelType.ControlNet:
789
+ global_average_pooling = model_net.control_model.global_average_pooling
790
+
791
+ preprocessor_resolution = unit.processor_res
792
+ if unit.pixel_perfect:
793
+ preprocessor_resolution = external_code.pixel_perfect_resolution(
794
+ input_image,
795
+ target_H=h,
796
+ target_W=w,
797
+ resize_mode=resize_mode
798
+ )
799
+
800
+ logger.info(f'preprocessor resolution = {preprocessor_resolution}')
801
+ # Preprocessor result may depend on numpy random operations, use the
802
+ # random seed in `StableDiffusionProcessing` to make the
803
+ # preprocessor result reproducable.
804
+ # Currently following preprocessors use numpy random:
805
+ # - shuffle
806
+ seed = set_numpy_seed(p)
807
+ logger.debug(f"Use numpy seed {seed}.")
808
+ detected_map, is_image = preprocessor(
809
+ input_image,
810
+ res=preprocessor_resolution,
811
+ thr_a=unit.threshold_a,
812
+ thr_b=unit.threshold_b,
813
+ )
814
+
815
+ if high_res_fix:
816
+ if is_image:
817
+ hr_control, hr_detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, hr_y, hr_x)
818
+ detected_maps.append((hr_detected_map, unit.module))
819
+ else:
820
+ hr_control = detected_map
821
+ else:
822
+ hr_control = None
823
+
824
+ if is_image:
825
+ control, detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, h, w)
826
+ detected_maps.append((detected_map, unit.module))
827
+ else:
828
+ control = detected_map
829
+ detected_maps.append((input_image, unit.module))
830
+
831
+ if control_model_type == ControlModelType.T2I_StyleAdapter:
832
+ control = control['last_hidden_state']
833
+
834
+ if control_model_type == ControlModelType.ReVision:
835
+ control = control['image_embeds']
836
+
837
+ preprocessor_dict = dict(
838
+ name=unit.module,
839
+ preprocessor_resolution=preprocessor_resolution,
840
+ threshold_a=unit.threshold_a,
841
+ threshold_b=unit.threshold_b
842
+ )
843
+
844
+ forward_param = ControlParams(
845
+ control_model=model_net,
846
+ preprocessor=preprocessor_dict,
847
+ hint_cond=control,
848
+ weight=unit.weight,
849
+ guidance_stopped=False,
850
+ start_guidance_percent=unit.guidance_start,
851
+ stop_guidance_percent=unit.guidance_end,
852
+ advanced_weighting=None,
853
+ control_model_type=control_model_type,
854
+ global_average_pooling=global_average_pooling,
855
+ hr_hint_cond=hr_control,
856
+ soft_injection=control_mode != external_code.ControlMode.BALANCED,
857
+ cfg_injection=control_mode == external_code.ControlMode.CONTROL,
858
+ )
859
+ forward_params.append(forward_param)
860
+
861
+ if 'inpaint_only' in unit.module:
862
+ final_inpaint_feed = hr_control if hr_control is not None else control
863
+ final_inpaint_feed = final_inpaint_feed.detach().cpu().numpy()
864
+ final_inpaint_feed = np.ascontiguousarray(final_inpaint_feed).copy()
865
+ final_inpaint_mask = final_inpaint_feed[0, 3, :, :].astype(np.float32)
866
+ final_inpaint_raw = final_inpaint_feed[0, :3].astype(np.float32)
867
+ sigma = shared.opts.data.get("control_net_inpaint_blur_sigma", 7)
868
+ final_inpaint_mask = cv2.dilate(final_inpaint_mask, np.ones((sigma, sigma), dtype=np.uint8))
869
+ final_inpaint_mask = cv2.blur(final_inpaint_mask, (sigma, sigma))[None]
870
+ _, Hmask, Wmask = final_inpaint_mask.shape
871
+ final_inpaint_raw = torch.from_numpy(np.ascontiguousarray(final_inpaint_raw).copy())
872
+ final_inpaint_mask = torch.from_numpy(np.ascontiguousarray(final_inpaint_mask).copy())
873
+
874
+ def inpaint_only_post_processing(x):
875
+ _, H, W = x.shape
876
+ if Hmask != H or Wmask != W:
877
+ logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.')
878
+ return x
879
+ r = final_inpaint_raw.to(x.dtype).to(x.device)
880
+ m = final_inpaint_mask.to(x.dtype).to(x.device)
881
+ y = m * x.clip(0, 1) + (1 - m) * r
882
+ y = y.clip(0, 1)
883
+ return y
884
+
885
+ post_processors.append(inpaint_only_post_processing)
886
+
887
+ if 'recolor' in unit.module:
888
+ final_feed = hr_control if hr_control is not None else control
889
+ final_feed = final_feed.detach().cpu().numpy()
890
+ final_feed = np.ascontiguousarray(final_feed).copy()
891
+ final_feed = final_feed[0, 0, :, :].astype(np.float32)
892
+ final_feed = (final_feed * 255).clip(0, 255).astype(np.uint8)
893
+ Hfeed, Wfeed = final_feed.shape
894
+
895
+ if 'luminance' in unit.module:
896
+
897
+ def recolor_luminance_post_processing(x):
898
+ C, H, W = x.shape
899
+ if Hfeed != H or Wfeed != W or C != 3:
900
+ logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.')
901
+ return x
902
+ h = x.detach().cpu().numpy().transpose((1, 2, 0))
903
+ h = (h * 255).clip(0, 255).astype(np.uint8)
904
+ h = cv2.cvtColor(h, cv2.COLOR_RGB2LAB)
905
+ h[:, :, 0] = final_feed
906
+ h = cv2.cvtColor(h, cv2.COLOR_LAB2RGB)
907
+ h = (h.astype(np.float32) / 255.0).transpose((2, 0, 1))
908
+ y = torch.from_numpy(h).clip(0, 1).to(x)
909
+ return y
910
+
911
+ post_processors.append(recolor_luminance_post_processing)
912
+
913
+ if 'intensity' in unit.module:
914
+
915
+ def recolor_intensity_post_processing(x):
916
+ C, H, W = x.shape
917
+ if Hfeed != H or Wfeed != W or C != 3:
918
+ logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.')
919
+ return x
920
+ h = x.detach().cpu().numpy().transpose((1, 2, 0))
921
+ h = (h * 255).clip(0, 255).astype(np.uint8)
922
+ h = cv2.cvtColor(h, cv2.COLOR_RGB2HSV)
923
+ h[:, :, 2] = final_feed
924
+ h = cv2.cvtColor(h, cv2.COLOR_HSV2RGB)
925
+ h = (h.astype(np.float32) / 255.0).transpose((2, 0, 1))
926
+ y = torch.from_numpy(h).clip(0, 1).to(x)
927
+ return y
928
+
929
+ post_processors.append(recolor_intensity_post_processing)
930
+
931
+ if '+lama' in unit.module:
932
+ forward_param.used_hint_cond_latent = hook.UnetHook.call_vae_using_process(p, control)
933
+ self.noise_modifier = forward_param.used_hint_cond_latent
934
+
935
+ del model_net
936
+
937
+ is_low_vram = any(unit.low_vram for unit in self.enabled_units)
938
+
939
+ self.latest_network = UnetHook(lowvram=is_low_vram)
940
+ self.latest_network.hook(model=unet, sd_ldm=sd_ldm, control_params=forward_params, process=p)
941
+
942
+ for param in forward_params:
943
+ if param.control_model_type == ControlModelType.IPAdapter:
944
+ param.control_model.hook(
945
+ model=unet,
946
+ clip_vision_output=param.hint_cond,
947
+ weight=param.weight,
948
+ dtype=torch.float32,
949
+ start=param.start_guidance_percent,
950
+ end=param.stop_guidance_percent
951
+ )
952
+ if param.control_model_type == ControlModelType.Controlllite:
953
+ param.control_model.hook(
954
+ model=unet,
955
+ cond=param.hint_cond,
956
+ weight=param.weight,
957
+ start=param.start_guidance_percent,
958
+ end=param.stop_guidance_percent
959
+ )
960
+
961
+ self.detected_map = detected_maps
962
+ self.post_processors = post_processors
963
+
964
+ def controlnet_hack(self, p):
965
+ t = time.time()
966
+ self.controlnet_main_entry(p)
967
+ if len(self.enabled_units) > 0:
968
+ logger.info(f'ControlNet Hooked - Time = {time.time() - t}')
969
+ return
970
+
971
+ @staticmethod
972
+ def process_has_sdxl_refiner(p):
973
+ return getattr(p, 'refiner_checkpoint', None) is not None
974
+
975
+ def process(self, p, *args, **kwargs):
976
+ if not Script.process_has_sdxl_refiner(p):
977
+ self.controlnet_hack(p)
978
+ return
979
+
980
+ def before_process_batch(self, p, *args, **kwargs):
981
+ if self.noise_modifier is not None:
982
+ p.rng = HackedImageRNG(rng=p.rng,
983
+ noise_modifier=self.noise_modifier,
984
+ sd_model=p.sd_model)
985
+ self.noise_modifier = None
986
+ if Script.process_has_sdxl_refiner(p):
987
+ self.controlnet_hack(p)
988
+ return
989
+
990
+ def postprocess_batch(self, p, *args, **kwargs):
991
+ images = kwargs.get('images', [])
992
+ for post_processor in self.post_processors:
993
+ for i in range(len(images)):
994
+ images[i] = post_processor(images[i])
995
+ return
996
+
997
+ def postprocess(self, p, processed, *args):
998
+ clear_all_secondary_control_models()
999
+
1000
+ self.noise_modifier = None
1001
+
1002
+ for control_lora in getattr(p, 'controlnet_control_loras', []):
1003
+ unbind_control_lora(control_lora)
1004
+ p.controlnet_control_loras = []
1005
+
1006
+ self.post_processors = []
1007
+ setattr(p, 'controlnet_vae_cache', None)
1008
+
1009
+ processor_params_flag = (', '.join(getattr(processed, 'extra_generation_params', []))).lower()
1010
+ self.post_processors = []
1011
+
1012
+ if not batch_hijack.instance.is_batch:
1013
+ self.enabled_units.clear()
1014
+
1015
+ if shared.opts.data.get("control_net_detectmap_autosaving", False) and self.latest_network is not None:
1016
+ for detect_map, module in self.detected_map:
1017
+ detectmap_dir = os.path.join(shared.opts.data.get("control_net_detectedmap_dir", ""), module)
1018
+ if not os.path.isabs(detectmap_dir):
1019
+ detectmap_dir = os.path.join(p.outpath_samples, detectmap_dir)
1020
+ if module != "none":
1021
+ os.makedirs(detectmap_dir, exist_ok=True)
1022
+ img = Image.fromarray(np.ascontiguousarray(detect_map.clip(0, 255).astype(np.uint8)).copy())
1023
+ save_image(img, detectmap_dir, module)
1024
+
1025
+ if self.latest_network is None:
1026
+ return
1027
+
1028
+ if not batch_hijack.instance.is_batch:
1029
+ if not shared.opts.data.get("control_net_no_detectmap", False):
1030
+ if 'sd upscale' not in processor_params_flag:
1031
+ if self.detected_map is not None:
1032
+ for detect_map, module in self.detected_map:
1033
+ if detect_map is None:
1034
+ continue
1035
+ detect_map = np.ascontiguousarray(detect_map.copy()).copy()
1036
+ detect_map = external_code.visualize_inpaint_mask(detect_map)
1037
+ processed.images.extend([
1038
+ Image.fromarray(
1039
+ detect_map.clip(0, 255).astype(np.uint8)
1040
+ )
1041
+ ])
1042
+
1043
+ self.input_image = None
1044
+ self.latest_network.restore()
1045
+ self.latest_network = None
1046
+ self.detected_map.clear()
1047
+
1048
+ gc.collect()
1049
+ devices.torch_gc()
1050
+
1051
+ def batch_tab_process(self, p, batches, *args, **kwargs):
1052
+ self.enabled_units = self.get_enabled_units(p)
1053
+ for unit_i, unit in enumerate(self.enabled_units):
1054
+ unit.batch_images = iter([batch[unit_i] for batch in batches])
1055
+
1056
+ def batch_tab_process_each(self, p, *args, **kwargs):
1057
+ for unit_i, unit in enumerate(self.enabled_units):
1058
+ if getattr(unit, 'loopback', False) and batch_hijack.instance.batch_index > 0: continue
1059
+
1060
+ unit.image = next(unit.batch_images)
1061
+
1062
+ def batch_tab_postprocess_each(self, p, processed, *args, **kwargs):
1063
+ for unit_i, unit in enumerate(self.enabled_units):
1064
+ if getattr(unit, 'loopback', False):
1065
+ output_images = getattr(processed, 'images', [])[processed.index_of_first_image:]
1066
+ if output_images:
1067
+ unit.image = np.array(output_images[0])
1068
+ else:
1069
+ logger.warning(f'Warning: No loopback image found for controlnet unit {unit_i}. Using control map from last batch iteration instead')
1070
+
1071
+ def batch_tab_postprocess(self, p, *args, **kwargs):
1072
+ self.enabled_units.clear()
1073
+ self.input_image = None
1074
+ if self.latest_network is None: return
1075
+
1076
+ self.latest_network.restore()
1077
+ self.latest_network = None
1078
+ self.detected_map.clear()
1079
+
1080
+
1081
+ def on_ui_settings():
1082
+ section = ('control_net', "ControlNet")
1083
+ shared.opts.add_option("control_net_detectedmap_dir", shared.OptionInfo(
1084
+ global_state.default_detectedmap_dir, "Directory for detected maps auto saving", section=section))
1085
+ shared.opts.add_option("control_net_models_path", shared.OptionInfo(
1086
+ "", "Extra path to scan for ControlNet models (e.g. training output directory)", section=section))
1087
+ shared.opts.add_option("control_net_modules_path", shared.OptionInfo(
1088
+ "", "Path to directory containing annotator model directories (requires restart, overrides corresponding command line flag)", section=section))
1089
+ shared.opts.add_option("control_net_unit_count", shared.OptionInfo(
1090
+ 3, "Multi-ControlNet: ControlNet unit number (requires restart)", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}, section=section))
1091
+ shared.opts.add_option("control_net_model_cache_size", shared.OptionInfo(
1092
+ 1, "Model cache size (requires restart)", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}, section=section))
1093
+ shared.opts.add_option("control_net_inpaint_blur_sigma", shared.OptionInfo(
1094
+ 7, "ControlNet inpainting Gaussian blur sigma", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=section))
1095
+ shared.opts.add_option("control_net_no_high_res_fix", shared.OptionInfo(
1096
+ False, "Do not apply ControlNet during highres fix", gr.Checkbox, {"interactive": True}, section=section))
1097
+ shared.opts.add_option("control_net_no_detectmap", shared.OptionInfo(
1098
+ False, "Do not append detectmap to output", gr.Checkbox, {"interactive": True}, section=section))
1099
+ shared.opts.add_option("control_net_detectmap_autosaving", shared.OptionInfo(
1100
+ False, "Allow detectmap auto saving", gr.Checkbox, {"interactive": True}, section=section))
1101
+ shared.opts.add_option("control_net_allow_script_control", shared.OptionInfo(
1102
+ True, "Allow other script to control this extension", gr.Checkbox, {"interactive": True}, section=section))
1103
+ shared.opts.add_option("control_net_sync_field_args", shared.OptionInfo(
1104
+ False, "Paste ControlNet parameters in infotext", gr.Checkbox, {"interactive": True}, section=section))
1105
+ shared.opts.add_option("controlnet_show_batch_images_in_ui", shared.OptionInfo(
1106
+ False, "Show batch images in gradio gallery output", gr.Checkbox, {"interactive": True}, section=section))
1107
+ shared.opts.add_option("controlnet_increment_seed_during_batch", shared.OptionInfo(
1108
+ False, "Increment seed after each controlnet batch iteration", gr.Checkbox, {"interactive": True}, section=section))
1109
+ shared.opts.add_option("controlnet_disable_control_type", shared.OptionInfo(
1110
+ False, "Disable control type selection", gr.Checkbox, {"interactive": True}, section=section))
1111
+ shared.opts.add_option("controlnet_disable_openpose_edit", shared.OptionInfo(
1112
+ False, "Disable openpose edit", gr.Checkbox, {"interactive": True}, section=section))
1113
+ shared.opts.add_option("controlnet_ignore_noninpaint_mask", shared.OptionInfo(
1114
+ False, "Ignore mask on ControlNet input image if control type is not inpaint",
1115
+ gr.Checkbox, {"interactive": True}, section=section))
1116
+
1117
+
1118
+ batch_hijack.instance.do_hijack()
1119
+ script_callbacks.on_ui_settings(on_ui_settings)
1120
+ script_callbacks.on_infotext_pasted(Infotext.on_infotext_pasted)
1121
+ script_callbacks.on_after_component(ControlNetUiGroup.on_after_component)