diff --git a/modules/__pycache__/anisotropic.cpython-310.pyc b/modules/__pycache__/anisotropic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43e2517b45e03eba92b475d3a62520739b456cdd Binary files /dev/null and b/modules/__pycache__/anisotropic.cpython-310.pyc differ diff --git a/modules/__pycache__/async_worker.cpython-310.pyc b/modules/__pycache__/async_worker.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c2eacf59de394b546962e5ad51c123c829f0c58 Binary files /dev/null and b/modules/__pycache__/async_worker.cpython-310.pyc differ diff --git a/modules/__pycache__/auth.cpython-310.pyc b/modules/__pycache__/auth.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf76d5f6b9b47f034fad565a8a7dbd6ef6139ee5 Binary files /dev/null and b/modules/__pycache__/auth.cpython-310.pyc differ diff --git a/modules/__pycache__/config.cpython-310.pyc b/modules/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29b38ee9ad4f744b9db64cbf4c525f6ae5f1399b Binary files /dev/null and b/modules/__pycache__/config.cpython-310.pyc differ diff --git a/modules/__pycache__/config.cpython-312.pyc b/modules/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52b3198c28298561e0e11fdee4bc3ce7d584af38 Binary files /dev/null and b/modules/__pycache__/config.cpython-312.pyc differ diff --git a/modules/__pycache__/constants.cpython-310.pyc b/modules/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee52166f0d2868b24bfafe93b1fb53725ee4a7d8 Binary files /dev/null and b/modules/__pycache__/constants.cpython-310.pyc differ diff --git a/modules/__pycache__/core.cpython-310.pyc b/modules/__pycache__/core.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eaba5b84d48edd9cf9242d90b9af5c95646aaefc Binary files /dev/null and b/modules/__pycache__/core.cpython-310.pyc differ diff --git a/modules/__pycache__/default_pipeline.cpython-310.pyc b/modules/__pycache__/default_pipeline.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3cb673600e83c35ca5fa09103877310e30b659b9 Binary files /dev/null and b/modules/__pycache__/default_pipeline.cpython-310.pyc differ diff --git a/modules/__pycache__/flags.cpython-310.pyc b/modules/__pycache__/flags.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1cb1c216478deb2bbf1a0065a2a644734e49ba5 Binary files /dev/null and b/modules/__pycache__/flags.cpython-310.pyc differ diff --git a/modules/__pycache__/flags.cpython-312.pyc b/modules/__pycache__/flags.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbab4226f2aedb5836f09b5b9dfba37879d80cc2 Binary files /dev/null and b/modules/__pycache__/flags.cpython-312.pyc differ diff --git a/modules/__pycache__/gradio_hijack.cpython-310.pyc b/modules/__pycache__/gradio_hijack.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..735d4a980dbfdd4381255c5b54acf3b3eab56e03 Binary files /dev/null and b/modules/__pycache__/gradio_hijack.cpython-310.pyc differ diff --git a/modules/__pycache__/html.cpython-310.pyc b/modules/__pycache__/html.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d245933781cedf4273b5a7a5e0946b115624b32 Binary files /dev/null and b/modules/__pycache__/html.cpython-310.pyc differ diff --git a/modules/__pycache__/inpaint_worker.cpython-310.pyc b/modules/__pycache__/inpaint_worker.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..755268bcdaa32c9bef22384017e527d39d9f4711 Binary files /dev/null and b/modules/__pycache__/inpaint_worker.cpython-310.pyc differ diff --git a/modules/__pycache__/launch_util.cpython-310.pyc b/modules/__pycache__/launch_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a344adef9726f1f55d41b6cb74d11bc483a895b Binary files /dev/null and b/modules/__pycache__/launch_util.cpython-310.pyc differ diff --git a/modules/__pycache__/launch_util.cpython-312.pyc b/modules/__pycache__/launch_util.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e210eda8634054161e3d2a59eefe47eb7b9b317 Binary files /dev/null and b/modules/__pycache__/launch_util.cpython-312.pyc differ diff --git a/modules/__pycache__/localization.cpython-310.pyc b/modules/__pycache__/localization.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22749327c95a955880e5a8c2daa3ca88a9b77730 Binary files /dev/null and b/modules/__pycache__/localization.cpython-310.pyc differ diff --git a/modules/__pycache__/lora.cpython-310.pyc b/modules/__pycache__/lora.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bf37271026a2dfcb4555db7493b7af47a4180a3 Binary files /dev/null and b/modules/__pycache__/lora.cpython-310.pyc differ diff --git a/modules/__pycache__/meta_parser.cpython-310.pyc b/modules/__pycache__/meta_parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..179c6aa8f0db3ae9ce5af088d109636ffa92f189 Binary files /dev/null and b/modules/__pycache__/meta_parser.cpython-310.pyc differ diff --git a/modules/__pycache__/model_loader.cpython-310.pyc b/modules/__pycache__/model_loader.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdf56871b33cc5369bdd164c43aa4258aab8dae3 Binary files /dev/null and b/modules/__pycache__/model_loader.cpython-310.pyc differ diff --git a/modules/__pycache__/ops.cpython-310.pyc b/modules/__pycache__/ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40fd7d2955432ae501ea2e7a5c02f28b69838a1d Binary files /dev/null and b/modules/__pycache__/ops.cpython-310.pyc differ diff --git a/modules/__pycache__/patch.cpython-310.pyc b/modules/__pycache__/patch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5efbc336c3a86ad357b2d0c82c2fbded177dc66b Binary files /dev/null and b/modules/__pycache__/patch.cpython-310.pyc differ diff --git a/modules/__pycache__/patch_clip.cpython-310.pyc b/modules/__pycache__/patch_clip.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00f662bca01f5a747607e1aa234a3ee6667312d9 Binary files /dev/null and b/modules/__pycache__/patch_clip.cpython-310.pyc differ diff --git a/modules/__pycache__/patch_precision.cpython-310.pyc b/modules/__pycache__/patch_precision.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56ef76fd17b7c3425ec37fd666eca3ac156323b9 Binary files /dev/null and b/modules/__pycache__/patch_precision.cpython-310.pyc differ diff --git a/modules/__pycache__/private_logger.cpython-310.pyc b/modules/__pycache__/private_logger.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68efe37edd5ac0ff00c18372e2b244025d0b40aa Binary files /dev/null and b/modules/__pycache__/private_logger.cpython-310.pyc differ diff --git a/modules/__pycache__/sample_hijack.cpython-310.pyc b/modules/__pycache__/sample_hijack.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1431ac1072c726f882c49b1d5d4b9157730fa4e Binary files /dev/null and b/modules/__pycache__/sample_hijack.cpython-310.pyc differ diff --git a/modules/__pycache__/sdxl_styles.cpython-310.pyc b/modules/__pycache__/sdxl_styles.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70a3b7a41cb301fcda69c4a00590575698f44b6c Binary files /dev/null and b/modules/__pycache__/sdxl_styles.cpython-310.pyc differ diff --git a/modules/__pycache__/sdxl_styles.cpython-312.pyc b/modules/__pycache__/sdxl_styles.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90c83e770e8164bfc1585ab46f46ea81537dadd4 Binary files /dev/null and b/modules/__pycache__/sdxl_styles.cpython-312.pyc differ diff --git a/modules/__pycache__/style_sorter.cpython-310.pyc b/modules/__pycache__/style_sorter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b537c62911cdb6e5f57074672f0876c6ac5d31e Binary files /dev/null and b/modules/__pycache__/style_sorter.cpython-310.pyc differ diff --git a/modules/__pycache__/ui_gradio_extensions.cpython-310.pyc b/modules/__pycache__/ui_gradio_extensions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54e4266202229003604121fa6a9d219d165c277c Binary files /dev/null and b/modules/__pycache__/ui_gradio_extensions.cpython-310.pyc differ diff --git a/modules/__pycache__/upscaler.cpython-310.pyc b/modules/__pycache__/upscaler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3be4d59b92e78fec3e45d6b520f30559b29b40ef Binary files /dev/null and b/modules/__pycache__/upscaler.cpython-310.pyc differ diff --git a/modules/__pycache__/util.cpython-310.pyc b/modules/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0943260f439e7dadad75470747edc3e7f1b394d2 Binary files /dev/null and b/modules/__pycache__/util.cpython-310.pyc differ diff --git a/modules/__pycache__/util.cpython-312.pyc b/modules/__pycache__/util.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eeb171b5cd387094e5069ea5aec29dd324011adf Binary files /dev/null and b/modules/__pycache__/util.cpython-312.pyc differ diff --git a/modules/anisotropic.py b/modules/anisotropic.py new file mode 100644 index 0000000000000000000000000000000000000000..576822240762b7dfcfb27e49364314ee1cb436d9 --- /dev/null +++ b/modules/anisotropic.py @@ -0,0 +1,200 @@ +import torch + + +Tensor = torch.Tensor +Device = torch.DeviceObjType +Dtype = torch.Type +pad = torch.nn.functional.pad + + +def _compute_zero_padding(kernel_size: tuple[int, int] | int) -> tuple[int, int]: + ky, kx = _unpack_2d_ks(kernel_size) + return (ky - 1) // 2, (kx - 1) // 2 + + +def _unpack_2d_ks(kernel_size: tuple[int, int] | int) -> tuple[int, int]: + if isinstance(kernel_size, int): + ky = kx = kernel_size + else: + assert len(kernel_size) == 2, '2D Kernel size should have a length of 2.' + ky, kx = kernel_size + + ky = int(ky) + kx = int(kx) + return ky, kx + + +def gaussian( + window_size: int, sigma: Tensor | float, *, device: Device | None = None, dtype: Dtype | None = None +) -> Tensor: + + batch_size = sigma.shape[0] + + x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) + + if window_size % 2 == 0: + x = x + 0.5 + + gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) + + return gauss / gauss.sum(-1, keepdim=True) + + +def get_gaussian_kernel1d( + kernel_size: int, + sigma: float | Tensor, + force_even: bool = False, + *, + device: Device | None = None, + dtype: Dtype | None = None, +) -> Tensor: + + return gaussian(kernel_size, sigma, device=device, dtype=dtype) + + +def get_gaussian_kernel2d( + kernel_size: tuple[int, int] | int, + sigma: tuple[float, float] | Tensor, + force_even: bool = False, + *, + device: Device | None = None, + dtype: Dtype | None = None, +) -> Tensor: + + sigma = torch.Tensor([[sigma, sigma]]).to(device=device, dtype=dtype) + + ksize_y, ksize_x = _unpack_2d_ks(kernel_size) + sigma_y, sigma_x = sigma[:, 0, None], sigma[:, 1, None] + + kernel_y = get_gaussian_kernel1d(ksize_y, sigma_y, force_even, device=device, dtype=dtype)[..., None] + kernel_x = get_gaussian_kernel1d(ksize_x, sigma_x, force_even, device=device, dtype=dtype)[..., None] + + return kernel_y * kernel_x.view(-1, 1, ksize_x) + + +def _bilateral_blur( + input: Tensor, + guidance: Tensor | None, + kernel_size: tuple[int, int] | int, + sigma_color: float | Tensor, + sigma_space: tuple[float, float] | Tensor, + border_type: str = 'reflect', + color_distance_type: str = 'l1', +) -> Tensor: + + if isinstance(sigma_color, Tensor): + sigma_color = sigma_color.to(device=input.device, dtype=input.dtype).view(-1, 1, 1, 1, 1) + + ky, kx = _unpack_2d_ks(kernel_size) + pad_y, pad_x = _compute_zero_padding(kernel_size) + + padded_input = pad(input, (pad_x, pad_x, pad_y, pad_y), mode=border_type) + unfolded_input = padded_input.unfold(2, ky, 1).unfold(3, kx, 1).flatten(-2) # (B, C, H, W, Ky x Kx) + + if guidance is None: + guidance = input + unfolded_guidance = unfolded_input + else: + padded_guidance = pad(guidance, (pad_x, pad_x, pad_y, pad_y), mode=border_type) + unfolded_guidance = padded_guidance.unfold(2, ky, 1).unfold(3, kx, 1).flatten(-2) # (B, C, H, W, Ky x Kx) + + diff = unfolded_guidance - guidance.unsqueeze(-1) + if color_distance_type == "l1": + color_distance_sq = diff.abs().sum(1, keepdim=True).square() + elif color_distance_type == "l2": + color_distance_sq = diff.square().sum(1, keepdim=True) + else: + raise ValueError("color_distance_type only acceps l1 or l2") + color_kernel = (-0.5 / sigma_color**2 * color_distance_sq).exp() # (B, 1, H, W, Ky x Kx) + + space_kernel = get_gaussian_kernel2d(kernel_size, sigma_space, device=input.device, dtype=input.dtype) + space_kernel = space_kernel.view(-1, 1, 1, 1, kx * ky) + + kernel = space_kernel * color_kernel + out = (unfolded_input * kernel).sum(-1) / kernel.sum(-1) + return out + + +def bilateral_blur( + input: Tensor, + kernel_size: tuple[int, int] | int = (13, 13), + sigma_color: float | Tensor = 3.0, + sigma_space: tuple[float, float] | Tensor = 3.0, + border_type: str = 'reflect', + color_distance_type: str = 'l1', +) -> Tensor: + return _bilateral_blur(input, None, kernel_size, sigma_color, sigma_space, border_type, color_distance_type) + + +def adaptive_anisotropic_filter(x, g=None): + if g is None: + g = x + s, m = torch.std_mean(g, dim=(1, 2, 3), keepdim=True) + s = s + 1e-5 + guidance = (g - m) / s + y = _bilateral_blur(x, guidance, + kernel_size=(13, 13), + sigma_color=3.0, + sigma_space=3.0, + border_type='reflect', + color_distance_type='l1') + return y + + +def joint_bilateral_blur( + input: Tensor, + guidance: Tensor, + kernel_size: tuple[int, int] | int, + sigma_color: float | Tensor, + sigma_space: tuple[float, float] | Tensor, + border_type: str = 'reflect', + color_distance_type: str = 'l1', +) -> Tensor: + return _bilateral_blur(input, guidance, kernel_size, sigma_color, sigma_space, border_type, color_distance_type) + + +class _BilateralBlur(torch.nn.Module): + def __init__( + self, + kernel_size: tuple[int, int] | int, + sigma_color: float | Tensor, + sigma_space: tuple[float, float] | Tensor, + border_type: str = 'reflect', + color_distance_type: str = "l1", + ) -> None: + super().__init__() + self.kernel_size = kernel_size + self.sigma_color = sigma_color + self.sigma_space = sigma_space + self.border_type = border_type + self.color_distance_type = color_distance_type + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}" + f"(kernel_size={self.kernel_size}, " + f"sigma_color={self.sigma_color}, " + f"sigma_space={self.sigma_space}, " + f"border_type={self.border_type}, " + f"color_distance_type={self.color_distance_type})" + ) + + +class BilateralBlur(_BilateralBlur): + def forward(self, input: Tensor) -> Tensor: + return bilateral_blur( + input, self.kernel_size, self.sigma_color, self.sigma_space, self.border_type, self.color_distance_type + ) + + +class JointBilateralBlur(_BilateralBlur): + def forward(self, input: Tensor, guidance: Tensor) -> Tensor: + return joint_bilateral_blur( + input, + guidance, + self.kernel_size, + self.sigma_color, + self.sigma_space, + self.border_type, + self.color_distance_type, + ) diff --git a/modules/async_worker.py b/modules/async_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..a8661f4ddf1b1ce491f1008395fb09a527b5deb7 --- /dev/null +++ b/modules/async_worker.py @@ -0,0 +1,914 @@ +import threading +from modules.patch import PatchSettings, patch_settings, patch_all + +patch_all() + +class AsyncTask: + def __init__(self, args): + self.args = args + self.yields = [] + self.results = [] + self.last_stop = False + self.processing = False + + +async_tasks = [] + + +def worker(): + global async_tasks + + import os + import traceback + import math + import numpy as np + import cv2 + import torch + import time + import shared + import random + import copy + import modules.default_pipeline as pipeline + import modules.core as core + import modules.flags as flags + import modules.config + import modules.patch + import ldm_patched.modules.model_management + import extras.preprocessors as preprocessors + import modules.inpaint_worker as inpaint_worker + import modules.constants as constants + import extras.ip_adapter as ip_adapter + import extras.face_crop + import fooocus_version + import args_manager + + from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion, apply_arrays + from modules.private_logger import log + from extras.expansion import safe_str + from modules.util import remove_empty_str, HWC3, resize_image, \ + get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix + from modules.upscaler import perform_upscale + from modules.flags import Performance + from modules.meta_parser import get_metadata_parser, MetadataScheme + + pid = os.getpid() + print(f'Started worker with PID {pid}') + + try: + async_gradio_app = shared.gradio_root + flag = f'''App started successful. Use the app with {str(async_gradio_app.local_url)} or {str(async_gradio_app.server_name)}:{str(async_gradio_app.server_port)}''' + if async_gradio_app.share: + flag += f''' or {async_gradio_app.share_url}''' + print(flag) + except Exception as e: + print(e) + + def progressbar(async_task, number, text): + print(f'[Fooocus] {text}') + async_task.yields.append(['preview', (number, text, None)]) + + def yield_result(async_task, imgs, do_not_show_finished_images=False): + if not isinstance(imgs, list): + imgs = [imgs] + + async_task.results = async_task.results + imgs + + if do_not_show_finished_images: + return + + async_task.yields.append(['results', async_task.results]) + return + + def build_image_wall(async_task): + results = [] + + if len(async_task.results) < 2: + return + + for img in async_task.results: + if isinstance(img, str) and os.path.exists(img): + img = cv2.imread(img) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if not isinstance(img, np.ndarray): + return + if img.ndim != 3: + return + results.append(img) + + H, W, C = results[0].shape + + for img in results: + Hn, Wn, Cn = img.shape + if H != Hn: + return + if W != Wn: + return + if C != Cn: + return + + cols = float(len(results)) ** 0.5 + cols = int(math.ceil(cols)) + rows = float(len(results)) / float(cols) + rows = int(math.ceil(rows)) + + wall = np.zeros(shape=(H * rows, W * cols, C), dtype=np.uint8) + + for y in range(rows): + for x in range(cols): + if y * cols + x < len(results): + img = results[y * cols + x] + wall[y * H:y * H + H, x * W:x * W + W, :] = img + + # must use deep copy otherwise gradio is super laggy. Do not use list.append() . + async_task.results = async_task.results + [wall] + return + + def apply_enabled_loras(loras): + enabled_loras = [] + for lora_enabled, lora_model, lora_weight in loras: + if lora_enabled: + enabled_loras.append([lora_model, lora_weight]) + + return enabled_loras + + @torch.no_grad() + @torch.inference_mode() + def handler(async_task): + execution_start_time = time.perf_counter() + async_task.processing = True + + args = async_task.args + args.reverse() + + prompt = args.pop() + negative_prompt = args.pop() + style_selections = args.pop() + performance_selection = Performance(args.pop()) + aspect_ratios_selection = args.pop() + image_number = args.pop() + output_format = args.pop() + image_seed = args.pop() + sharpness = args.pop() + guidance_scale = args.pop() + base_model_name = args.pop() + refiner_model_name = args.pop() + refiner_switch = args.pop() + loras = apply_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop()), ] for _ in range(modules.config.default_max_lora_number)]) + input_image_checkbox = args.pop() + current_tab = args.pop() + uov_method = args.pop() + uov_input_image = args.pop() + outpaint_selections = args.pop() + inpaint_input_image = args.pop() + inpaint_additional_prompt = args.pop() + inpaint_mask_image_upload = args.pop() + + disable_preview = args.pop() + disable_intermediate_results = args.pop() + disable_seed_increment = args.pop() + adm_scaler_positive = args.pop() + adm_scaler_negative = args.pop() + adm_scaler_end = args.pop() + adaptive_cfg = args.pop() + sampler_name = args.pop() + scheduler_name = args.pop() + overwrite_step = args.pop() + overwrite_switch = args.pop() + overwrite_width = args.pop() + overwrite_height = args.pop() + overwrite_vary_strength = args.pop() + overwrite_upscale_strength = args.pop() + mixing_image_prompt_and_vary_upscale = args.pop() + mixing_image_prompt_and_inpaint = args.pop() + debugging_cn_preprocessor = args.pop() + skipping_cn_preprocessor = args.pop() + canny_low_threshold = args.pop() + canny_high_threshold = args.pop() + refiner_swap_method = args.pop() + controlnet_softness = args.pop() + freeu_enabled = args.pop() + freeu_b1 = args.pop() + freeu_b2 = args.pop() + freeu_s1 = args.pop() + freeu_s2 = args.pop() + debugging_inpaint_preprocessor = args.pop() + inpaint_disable_initial_latent = args.pop() + inpaint_engine = args.pop() + inpaint_strength = args.pop() + inpaint_respective_field = args.pop() + inpaint_mask_upload_checkbox = args.pop() + invert_mask_checkbox = args.pop() + inpaint_erode_or_dilate = args.pop() + + save_metadata_to_images = args.pop() if not args_manager.args.disable_metadata else False + metadata_scheme = MetadataScheme(args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS + + cn_tasks = {x: [] for x in flags.ip_list} + for _ in range(flags.controlnet_image_count): + cn_img = args.pop() + cn_stop = args.pop() + cn_weight = args.pop() + cn_type = args.pop() + if cn_img is not None: + cn_tasks[cn_type].append([cn_img, cn_stop, cn_weight]) + + outpaint_selections = [o.lower() for o in outpaint_selections] + base_model_additional_loras = [] + raw_style_selections = copy.deepcopy(style_selections) + uov_method = uov_method.lower() + + if fooocus_expansion in style_selections: + use_expansion = True + style_selections.remove(fooocus_expansion) + else: + use_expansion = False + + use_style = len(style_selections) > 0 + + if base_model_name == refiner_model_name: + print(f'Refiner disabled because base model and refiner are same.') + refiner_model_name = 'None' + + steps = performance_selection.steps() + + if performance_selection == Performance.EXTREME_SPEED: + print('Enter LCM mode.') + progressbar(async_task, 1, 'Downloading LCM components ...') + loras += [(modules.config.downloading_sdxl_lcm_lora(), 1.0)] + + if refiner_model_name != 'None': + print(f'Refiner disabled in LCM mode.') + + refiner_model_name = 'None' + sampler_name = 'lcm' + scheduler_name = 'lcm' + sharpness = 0.0 + guidance_scale = 1.0 + adaptive_cfg = 1.0 + refiner_switch = 1.0 + adm_scaler_positive = 1.0 + adm_scaler_negative = 1.0 + adm_scaler_end = 0.0 + + print(f'[Parameters] Adaptive CFG = {adaptive_cfg}') + print(f'[Parameters] Sharpness = {sharpness}') + print(f'[Parameters] ControlNet Softness = {controlnet_softness}') + print(f'[Parameters] ADM Scale = ' + f'{adm_scaler_positive} : ' + f'{adm_scaler_negative} : ' + f'{adm_scaler_end}') + + patch_settings[pid] = PatchSettings( + sharpness, + adm_scaler_end, + adm_scaler_positive, + adm_scaler_negative, + controlnet_softness, + adaptive_cfg + ) + + cfg_scale = float(guidance_scale) + print(f'[Parameters] CFG = {cfg_scale}') + + initial_latent = None + denoising_strength = 1.0 + tiled = False + + width, height = aspect_ratios_selection.replace('×', ' ').split(' ')[:2] + width, height = int(width), int(height) + + skip_prompt_processing = False + + inpaint_worker.current_task = None + inpaint_parameterized = inpaint_engine != 'None' + inpaint_image = None + inpaint_mask = None + inpaint_head_model_path = None + + use_synthetic_refiner = False + + controlnet_canny_path = None + controlnet_cpds_path = None + clip_vision_path, ip_negative_path, ip_adapter_path, ip_adapter_face_path = None, None, None, None + + seed = int(image_seed) + print(f'[Parameters] Seed = {seed}') + + goals = [] + tasks = [] + + if input_image_checkbox: + if (current_tab == 'uov' or ( + current_tab == 'ip' and mixing_image_prompt_and_vary_upscale)) \ + and uov_method != flags.disabled and uov_input_image is not None: + uov_input_image = HWC3(uov_input_image) + if 'vary' in uov_method: + goals.append('vary') + elif 'upscale' in uov_method: + goals.append('upscale') + if 'fast' in uov_method: + skip_prompt_processing = True + else: + steps = performance_selection.steps_uov() + + progressbar(async_task, 1, 'Downloading upscale models ...') + modules.config.downloading_upscale_model() + if (current_tab == 'inpaint' or ( + current_tab == 'ip' and mixing_image_prompt_and_inpaint)) \ + and isinstance(inpaint_input_image, dict): + inpaint_image = inpaint_input_image['image'] + inpaint_mask = inpaint_input_image['mask'][:, :, 0] + + if inpaint_mask_upload_checkbox: + if isinstance(inpaint_mask_image_upload, np.ndarray): + if inpaint_mask_image_upload.ndim == 3: + H, W, C = inpaint_image.shape + inpaint_mask_image_upload = resample_image(inpaint_mask_image_upload, width=W, height=H) + inpaint_mask_image_upload = np.mean(inpaint_mask_image_upload, axis=2) + inpaint_mask_image_upload = (inpaint_mask_image_upload > 127).astype(np.uint8) * 255 + inpaint_mask = np.maximum(inpaint_mask, inpaint_mask_image_upload) + + if int(inpaint_erode_or_dilate) != 0: + inpaint_mask = erode_or_dilate(inpaint_mask, inpaint_erode_or_dilate) + + if invert_mask_checkbox: + inpaint_mask = 255 - inpaint_mask + + inpaint_image = HWC3(inpaint_image) + if isinstance(inpaint_image, np.ndarray) and isinstance(inpaint_mask, np.ndarray) \ + and (np.any(inpaint_mask > 127) or len(outpaint_selections) > 0): + progressbar(async_task, 1, 'Downloading upscale models ...') + modules.config.downloading_upscale_model() + if inpaint_parameterized: + progressbar(async_task, 1, 'Downloading inpainter ...') + inpaint_head_model_path, inpaint_patch_model_path = modules.config.downloading_inpaint_models( + inpaint_engine) + base_model_additional_loras += [(inpaint_patch_model_path, 1.0)] + print(f'[Inpaint] Current inpaint model is {inpaint_patch_model_path}') + if refiner_model_name == 'None': + use_synthetic_refiner = True + refiner_switch = 0.5 + else: + inpaint_head_model_path, inpaint_patch_model_path = None, None + print(f'[Inpaint] Parameterized inpaint is disabled.') + if inpaint_additional_prompt != '': + if prompt == '': + prompt = inpaint_additional_prompt + else: + prompt = inpaint_additional_prompt + '\n' + prompt + goals.append('inpaint') + if current_tab == 'ip' or \ + mixing_image_prompt_and_vary_upscale or \ + mixing_image_prompt_and_inpaint: + goals.append('cn') + progressbar(async_task, 1, 'Downloading control models ...') + if len(cn_tasks[flags.cn_canny]) > 0: + controlnet_canny_path = modules.config.downloading_controlnet_canny() + if len(cn_tasks[flags.cn_cpds]) > 0: + controlnet_cpds_path = modules.config.downloading_controlnet_cpds() + if len(cn_tasks[flags.cn_ip]) > 0: + clip_vision_path, ip_negative_path, ip_adapter_path = modules.config.downloading_ip_adapters('ip') + if len(cn_tasks[flags.cn_ip_face]) > 0: + clip_vision_path, ip_negative_path, ip_adapter_face_path = modules.config.downloading_ip_adapters( + 'face') + progressbar(async_task, 1, 'Loading control models ...') + + # Load or unload CNs + pipeline.refresh_controlnets([controlnet_canny_path, controlnet_cpds_path]) + ip_adapter.load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_path) + ip_adapter.load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_face_path) + + if overwrite_step > 0: + steps = overwrite_step + + switch = int(round(steps * refiner_switch)) + + if overwrite_switch > 0: + switch = overwrite_switch + + if overwrite_width > 0: + width = overwrite_width + + if overwrite_height > 0: + height = overwrite_height + + print(f'[Parameters] Sampler = {sampler_name} - {scheduler_name}') + print(f'[Parameters] Steps = {steps} - {switch}') + + progressbar(async_task, 1, 'Initializing ...') + + if not skip_prompt_processing: + + prompts = remove_empty_str([safe_str(p) for p in prompt.splitlines()], default='') + negative_prompts = remove_empty_str([safe_str(p) for p in negative_prompt.splitlines()], default='') + + prompt = prompts[0] + negative_prompt = negative_prompts[0] + + if prompt == '': + # disable expansion when empty since it is not meaningful and influences image prompt + use_expansion = False + + extra_positive_prompts = prompts[1:] if len(prompts) > 1 else [] + extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else [] + + progressbar(async_task, 3, 'Loading models ...') + pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name, + loras=loras, base_model_additional_loras=base_model_additional_loras, + use_synthetic_refiner=use_synthetic_refiner) + + progressbar(async_task, 3, 'Processing prompts ...') + tasks = [] + + for i in range(image_number): + if disable_seed_increment: + task_seed = seed + else: + task_seed = (seed + i) % (constants.MAX_SEED + 1) # randint is inclusive, % is not + + task_rng = random.Random(task_seed) # may bind to inpaint noise in the future + task_prompt = apply_wildcards(prompt, task_rng) + task_prompt = apply_arrays(task_prompt, i) + task_negative_prompt = apply_wildcards(negative_prompt, task_rng) + task_extra_positive_prompts = [apply_wildcards(pmt, task_rng) for pmt in extra_positive_prompts] + task_extra_negative_prompts = [apply_wildcards(pmt, task_rng) for pmt in extra_negative_prompts] + + positive_basic_workloads = [] + negative_basic_workloads = [] + + if use_style: + for s in style_selections: + p, n = apply_style(s, positive=task_prompt) + positive_basic_workloads = positive_basic_workloads + p + negative_basic_workloads = negative_basic_workloads + n + else: + positive_basic_workloads.append(task_prompt) + + negative_basic_workloads.append(task_negative_prompt) # Always use independent workload for negative. + + positive_basic_workloads = positive_basic_workloads + task_extra_positive_prompts + negative_basic_workloads = negative_basic_workloads + task_extra_negative_prompts + + positive_basic_workloads = remove_empty_str(positive_basic_workloads, default=task_prompt) + negative_basic_workloads = remove_empty_str(negative_basic_workloads, default=task_negative_prompt) + + tasks.append(dict( + task_seed=task_seed, + task_prompt=task_prompt, + task_negative_prompt=task_negative_prompt, + positive=positive_basic_workloads, + negative=negative_basic_workloads, + expansion='', + c=None, + uc=None, + positive_top_k=len(positive_basic_workloads), + negative_top_k=len(negative_basic_workloads), + log_positive_prompt='\n'.join([task_prompt] + task_extra_positive_prompts), + log_negative_prompt='\n'.join([task_negative_prompt] + task_extra_negative_prompts), + )) + + if use_expansion: + for i, t in enumerate(tasks): + progressbar(async_task, 5, f'Preparing Fooocus text #{i + 1} ...') + expansion = pipeline.final_expansion(t['task_prompt'], t['task_seed']) + print(f'[Prompt Expansion] {expansion}') + t['expansion'] = expansion + t['positive'] = copy.deepcopy(t['positive']) + [expansion] # Deep copy. + + for i, t in enumerate(tasks): + progressbar(async_task, 7, f'Encoding positive #{i + 1} ...') + t['c'] = pipeline.clip_encode(texts=t['positive'], pool_top_k=t['positive_top_k']) + + for i, t in enumerate(tasks): + if abs(float(cfg_scale) - 1.0) < 1e-4: + t['uc'] = pipeline.clone_cond(t['c']) + else: + progressbar(async_task, 10, f'Encoding negative #{i + 1} ...') + t['uc'] = pipeline.clip_encode(texts=t['negative'], pool_top_k=t['negative_top_k']) + + if len(goals) > 0: + progressbar(async_task, 13, 'Image processing ...') + + if 'vary' in goals: + if 'subtle' in uov_method: + denoising_strength = 0.5 + if 'strong' in uov_method: + denoising_strength = 0.85 + if overwrite_vary_strength > 0: + denoising_strength = overwrite_vary_strength + + shape_ceil = get_image_shape_ceil(uov_input_image) + if shape_ceil < 1024: + print(f'[Vary] Image is resized because it is too small.') + shape_ceil = 1024 + elif shape_ceil > 2048: + print(f'[Vary] Image is resized because it is too big.') + shape_ceil = 2048 + + uov_input_image = set_image_shape_ceil(uov_input_image, shape_ceil) + + initial_pixels = core.numpy_to_pytorch(uov_input_image) + progressbar(async_task, 13, 'VAE encoding ...') + + candidate_vae, _ = pipeline.get_candidate_vae( + steps=steps, + switch=switch, + denoise=denoising_strength, + refiner_swap_method=refiner_swap_method + ) + + initial_latent = core.encode_vae(vae=candidate_vae, pixels=initial_pixels) + B, C, H, W = initial_latent['samples'].shape + width = W * 8 + height = H * 8 + print(f'Final resolution is {str((height, width))}.') + + if 'upscale' in goals: + H, W, C = uov_input_image.shape + progressbar(async_task, 13, f'Upscaling image from {str((H, W))} ...') + uov_input_image = perform_upscale(uov_input_image) + print(f'Image upscaled.') + + if '1.5x' in uov_method: + f = 1.5 + elif '2x' in uov_method: + f = 2.0 + else: + f = 1.0 + + shape_ceil = get_shape_ceil(H * f, W * f) + + if shape_ceil < 1024: + print(f'[Upscale] Image is resized because it is too small.') + uov_input_image = set_image_shape_ceil(uov_input_image, 1024) + shape_ceil = 1024 + else: + uov_input_image = resample_image(uov_input_image, width=W * f, height=H * f) + + image_is_super_large = shape_ceil > 2800 + + if 'fast' in uov_method: + direct_return = True + elif image_is_super_large: + print('Image is too large. Directly returned the SR image. ' + 'Usually directly return SR image at 4K resolution ' + 'yields better results than SDXL diffusion.') + direct_return = True + else: + direct_return = False + + if direct_return: + d = [('Upscale (Fast)', 'upscale_fast', '2x')] + uov_input_image_path = log(uov_input_image, d, output_format=output_format) + yield_result(async_task, uov_input_image_path, do_not_show_finished_images=True) + return + + tiled = True + denoising_strength = 0.382 + + if overwrite_upscale_strength > 0: + denoising_strength = overwrite_upscale_strength + + initial_pixels = core.numpy_to_pytorch(uov_input_image) + progressbar(async_task, 13, 'VAE encoding ...') + + candidate_vae, _ = pipeline.get_candidate_vae( + steps=steps, + switch=switch, + denoise=denoising_strength, + refiner_swap_method=refiner_swap_method + ) + + initial_latent = core.encode_vae( + vae=candidate_vae, + pixels=initial_pixels, tiled=True) + B, C, H, W = initial_latent['samples'].shape + width = W * 8 + height = H * 8 + print(f'Final resolution is {str((height, width))}.') + + if 'inpaint' in goals: + if len(outpaint_selections) > 0: + H, W, C = inpaint_image.shape + if 'top' in outpaint_selections: + inpaint_image = np.pad(inpaint_image, [[int(H * 0.3), 0], [0, 0], [0, 0]], mode='edge') + inpaint_mask = np.pad(inpaint_mask, [[int(H * 0.3), 0], [0, 0]], mode='constant', + constant_values=255) + if 'bottom' in outpaint_selections: + inpaint_image = np.pad(inpaint_image, [[0, int(H * 0.3)], [0, 0], [0, 0]], mode='edge') + inpaint_mask = np.pad(inpaint_mask, [[0, int(H * 0.3)], [0, 0]], mode='constant', + constant_values=255) + + H, W, C = inpaint_image.shape + if 'left' in outpaint_selections: + inpaint_image = np.pad(inpaint_image, [[0, 0], [int(H * 0.3), 0], [0, 0]], mode='edge') + inpaint_mask = np.pad(inpaint_mask, [[0, 0], [int(H * 0.3), 0]], mode='constant', + constant_values=255) + if 'right' in outpaint_selections: + inpaint_image = np.pad(inpaint_image, [[0, 0], [0, int(H * 0.3)], [0, 0]], mode='edge') + inpaint_mask = np.pad(inpaint_mask, [[0, 0], [0, int(H * 0.3)]], mode='constant', + constant_values=255) + + inpaint_image = np.ascontiguousarray(inpaint_image.copy()) + inpaint_mask = np.ascontiguousarray(inpaint_mask.copy()) + inpaint_strength = 1.0 + inpaint_respective_field = 1.0 + + denoising_strength = inpaint_strength + + inpaint_worker.current_task = inpaint_worker.InpaintWorker( + image=inpaint_image, + mask=inpaint_mask, + use_fill=denoising_strength > 0.99, + k=inpaint_respective_field + ) + + if debugging_inpaint_preprocessor: + yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), + do_not_show_finished_images=True) + return + + progressbar(async_task, 13, 'VAE Inpaint encoding ...') + + inpaint_pixel_fill = core.numpy_to_pytorch(inpaint_worker.current_task.interested_fill) + inpaint_pixel_image = core.numpy_to_pytorch(inpaint_worker.current_task.interested_image) + inpaint_pixel_mask = core.numpy_to_pytorch(inpaint_worker.current_task.interested_mask) + + candidate_vae, candidate_vae_swap = pipeline.get_candidate_vae( + steps=steps, + switch=switch, + denoise=denoising_strength, + refiner_swap_method=refiner_swap_method + ) + + latent_inpaint, latent_mask = core.encode_vae_inpaint( + mask=inpaint_pixel_mask, + vae=candidate_vae, + pixels=inpaint_pixel_image) + + latent_swap = None + if candidate_vae_swap is not None: + progressbar(async_task, 13, 'VAE SD15 encoding ...') + latent_swap = core.encode_vae( + vae=candidate_vae_swap, + pixels=inpaint_pixel_fill)['samples'] + + progressbar(async_task, 13, 'VAE encoding ...') + latent_fill = core.encode_vae( + vae=candidate_vae, + pixels=inpaint_pixel_fill)['samples'] + + inpaint_worker.current_task.load_latent( + latent_fill=latent_fill, latent_mask=latent_mask, latent_swap=latent_swap) + + if inpaint_parameterized: + pipeline.final_unet = inpaint_worker.current_task.patch( + inpaint_head_model_path=inpaint_head_model_path, + inpaint_latent=latent_inpaint, + inpaint_latent_mask=latent_mask, + model=pipeline.final_unet + ) + + if not inpaint_disable_initial_latent: + initial_latent = {'samples': latent_fill} + + B, C, H, W = latent_fill.shape + height, width = H * 8, W * 8 + final_height, final_width = inpaint_worker.current_task.image.shape[:2] + print(f'Final resolution is {str((final_height, final_width))}, latent is {str((height, width))}.') + + if 'cn' in goals: + for task in cn_tasks[flags.cn_canny]: + cn_img, cn_stop, cn_weight = task + cn_img = resize_image(HWC3(cn_img), width=width, height=height) + + if not skipping_cn_preprocessor: + cn_img = preprocessors.canny_pyramid(cn_img, canny_low_threshold, canny_high_threshold) + + cn_img = HWC3(cn_img) + task[0] = core.numpy_to_pytorch(cn_img) + if debugging_cn_preprocessor: + yield_result(async_task, cn_img, do_not_show_finished_images=True) + return + for task in cn_tasks[flags.cn_cpds]: + cn_img, cn_stop, cn_weight = task + cn_img = resize_image(HWC3(cn_img), width=width, height=height) + + if not skipping_cn_preprocessor: + cn_img = preprocessors.cpds(cn_img) + + cn_img = HWC3(cn_img) + task[0] = core.numpy_to_pytorch(cn_img) + if debugging_cn_preprocessor: + yield_result(async_task, cn_img, do_not_show_finished_images=True) + return + for task in cn_tasks[flags.cn_ip]: + cn_img, cn_stop, cn_weight = task + cn_img = HWC3(cn_img) + + # https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/README.md?plain=1#L75 + cn_img = resize_image(cn_img, width=224, height=224, resize_mode=0) + + task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_path) + if debugging_cn_preprocessor: + yield_result(async_task, cn_img, do_not_show_finished_images=True) + return + for task in cn_tasks[flags.cn_ip_face]: + cn_img, cn_stop, cn_weight = task + cn_img = HWC3(cn_img) + + if not skipping_cn_preprocessor: + cn_img = extras.face_crop.crop_image(cn_img) + + # https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/README.md?plain=1#L75 + cn_img = resize_image(cn_img, width=224, height=224, resize_mode=0) + + task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_face_path) + if debugging_cn_preprocessor: + yield_result(async_task, cn_img, do_not_show_finished_images=True) + return + + all_ip_tasks = cn_tasks[flags.cn_ip] + cn_tasks[flags.cn_ip_face] + + if len(all_ip_tasks) > 0: + pipeline.final_unet = ip_adapter.patch_model(pipeline.final_unet, all_ip_tasks) + + if freeu_enabled: + print(f'FreeU is enabled!') + pipeline.final_unet = core.apply_freeu( + pipeline.final_unet, + freeu_b1, + freeu_b2, + freeu_s1, + freeu_s2 + ) + + all_steps = steps * image_number + + print(f'[Parameters] Denoising Strength = {denoising_strength}') + + if isinstance(initial_latent, dict) and 'samples' in initial_latent: + log_shape = initial_latent['samples'].shape + else: + log_shape = f'Image Space {(height, width)}' + + print(f'[Parameters] Initial Latent shape: {log_shape}') + + preparation_time = time.perf_counter() - execution_start_time + print(f'Preparation time: {preparation_time:.2f} seconds') + + final_sampler_name = sampler_name + final_scheduler_name = scheduler_name + + if scheduler_name == 'lcm': + final_scheduler_name = 'sgm_uniform' + if pipeline.final_unet is not None: + pipeline.final_unet = core.opModelSamplingDiscrete.patch( + pipeline.final_unet, + sampling='lcm', + zsnr=False)[0] + if pipeline.final_refiner_unet is not None: + pipeline.final_refiner_unet = core.opModelSamplingDiscrete.patch( + pipeline.final_refiner_unet, + sampling='lcm', + zsnr=False)[0] + print('Using lcm scheduler.') + + async_task.yields.append(['preview', (13, 'Moving model to GPU ...', None)]) + + def callback(step, x0, x, total_steps, y): + done_steps = current_task_id * steps + step + async_task.yields.append(['preview', ( + int(15.0 + 85.0 * float(done_steps) / float(all_steps)), + f'Step {step}/{total_steps} in the {current_task_id + 1}{ordinal_suffix(current_task_id + 1)} Sampling', y)]) + + for current_task_id, task in enumerate(tasks): + execution_start_time = time.perf_counter() + + try: + if async_task.last_stop is not False: + ldm_patched.model_management.interrupt_current_processing() + positive_cond, negative_cond = task['c'], task['uc'] + + if 'cn' in goals: + for cn_flag, cn_path in [ + (flags.cn_canny, controlnet_canny_path), + (flags.cn_cpds, controlnet_cpds_path) + ]: + for cn_img, cn_stop, cn_weight in cn_tasks[cn_flag]: + positive_cond, negative_cond = core.apply_controlnet( + positive_cond, negative_cond, + pipeline.loaded_ControlNets[cn_path], cn_img, cn_weight, 0, cn_stop) + + imgs = pipeline.process_diffusion( + positive_cond=positive_cond, + negative_cond=negative_cond, + steps=steps, + switch=switch, + width=width, + height=height, + image_seed=task['task_seed'], + callback=callback, + sampler_name=final_sampler_name, + scheduler_name=final_scheduler_name, + latent=initial_latent, + denoise=denoising_strength, + tiled=tiled, + cfg_scale=cfg_scale, + refiner_swap_method=refiner_swap_method, + disable_preview=disable_preview + ) + + del task['c'], task['uc'], positive_cond, negative_cond # Save memory + + if inpaint_worker.current_task is not None: + imgs = [inpaint_worker.current_task.post_process(x) for x in imgs] + + img_paths = [] + for x in imgs: + d = [('Prompt', 'prompt', task['log_positive_prompt']), + ('Negative Prompt', 'negative_prompt', task['log_negative_prompt']), + ('Fooocus V2 Expansion', 'prompt_expansion', task['expansion']), + ('Styles', 'styles', str(raw_style_selections)), + ('Performance', 'performance', performance_selection.value)] + + if performance_selection.steps() != steps: + d.append(('Steps', 'steps', steps)) + + d += [('Resolution', 'resolution', str((width, height))), + ('Guidance Scale', 'guidance_scale', guidance_scale), + ('Sharpness', 'sharpness', sharpness), + ('ADM Guidance', 'adm_guidance', str(( + modules.patch.patch_settings[pid].positive_adm_scale, + modules.patch.patch_settings[pid].negative_adm_scale, + modules.patch.patch_settings[pid].adm_scaler_end))), + ('Base Model', 'base_model', base_model_name), + ('Refiner Model', 'refiner_model', refiner_model_name), + ('Refiner Switch', 'refiner_switch', refiner_switch)] + + if refiner_model_name != 'None': + if overwrite_switch > 0: + d.append(('Overwrite Switch', 'overwrite_switch', overwrite_switch)) + if refiner_swap_method != flags.refiner_swap_method: + d.append(('Refiner Swap Method', 'refiner_swap_method', refiner_swap_method)) + if modules.patch.patch_settings[pid].adaptive_cfg != modules.config.default_cfg_tsnr: + d.append(('CFG Mimicking from TSNR', 'adaptive_cfg', modules.patch.patch_settings[pid].adaptive_cfg)) + + d.append(('Sampler', 'sampler', sampler_name)) + d.append(('Scheduler', 'scheduler', scheduler_name)) + d.append(('Seed', 'seed', task['task_seed'])) + + if freeu_enabled: + d.append(('FreeU', 'freeu', str((freeu_b1, freeu_b2, freeu_s1, freeu_s2)))) + + for li, (n, w) in enumerate(loras): + if n != 'None': + d.append((f'LoRA {li + 1}', f'lora_combined_{li + 1}', f'{n} : {w}')) + + metadata_parser = None + if save_metadata_to_images: + metadata_parser = modules.meta_parser.get_metadata_parser(metadata_scheme) + metadata_parser.set_data(task['log_positive_prompt'], task['positive'], + task['log_negative_prompt'], task['negative'], + steps, base_model_name, refiner_model_name, loras) + d.append(('Metadata Scheme', 'metadata_scheme', metadata_scheme.value if save_metadata_to_images else save_metadata_to_images)) + d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version)) + img_paths.append(log(x, d, metadata_parser, output_format)) + + yield_result(async_task, img_paths, do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results) + except ldm_patched.modules.model_management.InterruptProcessingException as e: + if async_task.last_stop == 'skip': + print('User skipped') + async_task.last_stop = False + continue + else: + print('User stopped') + break + + execution_time = time.perf_counter() - execution_start_time + print(f'Generating and saving time: {execution_time:.2f} seconds') + async_task.processing = False + return + + while True: + time.sleep(0.01) + if len(async_tasks) > 0: + task = async_tasks.pop(0) + generate_image_grid = task.args.pop(0) + + try: + handler(task) + if generate_image_grid: + build_image_wall(task) + task.yields.append(['finish', task.results]) + pipeline.prepare_text_encoder(async_call=True) + except: + traceback.print_exc() + task.yields.append(['finish', task.results]) + finally: + if pid in modules.patch.patch_settings: + del modules.patch.patch_settings[pid] + pass + + +threading.Thread(target=worker, daemon=True).start() diff --git a/modules/auth.py b/modules/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..3ba111424523c19174f8b741b3bbac7b43b7bb6c --- /dev/null +++ b/modules/auth.py @@ -0,0 +1,41 @@ +import json +import hashlib +import modules.constants as constants + +from os.path import exists + + +def auth_list_to_dict(auth_list): + auth_dict = {} + for auth_data in auth_list: + if 'user' in auth_data: + if 'hash' in auth_data: + auth_dict |= {auth_data['user']: auth_data['hash']} + elif 'pass' in auth_data: + auth_dict |= {auth_data['user']: hashlib.sha256(bytes(auth_data['pass'], encoding='utf-8')).hexdigest()} + return auth_dict + + +def load_auth_data(filename=None): + auth_dict = None + if filename != None and exists(filename): + with open(filename, encoding='utf-8') as auth_file: + try: + auth_obj = json.load(auth_file) + if isinstance(auth_obj, list) and len(auth_obj) > 0: + auth_dict = auth_list_to_dict(auth_obj) + except Exception as e: + print('load_auth_data, e: ' + str(e)) + return auth_dict + + +auth_dict = load_auth_data(constants.AUTH_FILENAME) + +auth_enabled = auth_dict != None + + +def check_auth(user, password): + if user not in auth_dict: + return False + else: + return hashlib.sha256(bytes(password, encoding='utf-8')).hexdigest() == auth_dict[user] diff --git a/modules/config.py b/modules/config.py new file mode 100644 index 0000000000000000000000000000000000000000..a68bd2187f7f766708482edabd8bdb0647e3cacb --- /dev/null +++ b/modules/config.py @@ -0,0 +1,607 @@ +import os +import json +import math +import numbers +import args_manager +import modules.flags +import modules.sdxl_styles + +from modules.model_loader import load_file_from_url +from modules.util import get_files_from_folder, makedirs_with_log +from modules.flags import Performance, MetadataScheme + +def get_config_path(key, default_value): + env = os.getenv(key) + if env is not None and isinstance(env, str): + print(f"Environment: {key} = {env}") + return env + else: + return os.path.abspath(default_value) + +config_path = get_config_path('config_path', "./config.txt") +config_example_path = get_config_path('config_example_path', "config_modification_tutorial.txt") +config_dict = {} +always_save_keys = [] +visited_keys = [] + +try: + with open(os.path.abspath(f'./presets/default.json'), "r", encoding="utf-8") as json_file: + config_dict.update(json.load(json_file)) +except Exception as e: + print(f'Load default preset failed.') + print(e) + +try: + if os.path.exists(config_path): + with open(config_path, "r", encoding="utf-8") as json_file: + config_dict.update(json.load(json_file)) + always_save_keys = list(config_dict.keys()) +except Exception as e: + print(f'Failed to load config file "{config_path}" . The reason is: {str(e)}') + print('Please make sure that:') + print(f'1. The file "{config_path}" is a valid text file, and you have access to read it.') + print('2. Use "\\\\" instead of "\\" when describing paths.') + print('3. There is no "," before the last "}".') + print('4. All key/value formats are correct.') + + +def try_load_deprecated_user_path_config(): + global config_dict + + if not os.path.exists('user_path_config.txt'): + return + + try: + deprecated_config_dict = json.load(open('user_path_config.txt', "r", encoding="utf-8")) + + def replace_config(old_key, new_key): + if old_key in deprecated_config_dict: + config_dict[new_key] = deprecated_config_dict[old_key] + del deprecated_config_dict[old_key] + + replace_config('modelfile_path', 'path_checkpoints') + replace_config('lorafile_path', 'path_loras') + replace_config('embeddings_path', 'path_embeddings') + replace_config('vae_approx_path', 'path_vae_approx') + replace_config('upscale_models_path', 'path_upscale_models') + replace_config('inpaint_models_path', 'path_inpaint') + replace_config('controlnet_models_path', 'path_controlnet') + replace_config('clip_vision_models_path', 'path_clip_vision') + replace_config('fooocus_expansion_path', 'path_fooocus_expansion') + replace_config('temp_outputs_path', 'path_outputs') + + if deprecated_config_dict.get("default_model", None) == 'juggernautXL_version6Rundiffusion.safetensors': + os.replace('user_path_config.txt', 'user_path_config-deprecated.txt') + print('Config updated successfully in silence. ' + 'A backup of previous config is written to "user_path_config-deprecated.txt".') + return + + if input("Newer models and configs are available. " + "Download and update files? [Y/n]:") in ['n', 'N', 'No', 'no', 'NO']: + config_dict.update(deprecated_config_dict) + print('Loading using deprecated old models and deprecated old configs.') + return + else: + os.replace('user_path_config.txt', 'user_path_config-deprecated.txt') + print('Config updated successfully by user. ' + 'A backup of previous config is written to "user_path_config-deprecated.txt".') + return + except Exception as e: + print('Processing deprecated config failed') + print(e) + return + + +try_load_deprecated_user_path_config() + +preset = args_manager.args.preset + +if isinstance(preset, str): + preset_path = os.path.abspath(f'./presets/{preset}.json') + try: + if os.path.exists(preset_path): + with open(preset_path, "r", encoding="utf-8") as json_file: + config_dict.update(json.load(json_file)) + print(f'Loaded preset: {preset_path}') + else: + raise FileNotFoundError + except Exception as e: + print(f'Load preset [{preset_path}] failed') + print(e) + + +def get_path_output() -> str: + """ + Checking output path argument and overriding default path. + """ + global config_dict + path_output = get_dir_or_set_default('path_outputs', '../outputs/', make_directory=True) + if args_manager.args.output_path: + print(f'[CONFIG] Overriding config value path_outputs with {args_manager.args.output_path}') + config_dict['path_outputs'] = path_output = args_manager.args.output_path + return path_output + + +def get_dir_or_set_default(key, default_value, as_array=False, make_directory=False): + global config_dict, visited_keys, always_save_keys + + if key not in visited_keys: + visited_keys.append(key) + + if key not in always_save_keys: + always_save_keys.append(key) + + v = os.getenv(key) + if v is not None: + print(f"Environment: {key} = {v}") + config_dict[key] = v + else: + v = config_dict.get(key, None) + + if isinstance(v, str): + if make_directory: + makedirs_with_log(v) + if os.path.exists(v) and os.path.isdir(v): + return v if not as_array else [v] + elif isinstance(v, list): + if make_directory: + for d in v: + makedirs_with_log(d) + if all([os.path.exists(d) and os.path.isdir(d) for d in v]): + return v + + if v is not None: + print(f'Failed to load config key: {json.dumps({key:v})} is invalid or does not exist; will use {json.dumps({key:default_value})} instead.') + if isinstance(default_value, list): + dp = [] + for path in default_value: + abs_path = os.path.abspath(os.path.join(os.path.dirname(__file__), path)) + dp.append(abs_path) + os.makedirs(abs_path, exist_ok=True) + else: + dp = os.path.abspath(os.path.join(os.path.dirname(__file__), default_value)) + os.makedirs(dp, exist_ok=True) + if as_array: + dp = [dp] + config_dict[key] = dp + return dp + + +paths_checkpoints = get_dir_or_set_default('path_checkpoints', ['../models/checkpoints/'], True) +paths_loras = get_dir_or_set_default('path_loras', ['../models/loras/'], True) +path_embeddings = get_dir_or_set_default('path_embeddings', '../models/embeddings/') +path_vae_approx = get_dir_or_set_default('path_vae_approx', '../models/vae_approx/') +path_upscale_models = get_dir_or_set_default('path_upscale_models', '../models/upscale_models/') +path_inpaint = get_dir_or_set_default('path_inpaint', '../models/inpaint/') +path_controlnet = get_dir_or_set_default('path_controlnet', '../models/controlnet/') +path_clip_vision = get_dir_or_set_default('path_clip_vision', '../models/clip_vision/') +path_fooocus_expansion = get_dir_or_set_default('path_fooocus_expansion', '../models/prompt_expansion/fooocus_expansion') +path_outputs = get_path_output() + +def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False): + global config_dict, visited_keys + + if key not in visited_keys: + visited_keys.append(key) + + v = os.getenv(key) + if v is not None: + print(f"Environment: {key} = {v}") + config_dict[key] = v + + if key not in config_dict: + config_dict[key] = default_value + return default_value + + v = config_dict.get(key, None) + if not disable_empty_as_none: + if v is None or v == '': + v = 'None' + if validator(v): + return v + else: + if v is not None: + print(f'Failed to load config key: {json.dumps({key:v})} is invalid; will use {json.dumps({key:default_value})} instead.') + config_dict[key] = default_value + return default_value + + +default_base_model_name = get_config_item_or_set_default( + key='default_model', + default_value='model.safetensors', + validator=lambda x: isinstance(x, str) +) +previous_default_models = get_config_item_or_set_default( + key='previous_default_models', + default_value=[], + validator=lambda x: isinstance(x, list) and all(isinstance(k, str) for k in x) +) +default_refiner_model_name = get_config_item_or_set_default( + key='default_refiner', + default_value='None', + validator=lambda x: isinstance(x, str) +) +default_refiner_switch = get_config_item_or_set_default( + key='default_refiner_switch', + default_value=0.8, + validator=lambda x: isinstance(x, numbers.Number) and 0 <= x <= 1 +) +default_loras_min_weight = get_config_item_or_set_default( + key='default_loras_min_weight', + default_value=-2, + validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10 +) +default_loras_max_weight = get_config_item_or_set_default( + key='default_loras_max_weight', + default_value=2, + validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10 +) +default_loras = get_config_item_or_set_default( + key='default_loras', + default_value=[ + [ + "None", + 1.0 + ], + [ + "None", + 1.0 + ], + [ + "None", + 1.0 + ], + [ + "None", + 1.0 + ], + [ + "None", + 1.0 + ] + ], + validator=lambda x: isinstance(x, list) and all(len(y) == 2 and isinstance(y[0], str) and isinstance(y[1], numbers.Number) for y in x) +) +default_max_lora_number = get_config_item_or_set_default( + key='default_max_lora_number', + default_value=len(default_loras) if isinstance(default_loras, list) and len(default_loras) > 0 else 5, + validator=lambda x: isinstance(x, int) and x >= 1 +) +default_cfg_scale = get_config_item_or_set_default( + key='default_cfg_scale', + default_value=7.0, + validator=lambda x: isinstance(x, numbers.Number) +) +default_sample_sharpness = get_config_item_or_set_default( + key='default_sample_sharpness', + default_value=2.0, + validator=lambda x: isinstance(x, numbers.Number) +) +default_sampler = get_config_item_or_set_default( + key='default_sampler', + default_value='dpmpp_2m_sde_gpu', + validator=lambda x: x in modules.flags.sampler_list +) +default_scheduler = get_config_item_or_set_default( + key='default_scheduler', + default_value='karras', + validator=lambda x: x in modules.flags.scheduler_list +) +default_styles = get_config_item_or_set_default( + key='default_styles', + default_value=[ + "Fooocus V2", + "Fooocus Enhance", + "Fooocus Sharp" + ], + validator=lambda x: isinstance(x, list) and all(y in modules.sdxl_styles.legal_style_names for y in x) +) +default_prompt_negative = get_config_item_or_set_default( + key='default_prompt_negative', + default_value='', + validator=lambda x: isinstance(x, str), + disable_empty_as_none=True +) +default_prompt = get_config_item_or_set_default( + key='default_prompt', + default_value='', + validator=lambda x: isinstance(x, str), + disable_empty_as_none=True +) +default_performance = get_config_item_or_set_default( + key='default_performance', + default_value=Performance.SPEED.value, + validator=lambda x: x in Performance.list() +) +default_advanced_checkbox = get_config_item_or_set_default( + key='default_advanced_checkbox', + default_value=False, + validator=lambda x: isinstance(x, bool) +) +default_max_image_number = get_config_item_or_set_default( + key='default_max_image_number', + default_value=32, + validator=lambda x: isinstance(x, int) and x >= 1 +) +default_output_format = get_config_item_or_set_default( + key='default_output_format', + default_value='png', + validator=lambda x: x in modules.flags.output_formats +) +default_image_number = get_config_item_or_set_default( + key='default_image_number', + default_value=2, + validator=lambda x: isinstance(x, int) and 1 <= x <= default_max_image_number +) +checkpoint_downloads = get_config_item_or_set_default( + key='checkpoint_downloads', + default_value={}, + validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()) +) +lora_downloads = get_config_item_or_set_default( + key='lora_downloads', + default_value={}, + validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()) +) +embeddings_downloads = get_config_item_or_set_default( + key='embeddings_downloads', + default_value={}, + validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()) +) +available_aspect_ratios = get_config_item_or_set_default( + key='available_aspect_ratios', + default_value=[ + '704*1408', '704*1344', '768*1344', '768*1280', '832*1216', '832*1152', + '896*1152', '896*1088', '960*1088', '960*1024', '1024*1024', '1024*960', + '1088*960', '1088*896', '1152*896', '1152*832', '1216*832', '1280*768', + '1344*768', '1344*704', '1408*704', '1472*704', '1536*640', '1600*640', + '1664*576', '1728*576' + ], + validator=lambda x: isinstance(x, list) and all('*' in v for v in x) and len(x) > 1 +) +default_aspect_ratio = get_config_item_or_set_default( + key='default_aspect_ratio', + default_value='1152*896' if '1152*896' in available_aspect_ratios else available_aspect_ratios[0], + validator=lambda x: x in available_aspect_ratios +) +default_inpaint_engine_version = get_config_item_or_set_default( + key='default_inpaint_engine_version', + default_value='v2.6', + validator=lambda x: x in modules.flags.inpaint_engine_versions +) +default_cfg_tsnr = get_config_item_or_set_default( + key='default_cfg_tsnr', + default_value=7.0, + validator=lambda x: isinstance(x, numbers.Number) +) +default_overwrite_step = get_config_item_or_set_default( + key='default_overwrite_step', + default_value=-1, + validator=lambda x: isinstance(x, int) +) +default_overwrite_switch = get_config_item_or_set_default( + key='default_overwrite_switch', + default_value=-1, + validator=lambda x: isinstance(x, int) +) +example_inpaint_prompts = get_config_item_or_set_default( + key='example_inpaint_prompts', + default_value=[ + 'highly detailed face', 'detailed girl face', 'detailed man face', 'detailed hand', 'beautiful eyes' + ], + validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x) +) +default_save_metadata_to_images = get_config_item_or_set_default( + key='default_save_metadata_to_images', + default_value=False, + validator=lambda x: isinstance(x, bool) +) +default_metadata_scheme = get_config_item_or_set_default( + key='default_metadata_scheme', + default_value=MetadataScheme.FOOOCUS.value, + validator=lambda x: x in [y[1] for y in modules.flags.metadata_scheme if y[1] == x] +) +metadata_created_by = get_config_item_or_set_default( + key='metadata_created_by', + default_value='', + validator=lambda x: isinstance(x, str) +) + +example_inpaint_prompts = [[x] for x in example_inpaint_prompts] + +config_dict["default_loras"] = default_loras = default_loras[:default_max_lora_number] + [['None', 1.0] for _ in range(default_max_lora_number - len(default_loras))] + +possible_preset_keys = [ + "default_model", + "default_refiner", + "default_refiner_switch", + "default_loras_min_weight", + "default_loras_max_weight", + "default_loras", + "default_max_lora_number", + "default_cfg_scale", + "default_sample_sharpness", + "default_sampler", + "default_scheduler", + "default_performance", + "default_prompt", + "default_prompt_negative", + "default_styles", + "default_aspect_ratio", + "default_save_metadata_to_images", + "checkpoint_downloads", + "embeddings_downloads", + "lora_downloads", +] + + +REWRITE_PRESET = False + +if REWRITE_PRESET and isinstance(args_manager.args.preset, str): + save_path = 'presets/' + args_manager.args.preset + '.json' + with open(save_path, "w", encoding="utf-8") as json_file: + json.dump({k: config_dict[k] for k in possible_preset_keys}, json_file, indent=4) + print(f'Preset saved to {save_path}. Exiting ...') + exit(0) + + +def add_ratio(x): + a, b = x.replace('*', ' ').split(' ')[:2] + a, b = int(a), int(b) + g = math.gcd(a, b) + return f'{a}×{b} \U00002223 {a // g}:{b // g}' + + +default_aspect_ratio = add_ratio(default_aspect_ratio) +available_aspect_ratios = [add_ratio(x) for x in available_aspect_ratios] + + +# Only write config in the first launch. +if not os.path.exists(config_path): + with open(config_path, "w", encoding="utf-8") as json_file: + json.dump({k: config_dict[k] for k in always_save_keys}, json_file, indent=4) + + +# Always write tutorials. +with open(config_example_path, "w", encoding="utf-8") as json_file: + cpa = config_path.replace("\\", "\\\\") + json_file.write(f'You can modify your "{cpa}" using the below keys, formats, and examples.\n' + f'Do not modify this file. Modifications in this file will not take effect.\n' + f'This file is a tutorial and example. Please edit "{cpa}" to really change any settings.\n' + + 'Remember to split the paths with "\\\\" rather than "\\", ' + 'and there is no "," before the last "}". \n\n\n') + json.dump({k: config_dict[k] for k in visited_keys}, json_file, indent=4) + +model_filenames = [] +lora_filenames = [] +sdxl_lcm_lora = 'sdxl_lcm_lora.safetensors' + + +def get_model_filenames(folder_paths, name_filter=None): + extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch'] + files = [] + for folder in folder_paths: + files += get_files_from_folder(folder, extensions, name_filter) + return files + + +def update_all_model_names(): + global model_filenames, lora_filenames + model_filenames = get_model_filenames(paths_checkpoints) + lora_filenames = get_model_filenames(paths_loras) + return + + +def downloading_inpaint_models(v): + assert v in modules.flags.inpaint_engine_versions + + load_file_from_url( + url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/fooocus_inpaint_head.pth', + model_dir=path_inpaint, + file_name='fooocus_inpaint_head.pth' + ) + head_file = os.path.join(path_inpaint, 'fooocus_inpaint_head.pth') + patch_file = None + + if v == 'v1': + load_file_from_url( + url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint.fooocus.patch', + model_dir=path_inpaint, + file_name='inpaint.fooocus.patch' + ) + patch_file = os.path.join(path_inpaint, 'inpaint.fooocus.patch') + + if v == 'v2.5': + load_file_from_url( + url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint_v25.fooocus.patch', + model_dir=path_inpaint, + file_name='inpaint_v25.fooocus.patch' + ) + patch_file = os.path.join(path_inpaint, 'inpaint_v25.fooocus.patch') + + if v == 'v2.6': + load_file_from_url( + url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint_v26.fooocus.patch', + model_dir=path_inpaint, + file_name='inpaint_v26.fooocus.patch' + ) + patch_file = os.path.join(path_inpaint, 'inpaint_v26.fooocus.patch') + + return head_file, patch_file + + +def downloading_sdxl_lcm_lora(): + load_file_from_url( + url='https://huggingface.co/lllyasviel/misc/resolve/main/sdxl_lcm_lora.safetensors', + model_dir=paths_loras[0], + file_name=sdxl_lcm_lora + ) + return sdxl_lcm_lora + + +def downloading_controlnet_canny(): + load_file_from_url( + url='https://huggingface.co/lllyasviel/misc/resolve/main/control-lora-canny-rank128.safetensors', + model_dir=path_controlnet, + file_name='control-lora-canny-rank128.safetensors' + ) + return os.path.join(path_controlnet, 'control-lora-canny-rank128.safetensors') + + +def downloading_controlnet_cpds(): + load_file_from_url( + url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_xl_cpds_128.safetensors', + model_dir=path_controlnet, + file_name='fooocus_xl_cpds_128.safetensors' + ) + return os.path.join(path_controlnet, 'fooocus_xl_cpds_128.safetensors') + + +def downloading_ip_adapters(v): + assert v in ['ip', 'face'] + + results = [] + + load_file_from_url( + url='https://huggingface.co/lllyasviel/misc/resolve/main/clip_vision_vit_h.safetensors', + model_dir=path_clip_vision, + file_name='clip_vision_vit_h.safetensors' + ) + results += [os.path.join(path_clip_vision, 'clip_vision_vit_h.safetensors')] + + load_file_from_url( + url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_ip_negative.safetensors', + model_dir=path_controlnet, + file_name='fooocus_ip_negative.safetensors' + ) + results += [os.path.join(path_controlnet, 'fooocus_ip_negative.safetensors')] + + if v == 'ip': + load_file_from_url( + url='https://huggingface.co/lllyasviel/misc/resolve/main/ip-adapter-plus_sdxl_vit-h.bin', + model_dir=path_controlnet, + file_name='ip-adapter-plus_sdxl_vit-h.bin' + ) + results += [os.path.join(path_controlnet, 'ip-adapter-plus_sdxl_vit-h.bin')] + + if v == 'face': + load_file_from_url( + url='https://huggingface.co/lllyasviel/misc/resolve/main/ip-adapter-plus-face_sdxl_vit-h.bin', + model_dir=path_controlnet, + file_name='ip-adapter-plus-face_sdxl_vit-h.bin' + ) + results += [os.path.join(path_controlnet, 'ip-adapter-plus-face_sdxl_vit-h.bin')] + + return results + + +def downloading_upscale_model(): + load_file_from_url( + url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_upscaler_s409985e5.bin', + model_dir=path_upscale_models, + file_name='fooocus_upscaler_s409985e5.bin' + ) + return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin') + + +update_all_model_names() diff --git a/modules/constants.py b/modules/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..667fa8682306e192465f11733fc9814bacedfe89 --- /dev/null +++ b/modules/constants.py @@ -0,0 +1,5 @@ +# as in k-diffusion (sampling.py) +MIN_SEED = 0 +MAX_SEED = 2**63 - 1 + +AUTH_FILENAME = 'auth.json' diff --git a/modules/core.py b/modules/core.py new file mode 100644 index 0000000000000000000000000000000000000000..bfc449661d9c636e096b8e9555daa3bebb5f50e7 --- /dev/null +++ b/modules/core.py @@ -0,0 +1,339 @@ +import os +import einops +import torch +import numpy as np + +import ldm_patched.modules.model_management +import ldm_patched.modules.model_detection +import ldm_patched.modules.model_patcher +import ldm_patched.modules.utils +import ldm_patched.modules.controlnet +import modules.sample_hijack +import ldm_patched.modules.samplers +import ldm_patched.modules.latent_formats + +from ldm_patched.modules.sd import load_checkpoint_guess_config +from ldm_patched.contrib.external import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDecodeTiled, \ + ControlNetApplyAdvanced +from ldm_patched.contrib.external_freelunch import FreeU_V2 +from ldm_patched.modules.sample import prepare_mask +from modules.lora import match_lora +from modules.util import get_file_from_folder_list +from ldm_patched.modules.lora import model_lora_keys_unet, model_lora_keys_clip +from modules.config import path_embeddings +from ldm_patched.contrib.external_model_advanced import ModelSamplingDiscrete + + +opEmptyLatentImage = EmptyLatentImage() +opVAEDecode = VAEDecode() +opVAEEncode = VAEEncode() +opVAEDecodeTiled = VAEDecodeTiled() +opVAEEncodeTiled = VAEEncodeTiled() +opControlNetApplyAdvanced = ControlNetApplyAdvanced() +opFreeU = FreeU_V2() +opModelSamplingDiscrete = ModelSamplingDiscrete() + + +class StableDiffusionModel: + def __init__(self, unet=None, vae=None, clip=None, clip_vision=None, filename=None): + self.unet = unet + self.vae = vae + self.clip = clip + self.clip_vision = clip_vision + self.filename = filename + self.unet_with_lora = unet + self.clip_with_lora = clip + self.visited_loras = '' + + self.lora_key_map_unet = {} + self.lora_key_map_clip = {} + + if self.unet is not None: + self.lora_key_map_unet = model_lora_keys_unet(self.unet.model, self.lora_key_map_unet) + self.lora_key_map_unet.update({x: x for x in self.unet.model.state_dict().keys()}) + + if self.clip is not None: + self.lora_key_map_clip = model_lora_keys_clip(self.clip.cond_stage_model, self.lora_key_map_clip) + self.lora_key_map_clip.update({x: x for x in self.clip.cond_stage_model.state_dict().keys()}) + + @torch.no_grad() + @torch.inference_mode() + def refresh_loras(self, loras): + assert isinstance(loras, list) + + if self.visited_loras == str(loras): + return + + self.visited_loras = str(loras) + + if self.unet is None: + return + + print(f'Request to load LoRAs {str(loras)} for model [{self.filename}].') + + loras_to_load = [] + + for name, weight in loras: + if name == 'None': + continue + + if os.path.exists(name): + lora_filename = name + else: + lora_filename = get_file_from_folder_list(name, modules.config.paths_loras) + + if not os.path.exists(lora_filename): + print(f'Lora file not found: {lora_filename}') + continue + + loras_to_load.append((lora_filename, weight)) + + self.unet_with_lora = self.unet.clone() if self.unet is not None else None + self.clip_with_lora = self.clip.clone() if self.clip is not None else None + + for lora_filename, weight in loras_to_load: + lora_unmatch = ldm_patched.modules.utils.load_torch_file(lora_filename, safe_load=False) + lora_unet, lora_unmatch = match_lora(lora_unmatch, self.lora_key_map_unet) + lora_clip, lora_unmatch = match_lora(lora_unmatch, self.lora_key_map_clip) + + if len(lora_unmatch) > 12: + # model mismatch + continue + + if len(lora_unmatch) > 0: + print(f'Loaded LoRA [{lora_filename}] for model [{self.filename}] ' + f'with unmatched keys {list(lora_unmatch.keys())}') + + if self.unet_with_lora is not None and len(lora_unet) > 0: + loaded_keys = self.unet_with_lora.add_patches(lora_unet, weight) + print(f'Loaded LoRA [{lora_filename}] for UNet [{self.filename}] ' + f'with {len(loaded_keys)} keys at weight {weight}.') + for item in lora_unet: + if item not in loaded_keys: + print("UNet LoRA key skipped: ", item) + + if self.clip_with_lora is not None and len(lora_clip) > 0: + loaded_keys = self.clip_with_lora.add_patches(lora_clip, weight) + print(f'Loaded LoRA [{lora_filename}] for CLIP [{self.filename}] ' + f'with {len(loaded_keys)} keys at weight {weight}.') + for item in lora_clip: + if item not in loaded_keys: + print("CLIP LoRA key skipped: ", item) + + +@torch.no_grad() +@torch.inference_mode() +def apply_freeu(model, b1, b2, s1, s2): + return opFreeU.patch(model=model, b1=b1, b2=b2, s1=s1, s2=s2)[0] + + +@torch.no_grad() +@torch.inference_mode() +def load_controlnet(ckpt_filename): + return ldm_patched.modules.controlnet.load_controlnet(ckpt_filename) + + +@torch.no_grad() +@torch.inference_mode() +def apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent): + return opControlNetApplyAdvanced.apply_controlnet(positive=positive, negative=negative, control_net=control_net, + image=image, strength=strength, start_percent=start_percent, end_percent=end_percent) + + +@torch.no_grad() +@torch.inference_mode() +def load_model(ckpt_filename): + unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings) + return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename) + + +@torch.no_grad() +@torch.inference_mode() +def generate_empty_latent(width=1024, height=1024, batch_size=1): + return opEmptyLatentImage.generate(width=width, height=height, batch_size=batch_size)[0] + + +@torch.no_grad() +@torch.inference_mode() +def decode_vae(vae, latent_image, tiled=False): + if tiled: + return opVAEDecodeTiled.decode(samples=latent_image, vae=vae, tile_size=512)[0] + else: + return opVAEDecode.decode(samples=latent_image, vae=vae)[0] + + +@torch.no_grad() +@torch.inference_mode() +def encode_vae(vae, pixels, tiled=False): + if tiled: + return opVAEEncodeTiled.encode(pixels=pixels, vae=vae, tile_size=512)[0] + else: + return opVAEEncode.encode(pixels=pixels, vae=vae)[0] + + +@torch.no_grad() +@torch.inference_mode() +def encode_vae_inpaint(vae, pixels, mask): + assert mask.ndim == 3 and pixels.ndim == 4 + assert mask.shape[-1] == pixels.shape[-2] + assert mask.shape[-2] == pixels.shape[-3] + + w = mask.round()[..., None] + pixels = pixels * (1 - w) + 0.5 * w + + latent = vae.encode(pixels) + B, C, H, W = latent.shape + + latent_mask = mask[:, None, :, :] + latent_mask = torch.nn.functional.interpolate(latent_mask, size=(H * 8, W * 8), mode="bilinear").round() + latent_mask = torch.nn.functional.max_pool2d(latent_mask, (8, 8)).round().to(latent) + + return latent, latent_mask + + +class VAEApprox(torch.nn.Module): + def __init__(self): + super(VAEApprox, self).__init__() + self.conv1 = torch.nn.Conv2d(4, 8, (7, 7)) + self.conv2 = torch.nn.Conv2d(8, 16, (5, 5)) + self.conv3 = torch.nn.Conv2d(16, 32, (3, 3)) + self.conv4 = torch.nn.Conv2d(32, 64, (3, 3)) + self.conv5 = torch.nn.Conv2d(64, 32, (3, 3)) + self.conv6 = torch.nn.Conv2d(32, 16, (3, 3)) + self.conv7 = torch.nn.Conv2d(16, 8, (3, 3)) + self.conv8 = torch.nn.Conv2d(8, 3, (3, 3)) + self.current_type = None + + def forward(self, x): + extra = 11 + x = torch.nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2)) + x = torch.nn.functional.pad(x, (extra, extra, extra, extra)) + for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8]: + x = layer(x) + x = torch.nn.functional.leaky_relu(x, 0.1) + return x + + +VAE_approx_models = {} + + +@torch.no_grad() +@torch.inference_mode() +def get_previewer(model): + global VAE_approx_models + + from modules.config import path_vae_approx + is_sdxl = isinstance(model.model.latent_format, ldm_patched.modules.latent_formats.SDXL) + vae_approx_filename = os.path.join(path_vae_approx, 'xlvaeapp.pth' if is_sdxl else 'vaeapp_sd15.pth') + + if vae_approx_filename in VAE_approx_models: + VAE_approx_model = VAE_approx_models[vae_approx_filename] + else: + sd = torch.load(vae_approx_filename, map_location='cpu') + VAE_approx_model = VAEApprox() + VAE_approx_model.load_state_dict(sd) + del sd + VAE_approx_model.eval() + + if ldm_patched.modules.model_management.should_use_fp16(): + VAE_approx_model.half() + VAE_approx_model.current_type = torch.float16 + else: + VAE_approx_model.float() + VAE_approx_model.current_type = torch.float32 + + VAE_approx_model.to(ldm_patched.modules.model_management.get_torch_device()) + VAE_approx_models[vae_approx_filename] = VAE_approx_model + + @torch.no_grad() + @torch.inference_mode() + def preview_function(x0, step, total_steps): + with torch.no_grad(): + x_sample = x0.to(VAE_approx_model.current_type) + x_sample = VAE_approx_model(x_sample) * 127.5 + 127.5 + x_sample = einops.rearrange(x_sample, 'b c h w -> b h w c')[0] + x_sample = x_sample.cpu().numpy().clip(0, 255).astype(np.uint8) + return x_sample + + return preview_function + + +@torch.no_grad() +@torch.inference_mode() +def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sampler_name='dpmpp_2m_sde_gpu', + scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None, + force_full_denoise=False, callback_function=None, refiner=None, refiner_switch=-1, + previewer_start=None, previewer_end=None, sigmas=None, noise_mean=None, disable_preview=False): + + if sigmas is not None: + sigmas = sigmas.clone().to(ldm_patched.modules.model_management.get_torch_device()) + + latent_image = latent["samples"] + + if disable_noise: + noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + else: + batch_inds = latent["batch_index"] if "batch_index" in latent else None + noise = ldm_patched.modules.sample.prepare_noise(latent_image, seed, batch_inds) + + if isinstance(noise_mean, torch.Tensor): + noise = noise + noise_mean - torch.mean(noise, dim=1, keepdim=True) + + noise_mask = None + if "noise_mask" in latent: + noise_mask = latent["noise_mask"] + + previewer = get_previewer(model) + + if previewer_start is None: + previewer_start = 0 + + if previewer_end is None: + previewer_end = steps + + def callback(step, x0, x, total_steps): + ldm_patched.modules.model_management.throw_exception_if_processing_interrupted() + y = None + if previewer is not None and not disable_preview: + y = previewer(x0, previewer_start + step, previewer_end) + if callback_function is not None: + callback_function(previewer_start + step, x0, x, previewer_end, y) + + disable_pbar = False + modules.sample_hijack.current_refiner = refiner + modules.sample_hijack.refiner_switch_step = refiner_switch + ldm_patched.modules.samplers.sample = modules.sample_hijack.sample_hacked + + try: + samples = ldm_patched.modules.sample.sample(model, + noise, steps, cfg, sampler_name, scheduler, + positive, negative, latent_image, + denoise=denoise, disable_noise=disable_noise, + start_step=start_step, + last_step=last_step, + force_full_denoise=force_full_denoise, noise_mask=noise_mask, + callback=callback, + disable_pbar=disable_pbar, seed=seed, sigmas=sigmas) + + out = latent.copy() + out["samples"] = samples + finally: + modules.sample_hijack.current_refiner = None + + return out + + +@torch.no_grad() +@torch.inference_mode() +def pytorch_to_numpy(x): + return [np.clip(255. * y.cpu().numpy(), 0, 255).astype(np.uint8) for y in x] + + +@torch.no_grad() +@torch.inference_mode() +def numpy_to_pytorch(x): + y = x.astype(np.float32) / 255.0 + y = y[None] + y = np.ascontiguousarray(y.copy()) + y = torch.from_numpy(y).float() + return y diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..f8edfae105fa45a103f9e1463e6abbee2c19444c --- /dev/null +++ b/modules/default_pipeline.py @@ -0,0 +1,498 @@ +import modules.core as core +import os +import torch +import modules.patch +import modules.config +import ldm_patched.modules.model_management +import ldm_patched.modules.latent_formats +import modules.inpaint_worker +import extras.vae_interpose as vae_interpose +from extras.expansion import FooocusExpansion + +from ldm_patched.modules.model_base import SDXL, SDXLRefiner +from modules.sample_hijack import clip_separate +from modules.util import get_file_from_folder_list + + +model_base = core.StableDiffusionModel() +model_refiner = core.StableDiffusionModel() + +final_expansion = None +final_unet = None +final_clip = None +final_vae = None +final_refiner_unet = None +final_refiner_vae = None + +loaded_ControlNets = {} + + +@torch.no_grad() +@torch.inference_mode() +def refresh_controlnets(model_paths): + global loaded_ControlNets + cache = {} + for p in model_paths: + if p is not None: + if p in loaded_ControlNets: + cache[p] = loaded_ControlNets[p] + else: + cache[p] = core.load_controlnet(p) + loaded_ControlNets = cache + return + + +@torch.no_grad() +@torch.inference_mode() +def assert_model_integrity(): + error_message = None + + if not isinstance(model_base.unet_with_lora.model, SDXL): + error_message = 'You have selected base model other than SDXL. This is not supported yet.' + + if error_message is not None: + raise NotImplementedError(error_message) + + return True + + +@torch.no_grad() +@torch.inference_mode() +def refresh_base_model(name): + global model_base + + filename = get_file_from_folder_list(name, modules.config.paths_checkpoints) + + if model_base.filename == filename: + return + + model_base = core.StableDiffusionModel() + model_base = core.load_model(filename) + print(f'Base model loaded: {model_base.filename}') + return + + +@torch.no_grad() +@torch.inference_mode() +def refresh_refiner_model(name): + global model_refiner + + filename = get_file_from_folder_list(name, modules.config.paths_checkpoints) + + if model_refiner.filename == filename: + return + + model_refiner = core.StableDiffusionModel() + + if name == 'None': + print(f'Refiner unloaded.') + return + + model_refiner = core.load_model(filename) + print(f'Refiner model loaded: {model_refiner.filename}') + + if isinstance(model_refiner.unet.model, SDXL): + model_refiner.clip = None + model_refiner.vae = None + elif isinstance(model_refiner.unet.model, SDXLRefiner): + model_refiner.clip = None + model_refiner.vae = None + else: + model_refiner.clip = None + + return + + +@torch.no_grad() +@torch.inference_mode() +def synthesize_refiner_model(): + global model_base, model_refiner + + print('Synthetic Refiner Activated') + model_refiner = core.StableDiffusionModel( + unet=model_base.unet, + vae=model_base.vae, + clip=model_base.clip, + clip_vision=model_base.clip_vision, + filename=model_base.filename + ) + model_refiner.vae = None + model_refiner.clip = None + model_refiner.clip_vision = None + + return + + +@torch.no_grad() +@torch.inference_mode() +def refresh_loras(loras, base_model_additional_loras=None): + global model_base, model_refiner + + if not isinstance(base_model_additional_loras, list): + base_model_additional_loras = [] + + model_base.refresh_loras(loras + base_model_additional_loras) + model_refiner.refresh_loras(loras) + + return + + +@torch.no_grad() +@torch.inference_mode() +def clip_encode_single(clip, text, verbose=False): + cached = clip.fcs_cond_cache.get(text, None) + if cached is not None: + if verbose: + print(f'[CLIP Cached] {text}') + return cached + tokens = clip.tokenize(text) + result = clip.encode_from_tokens(tokens, return_pooled=True) + clip.fcs_cond_cache[text] = result + if verbose: + print(f'[CLIP Encoded] {text}') + return result + + +@torch.no_grad() +@torch.inference_mode() +def clone_cond(conds): + results = [] + + for c, p in conds: + p = p["pooled_output"] + + if isinstance(c, torch.Tensor): + c = c.clone() + + if isinstance(p, torch.Tensor): + p = p.clone() + + results.append([c, {"pooled_output": p}]) + + return results + + +@torch.no_grad() +@torch.inference_mode() +def clip_encode(texts, pool_top_k=1): + global final_clip + + if final_clip is None: + return None + if not isinstance(texts, list): + return None + if len(texts) == 0: + return None + + cond_list = [] + pooled_acc = 0 + + for i, text in enumerate(texts): + cond, pooled = clip_encode_single(final_clip, text) + cond_list.append(cond) + if i < pool_top_k: + pooled_acc += pooled + + return [[torch.cat(cond_list, dim=1), {"pooled_output": pooled_acc}]] + + +@torch.no_grad() +@torch.inference_mode() +def clear_all_caches(): + final_clip.fcs_cond_cache = {} + + +@torch.no_grad() +@torch.inference_mode() +def prepare_text_encoder(async_call=True): + if async_call: + # TODO: make sure that this is always called in an async way so that users cannot feel it. + pass + assert_model_integrity() + ldm_patched.modules.model_management.load_models_gpu([final_clip.patcher, final_expansion.patcher]) + return + + +@torch.no_grad() +@torch.inference_mode() +def refresh_everything(refiner_model_name, base_model_name, loras, + base_model_additional_loras=None, use_synthetic_refiner=False): + global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, final_expansion + + final_unet = None + final_clip = None + final_vae = None + final_refiner_unet = None + final_refiner_vae = None + + if use_synthetic_refiner and refiner_model_name == 'None': + print('Synthetic Refiner Activated') + refresh_base_model(base_model_name) + synthesize_refiner_model() + else: + refresh_refiner_model(refiner_model_name) + refresh_base_model(base_model_name) + + refresh_loras(loras, base_model_additional_loras=base_model_additional_loras) + assert_model_integrity() + + final_unet = model_base.unet_with_lora + final_clip = model_base.clip_with_lora + final_vae = model_base.vae + + final_refiner_unet = model_refiner.unet_with_lora + final_refiner_vae = model_refiner.vae + + if final_expansion is None: + final_expansion = FooocusExpansion() + + prepare_text_encoder(async_call=True) + clear_all_caches() + return + + +refresh_everything( + refiner_model_name=modules.config.default_refiner_model_name, + base_model_name=modules.config.default_base_model_name, + loras=modules.config.default_loras +) + + +@torch.no_grad() +@torch.inference_mode() +def vae_parse(latent): + if final_refiner_vae is None: + return latent + + result = vae_interpose.parse(latent["samples"]) + return {'samples': result} + + +@torch.no_grad() +@torch.inference_mode() +def calculate_sigmas_all(sampler, model, scheduler, steps): + from ldm_patched.modules.samplers import calculate_sigmas_scheduler + + discard_penultimate_sigma = False + if sampler in ['dpm_2', 'dpm_2_ancestral']: + steps += 1 + discard_penultimate_sigma = True + + sigmas = calculate_sigmas_scheduler(model, scheduler, steps) + + if discard_penultimate_sigma: + sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) + return sigmas + + +@torch.no_grad() +@torch.inference_mode() +def calculate_sigmas(sampler, model, scheduler, steps, denoise): + if denoise is None or denoise > 0.9999: + sigmas = calculate_sigmas_all(sampler, model, scheduler, steps) + else: + new_steps = int(steps / denoise) + sigmas = calculate_sigmas_all(sampler, model, scheduler, new_steps) + sigmas = sigmas[-(steps + 1):] + return sigmas + + +@torch.no_grad() +@torch.inference_mode() +def get_candidate_vae(steps, switch, denoise=1.0, refiner_swap_method='joint'): + assert refiner_swap_method in ['joint', 'separate', 'vae'] + + if final_refiner_vae is not None and final_refiner_unet is not None: + if denoise > 0.9: + return final_vae, final_refiner_vae + else: + if denoise > (float(steps - switch) / float(steps)) ** 0.834: # karras 0.834 + return final_vae, None + else: + return final_refiner_vae, None + + return final_vae, final_refiner_vae + + +@torch.no_grad() +@torch.inference_mode() +def process_diffusion(positive_cond, negative_cond, steps, switch, width, height, image_seed, callback, sampler_name, scheduler_name, latent=None, denoise=1.0, tiled=False, cfg_scale=7.0, refiner_swap_method='joint', disable_preview=False): + target_unet, target_vae, target_refiner_unet, target_refiner_vae, target_clip \ + = final_unet, final_vae, final_refiner_unet, final_refiner_vae, final_clip + + assert refiner_swap_method in ['joint', 'separate', 'vae'] + + if final_refiner_vae is not None and final_refiner_unet is not None: + # Refiner Use Different VAE (then it is SD15) + if denoise > 0.9: + refiner_swap_method = 'vae' + else: + refiner_swap_method = 'joint' + if denoise > (float(steps - switch) / float(steps)) ** 0.834: # karras 0.834 + target_unet, target_vae, target_refiner_unet, target_refiner_vae \ + = final_unet, final_vae, None, None + print(f'[Sampler] only use Base because of partial denoise.') + else: + positive_cond = clip_separate(positive_cond, target_model=final_refiner_unet.model, target_clip=final_clip) + negative_cond = clip_separate(negative_cond, target_model=final_refiner_unet.model, target_clip=final_clip) + target_unet, target_vae, target_refiner_unet, target_refiner_vae \ + = final_refiner_unet, final_refiner_vae, None, None + print(f'[Sampler] only use Refiner because of partial denoise.') + + print(f'[Sampler] refiner_swap_method = {refiner_swap_method}') + + if latent is None: + initial_latent = core.generate_empty_latent(width=width, height=height, batch_size=1) + else: + initial_latent = latent + + minmax_sigmas = calculate_sigmas(sampler=sampler_name, scheduler=scheduler_name, model=final_unet.model, steps=steps, denoise=denoise) + sigma_min, sigma_max = minmax_sigmas[minmax_sigmas > 0].min(), minmax_sigmas.max() + sigma_min = float(sigma_min.cpu().numpy()) + sigma_max = float(sigma_max.cpu().numpy()) + print(f'[Sampler] sigma_min = {sigma_min}, sigma_max = {sigma_max}') + + modules.patch.BrownianTreeNoiseSamplerPatched.global_init( + initial_latent['samples'].to(ldm_patched.modules.model_management.get_torch_device()), + sigma_min, sigma_max, seed=image_seed, cpu=False) + + decoded_latent = None + + if refiner_swap_method == 'joint': + sampled_latent = core.ksampler( + model=target_unet, + refiner=target_refiner_unet, + positive=positive_cond, + negative=negative_cond, + latent=initial_latent, + steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True, + seed=image_seed, + denoise=denoise, + callback_function=callback, + cfg=cfg_scale, + sampler_name=sampler_name, + scheduler=scheduler_name, + refiner_switch=switch, + previewer_start=0, + previewer_end=steps, + disable_preview=disable_preview + ) + decoded_latent = core.decode_vae(vae=target_vae, latent_image=sampled_latent, tiled=tiled) + + if refiner_swap_method == 'separate': + sampled_latent = core.ksampler( + model=target_unet, + positive=positive_cond, + negative=negative_cond, + latent=initial_latent, + steps=steps, start_step=0, last_step=switch, disable_noise=False, force_full_denoise=False, + seed=image_seed, + denoise=denoise, + callback_function=callback, + cfg=cfg_scale, + sampler_name=sampler_name, + scheduler=scheduler_name, + previewer_start=0, + previewer_end=steps, + disable_preview=disable_preview + ) + print('Refiner swapped by changing ksampler. Noise preserved.') + + target_model = target_refiner_unet + if target_model is None: + target_model = target_unet + print('Use base model to refine itself - this may because of developer mode.') + + sampled_latent = core.ksampler( + model=target_model, + positive=clip_separate(positive_cond, target_model=target_model.model, target_clip=target_clip), + negative=clip_separate(negative_cond, target_model=target_model.model, target_clip=target_clip), + latent=sampled_latent, + steps=steps, start_step=switch, last_step=steps, disable_noise=True, force_full_denoise=True, + seed=image_seed, + denoise=denoise, + callback_function=callback, + cfg=cfg_scale, + sampler_name=sampler_name, + scheduler=scheduler_name, + previewer_start=switch, + previewer_end=steps, + disable_preview=disable_preview + ) + + target_model = target_refiner_vae + if target_model is None: + target_model = target_vae + decoded_latent = core.decode_vae(vae=target_model, latent_image=sampled_latent, tiled=tiled) + + if refiner_swap_method == 'vae': + modules.patch.patch_settings[os.getpid()].eps_record = 'vae' + + if modules.inpaint_worker.current_task is not None: + modules.inpaint_worker.current_task.unswap() + + sampled_latent = core.ksampler( + model=target_unet, + positive=positive_cond, + negative=negative_cond, + latent=initial_latent, + steps=steps, start_step=0, last_step=switch, disable_noise=False, force_full_denoise=True, + seed=image_seed, + denoise=denoise, + callback_function=callback, + cfg=cfg_scale, + sampler_name=sampler_name, + scheduler=scheduler_name, + previewer_start=0, + previewer_end=steps, + disable_preview=disable_preview + ) + print('Fooocus VAE-based swap.') + + target_model = target_refiner_unet + if target_model is None: + target_model = target_unet + print('Use base model to refine itself - this may because of developer mode.') + + sampled_latent = vae_parse(sampled_latent) + + k_sigmas = 1.4 + sigmas = calculate_sigmas(sampler=sampler_name, + scheduler=scheduler_name, + model=target_model.model, + steps=steps, + denoise=denoise)[switch:] * k_sigmas + len_sigmas = len(sigmas) - 1 + + noise_mean = torch.mean(modules.patch.patch_settings[os.getpid()].eps_record, dim=1, keepdim=True) + + if modules.inpaint_worker.current_task is not None: + modules.inpaint_worker.current_task.swap() + + sampled_latent = core.ksampler( + model=target_model, + positive=clip_separate(positive_cond, target_model=target_model.model, target_clip=target_clip), + negative=clip_separate(negative_cond, target_model=target_model.model, target_clip=target_clip), + latent=sampled_latent, + steps=len_sigmas, start_step=0, last_step=len_sigmas, disable_noise=False, force_full_denoise=True, + seed=image_seed+1, + denoise=denoise, + callback_function=callback, + cfg=cfg_scale, + sampler_name=sampler_name, + scheduler=scheduler_name, + previewer_start=switch, + previewer_end=steps, + sigmas=sigmas, + noise_mean=noise_mean, + disable_preview=disable_preview + ) + + target_model = target_refiner_vae + if target_model is None: + target_model = target_vae + decoded_latent = core.decode_vae(vae=target_model, latent_image=sampled_latent, tiled=tiled) + + images = core.pytorch_to_numpy(decoded_latent) + modules.patch.patch_settings[os.getpid()].eps_record = None + return images diff --git a/modules/flags.py b/modules/flags.py new file mode 100644 index 0000000000000000000000000000000000000000..6f12bc8f3f27c4b9ae06f2ee7ac0a90e46122b16 --- /dev/null +++ b/modules/flags.py @@ -0,0 +1,125 @@ +from enum import IntEnum, Enum + +disabled = 'Disabled' +enabled = 'Enabled' +subtle_variation = 'Vary (Subtle)' +strong_variation = 'Vary (Strong)' +upscale_15 = 'Upscale (1.5x)' +upscale_2 = 'Upscale (2x)' +upscale_fast = 'Upscale (Fast 2x)' + +uov_list = [ + disabled, subtle_variation, strong_variation, upscale_15, upscale_2, upscale_fast +] + +CIVITAI_NO_KARRAS = ["euler", "euler_ancestral", "heun", "dpm_fast", "dpm_adaptive", "ddim", "uni_pc"] + +# fooocus: a1111 (Civitai) +KSAMPLER = { + "euler": "Euler", + "euler_ancestral": "Euler a", + "heun": "Heun", + "heunpp2": "", + "dpm_2": "DPM2", + "dpm_2_ancestral": "DPM2 a", + "lms": "LMS", + "dpm_fast": "DPM fast", + "dpm_adaptive": "DPM adaptive", + "dpmpp_2s_ancestral": "DPM++ 2S a", + "dpmpp_sde": "DPM++ SDE", + "dpmpp_sde_gpu": "DPM++ SDE", + "dpmpp_2m": "DPM++ 2M", + "dpmpp_2m_sde": "DPM++ 2M SDE", + "dpmpp_2m_sde_gpu": "DPM++ 2M SDE", + "dpmpp_3m_sde": "", + "dpmpp_3m_sde_gpu": "", + "ddpm": "", + "lcm": "LCM" +} + +SAMPLER_EXTRA = { + "ddim": "DDIM", + "uni_pc": "UniPC", + "uni_pc_bh2": "" +} + +SAMPLERS = KSAMPLER | SAMPLER_EXTRA + +KSAMPLER_NAMES = list(KSAMPLER.keys()) + +SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo"] +SAMPLER_NAMES = KSAMPLER_NAMES + list(SAMPLER_EXTRA.keys()) + +sampler_list = SAMPLER_NAMES +scheduler_list = SCHEDULER_NAMES + +refiner_swap_method = 'joint' + +cn_ip = "ImagePrompt" +cn_ip_face = "FaceSwap" +cn_canny = "PyraCanny" +cn_cpds = "CPDS" + +ip_list = [cn_ip, cn_canny, cn_cpds, cn_ip_face] +default_ip = cn_ip + +default_parameters = { + cn_ip: (0.5, 0.6), cn_ip_face: (0.9, 0.75), cn_canny: (0.5, 1.0), cn_cpds: (0.5, 1.0) +} # stop, weight + +output_formats = ['png', 'jpg', 'webp'] + +inpaint_engine_versions = ['None', 'v1', 'v2.5', 'v2.6'] +inpaint_option_default = 'Inpaint or Outpaint (default)' +inpaint_option_detail = 'Improve Detail (face, hand, eyes, etc.)' +inpaint_option_modify = 'Modify Content (add objects, change background, etc.)' +inpaint_options = [inpaint_option_default, inpaint_option_detail, inpaint_option_modify] + +desc_type_photo = 'Photograph' +desc_type_anime = 'Art/Anime' + + +class MetadataScheme(Enum): + FOOOCUS = 'fooocus' + A1111 = 'a1111' + + +metadata_scheme = [ + (f'{MetadataScheme.FOOOCUS.value} (json)', MetadataScheme.FOOOCUS.value), + (f'{MetadataScheme.A1111.value} (plain text)', MetadataScheme.A1111.value), +] + +lora_count = 5 + +controlnet_image_count = 4 + + +class Steps(IntEnum): + QUALITY = 60 + SPEED = 30 + EXTREME_SPEED = 8 + + +class StepsUOV(IntEnum): + QUALITY = 36 + SPEED = 18 + EXTREME_SPEED = 8 + + +class Performance(Enum): + QUALITY = 'Quality' + SPEED = 'Speed' + EXTREME_SPEED = 'Extreme Speed' + + @classmethod + def list(cls) -> list: + return list(map(lambda c: c.value, cls)) + + def steps(self) -> int | None: + return Steps[self.name].value if Steps[self.name] else None + + def steps_uov(self) -> int | None: + return StepsUOV[self.name].value if Steps[self.name] else None + + +performance_selections = Performance.list() diff --git a/modules/gradio_hijack.py b/modules/gradio_hijack.py new file mode 100644 index 0000000000000000000000000000000000000000..181429ec39a0336ffa43ebf23e4fa2b87dd97674 --- /dev/null +++ b/modules/gradio_hijack.py @@ -0,0 +1,480 @@ +"""gr.Image() component.""" + +from __future__ import annotations + +import warnings +from pathlib import Path +from typing import Any, Literal + +import numpy as np +import PIL +import PIL.ImageOps +import gradio.routes +import importlib + +from gradio_client import utils as client_utils +from gradio_client.documentation import document, set_documentation_group +from gradio_client.serializing import ImgSerializable +from PIL import Image as _Image # using _ to minimize namespace pollution + +from gradio import processing_utils, utils +from gradio.components.base import IOComponent, _Keywords, Block +from gradio.deprecation import warn_style_method_deprecation +from gradio.events import ( + Changeable, + Clearable, + Editable, + EventListenerMethod, + Selectable, + Streamable, + Uploadable, +) +from gradio.interpretation import TokenInterpretable + +set_documentation_group("component") +_Image.init() # fixes https://github.com/gradio-app/gradio/issues/2843 + + +@document() +class Image( + Editable, + Clearable, + Changeable, + Streamable, + Selectable, + Uploadable, + IOComponent, + ImgSerializable, + TokenInterpretable, +): + """ + Creates an image component that can be used to upload/draw images (as an input) or display images (as an output). + Preprocessing: passes the uploaded image as a {numpy.array}, {PIL.Image} or {str} filepath depending on `type` -- unless `tool` is `sketch` AND source is one of `upload` or `webcam`. In these cases, a {dict} with keys `image` and `mask` is passed, and the format of the corresponding values depends on `type`. + Postprocessing: expects a {numpy.array}, {PIL.Image} or {str} or {pathlib.Path} filepath to an image and displays the image. + Examples-format: a {str} filepath to a local file that contains the image. + Demos: image_mod, image_mod_default_image + Guides: image-classification-in-pytorch, image-classification-in-tensorflow, image-classification-with-vision-transformers, building-a-pictionary_app, create-your-own-friends-with-a-gan + """ + + def __init__( + self, + value: str | _Image.Image | np.ndarray | None = None, + *, + shape: tuple[int, int] | None = None, + height: int | None = None, + width: int | None = None, + image_mode: Literal[ + "1", "L", "P", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F" + ] = "RGB", + invert_colors: bool = False, + source: Literal["upload", "webcam", "canvas"] = "upload", + tool: Literal["editor", "select", "sketch", "color-sketch"] | None = None, + type: Literal["numpy", "pil", "filepath"] = "numpy", + label: str | None = None, + every: float | None = None, + show_label: bool | None = None, + show_download_button: bool = True, + container: bool = True, + scale: int | None = None, + min_width: int = 160, + interactive: bool | None = None, + visible: bool = True, + streaming: bool = False, + elem_id: str | None = None, + elem_classes: list[str] | str | None = None, + mirror_webcam: bool = True, + brush_radius: float | None = None, + brush_color: str = "#000000", + mask_opacity: float = 0.7, + show_share_button: bool | None = None, + **kwargs, + ): + """ + Parameters: + value: A PIL Image, numpy array, path or URL for the default value that Image component is going to take. If callable, the function will be called whenever the app loads to set the initial value of the component. + shape: (width, height) shape to crop and resize image when passed to function. If None, matches input image size. Pass None for either width or height to only crop and resize the other. + height: Height of the displayed image in pixels. + width: Width of the displayed image in pixels. + image_mode: "RGB" if color, or "L" if black and white. See https://pillow.readthedocs.io/en/stable/handbook/concepts.html for other supported image modes and their meaning. + invert_colors: whether to invert the image as a preprocessing step. + source: Source of image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "canvas" defaults to a white image that can be edited and drawn upon with tools. + tool: Tools used for editing. "editor" allows a full screen editor (and is the default if source is "upload" or "webcam"), "select" provides a cropping and zoom tool, "sketch" allows you to create a binary sketch (and is the default if source="canvas"), and "color-sketch" allows you to created a sketch in different colors. "color-sketch" can be used with source="upload" or "webcam" to allow sketching on an image. "sketch" can also be used with "upload" or "webcam" to create a mask over an image and in that case both the image and mask are passed into the function as a dictionary with keys "image" and "mask" respectively. + type: The format the image is converted to before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (height, width, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "filepath" passes a str path to a temporary file containing the image. + label: component name in interface. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + show_label: if True, will display label. + show_download_button: If True, will display button to download image. + container: If True, will place the component in a container - providing some extra padding around the border. + scale: relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer. + min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first. + interactive: if True, will allow users to upload and edit an image; if False, can only be used to display images. If not provided, this is inferred based on whether the component is used as an input or output. + visible: If False, component will be hidden. + streaming: If True when used in a `live` interface, will automatically stream webcam feed. Only valid is source is 'webcam'. + elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles. + elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles. + mirror_webcam: If True webcam will be mirrored. Default is True. + brush_radius: Size of the brush for Sketch. Default is None which chooses a sensible default + brush_color: Color of the brush for Sketch as hex string. Default is "#000000". + mask_opacity: Opacity of mask drawn on image, as a value between 0 and 1. + show_share_button: If True, will show a share icon in the corner of the component that allows user to share outputs to Hugging Face Spaces Discussions. If False, icon does not appear. If set to None (default behavior), then the icon appears if this Gradio app is launched on Spaces, but not otherwise. + """ + self.brush_radius = brush_radius + self.brush_color = brush_color + self.mask_opacity = mask_opacity + self.mirror_webcam = mirror_webcam + valid_types = ["numpy", "pil", "filepath"] + if type not in valid_types: + raise ValueError( + f"Invalid value for parameter `type`: {type}. Please choose from one of: {valid_types}" + ) + self.type = type + self.shape = shape + self.height = height + self.width = width + self.image_mode = image_mode + valid_sources = ["upload", "webcam", "canvas"] + if source not in valid_sources: + raise ValueError( + f"Invalid value for parameter `source`: {source}. Please choose from one of: {valid_sources}" + ) + self.source = source + if tool is None: + self.tool = "sketch" if source == "canvas" else "editor" + else: + self.tool = tool + self.invert_colors = invert_colors + self.streaming = streaming + self.show_download_button = show_download_button + if streaming and source != "webcam": + raise ValueError("Image streaming only available if source is 'webcam'.") + self.select: EventListenerMethod + """ + Event listener for when the user clicks on a pixel within the image. + Uses event data gradio.SelectData to carry `index` to refer to the [x, y] coordinates of the clicked pixel. + See EventData documentation on how to use this event data. + """ + self.show_share_button = ( + (utils.get_space() is not None) + if show_share_button is None + else show_share_button + ) + IOComponent.__init__( + self, + label=label, + every=every, + show_label=show_label, + container=container, + scale=scale, + min_width=min_width, + interactive=interactive, + visible=visible, + elem_id=elem_id, + elem_classes=elem_classes, + value=value, + **kwargs, + ) + TokenInterpretable.__init__(self) + + def get_config(self): + return { + "image_mode": self.image_mode, + "shape": self.shape, + "height": self.height, + "width": self.width, + "source": self.source, + "tool": self.tool, + "value": self.value, + "streaming": self.streaming, + "mirror_webcam": self.mirror_webcam, + "brush_radius": self.brush_radius, + "brush_color": self.brush_color, + "mask_opacity": self.mask_opacity, + "selectable": self.selectable, + "show_share_button": self.show_share_button, + "show_download_button": self.show_download_button, + **IOComponent.get_config(self), + } + + @staticmethod + def update( + value: Any | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE, + height: int | None = None, + width: int | None = None, + label: str | None = None, + show_label: bool | None = None, + show_download_button: bool | None = None, + container: bool | None = None, + scale: int | None = None, + min_width: int | None = None, + interactive: bool | None = None, + visible: bool | None = None, + brush_radius: float | None = None, + brush_color: str | None = None, + mask_opacity: float | None = None, + show_share_button: bool | None = None, + ): + return { + "height": height, + "width": width, + "label": label, + "show_label": show_label, + "show_download_button": show_download_button, + "container": container, + "scale": scale, + "min_width": min_width, + "interactive": interactive, + "visible": visible, + "value": value, + "brush_radius": brush_radius, + "brush_color": brush_color, + "mask_opacity": mask_opacity, + "show_share_button": show_share_button, + "__type__": "update", + } + + def _format_image( + self, im: _Image.Image | None + ) -> np.ndarray | _Image.Image | str | None: + """Helper method to format an image based on self.type""" + if im is None: + return im + fmt = im.format + if self.type == "pil": + return im + elif self.type == "numpy": + return np.array(im) + elif self.type == "filepath": + path = self.pil_to_temp_file( + im, dir=self.DEFAULT_TEMP_DIR, format=fmt or "png" + ) + self.temp_files.add(path) + return path + else: + raise ValueError( + "Unknown type: " + + str(self.type) + + ". Please choose from: 'numpy', 'pil', 'filepath'." + ) + + def preprocess( + self, x: str | dict[str, str] + ) -> np.ndarray | _Image.Image | str | dict | None: + """ + Parameters: + x: base64 url data, or (if tool == "sketch") a dict of image and mask base64 url data + Returns: + image in requested format, or (if tool == "sketch") a dict of image and mask in requested format + """ + if x is None: + return x + + mask = None + + if self.tool == "sketch" and self.source in ["upload", "webcam"]: + if isinstance(x, dict): + x, mask = x["image"], x["mask"] + + assert isinstance(x, str) + im = processing_utils.decode_base64_to_image(x) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + im = im.convert(self.image_mode) + if self.shape is not None: + im = processing_utils.resize_and_crop(im, self.shape) + if self.invert_colors: + im = PIL.ImageOps.invert(im) + if ( + self.source == "webcam" + and self.mirror_webcam is True + and self.tool != "color-sketch" + ): + im = PIL.ImageOps.mirror(im) + + if self.tool == "sketch" and self.source in ["upload", "webcam"]: + if mask is not None: + mask_im = processing_utils.decode_base64_to_image(mask) + if mask_im.mode == "RGBA": # whiten any opaque pixels in the mask + alpha_data = mask_im.getchannel("A").convert("L") + mask_im = _Image.merge("RGB", [alpha_data, alpha_data, alpha_data]) + return { + "image": self._format_image(im), + "mask": self._format_image(mask_im), + } + else: + return { + "image": self._format_image(im), + "mask": None, + } + + return self._format_image(im) + + def postprocess( + self, y: np.ndarray | _Image.Image | str | Path | None + ) -> str | None: + """ + Parameters: + y: image as a numpy array, PIL Image, string/Path filepath, or string URL + Returns: + base64 url data + """ + if y is None: + return None + if isinstance(y, np.ndarray): + return processing_utils.encode_array_to_base64(y) + elif isinstance(y, _Image.Image): + return processing_utils.encode_pil_to_base64(y) + elif isinstance(y, (str, Path)): + return client_utils.encode_url_or_file_to_base64(y) + else: + raise ValueError("Cannot process this value as an Image") + + def set_interpret_parameters(self, segments: int = 16): + """ + Calculates interpretation score of image subsections by splitting the image into subsections, then using a "leave one out" method to calculate the score of each subsection by whiting out the subsection and measuring the delta of the output value. + Parameters: + segments: Number of interpretation segments to split image into. + """ + self.interpretation_segments = segments + return self + + def _segment_by_slic(self, x): + """ + Helper method that segments an image into superpixels using slic. + Parameters: + x: base64 representation of an image + """ + x = processing_utils.decode_base64_to_image(x) + if self.shape is not None: + x = processing_utils.resize_and_crop(x, self.shape) + resized_and_cropped_image = np.array(x) + try: + from skimage.segmentation import slic + except (ImportError, ModuleNotFoundError) as err: + raise ValueError( + "Error: running this interpretation for images requires scikit-image, please install it first." + ) from err + try: + segments_slic = slic( + resized_and_cropped_image, + self.interpretation_segments, + compactness=10, + sigma=1, + start_label=1, + ) + except TypeError: # For skimage 0.16 and older + segments_slic = slic( + resized_and_cropped_image, + self.interpretation_segments, + compactness=10, + sigma=1, + ) + return segments_slic, resized_and_cropped_image + + def tokenize(self, x): + """ + Segments image into tokens, masks, and leave-one-out-tokens + Parameters: + x: base64 representation of an image + Returns: + tokens: list of tokens, used by the get_masked_input() method + leave_one_out_tokens: list of left-out tokens, used by the get_interpretation_neighbors() method + masks: list of masks, used by the get_interpretation_neighbors() method + """ + segments_slic, resized_and_cropped_image = self._segment_by_slic(x) + tokens, masks, leave_one_out_tokens = [], [], [] + replace_color = np.mean(resized_and_cropped_image, axis=(0, 1)) + for segment_value in np.unique(segments_slic): + mask = segments_slic == segment_value + image_screen = np.copy(resized_and_cropped_image) + image_screen[segments_slic == segment_value] = replace_color + leave_one_out_tokens.append( + processing_utils.encode_array_to_base64(image_screen) + ) + token = np.copy(resized_and_cropped_image) + token[segments_slic != segment_value] = 0 + tokens.append(token) + masks.append(mask) + return tokens, leave_one_out_tokens, masks + + def get_masked_inputs(self, tokens, binary_mask_matrix): + masked_inputs = [] + for binary_mask_vector in binary_mask_matrix: + masked_input = np.zeros_like(tokens[0], dtype=int) + for token, b in zip(tokens, binary_mask_vector): + masked_input = masked_input + token * int(b) + masked_inputs.append(processing_utils.encode_array_to_base64(masked_input)) + return masked_inputs + + def get_interpretation_scores( + self, x, neighbors, scores, masks, tokens=None, **kwargs + ) -> list[list[float]]: + """ + Returns: + A 2D array representing the interpretation score of each pixel of the image. + """ + x = processing_utils.decode_base64_to_image(x) + if self.shape is not None: + x = processing_utils.resize_and_crop(x, self.shape) + x = np.array(x) + output_scores = np.zeros((x.shape[0], x.shape[1])) + + for score, mask in zip(scores, masks): + output_scores += score * mask + + max_val, min_val = np.max(output_scores), np.min(output_scores) + if max_val > 0: + output_scores = (output_scores - min_val) / (max_val - min_val) + return output_scores.tolist() + + def style(self, *, height: int | None = None, width: int | None = None, **kwargs): + """ + This method is deprecated. Please set these arguments in the constructor instead. + """ + warn_style_method_deprecation() + if height is not None: + self.height = height + if width is not None: + self.width = width + return self + + def check_streamable(self): + if self.source != "webcam": + raise ValueError("Image streaming only available if source is 'webcam'.") + + def as_example(self, input_data: str | None) -> str: + if input_data is None: + return "" + elif ( + self.root_url + ): # If an externally hosted image, don't convert to absolute path + return input_data + return str(utils.abspath(input_data)) + + +all_components = [] + +if not hasattr(Block, 'original__init__'): + Block.original_init = Block.__init__ + + +def blk_ini(self, *args, **kwargs): + all_components.append(self) + return Block.original_init(self, *args, **kwargs) + + +Block.__init__ = blk_ini + + +gradio.routes.asyncio = importlib.reload(gradio.routes.asyncio) + +if not hasattr(gradio.routes.asyncio, 'original_wait_for'): + gradio.routes.asyncio.original_wait_for = gradio.routes.asyncio.wait_for + + +def patched_wait_for(fut, timeout): + del timeout + return gradio.routes.asyncio.original_wait_for(fut, timeout=65535) + + +gradio.routes.asyncio.wait_for = patched_wait_for + diff --git a/modules/html.py b/modules/html.py new file mode 100644 index 0000000000000000000000000000000000000000..769151a9ff86e460d69d3598fcac0481d59cf17b --- /dev/null +++ b/modules/html.py @@ -0,0 +1,146 @@ +css = ''' +.loader-container { + display: flex; /* Use flex to align items horizontally */ + align-items: center; /* Center items vertically within the container */ + white-space: nowrap; /* Prevent line breaks within the container */ +} + +.loader { + border: 8px solid #f3f3f3; /* Light grey */ + border-top: 8px solid #3498db; /* Blue */ + border-radius: 50%; + width: 30px; + height: 30px; + animation: spin 2s linear infinite; +} + +@keyframes spin { + 0% { transform: rotate(0deg); } + 100% { transform: rotate(360deg); } +} + +/* Style the progress bar */ +progress { + appearance: none; /* Remove default styling */ + height: 20px; /* Set the height of the progress bar */ + border-radius: 5px; /* Round the corners of the progress bar */ + background-color: #f3f3f3; /* Light grey background */ + width: 100%; +} + +/* Style the progress bar container */ +.progress-container { + margin-left: 20px; + margin-right: 20px; + flex-grow: 1; /* Allow the progress container to take up remaining space */ +} + +/* Set the color of the progress bar fill */ +progress::-webkit-progress-value { + background-color: #3498db; /* Blue color for the fill */ +} + +progress::-moz-progress-bar { + background-color: #3498db; /* Blue color for the fill in Firefox */ +} + +/* Style the text on the progress bar */ +progress::after { + content: attr(value '%'); /* Display the progress value followed by '%' */ + position: absolute; + top: 50%; + left: 50%; + transform: translate(-50%, -50%); + color: white; /* Set text color */ + font-size: 14px; /* Set font size */ +} + +/* Style other texts */ +.loader-container > span { + margin-left: 5px; /* Add spacing between the progress bar and the text */ +} + +.progress-bar > .generating { + display: none !important; +} + +.progress-bar{ + height: 30px !important; +} + +.type_row{ + height: 80px !important; +} + +.type_row_half{ + height: 32px !important; +} + +.scroll-hide{ + resize: none !important; +} + +.refresh_button{ + border: none !important; + background: none !important; + font-size: none !important; + box-shadow: none !important; +} + +.advanced_check_row{ + width: 250px !important; +} + +.min_check{ + min-width: min(1px, 100%) !important; +} + +.resizable_area { + resize: vertical; + overflow: auto !important; +} + +.aspect_ratios label { + width: 140px !important; +} + +.aspect_ratios label span { + white-space: nowrap !important; +} + +.aspect_ratios label input { + margin-left: -5px !important; +} + +.lora_enable label { + height: 100%; +} + +.lora_enable label input { + margin: auto; +} + +.lora_enable label span { + display: none; +} + +@-moz-document url-prefix() { + .lora_weight input[type=number] { + width: 80px; + } +} + +''' +progress_html = ''' +
+
+
+ +
+ *text* +
+''' + + +def make_progress_html(number, text): + return progress_html.replace('*number*', str(number)).replace('*text*', text) diff --git a/modules/inpaint_worker.py b/modules/inpaint_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..43a7ae23e9bd2cebda69b94013bf1661bd8fd952 --- /dev/null +++ b/modules/inpaint_worker.py @@ -0,0 +1,264 @@ +import torch +import numpy as np + +from PIL import Image, ImageFilter +from modules.util import resample_image, set_image_shape_ceil, get_image_shape_ceil +from modules.upscaler import perform_upscale +import cv2 + + +inpaint_head_model = None + + +class InpaintHead(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.head = torch.nn.Parameter(torch.empty(size=(320, 5, 3, 3), device='cpu')) + + def __call__(self, x): + x = torch.nn.functional.pad(x, (1, 1, 1, 1), "replicate") + return torch.nn.functional.conv2d(input=x, weight=self.head) + + +current_task = None + + +def box_blur(x, k): + x = Image.fromarray(x) + x = x.filter(ImageFilter.BoxBlur(k)) + return np.array(x) + + +def max_filter_opencv(x, ksize=3): + # Use OpenCV maximum filter + # Make sure the input type is int16 + return cv2.dilate(x, np.ones((ksize, ksize), dtype=np.int16)) + + +def morphological_open(x): + # Convert array to int16 type via threshold operation + x_int16 = np.zeros_like(x, dtype=np.int16) + x_int16[x > 127] = 256 + + for i in range(32): + # Use int16 type to avoid overflow + maxed = max_filter_opencv(x_int16, ksize=3) - 8 + x_int16 = np.maximum(maxed, x_int16) + + # Clip negative values to 0 and convert back to uint8 type + x_uint8 = np.clip(x_int16, 0, 255).astype(np.uint8) + return x_uint8 + + +def up255(x, t=0): + y = np.zeros_like(x).astype(np.uint8) + y[x > t] = 255 + return y + + +def imsave(x, path): + x = Image.fromarray(x) + x.save(path) + + +def regulate_abcd(x, a, b, c, d): + H, W = x.shape[:2] + if a < 0: + a = 0 + if a > H: + a = H + if b < 0: + b = 0 + if b > H: + b = H + if c < 0: + c = 0 + if c > W: + c = W + if d < 0: + d = 0 + if d > W: + d = W + return int(a), int(b), int(c), int(d) + + +def compute_initial_abcd(x): + indices = np.where(x) + a = np.min(indices[0]) + b = np.max(indices[0]) + c = np.min(indices[1]) + d = np.max(indices[1]) + abp = (b + a) // 2 + abm = (b - a) // 2 + cdp = (d + c) // 2 + cdm = (d - c) // 2 + l = int(max(abm, cdm) * 1.15) + a = abp - l + b = abp + l + 1 + c = cdp - l + d = cdp + l + 1 + a, b, c, d = regulate_abcd(x, a, b, c, d) + return a, b, c, d + + +def solve_abcd(x, a, b, c, d, k): + k = float(k) + assert 0.0 <= k <= 1.0 + + H, W = x.shape[:2] + if k == 1.0: + return 0, H, 0, W + while True: + if b - a >= H * k and d - c >= W * k: + break + + add_h = (b - a) < (d - c) + add_w = not add_h + + if b - a == H: + add_w = True + + if d - c == W: + add_h = True + + if add_h: + a -= 1 + b += 1 + + if add_w: + c -= 1 + d += 1 + + a, b, c, d = regulate_abcd(x, a, b, c, d) + return a, b, c, d + + +def fooocus_fill(image, mask): + current_image = image.copy() + raw_image = image.copy() + area = np.where(mask < 127) + store = raw_image[area] + + for k, repeats in [(512, 2), (256, 2), (128, 4), (64, 4), (33, 8), (15, 8), (5, 16), (3, 16)]: + for _ in range(repeats): + current_image = box_blur(current_image, k) + current_image[area] = store + + return current_image + + +class InpaintWorker: + def __init__(self, image, mask, use_fill=True, k=0.618): + a, b, c, d = compute_initial_abcd(mask > 0) + a, b, c, d = solve_abcd(mask, a, b, c, d, k=k) + + # interested area + self.interested_area = (a, b, c, d) + self.interested_mask = mask[a:b, c:d] + self.interested_image = image[a:b, c:d] + + # super resolution + if get_image_shape_ceil(self.interested_image) < 1024: + self.interested_image = perform_upscale(self.interested_image) + + # resize to make images ready for diffusion + self.interested_image = set_image_shape_ceil(self.interested_image, 1024) + self.interested_fill = self.interested_image.copy() + H, W, C = self.interested_image.shape + + # process mask + self.interested_mask = up255(resample_image(self.interested_mask, W, H), t=127) + + # compute filling + if use_fill: + self.interested_fill = fooocus_fill(self.interested_image, self.interested_mask) + + # soft pixels + self.mask = morphological_open(mask) + self.image = image + + # ending + self.latent = None + self.latent_after_swap = None + self.swapped = False + self.latent_mask = None + self.inpaint_head_feature = None + return + + def load_latent(self, latent_fill, latent_mask, latent_swap=None): + self.latent = latent_fill + self.latent_mask = latent_mask + self.latent_after_swap = latent_swap + return + + def patch(self, inpaint_head_model_path, inpaint_latent, inpaint_latent_mask, model): + global inpaint_head_model + + if inpaint_head_model is None: + inpaint_head_model = InpaintHead() + sd = torch.load(inpaint_head_model_path, map_location='cpu') + inpaint_head_model.load_state_dict(sd) + + feed = torch.cat([ + inpaint_latent_mask, + model.model.process_latent_in(inpaint_latent) + ], dim=1) + + inpaint_head_model.to(device=feed.device, dtype=feed.dtype) + inpaint_head_feature = inpaint_head_model(feed) + + def input_block_patch(h, transformer_options): + if transformer_options["block"][1] == 0: + h = h + inpaint_head_feature.to(h) + return h + + m = model.clone() + m.set_model_input_block_patch(input_block_patch) + return m + + def swap(self): + if self.swapped: + return + + if self.latent is None: + return + + if self.latent_after_swap is None: + return + + self.latent, self.latent_after_swap = self.latent_after_swap, self.latent + self.swapped = True + return + + def unswap(self): + if not self.swapped: + return + + if self.latent is None: + return + + if self.latent_after_swap is None: + return + + self.latent, self.latent_after_swap = self.latent_after_swap, self.latent + self.swapped = False + return + + def color_correction(self, img): + fg = img.astype(np.float32) + bg = self.image.copy().astype(np.float32) + w = self.mask[:, :, None].astype(np.float32) / 255.0 + y = fg * w + bg * (1 - w) + return y.clip(0, 255).astype(np.uint8) + + def post_process(self, img): + a, b, c, d = self.interested_area + content = resample_image(img, d - c, b - a) + result = self.image.copy() + result[a:b, c:d] = content + result = self.color_correction(result) + return result + + def visualize_mask_processing(self): + return [self.interested_fill, self.interested_mask, self.interested_image] + diff --git a/modules/launch_util.py b/modules/launch_util.py new file mode 100644 index 0000000000000000000000000000000000000000..b483d5158ca5eeeff6f385b1a94990f9e5f6e871 --- /dev/null +++ b/modules/launch_util.py @@ -0,0 +1,103 @@ +import os +import importlib +import importlib.util +import subprocess +import sys +import re +import logging +import importlib.metadata +import packaging.version +from packaging.requirements import Requirement + + + + +logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh... +logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) + +re_requirement = re.compile(r"\s*([-\w]+)\s*(?:==\s*([-+.\w]+))?\s*") + +python = sys.executable +default_command_live = (os.environ.get('LAUNCH_LIVE_OUTPUT') == "1") +index_url = os.environ.get('INDEX_URL', "") + +modules_path = os.path.dirname(os.path.realpath(__file__)) +script_path = os.path.dirname(modules_path) + + +def is_installed(package): + try: + spec = importlib.util.find_spec(package) + except ModuleNotFoundError: + return False + + return spec is not None + + +def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str: + if desc is not None: + print(desc) + + run_kwargs = { + "args": command, + "shell": True, + "env": os.environ if custom_env is None else custom_env, + "encoding": 'utf8', + "errors": 'ignore', + } + + if not live: + run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE + + result = subprocess.run(**run_kwargs) + + if result.returncode != 0: + error_bits = [ + f"{errdesc or 'Error running command'}.", + f"Command: {command}", + f"Error code: {result.returncode}", + ] + if result.stdout: + error_bits.append(f"stdout: {result.stdout}") + if result.stderr: + error_bits.append(f"stderr: {result.stderr}") + raise RuntimeError("\n".join(error_bits)) + + return (result.stdout or "") + + +def run_pip(command, desc=None, live=default_command_live): + try: + index_url_line = f' --index-url {index_url}' if index_url != '' else '' + return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", + errdesc=f"Couldn't install {desc}", live=live) + except Exception as e: + print(e) + print(f'CMD Failed {desc}: {command}') + return None + + +def requirements_met(requirements_file): + with open(requirements_file, "r", encoding="utf8") as file: + for line in file: + line = line.strip() + if line == "" or line.startswith('#'): + continue + + requirement = Requirement(line) + package = requirement.name + + try: + version_installed = importlib.metadata.version(package) + installed_version = packaging.version.parse(version_installed) + + # Check if the installed version satisfies the requirement + if installed_version not in requirement.specifier: + print(f"Version mismatch for {package}: Installed version {version_installed} does not meet requirement {requirement}") + return False + except Exception as e: + print(f"Error checking version for {package}: {e}") + return False + + return True + diff --git a/modules/localization.py b/modules/localization.py new file mode 100644 index 0000000000000000000000000000000000000000..b21d4a564d134ac0be00d83c7005627d601d206e --- /dev/null +++ b/modules/localization.py @@ -0,0 +1,60 @@ +import json +import os + + +current_translation = {} +localization_root = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'language') + + +def localization_js(filename): + global current_translation + + if isinstance(filename, str): + full_name = os.path.abspath(os.path.join(localization_root, filename + '.json')) + if os.path.exists(full_name): + try: + with open(full_name, encoding='utf-8') as f: + current_translation = json.load(f) + assert isinstance(current_translation, dict) + for k, v in current_translation.items(): + assert isinstance(k, str) + assert isinstance(v, str) + except Exception as e: + print(str(e)) + print(f'Failed to load localization file {full_name}') + + # current_translation = {k: 'XXX' for k in current_translation.keys()} # use this to see if all texts are covered + + return f"window.localization = {json.dumps(current_translation)}" + + +def dump_english_config(components): + all_texts = [] + for c in components: + label = getattr(c, 'label', None) + value = getattr(c, 'value', None) + choices = getattr(c, 'choices', None) + info = getattr(c, 'info', None) + + if isinstance(label, str): + all_texts.append(label) + if isinstance(value, str): + all_texts.append(value) + if isinstance(info, str): + all_texts.append(info) + if isinstance(choices, list): + for x in choices: + if isinstance(x, str): + all_texts.append(x) + if isinstance(x, tuple): + for y in x: + if isinstance(y, str): + all_texts.append(y) + + config_dict = {k: k for k in all_texts if k != "" and 'progress-container' not in k} + full_name = os.path.abspath(os.path.join(localization_root, 'en.json')) + + with open(full_name, "w", encoding="utf-8") as json_file: + json.dump(config_dict, json_file, indent=4) + + return diff --git a/modules/lora.py b/modules/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..088545c708943aa8e51e8b2bfd32b2a9968b639f --- /dev/null +++ b/modules/lora.py @@ -0,0 +1,152 @@ +def match_lora(lora, to_load): + patch_dict = {} + loaded_keys = set() + for x in to_load: + real_load_key = to_load[x] + if real_load_key in lora: + patch_dict[real_load_key] = ('fooocus', lora[real_load_key]) + loaded_keys.add(real_load_key) + continue + + alpha_name = "{}.alpha".format(x) + alpha = None + if alpha_name in lora.keys(): + alpha = lora[alpha_name].item() + loaded_keys.add(alpha_name) + + regular_lora = "{}.lora_up.weight".format(x) + diffusers_lora = "{}_lora.up.weight".format(x) + transformers_lora = "{}.lora_linear_layer.up.weight".format(x) + A_name = None + + if regular_lora in lora.keys(): + A_name = regular_lora + B_name = "{}.lora_down.weight".format(x) + mid_name = "{}.lora_mid.weight".format(x) + elif diffusers_lora in lora.keys(): + A_name = diffusers_lora + B_name = "{}_lora.down.weight".format(x) + mid_name = None + elif transformers_lora in lora.keys(): + A_name = transformers_lora + B_name ="{}.lora_linear_layer.down.weight".format(x) + mid_name = None + + if A_name is not None: + mid = None + if mid_name is not None and mid_name in lora.keys(): + mid = lora[mid_name] + loaded_keys.add(mid_name) + patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid)) + loaded_keys.add(A_name) + loaded_keys.add(B_name) + + + ######## loha + hada_w1_a_name = "{}.hada_w1_a".format(x) + hada_w1_b_name = "{}.hada_w1_b".format(x) + hada_w2_a_name = "{}.hada_w2_a".format(x) + hada_w2_b_name = "{}.hada_w2_b".format(x) + hada_t1_name = "{}.hada_t1".format(x) + hada_t2_name = "{}.hada_t2".format(x) + if hada_w1_a_name in lora.keys(): + hada_t1 = None + hada_t2 = None + if hada_t1_name in lora.keys(): + hada_t1 = lora[hada_t1_name] + hada_t2 = lora[hada_t2_name] + loaded_keys.add(hada_t1_name) + loaded_keys.add(hada_t2_name) + + patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)) + loaded_keys.add(hada_w1_a_name) + loaded_keys.add(hada_w1_b_name) + loaded_keys.add(hada_w2_a_name) + loaded_keys.add(hada_w2_b_name) + + + ######## lokr + lokr_w1_name = "{}.lokr_w1".format(x) + lokr_w2_name = "{}.lokr_w2".format(x) + lokr_w1_a_name = "{}.lokr_w1_a".format(x) + lokr_w1_b_name = "{}.lokr_w1_b".format(x) + lokr_t2_name = "{}.lokr_t2".format(x) + lokr_w2_a_name = "{}.lokr_w2_a".format(x) + lokr_w2_b_name = "{}.lokr_w2_b".format(x) + + lokr_w1 = None + if lokr_w1_name in lora.keys(): + lokr_w1 = lora[lokr_w1_name] + loaded_keys.add(lokr_w1_name) + + lokr_w2 = None + if lokr_w2_name in lora.keys(): + lokr_w2 = lora[lokr_w2_name] + loaded_keys.add(lokr_w2_name) + + lokr_w1_a = None + if lokr_w1_a_name in lora.keys(): + lokr_w1_a = lora[lokr_w1_a_name] + loaded_keys.add(lokr_w1_a_name) + + lokr_w1_b = None + if lokr_w1_b_name in lora.keys(): + lokr_w1_b = lora[lokr_w1_b_name] + loaded_keys.add(lokr_w1_b_name) + + lokr_w2_a = None + if lokr_w2_a_name in lora.keys(): + lokr_w2_a = lora[lokr_w2_a_name] + loaded_keys.add(lokr_w2_a_name) + + lokr_w2_b = None + if lokr_w2_b_name in lora.keys(): + lokr_w2_b = lora[lokr_w2_b_name] + loaded_keys.add(lokr_w2_b_name) + + lokr_t2 = None + if lokr_t2_name in lora.keys(): + lokr_t2 = lora[lokr_t2_name] + loaded_keys.add(lokr_t2_name) + + if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): + patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)) + + #glora + a1_name = "{}.a1.weight".format(x) + a2_name = "{}.a2.weight".format(x) + b1_name = "{}.b1.weight".format(x) + b2_name = "{}.b2.weight".format(x) + if a1_name in lora: + patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha)) + loaded_keys.add(a1_name) + loaded_keys.add(a2_name) + loaded_keys.add(b1_name) + loaded_keys.add(b2_name) + + w_norm_name = "{}.w_norm".format(x) + b_norm_name = "{}.b_norm".format(x) + w_norm = lora.get(w_norm_name, None) + b_norm = lora.get(b_norm_name, None) + + if w_norm is not None: + loaded_keys.add(w_norm_name) + patch_dict[to_load[x]] = ("diff", (w_norm,)) + if b_norm is not None: + loaded_keys.add(b_norm_name) + patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,)) + + diff_name = "{}.diff".format(x) + diff_weight = lora.get(diff_name, None) + if diff_weight is not None: + patch_dict[to_load[x]] = ("diff", (diff_weight,)) + loaded_keys.add(diff_name) + + diff_bias_name = "{}.diff_b".format(x) + diff_bias = lora.get(diff_bias_name, None) + if diff_bias is not None: + patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,)) + loaded_keys.add(diff_bias_name) + + remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys} + return patch_dict, remaining_dict diff --git a/modules/meta_parser.py b/modules/meta_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..546c093fa008da831fb17b0e6a2cc256467315b2 --- /dev/null +++ b/modules/meta_parser.py @@ -0,0 +1,573 @@ +import json +import os +import re +from abc import ABC, abstractmethod +from pathlib import Path + +import gradio as gr +from PIL import Image + +import fooocus_version +import modules.config +import modules.sdxl_styles +from modules.flags import MetadataScheme, Performance, Steps +from modules.flags import SAMPLERS, CIVITAI_NO_KARRAS +from modules.util import quote, unquote, extract_styles_from_prompt, is_json, get_file_from_folder_list, calculate_sha256 + +re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' +re_param = re.compile(re_param_code) +re_imagesize = re.compile(r"^(\d+)x(\d+)$") + +hash_cache = {} + + +def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool): + loaded_parameter_dict = raw_metadata + if isinstance(raw_metadata, str): + loaded_parameter_dict = json.loads(raw_metadata) + assert isinstance(loaded_parameter_dict, dict) + + results = [len(loaded_parameter_dict) > 0, 1] + + get_str('prompt', 'Prompt', loaded_parameter_dict, results) + get_str('negative_prompt', 'Negative Prompt', loaded_parameter_dict, results) + get_list('styles', 'Styles', loaded_parameter_dict, results) + get_str('performance', 'Performance', loaded_parameter_dict, results) + get_steps('steps', 'Steps', loaded_parameter_dict, results) + get_float('overwrite_switch', 'Overwrite Switch', loaded_parameter_dict, results) + get_resolution('resolution', 'Resolution', loaded_parameter_dict, results) + get_float('guidance_scale', 'Guidance Scale', loaded_parameter_dict, results) + get_float('sharpness', 'Sharpness', loaded_parameter_dict, results) + get_adm_guidance('adm_guidance', 'ADM Guidance', loaded_parameter_dict, results) + get_str('refiner_swap_method', 'Refiner Swap Method', loaded_parameter_dict, results) + get_float('adaptive_cfg', 'CFG Mimicking from TSNR', loaded_parameter_dict, results) + get_str('base_model', 'Base Model', loaded_parameter_dict, results) + get_str('refiner_model', 'Refiner Model', loaded_parameter_dict, results) + get_float('refiner_switch', 'Refiner Switch', loaded_parameter_dict, results) + get_str('sampler', 'Sampler', loaded_parameter_dict, results) + get_str('scheduler', 'Scheduler', loaded_parameter_dict, results) + get_seed('seed', 'Seed', loaded_parameter_dict, results) + + if is_generating: + results.append(gr.update()) + else: + results.append(gr.update(visible=True)) + + results.append(gr.update(visible=False)) + + get_freeu('freeu', 'FreeU', loaded_parameter_dict, results) + + for i in range(modules.config.default_max_lora_number): + get_lora(f'lora_combined_{i + 1}', f'LoRA {i + 1}', loaded_parameter_dict, results) + + return results + + +def get_str(key: str, fallback: str | None, source_dict: dict, results: list, default=None): + try: + h = source_dict.get(key, source_dict.get(fallback, default)) + assert isinstance(h, str) + results.append(h) + except: + results.append(gr.update()) + + +def get_list(key: str, fallback: str | None, source_dict: dict, results: list, default=None): + try: + h = source_dict.get(key, source_dict.get(fallback, default)) + h = eval(h) + assert isinstance(h, list) + results.append(h) + except: + results.append(gr.update()) + + +def get_float(key: str, fallback: str | None, source_dict: dict, results: list, default=None): + try: + h = source_dict.get(key, source_dict.get(fallback, default)) + assert h is not None + h = float(h) + results.append(h) + except: + results.append(gr.update()) + + +def get_steps(key: str, fallback: str | None, source_dict: dict, results: list, default=None): + try: + h = source_dict.get(key, source_dict.get(fallback, default)) + assert h is not None + h = int(h) + # if not in steps or in steps and performance is not the same + if h not in iter(Steps) or Steps(h).name.casefold() != source_dict.get('performance', '').replace(' ', '_').casefold(): + results.append(h) + return + results.append(-1) + except: + results.append(-1) + + +def get_resolution(key: str, fallback: str | None, source_dict: dict, results: list, default=None): + try: + h = source_dict.get(key, source_dict.get(fallback, default)) + width, height = eval(h) + formatted = modules.config.add_ratio(f'{width}*{height}') + if formatted in modules.config.available_aspect_ratios: + results.append(formatted) + results.append(-1) + results.append(-1) + else: + results.append(gr.update()) + results.append(int(width)) + results.append(int(height)) + except: + results.append(gr.update()) + results.append(gr.update()) + results.append(gr.update()) + + +def get_seed(key: str, fallback: str | None, source_dict: dict, results: list, default=None): + try: + h = source_dict.get(key, source_dict.get(fallback, default)) + assert h is not None + h = int(h) + results.append(False) + results.append(h) + except: + results.append(gr.update()) + results.append(gr.update()) + + +def get_adm_guidance(key: str, fallback: str | None, source_dict: dict, results: list, default=None): + try: + h = source_dict.get(key, source_dict.get(fallback, default)) + p, n, e = eval(h) + results.append(float(p)) + results.append(float(n)) + results.append(float(e)) + except: + results.append(gr.update()) + results.append(gr.update()) + results.append(gr.update()) + + +def get_freeu(key: str, fallback: str | None, source_dict: dict, results: list, default=None): + try: + h = source_dict.get(key, source_dict.get(fallback, default)) + b1, b2, s1, s2 = eval(h) + results.append(True) + results.append(float(b1)) + results.append(float(b2)) + results.append(float(s1)) + results.append(float(s2)) + except: + results.append(False) + results.append(gr.update()) + results.append(gr.update()) + results.append(gr.update()) + results.append(gr.update()) + + +def get_lora(key: str, fallback: str | None, source_dict: dict, results: list): + try: + n, w = source_dict.get(key, source_dict.get(fallback)).split(' : ') + w = float(w) + results.append(True) + results.append(n) + results.append(w) + except: + results.append(True) + results.append('None') + results.append(1) + + +def get_sha256(filepath): + global hash_cache + if filepath not in hash_cache: + hash_cache[filepath] = calculate_sha256(filepath) + + return hash_cache[filepath] + + +def parse_meta_from_preset(preset_content): + assert isinstance(preset_content, dict) + preset_prepared = {} + items = preset_content + + for settings_key, meta_key in modules.config.possible_preset_keys.items(): + if settings_key == "default_loras": + loras = getattr(modules.config, settings_key) + if settings_key in items: + loras = items[settings_key] + for index, lora in enumerate(loras[:5]): + preset_prepared[f'lora_combined_{index + 1}'] = ' : '.join(map(str, lora)) + elif settings_key == "default_aspect_ratio": + if settings_key in items and items[settings_key] is not None: + default_aspect_ratio = items[settings_key] + width, height = default_aspect_ratio.split('*') + else: + default_aspect_ratio = getattr(modules.config, settings_key) + width, height = default_aspect_ratio.split('×') + height = height[:height.index(" ")] + preset_prepared[meta_key] = (width, height) + else: + preset_prepared[meta_key] = items[settings_key] if settings_key in items and items[ + settings_key] is not None else getattr(modules.config, settings_key) + + if settings_key == "default_styles" or settings_key == "default_aspect_ratio": + preset_prepared[meta_key] = str(preset_prepared[meta_key]) + + return preset_prepared + + +class MetadataParser(ABC): + def __init__(self): + self.raw_prompt: str = '' + self.full_prompt: str = '' + self.raw_negative_prompt: str = '' + self.full_negative_prompt: str = '' + self.steps: int = 30 + self.base_model_name: str = '' + self.base_model_hash: str = '' + self.refiner_model_name: str = '' + self.refiner_model_hash: str = '' + self.loras: list = [] + + @abstractmethod + def get_scheme(self) -> MetadataScheme: + raise NotImplementedError + + @abstractmethod + def parse_json(self, metadata: dict | str) -> dict: + raise NotImplementedError + + @abstractmethod + def parse_string(self, metadata: dict) -> str: + raise NotImplementedError + + def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_prompt, steps, base_model_name, + refiner_model_name, loras): + self.raw_prompt = raw_prompt + self.full_prompt = full_prompt + self.raw_negative_prompt = raw_negative_prompt + self.full_negative_prompt = full_negative_prompt + self.steps = steps + self.base_model_name = Path(base_model_name).stem + + base_model_path = get_file_from_folder_list(base_model_name, modules.config.paths_checkpoints) + self.base_model_hash = get_sha256(base_model_path) + + if refiner_model_name not in ['', 'None']: + self.refiner_model_name = Path(refiner_model_name).stem + refiner_model_path = get_file_from_folder_list(refiner_model_name, modules.config.paths_checkpoints) + self.refiner_model_hash = get_sha256(refiner_model_path) + + self.loras = [] + for (lora_name, lora_weight) in loras: + if lora_name != 'None': + lora_path = get_file_from_folder_list(lora_name, modules.config.paths_loras) + lora_hash = get_sha256(lora_path) + self.loras.append((Path(lora_name).stem, lora_weight, lora_hash)) + + +class A1111MetadataParser(MetadataParser): + def get_scheme(self) -> MetadataScheme: + return MetadataScheme.A1111 + + fooocus_to_a1111 = { + 'raw_prompt': 'Raw prompt', + 'raw_negative_prompt': 'Raw negative prompt', + 'negative_prompt': 'Negative prompt', + 'styles': 'Styles', + 'performance': 'Performance', + 'steps': 'Steps', + 'sampler': 'Sampler', + 'scheduler': 'Scheduler', + 'guidance_scale': 'CFG scale', + 'seed': 'Seed', + 'resolution': 'Size', + 'sharpness': 'Sharpness', + 'adm_guidance': 'ADM Guidance', + 'refiner_swap_method': 'Refiner Swap Method', + 'adaptive_cfg': 'Adaptive CFG', + 'overwrite_switch': 'Overwrite Switch', + 'freeu': 'FreeU', + 'base_model': 'Model', + 'base_model_hash': 'Model hash', + 'refiner_model': 'Refiner', + 'refiner_model_hash': 'Refiner hash', + 'lora_hashes': 'Lora hashes', + 'lora_weights': 'Lora weights', + 'created_by': 'User', + 'version': 'Version' + } + + def parse_json(self, metadata: str) -> dict: + metadata_prompt = '' + metadata_negative_prompt = '' + + done_with_prompt = False + + *lines, lastline = metadata.strip().split("\n") + if len(re_param.findall(lastline)) < 3: + lines.append(lastline) + lastline = '' + + for line in lines: + line = line.strip() + if line.startswith(f"{self.fooocus_to_a1111['negative_prompt']}:"): + done_with_prompt = True + line = line[len(f"{self.fooocus_to_a1111['negative_prompt']}:"):].strip() + if done_with_prompt: + metadata_negative_prompt += ('' if metadata_negative_prompt == '' else "\n") + line + else: + metadata_prompt += ('' if metadata_prompt == '' else "\n") + line + + found_styles, prompt, negative_prompt = extract_styles_from_prompt(metadata_prompt, metadata_negative_prompt) + + data = { + 'prompt': prompt, + 'negative_prompt': negative_prompt + } + + for k, v in re_param.findall(lastline): + try: + if v != '' and v[0] == '"' and v[-1] == '"': + v = unquote(v) + + m = re_imagesize.match(v) + if m is not None: + data['resolution'] = str((m.group(1), m.group(2))) + else: + data[list(self.fooocus_to_a1111.keys())[list(self.fooocus_to_a1111.values()).index(k)]] = v + except Exception: + print(f"Error parsing \"{k}: {v}\"") + + # workaround for multiline prompts + if 'raw_prompt' in data: + data['prompt'] = data['raw_prompt'] + raw_prompt = data['raw_prompt'].replace("\n", ', ') + if metadata_prompt != raw_prompt and modules.sdxl_styles.fooocus_expansion not in found_styles: + found_styles.append(modules.sdxl_styles.fooocus_expansion) + + if 'raw_negative_prompt' in data: + data['negative_prompt'] = data['raw_negative_prompt'] + + data['styles'] = str(found_styles) + + # try to load performance based on steps, fallback for direct A1111 imports + if 'steps' in data and 'performance' not in data: + try: + data['performance'] = Performance[Steps(int(data['steps'])).name].value + except ValueError | KeyError: + pass + + if 'sampler' in data: + data['sampler'] = data['sampler'].replace(' Karras', '') + # get key + for k, v in SAMPLERS.items(): + if v == data['sampler']: + data['sampler'] = k + break + + for key in ['base_model', 'refiner_model']: + if key in data: + for filename in modules.config.model_filenames: + path = Path(filename) + if data[key] == path.stem: + data[key] = filename + break + + if 'lora_hashes' in data: + lora_filenames = modules.config.lora_filenames.copy() + if modules.config.sdxl_lcm_lora in lora_filenames: + lora_filenames.remove(modules.config.sdxl_lcm_lora) + for li, lora in enumerate(data['lora_hashes'].split(', ')): + lora_name, lora_hash, lora_weight = lora.split(': ') + for filename in lora_filenames: + path = Path(filename) + if lora_name == path.stem: + data[f'lora_combined_{li + 1}'] = f'{filename} : {lora_weight}' + break + + return data + + def parse_string(self, metadata: dict) -> str: + data = {k: v for _, k, v in metadata} + + width, height = eval(data['resolution']) + + sampler = data['sampler'] + scheduler = data['scheduler'] + if sampler in SAMPLERS and SAMPLERS[sampler] != '': + sampler = SAMPLERS[sampler] + if sampler not in CIVITAI_NO_KARRAS and scheduler == 'karras': + sampler += f' Karras' + + generation_params = { + self.fooocus_to_a1111['steps']: self.steps, + self.fooocus_to_a1111['sampler']: sampler, + self.fooocus_to_a1111['seed']: data['seed'], + self.fooocus_to_a1111['resolution']: f'{width}x{height}', + self.fooocus_to_a1111['guidance_scale']: data['guidance_scale'], + self.fooocus_to_a1111['sharpness']: data['sharpness'], + self.fooocus_to_a1111['adm_guidance']: data['adm_guidance'], + self.fooocus_to_a1111['base_model']: Path(data['base_model']).stem, + self.fooocus_to_a1111['base_model_hash']: self.base_model_hash, + + self.fooocus_to_a1111['performance']: data['performance'], + self.fooocus_to_a1111['scheduler']: scheduler, + # workaround for multiline prompts + self.fooocus_to_a1111['raw_prompt']: self.raw_prompt, + self.fooocus_to_a1111['raw_negative_prompt']: self.raw_negative_prompt, + } + + if self.refiner_model_name not in ['', 'None']: + generation_params |= { + self.fooocus_to_a1111['refiner_model']: self.refiner_model_name, + self.fooocus_to_a1111['refiner_model_hash']: self.refiner_model_hash + } + + for key in ['adaptive_cfg', 'overwrite_switch', 'refiner_swap_method', 'freeu']: + if key in data: + generation_params[self.fooocus_to_a1111[key]] = data[key] + + lora_hashes = [] + for index, (lora_name, lora_weight, lora_hash) in enumerate(self.loras): + # workaround for Fooocus not knowing LoRA name in LoRA metadata + lora_hashes.append(f'{lora_name}: {lora_hash}: {lora_weight}') + lora_hashes_string = ', '.join(lora_hashes) + + generation_params |= { + self.fooocus_to_a1111['lora_hashes']: lora_hashes_string, + self.fooocus_to_a1111['version']: data['version'] + } + + if modules.config.metadata_created_by != '': + generation_params[self.fooocus_to_a1111['created_by']] = modules.config.metadata_created_by + + generation_params_text = ", ".join( + [k if k == v else f'{k}: {quote(v)}' for k, v in generation_params.items() if + v is not None]) + positive_prompt_resolved = ', '.join(self.full_prompt) + negative_prompt_resolved = ', '.join(self.full_negative_prompt) + negative_prompt_text = f"\nNegative prompt: {negative_prompt_resolved}" if negative_prompt_resolved else "" + return f"{positive_prompt_resolved}{negative_prompt_text}\n{generation_params_text}".strip() + + +class FooocusMetadataParser(MetadataParser): + def get_scheme(self) -> MetadataScheme: + return MetadataScheme.FOOOCUS + + def parse_json(self, metadata: dict) -> dict: + model_filenames = modules.config.model_filenames.copy() + lora_filenames = modules.config.lora_filenames.copy() + if modules.config.sdxl_lcm_lora in lora_filenames: + lora_filenames.remove(modules.config.sdxl_lcm_lora) + + for key, value in metadata.items(): + if value in ['', 'None']: + continue + if key in ['base_model', 'refiner_model']: + metadata[key] = self.replace_value_with_filename(key, value, model_filenames) + elif key.startswith('lora_combined_'): + metadata[key] = self.replace_value_with_filename(key, value, lora_filenames) + else: + continue + + return metadata + + def parse_string(self, metadata: list) -> str: + for li, (label, key, value) in enumerate(metadata): + # remove model folder paths from metadata + if key.startswith('lora_combined_'): + name, weight = value.split(' : ') + name = Path(name).stem + value = f'{name} : {weight}' + metadata[li] = (label, key, value) + + res = {k: v for _, k, v in metadata} + + res['full_prompt'] = self.full_prompt + res['full_negative_prompt'] = self.full_negative_prompt + res['steps'] = self.steps + res['base_model'] = self.base_model_name + res['base_model_hash'] = self.base_model_hash + + if self.refiner_model_name not in ['', 'None']: + res['refiner_model'] = self.refiner_model_name + res['refiner_model_hash'] = self.refiner_model_hash + + res['loras'] = self.loras + + if modules.config.metadata_created_by != '': + res['created_by'] = modules.config.metadata_created_by + + return json.dumps(dict(sorted(res.items()))) + + @staticmethod + def replace_value_with_filename(key, value, filenames): + for filename in filenames: + path = Path(filename) + if key.startswith('lora_combined_'): + name, weight = value.split(' : ') + if name == path.stem: + return f'{filename} : {weight}' + elif value == path.stem: + return filename + + +def get_metadata_parser(metadata_scheme: MetadataScheme) -> MetadataParser: + match metadata_scheme: + case MetadataScheme.FOOOCUS: + return FooocusMetadataParser() + case MetadataScheme.A1111: + return A1111MetadataParser() + case _: + raise NotImplementedError + + +def read_info_from_image(filepath) -> tuple[str | None, MetadataScheme | None]: + with Image.open(filepath) as image: + items = (image.info or {}).copy() + + parameters = items.pop('parameters', None) + metadata_scheme = items.pop('fooocus_scheme', None) + exif = items.pop('exif', None) + + if parameters is not None and is_json(parameters): + parameters = json.loads(parameters) + elif exif is not None: + exif = image.getexif() + # 0x9286 = UserComment + parameters = exif.get(0x9286, None) + # 0x927C = MakerNote + metadata_scheme = exif.get(0x927C, None) + + if is_json(parameters): + parameters = json.loads(parameters) + + try: + metadata_scheme = MetadataScheme(metadata_scheme) + except ValueError: + metadata_scheme = None + + # broad fallback + if isinstance(parameters, dict): + metadata_scheme = MetadataScheme.FOOOCUS + + if isinstance(parameters, str): + metadata_scheme = MetadataScheme.A1111 + + return parameters, metadata_scheme + + +def get_exif(metadata: str | None, metadata_scheme: str): + exif = Image.Exif() + # tags see see https://github.com/python-pillow/Pillow/blob/9.2.x/src/PIL/ExifTags.py + # 0x9286 = UserComment + exif[0x9286] = metadata + # 0x0131 = Software + exif[0x0131] = 'Fooocus v' + fooocus_version.version + # 0x927C = MakerNote + exif[0x927C] = metadata_scheme + return exif \ No newline at end of file diff --git a/modules/model_loader.py b/modules/model_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..8ba336a915ae234b7cd5f9f2576d4edf779738ba --- /dev/null +++ b/modules/model_loader.py @@ -0,0 +1,26 @@ +import os +from urllib.parse import urlparse +from typing import Optional + + +def load_file_from_url( + url: str, + *, + model_dir: str, + progress: bool = True, + file_name: Optional[str] = None, +) -> str: + """Download a file from `url` into `model_dir`, using the file present if possible. + + Returns the path to the downloaded file. + """ + os.makedirs(model_dir, exist_ok=True) + if not file_name: + parts = urlparse(url) + file_name = os.path.basename(parts.path) + cached_file = os.path.abspath(os.path.join(model_dir, file_name)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + from torch.hub import download_url_to_file + download_url_to_file(url, cached_file, progress=progress) + return cached_file diff --git a/modules/ops.py b/modules/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ee0e775634314d1b71811258cff87b2178e1c740 --- /dev/null +++ b/modules/ops.py @@ -0,0 +1,19 @@ +import torch +import contextlib + + +@contextlib.contextmanager +def use_patched_ops(operations): + op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm'] + backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names} + + try: + for op_name in op_names: + setattr(torch.nn, op_name, getattr(operations, op_name)) + + yield + + finally: + for op_name in op_names: + setattr(torch.nn, op_name, backups[op_name]) + return diff --git a/modules/patch.py b/modules/patch.py new file mode 100644 index 0000000000000000000000000000000000000000..3c2dd8f477902e68a467e8f89888934a762f4bb1 --- /dev/null +++ b/modules/patch.py @@ -0,0 +1,513 @@ +import os +import torch +import time +import math +import ldm_patched.modules.model_base +import ldm_patched.ldm.modules.diffusionmodules.openaimodel +import ldm_patched.modules.model_management +import modules.anisotropic as anisotropic +import ldm_patched.ldm.modules.attention +import ldm_patched.k_diffusion.sampling +import ldm_patched.modules.sd1_clip +import modules.inpaint_worker as inpaint_worker +import ldm_patched.ldm.modules.diffusionmodules.openaimodel +import ldm_patched.ldm.modules.diffusionmodules.model +import ldm_patched.modules.sd +import ldm_patched.controlnet.cldm +import ldm_patched.modules.model_patcher +import ldm_patched.modules.samplers +import ldm_patched.modules.args_parser +import warnings +import safetensors.torch +import modules.constants as constants + +from ldm_patched.modules.samplers import calc_cond_uncond_batch +from ldm_patched.k_diffusion.sampling import BatchedBrownianTree +from ldm_patched.ldm.modules.diffusionmodules.openaimodel import forward_timestep_embed, apply_control +from modules.patch_precision import patch_all_precision +from modules.patch_clip import patch_all_clip + + +class PatchSettings: + def __init__(self, + sharpness=2.0, + adm_scaler_end=0.3, + positive_adm_scale=1.5, + negative_adm_scale=0.8, + controlnet_softness=0.25, + adaptive_cfg=7.0): + self.sharpness = sharpness + self.adm_scaler_end = adm_scaler_end + self.positive_adm_scale = positive_adm_scale + self.negative_adm_scale = negative_adm_scale + self.controlnet_softness = controlnet_softness + self.adaptive_cfg = adaptive_cfg + self.global_diffusion_progress = 0 + self.eps_record = None + + +patch_settings = {} + + +def calculate_weight_patched(self, patches, weight, key): + for p in patches: + alpha = p[0] + v = p[1] + strength_model = p[2] + + if strength_model != 1.0: + weight *= strength_model + + if isinstance(v, list): + v = (self.calculate_weight(v[1:], v[0].clone(), key),) + + if len(v) == 1: + patch_type = "diff" + elif len(v) == 2: + patch_type = v[0] + v = v[1] + + if patch_type == "diff": + w1 = v[0] + if alpha != 0.0: + if w1.shape != weight.shape: + print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) + else: + weight += alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype) + elif patch_type == "lora": + mat1 = ldm_patched.modules.model_management.cast_to_device(v[0], weight.device, torch.float32) + mat2 = ldm_patched.modules.model_management.cast_to_device(v[1], weight.device, torch.float32) + if v[2] is not None: + alpha *= v[2] / mat2.shape[0] + if v[3] is not None: + mat3 = ldm_patched.modules.model_management.cast_to_device(v[3], weight.device, torch.float32) + final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] + mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), + mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) + try: + weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape( + weight.shape).type(weight.dtype) + except Exception as e: + print("ERROR", key, e) + elif patch_type == "fooocus": + w1 = ldm_patched.modules.model_management.cast_to_device(v[0], weight.device, torch.float32) + w_min = ldm_patched.modules.model_management.cast_to_device(v[1], weight.device, torch.float32) + w_max = ldm_patched.modules.model_management.cast_to_device(v[2], weight.device, torch.float32) + w1 = (w1 / 255.0) * (w_max - w_min) + w_min + if alpha != 0.0: + if w1.shape != weight.shape: + print("WARNING SHAPE MISMATCH {} FOOOCUS WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) + else: + weight += alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype) + elif patch_type == "lokr": + w1 = v[0] + w2 = v[1] + w1_a = v[3] + w1_b = v[4] + w2_a = v[5] + w2_b = v[6] + t2 = v[7] + dim = None + + if w1 is None: + dim = w1_b.shape[0] + w1 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w1_a, weight.device, torch.float32), + ldm_patched.modules.model_management.cast_to_device(w1_b, weight.device, torch.float32)) + else: + w1 = ldm_patched.modules.model_management.cast_to_device(w1, weight.device, torch.float32) + + if w2 is None: + dim = w2_b.shape[0] + if t2 is None: + w2 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w2_a, weight.device, torch.float32), + ldm_patched.modules.model_management.cast_to_device(w2_b, weight.device, torch.float32)) + else: + w2 = torch.einsum('i j k l, j r, i p -> p r k l', + ldm_patched.modules.model_management.cast_to_device(t2, weight.device, torch.float32), + ldm_patched.modules.model_management.cast_to_device(w2_b, weight.device, torch.float32), + ldm_patched.modules.model_management.cast_to_device(w2_a, weight.device, torch.float32)) + else: + w2 = ldm_patched.modules.model_management.cast_to_device(w2, weight.device, torch.float32) + + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + if v[2] is not None and dim is not None: + alpha *= v[2] / dim + + try: + weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) + except Exception as e: + print("ERROR", key, e) + elif patch_type == "loha": + w1a = v[0] + w1b = v[1] + if v[2] is not None: + alpha *= v[2] / w1b.shape[0] + w2a = v[3] + w2b = v[4] + if v[5] is not None: # cp decomposition + t1 = v[5] + t2 = v[6] + m1 = torch.einsum('i j k l, j r, i p -> p r k l', + ldm_patched.modules.model_management.cast_to_device(t1, weight.device, torch.float32), + ldm_patched.modules.model_management.cast_to_device(w1b, weight.device, torch.float32), + ldm_patched.modules.model_management.cast_to_device(w1a, weight.device, torch.float32)) + + m2 = torch.einsum('i j k l, j r, i p -> p r k l', + ldm_patched.modules.model_management.cast_to_device(t2, weight.device, torch.float32), + ldm_patched.modules.model_management.cast_to_device(w2b, weight.device, torch.float32), + ldm_patched.modules.model_management.cast_to_device(w2a, weight.device, torch.float32)) + else: + m1 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w1a, weight.device, torch.float32), + ldm_patched.modules.model_management.cast_to_device(w1b, weight.device, torch.float32)) + m2 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w2a, weight.device, torch.float32), + ldm_patched.modules.model_management.cast_to_device(w2b, weight.device, torch.float32)) + + try: + weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) + except Exception as e: + print("ERROR", key, e) + elif patch_type == "glora": + if v[4] is not None: + alpha *= v[4] / v[0].shape[0] + + a1 = ldm_patched.modules.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32) + a2 = ldm_patched.modules.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32) + b1 = ldm_patched.modules.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32) + b2 = ldm_patched.modules.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32) + + weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype) + else: + print("patch type not recognized", patch_type, key) + + return weight + + +class BrownianTreeNoiseSamplerPatched: + transform = None + tree = None + + @staticmethod + def global_init(x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False): + if ldm_patched.modules.model_management.directml_enabled: + cpu = True + + t0, t1 = transform(torch.as_tensor(sigma_min)), transform(torch.as_tensor(sigma_max)) + + BrownianTreeNoiseSamplerPatched.transform = transform + BrownianTreeNoiseSamplerPatched.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu) + + def __init__(self, *args, **kwargs): + pass + + @staticmethod + def __call__(sigma, sigma_next): + transform = BrownianTreeNoiseSamplerPatched.transform + tree = BrownianTreeNoiseSamplerPatched.tree + + t0, t1 = transform(torch.as_tensor(sigma)), transform(torch.as_tensor(sigma_next)) + return tree(t0, t1) / (t1 - t0).abs().sqrt() + + +def compute_cfg(uncond, cond, cfg_scale, t): + pid = os.getpid() + mimic_cfg = float(patch_settings[pid].adaptive_cfg) + real_cfg = float(cfg_scale) + + real_eps = uncond + real_cfg * (cond - uncond) + + if cfg_scale > patch_settings[pid].adaptive_cfg: + mimicked_eps = uncond + mimic_cfg * (cond - uncond) + return real_eps * t + mimicked_eps * (1 - t) + else: + return real_eps + + +def patched_sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options=None, seed=None): + pid = os.getpid() + + if math.isclose(cond_scale, 1.0) and not model_options.get("disable_cfg1_optimization", False): + final_x0 = calc_cond_uncond_batch(model, cond, None, x, timestep, model_options)[0] + + if patch_settings[pid].eps_record is not None: + patch_settings[pid].eps_record = ((x - final_x0) / timestep).cpu() + + return final_x0 + + positive_x0, negative_x0 = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options) + + positive_eps = x - positive_x0 + negative_eps = x - negative_x0 + + alpha = 0.001 * patch_settings[pid].sharpness * patch_settings[pid].global_diffusion_progress + + positive_eps_degraded = anisotropic.adaptive_anisotropic_filter(x=positive_eps, g=positive_x0) + positive_eps_degraded_weighted = positive_eps_degraded * alpha + positive_eps * (1.0 - alpha) + + final_eps = compute_cfg(uncond=negative_eps, cond=positive_eps_degraded_weighted, + cfg_scale=cond_scale, t=patch_settings[pid].global_diffusion_progress) + + if patch_settings[pid].eps_record is not None: + patch_settings[pid].eps_record = (final_eps / timestep).cpu() + + return x - final_eps + + +def round_to_64(x): + h = float(x) + h = h / 64.0 + h = round(h) + h = int(h) + h = h * 64 + return h + + +def sdxl_encode_adm_patched(self, **kwargs): + clip_pooled = ldm_patched.modules.model_base.sdxl_pooled(kwargs, self.noise_augmentor) + width = kwargs.get("width", 1024) + height = kwargs.get("height", 1024) + target_width = width + target_height = height + pid = os.getpid() + + if kwargs.get("prompt_type", "") == "negative": + width = float(width) * patch_settings[pid].negative_adm_scale + height = float(height) * patch_settings[pid].negative_adm_scale + elif kwargs.get("prompt_type", "") == "positive": + width = float(width) * patch_settings[pid].positive_adm_scale + height = float(height) * patch_settings[pid].positive_adm_scale + + def embedder(number_list): + h = self.embedder(torch.tensor(number_list, dtype=torch.float32)) + h = torch.flatten(h).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) + return h + + width, height = int(width), int(height) + target_width, target_height = round_to_64(target_width), round_to_64(target_height) + + adm_emphasized = embedder([height, width, 0, 0, target_height, target_width]) + adm_consistent = embedder([target_height, target_width, 0, 0, target_height, target_width]) + + clip_pooled = clip_pooled.to(adm_emphasized) + final_adm = torch.cat((clip_pooled, adm_emphasized, clip_pooled, adm_consistent), dim=1) + + return final_adm + + +def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None): + if inpaint_worker.current_task is not None: + latent_processor = self.inner_model.inner_model.process_latent_in + inpaint_latent = latent_processor(inpaint_worker.current_task.latent).to(x) + inpaint_mask = inpaint_worker.current_task.latent_mask.to(x) + + if getattr(self, 'energy_generator', None) is None: + # avoid bad results by using different seeds. + self.energy_generator = torch.Generator(device='cpu').manual_seed((seed + 1) % constants.MAX_SEED) + + energy_sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(x.shape) - 1)) + current_energy = torch.randn( + x.size(), dtype=x.dtype, generator=self.energy_generator, device="cpu").to(x) * energy_sigma + x = x * inpaint_mask + (inpaint_latent + current_energy) * (1.0 - inpaint_mask) + + out = self.inner_model(x, sigma, + cond=cond, + uncond=uncond, + cond_scale=cond_scale, + model_options=model_options, + seed=seed) + + out = out * inpaint_mask + inpaint_latent * (1.0 - inpaint_mask) + else: + out = self.inner_model(x, sigma, + cond=cond, + uncond=uncond, + cond_scale=cond_scale, + model_options=model_options, + seed=seed) + return out + + +def timed_adm(y, timesteps): + if isinstance(y, torch.Tensor) and int(y.dim()) == 2 and int(y.shape[1]) == 5632: + y_mask = (timesteps > 999.0 * (1.0 - float(patch_settings[os.getpid()].adm_scaler_end))).to(y)[..., None] + y_with_adm = y[..., :2816].clone() + y_without_adm = y[..., 2816:].clone() + return y_with_adm * y_mask + y_without_adm * (1.0 - y_mask) + return y + + +def patched_cldm_forward(self, x, hint, timesteps, context, y=None, **kwargs): + t_emb = ldm_patched.ldm.modules.diffusionmodules.openaimodel.timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) + emb = self.time_embed(t_emb) + pid = os.getpid() + + guided_hint = self.input_hint_block(hint, emb, context) + + y = timed_adm(y, timesteps) + + outs = [] + + hs = [] + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x + for module, zero_conv in zip(self.input_blocks, self.zero_convs): + if guided_hint is not None: + h = module(h, emb, context) + h += guided_hint + guided_hint = None + else: + h = module(h, emb, context) + outs.append(zero_conv(h, emb, context)) + + h = self.middle_block(h, emb, context) + outs.append(self.middle_block_out(h, emb, context)) + + if patch_settings[pid].controlnet_softness > 0: + for i in range(10): + k = 1.0 - float(i) / 9.0 + outs[i] = outs[i] * (1.0 - patch_settings[pid].controlnet_softness * k) + + return outs + + +def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): + self.current_step = 1.0 - timesteps.to(x) / 999.0 + patch_settings[os.getpid()].global_diffusion_progress = float(self.current_step.detach().cpu().numpy().tolist()[0]) + + y = timed_adm(y, timesteps) + + transformer_options["original_shape"] = list(x.shape) + transformer_options["transformer_index"] = 0 + transformer_patches = transformer_options.get("patches", {}) + + num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames) + image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator) + time_context = kwargs.get("time_context", None) + + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = ldm_patched.ldm.modules.diffusionmodules.openaimodel.timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x + for id, module in enumerate(self.input_blocks): + transformer_options["block"] = ("input", id) + h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) + h = apply_control(h, control, 'input') + if "input_block_patch" in transformer_patches: + patch = transformer_patches["input_block_patch"] + for p in patch: + h = p(h, transformer_options) + + hs.append(h) + if "input_block_patch_after_skip" in transformer_patches: + patch = transformer_patches["input_block_patch_after_skip"] + for p in patch: + h = p(h, transformer_options) + + transformer_options["block"] = ("middle", 0) + h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) + h = apply_control(h, control, 'middle') + + for id, module in enumerate(self.output_blocks): + transformer_options["block"] = ("output", id) + hsp = hs.pop() + hsp = apply_control(hsp, control, 'output') + + if "output_block_patch" in transformer_patches: + patch = transformer_patches["output_block_patch"] + for p in patch: + h, hsp = p(h, hsp, transformer_options) + + h = torch.cat([h, hsp], dim=1) + del hsp + if len(hs) > 0: + output_shape = hs[-1].shape + else: + output_shape = None + h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +def patched_load_models_gpu(*args, **kwargs): + execution_start_time = time.perf_counter() + y = ldm_patched.modules.model_management.load_models_gpu_origin(*args, **kwargs) + moving_time = time.perf_counter() - execution_start_time + if moving_time > 0.1: + print(f'[Fooocus Model Management] Moving model(s) has taken {moving_time:.2f} seconds') + return y + + +def build_loaded(module, loader_name): + original_loader_name = loader_name + '_origin' + + if not hasattr(module, original_loader_name): + setattr(module, original_loader_name, getattr(module, loader_name)) + + original_loader = getattr(module, original_loader_name) + + def loader(*args, **kwargs): + result = None + try: + result = original_loader(*args, **kwargs) + except Exception as e: + result = None + exp = str(e) + '\n' + for path in list(args) + list(kwargs.values()): + if isinstance(path, str): + if os.path.exists(path): + exp += f'File corrupted: {path} \n' + corrupted_backup_file = path + '.corrupted' + if os.path.exists(corrupted_backup_file): + os.remove(corrupted_backup_file) + os.replace(path, corrupted_backup_file) + if os.path.exists(path): + os.remove(path) + exp += f'Fooocus has tried to move the corrupted file to {corrupted_backup_file} \n' + exp += f'You may try again now and Fooocus will download models again. \n' + raise ValueError(exp) + return result + + setattr(module, loader_name, loader) + return + + +def patch_all(): + if ldm_patched.modules.model_management.directml_enabled: + ldm_patched.modules.model_management.lowvram_available = True + ldm_patched.modules.model_management.OOM_EXCEPTION = Exception + + patch_all_precision() + patch_all_clip() + + if not hasattr(ldm_patched.modules.model_management, 'load_models_gpu_origin'): + ldm_patched.modules.model_management.load_models_gpu_origin = ldm_patched.modules.model_management.load_models_gpu + + ldm_patched.modules.model_management.load_models_gpu = patched_load_models_gpu + ldm_patched.modules.model_patcher.ModelPatcher.calculate_weight = calculate_weight_patched + ldm_patched.controlnet.cldm.ControlNet.forward = patched_cldm_forward + ldm_patched.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = patched_unet_forward + ldm_patched.modules.model_base.SDXL.encode_adm = sdxl_encode_adm_patched + ldm_patched.modules.samplers.KSamplerX0Inpaint.forward = patched_KSamplerX0Inpaint_forward + ldm_patched.k_diffusion.sampling.BrownianTreeNoiseSampler = BrownianTreeNoiseSamplerPatched + ldm_patched.modules.samplers.sampling_function = patched_sampling_function + + warnings.filterwarnings(action='ignore', module='torchsde') + + build_loaded(safetensors.torch, 'load_file') + build_loaded(torch, 'load') + + return diff --git a/modules/patch_clip.py b/modules/patch_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..06b7f01bb857b01995ff7b0326813f98f92ea76d --- /dev/null +++ b/modules/patch_clip.py @@ -0,0 +1,195 @@ +# Consistent with Kohya/A1111 to reduce differences between model training and inference. + +import os +import torch +import ldm_patched.controlnet.cldm +import ldm_patched.k_diffusion.sampling +import ldm_patched.ldm.modules.attention +import ldm_patched.ldm.modules.diffusionmodules.model +import ldm_patched.ldm.modules.diffusionmodules.openaimodel +import ldm_patched.ldm.modules.diffusionmodules.openaimodel +import ldm_patched.modules.args_parser +import ldm_patched.modules.model_base +import ldm_patched.modules.model_management +import ldm_patched.modules.model_patcher +import ldm_patched.modules.samplers +import ldm_patched.modules.sd +import ldm_patched.modules.sd1_clip +import ldm_patched.modules.clip_vision +import ldm_patched.modules.ops as ops + +from modules.ops import use_patched_ops +from transformers import CLIPTextModel, CLIPTextConfig, modeling_utils, CLIPVisionConfig, CLIPVisionModelWithProjection + + +def patched_encode_token_weights(self, token_weight_pairs): + to_encode = list() + max_token_len = 0 + has_weights = False + for x in token_weight_pairs: + tokens = list(map(lambda a: a[0], x)) + max_token_len = max(len(tokens), max_token_len) + has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x)) + to_encode.append(tokens) + + sections = len(to_encode) + if has_weights or sections == 0: + to_encode.append(ldm_patched.modules.sd1_clip.gen_empty_tokens(self.special_tokens, max_token_len)) + + out, pooled = self.encode(to_encode) + if pooled is not None: + first_pooled = pooled[0:1].to(ldm_patched.modules.model_management.intermediate_device()) + else: + first_pooled = pooled + + output = [] + for k in range(0, sections): + z = out[k:k + 1] + if has_weights: + original_mean = z.mean() + z_empty = out[-1] + for i in range(len(z)): + for j in range(len(z[i])): + weight = token_weight_pairs[k][j][1] + if weight != 1.0: + z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j] + new_mean = z.mean() + z = z * (original_mean / new_mean) + output.append(z) + + if len(output) == 0: + return out[-1:].to(ldm_patched.modules.model_management.intermediate_device()), first_pooled + return torch.cat(output, dim=-2).to(ldm_patched.modules.model_management.intermediate_device()), first_pooled + + +def patched_SDClipModel__init__(self, max_length=77, freeze=True, layer="last", layer_idx=None, + textmodel_json_config=None, dtype=None, special_tokens=None, + layer_norm_hidden_state=True, **kwargs): + torch.nn.Module.__init__(self) + assert layer in self.LAYERS + + if special_tokens is None: + special_tokens = {"start": 49406, "end": 49407, "pad": 49407} + + if textmodel_json_config is None: + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(ldm_patched.modules.sd1_clip.__file__)), + "sd1_clip_config.json") + + config = CLIPTextConfig.from_json_file(textmodel_json_config) + self.num_layers = config.num_hidden_layers + + with use_patched_ops(ops.manual_cast): + with modeling_utils.no_init_weights(): + self.transformer = CLIPTextModel(config) + + if dtype is not None: + self.transformer.to(dtype) + + self.transformer.text_model.embeddings.to(torch.float32) + + if freeze: + self.freeze() + + self.max_length = max_length + self.layer = layer + self.layer_idx = None + self.special_tokens = special_tokens + self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) + self.enable_attention_masks = False + + self.layer_norm_hidden_state = layer_norm_hidden_state + if layer == "hidden": + assert layer_idx is not None + assert abs(layer_idx) < self.num_layers + self.clip_layer(layer_idx) + self.layer_default = (self.layer, self.layer_idx) + + +def patched_SDClipModel_forward(self, tokens): + backup_embeds = self.transformer.get_input_embeddings() + device = backup_embeds.weight.device + tokens = self.set_up_textual_embeddings(tokens, backup_embeds) + tokens = torch.LongTensor(tokens).to(device) + + attention_mask = None + if self.enable_attention_masks: + attention_mask = torch.zeros_like(tokens) + max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1 + for x in range(attention_mask.shape[0]): + for y in range(attention_mask.shape[1]): + attention_mask[x, y] = 1 + if tokens[x, y] == max_token: + break + + outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, + output_hidden_states=self.layer == "hidden") + self.transformer.set_input_embeddings(backup_embeds) + + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + if self.layer_norm_hidden_state: + z = self.transformer.text_model.final_layer_norm(z) + + if hasattr(outputs, "pooler_output"): + pooled_output = outputs.pooler_output.float() + else: + pooled_output = None + + if self.text_projection is not None and pooled_output is not None: + pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() + + return z.float(), pooled_output + + +def patched_ClipVisionModel__init__(self, json_config): + config = CLIPVisionConfig.from_json_file(json_config) + + self.load_device = ldm_patched.modules.model_management.text_encoder_device() + self.offload_device = ldm_patched.modules.model_management.text_encoder_offload_device() + + if ldm_patched.modules.model_management.should_use_fp16(self.load_device, prioritize_performance=False): + self.dtype = torch.float16 + else: + self.dtype = torch.float32 + + with use_patched_ops(ops.manual_cast): + with modeling_utils.no_init_weights(): + self.model = CLIPVisionModelWithProjection(config) + + self.model.to(self.dtype) + self.patcher = ldm_patched.modules.model_patcher.ModelPatcher( + self.model, + load_device=self.load_device, + offload_device=self.offload_device + ) + + +def patched_ClipVisionModel_encode_image(self, image): + ldm_patched.modules.model_management.load_model_gpu(self.patcher) + pixel_values = ldm_patched.modules.clip_vision.clip_preprocess(image.to(self.load_device)) + outputs = self.model(pixel_values=pixel_values, output_hidden_states=True) + + for k in outputs: + t = outputs[k] + if t is not None: + if k == 'hidden_states': + outputs["penultimate_hidden_states"] = t[-2].to(ldm_patched.modules.model_management.intermediate_device()) + outputs["hidden_states"] = None + else: + outputs[k] = t.to(ldm_patched.modules.model_management.intermediate_device()) + + return outputs + + +def patch_all_clip(): + ldm_patched.modules.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = patched_encode_token_weights + ldm_patched.modules.sd1_clip.SDClipModel.__init__ = patched_SDClipModel__init__ + ldm_patched.modules.sd1_clip.SDClipModel.forward = patched_SDClipModel_forward + ldm_patched.modules.clip_vision.ClipVisionModel.__init__ = patched_ClipVisionModel__init__ + ldm_patched.modules.clip_vision.ClipVisionModel.encode_image = patched_ClipVisionModel_encode_image + return diff --git a/modules/patch_precision.py b/modules/patch_precision.py new file mode 100644 index 0000000000000000000000000000000000000000..83569bdd15f5ab0cac2c57353626c4e843bd264d --- /dev/null +++ b/modules/patch_precision.py @@ -0,0 +1,60 @@ +# Consistent with Kohya to reduce differences between model training and inference. + +import torch +import math +import einops +import numpy as np + +import ldm_patched.ldm.modules.diffusionmodules.openaimodel +import ldm_patched.modules.model_sampling +import ldm_patched.modules.sd1_clip + +from ldm_patched.ldm.modules.diffusionmodules.util import make_beta_schedule + + +def patched_timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + # Consistent with Kohya to reduce differences between model training and inference. + + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = einops.repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def patched_register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + # Consistent with Kohya to reduce differences between model training and inference. + + if given_betas is not None: + betas = given_betas + else: + betas = make_beta_schedule( + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s) + + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32) + self.set_sigmas(sigmas) + return + + +def patch_all_precision(): + ldm_patched.ldm.modules.diffusionmodules.openaimodel.timestep_embedding = patched_timestep_embedding + ldm_patched.modules.model_sampling.ModelSamplingDiscrete._register_schedule = patched_register_schedule + return diff --git a/modules/private_logger.py b/modules/private_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..01e570a7d96375a15a81b6f07a678b1f7eda743e --- /dev/null +++ b/modules/private_logger.py @@ -0,0 +1,130 @@ +import os +import args_manager +import modules.config +import json +import urllib.parse + +from PIL import Image +from PIL.PngImagePlugin import PngInfo +from modules.util import generate_temp_filename +from modules.meta_parser import MetadataParser, get_exif + +log_cache = {} + + +def get_current_html_path(output_format=None): + output_format = output_format if output_format else modules.config.default_output_format + date_string, local_temp_filename, only_name = generate_temp_filename(folder=modules.config.path_outputs, + extension=output_format) + html_name = os.path.join(os.path.dirname(local_temp_filename), 'log.html') + return html_name + + +def log(img, metadata, metadata_parser: MetadataParser | None = None, output_format=None) -> str: + path_outputs = args_manager.args.temp_path if args_manager.args.disable_image_log else modules.config.path_outputs + output_format = output_format if output_format else modules.config.default_output_format + date_string, local_temp_filename, only_name = generate_temp_filename(folder=path_outputs, extension=output_format) + os.makedirs(os.path.dirname(local_temp_filename), exist_ok=True) + + parsed_parameters = metadata_parser.parse_string(metadata.copy()) if metadata_parser is not None else '' + image = Image.fromarray(img) + + if output_format == 'png': + if parsed_parameters != '': + pnginfo = PngInfo() + pnginfo.add_text('parameters', parsed_parameters) + pnginfo.add_text('fooocus_scheme', metadata_parser.get_scheme().value) + else: + pnginfo = None + image.save(local_temp_filename, pnginfo=pnginfo) + elif output_format == 'jpg': + image.save(local_temp_filename, quality=95, optimize=True, progressive=True, exif=get_exif(parsed_parameters, metadata_parser.get_scheme().value) if metadata_parser else Image.Exif()) + elif output_format == 'webp': + image.save(local_temp_filename, quality=95, lossless=False, exif=get_exif(parsed_parameters, metadata_parser.get_scheme().value) if metadata_parser else Image.Exif()) + else: + image.save(local_temp_filename) + + if args_manager.args.disable_image_log: + return local_temp_filename + + html_name = os.path.join(os.path.dirname(local_temp_filename), 'log.html') + + css_styles = ( + "" + ) + + js = ( + """""" + ) + + begin_part = f"Fooocus Log {date_string}{css_styles}{js}

Fooocus Log {date_string} (private)

\n

Metadata is embedded if enabled in the config or developer debug mode. You can find the information for each image in line Metadata Scheme.

\n\n" + end_part = f'\n' + + middle_part = log_cache.get(html_name, "") + + if middle_part == "": + if os.path.exists(html_name): + existing_split = open(html_name, 'r', encoding='utf-8').read().split('') + if len(existing_split) == 3: + middle_part = existing_split[1] + else: + middle_part = existing_split[0] + + div_name = only_name.replace('.', '_') + item = f"

\n" + item += f"" + item += "" + item += "
{only_name}
" + for label, key, value in metadata: + value_txt = str(value).replace('\n', '
') + item += f"\n" + item += "" + + js_txt = urllib.parse.quote(json.dumps({k: v for _, k, v in metadata}, indent=0), safe='') + item += f"
" + + item += "
\n\n" + + middle_part = item + middle_part + + with open(html_name, 'w', encoding='utf-8') as f: + f.write(begin_part + middle_part + end_part) + + print(f'Image generated with private log at: {html_name}') + + log_cache[html_name] = middle_part + + return local_temp_filename diff --git a/modules/sample_hijack.py b/modules/sample_hijack.py new file mode 100644 index 0000000000000000000000000000000000000000..5936a096d9f0afaac0a672f72cee5f84b23496ad --- /dev/null +++ b/modules/sample_hijack.py @@ -0,0 +1,184 @@ +import torch +import ldm_patched.modules.samplers +import ldm_patched.modules.model_management + +from collections import namedtuple +from ldm_patched.contrib.external_custom_sampler import SDTurboScheduler +from ldm_patched.k_diffusion import sampling as k_diffusion_sampling +from ldm_patched.modules.samplers import normal_scheduler, simple_scheduler, ddim_scheduler +from ldm_patched.modules.model_base import SDXLRefiner, SDXL +from ldm_patched.modules.conds import CONDRegular +from ldm_patched.modules.sample import get_additional_models, get_models_from_cond, cleanup_additional_models +from ldm_patched.modules.samplers import resolve_areas_and_cond_masks, wrap_model, calculate_start_end_timesteps, \ + create_cond_with_same_area_if_none, pre_run_control, apply_empty_x_to_equal_area, encode_model_conds + + +current_refiner = None +refiner_switch_step = -1 + + +@torch.no_grad() +@torch.inference_mode() +def clip_separate_inner(c, p, target_model=None, target_clip=None): + if target_model is None or isinstance(target_model, SDXLRefiner): + c = c[..., -1280:].clone() + elif isinstance(target_model, SDXL): + c = c.clone() + else: + p = None + c = c[..., :768].clone() + + final_layer_norm = target_clip.cond_stage_model.clip_l.transformer.text_model.final_layer_norm + + final_layer_norm_origin_device = final_layer_norm.weight.device + final_layer_norm_origin_dtype = final_layer_norm.weight.dtype + + c_origin_device = c.device + c_origin_dtype = c.dtype + + final_layer_norm.to(device='cpu', dtype=torch.float32) + c = c.to(device='cpu', dtype=torch.float32) + + c = torch.chunk(c, int(c.size(1)) // 77, 1) + c = [final_layer_norm(ci) for ci in c] + c = torch.cat(c, dim=1) + + final_layer_norm.to(device=final_layer_norm_origin_device, dtype=final_layer_norm_origin_dtype) + c = c.to(device=c_origin_device, dtype=c_origin_dtype) + return c, p + + +@torch.no_grad() +@torch.inference_mode() +def clip_separate(cond, target_model=None, target_clip=None): + results = [] + + for c, px in cond: + p = px.get('pooled_output', None) + c, p = clip_separate_inner(c, p, target_model=target_model, target_clip=target_clip) + p = {} if p is None else {'pooled_output': p.clone()} + results.append([c, p]) + + return results + + +@torch.no_grad() +@torch.inference_mode() +def clip_separate_after_preparation(cond, target_model=None, target_clip=None): + results = [] + + for x in cond: + p = x.get('pooled_output', None) + c = x['model_conds']['c_crossattn'].cond + + c, p = clip_separate_inner(c, p, target_model=target_model, target_clip=target_clip) + + result = {'model_conds': {'c_crossattn': CONDRegular(c)}} + + if p is not None: + result['pooled_output'] = p.clone() + + results.append(result) + + return results + + +@torch.no_grad() +@torch.inference_mode() +def sample_hacked(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): + global current_refiner + + positive = positive[:] + negative = negative[:] + + resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], device) + resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], device) + + model_wrap = wrap_model(model) + + calculate_start_end_timesteps(model, negative) + calculate_start_end_timesteps(model, positive) + + if latent_image is not None: + latent_image = model.process_latent_in(latent_image) + + if hasattr(model, 'extra_conds'): + positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) + negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) + + #make sure each cond area has an opposite one with the same area + for c in positive: + create_cond_with_same_area_if_none(negative, c) + for c in negative: + create_cond_with_same_area_if_none(positive, c) + + # pre_run_control(model, negative + positive) + pre_run_control(model, positive) # negative is not necessary in Fooocus, 0.5s faster. + + apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) + apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) + + extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} + + if current_refiner is not None and hasattr(current_refiner.model, 'extra_conds'): + positive_refiner = clip_separate_after_preparation(positive, target_model=current_refiner.model) + negative_refiner = clip_separate_after_preparation(negative, target_model=current_refiner.model) + + positive_refiner = encode_model_conds(current_refiner.model.extra_conds, positive_refiner, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) + negative_refiner = encode_model_conds(current_refiner.model.extra_conds, negative_refiner, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) + + def refiner_switch(): + cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))) + + extra_args["cond"] = positive_refiner + extra_args["uncond"] = negative_refiner + + # clear ip-adapter for refiner + extra_args['model_options'] = {k: {} if k == 'transformer_options' else v for k, v in extra_args['model_options'].items()} + + models, inference_memory = get_additional_models(positive_refiner, negative_refiner, current_refiner.model_dtype()) + ldm_patched.modules.model_management.load_models_gpu( + [current_refiner] + models, + model.memory_required([noise.shape[0] * 2] + list(noise.shape[1:])) + inference_memory) + + model_wrap.inner_model = current_refiner.model + print('Refiner Swapped') + return + + def callback_wrap(step, x0, x, total_steps): + if step == refiner_switch_step and current_refiner is not None: + refiner_switch() + if callback is not None: + # residual_noise_preview = x - x0 + # residual_noise_preview /= residual_noise_preview.std() + # residual_noise_preview *= x0.std() + callback(step, x0, x, total_steps) + + samples = sampler.sample(model_wrap, sigmas, extra_args, callback_wrap, noise, latent_image, denoise_mask, disable_pbar) + return model.process_latent_out(samples.to(torch.float32)) + + +@torch.no_grad() +@torch.inference_mode() +def calculate_sigmas_scheduler_hacked(model, scheduler_name, steps): + if scheduler_name == "karras": + sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) + elif scheduler_name == "exponential": + sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) + elif scheduler_name == "normal": + sigmas = normal_scheduler(model, steps) + elif scheduler_name == "simple": + sigmas = simple_scheduler(model, steps) + elif scheduler_name == "ddim_uniform": + sigmas = ddim_scheduler(model, steps) + elif scheduler_name == "sgm_uniform": + sigmas = normal_scheduler(model, steps, sgm=True) + elif scheduler_name == "turbo": + sigmas = SDTurboScheduler().get_sigmas(namedtuple('Patcher', ['model'])(model=model), steps=steps, denoise=1.0)[0] + else: + raise TypeError("error invalid scheduler") + return sigmas + + +ldm_patched.modules.samplers.calculate_sigmas_scheduler = calculate_sigmas_scheduler_hacked +ldm_patched.modules.samplers.sample = sample_hacked diff --git a/modules/sdxl_styles.py b/modules/sdxl_styles.py new file mode 100644 index 0000000000000000000000000000000000000000..2a310024cdd0f96cb20341f811a50146000b586b --- /dev/null +++ b/modules/sdxl_styles.py @@ -0,0 +1,117 @@ +import os +import re +import json +import math + +from modules.util import get_files_from_folder + + +# cannot use modules.config - validators causing circular imports +styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/')) +wildcards_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../wildcards/')) +wildcards_max_bfs_depth = 64 + + +def normalize_key(k): + k = k.replace('-', ' ') + words = k.split(' ') + words = [w[:1].upper() + w[1:].lower() for w in words] + k = ' '.join(words) + k = k.replace('3d', '3D') + k = k.replace('Sai', 'SAI') + k = k.replace('Mre', 'MRE') + k = k.replace('(s', '(S') + return k + + +styles = {} + +styles_files = get_files_from_folder(styles_path, ['.json']) + +for x in ['sdxl_styles_fooocus.json', + 'sdxl_styles_sai.json', + 'sdxl_styles_mre.json', + 'sdxl_styles_twri.json', + 'sdxl_styles_diva.json', + 'sdxl_styles_marc_k3nt3l.json']: + if x in styles_files: + styles_files.remove(x) + styles_files.append(x) + +for styles_file in styles_files: + try: + with open(os.path.join(styles_path, styles_file), encoding='utf-8') as f: + for entry in json.load(f): + name = normalize_key(entry['name']) + prompt = entry['prompt'] if 'prompt' in entry else '' + negative_prompt = entry['negative_prompt'] if 'negative_prompt' in entry else '' + styles[name] = (prompt, negative_prompt) + except Exception as e: + print(str(e)) + print(f'Failed to load style file {styles_file}') + +style_keys = list(styles.keys()) +fooocus_expansion = "Fooocus V2" +legal_style_names = [fooocus_expansion] + style_keys + + +def apply_style(style, positive): + p, n = styles[style] + return p.replace('{prompt}', positive).splitlines(), n.splitlines() + + +def apply_wildcards(wildcard_text, rng, directory=wildcards_path): + for _ in range(wildcards_max_bfs_depth): + placeholders = re.findall(r'__([\w-]+)__', wildcard_text) + if len(placeholders) == 0: + return wildcard_text + + print(f'[Wildcards] processing: {wildcard_text}') + for placeholder in placeholders: + try: + words = open(os.path.join(directory, f'{placeholder}.txt'), encoding='utf-8').read().splitlines() + words = [x for x in words if x != ''] + assert len(words) > 0 + wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1) + except: + print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. ' + f'Using "{placeholder}" as a normal word.') + wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder) + print(f'[Wildcards] {wildcard_text}') + + print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}') + return wildcard_text + +def get_words(arrays, totalMult, index): + if(len(arrays) == 1): + return [arrays[0].split(',')[index]] + else: + words = arrays[0].split(',') + word = words[index % len(words)] + index -= index % len(words) + index /= len(words) + index = math.floor(index) + return [word] + get_words(arrays[1:], math.floor(totalMult/len(words)), index) + + +def apply_arrays(text, index): + arrays = re.findall(r'\[\[(.*?)\]\]', text) + if len(arrays) == 0: + return text + + print(f'[Arrays] processing: {text}') + mult = 1 + for arr in arrays: + words = arr.split(',') + mult *= len(words) + + index %= mult + chosen_words = get_words(arrays, mult, index) + + i = 0 + for arr in arrays: + text = text.replace(f'[[{arr}]]', chosen_words[i], 1) + i = i+1 + + return text + diff --git a/modules/style_sorter.py b/modules/style_sorter.py new file mode 100644 index 0000000000000000000000000000000000000000..49142bc7926e06ee29f5678de1a9acc13dac5b70 --- /dev/null +++ b/modules/style_sorter.py @@ -0,0 +1,59 @@ +import os +import gradio as gr +import modules.localization as localization +import json + + +all_styles = [] + + +def try_load_sorted_styles(style_names, default_selected): + global all_styles + + all_styles = style_names + + try: + if os.path.exists('sorted_styles.json'): + with open('sorted_styles.json', 'rt', encoding='utf-8') as fp: + sorted_styles = [] + for x in json.load(fp): + if x in all_styles: + sorted_styles.append(x) + for x in all_styles: + if x not in sorted_styles: + sorted_styles.append(x) + all_styles = sorted_styles + except Exception as e: + print('Load style sorting failed.') + print(e) + + unselected = [y for y in all_styles if y not in default_selected] + all_styles = default_selected + unselected + + return + + +def sort_styles(selected): + global all_styles + unselected = [y for y in all_styles if y not in selected] + sorted_styles = selected + unselected + try: + with open('sorted_styles.json', 'wt', encoding='utf-8') as fp: + json.dump(sorted_styles, fp, indent=4) + except Exception as e: + print('Write style sorting failed.') + print(e) + all_styles = sorted_styles + return gr.CheckboxGroup.update(choices=sorted_styles) + + +def localization_key(x): + return x + localization.current_translation.get(x, '') + + +def search_styles(selected, query): + unselected = [y for y in all_styles if y not in selected] + matched = [y for y in unselected if query.lower() in localization_key(y).lower()] if len(query.replace(' ', '')) > 0 else [] + unmatched = [y for y in unselected if y not in matched] + sorted_styles = matched + selected + unmatched + return gr.CheckboxGroup.update(choices=sorted_styles) diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..bebf9f8ca7860c700f52ea5d3d3586917f17d34b --- /dev/null +++ b/modules/ui_gradio_extensions.py @@ -0,0 +1,67 @@ +# based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/v1.6.0/modules/ui_gradio_extensions.py + +import os +import gradio as gr +import args_manager + +from modules.localization import localization_js + + +GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse + +modules_path = os.path.dirname(os.path.realpath(__file__)) +script_path = os.path.dirname(modules_path) + + +def webpath(fn): + if fn.startswith(script_path): + web_path = os.path.relpath(fn, script_path).replace('\\', '/') + else: + web_path = os.path.abspath(fn) + + return f'file={web_path}?{os.path.getmtime(fn)}' + + +def javascript_html(): + script_js_path = webpath('javascript/script.js') + context_menus_js_path = webpath('javascript/contextMenus.js') + localization_js_path = webpath('javascript/localization.js') + zoom_js_path = webpath('javascript/zoom.js') + edit_attention_js_path = webpath('javascript/edit-attention.js') + viewer_js_path = webpath('javascript/viewer.js') + image_viewer_js_path = webpath('javascript/imageviewer.js') + samples_path = webpath(os.path.abspath('./sdxl_styles/samples/fooocus_v2.jpg')) + head = f'\n' + head += f'\n' + head += f'\n' + head += f'\n' + head += f'\n' + head += f'\n' + head += f'\n' + head += f'\n' + head += f'\n' + + if args_manager.args.theme: + head += f'\n' + + return head + + +def css_html(): + style_css_path = webpath('css/style.css') + head = f'' + return head + + +def reload_javascript(): + js = javascript_html() + css = css_html() + + def template_response(*args, **kwargs): + res = GradioTemplateResponseOriginal(*args, **kwargs) + res.body = res.body.replace(b'', f'{js}'.encode("utf8")) + res.body = res.body.replace(b'', f'{css}'.encode("utf8")) + res.init_headers() + return res + + gr.routes.templates.TemplateResponse = template_response diff --git a/modules/upscaler.py b/modules/upscaler.py new file mode 100644 index 0000000000000000000000000000000000000000..974e4f37c8756df56b9e64143bddff1f1378bc83 --- /dev/null +++ b/modules/upscaler.py @@ -0,0 +1,34 @@ +import os +import torch +import modules.core as core + +from ldm_patched.pfn.architecture.RRDB import RRDBNet as ESRGAN +from ldm_patched.contrib.external_upscale_model import ImageUpscaleWithModel +from collections import OrderedDict +from modules.config import path_upscale_models + +model_filename = os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin') +opImageUpscaleWithModel = ImageUpscaleWithModel() +model = None + + +def perform_upscale(img): + global model + + print(f'Upscaling image with shape {str(img.shape)} ...') + + if model is None: + sd = torch.load(model_filename) + sdo = OrderedDict() + for k, v in sd.items(): + sdo[k.replace('residual_block_', 'RDB')] = v + del sd + model = ESRGAN(sdo) + model.cpu() + model.eval() + + img = core.numpy_to_pytorch(img) + img = opImageUpscaleWithModel.upscale(model, img)[0] + img = core.pytorch_to_numpy(img)[0] + + return img diff --git a/modules/util.py b/modules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..c7923ec8260286d4bcf858a250b0e75c7e51d97d --- /dev/null +++ b/modules/util.py @@ -0,0 +1,362 @@ +import typing + +import numpy as np +import datetime +import random +import math +import os +import cv2 +import json + +from PIL import Image +from hashlib import sha256 + +import modules.sdxl_styles + +LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) +HASH_SHA256_LENGTH = 10 + +def erode_or_dilate(x, k): + k = int(k) + if k > 0: + return cv2.dilate(x, kernel=np.ones(shape=(3, 3), dtype=np.uint8), iterations=k) + if k < 0: + return cv2.erode(x, kernel=np.ones(shape=(3, 3), dtype=np.uint8), iterations=-k) + return x + + +def resample_image(im, width, height): + im = Image.fromarray(im) + im = im.resize((int(width), int(height)), resample=LANCZOS) + return np.array(im) + + +def resize_image(im, width, height, resize_mode=1): + """ + Resizes an image with the specified resize_mode, width, and height. + + Args: + resize_mode: The mode to use when resizing the image. + 0: Resize the image to the specified width and height. + 1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess. + 2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image. + im: The image to resize. + width: The width to resize the image to. + height: The height to resize the image to. + """ + + im = Image.fromarray(im) + + def resize(im, w, h): + return im.resize((w, h), resample=LANCZOS) + + if resize_mode == 0: + res = resize(im, width, height) + + elif resize_mode == 1: + ratio = width / height + src_ratio = im.width / im.height + + src_w = width if ratio > src_ratio else im.width * height // im.height + src_h = height if ratio <= src_ratio else im.height * width // im.width + + resized = resize(im, src_w, src_h) + res = Image.new("RGB", (width, height)) + res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) + + else: + ratio = width / height + src_ratio = im.width / im.height + + src_w = width if ratio < src_ratio else im.width * height // im.height + src_h = height if ratio >= src_ratio else im.height * width // im.width + + resized = resize(im, src_w, src_h) + res = Image.new("RGB", (width, height)) + res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) + + if ratio < src_ratio: + fill_height = height // 2 - src_h // 2 + if fill_height > 0: + res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) + res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h)) + elif ratio > src_ratio: + fill_width = width // 2 - src_w // 2 + if fill_width > 0: + res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) + res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0)) + + return np.array(res) + + +def get_shape_ceil(h, w): + return math.ceil(((h * w) ** 0.5) / 64.0) * 64.0 + + +def get_image_shape_ceil(im): + H, W = im.shape[:2] + return get_shape_ceil(H, W) + + +def set_image_shape_ceil(im, shape_ceil): + shape_ceil = float(shape_ceil) + + H_origin, W_origin, _ = im.shape + H, W = H_origin, W_origin + + for _ in range(256): + current_shape_ceil = get_shape_ceil(H, W) + if abs(current_shape_ceil - shape_ceil) < 0.1: + break + k = shape_ceil / current_shape_ceil + H = int(round(float(H) * k / 64.0) * 64) + W = int(round(float(W) * k / 64.0) * 64) + + if H == H_origin and W == W_origin: + return im + + return resample_image(im, width=W, height=H) + + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + + +def remove_empty_str(items, default=None): + items = [x for x in items if x != ""] + if len(items) == 0 and default is not None: + return [default] + return items + + +def join_prompts(*args, **kwargs): + prompts = [str(x) for x in args if str(x) != ""] + if len(prompts) == 0: + return "" + if len(prompts) == 1: + return prompts[0] + return ', '.join(prompts) + + +def generate_temp_filename(folder='./outputs/', extension='png'): + current_time = datetime.datetime.now() + date_string = current_time.strftime("%Y-%m-%d") + time_string = current_time.strftime("%Y-%m-%d_%H-%M-%S") + random_number = random.randint(1000, 9999) + filename = f"{time_string}_{random_number}.{extension}" + result = os.path.join(folder, date_string, filename) + return date_string, os.path.abspath(result), filename + + +def get_files_from_folder(folder_path, exensions=None, name_filter=None): + if not os.path.isdir(folder_path): + raise ValueError("Folder path is not a valid directory.") + + filenames = [] + + for root, dirs, files in os.walk(folder_path, topdown=False): + relative_path = os.path.relpath(root, folder_path) + if relative_path == ".": + relative_path = "" + for filename in sorted(files, key=lambda s: s.casefold()): + _, file_extension = os.path.splitext(filename) + if (exensions is None or file_extension.lower() in exensions) and (name_filter is None or name_filter in _): + path = os.path.join(relative_path, filename) + filenames.append(path) + + return filenames + + +def calculate_sha256(filename, length=HASH_SHA256_LENGTH) -> str: + hash_sha256 = sha256() + blksize = 1024 * 1024 + + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(blksize), b""): + hash_sha256.update(chunk) + + res = hash_sha256.hexdigest() + return res[:length] if length else res + + +def quote(text): + if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text): + return text + + return json.dumps(text, ensure_ascii=False) + + +def unquote(text): + if len(text) == 0 or text[0] != '"' or text[-1] != '"': + return text + + try: + return json.loads(text) + except Exception: + return text + + +def unwrap_style_text_from_prompt(style_text, prompt): + """ + Checks the prompt to see if the style text is wrapped around it. If so, + returns True plus the prompt text without the style text. Otherwise, returns + False with the original prompt. + + Note that the "cleaned" version of the style text is only used for matching + purposes here. It isn't returned; the original style text is not modified. + """ + stripped_prompt = prompt + stripped_style_text = style_text + if "{prompt}" in stripped_style_text: + # Work out whether the prompt is wrapped in the style text. If so, we + # return True and the "inner" prompt text that isn't part of the style. + try: + left, right = stripped_style_text.split("{prompt}", 2) + except ValueError as e: + # If the style text has multple "{prompt}"s, we can't split it into + # two parts. This is an error, but we can't do anything about it. + print(f"Unable to compare style text to prompt:\n{style_text}") + print(f"Error: {e}") + return False, prompt, '' + + left_pos = stripped_prompt.find(left) + right_pos = stripped_prompt.find(right) + if 0 <= left_pos < right_pos: + real_prompt = stripped_prompt[left_pos + len(left):right_pos] + prompt = stripped_prompt.replace(left + real_prompt + right, '', 1) + if prompt.startswith(", "): + prompt = prompt[2:] + if prompt.endswith(", "): + prompt = prompt[:-2] + return True, prompt, real_prompt + else: + # Work out whether the given prompt starts with the style text. If so, we + # return True and the prompt text up to where the style text starts. + if stripped_prompt.endswith(stripped_style_text): + prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)] + if prompt.endswith(", "): + prompt = prompt[:-2] + return True, prompt, prompt + + return False, prompt, '' + + +def extract_original_prompts(style, prompt, negative_prompt): + """ + Takes a style and compares it to the prompt and negative prompt. If the style + matches, returns True plus the prompt and negative prompt with the style text + removed. Otherwise, returns False with the original prompt and negative prompt. + """ + if not style.prompt and not style.negative_prompt: + return False, prompt, negative_prompt + + match_positive, extracted_positive, real_prompt = unwrap_style_text_from_prompt( + style.prompt, prompt + ) + if not match_positive: + return False, prompt, negative_prompt, '' + + match_negative, extracted_negative, _ = unwrap_style_text_from_prompt( + style.negative_prompt, negative_prompt + ) + if not match_negative: + return False, prompt, negative_prompt, '' + + return True, extracted_positive, extracted_negative, real_prompt + + +def extract_styles_from_prompt(prompt, negative_prompt): + extracted = [] + applicable_styles = [] + + for style_name, (style_prompt, style_negative_prompt) in modules.sdxl_styles.styles.items(): + applicable_styles.append(PromptStyle(name=style_name, prompt=style_prompt, negative_prompt=style_negative_prompt)) + + real_prompt = '' + + while True: + found_style = None + + for style in applicable_styles: + is_match, new_prompt, new_neg_prompt, new_real_prompt = extract_original_prompts( + style, prompt, negative_prompt + ) + if is_match: + found_style = style + prompt = new_prompt + negative_prompt = new_neg_prompt + if real_prompt == '' and new_real_prompt != '' and new_real_prompt != prompt: + real_prompt = new_real_prompt + break + + if not found_style: + break + + applicable_styles.remove(found_style) + extracted.append(found_style.name) + + # add prompt expansion if not all styles could be resolved + if prompt != '': + if real_prompt != '': + extracted.append(modules.sdxl_styles.fooocus_expansion) + else: + # find real_prompt when only prompt expansion is selected + first_word = prompt.split(', ')[0] + first_word_positions = [i for i in range(len(prompt)) if prompt.startswith(first_word, i)] + if len(first_word_positions) > 1: + real_prompt = prompt[:first_word_positions[-1]] + extracted.append(modules.sdxl_styles.fooocus_expansion) + if real_prompt.endswith(', '): + real_prompt = real_prompt[:-2] + + return list(reversed(extracted)), real_prompt, negative_prompt + + +class PromptStyle(typing.NamedTuple): + name: str + prompt: str + negative_prompt: str + + +def is_json(data: str) -> bool: + try: + loaded_json = json.loads(data) + assert isinstance(loaded_json, dict) + except (ValueError, AssertionError): + return False + return True + + +def get_file_from_folder_list(name, folders): + for folder in folders: + filename = os.path.abspath(os.path.realpath(os.path.join(folder, name))) + if os.path.isfile(filename): + return filename + + return os.path.abspath(os.path.realpath(os.path.join(folders[0], name))) + + +def ordinal_suffix(number: int) -> str: + return 'th' if 10 <= number % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(number % 10, 'th') + + +def makedirs_with_log(path): + try: + os.makedirs(path, exist_ok=True) + except OSError as error: + print(f'Directory {path} could not be created, reason: {error}')