extensions / adetailer /scripts /!adetailer.py
toto10's picture
Upload folder using huggingface_hub (#1)
34097e9
raw
history blame
20 kB
from __future__ import annotations
import os
import platform
import re
import sys
import traceback
from contextlib import contextmanager, suppress
from copy import copy, deepcopy
from pathlib import Path
from textwrap import dedent
from typing import Any
import gradio as gr
import torch
import modules # noqa: F401
from adetailer import (
AFTER_DETAILER,
__version__,
get_models,
mediapipe_predict,
ultralytics_predict,
)
from adetailer.args import ALL_ARGS, BBOX_SORTBY, ADetailerArgs, EnableChecker
from adetailer.common import PredictOutput
from adetailer.mask import filter_by_ratio, mask_preprocess, sort_bboxes
from adetailer.ui import adui, ordinal, suffix
from controlnet_ext import ControlNetExt, controlnet_exists
from controlnet_ext.restore import (
CNHijackRestore,
cn_allow_script_control,
cn_restore_unet_hook,
)
from sd_webui import images, safe, script_callbacks, scripts, shared
from sd_webui.paths import data_path, models_path
from sd_webui.processing import (
StableDiffusionProcessingImg2Img,
create_infotext,
process_images,
)
from sd_webui.shared import cmd_opts, opts, state
with suppress(ImportError):
from rich import print
no_huggingface = getattr(cmd_opts, "ad_no_huggingface", False)
adetailer_dir = Path(models_path, "adetailer")
model_mapping = get_models(adetailer_dir, huggingface=not no_huggingface)
txt2img_submit_button = img2img_submit_button = None
SCRIPT_DEFAULT = "dynamic_prompting,dynamic_thresholding,wildcard_recursive,wildcards"
if (
not adetailer_dir.exists()
and adetailer_dir.parent.exists()
and os.access(adetailer_dir.parent, os.W_OK)
):
adetailer_dir.mkdir()
print(
f"[-] ADetailer initialized. version: {__version__}, num models: {len(model_mapping)}"
)
@contextmanager
def change_torch_load():
orig = torch.load
try:
torch.load = safe.unsafe_torch_load
yield
finally:
torch.load = orig
@contextmanager
def pause_total_tqdm():
orig = opts.data.get("multiple_tqdm", True)
try:
opts.data["multiple_tqdm"] = False
yield
finally:
opts.data["multiple_tqdm"] = orig
class AfterDetailerScript(scripts.Script):
def __init__(self):
super().__init__()
self.ultralytics_device = self.get_ultralytics_device()
self.controlnet_ext = None
self.cn_script = None
self.cn_latest_network = None
def title(self):
return AFTER_DETAILER
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
num_models = opts.data.get("ad_max_models", 2)
model_list = list(model_mapping.keys())
components, infotext_fields = adui(
num_models,
is_img2img,
model_list,
txt2img_submit_button,
img2img_submit_button,
)
self.infotext_fields = infotext_fields
return components
def init_controlnet_ext(self) -> None:
if self.controlnet_ext is not None:
return
self.controlnet_ext = ControlNetExt()
if controlnet_exists:
try:
self.controlnet_ext.init_controlnet()
except ImportError:
error = traceback.format_exc()
print(
f"[-] ADetailer: ControlNetExt init failed:\n{error}",
file=sys.stderr,
)
def update_controlnet_args(self, p, args: ADetailerArgs) -> None:
if self.controlnet_ext is None:
self.init_controlnet_ext()
if (
self.controlnet_ext is not None
and self.controlnet_ext.cn_available
and args.ad_controlnet_model != "None"
):
self.controlnet_ext.update_scripts_args(
p,
model=args.ad_controlnet_model,
weight=args.ad_controlnet_weight,
guidance_start=args.ad_controlnet_guidance_start,
guidance_end=args.ad_controlnet_guidance_end,
)
def is_ad_enabled(self, *args_) -> bool:
if len(args_) == 0 or (len(args_) == 1 and isinstance(args_[0], bool)):
message = f"""
[-] ADetailer: Not enough arguments passed to ADetailer.
input: {args_!r}
"""
raise ValueError(dedent(message))
a0 = args_[0]
a1 = args_[1] if len(args_) > 1 else None
checker = EnableChecker(a0=a0, a1=a1)
return checker.is_enabled()
def get_args(self, *args_) -> list[ADetailerArgs]:
"""
`args_` is at least 1 in length by `is_ad_enabled` immediately above
"""
args = [arg for arg in args_ if isinstance(arg, dict)]
if not args:
message = f"[-] ADetailer: Invalid arguments passed to ADetailer: {args_!r}"
raise ValueError(message)
all_inputs = []
for n, arg_dict in enumerate(args, 1):
try:
inp = ADetailerArgs(**arg_dict)
except ValueError as e:
msgs = [
f"[-] ADetailer: ValidationError when validating {ordinal(n)} arguments: {e}\n"
]
for attr in ALL_ARGS.attrs:
arg = arg_dict.get(attr)
dtype = type(arg)
arg = "DEFAULT" if arg is None else repr(arg)
msgs.append(f" {attr}: {arg} ({dtype})")
raise ValueError("\n".join(msgs)) from e
all_inputs.append(inp)
return all_inputs
def extra_params(self, arg_list: list[ADetailerArgs]) -> dict:
params = {}
for n, args in enumerate(arg_list):
params.update(args.extra_params(suffix=suffix(n)))
params["ADetailer version"] = __version__
return params
@staticmethod
def get_ultralytics_device() -> str:
'`device = ""` means autodetect'
device = ""
if platform.system() == "Darwin":
return device
if any(getattr(cmd_opts, vram, False) for vram in ["lowvram", "medvram"]):
device = "cpu"
return device
def prompt_blank_replacement(
self, all_prompts: list[str], i: int, default: str
) -> str:
if not all_prompts:
return default
if i < len(all_prompts):
return all_prompts[i]
j = i % len(all_prompts)
return all_prompts[j]
def _get_prompt(
self, ad_prompt: str, all_prompts: list[str], i: int, default: str
) -> list[str]:
prompts = re.split(r"\s*\[SEP\]\s*", ad_prompt)
blank_replacement = self.prompt_blank_replacement(all_prompts, i, default)
for n in range(len(prompts)):
if not prompts[n]:
prompts[n] = blank_replacement
return prompts
def get_prompt(self, p, args: ADetailerArgs) -> tuple[list[str], list[str]]:
i = p._idx
prompt = self._get_prompt(args.ad_prompt, p.all_prompts, i, p.prompt)
negative_prompt = self._get_prompt(
args.ad_negative_prompt, p.all_negative_prompts, i, p.negative_prompt
)
return prompt, negative_prompt
def get_seed(self, p) -> tuple[int, int]:
i = p._idx
if not p.all_seeds:
seed = p.seed
elif i < len(p.all_seeds):
seed = p.all_seeds[i]
else:
j = i % len(p.all_seeds)
seed = p.all_seeds[j]
if not p.all_subseeds:
subseed = p.subseed
elif i < len(p.all_subseeds):
subseed = p.all_subseeds[i]
else:
j = i % len(p.all_subseeds)
subseed = p.all_subseeds[j]
return seed, subseed
def get_width_height(self, p, args: ADetailerArgs) -> tuple[int, int]:
if args.ad_use_inpaint_width_height:
width = args.ad_inpaint_width
height = args.ad_inpaint_height
else:
width = p.width
height = p.height
return width, height
def get_steps(self, p, args: ADetailerArgs) -> int:
if args.ad_use_steps:
return args.ad_steps
return p.steps
def get_cfg_scale(self, p, args: ADetailerArgs) -> float:
if args.ad_use_cfg_scale:
return args.ad_cfg_scale
return p.cfg_scale
def infotext(self, p) -> str:
return create_infotext(
p, p.all_prompts, p.all_seeds, p.all_subseeds, None, 0, 0
)
def write_params_txt(self, p) -> None:
infotext = self.infotext(p)
params_txt = Path(data_path, "params.txt")
params_txt.write_text(infotext, encoding="utf-8")
def script_filter(self, p, args: ADetailerArgs):
script_runner = copy(p.scripts)
script_args = deepcopy(p.script_args)
self.disable_controlnet_units(script_args)
ad_only_seleted_scripts = opts.data.get("ad_only_seleted_scripts", True)
if not ad_only_seleted_scripts:
return script_runner, script_args
ad_script_names = opts.data.get("ad_script_names", SCRIPT_DEFAULT)
script_names_set = {
name
for script_name in ad_script_names.split(",")
for name in (script_name, script_name.strip())
}
if args.ad_controlnet_model != "None":
script_names_set.add("controlnet")
filtered_alwayson = []
for script_object in script_runner.alwayson_scripts:
filepath = script_object.filename
filename = Path(filepath).stem
if filename in script_names_set:
filtered_alwayson.append(script_object)
if filename == "controlnet":
self.cn_script = script_object
self.cn_latest_network = script_object.latest_network
script_runner.alwayson_scripts = filtered_alwayson
return script_runner, script_args
def disable_controlnet_units(self, script_args: list[Any]) -> None:
for obj in script_args:
if "controlnet" in obj.__class__.__name__.lower():
if hasattr(obj, "enabled"):
obj.enabled = False
if hasattr(obj, "input_mode"):
obj.input_mode = getattr(obj.input_mode, "SIMPLE", "simple")
elif isinstance(obj, dict) and "module" in obj:
obj["enabled"] = False
def get_i2i_p(self, p, args: ADetailerArgs, image):
seed, subseed = self.get_seed(p)
width, height = self.get_width_height(p, args)
steps = self.get_steps(p, args)
cfg_scale = self.get_cfg_scale(p, args)
sampler_name = p.sampler_name
if sampler_name in ["PLMS", "UniPC"]:
sampler_name = "Euler"
i2i = StableDiffusionProcessingImg2Img(
init_images=[image],
resize_mode=0,
denoising_strength=args.ad_denoising_strength,
mask=None,
mask_blur=args.ad_mask_blur,
inpainting_fill=1,
inpaint_full_res=args.ad_inpaint_only_masked,
inpaint_full_res_padding=args.ad_inpaint_only_masked_padding,
inpainting_mask_invert=0,
sd_model=p.sd_model,
outpath_samples=p.outpath_samples,
outpath_grids=p.outpath_grids,
prompt="", # replace later
negative_prompt="",
styles=p.styles,
seed=seed,
subseed=subseed,
subseed_strength=p.subseed_strength,
seed_resize_from_h=p.seed_resize_from_h,
seed_resize_from_w=p.seed_resize_from_w,
sampler_name=sampler_name,
batch_size=1,
n_iter=1,
steps=steps,
cfg_scale=cfg_scale,
width=width,
height=height,
restore_faces=args.ad_restore_face,
tiling=p.tiling,
extra_generation_params=p.extra_generation_params,
do_not_save_samples=True,
do_not_save_grid=True,
)
i2i.scripts, i2i.script_args = self.script_filter(p, args)
i2i._disable_adetailer = True
if args.ad_controlnet_model != "None":
self.update_controlnet_args(i2i, args)
else:
i2i.control_net_enabled = False
return i2i
def save_image(self, p, image, *, condition: str, suffix: str) -> None:
i = p._idx
seed, _ = self.get_seed(p)
if opts.data.get(condition, False):
images.save_image(
image=image,
path=p.outpath_samples,
basename="",
seed=seed,
prompt=p.all_prompts[i] if i < len(p.all_prompts) else p.prompt,
extension=opts.samples_format,
info=self.infotext(p),
p=p,
suffix=suffix,
)
def get_ad_model(self, name: str):
if name not in model_mapping:
msg = f"[-] ADetailer: Model {name!r} not found. Available models: {list(model_mapping.keys())}"
raise ValueError(msg)
return model_mapping[name]
def sort_bboxes(self, pred: PredictOutput) -> PredictOutput:
sortby = opts.data.get("ad_bbox_sortby", BBOX_SORTBY[0])
sortby_idx = BBOX_SORTBY.index(sortby)
pred = sort_bboxes(pred, sortby_idx)
return pred
def pred_preprocessing(self, pred: PredictOutput, args: ADetailerArgs):
pred = filter_by_ratio(
pred, low=args.ad_mask_min_ratio, high=args.ad_mask_max_ratio
)
pred = self.sort_bboxes(pred)
return mask_preprocess(
pred.masks,
kernel=args.ad_dilate_erode,
x_offset=args.ad_x_offset,
y_offset=args.ad_y_offset,
merge_invert=args.ad_mask_merge_invert,
)
def i2i_prompts_replace(
self, i2i, prompts: list[str], negative_prompts: list[str], j: int
) -> None:
i1 = min(j, len(prompts) - 1)
i2 = min(j, len(negative_prompts) - 1)
prompt = prompts[i1]
negative_prompt = negative_prompts[i2]
i2i.prompt = prompt
i2i.negative_prompt = negative_prompt
def is_need_call_process(self, p) -> bool:
i = p._idx
n_iter = p.iteration
bs = p.batch_size
return (i == (n_iter + 1) * bs - 1) and (i != len(p.all_prompts) - 1)
def process(self, p, *args_):
if getattr(p, "_disable_adetailer", False):
return
if self.is_ad_enabled(*args_):
arg_list = self.get_args(*args_)
extra_params = self.extra_params(arg_list)
p.extra_generation_params.update(extra_params)
p._idx = -1
def _postprocess_image(self, p, pp, args: ADetailerArgs, *, n: int = 0) -> bool:
"""
Returns
-------
bool
`True` if image was processed, `False` otherwise.
"""
if state.interrupted:
return False
i = p._idx
i2i = self.get_i2i_p(p, args, pp.image)
seed, subseed = self.get_seed(p)
ad_prompts, ad_negatives = self.get_prompt(p, args)
is_mediapipe = args.ad_model.lower().startswith("mediapipe")
kwargs = {}
if is_mediapipe:
predictor = mediapipe_predict
ad_model = args.ad_model
else:
predictor = ultralytics_predict
ad_model = self.get_ad_model(args.ad_model)
kwargs["device"] = self.ultralytics_device
with change_torch_load():
pred = predictor(ad_model, pp.image, args.ad_confidence, **kwargs)
masks = self.pred_preprocessing(pred, args)
if not masks:
print(
f"[-] ADetailer: nothing detected on image {i + 1} with {ordinal(n + 1)} settings."
)
return False
self.save_image(
p,
pred.preview,
condition="ad_save_previews",
suffix="-ad-preview" + suffix(n, "-"),
)
steps = len(masks)
processed = None
state.job_count += steps
if is_mediapipe:
print(f"mediapipe: {steps} detected.")
p2 = copy(i2i)
for j in range(steps):
p2.image_mask = masks[j]
self.i2i_prompts_replace(p2, ad_prompts, ad_negatives, j)
if not re.match(r"^\s*\[SKIP\]\s*$", p2.prompt):
if args.ad_controlnet_model == "None":
cn_restore_unet_hook(p2, self.cn_latest_network)
processed = process_images(p2)
p2 = copy(i2i)
p2.init_images = [processed.images[0]]
p2.seed = seed + j + 1
p2.subseed = subseed + j + 1
if processed is not None:
pp.image = processed.images[0]
return True
return False
def postprocess_image(self, p, pp, *args_):
if getattr(p, "_disable_adetailer", False):
return
if not self.is_ad_enabled(*args_):
return
p._idx = getattr(p, "_idx", -1) + 1
init_image = copy(pp.image)
arg_list = self.get_args(*args_)
is_processed = False
with CNHijackRestore(), pause_total_tqdm(), cn_allow_script_control():
for n, args in enumerate(arg_list):
if args.ad_model == "None":
continue
is_processed |= self._postprocess_image(p, pp, args, n=n)
if is_processed:
self.save_image(
p, init_image, condition="ad_save_images_before", suffix="-ad-before"
)
if self.cn_script is not None and self.is_need_call_process(p):
self.cn_script.process(p)
try:
if p._idx == len(p.all_prompts) - 1:
self.write_params_txt(p)
except Exception:
pass
def on_after_component(component, **_kwargs):
global txt2img_submit_button, img2img_submit_button
if getattr(component, "elem_id", None) == "txt2img_generate":
txt2img_submit_button = component
return
if getattr(component, "elem_id", None) == "img2img_generate":
img2img_submit_button = component
def on_ui_settings():
section = ("ADetailer", AFTER_DETAILER)
shared.opts.add_option(
"ad_max_models",
shared.OptionInfo(
default=2,
label="Max models",
component=gr.Slider,
component_args={"minimum": 1, "maximum": 5, "step": 1},
section=section,
),
)
shared.opts.add_option(
"ad_save_previews",
shared.OptionInfo(False, "Save mask previews", section=section),
)
shared.opts.add_option(
"ad_save_images_before",
shared.OptionInfo(False, "Save images before ADetailer", section=section),
)
shared.opts.add_option(
"ad_only_seleted_scripts",
shared.OptionInfo(
True, "Apply only selected scripts to ADetailer", section=section
),
)
textbox_args = {
"placeholder": "comma-separated list of script names",
"interactive": True,
}
shared.opts.add_option(
"ad_script_names",
shared.OptionInfo(
default=SCRIPT_DEFAULT,
label="Script names to apply to ADetailer (separated by comma)",
component=gr.Textbox,
component_args=textbox_args,
section=section,
),
)
shared.opts.add_option(
"ad_bbox_sortby",
shared.OptionInfo(
default="None",
label="Sort bounding boxes by",
component=gr.Radio,
component_args={"choices": BBOX_SORTBY},
section=section,
),
)
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_after_component(on_after_component)