Adityadn commited on
Commit
611f1b5
1 Parent(s): ef0edf3

Upload 59 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. modules/__pycache__/anisotropic.cpython-310.pyc +0 -0
  2. modules/__pycache__/async_worker.cpython-310.pyc +0 -0
  3. modules/__pycache__/auth.cpython-310.pyc +0 -0
  4. modules/__pycache__/config.cpython-310.pyc +0 -0
  5. modules/__pycache__/config.cpython-312.pyc +0 -0
  6. modules/__pycache__/constants.cpython-310.pyc +0 -0
  7. modules/__pycache__/core.cpython-310.pyc +0 -0
  8. modules/__pycache__/default_pipeline.cpython-310.pyc +0 -0
  9. modules/__pycache__/flags.cpython-310.pyc +0 -0
  10. modules/__pycache__/flags.cpython-312.pyc +0 -0
  11. modules/__pycache__/gradio_hijack.cpython-310.pyc +0 -0
  12. modules/__pycache__/html.cpython-310.pyc +0 -0
  13. modules/__pycache__/inpaint_worker.cpython-310.pyc +0 -0
  14. modules/__pycache__/launch_util.cpython-310.pyc +0 -0
  15. modules/__pycache__/launch_util.cpython-312.pyc +0 -0
  16. modules/__pycache__/localization.cpython-310.pyc +0 -0
  17. modules/__pycache__/lora.cpython-310.pyc +0 -0
  18. modules/__pycache__/meta_parser.cpython-310.pyc +0 -0
  19. modules/__pycache__/model_loader.cpython-310.pyc +0 -0
  20. modules/__pycache__/ops.cpython-310.pyc +0 -0
  21. modules/__pycache__/patch.cpython-310.pyc +0 -0
  22. modules/__pycache__/patch_clip.cpython-310.pyc +0 -0
  23. modules/__pycache__/patch_precision.cpython-310.pyc +0 -0
  24. modules/__pycache__/private_logger.cpython-310.pyc +0 -0
  25. modules/__pycache__/sample_hijack.cpython-310.pyc +0 -0
  26. modules/__pycache__/sdxl_styles.cpython-310.pyc +0 -0
  27. modules/__pycache__/sdxl_styles.cpython-312.pyc +0 -0
  28. modules/__pycache__/style_sorter.cpython-310.pyc +0 -0
  29. modules/__pycache__/ui_gradio_extensions.cpython-310.pyc +0 -0
  30. modules/__pycache__/upscaler.cpython-310.pyc +0 -0
  31. modules/__pycache__/util.cpython-310.pyc +0 -0
  32. modules/__pycache__/util.cpython-312.pyc +0 -0
  33. modules/anisotropic.py +200 -0
  34. modules/async_worker.py +914 -0
  35. modules/auth.py +41 -0
  36. modules/config.py +607 -0
  37. modules/constants.py +5 -0
  38. modules/core.py +339 -0
  39. modules/default_pipeline.py +498 -0
  40. modules/flags.py +125 -0
  41. modules/gradio_hijack.py +480 -0
  42. modules/html.py +146 -0
  43. modules/inpaint_worker.py +264 -0
  44. modules/launch_util.py +103 -0
  45. modules/localization.py +60 -0
  46. modules/lora.py +152 -0
  47. modules/meta_parser.py +573 -0
  48. modules/model_loader.py +26 -0
  49. modules/ops.py +19 -0
  50. modules/patch.py +513 -0
modules/__pycache__/anisotropic.cpython-310.pyc ADDED
Binary file (5.81 kB). View file
 
modules/__pycache__/async_worker.cpython-310.pyc ADDED
Binary file (21 kB). View file
 
modules/__pycache__/auth.cpython-310.pyc ADDED
Binary file (1.39 kB). View file
 
modules/__pycache__/config.cpython-310.pyc ADDED
Binary file (18.4 kB). View file
 
modules/__pycache__/config.cpython-312.pyc ADDED
Binary file (29.3 kB). View file
 
modules/__pycache__/constants.cpython-310.pyc ADDED
Binary file (309 Bytes). View file
 
modules/__pycache__/core.cpython-310.pyc ADDED
Binary file (10.5 kB). View file
 
modules/__pycache__/default_pipeline.cpython-310.pyc ADDED
Binary file (9.98 kB). View file
 
modules/__pycache__/flags.cpython-310.pyc ADDED
Binary file (3.72 kB). View file
 
modules/__pycache__/flags.cpython-312.pyc ADDED
Binary file (4.89 kB). View file
 
modules/__pycache__/gradio_hijack.cpython-310.pyc ADDED
Binary file (17.3 kB). View file
 
modules/__pycache__/html.cpython-310.pyc ADDED
Binary file (3.25 kB). View file
 
modules/__pycache__/inpaint_worker.cpython-310.pyc ADDED
Binary file (6.99 kB). View file
 
modules/__pycache__/launch_util.cpython-310.pyc ADDED
Binary file (3.14 kB). View file
 
modules/__pycache__/launch_util.cpython-312.pyc ADDED
Binary file (5.21 kB). View file
 
modules/__pycache__/localization.cpython-310.pyc ADDED
Binary file (1.87 kB). View file
 
modules/__pycache__/lora.cpython-310.pyc ADDED
Binary file (3.26 kB). View file
 
modules/__pycache__/meta_parser.cpython-310.pyc ADDED
Binary file (15.3 kB). View file
 
modules/__pycache__/model_loader.cpython-310.pyc ADDED
Binary file (1.04 kB). View file
 
modules/__pycache__/ops.cpython-310.pyc ADDED
Binary file (816 Bytes). View file
 
modules/__pycache__/patch.cpython-310.pyc ADDED
Binary file (14.4 kB). View file
 
modules/__pycache__/patch_clip.cpython-310.pyc ADDED
Binary file (6.01 kB). View file
 
modules/__pycache__/patch_precision.cpython-310.pyc ADDED
Binary file (2.03 kB). View file
 
modules/__pycache__/private_logger.cpython-310.pyc ADDED
Binary file (5.2 kB). View file
 
modules/__pycache__/sample_hijack.cpython-310.pyc ADDED
Binary file (6.19 kB). View file
 
modules/__pycache__/sdxl_styles.cpython-310.pyc ADDED
Binary file (3.61 kB). View file
 
modules/__pycache__/sdxl_styles.cpython-312.pyc ADDED
Binary file (6.21 kB). View file
 
modules/__pycache__/style_sorter.cpython-310.pyc ADDED
Binary file (2.42 kB). View file
 
modules/__pycache__/ui_gradio_extensions.cpython-310.pyc ADDED
Binary file (2.55 kB). View file
 
modules/__pycache__/upscaler.cpython-310.pyc ADDED
Binary file (1.17 kB). View file
 
modules/__pycache__/util.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
modules/__pycache__/util.cpython-312.pyc ADDED
Binary file (18.4 kB). View file
 
