Asif782 commited on
Commit
f2b239d
1 Parent(s): c9df518

Upload !adetailer.py

Browse files
Files changed (1) hide show
  1. !adetailer.py +1000 -0
!adetailer.py ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import platform
5
+ import re
6
+ import sys
7
+ import traceback
8
+ from contextlib import contextmanager, suppress
9
+ from copy import copy
10
+ from functools import partial
11
+ from pathlib import Path
12
+ from textwrap import dedent
13
+ from typing import TYPE_CHECKING, Any, NamedTuple
14
+
15
+ import gradio as gr
16
+ import torch
17
+ from PIL import Image
18
+ from rich import print
19
+ from torchvision.transforms.functional import to_pil_image
20
+
21
+ import modules
22
+ from adetailer import (
23
+ AFTER_DETAILER,
24
+ __version__,
25
+ get_models,
26
+ mediapipe_predict,
27
+ ultralytics_predict,
28
+ )
29
+ from adetailer.args import ALL_ARGS, BBOX_SORTBY, ADetailerArgs, SkipImg2ImgOrig
30
+ from adetailer.common import PredictOutput
31
+ from adetailer.mask import (
32
+ filter_by_ratio,
33
+ filter_k_largest,
34
+ mask_preprocess,
35
+ sort_bboxes,
36
+ )
37
+ from adetailer.traceback import rich_traceback
38
+ from adetailer.ui import WebuiInfo, adui, ordinal, suffix
39
+ from controlnet_ext import ControlNetExt, controlnet_exists, get_cn_models
40
+ from controlnet_ext.restore import (
41
+ CNHijackRestore,
42
+ cn_allow_script_control,
43
+ )
44
+ from modules import images, paths, safe, script_callbacks, scripts, shared
45
+ from modules.devices import NansException
46
+ from modules.processing import (
47
+ Processed,
48
+ StableDiffusionProcessingImg2Img,
49
+ create_infotext,
50
+ process_images,
51
+ )
52
+ from modules.sd_samplers import all_samplers
53
+ from modules.shared import cmd_opts, opts, state
54
+
55
+ if TYPE_CHECKING:
56
+ from fastapi import FastAPI
57
+
58
+ no_huggingface = getattr(cmd_opts, "ad_no_huggingface", False)
59
+ adetailer_dir = Path(paths.models_path, "adetailer")
60
+ extra_models_dir = shared.opts.data.get("ad_extra_models_dir", "")
61
+ model_mapping = get_models(
62
+ adetailer_dir, extra_dir=extra_models_dir, huggingface=not no_huggingface
63
+ )
64
+ txt2img_submit_button = img2img_submit_button = None
65
+ SCRIPT_DEFAULT = "dynamic_prompting,dynamic_thresholding,wildcard_recursive,wildcards,lora_block_weight,negpip"
66
+
67
+ if (
68
+ not adetailer_dir.exists()
69
+ and adetailer_dir.parent.exists()
70
+ and os.access(adetailer_dir.parent, os.W_OK)
71
+ ):
72
+ adetailer_dir.mkdir()
73
+
74
+ print(
75
+ f"[-] ADetailer initialized. version: {__version__}, num models: {len(model_mapping)}"
76
+ )
77
+
78
+
79
+ @contextmanager
80
+ def change_torch_load():
81
+ orig = torch.load
82
+ try:
83
+ torch.load = safe.unsafe_torch_load
84
+ yield
85
+ finally:
86
+ torch.load = orig
87
+
88
+
89
+ @contextmanager
90
+ def pause_total_tqdm():
91
+ orig = opts.data.get("multiple_tqdm", True)
92
+ try:
93
+ opts.data["multiple_tqdm"] = False
94
+ yield
95
+ finally:
96
+ opts.data["multiple_tqdm"] = orig
97
+
98
+
99
+ @contextmanager
100
+ def preseve_prompts(p):
101
+ all_pt = copy(p.all_prompts)
102
+ all_ng = copy(p.all_negative_prompts)
103
+ try:
104
+ yield
105
+ finally:
106
+ p.all_prompts = all_pt
107
+ p.all_negative_prompts = all_ng
108
+
109
+
110
+ class AfterDetailerScript(scripts.Script):
111
+ def __init__(self):
112
+ super().__init__()
113
+ self.ultralytics_device = self.get_ultralytics_device()
114
+
115
+ self.controlnet_ext = None
116
+
117
+ def __repr__(self):
118
+ return f"{self.__class__.__name__}(version={__version__})"
119
+
120
+ def title(self):
121
+ return AFTER_DETAILER
122
+
123
+ def show(self, is_img2img):
124
+ return scripts.AlwaysVisible
125
+
126
+ def ui(self, is_img2img):
127
+ num_models = opts.data.get("ad_max_models", 2)
128
+ ad_model_list = list(model_mapping.keys())
129
+ sampler_names = [sampler.name for sampler in all_samplers]
130
+
131
+ try:
132
+ checkpoint_list = modules.sd_models.checkpoint_tiles(use_shorts=True)
133
+ except TypeError:
134
+ checkpoint_list = modules.sd_models.checkpoint_tiles()
135
+ vae_list = modules.shared_items.sd_vae_items()
136
+
137
+ webui_info = WebuiInfo(
138
+ ad_model_list=ad_model_list,
139
+ sampler_names=sampler_names,
140
+ t2i_button=txt2img_submit_button,
141
+ i2i_button=img2img_submit_button,
142
+ checkpoints_list=checkpoint_list,
143
+ vae_list=vae_list,
144
+ )
145
+
146
+ components, infotext_fields = adui(num_models, is_img2img, webui_info)
147
+
148
+ self.infotext_fields = infotext_fields
149
+ return components
150
+
151
+ def init_controlnet_ext(self) -> None:
152
+ if self.controlnet_ext is not None:
153
+ return
154
+ self.controlnet_ext = ControlNetExt()
155
+
156
+ if controlnet_exists:
157
+ try:
158
+ self.controlnet_ext.init_controlnet()
159
+ except ImportError:
160
+ error = traceback.format_exc()
161
+ print(
162
+ f"[-] ADetailer: ControlNetExt init failed:\n{error}",
163
+ file=sys.stderr,
164
+ )
165
+
166
+ def update_controlnet_args(self, p, args: ADetailerArgs) -> None:
167
+ if self.controlnet_ext is None:
168
+ self.init_controlnet_ext()
169
+
170
+ if (
171
+ self.controlnet_ext is not None
172
+ and self.controlnet_ext.cn_available
173
+ and args.ad_controlnet_model != "None"
174
+ ):
175
+ self.controlnet_ext.update_scripts_args(
176
+ p,
177
+ model=args.ad_controlnet_model,
178
+ module=args.ad_controlnet_module,
179
+ weight=args.ad_controlnet_weight,
180
+ guidance_start=args.ad_controlnet_guidance_start,
181
+ guidance_end=args.ad_controlnet_guidance_end,
182
+ )
183
+
184
+ def is_ad_enabled(self, *args_) -> bool:
185
+ arg_list = [arg for arg in args_ if isinstance(arg, dict)]
186
+ if not args_ or not arg_list:
187
+ message = f"""
188
+ [-] ADetailer: Invalid arguments passed to ADetailer.
189
+ input: {args_!r}
190
+ ADetailer disabled.
191
+ """
192
+ print(dedent(message), file=sys.stderr)
193
+ return False
194
+
195
+ ad_enabled = args_[0] if isinstance(args_[0], bool) else True
196
+ not_none = any(arg.get("ad_model", "None") != "None" for arg in arg_list)
197
+ return ad_enabled and not_none
198
+
199
+ def check_skip_img2img(self, p, *args_) -> None:
200
+ if (
201
+ hasattr(p, "_ad_skip_img2img")
202
+ or not hasattr(p, "init_images")
203
+ or not p.init_images
204
+ ):
205
+ return
206
+
207
+ if len(args_) >= 2 and isinstance(args_[1], bool):
208
+ p._ad_skip_img2img = args_[1]
209
+ if args_[1]:
210
+ p._ad_orig = SkipImg2ImgOrig(
211
+ steps=p.steps,
212
+ sampler_name=p.sampler_name,
213
+ width=p.width,
214
+ height=p.height,
215
+ )
216
+ p.steps = 1
217
+ p.sampler_name = "Euler"
218
+ p.width = 128
219
+ p.height = 128
220
+ else:
221
+ p._ad_skip_img2img = False
222
+
223
+ @staticmethod
224
+ def get_i(p) -> int:
225
+ it = p.iteration
226
+ bs = p.batch_size
227
+ i = p.batch_index
228
+ return it * bs + i
229
+
230
+ def get_args(self, p, *args_) -> list[ADetailerArgs]:
231
+ """
232
+ `args_` is at least 1 in length by `is_ad_enabled` immediately above
233
+ """
234
+ args = [arg for arg in args_ if isinstance(arg, dict)]
235
+
236
+ if not args:
237
+ message = f"[-] ADetailer: Invalid arguments passed to ADetailer: {args_!r}"
238
+ raise ValueError(message)
239
+
240
+ if hasattr(p, "_ad_xyz"):
241
+ args[0] = {**args[0], **p._ad_xyz}
242
+
243
+ all_inputs = []
244
+
245
+ for n, arg_dict in enumerate(args, 1):
246
+ try:
247
+ inp = ADetailerArgs(**arg_dict)
248
+ except ValueError as e:
249
+ msgs = [
250
+ f"[-] ADetailer: ValidationError when validating {ordinal(n)} arguments: {e}\n"
251
+ ]
252
+ for attr in ALL_ARGS.attrs:
253
+ arg = arg_dict.get(attr)
254
+ dtype = type(arg)
255
+ arg = "DEFAULT" if arg is None else repr(arg)
256
+ msgs.append(f" {attr}: {arg} ({dtype})")
257
+ raise ValueError("\n".join(msgs)) from e
258
+
259
+ all_inputs.append(inp)
260
+
261
+ return all_inputs
262
+
263
+ def extra_params(self, arg_list: list[ADetailerArgs]) -> dict:
264
+ params = {}
265
+ for n, args in enumerate(arg_list):
266
+ params.update(args.extra_params(suffix=suffix(n)))
267
+ params["ADetailer version"] = __version__
268
+ return params
269
+
270
+ @staticmethod
271
+ def get_ultralytics_device() -> str:
272
+ if "adetailer" in shared.cmd_opts.use_cpu:
273
+ return "cpu"
274
+
275
+ if platform.system() == "Darwin":
276
+ return ""
277
+
278
+ vram_args = ["lowvram", "medvram", "medvram_sdxl"]
279
+ if any(getattr(cmd_opts, vram, False) for vram in vram_args):
280
+ return "cpu"
281
+
282
+ return ""
283
+
284
+ def prompt_blank_replacement(
285
+ self, all_prompts: list[str], i: int, default: str
286
+ ) -> str:
287
+ if not all_prompts:
288
+ return default
289
+ if i < len(all_prompts):
290
+ return all_prompts[i]
291
+ j = i % len(all_prompts)
292
+ return all_prompts[j]
293
+
294
+ def _get_prompt(
295
+ self,
296
+ ad_prompt: str,
297
+ all_prompts: list[str],
298
+ i: int,
299
+ default: str,
300
+ replacements: list[PromptSR],
301
+ ) -> list[str]:
302
+ prompts = re.split(r"\s*\[SEP\]\s*", ad_prompt)
303
+ blank_replacement = self.prompt_blank_replacement(all_prompts, i, default)
304
+ for n in range(len(prompts)):
305
+ if not prompts[n]:
306
+ prompts[n] = blank_replacement
307
+ elif "[PROMPT]" in prompts[n]:
308
+ prompts[n] = prompts[n].replace("[PROMPT]", f" {blank_replacement} ")
309
+
310
+ for pair in replacements:
311
+ prompts[n] = prompts[n].replace(pair.s, pair.r)
312
+ return prompts
313
+
314
+ def get_prompt(self, p, args: ADetailerArgs) -> tuple[list[str], list[str]]:
315
+ i = self.get_i(p)
316
+ prompt_sr = p._ad_xyz_prompt_sr if hasattr(p, "_ad_xyz_prompt_sr") else []
317
+
318
+ prompt = self._get_prompt(args.ad_prompt, p.all_prompts, i, p.prompt, prompt_sr)
319
+ negative_prompt = self._get_prompt(
320
+ args.ad_negative_prompt,
321
+ p.all_negative_prompts,
322
+ i,
323
+ p.negative_prompt,
324
+ prompt_sr,
325
+ )
326
+
327
+ return prompt, negative_prompt
328
+
329
+ def get_seed(self, p) -> tuple[int, int]:
330
+ i = self.get_i(p)
331
+
332
+ if not p.all_seeds:
333
+ seed = p.seed
334
+ elif i < len(p.all_seeds):
335
+ seed = p.all_seeds[i]
336
+ else:
337
+ j = i % len(p.all_seeds)
338
+ seed = p.all_seeds[j]
339
+
340
+ if not p.all_subseeds:
341
+ subseed = p.subseed
342
+ elif i < len(p.all_subseeds):
343
+ subseed = p.all_subseeds[i]
344
+ else:
345
+ j = i % len(p.all_subseeds)
346
+ subseed = p.all_subseeds[j]
347
+
348
+ return seed, subseed
349
+
350
+ def get_width_height(self, p, args: ADetailerArgs) -> tuple[int, int]:
351
+ if args.ad_use_inpaint_width_height:
352
+ width = args.ad_inpaint_width
353
+ height = args.ad_inpaint_height
354
+ elif hasattr(p, "_ad_orig"):
355
+ width = p._ad_orig.width
356
+ height = p._ad_orig.height
357
+ else:
358
+ width = p.width
359
+ height = p.height
360
+
361
+ return width, height
362
+
363
+ def get_steps(self, p, args: ADetailerArgs) -> int:
364
+ if args.ad_use_steps:
365
+ return args.ad_steps
366
+ if hasattr(p, "_ad_orig"):
367
+ return p._ad_orig.steps
368
+ return p.steps
369
+
370
+ def get_cfg_scale(self, p, args: ADetailerArgs) -> float:
371
+ return args.ad_cfg_scale if args.ad_use_cfg_scale else p.cfg_scale
372
+
373
+ def get_sampler(self, p, args: ADetailerArgs) -> str:
374
+ if args.ad_use_sampler:
375
+ return args.ad_sampler
376
+ if hasattr(p, "_ad_orig"):
377
+ return p._ad_orig.sampler_name
378
+ return p.sampler_name
379
+
380
+ def get_override_settings(self, p, args: ADetailerArgs) -> dict[str, Any]:
381
+ d = {}
382
+
383
+ if args.ad_use_clip_skip:
384
+ d["CLIP_stop_at_last_layers"] = args.ad_clip_skip
385
+
386
+ if (
387
+ args.ad_use_checkpoint
388
+ and args.ad_checkpoint
389
+ and args.ad_checkpoint not in ("None", "Use same checkpoint")
390
+ ):
391
+ d["sd_model_checkpoint"] = args.ad_checkpoint
392
+
393
+ if (
394
+ args.ad_use_vae
395
+ and args.ad_vae
396
+ and args.ad_vae not in ("None", "Use same VAE")
397
+ ):
398
+ d["sd_vae"] = args.ad_vae
399
+ return d
400
+
401
+ def get_initial_noise_multiplier(self, p, args: ADetailerArgs) -> float | None:
402
+ return args.ad_noise_multiplier if args.ad_use_noise_multiplier else None
403
+
404
+ @staticmethod
405
+ def infotext(p) -> str:
406
+ return create_infotext(
407
+ p, p.all_prompts, p.all_seeds, p.all_subseeds, None, 0, 0
408
+ )
409
+
410
+ def write_params_txt(self, content: str) -> None:
411
+ params_txt = Path(paths.data_path, "params.txt")
412
+ with suppress(Exception):
413
+ params_txt.write_text(content, encoding="utf-8")
414
+
415
+ @staticmethod
416
+ def script_args_copy(script_args):
417
+ type_: type[list] | type[tuple] = type(script_args)
418
+ result = []
419
+ for arg in script_args:
420
+ try:
421
+ a = copy(arg)
422
+ except TypeError:
423
+ a = arg
424
+ result.append(a)
425
+ return type_(result)
426
+
427
+ def script_filter(self, p, args: ADetailerArgs):
428
+ script_runner = copy(p.scripts)
429
+ script_args = self.script_args_copy(p.script_args)
430
+
431
+ ad_only_seleted_scripts = opts.data.get("ad_only_seleted_scripts", True)
432
+ if not ad_only_seleted_scripts:
433
+ return script_runner, script_args
434
+
435
+ ad_script_names = opts.data.get("ad_script_names", SCRIPT_DEFAULT)
436
+ script_names_set = {
437
+ name
438
+ for script_name in ad_script_names.split(",")
439
+ for name in (script_name, script_name.strip())
440
+ }
441
+
442
+ if args.ad_controlnet_model != "None":
443
+ script_names_set.add("controlnet")
444
+
445
+ filtered_alwayson = []
446
+ for script_object in script_runner.alwayson_scripts:
447
+ filepath = script_object.filename
448
+ filename = Path(filepath).stem
449
+ if filename in script_names_set:
450
+ filtered_alwayson.append(script_object)
451
+
452
+ script_runner.alwayson_scripts = filtered_alwayson
453
+ return script_runner, script_args
454
+
455
+ def disable_controlnet_units(
456
+ self, script_args: list[Any] | tuple[Any, ...]
457
+ ) -> None:
458
+ for obj in script_args:
459
+ if "controlnet" in obj.__class__.__name__.lower():
460
+ if hasattr(obj, "enabled"):
461
+ obj.enabled = False
462
+ if hasattr(obj, "input_mode"):
463
+ obj.input_mode = getattr(obj.input_mode, "SIMPLE", "simple")
464
+
465
+ elif isinstance(obj, dict) and "module" in obj:
466
+ obj["enabled"] = False
467
+
468
+ def get_i2i_p(self, p, args: ADetailerArgs, image):
469
+ seed, subseed = self.get_seed(p)
470
+ width, height = self.get_width_height(p, args)
471
+ steps = self.get_steps(p, args)
472
+ cfg_scale = self.get_cfg_scale(p, args)
473
+ initial_noise_multiplier = self.get_initial_noise_multiplier(p, args)
474
+ sampler_name = self.get_sampler(p, args)
475
+ override_settings = self.get_override_settings(p, args)
476
+
477
+ i2i = StableDiffusionProcessingImg2Img(
478
+ init_images=[image],
479
+ resize_mode=0,
480
+ denoising_strength=args.ad_denoising_strength,
481
+ mask=None,
482
+ mask_blur=args.ad_mask_blur,
483
+ inpainting_fill=1,
484
+ inpaint_full_res=args.ad_inpaint_only_masked,
485
+ inpaint_full_res_padding=args.ad_inpaint_only_masked_padding,
486
+ inpainting_mask_invert=0,
487
+ initial_noise_multiplier=initial_noise_multiplier,
488
+ sd_model=p.sd_model,
489
+ outpath_samples=p.outpath_samples,
490
+ outpath_grids=p.outpath_grids,
491
+ prompt="", # replace later
492
+ negative_prompt="",
493
+ styles=p.styles,
494
+ seed=seed,
495
+ subseed=subseed,
496
+ subseed_strength=p.subseed_strength,
497
+ seed_resize_from_h=p.seed_resize_from_h,
498
+ seed_resize_from_w=p.seed_resize_from_w,
499
+ sampler_name=sampler_name,
500
+ batch_size=1,
501
+ n_iter=1,
502
+ steps=steps,
503
+ cfg_scale=cfg_scale,
504
+ width=width,
505
+ height=height,
506
+ restore_faces=args.ad_restore_face,
507
+ tiling=p.tiling,
508
+ extra_generation_params=p.extra_generation_params,
509
+ do_not_save_samples=True,
510
+ do_not_save_grid=True,
511
+ override_settings=override_settings,
512
+ )
513
+
514
+ i2i.cached_c = [None, None]
515
+ i2i.cached_uc = [None, None]
516
+ i2i.scripts, i2i.script_args = self.script_filter(p, args)
517
+ i2i._ad_disabled = True
518
+ i2i._ad_inner = True
519
+
520
+ if args.ad_controlnet_model != "Passthrough":
521
+ self.disable_controlnet_units(i2i.script_args)
522
+
523
+ if args.ad_controlnet_model not in ["None", "Passthrough"]:
524
+ self.update_controlnet_args(i2i, args)
525
+ elif args.ad_controlnet_model == "None":
526
+ i2i.control_net_enabled = False
527
+
528
+ return i2i
529
+
530
+ def save_image(self, p, image, *, condition: str, suffix: str) -> None:
531
+ i = self.get_i(p)
532
+ if p.all_prompts:
533
+ i %= len(p.all_prompts)
534
+ save_prompt = p.all_prompts[i]
535
+ else:
536
+ save_prompt = p.prompt
537
+ seed, _ = self.get_seed(p)
538
+
539
+ if opts.data.get(condition, False):
540
+ images.save_image(
541
+ image=image,
542
+ path=p.outpath_samples,
543
+ basename="",
544
+ seed=seed,
545
+ prompt=save_prompt,
546
+ extension=opts.samples_format,
547
+ info=self.infotext(p),
548
+ p=p,
549
+ suffix=suffix,
550
+ )
551
+
552
+ def get_ad_model(self, name: str):
553
+ if name not in model_mapping:
554
+ msg = f"[-] ADetailer: Model {name!r} not found. Available models: {list(model_mapping.keys())}"
555
+ raise ValueError(msg)
556
+ return model_mapping[name]
557
+
558
+ def sort_bboxes(self, pred: PredictOutput) -> PredictOutput:
559
+ sortby = opts.data.get("ad_bbox_sortby", BBOX_SORTBY[0])
560
+ sortby_idx = BBOX_SORTBY.index(sortby)
561
+ return sort_bboxes(pred, sortby_idx)
562
+
563
+ def pred_preprocessing(self, pred: PredictOutput, args: ADetailerArgs):
564
+ pred = filter_by_ratio(
565
+ pred, low=args.ad_mask_min_ratio, high=args.ad_mask_max_ratio
566
+ )
567
+ pred = filter_k_largest(pred, k=args.ad_mask_k_largest)
568
+ pred = self.sort_bboxes(pred)
569
+ return mask_preprocess(
570
+ pred.masks,
571
+ kernel=args.ad_dilate_erode,
572
+ x_offset=args.ad_x_offset,
573
+ y_offset=args.ad_y_offset,
574
+ merge_invert=args.ad_mask_merge_invert,
575
+ )
576
+
577
+ @staticmethod
578
+ def ensure_rgb_image(image: Any):
579
+ if not isinstance(image, Image.Image):
580
+ image = to_pil_image(image)
581
+ if image.mode != "RGB":
582
+ image = image.convert("RGB")
583
+ return image
584
+
585
+ @staticmethod
586
+ def i2i_prompts_replace(
587
+ i2i, prompts: list[str], negative_prompts: list[str], j: int
588
+ ) -> None:
589
+ i1 = min(j, len(prompts) - 1)
590
+ i2 = min(j, len(negative_prompts) - 1)
591
+ prompt = prompts[i1]
592
+ negative_prompt = negative_prompts[i2]
593
+ i2i.prompt = prompt
594
+ i2i.negative_prompt = negative_prompt
595
+
596
+ @staticmethod
597
+ def compare_prompt(p, processed, n: int = 0):
598
+ if p.prompt != processed.all_prompts[0]:
599
+ print(
600
+ f"[-] ADetailer: applied {ordinal(n + 1)} ad_prompt: {processed.all_prompts[0]!r}"
601
+ )
602
+
603
+ if p.negative_prompt != processed.all_negative_prompts[0]:
604
+ print(
605
+ f"[-] ADetailer: applied {ordinal(n + 1)} ad_negative_prompt: {processed.all_negative_prompts[0]!r}"
606
+ )
607
+
608
+ @staticmethod
609
+ def need_call_process(p) -> bool:
610
+ if p.scripts is None:
611
+ return False
612
+ i = p.batch_index
613
+ bs = p.batch_size
614
+ return i == bs - 1
615
+
616
+ @staticmethod
617
+ def need_call_postprocess(p) -> bool:
618
+ if p.scripts is None:
619
+ return False
620
+ return p.batch_index == 0
621
+
622
+ @staticmethod
623
+ def get_i2i_init_image(p, pp):
624
+ if getattr(p, "_ad_skip_img2img", False):
625
+ return p.init_images[0]
626
+ return pp.image
627
+
628
+ @staticmethod
629
+ def get_each_tap_seed(seed: int, i: int):
630
+ use_same_seed = shared.opts.data.get("ad_same_seed_for_each_tap", False)
631
+ return seed if use_same_seed else seed + i
632
+
633
+ @staticmethod
634
+ def is_img2img_inpaint(p) -> bool:
635
+ return hasattr(p, "image_mask") and bool(p.image_mask)
636
+
637
+ @rich_traceback
638
+ def process(self, p, *args_):
639
+ if getattr(p, "_ad_disabled", False):
640
+ return
641
+
642
+ # if self.is_img2img_inpaint(p):
643
+ # p._ad_disabled = True
644
+ # msg = "[-] ADetailer: img2img inpainting detected. adetailer disabled."
645
+ # print(msg)
646
+ # return
647
+
648
+ if self.is_ad_enabled(*args_):
649
+ arg_list = self.get_args(p, *args_)
650
+ self.check_skip_img2img(p, *args_)
651
+ extra_params = self.extra_params(arg_list)
652
+ p.extra_generation_params.update(extra_params)
653
+ else:
654
+ p._ad_disabled = True
655
+
656
+ def _postprocess_image_inner(
657
+ self, p, pp, args: ADetailerArgs, *, n: int = 0
658
+ ) -> bool:
659
+ """
660
+ Returns
661
+ -------
662
+ bool
663
+
664
+ `True` if image was processed, `False` otherwise.
665
+ """
666
+ if state.interrupted or state.skipped:
667
+ return False
668
+
669
+ i = self.get_i(p)
670
+
671
+ i2i = self.get_i2i_p(p, args, pp.image)
672
+ seed, subseed = self.get_seed(p)
673
+ ad_prompts, ad_negatives = self.get_prompt(p, args)
674
+
675
+ is_mediapipe = args.ad_model.lower().startswith("mediapipe")
676
+
677
+ kwargs = {}
678
+ if is_mediapipe:
679
+ predictor = mediapipe_predict
680
+ ad_model = args.ad_model
681
+ else:
682
+ predictor = ultralytics_predict
683
+ ad_model = self.get_ad_model(args.ad_model)
684
+ kwargs["device"] = self.ultralytics_device
685
+
686
+ with change_torch_load():
687
+ pred = predictor(ad_model, pp.image, args.ad_confidence, **kwargs)
688
+
689
+ masks = self.pred_preprocessing(pred, args)
690
+ shared.state.assign_current_image(pred.preview)
691
+
692
+ if not masks:
693
+ print(
694
+ f"[-] ADetailer: nothing detected on image {i + 1} with {ordinal(n + 1)} settings."
695
+ )
696
+ return False
697
+
698
+ self.save_image(
699
+ p,
700
+ pred.preview,
701
+ condition="ad_save_previews",
702
+ suffix="-ad-preview" + suffix(n, "-"),
703
+ )
704
+
705
+ steps = len(masks)
706
+ processed = None
707
+ state.job_count += steps
708
+
709
+ if is_mediapipe:
710
+ print(f"mediapipe: {steps} detected.")
711
+
712
+ p2 = copy(i2i)
713
+ for j in range(steps):
714
+ p2.image_mask = masks[j]
715
+ p2.init_images[0] = self.ensure_rgb_image(p2.init_images[0])
716
+ self.i2i_prompts_replace(p2, ad_prompts, ad_negatives, j)
717
+
718
+ if re.match(r"^\s*\[SKIP\]\s*$", p2.prompt):
719
+ continue
720
+
721
+ p2.seed = self.get_each_tap_seed(seed, j)
722
+ p2.subseed = self.get_each_tap_seed(subseed, j)
723
+
724
+ try:
725
+ processed = process_images(p2)
726
+ except NansException as e:
727
+ msg = f"[-] ADetailer: 'NansException' occurred with {ordinal(n + 1)} settings.\n{e}"
728
+ print(msg, file=sys.stderr)
729
+ continue
730
+ finally:
731
+ p2.close()
732
+
733
+ self.compare_prompt(p2, processed, n=n)
734
+ p2 = copy(i2i)
735
+ p2.init_images = [processed.images[0]]
736
+
737
+ if processed is not None:
738
+ pp.image = processed.images[0]
739
+ return True
740
+
741
+ return False
742
+
743
+ @rich_traceback
744
+ def postprocess_image(self, p, pp, *args_):
745
+ if getattr(p, "_ad_disabled", False) or not self.is_ad_enabled(*args_):
746
+ return
747
+
748
+ pp.image = self.get_i2i_init_image(p, pp)
749
+ pp.image = self.ensure_rgb_image(pp.image)
750
+ init_image = copy(pp.image)
751
+ arg_list = self.get_args(p, *args_)
752
+ params_txt_content = Path(paths.data_path, "params.txt").read_text("utf-8")
753
+
754
+ if self.need_call_postprocess(p):
755
+ dummy = Processed(p, [], p.seed, "")
756
+ with preseve_prompts(p):
757
+ p.scripts.postprocess(copy(p), dummy)
758
+
759
+ is_processed = False
760
+ with CNHijackRestore(), pause_total_tqdm(), cn_allow_script_control():
761
+ for n, args in enumerate(arg_list):
762
+ if args.ad_model == "None":
763
+ continue
764
+ is_processed |= self._postprocess_image_inner(p, pp, args, n=n)
765
+
766
+ if is_processed and not getattr(p, "_ad_skip_img2img", False):
767
+ self.save_image(
768
+ p, init_image, condition="ad_save_images_before", suffix="-ad-before"
769
+ )
770
+
771
+ if self.need_call_process(p):
772
+ with preseve_prompts(p):
773
+ copy_p = copy(p)
774
+ if hasattr(p.scripts, "before_process"):
775
+ p.scripts.before_process(copy_p)
776
+ p.scripts.process(copy_p)
777
+
778
+ self.write_params_txt(params_txt_content)
779
+
780
+
781
+ def on_after_component(component, **_kwargs):
782
+ global txt2img_submit_button, img2img_submit_button
783
+ if getattr(component, "elem_id", None) == "txt2img_generate":
784
+ txt2img_submit_button = component
785
+ return
786
+
787
+ if getattr(component, "elem_id", None) == "img2img_generate":
788
+ img2img_submit_button = component
789
+
790
+
791
+ def on_ui_settings():
792
+ section = ("ADetailer", AFTER_DETAILER)
793
+ shared.opts.add_option(
794
+ "ad_max_models",
795
+ shared.OptionInfo(
796
+ default=2,
797
+ label="Max models",
798
+ component=gr.Slider,
799
+ component_args={"minimum": 1, "maximum": 10, "step": 1},
800
+ section=section,
801
+ ),
802
+ )
803
+
804
+ shared.opts.add_option(
805
+ "ad_extra_models_dir",
806
+ shared.OptionInfo(
807
+ default="",
808
+ label="Extra path to scan adetailer models",
809
+ component=gr.Textbox,
810
+ section=section,
811
+ ),
812
+ )
813
+
814
+ shared.opts.add_option(
815
+ "ad_save_previews",
816
+ shared.OptionInfo(False, "Save mask previews", section=section),
817
+ )
818
+
819
+ shared.opts.add_option(
820
+ "ad_save_images_before",
821
+ shared.OptionInfo(False, "Save images before ADetailer", section=section),
822
+ )
823
+
824
+ shared.opts.add_option(
825
+ "ad_only_seleted_scripts",
826
+ shared.OptionInfo(
827
+ True, "Apply only selected scripts to ADetailer", section=section
828
+ ),
829
+ )
830
+
831
+ textbox_args = {
832
+ "placeholder": "comma-separated list of script names",
833
+ "interactive": True,
834
+ }
835
+
836
+ shared.opts.add_option(
837
+ "ad_script_names",
838
+ shared.OptionInfo(
839
+ default=SCRIPT_DEFAULT,
840
+ label="Script names to apply to ADetailer (separated by comma)",
841
+ component=gr.Textbox,
842
+ component_args=textbox_args,
843
+ section=section,
844
+ ),
845
+ )
846
+
847
+ shared.opts.add_option(
848
+ "ad_bbox_sortby",
849
+ shared.OptionInfo(
850
+ default="None",
851
+ label="Sort bounding boxes by",
852
+ component=gr.Radio,
853
+ component_args={"choices": BBOX_SORTBY},
854
+ section=section,
855
+ ),
856
+ )
857
+
858
+ shared.opts.add_option(
859
+ "ad_same_seed_for_each_tap",
860
+ shared.OptionInfo(
861
+ False, "Use same seed for each tab in adetailer", section=section
862
+ ),
863
+ )
864
+
865
+
866
+ # xyz_grid
867
+
868
+
869
+ class PromptSR(NamedTuple):
870
+ s: str
871
+ r: str
872
+
873
+
874
+ def set_value(p, x: Any, xs: Any, *, field: str):
875
+ if not hasattr(p, "_ad_xyz"):
876
+ p._ad_xyz = {}
877
+ p._ad_xyz[field] = x
878
+
879
+
880
+ def search_and_replace_prompt(p, x: Any, xs: Any, replace_in_main_prompt: bool):
881
+ if replace_in_main_prompt:
882
+ p.prompt = p.prompt.replace(xs[0], x)
883
+ p.negative_prompt = p.negative_prompt.replace(xs[0], x)
884
+
885
+ if not hasattr(p, "_ad_xyz_prompt_sr"):
886
+ p._ad_xyz_prompt_sr = []
887
+ p._ad_xyz_prompt_sr.append(PromptSR(s=xs[0], r=x))
888
+
889
+
890
+ def make_axis_on_xyz_grid():
891
+ xyz_grid = None
892
+ for script in scripts.scripts_data:
893
+ if script.script_class.__module__ == "xyz_grid.py":
894
+ xyz_grid = script.module
895
+ break
896
+
897
+ if xyz_grid is None:
898
+ return
899
+
900
+ model_list = ["None", *model_mapping.keys()]
901
+ samplers = [sampler.name for sampler in all_samplers]
902
+
903
+ axis = [
904
+ xyz_grid.AxisOption(
905
+ "[ADetailer] ADetailer model 1st",
906
+ str,
907
+ partial(set_value, field="ad_model"),
908
+ choices=lambda: model_list,
909
+ ),
910
+ xyz_grid.AxisOption(
911
+ "[ADetailer] ADetailer prompt 1st",
912
+ str,
913
+ partial(set_value, field="ad_prompt"),
914
+ ),
915
+ xyz_grid.AxisOption(
916
+ "[ADetailer] ADetailer negative prompt 1st",
917
+ str,
918
+ partial(set_value, field="ad_negative_prompt"),
919
+ ),
920
+ xyz_grid.AxisOption(
921
+ "[ADetailer] Prompt S/R (AD 1st)",
922
+ str,
923
+ partial(search_and_replace_prompt, replace_in_main_prompt=False),
924
+ ),
925
+ xyz_grid.AxisOption(
926
+ "[ADetailer] Prompt S/R (AD 1st and main prompt)",
927
+ str,
928
+ partial(search_and_replace_prompt, replace_in_main_prompt=True),
929
+ ),
930
+ xyz_grid.AxisOption(
931
+ "[ADetailer] Mask erosion / dilation 1st",
932
+ int,
933
+ partial(set_value, field="ad_dilate_erode"),
934
+ ),
935
+ xyz_grid.AxisOption(
936
+ "[ADetailer] Inpaint denoising strength 1st",
937
+ float,
938
+ partial(set_value, field="ad_denoising_strength"),
939
+ ),
940
+ xyz_grid.AxisOption(
941
+ "[ADetailer] Inpaint only masked 1st",
942
+ str,
943
+ partial(set_value, field="ad_inpaint_only_masked"),
944
+ choices=lambda: ["True", "False"],
945
+ ),
946
+ xyz_grid.AxisOption(
947
+ "[ADetailer] Inpaint only masked padding 1st",
948
+ int,
949
+ partial(set_value, field="ad_inpaint_only_masked_padding"),
950
+ ),
951
+ xyz_grid.AxisOption(
952
+ "[ADetailer] ADetailer sampler 1st",
953
+ str,
954
+ partial(set_value, field="ad_sampler"),
955
+ choices=lambda: samplers,
956
+ ),
957
+ xyz_grid.AxisOption(
958
+ "[ADetailer] ControlNet model 1st",
959
+ str,
960
+ partial(set_value, field="ad_controlnet_model"),
961
+ choices=lambda: ["None", *get_cn_models()],
962
+ ),
963
+ ]
964
+
965
+ if not any(x.label.startswith("[ADetailer]") for x in xyz_grid.axis_options):
966
+ xyz_grid.axis_options.extend(axis)
967
+
968
+
969
+ def on_before_ui():
970
+ try:
971
+ make_axis_on_xyz_grid()
972
+ except Exception:
973
+ error = traceback.format_exc()
974
+ print(
975
+ f"[-] ADetailer: xyz_grid error:\n{error}",
976
+ file=sys.stderr,
977
+ )
978
+
979
+
980
+ # api
981
+
982
+
983
+ def add_api_endpoints(_: gr.Blocks, app: FastAPI):
984
+ @app.get("/adetailer/v1/version")
985
+ def version():
986
+ return {"version": __version__}
987
+
988
+ @app.get("/adetailer/v1/schema")
989
+ def schema():
990
+ return ADetailerArgs.schema()
991
+
992
+ @app.get("/adetailer/v1/ad_model")
993
+ def ad_model():
994
+ return {"ad_model": list(model_mapping)}
995
+
996
+
997
+ script_callbacks.on_ui_settings(on_ui_settings)
998
+ script_callbacks.on_after_component(on_after_component)
999
+ script_callbacks.on_app_started(add_api_endpoints)
1000
+ script_callbacks.on_before_ui(on_before_ui)