modules/anisotropic.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ Tensor = torch.Tensor
5
+ Device = torch.DeviceObjType
6
+ Dtype = torch.Type
7
+ pad = torch.nn.functional.pad
8
+
9
+
10
+ def _compute_zero_padding(kernel_size: tuple[int, int] | int) -> tuple[int, int]:
11
+ ky, kx = _unpack_2d_ks(kernel_size)
12
+ return (ky - 1) // 2, (kx - 1) // 2
13
+
14
+
15
+ def _unpack_2d_ks(kernel_size: tuple[int, int] | int) -> tuple[int, int]:
16
+ if isinstance(kernel_size, int):
17
+ ky = kx = kernel_size
18
+ else:
19
+ assert len(kernel_size) == 2, '2D Kernel size should have a length of 2.'
20
+ ky, kx = kernel_size
21
+
22
+ ky = int(ky)
23
+ kx = int(kx)
24
+ return ky, kx
25
+
26
+
27
+ def gaussian(
28
+ window_size: int, sigma: Tensor | float, *, device: Device | None = None, dtype: Dtype | None = None
29
+ ) -> Tensor:
30
+
31
+ batch_size = sigma.shape[0]
32
+
33
+ x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
34
+
35
+ if window_size % 2 == 0:
36
+ x = x + 0.5
37
+
38
+ gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
39
+
40
+ return gauss / gauss.sum(-1, keepdim=True)
41
+
42
+
43
+ def get_gaussian_kernel1d(
44
+ kernel_size: int,
45
+ sigma: float | Tensor,
46
+ force_even: bool = False,
47
+ *,
48
+ device: Device | None = None,
49
+ dtype: Dtype | None = None,
50
+ ) -> Tensor:
51
+
52
+ return gaussian(kernel_size, sigma, device=device, dtype=dtype)
53
+
54
+
55
+ def get_gaussian_kernel2d(
56
+ kernel_size: tuple[int, int] | int,
57
+ sigma: tuple[float, float] | Tensor,
58
+ force_even: bool = False,
59
+ *,
60
+ device: Device | None = None,
61
+ dtype: Dtype | None = None,
62
+ ) -> Tensor:
63
+
64
+ sigma = torch.Tensor([[sigma, sigma]]).to(device=device, dtype=dtype)
65
+
66
+ ksize_y, ksize_x = _unpack_2d_ks(kernel_size)
67
+ sigma_y, sigma_x = sigma[:, 0, None], sigma[:, 1, None]
68
+
69
+ kernel_y = get_gaussian_kernel1d(ksize_y, sigma_y, force_even, device=device, dtype=dtype)[..., None]
70
+ kernel_x = get_gaussian_kernel1d(ksize_x, sigma_x, force_even, device=device, dtype=dtype)[..., None]
71
+
72
+ return kernel_y * kernel_x.view(-1, 1, ksize_x)
73
+
74
+
75
+ def _bilateral_blur(
76
+ input: Tensor,
77
+ guidance: Tensor | None,
78
+ kernel_size: tuple[int, int] | int,
79
+ sigma_color: float | Tensor,
80
+ sigma_space: tuple[float, float] | Tensor,
81
+ border_type: str = 'reflect',
82
+ color_distance_type: str = 'l1',
83
+ ) -> Tensor:
84
+
85
+ if isinstance(sigma_color, Tensor):
86
+ sigma_color = sigma_color.to(device=input.device, dtype=input.dtype).view(-1, 1, 1, 1, 1)
87
+
88
+ ky, kx = _unpack_2d_ks(kernel_size)
89
+ pad_y, pad_x = _compute_zero_padding(kernel_size)
90
+
91
+ padded_input = pad(input, (pad_x, pad_x, pad_y, pad_y), mode=border_type)
92
+ unfolded_input = padded_input.unfold(2, ky, 1).unfold(3, kx, 1).flatten(-2) # (B, C, H, W, Ky x Kx)
93
+
94
+ if guidance is None:
95
+ guidance = input
96
+ unfolded_guidance = unfolded_input
97
+ else:
98
+ padded_guidance = pad(guidance, (pad_x, pad_x, pad_y, pad_y), mode=border_type)
99
+ unfolded_guidance = padded_guidance.unfold(2, ky, 1).unfold(3, kx, 1).flatten(-2) # (B, C, H, W, Ky x Kx)
100
+
101
+ diff = unfolded_guidance - guidance.unsqueeze(-1)
102
+ if color_distance_type == "l1":
103
+ color_distance_sq = diff.abs().sum(1, keepdim=True).square()
104
+ elif color_distance_type == "l2":
105
+ color_distance_sq = diff.square().sum(1, keepdim=True)
106
+ else:
107
+ raise ValueError("color_distance_type only acceps l1 or l2")
108
+ color_kernel = (-0.5 / sigma_color**2 * color_distance_sq).exp() # (B, 1, H, W, Ky x Kx)
109
+
110
+ space_kernel = get_gaussian_kernel2d(kernel_size, sigma_space, device=input.device, dtype=input.dtype)
111
+ space_kernel = space_kernel.view(-1, 1, 1, 1, kx * ky)
112
+
113
+ kernel = space_kernel * color_kernel
114
+ out = (unfolded_input * kernel).sum(-1) / kernel.sum(-1)
115
+ return out
116
+
117
+
118
+ def bilateral_blur(
119
+ input: Tensor,
120
+ kernel_size: tuple[int, int] | int = (13, 13),
121
+ sigma_color: float | Tensor = 3.0,
122
+ sigma_space: tuple[float, float] | Tensor = 3.0,
123
+ border_type: str = 'reflect',
124
+ color_distance_type: str = 'l1',
125
+ ) -> Tensor:
126
+ return _bilateral_blur(input, None, kernel_size, sigma_color, sigma_space, border_type, color_distance_type)
127
+
128
+
129
+ def adaptive_anisotropic_filter(x, g=None):
130
+ if g is None:
131
+ g = x
132
+ s, m = torch.std_mean(g, dim=(1, 2, 3), keepdim=True)
133
+ s = s + 1e-5
134
+ guidance = (g - m) / s
135
+ y = _bilateral_blur(x, guidance,
136
+ kernel_size=(13, 13),
137
+ sigma_color=3.0,
138
+ sigma_space=3.0,
139
+ border_type='reflect',
140
+ color_distance_type='l1')
141
+ return y
142
+
143
+
144
+ def joint_bilateral_blur(
145
+ input: Tensor,
146
+ guidance: Tensor,
147
+ kernel_size: tuple[int, int] | int,
148
+ sigma_color: float | Tensor,
149
+ sigma_space: tuple[float, float] | Tensor,
150
+ border_type: str = 'reflect',
151
+ color_distance_type: str = 'l1',
152
+ ) -> Tensor:
153
+ return _bilateral_blur(input, guidance, kernel_size, sigma_color, sigma_space, border_type, color_distance_type)
154
+
155
+
156
+ class _BilateralBlur(torch.nn.Module):
157
+ def __init__(
158
+ self,
159
+ kernel_size: tuple[int, int] | int,
160
+ sigma_color: float | Tensor,
161
+ sigma_space: tuple[float, float] | Tensor,
162
+ border_type: str = 'reflect',
163
+ color_distance_type: str = "l1",
164
+ ) -> None:
165
+ super().__init__()
166
+ self.kernel_size = kernel_size
167
+ self.sigma_color = sigma_color
168
+ self.sigma_space = sigma_space
169
+ self.border_type = border_type
170
+ self.color_distance_type = color_distance_type
171
+
172
+ def __repr__(self) -> str:
173
+ return (
174
+ f"{self.__class__.__name__}"
175
+ f"(kernel_size={self.kernel_size}, "
176
+ f"sigma_color={self.sigma_color}, "
177
+ f"sigma_space={self.sigma_space}, "
178
+ f"border_type={self.border_type}, "
179
+ f"color_distance_type={self.color_distance_type})"
180
+ )
181
+
182
+
183
+ class BilateralBlur(_BilateralBlur):
184
+ def forward(self, input: Tensor) -> Tensor:
185
+ return bilateral_blur(
186
+ input, self.kernel_size, self.sigma_color, self.sigma_space, self.border_type, self.color_distance_type
187
+ )
188
+
189
+
190
+ class JointBilateralBlur(_BilateralBlur):
191
+ def forward(self, input: Tensor, guidance: Tensor) -> Tensor:
192
+ return joint_bilateral_blur(
193
+ input,
194
+ guidance,
195
+ self.kernel_size,
196
+ self.sigma_color,
197
+ self.sigma_space,
198
+ self.border_type,
199
+ self.color_distance_type,
200
+ )
modules/async_worker.py ADDED
@@ -0,0 +1,914 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ from modules.patch import PatchSettings, patch_settings, patch_all
3
+
4
+ patch_all()
5
+
6
+ class AsyncTask:
7
+ def __init__(self, args):
8
+ self.args = args
9
+ self.yields = []
10
+ self.results = []
11
+ self.last_stop = False
12
+ self.processing = False
13
+
14
+
15
+ async_tasks = []
16
+
17
+
18
+ def worker():
19
+ global async_tasks
20
+
21
+ import os
22
+ import traceback
23
+ import math
24
+ import numpy as np
25
+ import cv2
26
+ import torch
27
+ import time
28
+ import shared
29
+ import random
30
+ import copy
31
+ import modules.default_pipeline as pipeline
32
+ import modules.core as core
33
+ import modules.flags as flags
34
+ import modules.config
35
+ import modules.patch
36
+ import ldm_patched.modules.model_management
37
+ import extras.preprocessors as preprocessors
38
+ import modules.inpaint_worker as inpaint_worker
39
+ import modules.constants as constants
40
+ import extras.ip_adapter as ip_adapter
41
+ import extras.face_crop
42
+ import fooocus_version
43
+ import args_manager
44
+
45
+ from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion, apply_arrays
46
+ from modules.private_logger import log
47
+ from extras.expansion import safe_str
48
+ from modules.util import remove_empty_str, HWC3, resize_image, \
49
+ get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix
50
+ from modules.upscaler import perform_upscale
51
+ from modules.flags import Performance
52
+ from modules.meta_parser import get_metadata_parser, MetadataScheme
53
+
54
+ pid = os.getpid()
55
+ print(f'Started worker with PID {pid}')
56
+
57
+ try:
58
+ async_gradio_app = shared.gradio_root
59
+ flag = f'''App started successful. Use the app with {str(async_gradio_app.local_url)} or {str(async_gradio_app.server_name)}:{str(async_gradio_app.server_port)}'''
60
+ if async_gradio_app.share:
61
+ flag += f''' or {async_gradio_app.share_url}'''
62
+ print(flag)
63
+ except Exception as e:
64
+ print(e)
65
+
66
+ def progressbar(async_task, number, text):
67
+ print(f'[Fooocus] {text}')
68
+ async_task.yields.append(['preview', (number, text, None)])
69
+
70
+ def yield_result(async_task, imgs, do_not_show_finished_images=False):
71
+ if not isinstance(imgs, list):
72
+ imgs = [imgs]
73
+
74
+ async_task.results = async_task.results + imgs
75
+
76
+ if do_not_show_finished_images:
77
+ return
78
+
79
+ async_task.yields.append(['results', async_task.results])
80
+ return
81
+
82
+ def build_image_wall(async_task):
83
+ results = []
84
+
85
+ if len(async_task.results) < 2:
86
+ return
87
+
88
+ for img in async_task.results:
89
+ if isinstance(img, str) and os.path.exists(img):
90
+ img = cv2.imread(img)
91
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
92
+ if not isinstance(img, np.ndarray):
93
+ return
94
+ if img.ndim != 3:
95
+ return
96
+ results.append(img)
97
+
98
+ H, W, C = results[0].shape
99
+
100
+ for img in results:
101
+ Hn, Wn, Cn = img.shape
102
+ if H != Hn:
103
+ return
104
+ if W != Wn:
105
+ return
106
+ if C != Cn:
107
+ return
108
+
109
+ cols = float(len(results)) ** 0.5
110
+ cols = int(math.ceil(cols))
111
+ rows = float(len(results)) / float(cols)
112
+ rows = int(math.ceil(rows))
113
+
114
+ wall = np.zeros(shape=(H * rows, W * cols, C), dtype=np.uint8)
115
+
116
+ for y in range(rows):
117
+ for x in range(cols):
118
+ if y * cols + x < len(results):
119
+ img = results[y * cols + x]
120
+ wall[y * H:y * H + H, x * W:x * W + W, :] = img
121
+
122
+ # must use deep copy otherwise gradio is super laggy. Do not use list.append() .
123
+ async_task.results = async_task.results + [wall]
124
+ return
125
+
126
+ def apply_enabled_loras(loras):
127
+ enabled_loras = []
128
+ for lora_enabled, lora_model, lora_weight in loras:
129
+ if lora_enabled:
130
+ enabled_loras.append([lora_model, lora_weight])
131
+
132
+ return enabled_loras
133
+
134
+ @torch.no_grad()
135
+ @torch.inference_mode()
136
+ def handler(async_task):
137
+ execution_start_time = time.perf_counter()
138
+ async_task.processing = True
139
+
140
+ args = async_task.args
141
+ args.reverse()
142
+
143
+ prompt = args.pop()
144
+ negative_prompt = args.pop()
145
+ style_selections = args.pop()
146
+ performance_selection = Performance(args.pop())
147
+ aspect_ratios_selection = args.pop()
148
+ image_number = args.pop()
149
+ output_format = args.pop()
150
+ image_seed = args.pop()
151
+ sharpness = args.pop()
152
+ guidance_scale = args.pop()
153
+ base_model_name = args.pop()
154
+ refiner_model_name = args.pop()
155
+ refiner_switch = args.pop()
156
+ loras = apply_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop()), ] for _ in range(modules.config.default_max_lora_number)])
157
+ input_image_checkbox = args.pop()
158
+ current_tab = args.pop()
159
+ uov_method = args.pop()
160
+ uov_input_image = args.pop()
161
+ outpaint_selections = args.pop()
162
+ inpaint_input_image = args.pop()
163
+ inpaint_additional_prompt = args.pop()
164
+ inpaint_mask_image_upload = args.pop()
165
+
166
+ disable_preview = args.pop()
167
+ disable_intermediate_results = args.pop()
168
+ disable_seed_increment = args.pop()
169
+ adm_scaler_positive = args.pop()
170
+ adm_scaler_negative = args.pop()
171
+ adm_scaler_end = args.pop()
172
+ adaptive_cfg = args.pop()
173
+ sampler_name = args.pop()
174
+ scheduler_name = args.pop()
175
+ overwrite_step = args.pop()
176
+ overwrite_switch = args.pop()
177
+ overwrite_width = args.pop()
178
+ overwrite_height = args.pop()
179
+ overwrite_vary_strength = args.pop()
180
+ overwrite_upscale_strength = args.pop()
181
+ mixing_image_prompt_and_vary_upscale = args.pop()
182
+ mixing_image_prompt_and_inpaint = args.pop()
183
+ debugging_cn_preprocessor = args.pop()
184
+ skipping_cn_preprocessor = args.pop()
185
+ canny_low_threshold = args.pop()
186
+ canny_high_threshold = args.pop()
187
+ refiner_swap_method = args.pop()
188
+ controlnet_softness = args.pop()
189
+ freeu_enabled = args.pop()
190
+ freeu_b1 = args.pop()
191
+ freeu_b2 = args.pop()
192
+ freeu_s1 = args.pop()
193
+ freeu_s2 = args.pop()
194
+ debugging_inpaint_preprocessor = args.pop()
195
+ inpaint_disable_initial_latent = args.pop()
196
+ inpaint_engine = args.pop()
197
+ inpaint_strength = args.pop()
198
+ inpaint_respective_field = args.pop()
199
+ inpaint_mask_upload_checkbox = args.pop()
200
+ invert_mask_checkbox = args.pop()
201
+ inpaint_erode_or_dilate = args.pop()
202
+
203
+ save_metadata_to_images = args.pop() if not args_manager.args.disable_metadata else False
204
+ metadata_scheme = MetadataScheme(args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS
205
+
206
+ cn_tasks = {x: [] for x in flags.ip_list}
207
+ for _ in range(flags.controlnet_image_count):
208
+ cn_img = args.pop()
209
+ cn_stop = args.pop()
210
+ cn_weight = args.pop()
211
+ cn_type = args.pop()
212
+ if cn_img is not None:
213
+ cn_tasks[cn_type].append([cn_img, cn_stop, cn_weight])
214
+
215
+ outpaint_selections = [o.lower() for o in outpaint_selections]
216
+ base_model_additional_loras = []
217
+ raw_style_selections = copy.deepcopy(style_selections)
218
+ uov_method = uov_method.lower()
219
+
220
+ if fooocus_expansion in style_selections:
221
+ use_expansion = True
222
+ style_selections.remove(fooocus_expansion)
223
+ else:
224
+ use_expansion = False
225
+
226
+ use_style = len(style_selections) > 0
227
+
228
+ if base_model_name == refiner_model_name:
229
+ print(f'Refiner disabled because base model and refiner are same.')
230
+ refiner_model_name = 'None'
231
+
232
+ steps = performance_selection.steps()
233
+
234
+ if performance_selection == Performance.EXTREME_SPEED:
235
+ print('Enter LCM mode.')
236
+ progressbar(async_task, 1, 'Downloading LCM components ...')
237
+ loras += [(modules.config.downloading_sdxl_lcm_lora(), 1.0)]
238
+
239
+ if refiner_model_name != 'None':
240
+ print(f'Refiner disabled in LCM mode.')
241
+
242
+ refiner_model_name = 'None'
243
+ sampler_name = 'lcm'
244
+ scheduler_name = 'lcm'
245
+ sharpness = 0.0
246
+ guidance_scale = 1.0
247
+ adaptive_cfg = 1.0
248
+ refiner_switch = 1.0
249
+ adm_scaler_positive = 1.0
250
+ adm_scaler_negative = 1.0
251
+ adm_scaler_end = 0.0
252
+
253
+ print(f'[Parameters] Adaptive CFG = {adaptive_cfg}')
254
+ print(f'[Parameters] Sharpness = {sharpness}')
255
+ print(f'[Parameters] ControlNet Softness = {controlnet_softness}')
256
+ print(f'[Parameters] ADM Scale = '
257
+ f'{adm_scaler_positive} : '
258
+ f'{adm_scaler_negative} : '
259
+ f'{adm_scaler_end}')
260
+
261
+ patch_settings[pid] = PatchSettings(
262
+ sharpness,
263
+ adm_scaler_end,
264
+ adm_scaler_positive,
265
+ adm_scaler_negative,
266
+ controlnet_softness,
267
+ adaptive_cfg
268
+ )
269
+
270
+ cfg_scale = float(guidance_scale)
271
+ print(f'[Parameters] CFG = {cfg_scale}')
272
+
273
+ initial_latent = None
274
+ denoising_strength = 1.0
275
+ tiled = False
276
+
277
+ width, height = aspect_ratios_selection.replace('×', ' ').split(' ')[:2]
278
+ width, height = int(width), int(height)
279
+
280
+ skip_prompt_processing = False
281
+
282
+ inpaint_worker.current_task = None
283
+ inpaint_parameterized = inpaint_engine != 'None'
284
+ inpaint_image = None
285
+ inpaint_mask = None
286
+ inpaint_head_model_path = None
287
+
288
+ use_synthetic_refiner = False
289
+
290
+ controlnet_canny_path = None
291
+ controlnet_cpds_path = None
292
+ clip_vision_path, ip_negative_path, ip_adapter_path, ip_adapter_face_path = None, None, None, None
293
+
294
+ seed = int(image_seed)
295
+ print(f'[Parameters] Seed = {seed}')
296
+
297
+ goals = []
298
+ tasks = []
299
+
300
+ if input_image_checkbox:
301
+ if (current_tab == 'uov' or (
302
+ current_tab == 'ip' and mixing_image_prompt_and_vary_upscale)) \
303
+ and uov_method != flags.disabled and uov_input_image is not None:
304
+ uov_input_image = HWC3(uov_input_image)
305
+ if 'vary' in uov_method:
306
+ goals.append('vary')
307
+ elif 'upscale' in uov_method:
308
+ goals.append('upscale')
309
+ if 'fast' in uov_method:
310
+ skip_prompt_processing = True
311
+ else:
312
+ steps = performance_selection.steps_uov()
313
+
314
+ progressbar(async_task, 1, 'Downloading upscale models ...')
315
+ modules.config.downloading_upscale_model()
316
+ if (current_tab == 'inpaint' or (
317
+ current_tab == 'ip' and mixing_image_prompt_and_inpaint)) \
318
+ and isinstance(inpaint_input_image, dict):
319
+ inpaint_image = inpaint_input_image['image']
320
+ inpaint_mask = inpaint_input_image['mask'][:, :, 0]
321
+
322
+ if inpaint_mask_upload_checkbox:
323
+ if isinstance(inpaint_mask_image_upload, np.ndarray):
324
+ if inpaint_mask_image_upload.ndim == 3:
325
+ H, W, C = inpaint_image.shape
326
+ inpaint_mask_image_upload = resample_image(inpaint_mask_image_upload, width=W, height=H)
327
+ inpaint_mask_image_upload = np.mean(inpaint_mask_image_upload, axis=2)
328
+ inpaint_mask_image_upload = (inpaint_mask_image_upload > 127).astype(np.uint8) * 255
329
+ inpaint_mask = np.maximum(inpaint_mask, inpaint_mask_image_upload)
330
+
331
+ if int(inpaint_erode_or_dilate) != 0:
332
+ inpaint_mask = erode_or_dilate(inpaint_mask, inpaint_erode_or_dilate)
333
+
334
+ if invert_mask_checkbox:
335
+ inpaint_mask = 255 - inpaint_mask
336
+
337
+ inpaint_image = HWC3(inpaint_image)
338
+ if isinstance(inpaint_image, np.ndarray) and isinstance(inpaint_mask, np.ndarray) \
339
+ and (np.any(inpaint_mask > 127) or len(outpaint_selections) > 0):
340
+ progressbar(async_task, 1, 'Downloading upscale models ...')
341
+ modules.config.downloading_upscale_model()
342
+ if inpaint_parameterized:
343
+ progressbar(async_task, 1, 'Downloading inpainter ...')
344
+ inpaint_head_model_path, inpaint_patch_model_path = modules.config.downloading_inpaint_models(
345
+ inpaint_engine)
346
+ base_model_additional_loras += [(inpaint_patch_model_path, 1.0)]
347
+ print(f'[Inpaint] Current inpaint model is {inpaint_patch_model_path}')
348
+ if refiner_model_name == 'None':
349
+ use_synthetic_refiner = True
350
+ refiner_switch = 0.5
351
+ else:
352
+ inpaint_head_model_path, inpaint_patch_model_path = None, None
353
+ print(f'[Inpaint] Parameterized inpaint is disabled.')
354
+ if inpaint_additional_prompt != '':
355
+ if prompt == '':
356
+ prompt = inpaint_additional_prompt
357
+ else:
358
+ prompt = inpaint_additional_prompt + '\n' + prompt
359
+ goals.append('inpaint')
360
+ if current_tab == 'ip' or \
361
+ mixing_image_prompt_and_vary_upscale or \
362
+ mixing_image_prompt_and_inpaint:
363
+ goals.append('cn')
364
+ progressbar(async_task, 1, 'Downloading control models ...')
365
+ if len(cn_tasks[flags.cn_canny]) > 0:
366
+ controlnet_canny_path = modules.config.downloading_controlnet_canny()
367
+ if len(cn_tasks[flags.cn_cpds]) > 0:
368
+ controlnet_cpds_path = modules.config.downloading_controlnet_cpds()
369
+ if len(cn_tasks[flags.cn_ip]) > 0:
370
+ clip_vision_path, ip_negative_path, ip_adapter_path = modules.config.downloading_ip_adapters('ip')
371
+ if len(cn_tasks[flags.cn_ip_face]) > 0:
372
+ clip_vision_path, ip_negative_path, ip_adapter_face_path = modules.config.downloading_ip_adapters(
373
+ 'face')
374
+ progressbar(async_task, 1, 'Loading control models ...')
375
+
376
+ # Load or unload CNs
377
+ pipeline.refresh_controlnets([controlnet_canny_path, controlnet_cpds_path])
378
+ ip_adapter.load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_path)
379
+ ip_adapter.load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_face_path)
380
+
381
+ if overwrite_step > 0:
382
+ steps = overwrite_step
383
+
384
+ switch = int(round(steps * refiner_switch))
385
+
386
+ if overwrite_switch > 0:
387
+ switch = overwrite_switch
388
+
389
+ if overwrite_width > 0:
390
+ width = overwrite_width
391
+
392
+ if overwrite_height > 0:
393
+ height = overwrite_height
394
+
395
+ print(f'[Parameters] Sampler = {sampler_name} - {scheduler_name}')
396
+ print(f'[Parameters] Steps = {steps} - {switch}')
397
+
398
+ progressbar(async_task, 1, 'Initializing ...')
399
+
400
+ if not skip_prompt_processing:
401
+
402
+ prompts = remove_empty_str([safe_str(p) for p in prompt.splitlines()], default='')
403
+ negative_prompts = remove_empty_str([safe_str(p) for p in negative_prompt.splitlines()], default='')
404
+
405
+ prompt = prompts[0]
406
+ negative_prompt = negative_prompts[0]
407
+
408
+ if prompt == '':
409
+ # disable expansion when empty since it is not meaningful and influences image prompt
410
+ use_expansion = False
411
+
412
+ extra_positive_prompts = prompts[1:] if len(prompts) > 1 else []
413
+ extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else []
414
+
415
+ progressbar(async_task, 3, 'Loading models ...')
416
+ pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
417
+ loras=loras, base_model_additional_loras=base_model_additional_loras,
418
+ use_synthetic_refiner=use_synthetic_refiner)
419
+
420
+ progressbar(async_task, 3, 'Processing prompts ...')
421
+ tasks = []
422
+
423
+ for i in range(image_number):
424
+ if disable_seed_increment:
425
+ task_seed = seed
426
+ else:
427
+ task_seed = (seed + i) % (constants.MAX_SEED + 1) # randint is inclusive, % is not
428
+
429
+ task_rng = random.Random(task_seed) # may bind to inpaint noise in the future
430
+ task_prompt = apply_wildcards(prompt, task_rng)
431
+ task_prompt = apply_arrays(task_prompt, i)
432
+ task_negative_prompt = apply_wildcards(negative_prompt, task_rng)
433
+ task_extra_positive_prompts = [apply_wildcards(pmt, task_rng) for pmt in extra_positive_prompts]
434
+ task_extra_negative_prompts = [apply_wildcards(pmt, task_rng) for pmt in extra_negative_prompts]
435
+
436
+ positive_basic_workloads = []
437
+ negative_basic_workloads = []
438
+
439
+ if use_style:
440
+ for s in style_selections:
441
+ p, n = apply_style(s, positive=task_prompt)
442
+ positive_basic_workloads = positive_basic_workloads + p
443
+ negative_basic_workloads = negative_basic_workloads + n
444
+ else:
445
+ positive_basic_workloads.append(task_prompt)
446
+
447
+ negative_basic_workloads.append(task_negative_prompt) # Always use independent workload for negative.
448
+
449
+ positive_basic_workloads = positive_basic_workloads + task_extra_positive_prompts
450
+ negative_basic_workloads = negative_basic_workloads + task_extra_negative_prompts
451
+
452
+ positive_basic_workloads = remove_empty_str(positive_basic_workloads, default=task_prompt)
453
+ negative_basic_workloads = remove_empty_str(negative_basic_workloads, default=task_negative_prompt)
454
+
455
+ tasks.append(dict(
456
+ task_seed=task_seed,
457
+ task_prompt=task_prompt,
458
+ task_negative_prompt=task_negative_prompt,
459
+ positive=positive_basic_workloads,
460
+ negative=negative_basic_workloads,
461
+ expansion='',
462
+ c=None,
463
+ uc=None,
464
+ positive_top_k=len(positive_basic_workloads),
465
+ negative_top_k=len(negative_basic_workloads),
466
+ log_positive_prompt='\n'.join([task_prompt] + task_extra_positive_prompts),
467
+ log_negative_prompt='\n'.join([task_negative_prompt] + task_extra_negative_prompts),
468
+ ))
469
+
470
+ if use_expansion:
471
+ for i, t in enumerate(tasks):
472
+ progressbar(async_task, 5, f'Preparing Fooocus text #{i + 1} ...')
473
+ expansion = pipeline.final_expansion(t['task_prompt'], t['task_seed'])
474
+ print(f'[Prompt Expansion] {expansion}')
475
+ t['expansion'] = expansion
476
+ t['positive'] = copy.deepcopy(t['positive']) + [expansion] # Deep copy.
477
+
478
+ for i, t in enumerate(tasks):
479
+ progressbar(async_task, 7, f'Encoding positive #{i + 1} ...')
480
+ t['c'] = pipeline.clip_encode(texts=t['positive'], pool_top_k=t['positive_top_k'])
481
+
482
+ for i, t in enumerate(tasks):
483
+ if abs(float(cfg_scale) - 1.0) < 1e-4:
484
+ t['uc'] = pipeline.clone_cond(t['c'])
485
+ else:
486
+ progressbar(async_task, 10, f'Encoding negative #{i + 1} ...')
487
+ t['uc'] = pipeline.clip_encode(texts=t['negative'], pool_top_k=t['negative_top_k'])
488
+
489
+ if len(goals) > 0:
490
+ progressbar(async_task, 13, 'Image processing ...')
491
+
492
+ if 'vary' in goals:
493
+ if 'subtle' in uov_method:
494
+ denoising_strength = 0.5
495
+ if 'strong' in uov_method:
496
+ denoising_strength = 0.85
497
+ if overwrite_vary_strength > 0:
498
+ denoising_strength = overwrite_vary_strength
499
+
500
+ shape_ceil = get_image_shape_ceil(uov_input_image)
501
+ if shape_ceil < 1024:
502
+ print(f'[Vary] Image is resized because it is too small.')
503
+ shape_ceil = 1024
504
+ elif shape_ceil > 2048:
505
+ print(f'[Vary] Image is resized because it is too big.')
506
+ shape_ceil = 2048
507
+
508
+ uov_input_image = set_image_shape_ceil(uov_input_image, shape_ceil)
509
+
510
+ initial_pixels = core.numpy_to_pytorch(uov_input_image)
511
+ progressbar(async_task, 13, 'VAE encoding ...')
512
+
513
+ candidate_vae, _ = pipeline.get_candidate_vae(
514
+ steps=steps,
515
+ switch=switch,
516
+ denoise=denoising_strength,
517
+ refiner_swap_method=refiner_swap_method
518
+ )
519
+
520
+ initial_latent = core.encode_vae(vae=candidate_vae, pixels=initial_pixels)
521
+ B, C, H, W = initial_latent['samples'].shape
522
+ width = W * 8
523
+ height = H * 8
524
+ print(f'Final resolution is {str((height, width))}.')
525
+
526
+ if 'upscale' in goals:
527
+ H, W, C = uov_input_image.shape
528
+ progressbar(async_task, 13, f'Upscaling image from {str((H, W))} ...')
529
+ uov_input_image = perform_upscale(uov_input_image)
530
+ print(f'Image upscaled.')
531
+
532
+ if '1.5x' in uov_method:
533
+ f = 1.5
534
+ elif '2x' in uov_method:
535
+ f = 2.0
536
+ else:
537
+ f = 1.0
538
+
539
+ shape_ceil = get_shape_ceil(H * f, W * f)
540
+
541
+ if shape_ceil < 1024:
542
+ print(f'[Upscale] Image is resized because it is too small.')
543
+ uov_input_image = set_image_shape_ceil(uov_input_image, 1024)
544
+ shape_ceil = 1024
545
+ else:
546
+ uov_input_image = resample_image(uov_input_image, width=W * f, height=H * f)
547
+
548
+ image_is_super_large = shape_ceil > 2800
549
+
550
+ if 'fast' in uov_method:
551
+ direct_return = True
552
+ elif image_is_super_large:
553
+ print('Image is too large. Directly returned the SR image. '
554
+ 'Usually directly return SR image at 4K resolution '
555
+ 'yields better results than SDXL diffusion.')
556
+ direct_return = True
557
+ else:
558
+ direct_return = False
559
+
560
+ if direct_return:
561
+ d = [('Upscale (Fast)', 'upscale_fast', '2x')]
562
+ uov_input_image_path = log(uov_input_image, d, output_format=output_format)
563
+ yield_result(async_task, uov_input_image_path, do_not_show_finished_images=True)
564
+ return
565
+
566
+ tiled = True
567
+ denoising_strength = 0.382
568
+
569
+ if overwrite_upscale_strength > 0:
570
+ denoising_strength = overwrite_upscale_strength
571
+
572
+ initial_pixels = core.numpy_to_pytorch(uov_input_image)
573
+ progressbar(async_task, 13, 'VAE encoding ...')
574
+
575
+ candidate_vae, _ = pipeline.get_candidate_vae(
576
+ steps=steps,
577
+ switch=switch,
578
+ denoise=denoising_strength,
579
+ refiner_swap_method=refiner_swap_method
580
+ )
581
+
582
+ initial_latent = core.encode_vae(
583
+ vae=candidate_vae,
584
+ pixels=initial_pixels, tiled=True)
585
+ B, C, H, W = initial_latent['samples'].shape
586
+ width = W * 8
587
+ height = H * 8
588
+ print(f'Final resolution is {str((height, width))}.')
589
+
590
+ if 'inpaint' in goals:
591
+ if len(outpaint_selections) > 0:
592
+ H, W, C = inpaint_image.shape
593
+ if 'top' in outpaint_selections:
594
+ inpaint_image = np.pad(inpaint_image, [[int(H * 0.3), 0], [0, 0], [0, 0]], mode='edge')
595
+ inpaint_mask = np.pad(inpaint_mask, [[int(H * 0.3), 0], [0, 0]], mode='constant',
596
+ constant_values=255)
597
+ if 'bottom' in outpaint_selections:
598
+ inpaint_image = np.pad(inpaint_image, [[0, int(H * 0.3)], [0, 0], [0, 0]], mode='edge')
599
+ inpaint_mask = np.pad(inpaint_mask, [[0, int(H * 0.3)], [0, 0]], mode='constant',
600
+ constant_values=255)
601
+
602
+ H, W, C = inpaint_image.shape
603
+ if 'left' in outpaint_selections:
604
+ inpaint_image = np.pad(inpaint_image, [[0, 0], [int(H * 0.3), 0], [0, 0]], mode='edge')
605
+ inpaint_mask = np.pad(inpaint_mask, [[0, 0], [int(H * 0.3), 0]], mode='constant',
606
+ constant_values=255)
607
+ if 'right' in outpaint_selections:
608
+ inpaint_image = np.pad(inpaint_image, [[0, 0], [0, int(H * 0.3)], [0, 0]], mode='edge')
609
+ inpaint_mask = np.pad(inpaint_mask, [[0, 0], [0, int(H * 0.3)]], mode='constant',
610
+ constant_values=255)
611
+
612
+ inpaint_image = np.ascontiguousarray(inpaint_image.copy())
613
+ inpaint_mask = np.ascontiguousarray(inpaint_mask.copy())
614
+ inpaint_strength = 1.0
615
+ inpaint_respective_field = 1.0
616
+
617
+ denoising_strength = inpaint_strength
618
+
619
+ inpaint_worker.current_task = inpaint_worker.InpaintWorker(
620
+ image=inpaint_image,
621
+ mask=inpaint_mask,
622
+ use_fill=denoising_strength > 0.99,
623
+ k=inpaint_respective_field
624
+ )
625
+
626
+ if debugging_inpaint_preprocessor:
627
+ yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(),
628
+ do_not_show_finished_images=True)
629
+ return
630
+
631
+ progressbar(async_task, 13, 'VAE Inpaint encoding ...')
632
+
633
+ inpaint_pixel_fill = core.numpy_to_pytorch(inpaint_worker.current_task.interested_fill)
634
+ inpaint_pixel_image = core.numpy_to_pytorch(inpaint_worker.current_task.interested_image)
635
+ inpaint_pixel_mask = core.numpy_to_pytorch(inpaint_worker.current_task.interested_mask)
636
+
637
+ candidate_vae, candidate_vae_swap = pipeline.get_candidate_vae(
638
+ steps=steps,
639
+ switch=switch,
640
+ denoise=denoising_strength,
641
+ refiner_swap_method=refiner_swap_method
642
+ )
643
+
644
+ latent_inpaint, latent_mask = core.encode_vae_inpaint(
645
+ mask=inpaint_pixel_mask,
646
+ vae=candidate_vae,
647
+ pixels=inpaint_pixel_image)
648
+
649
+ latent_swap = None
650
+ if candidate_vae_swap is not None:
651
+ progressbar(async_task, 13, 'VAE SD15 encoding ...')
652
+ latent_swap = core.encode_vae(
653
+ vae=candidate_vae_swap,
654
+ pixels=inpaint_pixel_fill)['samples']
655
+
656
+ progressbar(async_task, 13, 'VAE encoding ...')
657
+ latent_fill = core.encode_vae(
658
+ vae=candidate_vae,
659
+ pixels=inpaint_pixel_fill)['samples']
660
+
661
+ inpaint_worker.current_task.load_latent(
662
+ latent_fill=latent_fill, latent_mask=latent_mask, latent_swap=latent_swap)
663
+
664
+ if inpaint_parameterized:
665
+ pipeline.final_unet = inpaint_worker.current_task.patch(
666
+ inpaint_head_model_path=inpaint_head_model_path,
667
+ inpaint_latent=latent_inpaint,
668
+ inpaint_latent_mask=latent_mask,
669
+ model=pipeline.final_unet
670
+ )
671
+
672
+ if not inpaint_disable_initial_latent:
673
+ initial_latent = {'samples': latent_fill}
674
+
675
+ B, C, H, W = latent_fill.shape
676
+ height, width = H * 8, W * 8
677
+ final_height, final_width = inpaint_worker.current_task.image.shape[:2]
678
+ print(f'Final resolution is {str((final_height, final_width))}, latent is {str((height, width))}.')
679
+
680
+ if 'cn' in goals:
681
+ for task in cn_tasks[flags.cn_canny]:
682
+ cn_img, cn_stop, cn_weight = task
683
+ cn_img = resize_image(HWC3(cn_img), width=width, height=height)
684
+
685
+ if not skipping_cn_preprocessor:
686
+ cn_img = preprocessors.canny_pyramid(cn_img, canny_low_threshold, canny_high_threshold)
687
+
688
+ cn_img = HWC3(cn_img)
689
+ task[0] = core.numpy_to_pytorch(cn_img)
690
+ if debugging_cn_preprocessor:
691
+ yield_result(async_task, cn_img, do_not_show_finished_images=True)
692
+ return
693
+ for task in cn_tasks[flags.cn_cpds]:
694
+ cn_img, cn_stop, cn_weight = task
695
+ cn_img = resize_image(HWC3(cn_img), width=width, height=height)
696
+
697
+ if not skipping_cn_preprocessor:
698
+ cn_img = preprocessors.cpds(cn_img)
699
+
700
+ cn_img = HWC3(cn_img)
701
+ task[0] = core.numpy_to_pytorch(cn_img)
702
+ if debugging_cn_preprocessor:
703
+ yield_result(async_task, cn_img, do_not_show_finished_images=True)
704
+ return
705
+ for task in cn_tasks[flags.cn_ip]:
706
+ cn_img, cn_stop, cn_weight = task
707
+ cn_img = HWC3(cn_img)
708
+
709
+ # https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/README.md?plain=1#L75
710
+ cn_img = resize_image(cn_img, width=224, height=224, resize_mode=0)
711
+
712
+ task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_path)
713
+ if debugging_cn_preprocessor:
714
+ yield_result(async_task, cn_img, do_not_show_finished_images=True)
715
+ return
716
+ for task in cn_tasks[flags.cn_ip_face]:
717
+ cn_img, cn_stop, cn_weight = task
718
+ cn_img = HWC3(cn_img)
719
+
720
+ if not skipping_cn_preprocessor:
721
+ cn_img = extras.face_crop.crop_image(cn_img)
722
+
723
+ # https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/README.md?plain=1#L75
724
+ cn_img = resize_image(cn_img, width=224, height=224, resize_mode=0)
725
+
726
+ task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_face_path)
727
+ if debugging_cn_preprocessor:
728
+ yield_result(async_task, cn_img, do_not_show_finished_images=True)
729
+ return
730
+
731
+ all_ip_tasks = cn_tasks[flags.cn_ip] + cn_tasks[flags.cn_ip_face]
732
+
733
+ if len(all_ip_tasks) > 0:
734
+ pipeline.final_unet = ip_adapter.patch_model(pipeline.final_unet, all_ip_tasks)
735
+
736
+ if freeu_enabled:
737
+ print(f'FreeU is enabled!')
738
+ pipeline.final_unet = core.apply_freeu(
739
+ pipeline.final_unet,
740
+ freeu_b1,
741
+ freeu_b2,
742
+ freeu_s1,
743
+ freeu_s2
744
+ )
745
+
746
+ all_steps = steps * image_number
747
+
748
+ print(f'[Parameters] Denoising Strength = {denoising_strength}')
749
+
750
+ if isinstance(initial_latent, dict) and 'samples' in initial_latent:
751
+ log_shape = initial_latent['samples'].shape
752
+ else:
753
+ log_shape = f'Image Space {(height, width)}'
754
+
755
+ print(f'[Parameters] Initial Latent shape: {log_shape}')
756
+
757
+ preparation_time = time.perf_counter() - execution_start_time
758
+ print(f'Preparation time: {preparation_time:.2f} seconds')
759
+
760
+ final_sampler_name = sampler_name
761
+ final_scheduler_name = scheduler_name
762
+
763
+ if scheduler_name == 'lcm':
764
+ final_scheduler_name = 'sgm_uniform'
765
+ if pipeline.final_unet is not None:
766
+ pipeline.final_unet = core.opModelSamplingDiscrete.patch(
767
+ pipeline.final_unet,
768
+ sampling='lcm',
769
+ zsnr=False)[0]
770
+ if pipeline.final_refiner_unet is not None:
771
+ pipeline.final_refiner_unet = core.opModelSamplingDiscrete.patch(
772
+ pipeline.final_refiner_unet,
773
+ sampling='lcm',
774
+ zsnr=False)[0]
775
+ print('Using lcm scheduler.')
776
+
777
+ async_task.yields.append(['preview', (13, 'Moving model to GPU ...', None)])
778
+
779
+ def callback(step, x0, x, total_steps, y):
780
+ done_steps = current_task_id * steps + step
781
+ async_task.yields.append(['preview', (
782
+ int(15.0 + 85.0 * float(done_steps) / float(all_steps)),
783
+ f'Step {step}/{total_steps} in the {current_task_id + 1}{ordinal_suffix(current_task_id + 1)} Sampling', y)])
784
+
785
+ for current_task_id, task in enumerate(tasks):
786
+ execution_start_time = time.perf_counter()
787
+
788
+ try:
789
+ if async_task.last_stop is not False:
790
+ ldm_patched.model_management.interrupt_current_processing()
791
+ positive_cond, negative_cond = task['c'], task['uc']
792
+
793
+ if 'cn' in goals:
794
+ for cn_flag, cn_path in [
795
+ (flags.cn_canny, controlnet_canny_path),
796
+ (flags.cn_cpds, controlnet_cpds_path)
797
+ ]:
798
+ for cn_img, cn_stop, cn_weight in cn_tasks[cn_flag]:
799
+ positive_cond, negative_cond = core.apply_controlnet(
800
+ positive_cond, negative_cond,
801
+ pipeline.loaded_ControlNets[cn_path], cn_img, cn_weight, 0, cn_stop)
802
+
803
+ imgs = pipeline.process_diffusion(
804
+ positive_cond=positive_cond,
805
+ negative_cond=negative_cond,
806
+ steps=steps,
807
+ switch=switch,
808
+ width=width,
809
+ height=height,
810
+ image_seed=task['task_seed'],
811
+ callback=callback,
812
+ sampler_name=final_sampler_name,
813
+ scheduler_name=final_scheduler_name,
814
+ latent=initial_latent,
815
+ denoise=denoising_strength,
816
+ tiled=tiled,
817
+ cfg_scale=cfg_scale,
818
+ refiner_swap_method=refiner_swap_method,
819
+ disable_preview=disable_preview
820
+ )
821
+
822
+ del task['c'], task['uc'], positive_cond, negative_cond # Save memory
823
+
824
+ if inpaint_worker.current_task is not None:
825
+ imgs = [inpaint_worker.current_task.post_process(x) for x in imgs]
826
+
827
+ img_paths = []
828
+ for x in imgs:
829
+ d = [('Prompt', 'prompt', task['log_positive_prompt']),
830
+ ('Negative Prompt', 'negative_prompt', task['log_negative_prompt']),
831
+ ('Fooocus V2 Expansion', 'prompt_expansion', task['expansion']),
832
+ ('Styles', 'styles', str(raw_style_selections)),
833
+ ('Performance', 'performance', performance_selection.value)]
834
+
835
+ if performance_selection.steps() != steps:
836
+ d.append(('Steps', 'steps', steps))
837
+
838
+ d += [('Resolution', 'resolution', str((width, height))),
839
+ ('Guidance Scale', 'guidance_scale', guidance_scale),
840
+ ('Sharpness', 'sharpness', sharpness),
841
+ ('ADM Guidance', 'adm_guidance', str((
842
+ modules.patch.patch_settings[pid].positive_adm_scale,
843
+ modules.patch.patch_settings[pid].negative_adm_scale,
844
+ modules.patch.patch_settings[pid].adm_scaler_end))),
845
+ ('Base Model', 'base_model', base_model_name),
846
+ ('Refiner Model', 'refiner_model', refiner_model_name),
847
+ ('Refiner Switch', 'refiner_switch', refiner_switch)]
848
+
849
+ if refiner_model_name != 'None':
850
+ if overwrite_switch > 0:
851
+ d.append(('Overwrite Switch', 'overwrite_switch', overwrite_switch))
852
+ if refiner_swap_method != flags.refiner_swap_method:
853
+ d.append(('Refiner Swap Method', 'refiner_swap_method', refiner_swap_method))
854
+ if modules.patch.patch_settings[pid].adaptive_cfg != modules.config.default_cfg_tsnr:
855
+ d.append(('CFG Mimicking from TSNR', 'adaptive_cfg', modules.patch.patch_settings[pid].adaptive_cfg))
856
+
857
+ d.append(('Sampler', 'sampler', sampler_name))
858
+ d.append(('Scheduler', 'scheduler', scheduler_name))
859
+ d.append(('Seed', 'seed', task['task_seed']))
860
+
861
+ if freeu_enabled:
862
+ d.append(('FreeU', 'freeu', str((freeu_b1, freeu_b2, freeu_s1, freeu_s2))))
863
+
864
+ for li, (n, w) in enumerate(loras):
865
+ if n != 'None':
866
+ d.append((f'LoRA {li + 1}', f'lora_combined_{li + 1}', f'{n} : {w}'))
867
+
868
+ metadata_parser = None
869
+ if save_metadata_to_images:
870
+ metadata_parser = modules.meta_parser.get_metadata_parser(metadata_scheme)
871
+ metadata_parser.set_data(task['log_positive_prompt'], task['positive'],
872
+ task['log_negative_prompt'], task['negative'],
873
+ steps, base_model_name, refiner_model_name, loras)
874
+ d.append(('Metadata Scheme', 'metadata_scheme', metadata_scheme.value if save_metadata_to_images else save_metadata_to_images))
875
+ d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version))
876
+ img_paths.append(log(x, d, metadata_parser, output_format))
877
+
878
+ yield_result(async_task, img_paths, do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results)
879
+ except ldm_patched.modules.model_management.InterruptProcessingException as e:
880
+ if async_task.last_stop == 'skip':
881
+ print('User skipped')
882
+ async_task.last_stop = False
883
+ continue
884
+ else:
885
+ print('User stopped')
886
+ break
887
+
888
+ execution_time = time.perf_counter() - execution_start_time
889
+ print(f'Generating and saving time: {execution_time:.2f} seconds')
890
+ async_task.processing = False
891
+ return
892
+
893
+ while True:
894
+ time.sleep(0.01)
895
+ if len(async_tasks) > 0:
896
+ task = async_tasks.pop(0)
897
+ generate_image_grid = task.args.pop(0)
898
+
899
+ try:
900
+ handler(task)
901
+ if generate_image_grid:
902
+ build_image_wall(task)
903
+ task.yields.append(['finish', task.results])
904
+ pipeline.prepare_text_encoder(async_call=True)
905
+ except:
906
+ traceback.print_exc()
907
+ task.yields.append(['finish', task.results])
908
+ finally:
909
+ if pid in modules.patch.patch_settings:
910
+ del modules.patch.patch_settings[pid]
911
+ pass
912
+
913
+
914
+ threading.Thread(target=worker, daemon=True).start()
modules/auth.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import hashlib
3
+ import modules.constants as constants
4
+
5
+ from os.path import exists
6
+
7
+
8
+ def auth_list_to_dict(auth_list):
9
+ auth_dict = {}
10
+ for auth_data in auth_list:
11
+ if 'user' in auth_data:
12
+ if 'hash' in auth_data:
13
+ auth_dict |= {auth_data['user']: auth_data['hash']}
14
+ elif 'pass' in auth_data:
15
+ auth_dict |= {auth_data['user']: hashlib.sha256(bytes(auth_data['pass'], encoding='utf-8')).hexdigest()}
16
+ return auth_dict
17
+
18
+
19
+ def load_auth_data(filename=None):
20
+ auth_dict = None
21
+ if filename != None and exists(filename):
22
+ with open(filename, encoding='utf-8') as auth_file:
23
+ try:
24
+ auth_obj = json.load(auth_file)
25
+ if isinstance(auth_obj, list) and len(auth_obj) > 0:
26
+ auth_dict = auth_list_to_dict(auth_obj)
27
+ except Exception as e:
28
+ print('load_auth_data, e: ' + str(e))
29
+ return auth_dict
30
+
31
+
32
+ auth_dict = load_auth_data(constants.AUTH_FILENAME)
33
+
34
+ auth_enabled = auth_dict != None
35
+
36
+
37
+ def check_auth(user, password):
38
+ if user not in auth_dict:
39
+ return False
40
+ else:
41
+ return hashlib.sha256(bytes(password, encoding='utf-8')).hexdigest() == auth_dict[user]
modules/config.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import math
4
+ import numbers
5
+ import args_manager
6
+ import modules.flags
7
+ import modules.sdxl_styles
8
+
9
+ from modules.model_loader import load_file_from_url
10
+ from modules.util import get_files_from_folder, makedirs_with_log
11
+ from modules.flags import Performance, MetadataScheme
12
+
13
+ def get_config_path(key, default_value):
14
+ env = os.getenv(key)
15
+ if env is not None and isinstance(env, str):
16
+ print(f"Environment: {key} = {env}")
17
+ return env
18
+ else:
19
+ return os.path.abspath(default_value)
20
+
21
+ config_path = get_config_path('config_path', "./config.txt")
22
+ config_example_path = get_config_path('config_example_path', "config_modification_tutorial.txt")
23
+ config_dict = {}
24
+ always_save_keys = []
25
+ visited_keys = []
26
+
27
+ try:
28
+ with open(os.path.abspath(f'./presets/default.json'), "r", encoding="utf-8") as json_file:
29
+ config_dict.update(json.load(json_file))
30
+ except Exception as e:
31
+ print(f'Load default preset failed.')
32
+ print(e)
33
+
34
+ try:
35
+ if os.path.exists(config_path):
36
+ with open(config_path, "r", encoding="utf-8") as json_file:
37
+ config_dict.update(json.load(json_file))
38
+ always_save_keys = list(config_dict.keys())
39
+ except Exception as e:
40
+ print(f'Failed to load config file "{config_path}" . The reason is: {str(e)}')
41
+ print('Please make sure that:')
42
+ print(f'1. The file "{config_path}" is a valid text file, and you have access to read it.')
43
+ print('2. Use "\\\\" instead of "\\" when describing paths.')
44
+ print('3. There is no "," before the last "}".')
45
+ print('4. All key/value formats are correct.')
46
+
47
+
48
+ def try_load_deprecated_user_path_config():
49
+ global config_dict
50
+
51
+ if not os.path.exists('user_path_config.txt'):
52
+ return
53
+
54
+ try:
55
+ deprecated_config_dict = json.load(open('user_path_config.txt', "r", encoding="utf-8"))
56
+
57
+ def replace_config(old_key, new_key):
58
+ if old_key in deprecated_config_dict:
59
+ config_dict[new_key] = deprecated_config_dict[old_key]
60
+ del deprecated_config_dict[old_key]
61
+
62
+ replace_config('modelfile_path', 'path_checkpoints')
63
+ replace_config('lorafile_path', 'path_loras')
64
+ replace_config('embeddings_path', 'path_embeddings')
65
+ replace_config('vae_approx_path', 'path_vae_approx')
66
+ replace_config('upscale_models_path', 'path_upscale_models')
67
+ replace_config('inpaint_models_path', 'path_inpaint')
68
+ replace_config('controlnet_models_path', 'path_controlnet')
69
+ replace_config('clip_vision_models_path', 'path_clip_vision')
70
+ replace_config('fooocus_expansion_path', 'path_fooocus_expansion')
71
+ replace_config('temp_outputs_path', 'path_outputs')
72
+
73
+ if deprecated_config_dict.get("default_model", None) == 'juggernautXL_version6Rundiffusion.safetensors':
74
+ os.replace('user_path_config.txt', 'user_path_config-deprecated.txt')
75
+ print('Config updated successfully in silence. '
76
+ 'A backup of previous config is written to "user_path_config-deprecated.txt".')
77
+ return
78
+
79
+ if input("Newer models and configs are available. "
80
+ "Download and update files? [Y/n]:") in ['n', 'N', 'No', 'no', 'NO']:
81
+ config_dict.update(deprecated_config_dict)
82
+ print('Loading using deprecated old models and deprecated old configs.')
83
+ return
84
+ else:
85
+ os.replace('user_path_config.txt', 'user_path_config-deprecated.txt')
86
+ print('Config updated successfully by user. '
87
+ 'A backup of previous config is written to "user_path_config-deprecated.txt".')
88
+ return
89
+ except Exception as e:
90
+ print('Processing deprecated config failed')
91
+ print(e)
92
+ return
93
+
94
+
95
+ try_load_deprecated_user_path_config()
96
+
97
+ preset = args_manager.args.preset
98
+
99
+ if isinstance(preset, str):
100
+ preset_path = os.path.abspath(f'./presets/{preset}.json')
101
+ try:
102
+ if os.path.exists(preset_path):
103
+ with open(preset_path, "r", encoding="utf-8") as json_file:
104
+ config_dict.update(json.load(json_file))
105
+ print(f'Loaded preset: {preset_path}')
106
+ else:
107
+ raise FileNotFoundError
108
+ except Exception as e:
109
+ print(f'Load preset [{preset_path}] failed')
110
+ print(e)
111
+
112
+
113
+ def get_path_output() -> str:
114
+ """
115
+ Checking output path argument and overriding default path.
116
+ """
117
+ global config_dict
118
+ path_output = get_dir_or_set_default('path_outputs', '../outputs/', make_directory=True)
119
+ if args_manager.args.output_path:
120
+ print(f'[CONFIG] Overriding config value path_outputs with {args_manager.args.output_path}')
121
+ config_dict['path_outputs'] = path_output = args_manager.args.output_path
122
+ return path_output
123
+
124
+
125
+ def get_dir_or_set_default(key, default_value, as_array=False, make_directory=False):
126
+ global config_dict, visited_keys, always_save_keys
127
+
128
+ if key not in visited_keys:
129
+ visited_keys.append(key)
130
+
131
+ if key not in always_save_keys:
132
+ always_save_keys.append(key)
133
+
134
+ v = os.getenv(key)
135
+ if v is not None:
136
+ print(f"Environment: {key} = {v}")
137
+ config_dict[key] = v
138
+ else:
139
+ v = config_dict.get(key, None)
140
+
141
+ if isinstance(v, str):
142
+ if make_directory:
143
+ makedirs_with_log(v)
144
+ if os.path.exists(v) and os.path.isdir(v):
145
+ return v if not as_array else [v]
146
+ elif isinstance(v, list):
147
+ if make_directory:
148
+ for d in v:
149
+ makedirs_with_log(d)
150
+ if all([os.path.exists(d) and os.path.isdir(d) for d in v]):
151
+ return v
152
+
153
+ if v is not None:
154
+ print(f'Failed to load config key: {json.dumps({key:v})} is invalid or does not exist; will use {json.dumps({key:default_value})} instead.')
155
+ if isinstance(default_value, list):
156
+ dp = []
157
+ for path in default_value:
158
+ abs_path = os.path.abspath(os.path.join(os.path.dirname(__file__), path))
159
+ dp.append(abs_path)
160
+ os.makedirs(abs_path, exist_ok=True)
161
+ else:
162
+ dp = os.path.abspath(os.path.join(os.path.dirname(__file__), default_value))
163
+ os.makedirs(dp, exist_ok=True)
164
+ if as_array:
165
+ dp = [dp]
166
+ config_dict[key] = dp
167
+ return dp
168
+
169
+
170
+ paths_checkpoints = get_dir_or_set_default('path_checkpoints', ['../models/checkpoints/'], True)
171
+ paths_loras = get_dir_or_set_default('path_loras', ['../models/loras/'], True)
172
+ path_embeddings = get_dir_or_set_default('path_embeddings', '../models/embeddings/')
173
+ path_vae_approx = get_dir_or_set_default('path_vae_approx', '../models/vae_approx/')
174
+ path_upscale_models = get_dir_or_set_default('path_upscale_models', '../models/upscale_models/')
175
+ path_inpaint = get_dir_or_set_default('path_inpaint', '../models/inpaint/')
176
+ path_controlnet = get_dir_or_set_default('path_controlnet', '../models/controlnet/')
177
+ path_clip_vision = get_dir_or_set_default('path_clip_vision', '../models/clip_vision/')
178
+ path_fooocus_expansion = get_dir_or_set_default('path_fooocus_expansion', '../models/prompt_expansion/fooocus_expansion')
179
+ path_outputs = get_path_output()
180
+
181
+ def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False):
182
+ global config_dict, visited_keys
183
+
184
+ if key not in visited_keys:
185
+ visited_keys.append(key)
186
+
187
+ v = os.getenv(key)
188
+ if v is not None:
189
+ print(f"Environment: {key} = {v}")
190
+ config_dict[key] = v
191
+
192
+ if key not in config_dict:
193
+ config_dict[key] = default_value
194
+ return default_value
195
+
196
+ v = config_dict.get(key, None)
197
+ if not disable_empty_as_none:
198
+ if v is None or v == '':
199
+ v = 'None'
200
+ if validator(v):
201
+ return v
202
+ else:
203
+ if v is not None:
204
+ print(f'Failed to load config key: {json.dumps({key:v})} is invalid; will use {json.dumps({key:default_value})} instead.')
205
+ config_dict[key] = default_value
206
+ return default_value
207
+
208
+
209
+ default_base_model_name = get_config_item_or_set_default(
210
+ key='default_model',
211
+ default_value='model.safetensors',
212
+ validator=lambda x: isinstance(x, str)
213
+ )
214
+ previous_default_models = get_config_item_or_set_default(
215
+ key='previous_default_models',
216
+ default_value=[],
217
+ validator=lambda x: isinstance(x, list) and all(isinstance(k, str) for k in x)
218
+ )
219
+ default_refiner_model_name = get_config_item_or_set_default(
220
+ key='default_refiner',
221
+ default_value='None',
222
+ validator=lambda x: isinstance(x, str)
223
+ )
224
+ default_refiner_switch = get_config_item_or_set_default(
225
+ key='default_refiner_switch',
226
+ default_value=0.8,
227
+ validator=lambda x: isinstance(x, numbers.Number) and 0 <= x <= 1
228
+ )
229
+ default_loras_min_weight = get_config_item_or_set_default(
230
+ key='default_loras_min_weight',
231
+ default_value=-2,
232
+ validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10
233
+ )
234
+ default_loras_max_weight = get_config_item_or_set_default(
235
+ key='default_loras_max_weight',
236
+ default_value=2,
237
+ validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10
238
+ )
239
+ default_loras = get_config_item_or_set_default(
240
+ key='default_loras',
241
+ default_value=[
242
+ [
243
+ "None",
244
+ 1.0
245
+ ],
246
+ [
247
+ "None",
248
+ 1.0
249
+ ],
250
+ [
251
+ "None",
252
+ 1.0
253
+ ],
254
+ [
255
+ "None",
256
+ 1.0
257
+ ],
258
+ [
259
+ "None",
260
+ 1.0
261
+ ]
262
+ ],
263
+ validator=lambda x: isinstance(x, list) and all(len(y) == 2 and isinstance(y[0], str) and isinstance(y[1], numbers.Number) for y in x)
264
+ )
265
+ default_max_lora_number = get_config_item_or_set_default(
266
+ key='default_max_lora_number',
267
+ default_value=len(default_loras) if isinstance(default_loras, list) and len(default_loras) > 0 else 5,
268
+ validator=lambda x: isinstance(x, int) and x >= 1
269
+ )
270
+ default_cfg_scale = get_config_item_or_set_default(
271
+ key='default_cfg_scale',
272
+ default_value=7.0,
273
+ validator=lambda x: isinstance(x, numbers.Number)
274
+ )
275
+ default_sample_sharpness = get_config_item_or_set_default(
276
+ key='default_sample_sharpness',
277
+ default_value=2.0,
278
+ validator=lambda x: isinstance(x, numbers.Number)
279
+ )
280
+ default_sampler = get_config_item_or_set_default(
281
+ key='default_sampler',
282
+ default_value='dpmpp_2m_sde_gpu',
283
+ validator=lambda x: x in modules.flags.sampler_list
284
+ )
285
+ default_scheduler = get_config_item_or_set_default(
286
+ key='default_scheduler',
287
+ default_value='karras',
288
+ validator=lambda x: x in modules.flags.scheduler_list
289
+ )
290
+ default_styles = get_config_item_or_set_default(
291
+ key='default_styles',
292
+ default_value=[
293
+ "Fooocus V2",
294
+ "Fooocus Enhance",
295
+ "Fooocus Sharp"
296
+ ],
297
+ validator=lambda x: isinstance(x, list) and all(y in modules.sdxl_styles.legal_style_names for y in x)
298
+ )
299
+ default_prompt_negative = get_config_item_or_set_default(
300
+ key='default_prompt_negative',
301
+ default_value='',
302
+ validator=lambda x: isinstance(x, str),
303
+ disable_empty_as_none=True
304
+ )
305
+ default_prompt = get_config_item_or_set_default(
306
+ key='default_prompt',
307
+ default_value='',
308
+ validator=lambda x: isinstance(x, str),
309
+ disable_empty_as_none=True
310
+ )
311
+ default_performance = get_config_item_or_set_default(
312
+ key='default_performance',
313
+ default_value=Performance.SPEED.value,
314
+ validator=lambda x: x in Performance.list()
315
+ )
316
+ default_advanced_checkbox = get_config_item_or_set_default(
317
+ key='default_advanced_checkbox',
318
+ default_value=False,
319
+ validator=lambda x: isinstance(x, bool)
320
+ )
321
+ default_max_image_number = get_config_item_or_set_default(
322
+ key='default_max_image_number',
323
+ default_value=32,
324
+ validator=lambda x: isinstance(x, int) and x >= 1
325
+ )
326
+ default_output_format = get_config_item_or_set_default(
327
+ key='default_output_format',
328
+ default_value='png',
329
+ validator=lambda x: x in modules.flags.output_formats
330
+ )
331
+ default_image_number = get_config_item_or_set_default(
332
+ key='default_image_number',
333
+ default_value=2,
334
+ validator=lambda x: isinstance(x, int) and 1 <= x <= default_max_image_number
335
+ )
336
+ checkpoint_downloads = get_config_item_or_set_default(
337
+ key='checkpoint_downloads',
338
+ default_value={},
339
+ validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items())
340
+ )
341
+ lora_downloads = get_config_item_or_set_default(
342
+ key='lora_downloads',
343
+ default_value={},
344
+ validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items())
345
+ )
346
+ embeddings_downloads = get_config_item_or_set_default(
347
+ key='embeddings_downloads',
348
+ default_value={},
349
+ validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items())
350
+ )
351
+ available_aspect_ratios = get_config_item_or_set_default(
352
+ key='available_aspect_ratios',
353
+ default_value=[
354
+ '704*1408', '704*1344', '768*1344', '768*1280', '832*1216', '832*1152',
355
+ '896*1152', '896*1088', '960*1088', '960*1024', '1024*1024', '1024*960',
356
+ '1088*960', '1088*896', '1152*896', '1152*832', '1216*832', '1280*768',
357
+ '1344*768', '1344*704', '1408*704', '1472*704', '1536*640', '1600*640',
358
+ '1664*576', '1728*576'
359
+ ],
360
+ validator=lambda x: isinstance(x, list) and all('*' in v for v in x) and len(x) > 1
361
+ )
362
+ default_aspect_ratio = get_config_item_or_set_default(
363
+ key='default_aspect_ratio',
364
+ default_value='1152*896' if '1152*896' in available_aspect_ratios else available_aspect_ratios[0],
365
+ validator=lambda x: x in available_aspect_ratios
366
+ )
367
+ default_inpaint_engine_version = get_config_item_or_set_default(
368
+ key='default_inpaint_engine_version',
369
+ default_value='v2.6',
370
+ validator=lambda x: x in modules.flags.inpaint_engine_versions
371
+ )
372
+ default_cfg_tsnr = get_config_item_or_set_default(
373
+ key='default_cfg_tsnr',
374
+ default_value=7.0,
375
+ validator=lambda x: isinstance(x, numbers.Number)
376
+ )
377
+ default_overwrite_step = get_config_item_or_set_default(
378
+ key='default_overwrite_step',
379
+ default_value=-1,
380
+ validator=lambda x: isinstance(x, int)
381
+ )
382
+ default_overwrite_switch = get_config_item_or_set_default(
383
+ key='default_overwrite_switch',
384
+ default_value=-1,
385
+ validator=lambda x: isinstance(x, int)
386
+ )
387
+ example_inpaint_prompts = get_config_item_or_set_default(
388
+ key='example_inpaint_prompts',
389
+ default_value=[
390
+ 'highly detailed face', 'detailed girl face', 'detailed man face', 'detailed hand', 'beautiful eyes'
391
+ ],
392
+ validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x)
393
+ )
394
+ default_save_metadata_to_images = get_config_item_or_set_default(
395
+ key='default_save_metadata_to_images',
396
+ default_value=False,
397
+ validator=lambda x: isinstance(x, bool)
398
+ )
399
+ default_metadata_scheme = get_config_item_or_set_default(
400
+ key='default_metadata_scheme',
401
+ default_value=MetadataScheme.FOOOCUS.value,
402
+ validator=lambda x: x in [y[1] for y in modules.flags.metadata_scheme if y[1] == x]
403
+ )
404
+ metadata_created_by = get_config_item_or_set_default(
405
+ key='metadata_created_by',
406
+ default_value='',
407
+ validator=lambda x: isinstance(x, str)
408
+ )
409
+
410
+ example_inpaint_prompts = [[x] for x in example_inpaint_prompts]
411
+
412
+ config_dict["default_loras"] = default_loras = default_loras[:default_max_lora_number] + [['None', 1.0] for _ in range(default_max_lora_number - len(default_loras))]
413
+
414
+ possible_preset_keys = [
415
+ "default_model",
416
+ "default_refiner",
417
+ "default_refiner_switch",
418
+ "default_loras_min_weight",
419
+ "default_loras_max_weight",
420
+ "default_loras",
421
+ "default_max_lora_number",
422
+ "default_cfg_scale",
423
+ "default_sample_sharpness",
424
+ "default_sampler",
425
+ "default_scheduler",
426
+ "default_performance",
427
+ "default_prompt",
428
+ "default_prompt_negative",
429
+ "default_styles",
430
+ "default_aspect_ratio",
431
+ "default_save_metadata_to_images",
432
+ "checkpoint_downloads",
433
+ "embeddings_downloads",
434
+ "lora_downloads",
435
+ ]
436
+
437
+
438
+ REWRITE_PRESET = False
439
+
440
+ if REWRITE_PRESET and isinstance(args_manager.args.preset, str):
441
+ save_path = 'presets/' + args_manager.args.preset + '.json'
442
+ with open(save_path, "w", encoding="utf-8") as json_file:
443
+ json.dump({k: config_dict[k] for k in possible_preset_keys}, json_file, indent=4)
444
+ print(f'Preset saved to {save_path}. Exiting ...')
445
+ exit(0)
446
+
447
+
448
+ def add_ratio(x):
449
+ a, b = x.replace('*', ' ').split(' ')[:2]
450
+ a, b = int(a), int(b)
451
+ g = math.gcd(a, b)
452
+ return f'{a}×{b} <span style="color: grey;"> \U00002223 {a // g}:{b // g}</span>'
453
+
454
+
455
+ default_aspect_ratio = add_ratio(default_aspect_ratio)
456
+ available_aspect_ratios = [add_ratio(x) for x in available_aspect_ratios]
457
+
458
+
459
+ # Only write config in the first launch.
460
+ if not os.path.exists(config_path):
461
+ with open(config_path, "w", encoding="utf-8") as json_file:
462
+ json.dump({k: config_dict[k] for k in always_save_keys}, json_file, indent=4)
463
+
464
+
465
+ # Always write tutorials.
466
+ with open(config_example_path, "w", encoding="utf-8") as json_file:
467
+ cpa = config_path.replace("\\", "\\\\")
468
+ json_file.write(f'You can modify your "{cpa}" using the below keys, formats, and examples.\n'
469
+ f'Do not modify this file. Modifications in this file will not take effect.\n'
470
+ f'This file is a tutorial and example. Please edit "{cpa}" to really change any settings.\n'
471
+ + 'Remember to split the paths with "\\\\" rather than "\\", '
472
+ 'and there is no "," before the last "}". \n\n\n')
473
+ json.dump({k: config_dict[k] for k in visited_keys}, json_file, indent=4)
474
+
475
+ model_filenames = []
476
+ lora_filenames = []
477
+ sdxl_lcm_lora = 'sdxl_lcm_lora.safetensors'
478
+
479
+
480
+ def get_model_filenames(folder_paths, name_filter=None):
481
+ extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch']
482
+ files = []
483
+ for folder in folder_paths:
484
+ files += get_files_from_folder(folder, extensions, name_filter)
485
+ return files
486
+
487
+
488
+ def update_all_model_names():
489
+ global model_filenames, lora_filenames
490
+ model_filenames = get_model_filenames(paths_checkpoints)
491
+ lora_filenames = get_model_filenames(paths_loras)
492
+ return
493
+
494
+
495
+ def downloading_inpaint_models(v):
496
+ assert v in modules.flags.inpaint_engine_versions
497
+
498
+ load_file_from_url(
499
+ url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/fooocus_inpaint_head.pth',
500
+ model_dir=path_inpaint,
501
+ file_name='fooocus_inpaint_head.pth'
502
+ )
503
+ head_file = os.path.join(path_inpaint, 'fooocus_inpaint_head.pth')
504
+ patch_file = None
505
+
506
+ if v == 'v1':
507
+ load_file_from_url(
508
+ url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint.fooocus.patch',
509
+ model_dir=path_inpaint,
510
+ file_name='inpaint.fooocus.patch'
511
+ )
512
+ patch_file = os.path.join(path_inpaint, 'inpaint.fooocus.patch')
513
+
514
+ if v == 'v2.5':
515
+ load_file_from_url(
516
+ url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint_v25.fooocus.patch',
517
+ model_dir=path_inpaint,
518
+ file_name='inpaint_v25.fooocus.patch'
519
+ )
520
+ patch_file = os.path.join(path_inpaint, 'inpaint_v25.fooocus.patch')
521
+
522
+ if v == 'v2.6':
523
+ load_file_from_url(
524
+ url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint_v26.fooocus.patch',
525
+ model_dir=path_inpaint,
526
+ file_name='inpaint_v26.fooocus.patch'
527
+ )
528
+ patch_file = os.path.join(path_inpaint, 'inpaint_v26.fooocus.patch')
529
+
530
+ return head_file, patch_file
531
+
532
+
533
+ def downloading_sdxl_lcm_lora():
534
+ load_file_from_url(
535
+ url='https://huggingface.co/lllyasviel/misc/resolve/main/sdxl_lcm_lora.safetensors',
536
+ model_dir=paths_loras[0],
537
+ file_name=sdxl_lcm_lora
538
+ )
539
+ return sdxl_lcm_lora
540
+
541
+
542
+ def downloading_controlnet_canny():
543
+ load_file_from_url(
544
+ url='https://huggingface.co/lllyasviel/misc/resolve/main/control-lora-canny-rank128.safetensors',
545
+ model_dir=path_controlnet,
546
+ file_name='control-lora-canny-rank128.safetensors'
547
+ )
548
+ return os.path.join(path_controlnet, 'control-lora-canny-rank128.safetensors')
549
+
550
+
551
+ def downloading_controlnet_cpds():
552
+ load_file_from_url(
553
+ url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_xl_cpds_128.safetensors',
554
+ model_dir=path_controlnet,
555
+ file_name='fooocus_xl_cpds_128.safetensors'
556
+ )
557
+ return os.path.join(path_controlnet, 'fooocus_xl_cpds_128.safetensors')
558
+
559
+
560
+ def downloading_ip_adapters(v):
561
+ assert v in ['ip', 'face']
562
+
563
+ results = []
564
+
565
+ load_file_from_url(
566
+ url='https://huggingface.co/lllyasviel/misc/resolve/main/clip_vision_vit_h.safetensors',
567
+ model_dir=path_clip_vision,
568
+ file_name='clip_vision_vit_h.safetensors'
569
+ )
570
+ results += [os.path.join(path_clip_vision, 'clip_vision_vit_h.safetensors')]
571
+
572
+ load_file_from_url(
573
+ url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_ip_negative.safetensors',
574
+ model_dir=path_controlnet,
575
+ file_name='fooocus_ip_negative.safetensors'
576
+ )
577
+ results += [os.path.join(path_controlnet, 'fooocus_ip_negative.safetensors')]
578
+
579
+ if v == 'ip':
580
+ load_file_from_url(
581
+ url='https://huggingface.co/lllyasviel/misc/resolve/main/ip-adapter-plus_sdxl_vit-h.bin',
582
+ model_dir=path_controlnet,
583
+ file_name='ip-adapter-plus_sdxl_vit-h.bin'
584
+ )
585
+ results += [os.path.join(path_controlnet, 'ip-adapter-plus_sdxl_vit-h.bin')]
586
+
587
+ if v == 'face':
588
+ load_file_from_url(
589
+ url='https://huggingface.co/lllyasviel/misc/resolve/main/ip-adapter-plus-face_sdxl_vit-h.bin',
590
+ model_dir=path_controlnet,
591
+ file_name='ip-adapter-plus-face_sdxl_vit-h.bin'
592
+ )
593
+ results += [os.path.join(path_controlnet, 'ip-adapter-plus-face_sdxl_vit-h.bin')]
594
+
595
+ return results
596
+
597
+
598
+ def downloading_upscale_model():
599
+ load_file_from_url(
600
+ url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_upscaler_s409985e5.bin',
601
+ model_dir=path_upscale_models,
602
+ file_name='fooocus_upscaler_s409985e5.bin'
603
+ )
604
+ return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin')
605
+
606
+
607
+ update_all_model_names()
modules/constants.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # as in k-diffusion (sampling.py)
2
+ MIN_SEED = 0
3
+ MAX_SEED = 2**63 - 1
4
+
5
+ AUTH_FILENAME = 'auth.json'
modules/core.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import einops
3
+ import torch
4
+ import numpy as np
5
+
6
+ import ldm_patched.modules.model_management
7
+ import ldm_patched.modules.model_detection
8
+ import ldm_patched.modules.model_patcher
9
+ import ldm_patched.modules.utils
10
+ import ldm_patched.modules.controlnet
11
+ import modules.sample_hijack
12
+ import ldm_patched.modules.samplers
13
+ import ldm_patched.modules.latent_formats
14
+
15
+ from ldm_patched.modules.sd import load_checkpoint_guess_config
16
+ from ldm_patched.contrib.external import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDecodeTiled, \
17
+ ControlNetApplyAdvanced
18
+ from ldm_patched.contrib.external_freelunch import FreeU_V2
19
+ from ldm_patched.modules.sample import prepare_mask
20
+ from modules.lora import match_lora
21
+ from modules.util import get_file_from_folder_list
22
+ from ldm_patched.modules.lora import model_lora_keys_unet, model_lora_keys_clip
23
+ from modules.config import path_embeddings
24
+ from ldm_patched.contrib.external_model_advanced import ModelSamplingDiscrete
25
+
26
+
27
+ opEmptyLatentImage = EmptyLatentImage()
28
+ opVAEDecode = VAEDecode()
29
+ opVAEEncode = VAEEncode()
30
+ opVAEDecodeTiled = VAEDecodeTiled()
31
+ opVAEEncodeTiled = VAEEncodeTiled()
32
+ opControlNetApplyAdvanced = ControlNetApplyAdvanced()
33
+ opFreeU = FreeU_V2()
34
+ opModelSamplingDiscrete = ModelSamplingDiscrete()
35
+
36
+
37
+ class StableDiffusionModel:
38
+ def __init__(self, unet=None, vae=None, clip=None, clip_vision=None, filename=None):
39
+ self.unet = unet
40
+ self.vae = vae
41
+ self.clip = clip
42
+ self.clip_vision = clip_vision
43
+ self.filename = filename
44
+ self.unet_with_lora = unet
45
+ self.clip_with_lora = clip
46
+ self.visited_loras = ''
47
+
48
+ self.lora_key_map_unet = {}
49
+ self.lora_key_map_clip = {}
50
+
51
+ if self.unet is not None:
52
+ self.lora_key_map_unet = model_lora_keys_unet(self.unet.model, self.lora_key_map_unet)
53
+ self.lora_key_map_unet.update({x: x for x in self.unet.model.state_dict().keys()})
54
+
55
+ if self.clip is not None:
56
+ self.lora_key_map_clip = model_lora_keys_clip(self.clip.cond_stage_model, self.lora_key_map_clip)
57
+ self.lora_key_map_clip.update({x: x for x in self.clip.cond_stage_model.state_dict().keys()})
58
+
59
+ @torch.no_grad()
60
+ @torch.inference_mode()
61
+ def refresh_loras(self, loras):
62
+ assert isinstance(loras, list)
63
+
64
+ if self.visited_loras == str(loras):
65
+ return
66
+
67
+ self.visited_loras = str(loras)
68
+
69
+ if self.unet is None:
70
+ return
71
+
72
+ print(f'Request to load LoRAs {str(loras)} for model [{self.filename}].')
73
+
74
+ loras_to_load = []
75
+
76
+ for name, weight in loras:
77
+ if name == 'None':
78
+ continue
79
+
80
+ if os.path.exists(name):
81
+ lora_filename = name
82
+ else:
83
+ lora_filename = get_file_from_folder_list(name, modules.config.paths_loras)
84
+
85
+ if not os.path.exists(lora_filename):
86
+ print(f'Lora file not found: {lora_filename}')
87
+ continue
88
+
89
+ loras_to_load.append((lora_filename, weight))
90
+
91
+ self.unet_with_lora = self.unet.clone() if self.unet is not None else None
92
+ self.clip_with_lora = self.clip.clone() if self.clip is not None else None
93
+
94
+ for lora_filename, weight in loras_to_load:
95
+ lora_unmatch = ldm_patched.modules.utils.load_torch_file(lora_filename, safe_load=False)
96
+ lora_unet, lora_unmatch = match_lora(lora_unmatch, self.lora_key_map_unet)
97
+ lora_clip, lora_unmatch = match_lora(lora_unmatch, self.lora_key_map_clip)
98
+
99
+ if len(lora_unmatch) > 12:
100
+ # model mismatch
101
+ continue
102
+
103
+ if len(lora_unmatch) > 0:
104
+ print(f'Loaded LoRA [{lora_filename}] for model [{self.filename}] '
105
+ f'with unmatched keys {list(lora_unmatch.keys())}')
106
+
107
+ if self.unet_with_lora is not None and len(lora_unet) > 0:
108
+ loaded_keys = self.unet_with_lora.add_patches(lora_unet, weight)
109
+ print(f'Loaded LoRA [{lora_filename}] for UNet [{self.filename}] '
110
+ f'with {len(loaded_keys)} keys at weight {weight}.')
111
+ for item in lora_unet:
112
+ if item not in loaded_keys:
113
+ print("UNet LoRA key skipped: ", item)
114
+
115
+ if self.clip_with_lora is not None and len(lora_clip) > 0:
116
+ loaded_keys = self.clip_with_lora.add_patches(lora_clip, weight)
117
+ print(f'Loaded LoRA [{lora_filename}] for CLIP [{self.filename}] '
118
+ f'with {len(loaded_keys)} keys at weight {weight}.')
119
+ for item in lora_clip:
120
+ if item not in loaded_keys:
121
+ print("CLIP LoRA key skipped: ", item)
122
+
123
+
124
+ @torch.no_grad()
125
+ @torch.inference_mode()
126
+ def apply_freeu(model, b1, b2, s1, s2):
127
+ return opFreeU.patch(model=model, b1=b1, b2=b2, s1=s1, s2=s2)[0]
128
+
129
+
130
+ @torch.no_grad()
131
+ @torch.inference_mode()
132
+ def load_controlnet(ckpt_filename):
133
+ return ldm_patched.modules.controlnet.load_controlnet(ckpt_filename)
134
+
135
+
136
+ @torch.no_grad()
137
+ @torch.inference_mode()
138
+ def apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent):
139
+ return opControlNetApplyAdvanced.apply_controlnet(positive=positive, negative=negative, control_net=control_net,
140
+ image=image, strength=strength, start_percent=start_percent, end_percent=end_percent)
141
+
142
+
143
+ @torch.no_grad()
144
+ @torch.inference_mode()
145
+ def load_model(ckpt_filename):
146
+ unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings)
147
+ return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename)
148
+
149
+
150
+ @torch.no_grad()
151
+ @torch.inference_mode()
152
+ def generate_empty_latent(width=1024, height=1024, batch_size=1):
153
+ return opEmptyLatentImage.generate(width=width, height=height, batch_size=batch_size)[0]
154
+
155
+
156
+ @torch.no_grad()
157
+ @torch.inference_mode()
158
+ def decode_vae(vae, latent_image, tiled=False):
159
+ if tiled:
160
+ return opVAEDecodeTiled.decode(samples=latent_image, vae=vae, tile_size=512)[0]
161
+ else:
162
+ return opVAEDecode.decode(samples=latent_image, vae=vae)[0]
163
+
164
+
165
+ @torch.no_grad()
166
+ @torch.inference_mode()
167
+ def encode_vae(vae, pixels, tiled=False):
168
+ if tiled:
169
+ return opVAEEncodeTiled.encode(pixels=pixels, vae=vae, tile_size=512)[0]
170
+ else:
171
+ return opVAEEncode.encode(pixels=pixels, vae=vae)[0]
172
+
173
+
174
+ @torch.no_grad()
175
+ @torch.inference_mode()
176
+ def encode_vae_inpaint(vae, pixels, mask):
177
+ assert mask.ndim == 3 and pixels.ndim == 4
178
+ assert mask.shape[-1] == pixels.shape[-2]
179
+ assert mask.shape[-2] == pixels.shape[-3]
180
+
181
+ w = mask.round()[..., None]
182
+ pixels = pixels * (1 - w) + 0.5 * w
183
+
184
+ latent = vae.encode(pixels)
185
+ B, C, H, W = latent.shape
186
+
187
+ latent_mask = mask[:, None, :, :]
188
+ latent_mask = torch.nn.functional.interpolate(latent_mask, size=(H * 8, W * 8), mode="bilinear").round()
189
+ latent_mask = torch.nn.functional.max_pool2d(latent_mask, (8, 8)).round().to(latent)
190
+
191
+ return latent, latent_mask
192
+
193
+
194
+ class VAEApprox(torch.nn.Module):
195
+ def __init__(self):
196
+ super(VAEApprox, self).__init__()
197
+ self.conv1 = torch.nn.Conv2d(4, 8, (7, 7))
198
+ self.conv2 = torch.nn.Conv2d(8, 16, (5, 5))
199
+ self.conv3 = torch.nn.Conv2d(16, 32, (3, 3))
200
+ self.conv4 = torch.nn.Conv2d(32, 64, (3, 3))
201
+ self.conv5 = torch.nn.Conv2d(64, 32, (3, 3))
202
+ self.conv6 = torch.nn.Conv2d(32, 16, (3, 3))
203
+ self.conv7 = torch.nn.Conv2d(16, 8, (3, 3))
204
+ self.conv8 = torch.nn.Conv2d(8, 3, (3, 3))
205
+ self.current_type = None
206
+
207
+ def forward(self, x):
208
+ extra = 11
209
+ x = torch.nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
210
+ x = torch.nn.functional.pad(x, (extra, extra, extra, extra))
211
+ for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8]:
212
+ x = layer(x)
213
+ x = torch.nn.functional.leaky_relu(x, 0.1)
214
+ return x
215
+
216
+
217
+ VAE_approx_models = {}
218
+
219
+
220
+ @torch.no_grad()
221
+ @torch.inference_mode()
222
+ def get_previewer(model):
223
+ global VAE_approx_models
224
+
225
+ from modules.config import path_vae_approx
226
+ is_sdxl = isinstance(model.model.latent_format, ldm_patched.modules.latent_formats.SDXL)
227
+ vae_approx_filename = os.path.join(path_vae_approx, 'xlvaeapp.pth' if is_sdxl else 'vaeapp_sd15.pth')
228
+
229
+ if vae_approx_filename in VAE_approx_models:
230
+ VAE_approx_model = VAE_approx_models[vae_approx_filename]
231
+ else:
232
+ sd = torch.load(vae_approx_filename, map_location='cpu')
233
+ VAE_approx_model = VAEApprox()
234
+ VAE_approx_model.load_state_dict(sd)
235
+ del sd
236
+ VAE_approx_model.eval()
237
+
238
+ if ldm_patched.modules.model_management.should_use_fp16():
239
+ VAE_approx_model.half()
240
+ VAE_approx_model.current_type = torch.float16
241
+ else:
242
+ VAE_approx_model.float()
243
+ VAE_approx_model.current_type = torch.float32
244
+
245
+ VAE_approx_model.to(ldm_patched.modules.model_management.get_torch_device())
246
+ VAE_approx_models[vae_approx_filename] = VAE_approx_model
247
+
248
+ @torch.no_grad()
249
+ @torch.inference_mode()
250
+ def preview_function(x0, step, total_steps):
251
+ with torch.no_grad():
252
+ x_sample = x0.to(VAE_approx_model.current_type)
253
+ x_sample = VAE_approx_model(x_sample) * 127.5 + 127.5
254
+ x_sample = einops.rearrange(x_sample, 'b c h w -> b h w c')[0]
255
+ x_sample = x_sample.cpu().numpy().clip(0, 255).astype(np.uint8)
256
+ return x_sample
257
+
258
+ return preview_function
259
+
260
+
261
+ @torch.no_grad()
262
+ @torch.inference_mode()
263
+ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sampler_name='dpmpp_2m_sde_gpu',
264
+ scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None,
265
+ force_full_denoise=False, callback_function=None, refiner=None, refiner_switch=-1,
266
+ previewer_start=None, previewer_end=None, sigmas=None, noise_mean=None, disable_preview=False):
267
+
268
+ if sigmas is not None:
269
+ sigmas = sigmas.clone().to(ldm_patched.modules.model_management.get_torch_device())
270
+
271
+ latent_image = latent["samples"]
272
+
273
+ if disable_noise:
274
+ noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
275
+ else:
276
+ batch_inds = latent["batch_index"] if "batch_index" in latent else None
277
+ noise = ldm_patched.modules.sample.prepare_noise(latent_image, seed, batch_inds)
278
+
279
+ if isinstance(noise_mean, torch.Tensor):
280
+ noise = noise + noise_mean - torch.mean(noise, dim=1, keepdim=True)
281
+
282
+ noise_mask = None
283
+ if "noise_mask" in latent:
284
+ noise_mask = latent["noise_mask"]
285
+
286
+ previewer = get_previewer(model)
287
+
288
+ if previewer_start is None:
289
+ previewer_start = 0
290
+
291
+ if previewer_end is None:
292
+ previewer_end = steps
293
+
294
+ def callback(step, x0, x, total_steps):
295
+ ldm_patched.modules.model_management.throw_exception_if_processing_interrupted()
296
+ y = None
297
+ if previewer is not None and not disable_preview:
298
+ y = previewer(x0, previewer_start + step, previewer_end)
299
+ if callback_function is not None:
300
+ callback_function(previewer_start + step, x0, x, previewer_end, y)
301
+
302
+ disable_pbar = False
303
+ modules.sample_hijack.current_refiner = refiner
304
+ modules.sample_hijack.refiner_switch_step = refiner_switch
305
+ ldm_patched.modules.samplers.sample = modules.sample_hijack.sample_hacked
306
+
307
+ try:
308
+ samples = ldm_patched.modules.sample.sample(model,
309
+ noise, steps, cfg, sampler_name, scheduler,
310
+ positive, negative, latent_image,
311
+ denoise=denoise, disable_noise=disable_noise,
312
+ start_step=start_step,
313
+ last_step=last_step,
314
+ force_full_denoise=force_full_denoise, noise_mask=noise_mask,
315
+ callback=callback,
316
+ disable_pbar=disable_pbar, seed=seed, sigmas=sigmas)
317
+
318
+ out = latent.copy()
319
+ out["samples"] = samples
320
+ finally:
321
+ modules.sample_hijack.current_refiner = None
322
+
323
+ return out
324
+
325
+
326
+ @torch.no_grad()
327
+ @torch.inference_mode()
328
+ def pytorch_to_numpy(x):
329
+ return [np.clip(255. * y.cpu().numpy(), 0, 255).astype(np.uint8) for y in x]
330
+
331
+
332
+ @torch.no_grad()
333
+ @torch.inference_mode()
334
+ def numpy_to_pytorch(x):
335
+ y = x.astype(np.float32) / 255.0
336
+ y = y[None]
337
+ y = np.ascontiguousarray(y.copy())
338
+ y = torch.from_numpy(y).float()
339
+ return y
modules/default_pipeline.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import modules.core as core
2
+ import os
3
+ import torch
4
+ import modules.patch
5
+ import modules.config
6
+ import ldm_patched.modules.model_management
7
+ import ldm_patched.modules.latent_formats
8
+ import modules.inpaint_worker
9
+ import extras.vae_interpose as vae_interpose
10
+ from extras.expansion import FooocusExpansion
11
+
12
+ from ldm_patched.modules.model_base import SDXL, SDXLRefiner
13
+ from modules.sample_hijack import clip_separate
14
+ from modules.util import get_file_from_folder_list
15
+
16
+
17
+ model_base = core.StableDiffusionModel()
18
+ model_refiner = core.StableDiffusionModel()
19
+
20
+ final_expansion = None
21
+ final_unet = None
22
+ final_clip = None
23
+ final_vae = None
24
+ final_refiner_unet = None
25
+ final_refiner_vae = None
26
+
27
+ loaded_ControlNets = {}
28
+
29
+
30
+ @torch.no_grad()
31
+ @torch.inference_mode()
32
+ def refresh_controlnets(model_paths):
33
+ global loaded_ControlNets
34
+ cache = {}
35
+ for p in model_paths:
36
+ if p is not None:
37
+ if p in loaded_ControlNets:
38
+ cache[p] = loaded_ControlNets[p]
39
+ else:
40
+ cache[p] = core.load_controlnet(p)
41
+ loaded_ControlNets = cache
42
+ return
43
+
44
+
45
+ @torch.no_grad()
46
+ @torch.inference_mode()
47
+ def assert_model_integrity():
48
+ error_message = None
49
+
50
+ if not isinstance(model_base.unet_with_lora.model, SDXL):
51
+ error_message = 'You have selected base model other than SDXL. This is not supported yet.'
52
+
53
+ if error_message is not None:
54
+ raise NotImplementedError(error_message)
55
+
56
+ return True
57
+
58
+
59
+ @torch.no_grad()
60
+ @torch.inference_mode()
61
+ def refresh_base_model(name):
62
+ global model_base
63
+
64
+ filename = get_file_from_folder_list(name, modules.config.paths_checkpoints)
65
+
66
+ if model_base.filename == filename:
67
+ return
68
+
69
+ model_base = core.StableDiffusionModel()
70
+ model_base = core.load_model(filename)
71
+ print(f'Base model loaded: {model_base.filename}')
72
+ return
73
+
74
+
75
+ @torch.no_grad()
76
+ @torch.inference_mode()
77
+ def refresh_refiner_model(name):
78
+ global model_refiner
79
+
80
+ filename = get_file_from_folder_list(name, modules.config.paths_checkpoints)
81
+
82
+ if model_refiner.filename == filename:
83
+ return
84
+
85
+ model_refiner = core.StableDiffusionModel()
86
+
87
+ if name == 'None':
88
+ print(f'Refiner unloaded.')
89
+ return
90
+
91
+ model_refiner = core.load_model(filename)
92
+ print(f'Refiner model loaded: {model_refiner.filename}')
93
+
94
+ if isinstance(model_refiner.unet.model, SDXL):
95
+ model_refiner.clip = None
96
+ model_refiner.vae = None
97
+ elif isinstance(model_refiner.unet.model, SDXLRefiner):
98
+ model_refiner.clip = None
99
+ model_refiner.vae = None
100
+ else:
101
+ model_refiner.clip = None
102
+
103
+ return
104
+
105
+
106
+ @torch.no_grad()
107
+ @torch.inference_mode()
108
+ def synthesize_refiner_model():
109
+ global model_base, model_refiner
110
+
111
+ print('Synthetic Refiner Activated')
112
+ model_refiner = core.StableDiffusionModel(
113
+ unet=model_base.unet,
114
+ vae=model_base.vae,
115
+ clip=model_base.clip,
116
+ clip_vision=model_base.clip_vision,
117
+ filename=model_base.filename
118
+ )
119
+ model_refiner.vae = None
120
+ model_refiner.clip = None
121
+ model_refiner.clip_vision = None
122
+
123
+ return
124
+
125
+
126
+ @torch.no_grad()
127
+ @torch.inference_mode()
128
+ def refresh_loras(loras, base_model_additional_loras=None):
129
+ global model_base, model_refiner
130
+
131
+ if not isinstance(base_model_additional_loras, list):
132
+ base_model_additional_loras = []
133
+
134
+ model_base.refresh_loras(loras + base_model_additional_loras)
135
+ model_refiner.refresh_loras(loras)
136
+
137
+ return
138
+
139
+
140
+ @torch.no_grad()
141
+ @torch.inference_mode()
142
+ def clip_encode_single(clip, text, verbose=False):
143
+ cached = clip.fcs_cond_cache.get(text, None)
144
+ if cached is not None:
145
+ if verbose:
146
+ print(f'[CLIP Cached] {text}')
147
+ return cached
148
+ tokens = clip.tokenize(text)
149
+ result = clip.encode_from_tokens(tokens, return_pooled=True)
150
+ clip.fcs_cond_cache[text] = result
151
+ if verbose:
152
+ print(f'[CLIP Encoded] {text}')
153
+ return result
154
+
155
+
156
+ @torch.no_grad()
157
+ @torch.inference_mode()
158
+ def clone_cond(conds):
159
+ results = []
160
+
161
+ for c, p in conds:
162
+ p = p["pooled_output"]
163
+
164
+ if isinstance(c, torch.Tensor):
165
+ c = c.clone()
166
+
167
+ if isinstance(p, torch.Tensor):
168
+ p = p.clone()
169
+
170
+ results.append([c, {"pooled_output": p}])
171
+
172
+ return results
173
+
174
+
175
+ @torch.no_grad()
176
+ @torch.inference_mode()
177
+ def clip_encode(texts, pool_top_k=1):
178
+ global final_clip
179
+
180
+ if final_clip is None:
181
+ return None
182
+ if not isinstance(texts, list):
183
+ return None
184
+ if len(texts) == 0:
185
+ return None
186
+
187
+ cond_list = []
188
+ pooled_acc = 0
189
+
190
+ for i, text in enumerate(texts):
191
+ cond, pooled = clip_encode_single(final_clip, text)
192
+ cond_list.append(cond)
193
+ if i < pool_top_k:
194
+ pooled_acc += pooled
195
+
196
+ return [[torch.cat(cond_list, dim=1), {"pooled_output": pooled_acc}]]
197
+
198
+
199
+ @torch.no_grad()
200
+ @torch.inference_mode()
201
+ def clear_all_caches():
202
+ final_clip.fcs_cond_cache = {}
203
+
204
+
205
+ @torch.no_grad()
206
+ @torch.inference_mode()
207
+ def prepare_text_encoder(async_call=True):
208
+ if async_call:
209
+ # TODO: make sure that this is always called in an async way so that users cannot feel it.
210
+ pass
211
+ assert_model_integrity()
212
+ ldm_patched.modules.model_management.load_models_gpu([final_clip.patcher, final_expansion.patcher])
213
+ return
214
+
215
+
216
+ @torch.no_grad()
217
+ @torch.inference_mode()
218
+ def refresh_everything(refiner_model_name, base_model_name, loras,
219
+ base_model_additional_loras=None, use_synthetic_refiner=False):
220
+ global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, final_expansion
221
+
222
+ final_unet = None
223
+ final_clip = None
224
+ final_vae = None
225
+ final_refiner_unet = None
226
+ final_refiner_vae = None
227
+
228
+ if use_synthetic_refiner and refiner_model_name == 'None':
229
+ print('Synthetic Refiner Activated')
230
+ refresh_base_model(base_model_name)
231
+ synthesize_refiner_model()
232
+ else:
233
+ refresh_refiner_model(refiner_model_name)
234
+ refresh_base_model(base_model_name)
235
+
236
+ refresh_loras(loras, base_model_additional_loras=base_model_additional_loras)
237
+ assert_model_integrity()
238
+
239
+ final_unet = model_base.unet_with_lora
240
+ final_clip = model_base.clip_with_lora
241
+ final_vae = model_base.vae
242
+
243
+ final_refiner_unet = model_refiner.unet_with_lora
244
+ final_refiner_vae = model_refiner.vae
245
+
246
+ if final_expansion is None:
247
+ final_expansion = FooocusExpansion()
248
+
249
+ prepare_text_encoder(async_call=True)
250
+ clear_all_caches()
251
+ return
252
+
253
+
254
+ refresh_everything(
255
+ refiner_model_name=modules.config.default_refiner_model_name,
256
+ base_model_name=modules.config.default_base_model_name,
257
+ loras=modules.config.default_loras
258
+ )
259
+
260
+
261
+ @torch.no_grad()
262
+ @torch.inference_mode()
263
+ def vae_parse(latent):
264
+ if final_refiner_vae is None:
265
+ return latent
266
+
267
+ result = vae_interpose.parse(latent["samples"])
268
+ return {'samples': result}
269
+
270
+
271
+ @torch.no_grad()
272
+ @torch.inference_mode()
273
+ def calculate_sigmas_all(sampler, model, scheduler, steps):
274
+ from ldm_patched.modules.samplers import calculate_sigmas_scheduler
275
+
276
+ discard_penultimate_sigma = False
277
+ if sampler in ['dpm_2', 'dpm_2_ancestral']:
278
+ steps += 1
279
+ discard_penultimate_sigma = True
280
+
281
+ sigmas = calculate_sigmas_scheduler(model, scheduler, steps)
282
+
283
+ if discard_penultimate_sigma:
284
+ sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
285
+ return sigmas
286
+
287
+
288
+ @torch.no_grad()
289
+ @torch.inference_mode()
290
+ def calculate_sigmas(sampler, model, scheduler, steps, denoise):
291
+ if denoise is None or denoise > 0.9999:
292
+ sigmas = calculate_sigmas_all(sampler, model, scheduler, steps)
293
+ else:
294
+ new_steps = int(steps / denoise)
295
+ sigmas = calculate_sigmas_all(sampler, model, scheduler, new_steps)
296
+ sigmas = sigmas[-(steps + 1):]
297
+ return sigmas
298
+
299
+
300
+ @torch.no_grad()
301
+ @torch.inference_mode()
302
+ def get_candidate_vae(steps, switch, denoise=1.0, refiner_swap_method='joint'):
303
+ assert refiner_swap_method in ['joint', 'separate', 'vae']
304
+
305
+ if final_refiner_vae is not None and final_refiner_unet is not None:
306
+ if denoise > 0.9:
307
+ return final_vae, final_refiner_vae
308
+ else:
309
+ if denoise > (float(steps - switch) / float(steps)) ** 0.834: # karras 0.834
310
+ return final_vae, None
311
+ else:
312
+ return final_refiner_vae, None
313
+
314
+ return final_vae, final_refiner_vae
315
+
316
+
317
+ @torch.no_grad()
318
+ @torch.inference_mode()
319
+ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height, image_seed, callback, sampler_name, scheduler_name, latent=None, denoise=1.0, tiled=False, cfg_scale=7.0, refiner_swap_method='joint', disable_preview=False):
320
+ target_unet, target_vae, target_refiner_unet, target_refiner_vae, target_clip \
321
+ = final_unet, final_vae, final_refiner_unet, final_refiner_vae, final_clip
322
+
323
+ assert refiner_swap_method in ['joint', 'separate', 'vae']
324
+
325
+ if final_refiner_vae is not None and final_refiner_unet is not None:
326
+ # Refiner Use Different VAE (then it is SD15)
327
+ if denoise > 0.9:
328
+ refiner_swap_method = 'vae'
329
+ else:
330
+ refiner_swap_method = 'joint'
331
+ if denoise > (float(steps - switch) / float(steps)) ** 0.834: # karras 0.834
332
+ target_unet, target_vae, target_refiner_unet, target_refiner_vae \
333
+ = final_unet, final_vae, None, None
334
+ print(f'[Sampler] only use Base because of partial denoise.')
335
+ else:
336
+ positive_cond = clip_separate(positive_cond, target_model=final_refiner_unet.model, target_clip=final_clip)
337
+ negative_cond = clip_separate(negative_cond, target_model=final_refiner_unet.model, target_clip=final_clip)
338
+ target_unet, target_vae, target_refiner_unet, target_refiner_vae \
339
+ = final_refiner_unet, final_refiner_vae, None, None
340
+ print(f'[Sampler] only use Refiner because of partial denoise.')
341
+
342
+ print(f'[Sampler] refiner_swap_method = {refiner_swap_method}')
343
+
344
+ if latent is None:
345
+ initial_latent = core.generate_empty_latent(width=width, height=height, batch_size=1)
346
+ else:
347
+ initial_latent = latent
348
+
349
+ minmax_sigmas = calculate_sigmas(sampler=sampler_name, scheduler=scheduler_name, model=final_unet.model, steps=steps, denoise=denoise)
350
+ sigma_min, sigma_max = minmax_sigmas[minmax_sigmas > 0].min(), minmax_sigmas.max()
351
+ sigma_min = float(sigma_min.cpu().numpy())
352
+ sigma_max = float(sigma_max.cpu().numpy())
353
+ print(f'[Sampler] sigma_min = {sigma_min}, sigma_max = {sigma_max}')
354
+
355
+ modules.patch.BrownianTreeNoiseSamplerPatched.global_init(
356
+ initial_latent['samples'].to(ldm_patched.modules.model_management.get_torch_device()),
357
+ sigma_min, sigma_max, seed=image_seed, cpu=False)
358
+
359
+ decoded_latent = None
360
+
361
+ if refiner_swap_method == 'joint':
362
+ sampled_latent = core.ksampler(
363
+ model=target_unet,
364
+ refiner=target_refiner_unet,
365
+ positive=positive_cond,
366
+ negative=negative_cond,
367
+ latent=initial_latent,
368
+ steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True,
369
+ seed=image_seed,
370
+ denoise=denoise,
371
+ callback_function=callback,
372
+ cfg=cfg_scale,
373
+ sampler_name=sampler_name,
374
+ scheduler=scheduler_name,
375
+ refiner_switch=switch,
376
+ previewer_start=0,
377
+ previewer_end=steps,
378
+ disable_preview=disable_preview
379
+ )
380
+ decoded_latent = core.decode_vae(vae=target_vae, latent_image=sampled_latent, tiled=tiled)
381
+
382
+ if refiner_swap_method == 'separate':
383
+ sampled_latent = core.ksampler(
384
+ model=target_unet,
385
+ positive=positive_cond,
386
+ negative=negative_cond,
387
+ latent=initial_latent,
388
+ steps=steps, start_step=0, last_step=switch, disable_noise=False, force_full_denoise=False,
389
+ seed=image_seed,
390
+ denoise=denoise,
391
+ callback_function=callback,
392
+ cfg=cfg_scale,
393
+ sampler_name=sampler_name,
394
+ scheduler=scheduler_name,
395
+ previewer_start=0,
396
+ previewer_end=steps,
397
+ disable_preview=disable_preview
398
+ )
399
+ print('Refiner swapped by changing ksampler. Noise preserved.')
400
+
401
+ target_model = target_refiner_unet
402
+ if target_model is None:
403
+ target_model = target_unet
404
+ print('Use base model to refine itself - this may because of developer mode.')
405
+
406
+ sampled_latent = core.ksampler(
407
+ model=target_model,
408
+ positive=clip_separate(positive_cond, target_model=target_model.model, target_clip=target_clip),
409
+ negative=clip_separate(negative_cond, target_model=target_model.model, target_clip=target_clip),
410
+ latent=sampled_latent,
411
+ steps=steps, start_step=switch, last_step=steps, disable_noise=True, force_full_denoise=True,
412
+ seed=image_seed,
413
+ denoise=denoise,
414
+ callback_function=callback,
415
+ cfg=cfg_scale,
416
+ sampler_name=sampler_name,
417
+ scheduler=scheduler_name,
418
+ previewer_start=switch,
419
+ previewer_end=steps,
420
+ disable_preview=disable_preview
421
+ )
422
+
423
+ target_model = target_refiner_vae
424
+ if target_model is None:
425
+ target_model = target_vae
426
+ decoded_latent = core.decode_vae(vae=target_model, latent_image=sampled_latent, tiled=tiled)
427
+
428
+ if refiner_swap_method == 'vae':
429
+ modules.patch.patch_settings[os.getpid()].eps_record = 'vae'
430
+
431
+ if modules.inpaint_worker.current_task is not None:
432
+ modules.inpaint_worker.current_task.unswap()
433
+
434
+ sampled_latent = core.ksampler(
435
+ model=target_unet,
436
+ positive=positive_cond,
437
+ negative=negative_cond,
438
+ latent=initial_latent,
439
+ steps=steps, start_step=0, last_step=switch, disable_noise=False, force_full_denoise=True,
440
+ seed=image_seed,
441
+ denoise=denoise,
442
+ callback_function=callback,
443
+ cfg=cfg_scale,
444
+ sampler_name=sampler_name,
445
+ scheduler=scheduler_name,
446
+ previewer_start=0,
447
+ previewer_end=steps,
448
+ disable_preview=disable_preview
449
+ )
450
+ print('Fooocus VAE-based swap.')
451
+
452
+ target_model = target_refiner_unet
453
+ if target_model is None:
454
+ target_model = target_unet
455
+ print('Use base model to refine itself - this may because of developer mode.')
456
+
457
+ sampled_latent = vae_parse(sampled_latent)
458
+
459
+ k_sigmas = 1.4
460
+ sigmas = calculate_sigmas(sampler=sampler_name,
461
+ scheduler=scheduler_name,
462
+ model=target_model.model,
463
+ steps=steps,
464
+ denoise=denoise)[switch:] * k_sigmas
465
+ len_sigmas = len(sigmas) - 1
466
+
467
+ noise_mean = torch.mean(modules.patch.patch_settings[os.getpid()].eps_record, dim=1, keepdim=True)
468
+
469
+ if modules.inpaint_worker.current_task is not None:
470
+ modules.inpaint_worker.current_task.swap()
471
+
472
+ sampled_latent = core.ksampler(
473
+ model=target_model,
474
+ positive=clip_separate(positive_cond, target_model=target_model.model, target_clip=target_clip),
475
+ negative=clip_separate(negative_cond, target_model=target_model.model, target_clip=target_clip),
476
+ latent=sampled_latent,
477
+ steps=len_sigmas, start_step=0, last_step=len_sigmas, disable_noise=False, force_full_denoise=True,
478
+ seed=image_seed+1,
479
+ denoise=denoise,
480
+ callback_function=callback,
481
+ cfg=cfg_scale,
482
+ sampler_name=sampler_name,
483
+ scheduler=scheduler_name,
484
+ previewer_start=switch,
485
+ previewer_end=steps,
486
+ sigmas=sigmas,
487
+ noise_mean=noise_mean,
488
+ disable_preview=disable_preview
489
+ )
490
+
491
+ target_model = target_refiner_vae
492
+ if target_model is None:
493
+ target_model = target_vae
494
+ decoded_latent = core.decode_vae(vae=target_model, latent_image=sampled_latent, tiled=tiled)
495
+
496
+ images = core.pytorch_to_numpy(decoded_latent)
497
+ modules.patch.patch_settings[os.getpid()].eps_record = None
498
+ return images
modules/flags.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import IntEnum, Enum
2
+
3
+ disabled = 'Disabled'
4
+ enabled = 'Enabled'
5
+ subtle_variation = 'Vary (Subtle)'
6
+ strong_variation = 'Vary (Strong)'
7
+ upscale_15 = 'Upscale (1.5x)'
8
+ upscale_2 = 'Upscale (2x)'
9
+ upscale_fast = 'Upscale (Fast 2x)'
10
+
11
+ uov_list = [
12
+ disabled, subtle_variation, strong_variation, upscale_15, upscale_2, upscale_fast
13
+ ]
14
+
15
+ CIVITAI_NO_KARRAS = ["euler", "euler_ancestral", "heun", "dpm_fast", "dpm_adaptive", "ddim", "uni_pc"]
16
+
17
+ # fooocus: a1111 (Civitai)
18
+ KSAMPLER = {
19
+ "euler": "Euler",
20
+ "euler_ancestral": "Euler a",
21
+ "heun": "Heun",
22
+ "heunpp2": "",
23
+ "dpm_2": "DPM2",
24
+ "dpm_2_ancestral": "DPM2 a",
25
+ "lms": "LMS",
26
+ "dpm_fast": "DPM fast",
27
+ "dpm_adaptive": "DPM adaptive",
28
+ "dpmpp_2s_ancestral": "DPM++ 2S a",
29
+ "dpmpp_sde": "DPM++ SDE",
30
+ "dpmpp_sde_gpu": "DPM++ SDE",
31
+ "dpmpp_2m": "DPM++ 2M",
32
+ "dpmpp_2m_sde": "DPM++ 2M SDE",
33
+ "dpmpp_2m_sde_gpu": "DPM++ 2M SDE",
34
+ "dpmpp_3m_sde": "",
35
+ "dpmpp_3m_sde_gpu": "",
36
+ "ddpm": "",
37
+ "lcm": "LCM"
38
+ }
39
+
40
+ SAMPLER_EXTRA = {
41
+ "ddim": "DDIM",
42
+ "uni_pc": "UniPC",
43
+ "uni_pc_bh2": ""
44
+ }
45
+
46
+ SAMPLERS = KSAMPLER | SAMPLER_EXTRA
47
+
48
+ KSAMPLER_NAMES = list(KSAMPLER.keys())
49
+
50
+ SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo"]
51
+ SAMPLER_NAMES = KSAMPLER_NAMES + list(SAMPLER_EXTRA.keys())
52
+
53
+ sampler_list = SAMPLER_NAMES
54
+ scheduler_list = SCHEDULER_NAMES
55
+
56
+ refiner_swap_method = 'joint'
57
+
58
+ cn_ip = "ImagePrompt"
59
+ cn_ip_face = "FaceSwap"
60
+ cn_canny = "PyraCanny"
61
+ cn_cpds = "CPDS"
62
+
63
+ ip_list = [cn_ip, cn_canny, cn_cpds, cn_ip_face]
64
+ default_ip = cn_ip
65
+
66
+ default_parameters = {
67
+ cn_ip: (0.5, 0.6), cn_ip_face: (0.9, 0.75), cn_canny: (0.5, 1.0), cn_cpds: (0.5, 1.0)
68
+ } # stop, weight
69
+
70
+ output_formats = ['png', 'jpg', 'webp']
71
+
72
+ inpaint_engine_versions = ['None', 'v1', 'v2.5', 'v2.6']
73
+ inpaint_option_default = 'Inpaint or Outpaint (default)'
74
+ inpaint_option_detail = 'Improve Detail (face, hand, eyes, etc.)'
75
+ inpaint_option_modify = 'Modify Content (add objects, change background, etc.)'
76
+ inpaint_options = [inpaint_option_default, inpaint_option_detail, inpaint_option_modify]
77
+
78
+ desc_type_photo = 'Photograph'
79
+ desc_type_anime = 'Art/Anime'
80
+
81
+
82
+ class MetadataScheme(Enum):
83
+ FOOOCUS = 'fooocus'
84
+ A1111 = 'a1111'
85
+
86
+
87
+ metadata_scheme = [
88
+ (f'{MetadataScheme.FOOOCUS.value} (json)', MetadataScheme.FOOOCUS.value),
89
+ (f'{MetadataScheme.A1111.value} (plain text)', MetadataScheme.A1111.value),
90
+ ]
91
+
92
+ lora_count = 5
93
+
94
+ controlnet_image_count = 4
95
+
96
+
97
+ class Steps(IntEnum):
98
+ QUALITY = 60
99
+ SPEED = 30
100
+ EXTREME_SPEED = 8
101
+
102
+
103
+ class StepsUOV(IntEnum):
104
+ QUALITY = 36
105
+ SPEED = 18
106
+ EXTREME_SPEED = 8
107
+
108
+
109
+ class Performance(Enum):
110
+ QUALITY = 'Quality'
111
+ SPEED = 'Speed'
112
+ EXTREME_SPEED = 'Extreme Speed'
113
+
114
+ @classmethod
115
+ def list(cls) -> list:
116
+ return list(map(lambda c: c.value, cls))
117
+
118
+ def steps(self) -> int | None:
119
+ return Steps[self.name].value if Steps[self.name] else None
120
+
121
+ def steps_uov(self) -> int | None:
122
+ return StepsUOV[self.name].value if Steps[self.name] else None
123
+
124
+
125
+ performance_selections = Performance.list()
modules/gradio_hijack.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """gr.Image() component."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import warnings
6
+ from pathlib import Path
7
+ from typing import Any, Literal
8
+
9
+ import numpy as np
10
+ import PIL
11
+ import PIL.ImageOps
12
+ import gradio.routes
13
+ import importlib
14
+
15
+ from gradio_client import utils as client_utils
16
+ from gradio_client.documentation import document, set_documentation_group
17
+ from gradio_client.serializing import ImgSerializable
18
+ from PIL import Image as _Image # using _ to minimize namespace pollution
19
+
20
+ from gradio import processing_utils, utils
21
+ from gradio.components.base import IOComponent, _Keywords, Block
22
+ from gradio.deprecation import warn_style_method_deprecation
23
+ from gradio.events import (
24
+ Changeable,
25
+ Clearable,
26
+ Editable,
27
+ EventListenerMethod,
28
+ Selectable,
29
+ Streamable,
30
+ Uploadable,
31
+ )
32
+ from gradio.interpretation import TokenInterpretable
33
+
34
+ set_documentation_group("component")
35
+ _Image.init() # fixes https://github.com/gradio-app/gradio/issues/2843
36
+
37
+
38
+ @document()
39
+ class Image(
40
+ Editable,
41
+ Clearable,
42
+ Changeable,
43
+ Streamable,
44
+ Selectable,
45
+ Uploadable,
46
+ IOComponent,
47
+ ImgSerializable,
48
+ TokenInterpretable,
49
+ ):
50
+ """
51
+ Creates an image component that can be used to upload/draw images (as an input) or display images (as an output).
52
+ Preprocessing: passes the uploaded image as a {numpy.array}, {PIL.Image} or {str} filepath depending on `type` -- unless `tool` is `sketch` AND source is one of `upload` or `webcam`. In these cases, a {dict} with keys `image` and `mask` is passed, and the format of the corresponding values depends on `type`.
53
+ Postprocessing: expects a {numpy.array}, {PIL.Image} or {str} or {pathlib.Path} filepath to an image and displays the image.
54
+ Examples-format: a {str} filepath to a local file that contains the image.
55
+ Demos: image_mod, image_mod_default_image
56
+ Guides: image-classification-in-pytorch, image-classification-in-tensorflow, image-classification-with-vision-transformers, building-a-pictionary_app, create-your-own-friends-with-a-gan
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ value: str | _Image.Image | np.ndarray | None = None,
62
+ *,
63
+ shape: tuple[int, int] | None = None,
64
+ height: int | None = None,
65
+ width: int | None = None,
66
+ image_mode: Literal[
67
+ "1", "L", "P", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"
68
+ ] = "RGB",
69
+ invert_colors: bool = False,
70
+ source: Literal["upload", "webcam", "canvas"] = "upload",
71
+ tool: Literal["editor", "select", "sketch", "color-sketch"] | None = None,
72
+ type: Literal["numpy", "pil", "filepath"] = "numpy",
73
+ label: str | None = None,
74
+ every: float | None = None,
75
+ show_label: bool | None = None,
76
+ show_download_button: bool = True,
77
+ container: bool = True,
78
+ scale: int | None = None,
79
+ min_width: int = 160,
80
+ interactive: bool | None = None,
81
+ visible: bool = True,
82
+ streaming: bool = False,
83
+ elem_id: str | None = None,
84
+ elem_classes: list[str] | str | None = None,
85
+ mirror_webcam: bool = True,
86
+ brush_radius: float | None = None,
87
+ brush_color: str = "#000000",
88
+ mask_opacity: float = 0.7,
89
+ show_share_button: bool | None = None,
90
+ **kwargs,
91
+ ):
92
+ """
93
+ Parameters:
94
+ value: A PIL Image, numpy array, path or URL for the default value that Image component is going to take. If callable, the function will be called whenever the app loads to set the initial value of the component.
95
+ shape: (width, height) shape to crop and resize image when passed to function. If None, matches input image size. Pass None for either width or height to only crop and resize the other.
96
+ height: Height of the displayed image in pixels.
97
+ width: Width of the displayed image in pixels.
98
+ image_mode: "RGB" if color, or "L" if black and white. See https://pillow.readthedocs.io/en/stable/handbook/concepts.html for other supported image modes and their meaning.
99
+ invert_colors: whether to invert the image as a preprocessing step.
100
+ source: Source of image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "canvas" defaults to a white image that can be edited and drawn upon with tools.
101
+ tool: Tools used for editing. "editor" allows a full screen editor (and is the default if source is "upload" or "webcam"), "select" provides a cropping and zoom tool, "sketch" allows you to create a binary sketch (and is the default if source="canvas"), and "color-sketch" allows you to created a sketch in different colors. "color-sketch" can be used with source="upload" or "webcam" to allow sketching on an image. "sketch" can also be used with "upload" or "webcam" to create a mask over an image and in that case both the image and mask are passed into the function as a dictionary with keys "image" and "mask" respectively.
102
+ type: The format the image is converted to before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (height, width, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "filepath" passes a str path to a temporary file containing the image.
103
+ label: component name in interface.
104
+ every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute.
105
+ show_label: if True, will display label.
106
+ show_download_button: If True, will display button to download image.
107
+ container: If True, will place the component in a container - providing some extra padding around the border.
108
+ scale: relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer.
109
+ min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first.
110
+ interactive: if True, will allow users to upload and edit an image; if False, can only be used to display images. If not provided, this is inferred based on whether the component is used as an input or output.
111
+ visible: If False, component will be hidden.
112
+ streaming: If True when used in a `live` interface, will automatically stream webcam feed. Only valid is source is 'webcam'.
113
+ elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
114
+ elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.
115
+ mirror_webcam: If True webcam will be mirrored. Default is True.
116
+ brush_radius: Size of the brush for Sketch. Default is None which chooses a sensible default
117
+ brush_color: Color of the brush for Sketch as hex string. Default is "#000000".
118
+ mask_opacity: Opacity of mask drawn on image, as a value between 0 and 1.
119
+ show_share_button: If True, will show a share icon in the corner of the component that allows user to share outputs to Hugging Face Spaces Discussions. If False, icon does not appear. If set to None (default behavior), then the icon appears if this Gradio app is launched on Spaces, but not otherwise.
120
+ """
121
+ self.brush_radius = brush_radius
122
+ self.brush_color = brush_color
123
+ self.mask_opacity = mask_opacity
124
+ self.mirror_webcam = mirror_webcam
125
+ valid_types = ["numpy", "pil", "filepath"]
126
+ if type not in valid_types:
127
+ raise ValueError(
128
+ f"Invalid value for parameter `type`: {type}. Please choose from one of: {valid_types}"
129
+ )
130
+ self.type = type
131
+ self.shape = shape
132
+ self.height = height
133
+ self.width = width
134
+ self.image_mode = image_mode
135
+ valid_sources = ["upload", "webcam", "canvas"]
136
+ if source not in valid_sources:
137
+ raise ValueError(
138
+ f"Invalid value for parameter `source`: {source}. Please choose from one of: {valid_sources}"
139
+ )
140
+ self.source = source
141
+ if tool is None:
142
+ self.tool = "sketch" if source == "canvas" else "editor"
143
+ else:
144
+ self.tool = tool
145
+ self.invert_colors = invert_colors
146
+ self.streaming = streaming
147
+ self.show_download_button = show_download_button
148
+ if streaming and source != "webcam":
149
+ raise ValueError("Image streaming only available if source is 'webcam'.")
150
+ self.select: EventListenerMethod
151
+ """
152
+ Event listener for when the user clicks on a pixel within the image.
153
+ Uses event data gradio.SelectData to carry `index` to refer to the [x, y] coordinates of the clicked pixel.
154
+ See EventData documentation on how to use this event data.
155
+ """
156
+ self.show_share_button = (
157
+ (utils.get_space() is not None)
158
+ if show_share_button is None
159
+ else show_share_button
160
+ )
161
+ IOComponent.__init__(
162
+ self,
163
+ label=label,
164
+ every=every,
165
+ show_label=show_label,
166
+ container=container,
167
+ scale=scale,
168
+ min_width=min_width,
169
+ interactive=interactive,
170
+ visible=visible,
171
+ elem_id=elem_id,
172
+ elem_classes=elem_classes,
173
+ value=value,
174
+ **kwargs,
175
+ )
176
+ TokenInterpretable.__init__(self)
177
+
178
+ def get_config(self):
179
+ return {
180
+ "image_mode": self.image_mode,
181
+ "shape": self.shape,
182
+ "height": self.height,
183
+ "width": self.width,
184
+ "source": self.source,
185
+ "tool": self.tool,
186
+ "value": self.value,
187
+ "streaming": self.streaming,
188
+ "mirror_webcam": self.mirror_webcam,
189
+ "brush_radius": self.brush_radius,
190
+ "brush_color": self.brush_color,
191
+ "mask_opacity": self.mask_opacity,
192
+ "selectable": self.selectable,
193
+ "show_share_button": self.show_share_button,
194
+ "show_download_button": self.show_download_button,
195
+ **IOComponent.get_config(self),
196
+ }
197
+
198
+ @staticmethod
199
+ def update(
200
+ value: Any | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE,
201
+ height: int | None = None,
202
+ width: int | None = None,
203
+ label: str | None = None,
204
+ show_label: bool | None = None,
205
+ show_download_button: bool | None = None,
206
+ container: bool | None = None,
207
+ scale: int | None = None,
208
+ min_width: int | None = None,
209
+ interactive: bool | None = None,
210
+ visible: bool | None = None,
211
+ brush_radius: float | None = None,
212
+ brush_color: str | None = None,
213
+ mask_opacity: float | None = None,
214
+ show_share_button: bool | None = None,
215
+ ):
216
+ return {
217
+ "height": height,
218
+ "width": width,
219
+ "label": label,
220
+ "show_label": show_label,
221
+ "show_download_button": show_download_button,
222
+ "container": container,
223
+ "scale": scale,
224
+ "min_width": min_width,
225
+ "interactive": interactive,
226
+ "visible": visible,
227
+ "value": value,
228
+ "brush_radius": brush_radius,
229
+ "brush_color": brush_color,
230
+ "mask_opacity": mask_opacity,
231
+ "show_share_button": show_share_button,
232
+ "__type__": "update",
233
+ }
234
+
235
+ def _format_image(
236
+ self, im: _Image.Image | None
237
+ ) -> np.ndarray | _Image.Image | str | None:
238
+ """Helper method to format an image based on self.type"""
239
+ if im is None:
240
+ return im
241
+ fmt = im.format
242
+ if self.type == "pil":
243
+ return im
244
+ elif self.type == "numpy":
245
+ return np.array(im)
246
+ elif self.type == "filepath":
247
+ path = self.pil_to_temp_file(
248
+ im, dir=self.DEFAULT_TEMP_DIR, format=fmt or "png"
249
+ )
250
+ self.temp_files.add(path)
251
+ return path
252
+ else:
253
+ raise ValueError(
254
+ "Unknown type: "
255
+ + str(self.type)
256
+ + ". Please choose from: 'numpy', 'pil', 'filepath'."
257
+ )
258
+
259
+ def preprocess(
260
+ self, x: str | dict[str, str]
261
+ ) -> np.ndarray | _Image.Image | str | dict | None:
262
+ """
263
+ Parameters:
264
+ x: base64 url data, or (if tool == "sketch") a dict of image and mask base64 url data
265
+ Returns:
266
+ image in requested format, or (if tool == "sketch") a dict of image and mask in requested format
267
+ """
268
+ if x is None:
269
+ return x
270
+
271
+ mask = None
272
+
273
+ if self.tool == "sketch" and self.source in ["upload", "webcam"]:
274
+ if isinstance(x, dict):
275
+ x, mask = x["image"], x["mask"]
276
+
277
+ assert isinstance(x, str)
278
+ im = processing_utils.decode_base64_to_image(x)
279
+ with warnings.catch_warnings():
280
+ warnings.simplefilter("ignore")
281
+ im = im.convert(self.image_mode)
282
+ if self.shape is not None:
283
+ im = processing_utils.resize_and_crop(im, self.shape)
284
+ if self.invert_colors:
285
+ im = PIL.ImageOps.invert(im)
286
+ if (
287
+ self.source == "webcam"
288
+ and self.mirror_webcam is True
289
+ and self.tool != "color-sketch"
290
+ ):
291
+ im = PIL.ImageOps.mirror(im)
292
+
293
+ if self.tool == "sketch" and self.source in ["upload", "webcam"]:
294
+ if mask is not None:
295
+ mask_im = processing_utils.decode_base64_to_image(mask)
296
+ if mask_im.mode == "RGBA": # whiten any opaque pixels in the mask
297
+ alpha_data = mask_im.getchannel("A").convert("L")
298
+ mask_im = _Image.merge("RGB", [alpha_data, alpha_data, alpha_data])
299
+ return {
300
+ "image": self._format_image(im),
301
+ "mask": self._format_image(mask_im),
302
+ }
303
+ else:
304
+ return {
305
+ "image": self._format_image(im),
306
+ "mask": None,
307
+ }
308
+
309
+ return self._format_image(im)
310
+
311
+ def postprocess(
312
+ self, y: np.ndarray | _Image.Image | str | Path | None
313
+ ) -> str | None:
314
+ """
315
+ Parameters:
316
+ y: image as a numpy array, PIL Image, string/Path filepath, or string URL
317
+ Returns:
318
+ base64 url data
319
+ """
320
+ if y is None:
321
+ return None
322
+ if isinstance(y, np.ndarray):
323
+ return processing_utils.encode_array_to_base64(y)
324
+ elif isinstance(y, _Image.Image):
325
+ return processing_utils.encode_pil_to_base64(y)
326
+ elif isinstance(y, (str, Path)):
327
+ return client_utils.encode_url_or_file_to_base64(y)
328
+ else:
329
+ raise ValueError("Cannot process this value as an Image")
330
+
331
+ def set_interpret_parameters(self, segments: int = 16):
332
+ """
333
+ Calculates interpretation score of image subsections by splitting the image into subsections, then using a "leave one out" method to calculate the score of each subsection by whiting out the subsection and measuring the delta of the output value.
334
+ Parameters:
335
+ segments: Number of interpretation segments to split image into.
336
+ """
337
+ self.interpretation_segments = segments
338
+ return self
339
+
340
+ def _segment_by_slic(self, x):
341
+ """
342
+ Helper method that segments an image into superpixels using slic.
343
+ Parameters:
344
+ x: base64 representation of an image
345
+ """
346
+ x = processing_utils.decode_base64_to_image(x)
347
+ if self.shape is not None:
348
+ x = processing_utils.resize_and_crop(x, self.shape)
349
+ resized_and_cropped_image = np.array(x)
350
+ try:
351
+ from skimage.segmentation import slic
352
+ except (ImportError, ModuleNotFoundError) as err:
353
+ raise ValueError(
354
+ "Error: running this interpretation for images requires scikit-image, please install it first."
355
+ ) from err
356
+ try:
357
+ segments_slic = slic(
358
+ resized_and_cropped_image,
359
+ self.interpretation_segments,
360
+ compactness=10,
361
+ sigma=1,
362
+ start_label=1,
363
+ )
364
+ except TypeError: # For skimage 0.16 and older
365
+ segments_slic = slic(
366
+ resized_and_cropped_image,
367
+ self.interpretation_segments,
368
+ compactness=10,
369
+ sigma=1,
370
+ )
371
+ return segments_slic, resized_and_cropped_image
372
+
373
+ def tokenize(self, x):
374
+ """
375
+ Segments image into tokens, masks, and leave-one-out-tokens
376
+ Parameters:
377
+ x: base64 representation of an image
378
+ Returns:
379
+ tokens: list of tokens, used by the get_masked_input() method
380
+ leave_one_out_tokens: list of left-out tokens, used by the get_interpretation_neighbors() method
381
+ masks: list of masks, used by the get_interpretation_neighbors() method
382
+ """
383
+ segments_slic, resized_and_cropped_image = self._segment_by_slic(x)
384
+ tokens, masks, leave_one_out_tokens = [], [], []
385
+ replace_color = np.mean(resized_and_cropped_image, axis=(0, 1))
386
+ for segment_value in np.unique(segments_slic):
387
+ mask = segments_slic == segment_value
388
+ image_screen = np.copy(resized_and_cropped_image)
389
+ image_screen[segments_slic == segment_value] = replace_color
390
+ leave_one_out_tokens.append(
391
+ processing_utils.encode_array_to_base64(image_screen)
392
+ )
393
+ token = np.copy(resized_and_cropped_image)
394
+ token[segments_slic != segment_value] = 0
395
+ tokens.append(token)
396
+ masks.append(mask)
397
+ return tokens, leave_one_out_tokens, masks
398
+
399
+ def get_masked_inputs(self, tokens, binary_mask_matrix):
400
+ masked_inputs = []
401
+ for binary_mask_vector in binary_mask_matrix:
402
+ masked_input = np.zeros_like(tokens[0], dtype=int)
403
+ for token, b in zip(tokens, binary_mask_vector):
404
+ masked_input = masked_input + token * int(b)
405
+ masked_inputs.append(processing_utils.encode_array_to_base64(masked_input))
406
+ return masked_inputs
407
+
408
+ def get_interpretation_scores(
409
+ self, x, neighbors, scores, masks, tokens=None, **kwargs
410
+ ) -> list[list[float]]:
411
+ """
412
+ Returns:
413
+ A 2D array representing the interpretation score of each pixel of the image.
414
+ """
415
+ x = processing_utils.decode_base64_to_image(x)
416
+ if self.shape is not None:
417
+ x = processing_utils.resize_and_crop(x, self.shape)
418
+ x = np.array(x)
419
+ output_scores = np.zeros((x.shape[0], x.shape[1]))
420
+
421
+ for score, mask in zip(scores, masks):
422
+ output_scores += score * mask
423
+
424
+ max_val, min_val = np.max(output_scores), np.min(output_scores)
425
+ if max_val > 0:
426
+ output_scores = (output_scores - min_val) / (max_val - min_val)
427
+ return output_scores.tolist()
428
+
429
+ def style(self, *, height: int | None = None, width: int | None = None, **kwargs):
430
+ """
431
+ This method is deprecated. Please set these arguments in the constructor instead.
432
+ """
433
+ warn_style_method_deprecation()
434
+ if height is not None:
435
+ self.height = height
436
+ if width is not None:
437
+ self.width = width
438
+ return self
439
+
440
+ def check_streamable(self):
441
+ if self.source != "webcam":
442
+ raise ValueError("Image streaming only available if source is 'webcam'.")
443
+
444
+ def as_example(self, input_data: str | None) -> str:
445
+ if input_data is None:
446
+ return ""
447
+ elif (
448
+ self.root_url
449
+ ): # If an externally hosted image, don't convert to absolute path
450
+ return input_data
451
+ return str(utils.abspath(input_data))
452
+
453
+
454
+ all_components = []
455
+
456
+ if not hasattr(Block, 'original__init__'):
457
+ Block.original_init = Block.__init__
458
+
459
+
460
+ def blk_ini(self, *args, **kwargs):
461
+ all_components.append(self)
462
+ return Block.original_init(self, *args, **kwargs)
463
+
464
+
465
+ Block.__init__ = blk_ini
466
+
467
+
468
+ gradio.routes.asyncio = importlib.reload(gradio.routes.asyncio)
469
+
470
+ if not hasattr(gradio.routes.asyncio, 'original_wait_for'):
471
+ gradio.routes.asyncio.original_wait_for = gradio.routes.asyncio.wait_for
472
+
473
+
474
+ def patched_wait_for(fut, timeout):
475
+ del timeout
476
+ return gradio.routes.asyncio.original_wait_for(fut, timeout=65535)
477
+
478
+
479
+ gradio.routes.asyncio.wait_for = patched_wait_for
480
+
modules/html.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ css = '''
2
+ .loader-container {
3
+ display: flex; /* Use flex to align items horizontally */
4
+ align-items: center; /* Center items vertically within the container */
5
+ white-space: nowrap; /* Prevent line breaks within the container */
6
+ }
7
+
8
+ .loader {
9
+ border: 8px solid #f3f3f3; /* Light grey */
10
+ border-top: 8px solid #3498db; /* Blue */
11
+ border-radius: 50%;
12
+ width: 30px;
13
+ height: 30px;
14
+ animation: spin 2s linear infinite;
15
+ }
16
+
17
+ @keyframes spin {
18
+ 0% { transform: rotate(0deg); }
19
+ 100% { transform: rotate(360deg); }
20
+ }
21
+
22
+ /* Style the progress bar */
23
+ progress {
24
+ appearance: none; /* Remove default styling */
25
+ height: 20px; /* Set the height of the progress bar */
26
+ border-radius: 5px; /* Round the corners of the progress bar */
27
+ background-color: #f3f3f3; /* Light grey background */
28
+ width: 100%;
29
+ }
30
+
31
+ /* Style the progress bar container */
32
+ .progress-container {
33
+ margin-left: 20px;
34
+ margin-right: 20px;
35
+ flex-grow: 1; /* Allow the progress container to take up remaining space */
36
+ }
37
+
38
+ /* Set the color of the progress bar fill */
39
+ progress::-webkit-progress-value {
40
+ background-color: #3498db; /* Blue color for the fill */
41
+ }
42
+
43
+ progress::-moz-progress-bar {
44
+ background-color: #3498db; /* Blue color for the fill in Firefox */
45
+ }
46
+
47
+ /* Style the text on the progress bar */
48
+ progress::after {
49
+ content: attr(value '%'); /* Display the progress value followed by '%' */
50
+ position: absolute;
51
+ top: 50%;
52
+ left: 50%;
53
+ transform: translate(-50%, -50%);
54
+ color: white; /* Set text color */
55
+ font-size: 14px; /* Set font size */
56
+ }
57
+
58
+ /* Style other texts */
59
+ .loader-container > span {
60
+ margin-left: 5px; /* Add spacing between the progress bar and the text */
61
+ }
62
+
63
+ .progress-bar > .generating {
64
+ display: none !important;
65
+ }
66
+
67
+ .progress-bar{
68
+ height: 30px !important;
69
+ }
70
+
71
+ .type_row{
72
+ height: 80px !important;
73
+ }
74
+
75
+ .type_row_half{
76
+ height: 32px !important;
77
+ }
78
+
79
+ .scroll-hide{
80
+ resize: none !important;
81
+ }
82
+
83
+ .refresh_button{
84
+ border: none !important;
85
+ background: none !important;
86
+ font-size: none !important;
87
+ box-shadow: none !important;
88
+ }
89
+
90
+ .advanced_check_row{
91
+ width: 250px !important;
92
+ }
93
+
94
+ .min_check{
95
+ min-width: min(1px, 100%) !important;
96
+ }
97
+
98
+ .resizable_area {
99
+ resize: vertical;
100
+ overflow: auto !important;
101
+ }
102
+
103
+ .aspect_ratios label {
104
+ width: 140px !important;
105
+ }
106
+
107
+ .aspect_ratios label span {
108
+ white-space: nowrap !important;
109
+ }
110
+
111
+ .aspect_ratios label input {
112
+ margin-left: -5px !important;
113
+ }
114
+
115
+ .lora_enable label {
116
+ height: 100%;
117
+ }
118
+
119
+ .lora_enable label input {
120
+ margin: auto;
121
+ }
122
+
123
+ .lora_enable label span {
124
+ display: none;
125
+ }
126
+
127
+ @-moz-document url-prefix() {
128
+ .lora_weight input[type=number] {
129
+ width: 80px;
130
+ }
131
+ }
132
+
133
+ '''
134
+ progress_html = '''
135
+ <div class="loader-container">
136
+ <div class="loader"></div>
137
+ <div class="progress-container">
138
+ <progress value="*number*" max="100"></progress>
139
+ </div>
140
+ <span>*text*</span>
141
+ </div>
142
+ '''
143
+
144
+
145
+ def make_progress_html(number, text):
146
+ return progress_html.replace('*number*', str(number)).replace('*text*', text)
modules/inpaint_worker.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ from PIL import Image, ImageFilter
5
+ from modules.util import resample_image, set_image_shape_ceil, get_image_shape_ceil
6
+ from modules.upscaler import perform_upscale
7
+ import cv2
8
+
9
+
10
+ inpaint_head_model = None
11
+
12
+
13
+ class InpaintHead(torch.nn.Module):
14
+ def __init__(self, *args, **kwargs):
15
+ super().__init__(*args, **kwargs)
16
+ self.head = torch.nn.Parameter(torch.empty(size=(320, 5, 3, 3), device='cpu'))
17
+
18
+ def __call__(self, x):
19
+ x = torch.nn.functional.pad(x, (1, 1, 1, 1), "replicate")
20
+ return torch.nn.functional.conv2d(input=x, weight=self.head)
21
+
22
+
23
+ current_task = None
24
+
25
+
26
+ def box_blur(x, k):
27
+ x = Image.fromarray(x)
28
+ x = x.filter(ImageFilter.BoxBlur(k))
29
+ return np.array(x)
30
+
31
+
32
+ def max_filter_opencv(x, ksize=3):
33
+ # Use OpenCV maximum filter
34
+ # Make sure the input type is int16
35
+ return cv2.dilate(x, np.ones((ksize, ksize), dtype=np.int16))
36
+
37
+
38
+ def morphological_open(x):
39
+ # Convert array to int16 type via threshold operation
40
+ x_int16 = np.zeros_like(x, dtype=np.int16)
41
+ x_int16[x > 127] = 256
42
+
43
+ for i in range(32):
44
+ # Use int16 type to avoid overflow
45
+ maxed = max_filter_opencv(x_int16, ksize=3) - 8
46
+ x_int16 = np.maximum(maxed, x_int16)
47
+
48
+ # Clip negative values to 0 and convert back to uint8 type
49
+ x_uint8 = np.clip(x_int16, 0, 255).astype(np.uint8)
50
+ return x_uint8
51
+
52
+
53
+ def up255(x, t=0):
54
+ y = np.zeros_like(x).astype(np.uint8)
55
+ y[x > t] = 255
56
+ return y
57
+
58
+
59
+ def imsave(x, path):
60
+ x = Image.fromarray(x)
61
+ x.save(path)
62
+
63
+
64
+ def regulate_abcd(x, a, b, c, d):
65
+ H, W = x.shape[:2]
66
+ if a < 0:
67
+ a = 0
68
+ if a > H:
69
+ a = H
70
+ if b < 0:
71
+ b = 0
72
+ if b > H:
73
+ b = H
74
+ if c < 0:
75
+ c = 0
76
+ if c > W:
77
+ c = W
78
+ if d < 0:
79
+ d = 0
80
+ if d > W:
81
+ d = W
82
+ return int(a), int(b), int(c), int(d)
83
+
84
+
85
+ def compute_initial_abcd(x):
86
+ indices = np.where(x)
87
+ a = np.min(indices[0])
88
+ b = np.max(indices[0])
89
+ c = np.min(indices[1])
90
+ d = np.max(indices[1])
91
+ abp = (b + a) // 2
92
+ abm = (b - a) // 2
93
+ cdp = (d + c) // 2
94
+ cdm = (d - c) // 2
95
+ l = int(max(abm, cdm) * 1.15)
96
+ a = abp - l
97
+ b = abp + l + 1
98
+ c = cdp - l
99
+ d = cdp + l + 1
100
+ a, b, c, d = regulate_abcd(x, a, b, c, d)
101
+ return a, b, c, d
102
+
103
+
104
+ def solve_abcd(x, a, b, c, d, k):
105
+ k = float(k)
106
+ assert 0.0 <= k <= 1.0
107
+
108
+ H, W = x.shape[:2]
109
+ if k == 1.0:
110
+ return 0, H, 0, W
111
+ while True:
112
+ if b - a >= H * k and d - c >= W * k:
113
+ break
114
+
115
+ add_h = (b - a) < (d - c)
116
+ add_w = not add_h
117
+
118
+ if b - a == H:
119
+ add_w = True
120
+
121
+ if d - c == W:
122
+ add_h = True
123
+
124
+ if add_h:
125
+ a -= 1
126
+ b += 1
127
+
128
+ if add_w:
129
+ c -= 1
130
+ d += 1
131
+
132
+ a, b, c, d = regulate_abcd(x, a, b, c, d)
133
+ return a, b, c, d
134
+
135
+
136
+ def fooocus_fill(image, mask):
137
+ current_image = image.copy()
138
+ raw_image = image.copy()
139
+ area = np.where(mask < 127)
140
+ store = raw_image[area]
141
+
142
+ for k, repeats in [(512, 2), (256, 2), (128, 4), (64, 4), (33, 8), (15, 8), (5, 16), (3, 16)]:
143
+ for _ in range(repeats):
144
+ current_image = box_blur(current_image, k)
145
+ current_image[area] = store
146
+
147
+ return current_image
148
+
149
+
150
+ class InpaintWorker:
151
+ def __init__(self, image, mask, use_fill=True, k=0.618):
152
+ a, b, c, d = compute_initial_abcd(mask > 0)
153
+ a, b, c, d = solve_abcd(mask, a, b, c, d, k=k)
154
+
155
+ # interested area
156
+ self.interested_area = (a, b, c, d)
157
+ self.interested_mask = mask[a:b, c:d]
158
+ self.interested_image = image[a:b, c:d]
159
+
160
+ # super resolution
161
+ if get_image_shape_ceil(self.interested_image) < 1024:
162
+ self.interested_image = perform_upscale(self.interested_image)
163
+
164
+ # resize to make images ready for diffusion
165
+ self.interested_image = set_image_shape_ceil(self.interested_image, 1024)
166
+ self.interested_fill = self.interested_image.copy()
167
+ H, W, C = self.interested_image.shape
168
+
169
+ # process mask
170
+ self.interested_mask = up255(resample_image(self.interested_mask, W, H), t=127)
171
+
172
+ # compute filling
173
+ if use_fill:
174
+ self.interested_fill = fooocus_fill(self.interested_image, self.interested_mask)
175
+
176
+ # soft pixels
177
+ self.mask = morphological_open(mask)
178
+ self.image = image
179
+
180
+ # ending
181
+ self.latent = None
182
+ self.latent_after_swap = None
183
+ self.swapped = False
184
+ self.latent_mask = None
185
+ self.inpaint_head_feature = None
186
+ return
187
+
188
+ def load_latent(self, latent_fill, latent_mask, latent_swap=None):
189
+ self.latent = latent_fill
190
+ self.latent_mask = latent_mask
191
+ self.latent_after_swap = latent_swap
192
+ return
193
+
194
+ def patch(self, inpaint_head_model_path, inpaint_latent, inpaint_latent_mask, model):
195
+ global inpaint_head_model
196
+
197
+ if inpaint_head_model is None:
198
+ inpaint_head_model = InpaintHead()
199
+ sd = torch.load(inpaint_head_model_path, map_location='cpu')
200
+ inpaint_head_model.load_state_dict(sd)
201
+
202
+ feed = torch.cat([
203
+ inpaint_latent_mask,
204
+ model.model.process_latent_in(inpaint_latent)
205
+ ], dim=1)
206
+
207
+ inpaint_head_model.to(device=feed.device, dtype=feed.dtype)
208
+ inpaint_head_feature = inpaint_head_model(feed)
209
+
210
+ def input_block_patch(h, transformer_options):
211
+ if transformer_options["block"][1] == 0:
212
+ h = h + inpaint_head_feature.to(h)
213
+ return h
214
+
215
+ m = model.clone()
216
+ m.set_model_input_block_patch(input_block_patch)
217
+ return m
218
+
219
+ def swap(self):
220
+ if self.swapped:
221
+ return
222
+
223
+ if self.latent is None:
224
+ return
225
+
226
+ if self.latent_after_swap is None:
227
+ return
228
+
229
+ self.latent, self.latent_after_swap = self.latent_after_swap, self.latent
230
+ self.swapped = True
231
+ return
232
+
233
+ def unswap(self):
234
+ if not self.swapped:
235
+ return
236
+
237
+ if self.latent is None:
238
+ return
239
+
240
+ if self.latent_after_swap is None:
241
+ return
242
+
243
+ self.latent, self.latent_after_swap = self.latent_after_swap, self.latent
244
+ self.swapped = False
245
+ return
246
+
247
+ def color_correction(self, img):
248
+ fg = img.astype(np.float32)
249
+ bg = self.image.copy().astype(np.float32)
250
+ w = self.mask[:, :, None].astype(np.float32) / 255.0
251
+ y = fg * w + bg * (1 - w)
252
+ return y.clip(0, 255).astype(np.uint8)
253
+
254
+ def post_process(self, img):
255
+ a, b, c, d = self.interested_area
256
+ content = resample_image(img, d - c, b - a)
257
+ result = self.image.copy()
258
+ result[a:b, c:d] = content
259
+ result = self.color_correction(result)
260
+ return result
261
+
262
+ def visualize_mask_processing(self):
263
+ return [self.interested_fill, self.interested_mask, self.interested_image]
264
+
modules/launch_util.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import importlib
3
+ import importlib.util
4
+ import subprocess
5
+ import sys
6
+ import re
7
+ import logging
8
+ import importlib.metadata
9
+ import packaging.version
10
+ from packaging.requirements import Requirement
11
+
12
+
13
+
14
+
15
+ logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
16
+ logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
17
+
18
+ re_requirement = re.compile(r"\s*([-\w]+)\s*(?:==\s*([-+.\w]+))?\s*")
19
+
20
+ python = sys.executable
21
+ default_command_live = (os.environ.get('LAUNCH_LIVE_OUTPUT') == "1")
22
+ index_url = os.environ.get('INDEX_URL', "")
23
+
24
+ modules_path = os.path.dirname(os.path.realpath(__file__))
25
+ script_path = os.path.dirname(modules_path)
26
+
27
+
28
+ def is_installed(package):
29
+ try:
30
+ spec = importlib.util.find_spec(package)
31
+ except ModuleNotFoundError:
32
+ return False
33
+
34
+ return spec is not None
35
+
36
+
37
+ def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
38
+ if desc is not None:
39
+ print(desc)
40
+
41
+ run_kwargs = {
42
+ "args": command,
43
+ "shell": True,
44
+ "env": os.environ if custom_env is None else custom_env,
45
+ "encoding": 'utf8',
46
+ "errors": 'ignore',
47
+ }
48
+
49
+ if not live:
50
+ run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
51
+
52
+ result = subprocess.run(**run_kwargs)
53
+
54
+ if result.returncode != 0:
55
+ error_bits = [
56
+ f"{errdesc or 'Error running command'}.",
57
+ f"Command: {command}",
58
+ f"Error code: {result.returncode}",
59
+ ]
60
+ if result.stdout:
61
+ error_bits.append(f"stdout: {result.stdout}")
62
+ if result.stderr:
63
+ error_bits.append(f"stderr: {result.stderr}")
64
+ raise RuntimeError("\n".join(error_bits))
65
+
66
+ return (result.stdout or "")
67
+
68
+
69
+ def run_pip(command, desc=None, live=default_command_live):
70
+ try:
71
+ index_url_line = f' --index-url {index_url}' if index_url != '' else ''
72
+ return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}",
73
+ errdesc=f"Couldn't install {desc}", live=live)
74
+ except Exception as e:
75
+ print(e)
76
+ print(f'CMD Failed {desc}: {command}')
77
+ return None
78
+
79
+
80
+ def requirements_met(requirements_file):
81
+ with open(requirements_file, "r", encoding="utf8") as file:
82
+ for line in file:
83
+ line = line.strip()
84
+ if line == "" or line.startswith('#'):
85
+ continue
86
+
87
+ requirement = Requirement(line)
88
+ package = requirement.name
89
+
90
+ try:
91
+ version_installed = importlib.metadata.version(package)
92
+ installed_version = packaging.version.parse(version_installed)
93
+
94
+ # Check if the installed version satisfies the requirement
95
+ if installed_version not in requirement.specifier:
96
+ print(f"Version mismatch for {package}: Installed version {version_installed} does not meet requirement {requirement}")
97
+ return False
98
+ except Exception as e:
99
+ print(f"Error checking version for {package}: {e}")
100
+ return False
101
+
102
+ return True
103
+
modules/localization.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+
5
+ current_translation = {}
6
+ localization_root = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'language')
7
+
8
+
9
+ def localization_js(filename):
10
+ global current_translation
11
+
12
+ if isinstance(filename, str):
13
+ full_name = os.path.abspath(os.path.join(localization_root, filename + '.json'))
14
+ if os.path.exists(full_name):
15
+ try:
16
+ with open(full_name, encoding='utf-8') as f:
17
+ current_translation = json.load(f)
18
+ assert isinstance(current_translation, dict)
19
+ for k, v in current_translation.items():
20
+ assert isinstance(k, str)
21
+ assert isinstance(v, str)
22
+ except Exception as e:
23
+ print(str(e))
24
+ print(f'Failed to load localization file {full_name}')
25
+
26
+ # current_translation = {k: 'XXX' for k in current_translation.keys()} # use this to see if all texts are covered
27
+
28
+ return f"window.localization = {json.dumps(current_translation)}"
29
+
30
+
31
+ def dump_english_config(components):
32
+ all_texts = []
33
+ for c in components:
34
+ label = getattr(c, 'label', None)
35
+ value = getattr(c, 'value', None)
36
+ choices = getattr(c, 'choices', None)
37
+ info = getattr(c, 'info', None)
38
+
39
+ if isinstance(label, str):
40
+ all_texts.append(label)
41
+ if isinstance(value, str):
42
+ all_texts.append(value)
43
+ if isinstance(info, str):
44
+ all_texts.append(info)
45
+ if isinstance(choices, list):
46
+ for x in choices:
47
+ if isinstance(x, str):
48
+ all_texts.append(x)
49
+ if isinstance(x, tuple):
50
+ for y in x:
51
+ if isinstance(y, str):
52
+ all_texts.append(y)
53
+
54
+ config_dict = {k: k for k in all_texts if k != "" and 'progress-container' not in k}
55
+ full_name = os.path.abspath(os.path.join(localization_root, 'en.json'))
56
+
57
+ with open(full_name, "w", encoding="utf-8") as json_file:
58
+ json.dump(config_dict, json_file, indent=4)
59
+
60
+ return
modules/lora.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def match_lora(lora, to_load):
2
+ patch_dict = {}
3
+ loaded_keys = set()
4
+ for x in to_load:
5
+ real_load_key = to_load[x]
6
+ if real_load_key in lora:
7
+ patch_dict[real_load_key] = ('fooocus', lora[real_load_key])
8
+ loaded_keys.add(real_load_key)
9
+ continue
10
+
11
+ alpha_name = "{}.alpha".format(x)
12
+ alpha = None
13
+ if alpha_name in lora.keys():
14
+ alpha = lora[alpha_name].item()
15
+ loaded_keys.add(alpha_name)
16
+
17
+ regular_lora = "{}.lora_up.weight".format(x)
18
+ diffusers_lora = "{}_lora.up.weight".format(x)
19
+ transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
20
+ A_name = None
21
+
22
+ if regular_lora in lora.keys():
23
+ A_name = regular_lora
24
+ B_name = "{}.lora_down.weight".format(x)
25
+ mid_name = "{}.lora_mid.weight".format(x)
26
+ elif diffusers_lora in lora.keys():
27
+ A_name = diffusers_lora
28
+ B_name = "{}_lora.down.weight".format(x)
29
+ mid_name = None
30
+ elif transformers_lora in lora.keys():
31
+ A_name = transformers_lora
32
+ B_name ="{}.lora_linear_layer.down.weight".format(x)
33
+ mid_name = None
34
+
35
+ if A_name is not None:
36
+ mid = None
37
+ if mid_name is not None and mid_name in lora.keys():
38
+ mid = lora[mid_name]
39
+ loaded_keys.add(mid_name)
40
+ patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
41
+ loaded_keys.add(A_name)
42
+ loaded_keys.add(B_name)
43
+
44
+
45
+ ######## loha
46
+ hada_w1_a_name = "{}.hada_w1_a".format(x)
47
+ hada_w1_b_name = "{}.hada_w1_b".format(x)
48
+ hada_w2_a_name = "{}.hada_w2_a".format(x)
49
+ hada_w2_b_name = "{}.hada_w2_b".format(x)
50
+ hada_t1_name = "{}.hada_t1".format(x)
51
+ hada_t2_name = "{}.hada_t2".format(x)
52
+ if hada_w1_a_name in lora.keys():
53
+ hada_t1 = None
54
+ hada_t2 = None
55
+ if hada_t1_name in lora.keys():
56
+ hada_t1 = lora[hada_t1_name]
57
+ hada_t2 = lora[hada_t2_name]
58
+ loaded_keys.add(hada_t1_name)
59
+ loaded_keys.add(hada_t2_name)
60
+
61
+ patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
62
+ loaded_keys.add(hada_w1_a_name)
63
+ loaded_keys.add(hada_w1_b_name)
64
+ loaded_keys.add(hada_w2_a_name)
65
+ loaded_keys.add(hada_w2_b_name)
66
+
67
+
68
+ ######## lokr
69
+ lokr_w1_name = "{}.lokr_w1".format(x)
70
+ lokr_w2_name = "{}.lokr_w2".format(x)
71
+ lokr_w1_a_name = "{}.lokr_w1_a".format(x)
72
+ lokr_w1_b_name = "{}.lokr_w1_b".format(x)
73
+ lokr_t2_name = "{}.lokr_t2".format(x)
74
+ lokr_w2_a_name = "{}.lokr_w2_a".format(x)
75
+ lokr_w2_b_name = "{}.lokr_w2_b".format(x)
76
+
77
+ lokr_w1 = None
78
+ if lokr_w1_name in lora.keys():
79
+ lokr_w1 = lora[lokr_w1_name]
80
+ loaded_keys.add(lokr_w1_name)
81
+
82
+ lokr_w2 = None
83
+ if lokr_w2_name in lora.keys():
84
+ lokr_w2 = lora[lokr_w2_name]
85
+ loaded_keys.add(lokr_w2_name)
86
+
87
+ lokr_w1_a = None
88
+ if lokr_w1_a_name in lora.keys():
89
+ lokr_w1_a = lora[lokr_w1_a_name]
90
+ loaded_keys.add(lokr_w1_a_name)
91
+
92
+ lokr_w1_b = None
93
+ if lokr_w1_b_name in lora.keys():
94
+ lokr_w1_b = lora[lokr_w1_b_name]
95
+ loaded_keys.add(lokr_w1_b_name)
96
+
97
+ lokr_w2_a = None
98
+ if lokr_w2_a_name in lora.keys():
99
+ lokr_w2_a = lora[lokr_w2_a_name]
100
+ loaded_keys.add(lokr_w2_a_name)
101
+
102
+ lokr_w2_b = None
103
+ if lokr_w2_b_name in lora.keys():
104
+ lokr_w2_b = lora[lokr_w2_b_name]
105
+ loaded_keys.add(lokr_w2_b_name)
106
+
107
+ lokr_t2 = None
108
+ if lokr_t2_name in lora.keys():
109
+ lokr_t2 = lora[lokr_t2_name]
110
+ loaded_keys.add(lokr_t2_name)
111
+
112
+ if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
113
+ patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))
114
+
115
+ #glora
116
+ a1_name = "{}.a1.weight".format(x)
117
+ a2_name = "{}.a2.weight".format(x)
118
+ b1_name = "{}.b1.weight".format(x)
119
+ b2_name = "{}.b2.weight".format(x)
120
+ if a1_name in lora:
121
+ patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha))
122
+ loaded_keys.add(a1_name)
123
+ loaded_keys.add(a2_name)
124
+ loaded_keys.add(b1_name)
125
+ loaded_keys.add(b2_name)
126
+
127
+ w_norm_name = "{}.w_norm".format(x)
128
+ b_norm_name = "{}.b_norm".format(x)
129
+ w_norm = lora.get(w_norm_name, None)
130
+ b_norm = lora.get(b_norm_name, None)
131
+
132
+ if w_norm is not None:
133
+ loaded_keys.add(w_norm_name)
134
+ patch_dict[to_load[x]] = ("diff", (w_norm,))
135
+ if b_norm is not None:
136
+ loaded_keys.add(b_norm_name)
137
+ patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
138
+
139
+ diff_name = "{}.diff".format(x)
140
+ diff_weight = lora.get(diff_name, None)
141
+ if diff_weight is not None:
142
+ patch_dict[to_load[x]] = ("diff", (diff_weight,))
143
+ loaded_keys.add(diff_name)
144
+
145
+ diff_bias_name = "{}.diff_b".format(x)
146
+ diff_bias = lora.get(diff_bias_name, None)
147
+ if diff_bias is not None:
148
+ patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
149
+ loaded_keys.add(diff_bias_name)
150
+
151
+ remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys}
152
+ return patch_dict, remaining_dict
modules/meta_parser.py ADDED
@@ -0,0 +1,573 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ from abc import ABC, abstractmethod
5
+ from pathlib import Path
6
+
7
+ import gradio as gr
8
+ from PIL import Image
9
+
10
+ import fooocus_version
11
+ import modules.config
12
+ import modules.sdxl_styles
13
+ from modules.flags import MetadataScheme, Performance, Steps
14
+ from modules.flags import SAMPLERS, CIVITAI_NO_KARRAS
15
+ from modules.util import quote, unquote, extract_styles_from_prompt, is_json, get_file_from_folder_list, calculate_sha256
16
+
17
+ re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
18
+ re_param = re.compile(re_param_code)
19
+ re_imagesize = re.compile(r"^(\d+)x(\d+)$")
20
+
21
+ hash_cache = {}
22
+
23
+
24
+ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
25
+ loaded_parameter_dict = raw_metadata
26
+ if isinstance(raw_metadata, str):
27
+ loaded_parameter_dict = json.loads(raw_metadata)
28
+ assert isinstance(loaded_parameter_dict, dict)
29
+
30
+ results = [len(loaded_parameter_dict) > 0, 1]
31
+
32
+ get_str('prompt', 'Prompt', loaded_parameter_dict, results)
33
+ get_str('negative_prompt', 'Negative Prompt', loaded_parameter_dict, results)
34
+ get_list('styles', 'Styles', loaded_parameter_dict, results)
35
+ get_str('performance', 'Performance', loaded_parameter_dict, results)
36
+ get_steps('steps', 'Steps', loaded_parameter_dict, results)
37
+ get_float('overwrite_switch', 'Overwrite Switch', loaded_parameter_dict, results)
38
+ get_resolution('resolution', 'Resolution', loaded_parameter_dict, results)
39
+ get_float('guidance_scale', 'Guidance Scale', loaded_parameter_dict, results)
40
+ get_float('sharpness', 'Sharpness', loaded_parameter_dict, results)
41
+ get_adm_guidance('adm_guidance', 'ADM Guidance', loaded_parameter_dict, results)
42
+ get_str('refiner_swap_method', 'Refiner Swap Method', loaded_parameter_dict, results)
43
+ get_float('adaptive_cfg', 'CFG Mimicking from TSNR', loaded_parameter_dict, results)
44
+ get_str('base_model', 'Base Model', loaded_parameter_dict, results)
45
+ get_str('refiner_model', 'Refiner Model', loaded_parameter_dict, results)
46
+ get_float('refiner_switch', 'Refiner Switch', loaded_parameter_dict, results)
47
+ get_str('sampler', 'Sampler', loaded_parameter_dict, results)
48
+ get_str('scheduler', 'Scheduler', loaded_parameter_dict, results)
49
+ get_seed('seed', 'Seed', loaded_parameter_dict, results)
50
+
51
+ if is_generating:
52
+ results.append(gr.update())
53
+ else:
54
+ results.append(gr.update(visible=True))
55
+
56
+ results.append(gr.update(visible=False))
57
+
58
+ get_freeu('freeu', 'FreeU', loaded_parameter_dict, results)
59
+
60
+ for i in range(modules.config.default_max_lora_number):
61
+ get_lora(f'lora_combined_{i + 1}', f'LoRA {i + 1}', loaded_parameter_dict, results)
62
+
63
+ return results
64
+
65
+
66
+ def get_str(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
67
+ try:
68
+ h = source_dict.get(key, source_dict.get(fallback, default))
69
+ assert isinstance(h, str)
70
+ results.append(h)
71
+ except:
72
+ results.append(gr.update())
73
+
74
+
75
+ def get_list(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
76
+ try:
77
+ h = source_dict.get(key, source_dict.get(fallback, default))
78
+ h = eval(h)
79
+ assert isinstance(h, list)
80
+ results.append(h)
81
+ except:
82
+ results.append(gr.update())
83
+
84
+
85
+ def get_float(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
86
+ try:
87
+ h = source_dict.get(key, source_dict.get(fallback, default))
88
+ assert h is not None
89
+ h = float(h)
90
+ results.append(h)
91
+ except:
92
+ results.append(gr.update())
93
+
94
+
95
+ def get_steps(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
96
+ try:
97
+ h = source_dict.get(key, source_dict.get(fallback, default))
98
+ assert h is not None
99
+ h = int(h)
100
+ # if not in steps or in steps and performance is not the same
101
+ if h not in iter(Steps) or Steps(h).name.casefold() != source_dict.get('performance', '').replace(' ', '_').casefold():
102
+ results.append(h)
103
+ return
104
+ results.append(-1)
105
+ except:
106
+ results.append(-1)
107
+
108
+
109
+ def get_resolution(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
110
+ try:
111
+ h = source_dict.get(key, source_dict.get(fallback, default))
112
+ width, height = eval(h)
113
+ formatted = modules.config.add_ratio(f'{width}*{height}')
114
+ if formatted in modules.config.available_aspect_ratios:
115
+ results.append(formatted)
116
+ results.append(-1)
117
+ results.append(-1)
118
+ else:
119
+ results.append(gr.update())
120
+ results.append(int(width))
121
+ results.append(int(height))
122
+ except:
123
+ results.append(gr.update())
124
+ results.append(gr.update())
125
+ results.append(gr.update())
126
+
127
+
128
+ def get_seed(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
129
+ try:
130
+ h = source_dict.get(key, source_dict.get(fallback, default))
131
+ assert h is not None
132
+ h = int(h)
133
+ results.append(False)
134
+ results.append(h)
135
+ except:
136
+ results.append(gr.update())
137
+ results.append(gr.update())
138
+
139
+
140
+ def get_adm_guidance(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
141
+ try:
142
+ h = source_dict.get(key, source_dict.get(fallback, default))
143
+ p, n, e = eval(h)
144
+ results.append(float(p))
145
+ results.append(float(n))
146
+ results.append(float(e))
147
+ except:
148
+ results.append(gr.update())
149
+ results.append(gr.update())
150
+ results.append(gr.update())
151
+
152
+
153
+ def get_freeu(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
154
+ try:
155
+ h = source_dict.get(key, source_dict.get(fallback, default))
156
+ b1, b2, s1, s2 = eval(h)
157
+ results.append(True)
158
+ results.append(float(b1))
159
+ results.append(float(b2))
160
+ results.append(float(s1))
161
+ results.append(float(s2))
162
+ except:
163
+ results.append(False)
164
+ results.append(gr.update())
165
+ results.append(gr.update())
166
+ results.append(gr.update())
167
+ results.append(gr.update())
168
+
169
+
170
+ def get_lora(key: str, fallback: str | None, source_dict: dict, results: list):
171
+ try:
172
+ n, w = source_dict.get(key, source_dict.get(fallback)).split(' : ')
173
+ w = float(w)
174
+ results.append(True)
175
+ results.append(n)
176
+ results.append(w)
177
+ except:
178
+ results.append(True)
179
+ results.append('None')
180
+ results.append(1)
181
+
182
+
183
+ def get_sha256(filepath):
184
+ global hash_cache
185
+ if filepath not in hash_cache:
186
+ hash_cache[filepath] = calculate_sha256(filepath)
187
+
188
+ return hash_cache[filepath]
189
+
190
+
191
+ def parse_meta_from_preset(preset_content):
192
+ assert isinstance(preset_content, dict)
193
+ preset_prepared = {}
194
+ items = preset_content
195
+
196
+ for settings_key, meta_key in modules.config.possible_preset_keys.items():
197
+ if settings_key == "default_loras":
198
+ loras = getattr(modules.config, settings_key)
199
+ if settings_key in items:
200
+ loras = items[settings_key]
201
+ for index, lora in enumerate(loras[:5]):
202
+ preset_prepared[f'lora_combined_{index + 1}'] = ' : '.join(map(str, lora))
203
+ elif settings_key == "default_aspect_ratio":
204
+ if settings_key in items and items[settings_key] is not None:
205
+ default_aspect_ratio = items[settings_key]
206
+ width, height = default_aspect_ratio.split('*')
207
+ else:
208
+ default_aspect_ratio = getattr(modules.config, settings_key)
209
+ width, height = default_aspect_ratio.split('×')
210
+ height = height[:height.index(" ")]
211
+ preset_prepared[meta_key] = (width, height)
212
+ else:
213
+ preset_prepared[meta_key] = items[settings_key] if settings_key in items and items[
214
+ settings_key] is not None else getattr(modules.config, settings_key)
215
+
216
+ if settings_key == "default_styles" or settings_key == "default_aspect_ratio":
217
+ preset_prepared[meta_key] = str(preset_prepared[meta_key])
218
+
219
+ return preset_prepared
220
+
221
+
222
+ class MetadataParser(ABC):
223
+ def __init__(self):
224
+ self.raw_prompt: str = ''
225
+ self.full_prompt: str = ''
226
+ self.raw_negative_prompt: str = ''
227
+ self.full_negative_prompt: str = ''
228
+ self.steps: int = 30
229
+ self.base_model_name: str = ''
230
+ self.base_model_hash: str = ''
231
+ self.refiner_model_name: str = ''
232
+ self.refiner_model_hash: str = ''
233
+ self.loras: list = []
234
+
235
+ @abstractmethod
236
+ def get_scheme(self) -> MetadataScheme:
237
+ raise NotImplementedError
238
+
239
+ @abstractmethod
240
+ def parse_json(self, metadata: dict | str) -> dict:
241
+ raise NotImplementedError
242
+
243
+ @abstractmethod
244
+ def parse_string(self, metadata: dict) -> str:
245
+ raise NotImplementedError
246
+
247
+ def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_prompt, steps, base_model_name,
248
+ refiner_model_name, loras):
249
+ self.raw_prompt = raw_prompt
250
+ self.full_prompt = full_prompt
251
+ self.raw_negative_prompt = raw_negative_prompt
252
+ self.full_negative_prompt = full_negative_prompt
253
+ self.steps = steps
254
+ self.base_model_name = Path(base_model_name).stem
255
+
256
+ base_model_path = get_file_from_folder_list(base_model_name, modules.config.paths_checkpoints)
257
+ self.base_model_hash = get_sha256(base_model_path)
258
+
259
+ if refiner_model_name not in ['', 'None']:
260
+ self.refiner_model_name = Path(refiner_model_name).stem
261
+ refiner_model_path = get_file_from_folder_list(refiner_model_name, modules.config.paths_checkpoints)
262
+ self.refiner_model_hash = get_sha256(refiner_model_path)
263
+
264
+ self.loras = []
265
+ for (lora_name, lora_weight) in loras:
266
+ if lora_name != 'None':
267
+ lora_path = get_file_from_folder_list(lora_name, modules.config.paths_loras)
268
+ lora_hash = get_sha256(lora_path)
269
+ self.loras.append((Path(lora_name).stem, lora_weight, lora_hash))
270
+
271
+
272
+ class A1111MetadataParser(MetadataParser):
273
+ def get_scheme(self) -> MetadataScheme:
274
+ return MetadataScheme.A1111
275
+
276
+ fooocus_to_a1111 = {
277
+ 'raw_prompt': 'Raw prompt',
278
+ 'raw_negative_prompt': 'Raw negative prompt',
279
+ 'negative_prompt': 'Negative prompt',
280
+ 'styles': 'Styles',
281
+ 'performance': 'Performance',
282
+ 'steps': 'Steps',
283
+ 'sampler': 'Sampler',
284
+ 'scheduler': 'Scheduler',
285
+ 'guidance_scale': 'CFG scale',
286
+ 'seed': 'Seed',
287
+ 'resolution': 'Size',
288
+ 'sharpness': 'Sharpness',
289
+ 'adm_guidance': 'ADM Guidance',
290
+ 'refiner_swap_method': 'Refiner Swap Method',
291
+ 'adaptive_cfg': 'Adaptive CFG',
292
+ 'overwrite_switch': 'Overwrite Switch',
293
+ 'freeu': 'FreeU',
294
+ 'base_model': 'Model',
295
+ 'base_model_hash': 'Model hash',
296
+ 'refiner_model': 'Refiner',
297
+ 'refiner_model_hash': 'Refiner hash',
298
+ 'lora_hashes': 'Lora hashes',
299
+ 'lora_weights': 'Lora weights',
300
+ 'created_by': 'User',
301
+ 'version': 'Version'
302
+ }
303
+
304
+ def parse_json(self, metadata: str) -> dict:
305
+ metadata_prompt = ''
306
+ metadata_negative_prompt = ''
307
+
308
+ done_with_prompt = False
309
+
310
+ *lines, lastline = metadata.strip().split("\n")
311
+ if len(re_param.findall(lastline)) < 3:
312
+ lines.append(lastline)
313
+ lastline = ''
314
+
315
+ for line in lines:
316
+ line = line.strip()
317
+ if line.startswith(f"{self.fooocus_to_a1111['negative_prompt']}:"):
318
+ done_with_prompt = True
319
+ line = line[len(f"{self.fooocus_to_a1111['negative_prompt']}:"):].strip()
320
+ if done_with_prompt:
321
+ metadata_negative_prompt += ('' if metadata_negative_prompt == '' else "\n") + line
322
+ else:
323
+ metadata_prompt += ('' if metadata_prompt == '' else "\n") + line
324
+
325
+ found_styles, prompt, negative_prompt = extract_styles_from_prompt(metadata_prompt, metadata_negative_prompt)
326
+
327
+ data = {
328
+ 'prompt': prompt,
329
+ 'negative_prompt': negative_prompt
330
+ }
331
+
332
+ for k, v in re_param.findall(lastline):
333
+ try:
334
+ if v != '' and v[0] == '"' and v[-1] == '"':
335
+ v = unquote(v)
336
+
337
+ m = re_imagesize.match(v)
338
+ if m is not None:
339
+ data['resolution'] = str((m.group(1), m.group(2)))
340
+ else:
341
+ data[list(self.fooocus_to_a1111.keys())[list(self.fooocus_to_a1111.values()).index(k)]] = v
342
+ except Exception:
343
+ print(f"Error parsing \"{k}: {v}\"")
344
+
345
+ # workaround for multiline prompts
346
+ if 'raw_prompt' in data:
347
+ data['prompt'] = data['raw_prompt']
348
+ raw_prompt = data['raw_prompt'].replace("\n", ', ')
349
+ if metadata_prompt != raw_prompt and modules.sdxl_styles.fooocus_expansion not in found_styles:
350
+ found_styles.append(modules.sdxl_styles.fooocus_expansion)
351
+
352
+ if 'raw_negative_prompt' in data:
353
+ data['negative_prompt'] = data['raw_negative_prompt']
354
+
355
+ data['styles'] = str(found_styles)
356
+
357
+ # try to load performance based on steps, fallback for direct A1111 imports
358
+ if 'steps' in data and 'performance' not in data:
359
+ try:
360
+ data['performance'] = Performance[Steps(int(data['steps'])).name].value
361
+ except ValueError | KeyError:
362
+ pass
363
+
364
+ if 'sampler' in data:
365
+ data['sampler'] = data['sampler'].replace(' Karras', '')
366
+ # get key
367
+ for k, v in SAMPLERS.items():
368
+ if v == data['sampler']:
369
+ data['sampler'] = k
370
+ break
371
+
372
+ for key in ['base_model', 'refiner_model']:
373
+ if key in data:
374
+ for filename in modules.config.model_filenames:
375
+ path = Path(filename)
376
+ if data[key] == path.stem:
377
+ data[key] = filename
378
+ break
379
+
380
+ if 'lora_hashes' in data:
381
+ lora_filenames = modules.config.lora_filenames.copy()
382
+ if modules.config.sdxl_lcm_lora in lora_filenames:
383
+ lora_filenames.remove(modules.config.sdxl_lcm_lora)
384
+ for li, lora in enumerate(data['lora_hashes'].split(', ')):
385
+ lora_name, lora_hash, lora_weight = lora.split(': ')
386
+ for filename in lora_filenames:
387
+ path = Path(filename)
388
+ if lora_name == path.stem:
389
+ data[f'lora_combined_{li + 1}'] = f'{filename} : {lora_weight}'
390
+ break
391
+
392
+ return data
393
+
394
+ def parse_string(self, metadata: dict) -> str:
395
+ data = {k: v for _, k, v in metadata}
396
+
397
+ width, height = eval(data['resolution'])
398
+
399
+ sampler = data['sampler']
400
+ scheduler = data['scheduler']
401
+ if sampler in SAMPLERS and SAMPLERS[sampler] != '':
402
+ sampler = SAMPLERS[sampler]
403
+ if sampler not in CIVITAI_NO_KARRAS and scheduler == 'karras':
404
+ sampler += f' Karras'
405
+
406
+ generation_params = {
407
+ self.fooocus_to_a1111['steps']: self.steps,
408
+ self.fooocus_to_a1111['sampler']: sampler,
409
+ self.fooocus_to_a1111['seed']: data['seed'],
410
+ self.fooocus_to_a1111['resolution']: f'{width}x{height}',
411
+ self.fooocus_to_a1111['guidance_scale']: data['guidance_scale'],
412
+ self.fooocus_to_a1111['sharpness']: data['sharpness'],
413
+ self.fooocus_to_a1111['adm_guidance']: data['adm_guidance'],
414
+ self.fooocus_to_a1111['base_model']: Path(data['base_model']).stem,
415
+ self.fooocus_to_a1111['base_model_hash']: self.base_model_hash,
416
+
417
+ self.fooocus_to_a1111['performance']: data['performance'],
418
+ self.fooocus_to_a1111['scheduler']: scheduler,
419
+ # workaround for multiline prompts
420
+ self.fooocus_to_a1111['raw_prompt']: self.raw_prompt,
421
+ self.fooocus_to_a1111['raw_negative_prompt']: self.raw_negative_prompt,
422
+ }
423
+
424
+ if self.refiner_model_name not in ['', 'None']:
425
+ generation_params |= {
426
+ self.fooocus_to_a1111['refiner_model']: self.refiner_model_name,
427
+ self.fooocus_to_a1111['refiner_model_hash']: self.refiner_model_hash
428
+ }
429
+
430
+ for key in ['adaptive_cfg', 'overwrite_switch', 'refiner_swap_method', 'freeu']:
431
+ if key in data:
432
+ generation_params[self.fooocus_to_a1111[key]] = data[key]
433
+
434
+ lora_hashes = []
435
+ for index, (lora_name, lora_weight, lora_hash) in enumerate(self.loras):
436
+ # workaround for Fooocus not knowing LoRA name in LoRA metadata
437
+ lora_hashes.append(f'{lora_name}: {lora_hash}: {lora_weight}')
438
+ lora_hashes_string = ', '.join(lora_hashes)
439
+
440
+ generation_params |= {
441
+ self.fooocus_to_a1111['lora_hashes']: lora_hashes_string,
442
+ self.fooocus_to_a1111['version']: data['version']
443
+ }
444
+
445
+ if modules.config.metadata_created_by != '':
446
+ generation_params[self.fooocus_to_a1111['created_by']] = modules.config.metadata_created_by
447
+
448
+ generation_params_text = ", ".join(
449
+ [k if k == v else f'{k}: {quote(v)}' for k, v in generation_params.items() if
450
+ v is not None])
451
+ positive_prompt_resolved = ', '.join(self.full_prompt)
452
+ negative_prompt_resolved = ', '.join(self.full_negative_prompt)
453
+ negative_prompt_text = f"\nNegative prompt: {negative_prompt_resolved}" if negative_prompt_resolved else ""
454
+ return f"{positive_prompt_resolved}{negative_prompt_text}\n{generation_params_text}".strip()
455
+
456
+
457
+ class FooocusMetadataParser(MetadataParser):
458
+ def get_scheme(self) -> MetadataScheme:
459
+ return MetadataScheme.FOOOCUS
460
+
461
+ def parse_json(self, metadata: dict) -> dict:
462
+ model_filenames = modules.config.model_filenames.copy()
463
+ lora_filenames = modules.config.lora_filenames.copy()
464
+ if modules.config.sdxl_lcm_lora in lora_filenames:
465
+ lora_filenames.remove(modules.config.sdxl_lcm_lora)
466
+
467
+ for key, value in metadata.items():
468
+ if value in ['', 'None']:
469
+ continue
470
+ if key in ['base_model', 'refiner_model']:
471
+ metadata[key] = self.replace_value_with_filename(key, value, model_filenames)
472
+ elif key.startswith('lora_combined_'):
473
+ metadata[key] = self.replace_value_with_filename(key, value, lora_filenames)
474
+ else:
475
+ continue
476
+
477
+ return metadata
478
+
479
+ def parse_string(self, metadata: list) -> str:
480
+ for li, (label, key, value) in enumerate(metadata):
481
+ # remove model folder paths from metadata
482
+ if key.startswith('lora_combined_'):
483
+ name, weight = value.split(' : ')
484
+ name = Path(name).stem
485
+ value = f'{name} : {weight}'
486
+ metadata[li] = (label, key, value)
487
+
488
+ res = {k: v for _, k, v in metadata}
489
+
490
+ res['full_prompt'] = self.full_prompt
491
+ res['full_negative_prompt'] = self.full_negative_prompt
492
+ res['steps'] = self.steps
493
+ res['base_model'] = self.base_model_name
494
+ res['base_model_hash'] = self.base_model_hash
495
+
496
+ if self.refiner_model_name not in ['', 'None']:
497
+ res['refiner_model'] = self.refiner_model_name
498
+ res['refiner_model_hash'] = self.refiner_model_hash
499
+
500
+ res['loras'] = self.loras
501
+
502
+ if modules.config.metadata_created_by != '':
503
+ res['created_by'] = modules.config.metadata_created_by
504
+
505
+ return json.dumps(dict(sorted(res.items())))
506
+
507
+ @staticmethod
508
+ def replace_value_with_filename(key, value, filenames):
509
+ for filename in filenames:
510
+ path = Path(filename)
511
+ if key.startswith('lora_combined_'):
512
+ name, weight = value.split(' : ')
513
+ if name == path.stem:
514
+ return f'{filename} : {weight}'
515
+ elif value == path.stem:
516
+ return filename
517
+
518
+
519
+ def get_metadata_parser(metadata_scheme: MetadataScheme) -> MetadataParser:
520
+ match metadata_scheme:
521
+ case MetadataScheme.FOOOCUS:
522
+ return FooocusMetadataParser()
523
+ case MetadataScheme.A1111:
524
+ return A1111MetadataParser()
525
+ case _:
526
+ raise NotImplementedError
527
+
528
+
529
+ def read_info_from_image(filepath) -> tuple[str | None, MetadataScheme | None]:
530
+ with Image.open(filepath) as image:
531
+ items = (image.info or {}).copy()
532
+
533
+ parameters = items.pop('parameters', None)
534
+ metadata_scheme = items.pop('fooocus_scheme', None)
535
+ exif = items.pop('exif', None)
536
+
537
+ if parameters is not None and is_json(parameters):
538
+ parameters = json.loads(parameters)
539
+ elif exif is not None:
540
+ exif = image.getexif()
541
+ # 0x9286 = UserComment
542
+ parameters = exif.get(0x9286, None)
543
+ # 0x927C = MakerNote
544
+ metadata_scheme = exif.get(0x927C, None)
545
+
546
+ if is_json(parameters):
547
+ parameters = json.loads(parameters)
548
+
549
+ try:
550
+ metadata_scheme = MetadataScheme(metadata_scheme)
551
+ except ValueError:
552
+ metadata_scheme = None
553
+
554
+ # broad fallback
555
+ if isinstance(parameters, dict):
556
+ metadata_scheme = MetadataScheme.FOOOCUS
557
+
558
+ if isinstance(parameters, str):
559
+ metadata_scheme = MetadataScheme.A1111
560
+
561
+ return parameters, metadata_scheme
562
+
563
+
564
+ def get_exif(metadata: str | None, metadata_scheme: str):
565
+ exif = Image.Exif()
566
+ # tags see see https://github.com/python-pillow/Pillow/blob/9.2.x/src/PIL/ExifTags.py
567
+ # 0x9286 = UserComment
568
+ exif[0x9286] = metadata
569
+ # 0x0131 = Software
570
+ exif[0x0131] = 'Fooocus v' + fooocus_version.version
571
+ # 0x927C = MakerNote
572
+ exif[0x927C] = metadata_scheme
573
+ return exif
modules/model_loader.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from urllib.parse import urlparse
3
+ from typing import Optional
4
+
5
+
6
+ def load_file_from_url(
7
+ url: str,
8
+ *,
9
+ model_dir: str,
10
+ progress: bool = True,
11
+ file_name: Optional[str] = None,
12
+ ) -> str:
13
+ """Download a file from `url` into `model_dir`, using the file present if possible.
14
+
15
+ Returns the path to the downloaded file.
16
+ """
17
+ os.makedirs(model_dir, exist_ok=True)
18
+ if not file_name:
19
+ parts = urlparse(url)
20
+ file_name = os.path.basename(parts.path)
21
+ cached_file = os.path.abspath(os.path.join(model_dir, file_name))
22
+ if not os.path.exists(cached_file):
23
+ print(f'Downloading: "{url}" to {cached_file}\n')
24
+ from torch.hub import download_url_to_file
25
+ download_url_to_file(url, cached_file, progress=progress)
26
+ return cached_file
modules/ops.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import contextlib
3
+
4
+
5
+ @contextlib.contextmanager
6
+ def use_patched_ops(operations):
7
+ op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm']
8
+ backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names}
9
+
10
+ try:
11
+ for op_name in op_names:
12
+ setattr(torch.nn, op_name, getattr(operations, op_name))
13
+
14
+ yield
15
+
16
+ finally:
17
+ for op_name in op_names:
18
+ setattr(torch.nn, op_name, backups[op_name])
19
+ return
modules/patch.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import time
4
+ import math
5
+ import ldm_patched.modules.model_base
6
+ import ldm_patched.ldm.modules.diffusionmodules.openaimodel
7
+ import ldm_patched.modules.model_management
8
+ import modules.anisotropic as anisotropic
9
+ import ldm_patched.ldm.modules.attention
10
+ import ldm_patched.k_diffusion.sampling
11
+ import ldm_patched.modules.sd1_clip
12
+ import modules.inpaint_worker as inpaint_worker
13
+ import ldm_patched.ldm.modules.diffusionmodules.openaimodel
14
+ import ldm_patched.ldm.modules.diffusionmodules.model
15
+ import ldm_patched.modules.sd
16
+ import ldm_patched.controlnet.cldm
17
+ import ldm_patched.modules.model_patcher
18
+ import ldm_patched.modules.samplers
19
+ import ldm_patched.modules.args_parser
20
+ import warnings
21
+ import safetensors.torch
22
+ import modules.constants as constants
23
+
24
+ from ldm_patched.modules.samplers import calc_cond_uncond_batch
25
+ from ldm_patched.k_diffusion.sampling import BatchedBrownianTree
26
+ from ldm_patched.ldm.modules.diffusionmodules.openaimodel import forward_timestep_embed, apply_control
27
+ from modules.patch_precision import patch_all_precision
28
+ from modules.patch_clip import patch_all_clip
29
+
30
+
31
+ class PatchSettings:
32
+ def __init__(self,
33
+ sharpness=2.0,
34
+ adm_scaler_end=0.3,
35
+ positive_adm_scale=1.5,
36
+ negative_adm_scale=0.8,
37
+ controlnet_softness=0.25,
38
+ adaptive_cfg=7.0):
39
+ self.sharpness = sharpness
40
+ self.adm_scaler_end = adm_scaler_end
41
+ self.positive_adm_scale = positive_adm_scale
42
+ self.negative_adm_scale = negative_adm_scale
43
+ self.controlnet_softness = controlnet_softness
44
+ self.adaptive_cfg = adaptive_cfg
45
+ self.global_diffusion_progress = 0
46
+ self.eps_record = None
47
+
48
+
49
+ patch_settings = {}
50
+
51
+
52
+ def calculate_weight_patched(self, patches, weight, key):
53
+ for p in patches:
54
+ alpha = p[0]
55
+ v = p[1]
56
+ strength_model = p[2]
57
+
58
+ if strength_model != 1.0:
59
+ weight *= strength_model
60
+
61
+ if isinstance(v, list):
62
+ v = (self.calculate_weight(v[1:], v[0].clone(), key),)
63
+
64
+ if len(v) == 1:
65
+ patch_type = "diff"
66
+ elif len(v) == 2:
67
+ patch_type = v[0]
68
+ v = v[1]
69
+
70
+ if patch_type == "diff":
71
+ w1 = v[0]
72
+ if alpha != 0.0:
73
+ if w1.shape != weight.shape:
74
+ print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
75
+ else:
76
+ weight += alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype)
77
+ elif patch_type == "lora":
78
+ mat1 = ldm_patched.modules.model_management.cast_to_device(v[0], weight.device, torch.float32)
79
+ mat2 = ldm_patched.modules.model_management.cast_to_device(v[1], weight.device, torch.float32)
80
+ if v[2] is not None:
81
+ alpha *= v[2] / mat2.shape[0]
82
+ if v[3] is not None:
83
+ mat3 = ldm_patched.modules.model_management.cast_to_device(v[3], weight.device, torch.float32)
84
+ final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
85
+ mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1),
86
+ mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
87
+ try:
88
+ weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(
89
+ weight.shape).type(weight.dtype)
90
+ except Exception as e:
91
+ print("ERROR", key, e)
92
+ elif patch_type == "fooocus":
93
+ w1 = ldm_patched.modules.model_management.cast_to_device(v[0], weight.device, torch.float32)
94
+ w_min = ldm_patched.modules.model_management.cast_to_device(v[1], weight.device, torch.float32)
95
+ w_max = ldm_patched.modules.model_management.cast_to_device(v[2], weight.device, torch.float32)
96
+ w1 = (w1 / 255.0) * (w_max - w_min) + w_min
97
+ if alpha != 0.0:
98
+ if w1.shape != weight.shape:
99
+ print("WARNING SHAPE MISMATCH {} FOOOCUS WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
100
+ else:
101
+ weight += alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype)
102
+ elif patch_type == "lokr":
103
+ w1 = v[0]
104
+ w2 = v[1]
105
+ w1_a = v[3]
106
+ w1_b = v[4]
107
+ w2_a = v[5]
108
+ w2_b = v[6]
109
+ t2 = v[7]
110
+ dim = None
111
+
112
+ if w1 is None:
113
+ dim = w1_b.shape[0]
114
+ w1 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w1_a, weight.device, torch.float32),
115
+ ldm_patched.modules.model_management.cast_to_device(w1_b, weight.device, torch.float32))
116
+ else:
117
+ w1 = ldm_patched.modules.model_management.cast_to_device(w1, weight.device, torch.float32)
118
+
119
+ if w2 is None:
120
+ dim = w2_b.shape[0]
121
+ if t2 is None:
122
+ w2 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w2_a, weight.device, torch.float32),
123
+ ldm_patched.modules.model_management.cast_to_device(w2_b, weight.device, torch.float32))
124
+ else:
125
+ w2 = torch.einsum('i j k l, j r, i p -> p r k l',
126
+ ldm_patched.modules.model_management.cast_to_device(t2, weight.device, torch.float32),
127
+ ldm_patched.modules.model_management.cast_to_device(w2_b, weight.device, torch.float32),
128
+ ldm_patched.modules.model_management.cast_to_device(w2_a, weight.device, torch.float32))
129
+ else:
130
+ w2 = ldm_patched.modules.model_management.cast_to_device(w2, weight.device, torch.float32)
131
+
132
+ if len(w2.shape) == 4:
133
+ w1 = w1.unsqueeze(2).unsqueeze(2)
134
+ if v[2] is not None and dim is not None:
135
+ alpha *= v[2] / dim
136
+
137
+ try:
138
+ weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
139
+ except Exception as e:
140
+ print("ERROR", key, e)
141
+ elif patch_type == "loha":
142
+ w1a = v[0]
143
+ w1b = v[1]
144
+ if v[2] is not None:
145
+ alpha *= v[2] / w1b.shape[0]
146
+ w2a = v[3]
147
+ w2b = v[4]
148
+ if v[5] is not None: # cp decomposition
149
+ t1 = v[5]
150
+ t2 = v[6]
151
+ m1 = torch.einsum('i j k l, j r, i p -> p r k l',
152
+ ldm_patched.modules.model_management.cast_to_device(t1, weight.device, torch.float32),
153
+ ldm_patched.modules.model_management.cast_to_device(w1b, weight.device, torch.float32),
154
+ ldm_patched.modules.model_management.cast_to_device(w1a, weight.device, torch.float32))
155
+
156
+ m2 = torch.einsum('i j k l, j r, i p -> p r k l',
157
+ ldm_patched.modules.model_management.cast_to_device(t2, weight.device, torch.float32),
158
+ ldm_patched.modules.model_management.cast_to_device(w2b, weight.device, torch.float32),
159
+ ldm_patched.modules.model_management.cast_to_device(w2a, weight.device, torch.float32))
160
+ else:
161
+ m1 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w1a, weight.device, torch.float32),
162
+ ldm_patched.modules.model_management.cast_to_device(w1b, weight.device, torch.float32))
163
+ m2 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w2a, weight.device, torch.float32),
164
+ ldm_patched.modules.model_management.cast_to_device(w2b, weight.device, torch.float32))
165
+
166
+ try:
167
+ weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
168
+ except Exception as e:
169
+ print("ERROR", key, e)
170
+ elif patch_type == "glora":
171
+ if v[4] is not None:
172
+ alpha *= v[4] / v[0].shape[0]
173
+
174
+ a1 = ldm_patched.modules.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
175
+ a2 = ldm_patched.modules.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
176
+ b1 = ldm_patched.modules.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
177
+ b2 = ldm_patched.modules.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
178
+
179
+ weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
180
+ else:
181
+ print("patch type not recognized", patch_type, key)
182
+
183
+ return weight
184
+
185
+
186
+ class BrownianTreeNoiseSamplerPatched:
187
+ transform = None
188
+ tree = None
189
+
190
+ @staticmethod
191
+ def global_init(x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False):
192
+ if ldm_patched.modules.model_management.directml_enabled:
193
+ cpu = True
194
+
195
+ t0, t1 = transform(torch.as_tensor(sigma_min)), transform(torch.as_tensor(sigma_max))
196
+
197
+ BrownianTreeNoiseSamplerPatched.transform = transform
198
+ BrownianTreeNoiseSamplerPatched.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
199
+
200
+ def __init__(self, *args, **kwargs):
201
+ pass
202
+
203
+ @staticmethod
204
+ def __call__(sigma, sigma_next):
205
+ transform = BrownianTreeNoiseSamplerPatched.transform
206
+ tree = BrownianTreeNoiseSamplerPatched.tree
207
+
208
+ t0, t1 = transform(torch.as_tensor(sigma)), transform(torch.as_tensor(sigma_next))
209
+ return tree(t0, t1) / (t1 - t0).abs().sqrt()
210
+
211
+
212
+ def compute_cfg(uncond, cond, cfg_scale, t):
213
+ pid = os.getpid()
214
+ mimic_cfg = float(patch_settings[pid].adaptive_cfg)
215
+ real_cfg = float(cfg_scale)
216
+
217
+ real_eps = uncond + real_cfg * (cond - uncond)
218
+
219
+ if cfg_scale > patch_settings[pid].adaptive_cfg:
220
+ mimicked_eps = uncond + mimic_cfg * (cond - uncond)
221
+ return real_eps * t + mimicked_eps * (1 - t)
222
+ else:
223
+ return real_eps
224
+
225
+
226
+ def patched_sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options=None, seed=None):
227
+ pid = os.getpid()
228
+
229
+ if math.isclose(cond_scale, 1.0) and not model_options.get("disable_cfg1_optimization", False):
230
+ final_x0 = calc_cond_uncond_batch(model, cond, None, x, timestep, model_options)[0]
231
+
232
+ if patch_settings[pid].eps_record is not None:
233
+ patch_settings[pid].eps_record = ((x - final_x0) / timestep).cpu()
234
+
235
+ return final_x0
236
+
237
+ positive_x0, negative_x0 = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
238
+
239
+ positive_eps = x - positive_x0
240
+ negative_eps = x - negative_x0
241
+
242
+ alpha = 0.001 * patch_settings[pid].sharpness * patch_settings[pid].global_diffusion_progress
243
+
244
+ positive_eps_degraded = anisotropic.adaptive_anisotropic_filter(x=positive_eps, g=positive_x0)
245
+ positive_eps_degraded_weighted = positive_eps_degraded * alpha + positive_eps * (1.0 - alpha)
246
+
247
+ final_eps = compute_cfg(uncond=negative_eps, cond=positive_eps_degraded_weighted,
248
+ cfg_scale=cond_scale, t=patch_settings[pid].global_diffusion_progress)
249
+
250
+ if patch_settings[pid].eps_record is not None:
251
+ patch_settings[pid].eps_record = (final_eps / timestep).cpu()
252
+
253
+ return x - final_eps
254
+
255
+
256
+ def round_to_64(x):
257
+ h = float(x)
258
+ h = h / 64.0
259
+ h = round(h)
260
+ h = int(h)
261
+ h = h * 64
262
+ return h
263
+
264
+
265
+ def sdxl_encode_adm_patched(self, **kwargs):
266
+ clip_pooled = ldm_patched.modules.model_base.sdxl_pooled(kwargs, self.noise_augmentor)
267
+ width = kwargs.get("width", 1024)
268
+ height = kwargs.get("height", 1024)
269
+ target_width = width
270
+ target_height = height
271
+ pid = os.getpid()
272
+
273
+ if kwargs.get("prompt_type", "") == "negative":
274
+ width = float(width) * patch_settings[pid].negative_adm_scale
275
+ height = float(height) * patch_settings[pid].negative_adm_scale
276
+ elif kwargs.get("prompt_type", "") == "positive":
277
+ width = float(width) * patch_settings[pid].positive_adm_scale
278
+ height = float(height) * patch_settings[pid].positive_adm_scale
279
+
280
+ def embedder(number_list):
281
+ h = self.embedder(torch.tensor(number_list, dtype=torch.float32))
282
+ h = torch.flatten(h).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
283
+ return h
284
+
285
+ width, height = int(width), int(height)
286
+ target_width, target_height = round_to_64(target_width), round_to_64(target_height)
287
+
288
+ adm_emphasized = embedder([height, width, 0, 0, target_height, target_width])
289
+ adm_consistent = embedder([target_height, target_width, 0, 0, target_height, target_width])
290
+
291
+ clip_pooled = clip_pooled.to(adm_emphasized)
292
+ final_adm = torch.cat((clip_pooled, adm_emphasized, clip_pooled, adm_consistent), dim=1)
293
+
294
+ return final_adm
295
+
296
+
297
+ def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
298
+ if inpaint_worker.current_task is not None:
299
+ latent_processor = self.inner_model.inner_model.process_latent_in
300
+ inpaint_latent = latent_processor(inpaint_worker.current_task.latent).to(x)
301
+ inpaint_mask = inpaint_worker.current_task.latent_mask.to(x)
302
+
303
+ if getattr(self, 'energy_generator', None) is None:
304
+ # avoid bad results by using different seeds.
305
+ self.energy_generator = torch.Generator(device='cpu').manual_seed((seed + 1) % constants.MAX_SEED)
306
+
307
+ energy_sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(x.shape) - 1))
308
+ current_energy = torch.randn(
309
+ x.size(), dtype=x.dtype, generator=self.energy_generator, device="cpu").to(x) * energy_sigma
310
+ x = x * inpaint_mask + (inpaint_latent + current_energy) * (1.0 - inpaint_mask)
311
+
312
+ out = self.inner_model(x, sigma,
313
+ cond=cond,
314
+ uncond=uncond,
315
+ cond_scale=cond_scale,
316
+ model_options=model_options,
317
+ seed=seed)
318
+
319
+ out = out * inpaint_mask + inpaint_latent * (1.0 - inpaint_mask)
320
+ else:
321
+ out = self.inner_model(x, sigma,
322
+ cond=cond,
323
+ uncond=uncond,
324
+ cond_scale=cond_scale,
325
+ model_options=model_options,
326
+ seed=seed)
327
+ return out
328
+
329
+
330
+ def timed_adm(y, timesteps):
331
+ if isinstance(y, torch.Tensor) and int(y.dim()) == 2 and int(y.shape[1]) == 5632:
332
+ y_mask = (timesteps > 999.0 * (1.0 - float(patch_settings[os.getpid()].adm_scaler_end))).to(y)[..., None]
333
+ y_with_adm = y[..., :2816].clone()
334
+ y_without_adm = y[..., 2816:].clone()
335
+ return y_with_adm * y_mask + y_without_adm * (1.0 - y_mask)
336
+ return y
337
+
338
+
339
+ def patched_cldm_forward(self, x, hint, timesteps, context, y=None, **kwargs):
340
+ t_emb = ldm_patched.ldm.modules.diffusionmodules.openaimodel.timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
341
+ emb = self.time_embed(t_emb)
342
+ pid = os.getpid()
343
+
344
+ guided_hint = self.input_hint_block(hint, emb, context)
345
+
346
+ y = timed_adm(y, timesteps)
347
+
348
+ outs = []
349
+
350
+ hs = []
351
+ if self.num_classes is not None:
352
+ assert y.shape[0] == x.shape[0]
353
+ emb = emb + self.label_emb(y)
354
+
355
+ h = x
356
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
357
+ if guided_hint is not None:
358
+ h = module(h, emb, context)
359
+ h += guided_hint
360
+ guided_hint = None
361
+ else:
362
+ h = module(h, emb, context)
363
+ outs.append(zero_conv(h, emb, context))
364
+
365
+ h = self.middle_block(h, emb, context)
366
+ outs.append(self.middle_block_out(h, emb, context))
367
+
368
+ if patch_settings[pid].controlnet_softness > 0:
369
+ for i in range(10):
370
+ k = 1.0 - float(i) / 9.0
371
+ outs[i] = outs[i] * (1.0 - patch_settings[pid].controlnet_softness * k)
372
+
373
+ return outs
374
+
375
+
376
+ def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
377
+ self.current_step = 1.0 - timesteps.to(x) / 999.0
378
+ patch_settings[os.getpid()].global_diffusion_progress = float(self.current_step.detach().cpu().numpy().tolist()[0])
379
+
380
+ y = timed_adm(y, timesteps)
381
+
382
+ transformer_options["original_shape"] = list(x.shape)
383
+ transformer_options["transformer_index"] = 0
384
+ transformer_patches = transformer_options.get("patches", {})
385
+
386
+ num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
387
+ image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
388
+ time_context = kwargs.get("time_context", None)
389
+
390
+ assert (y is not None) == (
391
+ self.num_classes is not None
392
+ ), "must specify y if and only if the model is class-conditional"
393
+ hs = []
394
+ t_emb = ldm_patched.ldm.modules.diffusionmodules.openaimodel.timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
395
+ emb = self.time_embed(t_emb)
396
+
397
+ if self.num_classes is not None:
398
+ assert y.shape[0] == x.shape[0]
399
+ emb = emb + self.label_emb(y)
400
+
401
+ h = x
402
+ for id, module in enumerate(self.input_blocks):
403
+ transformer_options["block"] = ("input", id)
404
+ h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
405
+ h = apply_control(h, control, 'input')
406
+ if "input_block_patch" in transformer_patches:
407
+ patch = transformer_patches["input_block_patch"]
408
+ for p in patch:
409
+ h = p(h, transformer_options)
410
+
411
+ hs.append(h)
412
+ if "input_block_patch_after_skip" in transformer_patches:
413
+ patch = transformer_patches["input_block_patch_after_skip"]
414
+ for p in patch:
415
+ h = p(h, transformer_options)
416
+
417
+ transformer_options["block"] = ("middle", 0)
418
+ h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
419
+ h = apply_control(h, control, 'middle')
420
+
421
+ for id, module in enumerate(self.output_blocks):
422
+ transformer_options["block"] = ("output", id)
423
+ hsp = hs.pop()
424
+ hsp = apply_control(hsp, control, 'output')
425
+
426
+ if "output_block_patch" in transformer_patches:
427
+ patch = transformer_patches["output_block_patch"]
428
+ for p in patch:
429
+ h, hsp = p(h, hsp, transformer_options)
430
+
431
+ h = torch.cat([h, hsp], dim=1)
432
+ del hsp
433
+ if len(hs) > 0:
434
+ output_shape = hs[-1].shape
435
+ else:
436
+ output_shape = None
437
+ h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
438
+ h = h.type(x.dtype)
439
+ if self.predict_codebook_ids:
440
+ return self.id_predictor(h)
441
+ else:
442
+ return self.out(h)
443
+
444
+
445
+ def patched_load_models_gpu(*args, **kwargs):
446
+ execution_start_time = time.perf_counter()
447
+ y = ldm_patched.modules.model_management.load_models_gpu_origin(*args, **kwargs)
448
+ moving_time = time.perf_counter() - execution_start_time
449
+ if moving_time > 0.1:
450
+ print(f'[Fooocus Model Management] Moving model(s) has taken {moving_time:.2f} seconds')
451
+ return y
452
+
453
+
454
+ def build_loaded(module, loader_name):
455
+ original_loader_name = loader_name + '_origin'
456
+
457
+ if not hasattr(module, original_loader_name):
458
+ setattr(module, original_loader_name, getattr(module, loader_name))
459
+
460
+ original_loader = getattr(module, original_loader_name)
461
+
462
+ def loader(*args, **kwargs):
463
+ result = None
464
+ try:
465
+ result = original_loader(*args, **kwargs)
466
+ except Exception as e:
467
+ result = None
468
+ exp = str(e) + '\n'
469
+ for path in list(args) + list(kwargs.values()):
470
+ if isinstance(path, str):
471
+ if os.path.exists(path):
472
+ exp += f'File corrupted: {path} \n'
473
+ corrupted_backup_file = path + '.corrupted'
474
+ if os.path.exists(corrupted_backup_file):
475
+ os.remove(corrupted_backup_file)
476
+ os.replace(path, corrupted_backup_file)
477
+ if os.path.exists(path):
478
+ os.remove(path)
479
+ exp += f'Fooocus has tried to move the corrupted file to {corrupted_backup_file} \n'
480
+ exp += f'You may try again now and Fooocus will download models again. \n'
481
+ raise ValueError(exp)
482
+ return result
483
+
484
+ setattr(module, loader_name, loader)
485
+ return
486
+
487
+
488
+ def patch_all():
489
+ if ldm_patched.modules.model_management.directml_enabled:
490
+ ldm_patched.modules.model_management.lowvram_available = True
491
+ ldm_patched.modules.model_management.OOM_EXCEPTION = Exception
492
+
493
+ patch_all_precision()
494
+ patch_all_clip()
495
+
496
+ if not hasattr(ldm_patched.modules.model_management, 'load_models_gpu_origin'):
497
+ ldm_patched.modules.model_management.load_models_gpu_origin = ldm_patched.modules.model_management.load_models_gpu
498
+
499
+ ldm_patched.modules.model_management.load_models_gpu = patched_load_models_gpu
500
+ ldm_patched.modules.model_patcher.ModelPatcher.calculate_weight = calculate_weight_patched
501
+ ldm_patched.controlnet.cldm.ControlNet.forward = patched_cldm_forward
502
+ ldm_patched.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = patched_unet_forward
503
+ ldm_patched.modules.model_base.SDXL.encode_adm = sdxl_encode_adm_patched
504
+ ldm_patched.modules.samplers.KSamplerX0Inpaint.forward = patched_KSamplerX0Inpaint_forward
505
+ ldm_patched.k_diffusion.sampling.BrownianTreeNoiseSampler = BrownianTreeNoiseSamplerPatched
506
+ ldm_patched.modules.samplers.sampling_function = patched_sampling_function
507
+
508
+ warnings.filterwarnings(action='ignore', module='torchsde')
509
+
510
+ build_loaded(safetensors.torch, 'load_file')
511
+ build_loaded(torch, 'load')
512
+
513
+ return