Spaces:
Runtime error
Runtime error
Upload 59 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- modules/__pycache__/anisotropic.cpython-310.pyc +0 -0
- modules/__pycache__/async_worker.cpython-310.pyc +0 -0
- modules/__pycache__/auth.cpython-310.pyc +0 -0
- modules/__pycache__/config.cpython-310.pyc +0 -0
- modules/__pycache__/config.cpython-312.pyc +0 -0
- modules/__pycache__/constants.cpython-310.pyc +0 -0
- modules/__pycache__/core.cpython-310.pyc +0 -0
- modules/__pycache__/default_pipeline.cpython-310.pyc +0 -0
- modules/__pycache__/flags.cpython-310.pyc +0 -0
- modules/__pycache__/flags.cpython-312.pyc +0 -0
- modules/__pycache__/gradio_hijack.cpython-310.pyc +0 -0
- modules/__pycache__/html.cpython-310.pyc +0 -0
- modules/__pycache__/inpaint_worker.cpython-310.pyc +0 -0
- modules/__pycache__/launch_util.cpython-310.pyc +0 -0
- modules/__pycache__/launch_util.cpython-312.pyc +0 -0
- modules/__pycache__/localization.cpython-310.pyc +0 -0
- modules/__pycache__/lora.cpython-310.pyc +0 -0
- modules/__pycache__/meta_parser.cpython-310.pyc +0 -0
- modules/__pycache__/model_loader.cpython-310.pyc +0 -0
- modules/__pycache__/ops.cpython-310.pyc +0 -0
- modules/__pycache__/patch.cpython-310.pyc +0 -0
- modules/__pycache__/patch_clip.cpython-310.pyc +0 -0
- modules/__pycache__/patch_precision.cpython-310.pyc +0 -0
- modules/__pycache__/private_logger.cpython-310.pyc +0 -0
- modules/__pycache__/sample_hijack.cpython-310.pyc +0 -0
- modules/__pycache__/sdxl_styles.cpython-310.pyc +0 -0
- modules/__pycache__/sdxl_styles.cpython-312.pyc +0 -0
- modules/__pycache__/style_sorter.cpython-310.pyc +0 -0
- modules/__pycache__/ui_gradio_extensions.cpython-310.pyc +0 -0
- modules/__pycache__/upscaler.cpython-310.pyc +0 -0
- modules/__pycache__/util.cpython-310.pyc +0 -0
- modules/__pycache__/util.cpython-312.pyc +0 -0
- modules/anisotropic.py +200 -0
- modules/async_worker.py +914 -0
- modules/auth.py +41 -0
- modules/config.py +607 -0
- modules/constants.py +5 -0
- modules/core.py +339 -0
- modules/default_pipeline.py +498 -0
- modules/flags.py +125 -0
- modules/gradio_hijack.py +480 -0
- modules/html.py +146 -0
- modules/inpaint_worker.py +264 -0
- modules/launch_util.py +103 -0
- modules/localization.py +60 -0
- modules/lora.py +152 -0
- modules/meta_parser.py +573 -0
- modules/model_loader.py +26 -0
- modules/ops.py +19 -0
- modules/patch.py +513 -0
modules/__pycache__/anisotropic.cpython-310.pyc
ADDED
Binary file (5.81 kB). View file
|
|
modules/__pycache__/async_worker.cpython-310.pyc
ADDED
Binary file (21 kB). View file
|
|
modules/__pycache__/auth.cpython-310.pyc
ADDED
Binary file (1.39 kB). View file
|
|
modules/__pycache__/config.cpython-310.pyc
ADDED
Binary file (18.4 kB). View file
|
|
modules/__pycache__/config.cpython-312.pyc
ADDED
Binary file (29.3 kB). View file
|
|
modules/__pycache__/constants.cpython-310.pyc
ADDED
Binary file (309 Bytes). View file
|
|
modules/__pycache__/core.cpython-310.pyc
ADDED
Binary file (10.5 kB). View file
|
|
modules/__pycache__/default_pipeline.cpython-310.pyc
ADDED
Binary file (9.98 kB). View file
|
|
modules/__pycache__/flags.cpython-310.pyc
ADDED
Binary file (3.72 kB). View file
|
|
modules/__pycache__/flags.cpython-312.pyc
ADDED
Binary file (4.89 kB). View file
|
|
modules/__pycache__/gradio_hijack.cpython-310.pyc
ADDED
Binary file (17.3 kB). View file
|
|
modules/__pycache__/html.cpython-310.pyc
ADDED
Binary file (3.25 kB). View file
|
|
modules/__pycache__/inpaint_worker.cpython-310.pyc
ADDED
Binary file (6.99 kB). View file
|
|
modules/__pycache__/launch_util.cpython-310.pyc
ADDED
Binary file (3.14 kB). View file
|
|
modules/__pycache__/launch_util.cpython-312.pyc
ADDED
Binary file (5.21 kB). View file
|
|
modules/__pycache__/localization.cpython-310.pyc
ADDED
Binary file (1.87 kB). View file
|
|
modules/__pycache__/lora.cpython-310.pyc
ADDED
Binary file (3.26 kB). View file
|
|
modules/__pycache__/meta_parser.cpython-310.pyc
ADDED
Binary file (15.3 kB). View file
|
|
modules/__pycache__/model_loader.cpython-310.pyc
ADDED
Binary file (1.04 kB). View file
|
|
modules/__pycache__/ops.cpython-310.pyc
ADDED
Binary file (816 Bytes). View file
|
|
modules/__pycache__/patch.cpython-310.pyc
ADDED
Binary file (14.4 kB). View file
|
|
modules/__pycache__/patch_clip.cpython-310.pyc
ADDED
Binary file (6.01 kB). View file
|
|
modules/__pycache__/patch_precision.cpython-310.pyc
ADDED
Binary file (2.03 kB). View file
|
|
modules/__pycache__/private_logger.cpython-310.pyc
ADDED
Binary file (5.2 kB). View file
|
|
modules/__pycache__/sample_hijack.cpython-310.pyc
ADDED
Binary file (6.19 kB). View file
|
|
modules/__pycache__/sdxl_styles.cpython-310.pyc
ADDED
Binary file (3.61 kB). View file
|
|
modules/__pycache__/sdxl_styles.cpython-312.pyc
ADDED
Binary file (6.21 kB). View file
|
|
modules/__pycache__/style_sorter.cpython-310.pyc
ADDED
Binary file (2.42 kB). View file
|
|
modules/__pycache__/ui_gradio_extensions.cpython-310.pyc
ADDED
Binary file (2.55 kB). View file
|
|
modules/__pycache__/upscaler.cpython-310.pyc
ADDED
Binary file (1.17 kB). View file
|
|
modules/__pycache__/util.cpython-310.pyc
ADDED
Binary file (11.1 kB). View file
|
|
modules/__pycache__/util.cpython-312.pyc
ADDED
Binary file (18.4 kB). View file
|
|
modules/anisotropic.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
Tensor = torch.Tensor
|
5 |
+
Device = torch.DeviceObjType
|
6 |
+
Dtype = torch.Type
|
7 |
+
pad = torch.nn.functional.pad
|
8 |
+
|
9 |
+
|
10 |
+
def _compute_zero_padding(kernel_size: tuple[int, int] | int) -> tuple[int, int]:
|
11 |
+
ky, kx = _unpack_2d_ks(kernel_size)
|
12 |
+
return (ky - 1) // 2, (kx - 1) // 2
|
13 |
+
|
14 |
+
|
15 |
+
def _unpack_2d_ks(kernel_size: tuple[int, int] | int) -> tuple[int, int]:
|
16 |
+
if isinstance(kernel_size, int):
|
17 |
+
ky = kx = kernel_size
|
18 |
+
else:
|
19 |
+
assert len(kernel_size) == 2, '2D Kernel size should have a length of 2.'
|
20 |
+
ky, kx = kernel_size
|
21 |
+
|
22 |
+
ky = int(ky)
|
23 |
+
kx = int(kx)
|
24 |
+
return ky, kx
|
25 |
+
|
26 |
+
|
27 |
+
def gaussian(
|
28 |
+
window_size: int, sigma: Tensor | float, *, device: Device | None = None, dtype: Dtype | None = None
|
29 |
+
) -> Tensor:
|
30 |
+
|
31 |
+
batch_size = sigma.shape[0]
|
32 |
+
|
33 |
+
x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
|
34 |
+
|
35 |
+
if window_size % 2 == 0:
|
36 |
+
x = x + 0.5
|
37 |
+
|
38 |
+
gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
|
39 |
+
|
40 |
+
return gauss / gauss.sum(-1, keepdim=True)
|
41 |
+
|
42 |
+
|
43 |
+
def get_gaussian_kernel1d(
|
44 |
+
kernel_size: int,
|
45 |
+
sigma: float | Tensor,
|
46 |
+
force_even: bool = False,
|
47 |
+
*,
|
48 |
+
device: Device | None = None,
|
49 |
+
dtype: Dtype | None = None,
|
50 |
+
) -> Tensor:
|
51 |
+
|
52 |
+
return gaussian(kernel_size, sigma, device=device, dtype=dtype)
|
53 |
+
|
54 |
+
|
55 |
+
def get_gaussian_kernel2d(
|
56 |
+
kernel_size: tuple[int, int] | int,
|
57 |
+
sigma: tuple[float, float] | Tensor,
|
58 |
+
force_even: bool = False,
|
59 |
+
*,
|
60 |
+
device: Device | None = None,
|
61 |
+
dtype: Dtype | None = None,
|
62 |
+
) -> Tensor:
|
63 |
+
|
64 |
+
sigma = torch.Tensor([[sigma, sigma]]).to(device=device, dtype=dtype)
|
65 |
+
|
66 |
+
ksize_y, ksize_x = _unpack_2d_ks(kernel_size)
|
67 |
+
sigma_y, sigma_x = sigma[:, 0, None], sigma[:, 1, None]
|
68 |
+
|
69 |
+
kernel_y = get_gaussian_kernel1d(ksize_y, sigma_y, force_even, device=device, dtype=dtype)[..., None]
|
70 |
+
kernel_x = get_gaussian_kernel1d(ksize_x, sigma_x, force_even, device=device, dtype=dtype)[..., None]
|
71 |
+
|
72 |
+
return kernel_y * kernel_x.view(-1, 1, ksize_x)
|
73 |
+
|
74 |
+
|
75 |
+
def _bilateral_blur(
|
76 |
+
input: Tensor,
|
77 |
+
guidance: Tensor | None,
|
78 |
+
kernel_size: tuple[int, int] | int,
|
79 |
+
sigma_color: float | Tensor,
|
80 |
+
sigma_space: tuple[float, float] | Tensor,
|
81 |
+
border_type: str = 'reflect',
|
82 |
+
color_distance_type: str = 'l1',
|
83 |
+
) -> Tensor:
|
84 |
+
|
85 |
+
if isinstance(sigma_color, Tensor):
|
86 |
+
sigma_color = sigma_color.to(device=input.device, dtype=input.dtype).view(-1, 1, 1, 1, 1)
|
87 |
+
|
88 |
+
ky, kx = _unpack_2d_ks(kernel_size)
|
89 |
+
pad_y, pad_x = _compute_zero_padding(kernel_size)
|
90 |
+
|
91 |
+
padded_input = pad(input, (pad_x, pad_x, pad_y, pad_y), mode=border_type)
|
92 |
+
unfolded_input = padded_input.unfold(2, ky, 1).unfold(3, kx, 1).flatten(-2) # (B, C, H, W, Ky x Kx)
|
93 |
+
|
94 |
+
if guidance is None:
|
95 |
+
guidance = input
|
96 |
+
unfolded_guidance = unfolded_input
|
97 |
+
else:
|
98 |
+
padded_guidance = pad(guidance, (pad_x, pad_x, pad_y, pad_y), mode=border_type)
|
99 |
+
unfolded_guidance = padded_guidance.unfold(2, ky, 1).unfold(3, kx, 1).flatten(-2) # (B, C, H, W, Ky x Kx)
|
100 |
+
|
101 |
+
diff = unfolded_guidance - guidance.unsqueeze(-1)
|
102 |
+
if color_distance_type == "l1":
|
103 |
+
color_distance_sq = diff.abs().sum(1, keepdim=True).square()
|
104 |
+
elif color_distance_type == "l2":
|
105 |
+
color_distance_sq = diff.square().sum(1, keepdim=True)
|
106 |
+
else:
|
107 |
+
raise ValueError("color_distance_type only acceps l1 or l2")
|
108 |
+
color_kernel = (-0.5 / sigma_color**2 * color_distance_sq).exp() # (B, 1, H, W, Ky x Kx)
|
109 |
+
|
110 |
+
space_kernel = get_gaussian_kernel2d(kernel_size, sigma_space, device=input.device, dtype=input.dtype)
|
111 |
+
space_kernel = space_kernel.view(-1, 1, 1, 1, kx * ky)
|
112 |
+
|
113 |
+
kernel = space_kernel * color_kernel
|
114 |
+
out = (unfolded_input * kernel).sum(-1) / kernel.sum(-1)
|
115 |
+
return out
|
116 |
+
|
117 |
+
|
118 |
+
def bilateral_blur(
|
119 |
+
input: Tensor,
|
120 |
+
kernel_size: tuple[int, int] | int = (13, 13),
|
121 |
+
sigma_color: float | Tensor = 3.0,
|
122 |
+
sigma_space: tuple[float, float] | Tensor = 3.0,
|
123 |
+
border_type: str = 'reflect',
|
124 |
+
color_distance_type: str = 'l1',
|
125 |
+
) -> Tensor:
|
126 |
+
return _bilateral_blur(input, None, kernel_size, sigma_color, sigma_space, border_type, color_distance_type)
|
127 |
+
|
128 |
+
|
129 |
+
def adaptive_anisotropic_filter(x, g=None):
|
130 |
+
if g is None:
|
131 |
+
g = x
|
132 |
+
s, m = torch.std_mean(g, dim=(1, 2, 3), keepdim=True)
|
133 |
+
s = s + 1e-5
|
134 |
+
guidance = (g - m) / s
|
135 |
+
y = _bilateral_blur(x, guidance,
|
136 |
+
kernel_size=(13, 13),
|
137 |
+
sigma_color=3.0,
|
138 |
+
sigma_space=3.0,
|
139 |
+
border_type='reflect',
|
140 |
+
color_distance_type='l1')
|
141 |
+
return y
|
142 |
+
|
143 |
+
|
144 |
+
def joint_bilateral_blur(
|
145 |
+
input: Tensor,
|
146 |
+
guidance: Tensor,
|
147 |
+
kernel_size: tuple[int, int] | int,
|
148 |
+
sigma_color: float | Tensor,
|
149 |
+
sigma_space: tuple[float, float] | Tensor,
|
150 |
+
border_type: str = 'reflect',
|
151 |
+
color_distance_type: str = 'l1',
|
152 |
+
) -> Tensor:
|
153 |
+
return _bilateral_blur(input, guidance, kernel_size, sigma_color, sigma_space, border_type, color_distance_type)
|
154 |
+
|
155 |
+
|
156 |
+
class _BilateralBlur(torch.nn.Module):
|
157 |
+
def __init__(
|
158 |
+
self,
|
159 |
+
kernel_size: tuple[int, int] | int,
|
160 |
+
sigma_color: float | Tensor,
|
161 |
+
sigma_space: tuple[float, float] | Tensor,
|
162 |
+
border_type: str = 'reflect',
|
163 |
+
color_distance_type: str = "l1",
|
164 |
+
) -> None:
|
165 |
+
super().__init__()
|
166 |
+
self.kernel_size = kernel_size
|
167 |
+
self.sigma_color = sigma_color
|
168 |
+
self.sigma_space = sigma_space
|
169 |
+
self.border_type = border_type
|
170 |
+
self.color_distance_type = color_distance_type
|
171 |
+
|
172 |
+
def __repr__(self) -> str:
|
173 |
+
return (
|
174 |
+
f"{self.__class__.__name__}"
|
175 |
+
f"(kernel_size={self.kernel_size}, "
|
176 |
+
f"sigma_color={self.sigma_color}, "
|
177 |
+
f"sigma_space={self.sigma_space}, "
|
178 |
+
f"border_type={self.border_type}, "
|
179 |
+
f"color_distance_type={self.color_distance_type})"
|
180 |
+
)
|
181 |
+
|
182 |
+
|
183 |
+
class BilateralBlur(_BilateralBlur):
|
184 |
+
def forward(self, input: Tensor) -> Tensor:
|
185 |
+
return bilateral_blur(
|
186 |
+
input, self.kernel_size, self.sigma_color, self.sigma_space, self.border_type, self.color_distance_type
|
187 |
+
)
|
188 |
+
|
189 |
+
|
190 |
+
class JointBilateralBlur(_BilateralBlur):
|
191 |
+
def forward(self, input: Tensor, guidance: Tensor) -> Tensor:
|
192 |
+
return joint_bilateral_blur(
|
193 |
+
input,
|
194 |
+
guidance,
|
195 |
+
self.kernel_size,
|
196 |
+
self.sigma_color,
|
197 |
+
self.sigma_space,
|
198 |
+
self.border_type,
|
199 |
+
self.color_distance_type,
|
200 |
+
)
|
modules/async_worker.py
ADDED
@@ -0,0 +1,914 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import threading
|
2 |
+
from modules.patch import PatchSettings, patch_settings, patch_all
|
3 |
+
|
4 |
+
patch_all()
|
5 |
+
|
6 |
+
class AsyncTask:
|
7 |
+
def __init__(self, args):
|
8 |
+
self.args = args
|
9 |
+
self.yields = []
|
10 |
+
self.results = []
|
11 |
+
self.last_stop = False
|
12 |
+
self.processing = False
|
13 |
+
|
14 |
+
|
15 |
+
async_tasks = []
|
16 |
+
|
17 |
+
|
18 |
+
def worker():
|
19 |
+
global async_tasks
|
20 |
+
|
21 |
+
import os
|
22 |
+
import traceback
|
23 |
+
import math
|
24 |
+
import numpy as np
|
25 |
+
import cv2
|
26 |
+
import torch
|
27 |
+
import time
|
28 |
+
import shared
|
29 |
+
import random
|
30 |
+
import copy
|
31 |
+
import modules.default_pipeline as pipeline
|
32 |
+
import modules.core as core
|
33 |
+
import modules.flags as flags
|
34 |
+
import modules.config
|
35 |
+
import modules.patch
|
36 |
+
import ldm_patched.modules.model_management
|
37 |
+
import extras.preprocessors as preprocessors
|
38 |
+
import modules.inpaint_worker as inpaint_worker
|
39 |
+
import modules.constants as constants
|
40 |
+
import extras.ip_adapter as ip_adapter
|
41 |
+
import extras.face_crop
|
42 |
+
import fooocus_version
|
43 |
+
import args_manager
|
44 |
+
|
45 |
+
from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion, apply_arrays
|
46 |
+
from modules.private_logger import log
|
47 |
+
from extras.expansion import safe_str
|
48 |
+
from modules.util import remove_empty_str, HWC3, resize_image, \
|
49 |
+
get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix
|
50 |
+
from modules.upscaler import perform_upscale
|
51 |
+
from modules.flags import Performance
|
52 |
+
from modules.meta_parser import get_metadata_parser, MetadataScheme
|
53 |
+
|
54 |
+
pid = os.getpid()
|
55 |
+
print(f'Started worker with PID {pid}')
|
56 |
+
|
57 |
+
try:
|
58 |
+
async_gradio_app = shared.gradio_root
|
59 |
+
flag = f'''App started successful. Use the app with {str(async_gradio_app.local_url)} or {str(async_gradio_app.server_name)}:{str(async_gradio_app.server_port)}'''
|
60 |
+
if async_gradio_app.share:
|
61 |
+
flag += f''' or {async_gradio_app.share_url}'''
|
62 |
+
print(flag)
|
63 |
+
except Exception as e:
|
64 |
+
print(e)
|
65 |
+
|
66 |
+
def progressbar(async_task, number, text):
|
67 |
+
print(f'[Fooocus] {text}')
|
68 |
+
async_task.yields.append(['preview', (number, text, None)])
|
69 |
+
|
70 |
+
def yield_result(async_task, imgs, do_not_show_finished_images=False):
|
71 |
+
if not isinstance(imgs, list):
|
72 |
+
imgs = [imgs]
|
73 |
+
|
74 |
+
async_task.results = async_task.results + imgs
|
75 |
+
|
76 |
+
if do_not_show_finished_images:
|
77 |
+
return
|
78 |
+
|
79 |
+
async_task.yields.append(['results', async_task.results])
|
80 |
+
return
|
81 |
+
|
82 |
+
def build_image_wall(async_task):
|
83 |
+
results = []
|
84 |
+
|
85 |
+
if len(async_task.results) < 2:
|
86 |
+
return
|
87 |
+
|
88 |
+
for img in async_task.results:
|
89 |
+
if isinstance(img, str) and os.path.exists(img):
|
90 |
+
img = cv2.imread(img)
|
91 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
92 |
+
if not isinstance(img, np.ndarray):
|
93 |
+
return
|
94 |
+
if img.ndim != 3:
|
95 |
+
return
|
96 |
+
results.append(img)
|
97 |
+
|
98 |
+
H, W, C = results[0].shape
|
99 |
+
|
100 |
+
for img in results:
|
101 |
+
Hn, Wn, Cn = img.shape
|
102 |
+
if H != Hn:
|
103 |
+
return
|
104 |
+
if W != Wn:
|
105 |
+
return
|
106 |
+
if C != Cn:
|
107 |
+
return
|
108 |
+
|
109 |
+
cols = float(len(results)) ** 0.5
|
110 |
+
cols = int(math.ceil(cols))
|
111 |
+
rows = float(len(results)) / float(cols)
|
112 |
+
rows = int(math.ceil(rows))
|
113 |
+
|
114 |
+
wall = np.zeros(shape=(H * rows, W * cols, C), dtype=np.uint8)
|
115 |
+
|
116 |
+
for y in range(rows):
|
117 |
+
for x in range(cols):
|
118 |
+
if y * cols + x < len(results):
|
119 |
+
img = results[y * cols + x]
|
120 |
+
wall[y * H:y * H + H, x * W:x * W + W, :] = img
|
121 |
+
|
122 |
+
# must use deep copy otherwise gradio is super laggy. Do not use list.append() .
|
123 |
+
async_task.results = async_task.results + [wall]
|
124 |
+
return
|
125 |
+
|
126 |
+
def apply_enabled_loras(loras):
|
127 |
+
enabled_loras = []
|
128 |
+
for lora_enabled, lora_model, lora_weight in loras:
|
129 |
+
if lora_enabled:
|
130 |
+
enabled_loras.append([lora_model, lora_weight])
|
131 |
+
|
132 |
+
return enabled_loras
|
133 |
+
|
134 |
+
@torch.no_grad()
|
135 |
+
@torch.inference_mode()
|
136 |
+
def handler(async_task):
|
137 |
+
execution_start_time = time.perf_counter()
|
138 |
+
async_task.processing = True
|
139 |
+
|
140 |
+
args = async_task.args
|
141 |
+
args.reverse()
|
142 |
+
|
143 |
+
prompt = args.pop()
|
144 |
+
negative_prompt = args.pop()
|
145 |
+
style_selections = args.pop()
|
146 |
+
performance_selection = Performance(args.pop())
|
147 |
+
aspect_ratios_selection = args.pop()
|
148 |
+
image_number = args.pop()
|
149 |
+
output_format = args.pop()
|
150 |
+
image_seed = args.pop()
|
151 |
+
sharpness = args.pop()
|
152 |
+
guidance_scale = args.pop()
|
153 |
+
base_model_name = args.pop()
|
154 |
+
refiner_model_name = args.pop()
|
155 |
+
refiner_switch = args.pop()
|
156 |
+
loras = apply_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop()), ] for _ in range(modules.config.default_max_lora_number)])
|
157 |
+
input_image_checkbox = args.pop()
|
158 |
+
current_tab = args.pop()
|
159 |
+
uov_method = args.pop()
|
160 |
+
uov_input_image = args.pop()
|
161 |
+
outpaint_selections = args.pop()
|
162 |
+
inpaint_input_image = args.pop()
|
163 |
+
inpaint_additional_prompt = args.pop()
|
164 |
+
inpaint_mask_image_upload = args.pop()
|
165 |
+
|
166 |
+
disable_preview = args.pop()
|
167 |
+
disable_intermediate_results = args.pop()
|
168 |
+
disable_seed_increment = args.pop()
|
169 |
+
adm_scaler_positive = args.pop()
|
170 |
+
adm_scaler_negative = args.pop()
|
171 |
+
adm_scaler_end = args.pop()
|
172 |
+
adaptive_cfg = args.pop()
|
173 |
+
sampler_name = args.pop()
|
174 |
+
scheduler_name = args.pop()
|
175 |
+
overwrite_step = args.pop()
|
176 |
+
overwrite_switch = args.pop()
|
177 |
+
overwrite_width = args.pop()
|
178 |
+
overwrite_height = args.pop()
|
179 |
+
overwrite_vary_strength = args.pop()
|
180 |
+
overwrite_upscale_strength = args.pop()
|
181 |
+
mixing_image_prompt_and_vary_upscale = args.pop()
|
182 |
+
mixing_image_prompt_and_inpaint = args.pop()
|
183 |
+
debugging_cn_preprocessor = args.pop()
|
184 |
+
skipping_cn_preprocessor = args.pop()
|
185 |
+
canny_low_threshold = args.pop()
|
186 |
+
canny_high_threshold = args.pop()
|
187 |
+
refiner_swap_method = args.pop()
|
188 |
+
controlnet_softness = args.pop()
|
189 |
+
freeu_enabled = args.pop()
|
190 |
+
freeu_b1 = args.pop()
|
191 |
+
freeu_b2 = args.pop()
|
192 |
+
freeu_s1 = args.pop()
|
193 |
+
freeu_s2 = args.pop()
|
194 |
+
debugging_inpaint_preprocessor = args.pop()
|
195 |
+
inpaint_disable_initial_latent = args.pop()
|
196 |
+
inpaint_engine = args.pop()
|
197 |
+
inpaint_strength = args.pop()
|
198 |
+
inpaint_respective_field = args.pop()
|
199 |
+
inpaint_mask_upload_checkbox = args.pop()
|
200 |
+
invert_mask_checkbox = args.pop()
|
201 |
+
inpaint_erode_or_dilate = args.pop()
|
202 |
+
|
203 |
+
save_metadata_to_images = args.pop() if not args_manager.args.disable_metadata else False
|
204 |
+
metadata_scheme = MetadataScheme(args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS
|
205 |
+
|
206 |
+
cn_tasks = {x: [] for x in flags.ip_list}
|
207 |
+
for _ in range(flags.controlnet_image_count):
|
208 |
+
cn_img = args.pop()
|
209 |
+
cn_stop = args.pop()
|
210 |
+
cn_weight = args.pop()
|
211 |
+
cn_type = args.pop()
|
212 |
+
if cn_img is not None:
|
213 |
+
cn_tasks[cn_type].append([cn_img, cn_stop, cn_weight])
|
214 |
+
|
215 |
+
outpaint_selections = [o.lower() for o in outpaint_selections]
|
216 |
+
base_model_additional_loras = []
|
217 |
+
raw_style_selections = copy.deepcopy(style_selections)
|
218 |
+
uov_method = uov_method.lower()
|
219 |
+
|
220 |
+
if fooocus_expansion in style_selections:
|
221 |
+
use_expansion = True
|
222 |
+
style_selections.remove(fooocus_expansion)
|
223 |
+
else:
|
224 |
+
use_expansion = False
|
225 |
+
|
226 |
+
use_style = len(style_selections) > 0
|
227 |
+
|
228 |
+
if base_model_name == refiner_model_name:
|
229 |
+
print(f'Refiner disabled because base model and refiner are same.')
|
230 |
+
refiner_model_name = 'None'
|
231 |
+
|
232 |
+
steps = performance_selection.steps()
|
233 |
+
|
234 |
+
if performance_selection == Performance.EXTREME_SPEED:
|
235 |
+
print('Enter LCM mode.')
|
236 |
+
progressbar(async_task, 1, 'Downloading LCM components ...')
|
237 |
+
loras += [(modules.config.downloading_sdxl_lcm_lora(), 1.0)]
|
238 |
+
|
239 |
+
if refiner_model_name != 'None':
|
240 |
+
print(f'Refiner disabled in LCM mode.')
|
241 |
+
|
242 |
+
refiner_model_name = 'None'
|
243 |
+
sampler_name = 'lcm'
|
244 |
+
scheduler_name = 'lcm'
|
245 |
+
sharpness = 0.0
|
246 |
+
guidance_scale = 1.0
|
247 |
+
adaptive_cfg = 1.0
|
248 |
+
refiner_switch = 1.0
|
249 |
+
adm_scaler_positive = 1.0
|
250 |
+
adm_scaler_negative = 1.0
|
251 |
+
adm_scaler_end = 0.0
|
252 |
+
|
253 |
+
print(f'[Parameters] Adaptive CFG = {adaptive_cfg}')
|
254 |
+
print(f'[Parameters] Sharpness = {sharpness}')
|
255 |
+
print(f'[Parameters] ControlNet Softness = {controlnet_softness}')
|
256 |
+
print(f'[Parameters] ADM Scale = '
|
257 |
+
f'{adm_scaler_positive} : '
|
258 |
+
f'{adm_scaler_negative} : '
|
259 |
+
f'{adm_scaler_end}')
|
260 |
+
|
261 |
+
patch_settings[pid] = PatchSettings(
|
262 |
+
sharpness,
|
263 |
+
adm_scaler_end,
|
264 |
+
adm_scaler_positive,
|
265 |
+
adm_scaler_negative,
|
266 |
+
controlnet_softness,
|
267 |
+
adaptive_cfg
|
268 |
+
)
|
269 |
+
|
270 |
+
cfg_scale = float(guidance_scale)
|
271 |
+
print(f'[Parameters] CFG = {cfg_scale}')
|
272 |
+
|
273 |
+
initial_latent = None
|
274 |
+
denoising_strength = 1.0
|
275 |
+
tiled = False
|
276 |
+
|
277 |
+
width, height = aspect_ratios_selection.replace('×', ' ').split(' ')[:2]
|
278 |
+
width, height = int(width), int(height)
|
279 |
+
|
280 |
+
skip_prompt_processing = False
|
281 |
+
|
282 |
+
inpaint_worker.current_task = None
|
283 |
+
inpaint_parameterized = inpaint_engine != 'None'
|
284 |
+
inpaint_image = None
|
285 |
+
inpaint_mask = None
|
286 |
+
inpaint_head_model_path = None
|
287 |
+
|
288 |
+
use_synthetic_refiner = False
|
289 |
+
|
290 |
+
controlnet_canny_path = None
|
291 |
+
controlnet_cpds_path = None
|
292 |
+
clip_vision_path, ip_negative_path, ip_adapter_path, ip_adapter_face_path = None, None, None, None
|
293 |
+
|
294 |
+
seed = int(image_seed)
|
295 |
+
print(f'[Parameters] Seed = {seed}')
|
296 |
+
|
297 |
+
goals = []
|
298 |
+
tasks = []
|
299 |
+
|
300 |
+
if input_image_checkbox:
|
301 |
+
if (current_tab == 'uov' or (
|
302 |
+
current_tab == 'ip' and mixing_image_prompt_and_vary_upscale)) \
|
303 |
+
and uov_method != flags.disabled and uov_input_image is not None:
|
304 |
+
uov_input_image = HWC3(uov_input_image)
|
305 |
+
if 'vary' in uov_method:
|
306 |
+
goals.append('vary')
|
307 |
+
elif 'upscale' in uov_method:
|
308 |
+
goals.append('upscale')
|
309 |
+
if 'fast' in uov_method:
|
310 |
+
skip_prompt_processing = True
|
311 |
+
else:
|
312 |
+
steps = performance_selection.steps_uov()
|
313 |
+
|
314 |
+
progressbar(async_task, 1, 'Downloading upscale models ...')
|
315 |
+
modules.config.downloading_upscale_model()
|
316 |
+
if (current_tab == 'inpaint' or (
|
317 |
+
current_tab == 'ip' and mixing_image_prompt_and_inpaint)) \
|
318 |
+
and isinstance(inpaint_input_image, dict):
|
319 |
+
inpaint_image = inpaint_input_image['image']
|
320 |
+
inpaint_mask = inpaint_input_image['mask'][:, :, 0]
|
321 |
+
|
322 |
+
if inpaint_mask_upload_checkbox:
|
323 |
+
if isinstance(inpaint_mask_image_upload, np.ndarray):
|
324 |
+
if inpaint_mask_image_upload.ndim == 3:
|
325 |
+
H, W, C = inpaint_image.shape
|
326 |
+
inpaint_mask_image_upload = resample_image(inpaint_mask_image_upload, width=W, height=H)
|
327 |
+
inpaint_mask_image_upload = np.mean(inpaint_mask_image_upload, axis=2)
|
328 |
+
inpaint_mask_image_upload = (inpaint_mask_image_upload > 127).astype(np.uint8) * 255
|
329 |
+
inpaint_mask = np.maximum(inpaint_mask, inpaint_mask_image_upload)
|
330 |
+
|
331 |
+
if int(inpaint_erode_or_dilate) != 0:
|
332 |
+
inpaint_mask = erode_or_dilate(inpaint_mask, inpaint_erode_or_dilate)
|
333 |
+
|
334 |
+
if invert_mask_checkbox:
|
335 |
+
inpaint_mask = 255 - inpaint_mask
|
336 |
+
|
337 |
+
inpaint_image = HWC3(inpaint_image)
|
338 |
+
if isinstance(inpaint_image, np.ndarray) and isinstance(inpaint_mask, np.ndarray) \
|
339 |
+
and (np.any(inpaint_mask > 127) or len(outpaint_selections) > 0):
|
340 |
+
progressbar(async_task, 1, 'Downloading upscale models ...')
|
341 |
+
modules.config.downloading_upscale_model()
|
342 |
+
if inpaint_parameterized:
|
343 |
+
progressbar(async_task, 1, 'Downloading inpainter ...')
|
344 |
+
inpaint_head_model_path, inpaint_patch_model_path = modules.config.downloading_inpaint_models(
|
345 |
+
inpaint_engine)
|
346 |
+
base_model_additional_loras += [(inpaint_patch_model_path, 1.0)]
|
347 |
+
print(f'[Inpaint] Current inpaint model is {inpaint_patch_model_path}')
|
348 |
+
if refiner_model_name == 'None':
|
349 |
+
use_synthetic_refiner = True
|
350 |
+
refiner_switch = 0.5
|
351 |
+
else:
|
352 |
+
inpaint_head_model_path, inpaint_patch_model_path = None, None
|
353 |
+
print(f'[Inpaint] Parameterized inpaint is disabled.')
|
354 |
+
if inpaint_additional_prompt != '':
|
355 |
+
if prompt == '':
|
356 |
+
prompt = inpaint_additional_prompt
|
357 |
+
else:
|
358 |
+
prompt = inpaint_additional_prompt + '\n' + prompt
|
359 |
+
goals.append('inpaint')
|
360 |
+
if current_tab == 'ip' or \
|
361 |
+
mixing_image_prompt_and_vary_upscale or \
|
362 |
+
mixing_image_prompt_and_inpaint:
|
363 |
+
goals.append('cn')
|
364 |
+
progressbar(async_task, 1, 'Downloading control models ...')
|
365 |
+
if len(cn_tasks[flags.cn_canny]) > 0:
|
366 |
+
controlnet_canny_path = modules.config.downloading_controlnet_canny()
|
367 |
+
if len(cn_tasks[flags.cn_cpds]) > 0:
|
368 |
+
controlnet_cpds_path = modules.config.downloading_controlnet_cpds()
|
369 |
+
if len(cn_tasks[flags.cn_ip]) > 0:
|
370 |
+
clip_vision_path, ip_negative_path, ip_adapter_path = modules.config.downloading_ip_adapters('ip')
|
371 |
+
if len(cn_tasks[flags.cn_ip_face]) > 0:
|
372 |
+
clip_vision_path, ip_negative_path, ip_adapter_face_path = modules.config.downloading_ip_adapters(
|
373 |
+
'face')
|
374 |
+
progressbar(async_task, 1, 'Loading control models ...')
|
375 |
+
|
376 |
+
# Load or unload CNs
|
377 |
+
pipeline.refresh_controlnets([controlnet_canny_path, controlnet_cpds_path])
|
378 |
+
ip_adapter.load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_path)
|
379 |
+
ip_adapter.load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_face_path)
|
380 |
+
|
381 |
+
if overwrite_step > 0:
|
382 |
+
steps = overwrite_step
|
383 |
+
|
384 |
+
switch = int(round(steps * refiner_switch))
|
385 |
+
|
386 |
+
if overwrite_switch > 0:
|
387 |
+
switch = overwrite_switch
|
388 |
+
|
389 |
+
if overwrite_width > 0:
|
390 |
+
width = overwrite_width
|
391 |
+
|
392 |
+
if overwrite_height > 0:
|
393 |
+
height = overwrite_height
|
394 |
+
|
395 |
+
print(f'[Parameters] Sampler = {sampler_name} - {scheduler_name}')
|
396 |
+
print(f'[Parameters] Steps = {steps} - {switch}')
|
397 |
+
|
398 |
+
progressbar(async_task, 1, 'Initializing ...')
|
399 |
+
|
400 |
+
if not skip_prompt_processing:
|
401 |
+
|
402 |
+
prompts = remove_empty_str([safe_str(p) for p in prompt.splitlines()], default='')
|
403 |
+
negative_prompts = remove_empty_str([safe_str(p) for p in negative_prompt.splitlines()], default='')
|
404 |
+
|
405 |
+
prompt = prompts[0]
|
406 |
+
negative_prompt = negative_prompts[0]
|
407 |
+
|
408 |
+
if prompt == '':
|
409 |
+
# disable expansion when empty since it is not meaningful and influences image prompt
|
410 |
+
use_expansion = False
|
411 |
+
|
412 |
+
extra_positive_prompts = prompts[1:] if len(prompts) > 1 else []
|
413 |
+
extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else []
|
414 |
+
|
415 |
+
progressbar(async_task, 3, 'Loading models ...')
|
416 |
+
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
|
417 |
+
loras=loras, base_model_additional_loras=base_model_additional_loras,
|
418 |
+
use_synthetic_refiner=use_synthetic_refiner)
|
419 |
+
|
420 |
+
progressbar(async_task, 3, 'Processing prompts ...')
|
421 |
+
tasks = []
|
422 |
+
|
423 |
+
for i in range(image_number):
|
424 |
+
if disable_seed_increment:
|
425 |
+
task_seed = seed
|
426 |
+
else:
|
427 |
+
task_seed = (seed + i) % (constants.MAX_SEED + 1) # randint is inclusive, % is not
|
428 |
+
|
429 |
+
task_rng = random.Random(task_seed) # may bind to inpaint noise in the future
|
430 |
+
task_prompt = apply_wildcards(prompt, task_rng)
|
431 |
+
task_prompt = apply_arrays(task_prompt, i)
|
432 |
+
task_negative_prompt = apply_wildcards(negative_prompt, task_rng)
|
433 |
+
task_extra_positive_prompts = [apply_wildcards(pmt, task_rng) for pmt in extra_positive_prompts]
|
434 |
+
task_extra_negative_prompts = [apply_wildcards(pmt, task_rng) for pmt in extra_negative_prompts]
|
435 |
+
|
436 |
+
positive_basic_workloads = []
|
437 |
+
negative_basic_workloads = []
|
438 |
+
|
439 |
+
if use_style:
|
440 |
+
for s in style_selections:
|
441 |
+
p, n = apply_style(s, positive=task_prompt)
|
442 |
+
positive_basic_workloads = positive_basic_workloads + p
|
443 |
+
negative_basic_workloads = negative_basic_workloads + n
|
444 |
+
else:
|
445 |
+
positive_basic_workloads.append(task_prompt)
|
446 |
+
|
447 |
+
negative_basic_workloads.append(task_negative_prompt) # Always use independent workload for negative.
|
448 |
+
|
449 |
+
positive_basic_workloads = positive_basic_workloads + task_extra_positive_prompts
|
450 |
+
negative_basic_workloads = negative_basic_workloads + task_extra_negative_prompts
|
451 |
+
|
452 |
+
positive_basic_workloads = remove_empty_str(positive_basic_workloads, default=task_prompt)
|
453 |
+
negative_basic_workloads = remove_empty_str(negative_basic_workloads, default=task_negative_prompt)
|
454 |
+
|
455 |
+
tasks.append(dict(
|
456 |
+
task_seed=task_seed,
|
457 |
+
task_prompt=task_prompt,
|
458 |
+
task_negative_prompt=task_negative_prompt,
|
459 |
+
positive=positive_basic_workloads,
|
460 |
+
negative=negative_basic_workloads,
|
461 |
+
expansion='',
|
462 |
+
c=None,
|
463 |
+
uc=None,
|
464 |
+
positive_top_k=len(positive_basic_workloads),
|
465 |
+
negative_top_k=len(negative_basic_workloads),
|
466 |
+
log_positive_prompt='\n'.join([task_prompt] + task_extra_positive_prompts),
|
467 |
+
log_negative_prompt='\n'.join([task_negative_prompt] + task_extra_negative_prompts),
|
468 |
+
))
|
469 |
+
|
470 |
+
if use_expansion:
|
471 |
+
for i, t in enumerate(tasks):
|
472 |
+
progressbar(async_task, 5, f'Preparing Fooocus text #{i + 1} ...')
|
473 |
+
expansion = pipeline.final_expansion(t['task_prompt'], t['task_seed'])
|
474 |
+
print(f'[Prompt Expansion] {expansion}')
|
475 |
+
t['expansion'] = expansion
|
476 |
+
t['positive'] = copy.deepcopy(t['positive']) + [expansion] # Deep copy.
|
477 |
+
|
478 |
+
for i, t in enumerate(tasks):
|
479 |
+
progressbar(async_task, 7, f'Encoding positive #{i + 1} ...')
|
480 |
+
t['c'] = pipeline.clip_encode(texts=t['positive'], pool_top_k=t['positive_top_k'])
|
481 |
+
|
482 |
+
for i, t in enumerate(tasks):
|
483 |
+
if abs(float(cfg_scale) - 1.0) < 1e-4:
|
484 |
+
t['uc'] = pipeline.clone_cond(t['c'])
|
485 |
+
else:
|
486 |
+
progressbar(async_task, 10, f'Encoding negative #{i + 1} ...')
|
487 |
+
t['uc'] = pipeline.clip_encode(texts=t['negative'], pool_top_k=t['negative_top_k'])
|
488 |
+
|
489 |
+
if len(goals) > 0:
|
490 |
+
progressbar(async_task, 13, 'Image processing ...')
|
491 |
+
|
492 |
+
if 'vary' in goals:
|
493 |
+
if 'subtle' in uov_method:
|
494 |
+
denoising_strength = 0.5
|
495 |
+
if 'strong' in uov_method:
|
496 |
+
denoising_strength = 0.85
|
497 |
+
if overwrite_vary_strength > 0:
|
498 |
+
denoising_strength = overwrite_vary_strength
|
499 |
+
|
500 |
+
shape_ceil = get_image_shape_ceil(uov_input_image)
|
501 |
+
if shape_ceil < 1024:
|
502 |
+
print(f'[Vary] Image is resized because it is too small.')
|
503 |
+
shape_ceil = 1024
|
504 |
+
elif shape_ceil > 2048:
|
505 |
+
print(f'[Vary] Image is resized because it is too big.')
|
506 |
+
shape_ceil = 2048
|
507 |
+
|
508 |
+
uov_input_image = set_image_shape_ceil(uov_input_image, shape_ceil)
|
509 |
+
|
510 |
+
initial_pixels = core.numpy_to_pytorch(uov_input_image)
|
511 |
+
progressbar(async_task, 13, 'VAE encoding ...')
|
512 |
+
|
513 |
+
candidate_vae, _ = pipeline.get_candidate_vae(
|
514 |
+
steps=steps,
|
515 |
+
switch=switch,
|
516 |
+
denoise=denoising_strength,
|
517 |
+
refiner_swap_method=refiner_swap_method
|
518 |
+
)
|
519 |
+
|
520 |
+
initial_latent = core.encode_vae(vae=candidate_vae, pixels=initial_pixels)
|
521 |
+
B, C, H, W = initial_latent['samples'].shape
|
522 |
+
width = W * 8
|
523 |
+
height = H * 8
|
524 |
+
print(f'Final resolution is {str((height, width))}.')
|
525 |
+
|
526 |
+
if 'upscale' in goals:
|
527 |
+
H, W, C = uov_input_image.shape
|
528 |
+
progressbar(async_task, 13, f'Upscaling image from {str((H, W))} ...')
|
529 |
+
uov_input_image = perform_upscale(uov_input_image)
|
530 |
+
print(f'Image upscaled.')
|
531 |
+
|
532 |
+
if '1.5x' in uov_method:
|
533 |
+
f = 1.5
|
534 |
+
elif '2x' in uov_method:
|
535 |
+
f = 2.0
|
536 |
+
else:
|
537 |
+
f = 1.0
|
538 |
+
|
539 |
+
shape_ceil = get_shape_ceil(H * f, W * f)
|
540 |
+
|
541 |
+
if shape_ceil < 1024:
|
542 |
+
print(f'[Upscale] Image is resized because it is too small.')
|
543 |
+
uov_input_image = set_image_shape_ceil(uov_input_image, 1024)
|
544 |
+
shape_ceil = 1024
|
545 |
+
else:
|
546 |
+
uov_input_image = resample_image(uov_input_image, width=W * f, height=H * f)
|
547 |
+
|
548 |
+
image_is_super_large = shape_ceil > 2800
|
549 |
+
|
550 |
+
if 'fast' in uov_method:
|
551 |
+
direct_return = True
|
552 |
+
elif image_is_super_large:
|
553 |
+
print('Image is too large. Directly returned the SR image. '
|
554 |
+
'Usually directly return SR image at 4K resolution '
|
555 |
+
'yields better results than SDXL diffusion.')
|
556 |
+
direct_return = True
|
557 |
+
else:
|
558 |
+
direct_return = False
|
559 |
+
|
560 |
+
if direct_return:
|
561 |
+
d = [('Upscale (Fast)', 'upscale_fast', '2x')]
|
562 |
+
uov_input_image_path = log(uov_input_image, d, output_format=output_format)
|
563 |
+
yield_result(async_task, uov_input_image_path, do_not_show_finished_images=True)
|
564 |
+
return
|
565 |
+
|
566 |
+
tiled = True
|
567 |
+
denoising_strength = 0.382
|
568 |
+
|
569 |
+
if overwrite_upscale_strength > 0:
|
570 |
+
denoising_strength = overwrite_upscale_strength
|
571 |
+
|
572 |
+
initial_pixels = core.numpy_to_pytorch(uov_input_image)
|
573 |
+
progressbar(async_task, 13, 'VAE encoding ...')
|
574 |
+
|
575 |
+
candidate_vae, _ = pipeline.get_candidate_vae(
|
576 |
+
steps=steps,
|
577 |
+
switch=switch,
|
578 |
+
denoise=denoising_strength,
|
579 |
+
refiner_swap_method=refiner_swap_method
|
580 |
+
)
|
581 |
+
|
582 |
+
initial_latent = core.encode_vae(
|
583 |
+
vae=candidate_vae,
|
584 |
+
pixels=initial_pixels, tiled=True)
|
585 |
+
B, C, H, W = initial_latent['samples'].shape
|
586 |
+
width = W * 8
|
587 |
+
height = H * 8
|
588 |
+
print(f'Final resolution is {str((height, width))}.')
|
589 |
+
|
590 |
+
if 'inpaint' in goals:
|
591 |
+
if len(outpaint_selections) > 0:
|
592 |
+
H, W, C = inpaint_image.shape
|
593 |
+
if 'top' in outpaint_selections:
|
594 |
+
inpaint_image = np.pad(inpaint_image, [[int(H * 0.3), 0], [0, 0], [0, 0]], mode='edge')
|
595 |
+
inpaint_mask = np.pad(inpaint_mask, [[int(H * 0.3), 0], [0, 0]], mode='constant',
|
596 |
+
constant_values=255)
|
597 |
+
if 'bottom' in outpaint_selections:
|
598 |
+
inpaint_image = np.pad(inpaint_image, [[0, int(H * 0.3)], [0, 0], [0, 0]], mode='edge')
|
599 |
+
inpaint_mask = np.pad(inpaint_mask, [[0, int(H * 0.3)], [0, 0]], mode='constant',
|
600 |
+
constant_values=255)
|
601 |
+
|
602 |
+
H, W, C = inpaint_image.shape
|
603 |
+
if 'left' in outpaint_selections:
|
604 |
+
inpaint_image = np.pad(inpaint_image, [[0, 0], [int(H * 0.3), 0], [0, 0]], mode='edge')
|
605 |
+
inpaint_mask = np.pad(inpaint_mask, [[0, 0], [int(H * 0.3), 0]], mode='constant',
|
606 |
+
constant_values=255)
|
607 |
+
if 'right' in outpaint_selections:
|
608 |
+
inpaint_image = np.pad(inpaint_image, [[0, 0], [0, int(H * 0.3)], [0, 0]], mode='edge')
|
609 |
+
inpaint_mask = np.pad(inpaint_mask, [[0, 0], [0, int(H * 0.3)]], mode='constant',
|
610 |
+
constant_values=255)
|
611 |
+
|
612 |
+
inpaint_image = np.ascontiguousarray(inpaint_image.copy())
|
613 |
+
inpaint_mask = np.ascontiguousarray(inpaint_mask.copy())
|
614 |
+
inpaint_strength = 1.0
|
615 |
+
inpaint_respective_field = 1.0
|
616 |
+
|
617 |
+
denoising_strength = inpaint_strength
|
618 |
+
|
619 |
+
inpaint_worker.current_task = inpaint_worker.InpaintWorker(
|
620 |
+
image=inpaint_image,
|
621 |
+
mask=inpaint_mask,
|
622 |
+
use_fill=denoising_strength > 0.99,
|
623 |
+
k=inpaint_respective_field
|
624 |
+
)
|
625 |
+
|
626 |
+
if debugging_inpaint_preprocessor:
|
627 |
+
yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(),
|
628 |
+
do_not_show_finished_images=True)
|
629 |
+
return
|
630 |
+
|
631 |
+
progressbar(async_task, 13, 'VAE Inpaint encoding ...')
|
632 |
+
|
633 |
+
inpaint_pixel_fill = core.numpy_to_pytorch(inpaint_worker.current_task.interested_fill)
|
634 |
+
inpaint_pixel_image = core.numpy_to_pytorch(inpaint_worker.current_task.interested_image)
|
635 |
+
inpaint_pixel_mask = core.numpy_to_pytorch(inpaint_worker.current_task.interested_mask)
|
636 |
+
|
637 |
+
candidate_vae, candidate_vae_swap = pipeline.get_candidate_vae(
|
638 |
+
steps=steps,
|
639 |
+
switch=switch,
|
640 |
+
denoise=denoising_strength,
|
641 |
+
refiner_swap_method=refiner_swap_method
|
642 |
+
)
|
643 |
+
|
644 |
+
latent_inpaint, latent_mask = core.encode_vae_inpaint(
|
645 |
+
mask=inpaint_pixel_mask,
|
646 |
+
vae=candidate_vae,
|
647 |
+
pixels=inpaint_pixel_image)
|
648 |
+
|
649 |
+
latent_swap = None
|
650 |
+
if candidate_vae_swap is not None:
|
651 |
+
progressbar(async_task, 13, 'VAE SD15 encoding ...')
|
652 |
+
latent_swap = core.encode_vae(
|
653 |
+
vae=candidate_vae_swap,
|
654 |
+
pixels=inpaint_pixel_fill)['samples']
|
655 |
+
|
656 |
+
progressbar(async_task, 13, 'VAE encoding ...')
|
657 |
+
latent_fill = core.encode_vae(
|
658 |
+
vae=candidate_vae,
|
659 |
+
pixels=inpaint_pixel_fill)['samples']
|
660 |
+
|
661 |
+
inpaint_worker.current_task.load_latent(
|
662 |
+
latent_fill=latent_fill, latent_mask=latent_mask, latent_swap=latent_swap)
|
663 |
+
|
664 |
+
if inpaint_parameterized:
|
665 |
+
pipeline.final_unet = inpaint_worker.current_task.patch(
|
666 |
+
inpaint_head_model_path=inpaint_head_model_path,
|
667 |
+
inpaint_latent=latent_inpaint,
|
668 |
+
inpaint_latent_mask=latent_mask,
|
669 |
+
model=pipeline.final_unet
|
670 |
+
)
|
671 |
+
|
672 |
+
if not inpaint_disable_initial_latent:
|
673 |
+
initial_latent = {'samples': latent_fill}
|
674 |
+
|
675 |
+
B, C, H, W = latent_fill.shape
|
676 |
+
height, width = H * 8, W * 8
|
677 |
+
final_height, final_width = inpaint_worker.current_task.image.shape[:2]
|
678 |
+
print(f'Final resolution is {str((final_height, final_width))}, latent is {str((height, width))}.')
|
679 |
+
|
680 |
+
if 'cn' in goals:
|
681 |
+
for task in cn_tasks[flags.cn_canny]:
|
682 |
+
cn_img, cn_stop, cn_weight = task
|
683 |
+
cn_img = resize_image(HWC3(cn_img), width=width, height=height)
|
684 |
+
|
685 |
+
if not skipping_cn_preprocessor:
|
686 |
+
cn_img = preprocessors.canny_pyramid(cn_img, canny_low_threshold, canny_high_threshold)
|
687 |
+
|
688 |
+
cn_img = HWC3(cn_img)
|
689 |
+
task[0] = core.numpy_to_pytorch(cn_img)
|
690 |
+
if debugging_cn_preprocessor:
|
691 |
+
yield_result(async_task, cn_img, do_not_show_finished_images=True)
|
692 |
+
return
|
693 |
+
for task in cn_tasks[flags.cn_cpds]:
|
694 |
+
cn_img, cn_stop, cn_weight = task
|
695 |
+
cn_img = resize_image(HWC3(cn_img), width=width, height=height)
|
696 |
+
|
697 |
+
if not skipping_cn_preprocessor:
|
698 |
+
cn_img = preprocessors.cpds(cn_img)
|
699 |
+
|
700 |
+
cn_img = HWC3(cn_img)
|
701 |
+
task[0] = core.numpy_to_pytorch(cn_img)
|
702 |
+
if debugging_cn_preprocessor:
|
703 |
+
yield_result(async_task, cn_img, do_not_show_finished_images=True)
|
704 |
+
return
|
705 |
+
for task in cn_tasks[flags.cn_ip]:
|
706 |
+
cn_img, cn_stop, cn_weight = task
|
707 |
+
cn_img = HWC3(cn_img)
|
708 |
+
|
709 |
+
# https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/README.md?plain=1#L75
|
710 |
+
cn_img = resize_image(cn_img, width=224, height=224, resize_mode=0)
|
711 |
+
|
712 |
+
task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_path)
|
713 |
+
if debugging_cn_preprocessor:
|
714 |
+
yield_result(async_task, cn_img, do_not_show_finished_images=True)
|
715 |
+
return
|
716 |
+
for task in cn_tasks[flags.cn_ip_face]:
|
717 |
+
cn_img, cn_stop, cn_weight = task
|
718 |
+
cn_img = HWC3(cn_img)
|
719 |
+
|
720 |
+
if not skipping_cn_preprocessor:
|
721 |
+
cn_img = extras.face_crop.crop_image(cn_img)
|
722 |
+
|
723 |
+
# https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/README.md?plain=1#L75
|
724 |
+
cn_img = resize_image(cn_img, width=224, height=224, resize_mode=0)
|
725 |
+
|
726 |
+
task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_face_path)
|
727 |
+
if debugging_cn_preprocessor:
|
728 |
+
yield_result(async_task, cn_img, do_not_show_finished_images=True)
|
729 |
+
return
|
730 |
+
|
731 |
+
all_ip_tasks = cn_tasks[flags.cn_ip] + cn_tasks[flags.cn_ip_face]
|
732 |
+
|
733 |
+
if len(all_ip_tasks) > 0:
|
734 |
+
pipeline.final_unet = ip_adapter.patch_model(pipeline.final_unet, all_ip_tasks)
|
735 |
+
|
736 |
+
if freeu_enabled:
|
737 |
+
print(f'FreeU is enabled!')
|
738 |
+
pipeline.final_unet = core.apply_freeu(
|
739 |
+
pipeline.final_unet,
|
740 |
+
freeu_b1,
|
741 |
+
freeu_b2,
|
742 |
+
freeu_s1,
|
743 |
+
freeu_s2
|
744 |
+
)
|
745 |
+
|
746 |
+
all_steps = steps * image_number
|
747 |
+
|
748 |
+
print(f'[Parameters] Denoising Strength = {denoising_strength}')
|
749 |
+
|
750 |
+
if isinstance(initial_latent, dict) and 'samples' in initial_latent:
|
751 |
+
log_shape = initial_latent['samples'].shape
|
752 |
+
else:
|
753 |
+
log_shape = f'Image Space {(height, width)}'
|
754 |
+
|
755 |
+
print(f'[Parameters] Initial Latent shape: {log_shape}')
|
756 |
+
|
757 |
+
preparation_time = time.perf_counter() - execution_start_time
|
758 |
+
print(f'Preparation time: {preparation_time:.2f} seconds')
|
759 |
+
|
760 |
+
final_sampler_name = sampler_name
|
761 |
+
final_scheduler_name = scheduler_name
|
762 |
+
|
763 |
+
if scheduler_name == 'lcm':
|
764 |
+
final_scheduler_name = 'sgm_uniform'
|
765 |
+
if pipeline.final_unet is not None:
|
766 |
+
pipeline.final_unet = core.opModelSamplingDiscrete.patch(
|
767 |
+
pipeline.final_unet,
|
768 |
+
sampling='lcm',
|
769 |
+
zsnr=False)[0]
|
770 |
+
if pipeline.final_refiner_unet is not None:
|
771 |
+
pipeline.final_refiner_unet = core.opModelSamplingDiscrete.patch(
|
772 |
+
pipeline.final_refiner_unet,
|
773 |
+
sampling='lcm',
|
774 |
+
zsnr=False)[0]
|
775 |
+
print('Using lcm scheduler.')
|
776 |
+
|
777 |
+
async_task.yields.append(['preview', (13, 'Moving model to GPU ...', None)])
|
778 |
+
|
779 |
+
def callback(step, x0, x, total_steps, y):
|
780 |
+
done_steps = current_task_id * steps + step
|
781 |
+
async_task.yields.append(['preview', (
|
782 |
+
int(15.0 + 85.0 * float(done_steps) / float(all_steps)),
|
783 |
+
f'Step {step}/{total_steps} in the {current_task_id + 1}{ordinal_suffix(current_task_id + 1)} Sampling', y)])
|
784 |
+
|
785 |
+
for current_task_id, task in enumerate(tasks):
|
786 |
+
execution_start_time = time.perf_counter()
|
787 |
+
|
788 |
+
try:
|
789 |
+
if async_task.last_stop is not False:
|
790 |
+
ldm_patched.model_management.interrupt_current_processing()
|
791 |
+
positive_cond, negative_cond = task['c'], task['uc']
|
792 |
+
|
793 |
+
if 'cn' in goals:
|
794 |
+
for cn_flag, cn_path in [
|
795 |
+
(flags.cn_canny, controlnet_canny_path),
|
796 |
+
(flags.cn_cpds, controlnet_cpds_path)
|
797 |
+
]:
|
798 |
+
for cn_img, cn_stop, cn_weight in cn_tasks[cn_flag]:
|
799 |
+
positive_cond, negative_cond = core.apply_controlnet(
|
800 |
+
positive_cond, negative_cond,
|
801 |
+
pipeline.loaded_ControlNets[cn_path], cn_img, cn_weight, 0, cn_stop)
|
802 |
+
|
803 |
+
imgs = pipeline.process_diffusion(
|
804 |
+
positive_cond=positive_cond,
|
805 |
+
negative_cond=negative_cond,
|
806 |
+
steps=steps,
|
807 |
+
switch=switch,
|
808 |
+
width=width,
|
809 |
+
height=height,
|
810 |
+
image_seed=task['task_seed'],
|
811 |
+
callback=callback,
|
812 |
+
sampler_name=final_sampler_name,
|
813 |
+
scheduler_name=final_scheduler_name,
|
814 |
+
latent=initial_latent,
|
815 |
+
denoise=denoising_strength,
|
816 |
+
tiled=tiled,
|
817 |
+
cfg_scale=cfg_scale,
|
818 |
+
refiner_swap_method=refiner_swap_method,
|
819 |
+
disable_preview=disable_preview
|
820 |
+
)
|
821 |
+
|
822 |
+
del task['c'], task['uc'], positive_cond, negative_cond # Save memory
|
823 |
+
|
824 |
+
if inpaint_worker.current_task is not None:
|
825 |
+
imgs = [inpaint_worker.current_task.post_process(x) for x in imgs]
|
826 |
+
|
827 |
+
img_paths = []
|
828 |
+
for x in imgs:
|
829 |
+
d = [('Prompt', 'prompt', task['log_positive_prompt']),
|
830 |
+
('Negative Prompt', 'negative_prompt', task['log_negative_prompt']),
|
831 |
+
('Fooocus V2 Expansion', 'prompt_expansion', task['expansion']),
|
832 |
+
('Styles', 'styles', str(raw_style_selections)),
|
833 |
+
('Performance', 'performance', performance_selection.value)]
|
834 |
+
|
835 |
+
if performance_selection.steps() != steps:
|
836 |
+
d.append(('Steps', 'steps', steps))
|
837 |
+
|
838 |
+
d += [('Resolution', 'resolution', str((width, height))),
|
839 |
+
('Guidance Scale', 'guidance_scale', guidance_scale),
|
840 |
+
('Sharpness', 'sharpness', sharpness),
|
841 |
+
('ADM Guidance', 'adm_guidance', str((
|
842 |
+
modules.patch.patch_settings[pid].positive_adm_scale,
|
843 |
+
modules.patch.patch_settings[pid].negative_adm_scale,
|
844 |
+
modules.patch.patch_settings[pid].adm_scaler_end))),
|
845 |
+
('Base Model', 'base_model', base_model_name),
|
846 |
+
('Refiner Model', 'refiner_model', refiner_model_name),
|
847 |
+
('Refiner Switch', 'refiner_switch', refiner_switch)]
|
848 |
+
|
849 |
+
if refiner_model_name != 'None':
|
850 |
+
if overwrite_switch > 0:
|
851 |
+
d.append(('Overwrite Switch', 'overwrite_switch', overwrite_switch))
|
852 |
+
if refiner_swap_method != flags.refiner_swap_method:
|
853 |
+
d.append(('Refiner Swap Method', 'refiner_swap_method', refiner_swap_method))
|
854 |
+
if modules.patch.patch_settings[pid].adaptive_cfg != modules.config.default_cfg_tsnr:
|
855 |
+
d.append(('CFG Mimicking from TSNR', 'adaptive_cfg', modules.patch.patch_settings[pid].adaptive_cfg))
|
856 |
+
|
857 |
+
d.append(('Sampler', 'sampler', sampler_name))
|
858 |
+
d.append(('Scheduler', 'scheduler', scheduler_name))
|
859 |
+
d.append(('Seed', 'seed', task['task_seed']))
|
860 |
+
|
861 |
+
if freeu_enabled:
|
862 |
+
d.append(('FreeU', 'freeu', str((freeu_b1, freeu_b2, freeu_s1, freeu_s2))))
|
863 |
+
|
864 |
+
for li, (n, w) in enumerate(loras):
|
865 |
+
if n != 'None':
|
866 |
+
d.append((f'LoRA {li + 1}', f'lora_combined_{li + 1}', f'{n} : {w}'))
|
867 |
+
|
868 |
+
metadata_parser = None
|
869 |
+
if save_metadata_to_images:
|
870 |
+
metadata_parser = modules.meta_parser.get_metadata_parser(metadata_scheme)
|
871 |
+
metadata_parser.set_data(task['log_positive_prompt'], task['positive'],
|
872 |
+
task['log_negative_prompt'], task['negative'],
|
873 |
+
steps, base_model_name, refiner_model_name, loras)
|
874 |
+
d.append(('Metadata Scheme', 'metadata_scheme', metadata_scheme.value if save_metadata_to_images else save_metadata_to_images))
|
875 |
+
d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version))
|
876 |
+
img_paths.append(log(x, d, metadata_parser, output_format))
|
877 |
+
|
878 |
+
yield_result(async_task, img_paths, do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results)
|
879 |
+
except ldm_patched.modules.model_management.InterruptProcessingException as e:
|
880 |
+
if async_task.last_stop == 'skip':
|
881 |
+
print('User skipped')
|
882 |
+
async_task.last_stop = False
|
883 |
+
continue
|
884 |
+
else:
|
885 |
+
print('User stopped')
|
886 |
+
break
|
887 |
+
|
888 |
+
execution_time = time.perf_counter() - execution_start_time
|
889 |
+
print(f'Generating and saving time: {execution_time:.2f} seconds')
|
890 |
+
async_task.processing = False
|
891 |
+
return
|
892 |
+
|
893 |
+
while True:
|
894 |
+
time.sleep(0.01)
|
895 |
+
if len(async_tasks) > 0:
|
896 |
+
task = async_tasks.pop(0)
|
897 |
+
generate_image_grid = task.args.pop(0)
|
898 |
+
|
899 |
+
try:
|
900 |
+
handler(task)
|
901 |
+
if generate_image_grid:
|
902 |
+
build_image_wall(task)
|
903 |
+
task.yields.append(['finish', task.results])
|
904 |
+
pipeline.prepare_text_encoder(async_call=True)
|
905 |
+
except:
|
906 |
+
traceback.print_exc()
|
907 |
+
task.yields.append(['finish', task.results])
|
908 |
+
finally:
|
909 |
+
if pid in modules.patch.patch_settings:
|
910 |
+
del modules.patch.patch_settings[pid]
|
911 |
+
pass
|
912 |
+
|
913 |
+
|
914 |
+
threading.Thread(target=worker, daemon=True).start()
|
modules/auth.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import hashlib
|
3 |
+
import modules.constants as constants
|
4 |
+
|
5 |
+
from os.path import exists
|
6 |
+
|
7 |
+
|
8 |
+
def auth_list_to_dict(auth_list):
|
9 |
+
auth_dict = {}
|
10 |
+
for auth_data in auth_list:
|
11 |
+
if 'user' in auth_data:
|
12 |
+
if 'hash' in auth_data:
|
13 |
+
auth_dict |= {auth_data['user']: auth_data['hash']}
|
14 |
+
elif 'pass' in auth_data:
|
15 |
+
auth_dict |= {auth_data['user']: hashlib.sha256(bytes(auth_data['pass'], encoding='utf-8')).hexdigest()}
|
16 |
+
return auth_dict
|
17 |
+
|
18 |
+
|
19 |
+
def load_auth_data(filename=None):
|
20 |
+
auth_dict = None
|
21 |
+
if filename != None and exists(filename):
|
22 |
+
with open(filename, encoding='utf-8') as auth_file:
|
23 |
+
try:
|
24 |
+
auth_obj = json.load(auth_file)
|
25 |
+
if isinstance(auth_obj, list) and len(auth_obj) > 0:
|
26 |
+
auth_dict = auth_list_to_dict(auth_obj)
|
27 |
+
except Exception as e:
|
28 |
+
print('load_auth_data, e: ' + str(e))
|
29 |
+
return auth_dict
|
30 |
+
|
31 |
+
|
32 |
+
auth_dict = load_auth_data(constants.AUTH_FILENAME)
|
33 |
+
|
34 |
+
auth_enabled = auth_dict != None
|
35 |
+
|
36 |
+
|
37 |
+
def check_auth(user, password):
|
38 |
+
if user not in auth_dict:
|
39 |
+
return False
|
40 |
+
else:
|
41 |
+
return hashlib.sha256(bytes(password, encoding='utf-8')).hexdigest() == auth_dict[user]
|
modules/config.py
ADDED
@@ -0,0 +1,607 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import math
|
4 |
+
import numbers
|
5 |
+
import args_manager
|
6 |
+
import modules.flags
|
7 |
+
import modules.sdxl_styles
|
8 |
+
|
9 |
+
from modules.model_loader import load_file_from_url
|
10 |
+
from modules.util import get_files_from_folder, makedirs_with_log
|
11 |
+
from modules.flags import Performance, MetadataScheme
|
12 |
+
|
13 |
+
def get_config_path(key, default_value):
|
14 |
+
env = os.getenv(key)
|
15 |
+
if env is not None and isinstance(env, str):
|
16 |
+
print(f"Environment: {key} = {env}")
|
17 |
+
return env
|
18 |
+
else:
|
19 |
+
return os.path.abspath(default_value)
|
20 |
+
|
21 |
+
config_path = get_config_path('config_path', "./config.txt")
|
22 |
+
config_example_path = get_config_path('config_example_path', "config_modification_tutorial.txt")
|
23 |
+
config_dict = {}
|
24 |
+
always_save_keys = []
|
25 |
+
visited_keys = []
|
26 |
+
|
27 |
+
try:
|
28 |
+
with open(os.path.abspath(f'./presets/default.json'), "r", encoding="utf-8") as json_file:
|
29 |
+
config_dict.update(json.load(json_file))
|
30 |
+
except Exception as e:
|
31 |
+
print(f'Load default preset failed.')
|
32 |
+
print(e)
|
33 |
+
|
34 |
+
try:
|
35 |
+
if os.path.exists(config_path):
|
36 |
+
with open(config_path, "r", encoding="utf-8") as json_file:
|
37 |
+
config_dict.update(json.load(json_file))
|
38 |
+
always_save_keys = list(config_dict.keys())
|
39 |
+
except Exception as e:
|
40 |
+
print(f'Failed to load config file "{config_path}" . The reason is: {str(e)}')
|
41 |
+
print('Please make sure that:')
|
42 |
+
print(f'1. The file "{config_path}" is a valid text file, and you have access to read it.')
|
43 |
+
print('2. Use "\\\\" instead of "\\" when describing paths.')
|
44 |
+
print('3. There is no "," before the last "}".')
|
45 |
+
print('4. All key/value formats are correct.')
|
46 |
+
|
47 |
+
|
48 |
+
def try_load_deprecated_user_path_config():
|
49 |
+
global config_dict
|
50 |
+
|
51 |
+
if not os.path.exists('user_path_config.txt'):
|
52 |
+
return
|
53 |
+
|
54 |
+
try:
|
55 |
+
deprecated_config_dict = json.load(open('user_path_config.txt', "r", encoding="utf-8"))
|
56 |
+
|
57 |
+
def replace_config(old_key, new_key):
|
58 |
+
if old_key in deprecated_config_dict:
|
59 |
+
config_dict[new_key] = deprecated_config_dict[old_key]
|
60 |
+
del deprecated_config_dict[old_key]
|
61 |
+
|
62 |
+
replace_config('modelfile_path', 'path_checkpoints')
|
63 |
+
replace_config('lorafile_path', 'path_loras')
|
64 |
+
replace_config('embeddings_path', 'path_embeddings')
|
65 |
+
replace_config('vae_approx_path', 'path_vae_approx')
|
66 |
+
replace_config('upscale_models_path', 'path_upscale_models')
|
67 |
+
replace_config('inpaint_models_path', 'path_inpaint')
|
68 |
+
replace_config('controlnet_models_path', 'path_controlnet')
|
69 |
+
replace_config('clip_vision_models_path', 'path_clip_vision')
|
70 |
+
replace_config('fooocus_expansion_path', 'path_fooocus_expansion')
|
71 |
+
replace_config('temp_outputs_path', 'path_outputs')
|
72 |
+
|
73 |
+
if deprecated_config_dict.get("default_model", None) == 'juggernautXL_version6Rundiffusion.safetensors':
|
74 |
+
os.replace('user_path_config.txt', 'user_path_config-deprecated.txt')
|
75 |
+
print('Config updated successfully in silence. '
|
76 |
+
'A backup of previous config is written to "user_path_config-deprecated.txt".')
|
77 |
+
return
|
78 |
+
|
79 |
+
if input("Newer models and configs are available. "
|
80 |
+
"Download and update files? [Y/n]:") in ['n', 'N', 'No', 'no', 'NO']:
|
81 |
+
config_dict.update(deprecated_config_dict)
|
82 |
+
print('Loading using deprecated old models and deprecated old configs.')
|
83 |
+
return
|
84 |
+
else:
|
85 |
+
os.replace('user_path_config.txt', 'user_path_config-deprecated.txt')
|
86 |
+
print('Config updated successfully by user. '
|
87 |
+
'A backup of previous config is written to "user_path_config-deprecated.txt".')
|
88 |
+
return
|
89 |
+
except Exception as e:
|
90 |
+
print('Processing deprecated config failed')
|
91 |
+
print(e)
|
92 |
+
return
|
93 |
+
|
94 |
+
|
95 |
+
try_load_deprecated_user_path_config()
|
96 |
+
|
97 |
+
preset = args_manager.args.preset
|
98 |
+
|
99 |
+
if isinstance(preset, str):
|
100 |
+
preset_path = os.path.abspath(f'./presets/{preset}.json')
|
101 |
+
try:
|
102 |
+
if os.path.exists(preset_path):
|
103 |
+
with open(preset_path, "r", encoding="utf-8") as json_file:
|
104 |
+
config_dict.update(json.load(json_file))
|
105 |
+
print(f'Loaded preset: {preset_path}')
|
106 |
+
else:
|
107 |
+
raise FileNotFoundError
|
108 |
+
except Exception as e:
|
109 |
+
print(f'Load preset [{preset_path}] failed')
|
110 |
+
print(e)
|
111 |
+
|
112 |
+
|
113 |
+
def get_path_output() -> str:
|
114 |
+
"""
|
115 |
+
Checking output path argument and overriding default path.
|
116 |
+
"""
|
117 |
+
global config_dict
|
118 |
+
path_output = get_dir_or_set_default('path_outputs', '../outputs/', make_directory=True)
|
119 |
+
if args_manager.args.output_path:
|
120 |
+
print(f'[CONFIG] Overriding config value path_outputs with {args_manager.args.output_path}')
|
121 |
+
config_dict['path_outputs'] = path_output = args_manager.args.output_path
|
122 |
+
return path_output
|
123 |
+
|
124 |
+
|
125 |
+
def get_dir_or_set_default(key, default_value, as_array=False, make_directory=False):
|
126 |
+
global config_dict, visited_keys, always_save_keys
|
127 |
+
|
128 |
+
if key not in visited_keys:
|
129 |
+
visited_keys.append(key)
|
130 |
+
|
131 |
+
if key not in always_save_keys:
|
132 |
+
always_save_keys.append(key)
|
133 |
+
|
134 |
+
v = os.getenv(key)
|
135 |
+
if v is not None:
|
136 |
+
print(f"Environment: {key} = {v}")
|
137 |
+
config_dict[key] = v
|
138 |
+
else:
|
139 |
+
v = config_dict.get(key, None)
|
140 |
+
|
141 |
+
if isinstance(v, str):
|
142 |
+
if make_directory:
|
143 |
+
makedirs_with_log(v)
|
144 |
+
if os.path.exists(v) and os.path.isdir(v):
|
145 |
+
return v if not as_array else [v]
|
146 |
+
elif isinstance(v, list):
|
147 |
+
if make_directory:
|
148 |
+
for d in v:
|
149 |
+
makedirs_with_log(d)
|
150 |
+
if all([os.path.exists(d) and os.path.isdir(d) for d in v]):
|
151 |
+
return v
|
152 |
+
|
153 |
+
if v is not None:
|
154 |
+
print(f'Failed to load config key: {json.dumps({key:v})} is invalid or does not exist; will use {json.dumps({key:default_value})} instead.')
|
155 |
+
if isinstance(default_value, list):
|
156 |
+
dp = []
|
157 |
+
for path in default_value:
|
158 |
+
abs_path = os.path.abspath(os.path.join(os.path.dirname(__file__), path))
|
159 |
+
dp.append(abs_path)
|
160 |
+
os.makedirs(abs_path, exist_ok=True)
|
161 |
+
else:
|
162 |
+
dp = os.path.abspath(os.path.join(os.path.dirname(__file__), default_value))
|
163 |
+
os.makedirs(dp, exist_ok=True)
|
164 |
+
if as_array:
|
165 |
+
dp = [dp]
|
166 |
+
config_dict[key] = dp
|
167 |
+
return dp
|
168 |
+
|
169 |
+
|
170 |
+
paths_checkpoints = get_dir_or_set_default('path_checkpoints', ['../models/checkpoints/'], True)
|
171 |
+
paths_loras = get_dir_or_set_default('path_loras', ['../models/loras/'], True)
|
172 |
+
path_embeddings = get_dir_or_set_default('path_embeddings', '../models/embeddings/')
|
173 |
+
path_vae_approx = get_dir_or_set_default('path_vae_approx', '../models/vae_approx/')
|
174 |
+
path_upscale_models = get_dir_or_set_default('path_upscale_models', '../models/upscale_models/')
|
175 |
+
path_inpaint = get_dir_or_set_default('path_inpaint', '../models/inpaint/')
|
176 |
+
path_controlnet = get_dir_or_set_default('path_controlnet', '../models/controlnet/')
|
177 |
+
path_clip_vision = get_dir_or_set_default('path_clip_vision', '../models/clip_vision/')
|
178 |
+
path_fooocus_expansion = get_dir_or_set_default('path_fooocus_expansion', '../models/prompt_expansion/fooocus_expansion')
|
179 |
+
path_outputs = get_path_output()
|
180 |
+
|
181 |
+
def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False):
|
182 |
+
global config_dict, visited_keys
|
183 |
+
|
184 |
+
if key not in visited_keys:
|
185 |
+
visited_keys.append(key)
|
186 |
+
|
187 |
+
v = os.getenv(key)
|
188 |
+
if v is not None:
|
189 |
+
print(f"Environment: {key} = {v}")
|
190 |
+
config_dict[key] = v
|
191 |
+
|
192 |
+
if key not in config_dict:
|
193 |
+
config_dict[key] = default_value
|
194 |
+
return default_value
|
195 |
+
|
196 |
+
v = config_dict.get(key, None)
|
197 |
+
if not disable_empty_as_none:
|
198 |
+
if v is None or v == '':
|
199 |
+
v = 'None'
|
200 |
+
if validator(v):
|
201 |
+
return v
|
202 |
+
else:
|
203 |
+
if v is not None:
|
204 |
+
print(f'Failed to load config key: {json.dumps({key:v})} is invalid; will use {json.dumps({key:default_value})} instead.')
|
205 |
+
config_dict[key] = default_value
|
206 |
+
return default_value
|
207 |
+
|
208 |
+
|
209 |
+
default_base_model_name = get_config_item_or_set_default(
|
210 |
+
key='default_model',
|
211 |
+
default_value='model.safetensors',
|
212 |
+
validator=lambda x: isinstance(x, str)
|
213 |
+
)
|
214 |
+
previous_default_models = get_config_item_or_set_default(
|
215 |
+
key='previous_default_models',
|
216 |
+
default_value=[],
|
217 |
+
validator=lambda x: isinstance(x, list) and all(isinstance(k, str) for k in x)
|
218 |
+
)
|
219 |
+
default_refiner_model_name = get_config_item_or_set_default(
|
220 |
+
key='default_refiner',
|
221 |
+
default_value='None',
|
222 |
+
validator=lambda x: isinstance(x, str)
|
223 |
+
)
|
224 |
+
default_refiner_switch = get_config_item_or_set_default(
|
225 |
+
key='default_refiner_switch',
|
226 |
+
default_value=0.8,
|
227 |
+
validator=lambda x: isinstance(x, numbers.Number) and 0 <= x <= 1
|
228 |
+
)
|
229 |
+
default_loras_min_weight = get_config_item_or_set_default(
|
230 |
+
key='default_loras_min_weight',
|
231 |
+
default_value=-2,
|
232 |
+
validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10
|
233 |
+
)
|
234 |
+
default_loras_max_weight = get_config_item_or_set_default(
|
235 |
+
key='default_loras_max_weight',
|
236 |
+
default_value=2,
|
237 |
+
validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10
|
238 |
+
)
|
239 |
+
default_loras = get_config_item_or_set_default(
|
240 |
+
key='default_loras',
|
241 |
+
default_value=[
|
242 |
+
[
|
243 |
+
"None",
|
244 |
+
1.0
|
245 |
+
],
|
246 |
+
[
|
247 |
+
"None",
|
248 |
+
1.0
|
249 |
+
],
|
250 |
+
[
|
251 |
+
"None",
|
252 |
+
1.0
|
253 |
+
],
|
254 |
+
[
|
255 |
+
"None",
|
256 |
+
1.0
|
257 |
+
],
|
258 |
+
[
|
259 |
+
"None",
|
260 |
+
1.0
|
261 |
+
]
|
262 |
+
],
|
263 |
+
validator=lambda x: isinstance(x, list) and all(len(y) == 2 and isinstance(y[0], str) and isinstance(y[1], numbers.Number) for y in x)
|
264 |
+
)
|
265 |
+
default_max_lora_number = get_config_item_or_set_default(
|
266 |
+
key='default_max_lora_number',
|
267 |
+
default_value=len(default_loras) if isinstance(default_loras, list) and len(default_loras) > 0 else 5,
|
268 |
+
validator=lambda x: isinstance(x, int) and x >= 1
|
269 |
+
)
|
270 |
+
default_cfg_scale = get_config_item_or_set_default(
|
271 |
+
key='default_cfg_scale',
|
272 |
+
default_value=7.0,
|
273 |
+
validator=lambda x: isinstance(x, numbers.Number)
|
274 |
+
)
|
275 |
+
default_sample_sharpness = get_config_item_or_set_default(
|
276 |
+
key='default_sample_sharpness',
|
277 |
+
default_value=2.0,
|
278 |
+
validator=lambda x: isinstance(x, numbers.Number)
|
279 |
+
)
|
280 |
+
default_sampler = get_config_item_or_set_default(
|
281 |
+
key='default_sampler',
|
282 |
+
default_value='dpmpp_2m_sde_gpu',
|
283 |
+
validator=lambda x: x in modules.flags.sampler_list
|
284 |
+
)
|
285 |
+
default_scheduler = get_config_item_or_set_default(
|
286 |
+
key='default_scheduler',
|
287 |
+
default_value='karras',
|
288 |
+
validator=lambda x: x in modules.flags.scheduler_list
|
289 |
+
)
|
290 |
+
default_styles = get_config_item_or_set_default(
|
291 |
+
key='default_styles',
|
292 |
+
default_value=[
|
293 |
+
"Fooocus V2",
|
294 |
+
"Fooocus Enhance",
|
295 |
+
"Fooocus Sharp"
|
296 |
+
],
|
297 |
+
validator=lambda x: isinstance(x, list) and all(y in modules.sdxl_styles.legal_style_names for y in x)
|
298 |
+
)
|
299 |
+
default_prompt_negative = get_config_item_or_set_default(
|
300 |
+
key='default_prompt_negative',
|
301 |
+
default_value='',
|
302 |
+
validator=lambda x: isinstance(x, str),
|
303 |
+
disable_empty_as_none=True
|
304 |
+
)
|
305 |
+
default_prompt = get_config_item_or_set_default(
|
306 |
+
key='default_prompt',
|
307 |
+
default_value='',
|
308 |
+
validator=lambda x: isinstance(x, str),
|
309 |
+
disable_empty_as_none=True
|
310 |
+
)
|
311 |
+
default_performance = get_config_item_or_set_default(
|
312 |
+
key='default_performance',
|
313 |
+
default_value=Performance.SPEED.value,
|
314 |
+
validator=lambda x: x in Performance.list()
|
315 |
+
)
|
316 |
+
default_advanced_checkbox = get_config_item_or_set_default(
|
317 |
+
key='default_advanced_checkbox',
|
318 |
+
default_value=False,
|
319 |
+
validator=lambda x: isinstance(x, bool)
|
320 |
+
)
|
321 |
+
default_max_image_number = get_config_item_or_set_default(
|
322 |
+
key='default_max_image_number',
|
323 |
+
default_value=32,
|
324 |
+
validator=lambda x: isinstance(x, int) and x >= 1
|
325 |
+
)
|
326 |
+
default_output_format = get_config_item_or_set_default(
|
327 |
+
key='default_output_format',
|
328 |
+
default_value='png',
|
329 |
+
validator=lambda x: x in modules.flags.output_formats
|
330 |
+
)
|
331 |
+
default_image_number = get_config_item_or_set_default(
|
332 |
+
key='default_image_number',
|
333 |
+
default_value=2,
|
334 |
+
validator=lambda x: isinstance(x, int) and 1 <= x <= default_max_image_number
|
335 |
+
)
|
336 |
+
checkpoint_downloads = get_config_item_or_set_default(
|
337 |
+
key='checkpoint_downloads',
|
338 |
+
default_value={},
|
339 |
+
validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items())
|
340 |
+
)
|
341 |
+
lora_downloads = get_config_item_or_set_default(
|
342 |
+
key='lora_downloads',
|
343 |
+
default_value={},
|
344 |
+
validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items())
|
345 |
+
)
|
346 |
+
embeddings_downloads = get_config_item_or_set_default(
|
347 |
+
key='embeddings_downloads',
|
348 |
+
default_value={},
|
349 |
+
validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items())
|
350 |
+
)
|
351 |
+
available_aspect_ratios = get_config_item_or_set_default(
|
352 |
+
key='available_aspect_ratios',
|
353 |
+
default_value=[
|
354 |
+
'704*1408', '704*1344', '768*1344', '768*1280', '832*1216', '832*1152',
|
355 |
+
'896*1152', '896*1088', '960*1088', '960*1024', '1024*1024', '1024*960',
|
356 |
+
'1088*960', '1088*896', '1152*896', '1152*832', '1216*832', '1280*768',
|
357 |
+
'1344*768', '1344*704', '1408*704', '1472*704', '1536*640', '1600*640',
|
358 |
+
'1664*576', '1728*576'
|
359 |
+
],
|
360 |
+
validator=lambda x: isinstance(x, list) and all('*' in v for v in x) and len(x) > 1
|
361 |
+
)
|
362 |
+
default_aspect_ratio = get_config_item_or_set_default(
|
363 |
+
key='default_aspect_ratio',
|
364 |
+
default_value='1152*896' if '1152*896' in available_aspect_ratios else available_aspect_ratios[0],
|
365 |
+
validator=lambda x: x in available_aspect_ratios
|
366 |
+
)
|
367 |
+
default_inpaint_engine_version = get_config_item_or_set_default(
|
368 |
+
key='default_inpaint_engine_version',
|
369 |
+
default_value='v2.6',
|
370 |
+
validator=lambda x: x in modules.flags.inpaint_engine_versions
|
371 |
+
)
|
372 |
+
default_cfg_tsnr = get_config_item_or_set_default(
|
373 |
+
key='default_cfg_tsnr',
|
374 |
+
default_value=7.0,
|
375 |
+
validator=lambda x: isinstance(x, numbers.Number)
|
376 |
+
)
|
377 |
+
default_overwrite_step = get_config_item_or_set_default(
|
378 |
+
key='default_overwrite_step',
|
379 |
+
default_value=-1,
|
380 |
+
validator=lambda x: isinstance(x, int)
|
381 |
+
)
|
382 |
+
default_overwrite_switch = get_config_item_or_set_default(
|
383 |
+
key='default_overwrite_switch',
|
384 |
+
default_value=-1,
|
385 |
+
validator=lambda x: isinstance(x, int)
|
386 |
+
)
|
387 |
+
example_inpaint_prompts = get_config_item_or_set_default(
|
388 |
+
key='example_inpaint_prompts',
|
389 |
+
default_value=[
|
390 |
+
'highly detailed face', 'detailed girl face', 'detailed man face', 'detailed hand', 'beautiful eyes'
|
391 |
+
],
|
392 |
+
validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x)
|
393 |
+
)
|
394 |
+
default_save_metadata_to_images = get_config_item_or_set_default(
|
395 |
+
key='default_save_metadata_to_images',
|
396 |
+
default_value=False,
|
397 |
+
validator=lambda x: isinstance(x, bool)
|
398 |
+
)
|
399 |
+
default_metadata_scheme = get_config_item_or_set_default(
|
400 |
+
key='default_metadata_scheme',
|
401 |
+
default_value=MetadataScheme.FOOOCUS.value,
|
402 |
+
validator=lambda x: x in [y[1] for y in modules.flags.metadata_scheme if y[1] == x]
|
403 |
+
)
|
404 |
+
metadata_created_by = get_config_item_or_set_default(
|
405 |
+
key='metadata_created_by',
|
406 |
+
default_value='',
|
407 |
+
validator=lambda x: isinstance(x, str)
|
408 |
+
)
|
409 |
+
|
410 |
+
example_inpaint_prompts = [[x] for x in example_inpaint_prompts]
|
411 |
+
|
412 |
+
config_dict["default_loras"] = default_loras = default_loras[:default_max_lora_number] + [['None', 1.0] for _ in range(default_max_lora_number - len(default_loras))]
|
413 |
+
|
414 |
+
possible_preset_keys = [
|
415 |
+
"default_model",
|
416 |
+
"default_refiner",
|
417 |
+
"default_refiner_switch",
|
418 |
+
"default_loras_min_weight",
|
419 |
+
"default_loras_max_weight",
|
420 |
+
"default_loras",
|
421 |
+
"default_max_lora_number",
|
422 |
+
"default_cfg_scale",
|
423 |
+
"default_sample_sharpness",
|
424 |
+
"default_sampler",
|
425 |
+
"default_scheduler",
|
426 |
+
"default_performance",
|
427 |
+
"default_prompt",
|
428 |
+
"default_prompt_negative",
|
429 |
+
"default_styles",
|
430 |
+
"default_aspect_ratio",
|
431 |
+
"default_save_metadata_to_images",
|
432 |
+
"checkpoint_downloads",
|
433 |
+
"embeddings_downloads",
|
434 |
+
"lora_downloads",
|
435 |
+
]
|
436 |
+
|
437 |
+
|
438 |
+
REWRITE_PRESET = False
|
439 |
+
|
440 |
+
if REWRITE_PRESET and isinstance(args_manager.args.preset, str):
|
441 |
+
save_path = 'presets/' + args_manager.args.preset + '.json'
|
442 |
+
with open(save_path, "w", encoding="utf-8") as json_file:
|
443 |
+
json.dump({k: config_dict[k] for k in possible_preset_keys}, json_file, indent=4)
|
444 |
+
print(f'Preset saved to {save_path}. Exiting ...')
|
445 |
+
exit(0)
|
446 |
+
|
447 |
+
|
448 |
+
def add_ratio(x):
|
449 |
+
a, b = x.replace('*', ' ').split(' ')[:2]
|
450 |
+
a, b = int(a), int(b)
|
451 |
+
g = math.gcd(a, b)
|
452 |
+
return f'{a}×{b} <span style="color: grey;"> \U00002223 {a // g}:{b // g}</span>'
|
453 |
+
|
454 |
+
|
455 |
+
default_aspect_ratio = add_ratio(default_aspect_ratio)
|
456 |
+
available_aspect_ratios = [add_ratio(x) for x in available_aspect_ratios]
|
457 |
+
|
458 |
+
|
459 |
+
# Only write config in the first launch.
|
460 |
+
if not os.path.exists(config_path):
|
461 |
+
with open(config_path, "w", encoding="utf-8") as json_file:
|
462 |
+
json.dump({k: config_dict[k] for k in always_save_keys}, json_file, indent=4)
|
463 |
+
|
464 |
+
|
465 |
+
# Always write tutorials.
|
466 |
+
with open(config_example_path, "w", encoding="utf-8") as json_file:
|
467 |
+
cpa = config_path.replace("\\", "\\\\")
|
468 |
+
json_file.write(f'You can modify your "{cpa}" using the below keys, formats, and examples.\n'
|
469 |
+
f'Do not modify this file. Modifications in this file will not take effect.\n'
|
470 |
+
f'This file is a tutorial and example. Please edit "{cpa}" to really change any settings.\n'
|
471 |
+
+ 'Remember to split the paths with "\\\\" rather than "\\", '
|
472 |
+
'and there is no "," before the last "}". \n\n\n')
|
473 |
+
json.dump({k: config_dict[k] for k in visited_keys}, json_file, indent=4)
|
474 |
+
|
475 |
+
model_filenames = []
|
476 |
+
lora_filenames = []
|
477 |
+
sdxl_lcm_lora = 'sdxl_lcm_lora.safetensors'
|
478 |
+
|
479 |
+
|
480 |
+
def get_model_filenames(folder_paths, name_filter=None):
|
481 |
+
extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch']
|
482 |
+
files = []
|
483 |
+
for folder in folder_paths:
|
484 |
+
files += get_files_from_folder(folder, extensions, name_filter)
|
485 |
+
return files
|
486 |
+
|
487 |
+
|
488 |
+
def update_all_model_names():
|
489 |
+
global model_filenames, lora_filenames
|
490 |
+
model_filenames = get_model_filenames(paths_checkpoints)
|
491 |
+
lora_filenames = get_model_filenames(paths_loras)
|
492 |
+
return
|
493 |
+
|
494 |
+
|
495 |
+
def downloading_inpaint_models(v):
|
496 |
+
assert v in modules.flags.inpaint_engine_versions
|
497 |
+
|
498 |
+
load_file_from_url(
|
499 |
+
url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/fooocus_inpaint_head.pth',
|
500 |
+
model_dir=path_inpaint,
|
501 |
+
file_name='fooocus_inpaint_head.pth'
|
502 |
+
)
|
503 |
+
head_file = os.path.join(path_inpaint, 'fooocus_inpaint_head.pth')
|
504 |
+
patch_file = None
|
505 |
+
|
506 |
+
if v == 'v1':
|
507 |
+
load_file_from_url(
|
508 |
+
url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint.fooocus.patch',
|
509 |
+
model_dir=path_inpaint,
|
510 |
+
file_name='inpaint.fooocus.patch'
|
511 |
+
)
|
512 |
+
patch_file = os.path.join(path_inpaint, 'inpaint.fooocus.patch')
|
513 |
+
|
514 |
+
if v == 'v2.5':
|
515 |
+
load_file_from_url(
|
516 |
+
url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint_v25.fooocus.patch',
|
517 |
+
model_dir=path_inpaint,
|
518 |
+
file_name='inpaint_v25.fooocus.patch'
|
519 |
+
)
|
520 |
+
patch_file = os.path.join(path_inpaint, 'inpaint_v25.fooocus.patch')
|
521 |
+
|
522 |
+
if v == 'v2.6':
|
523 |
+
load_file_from_url(
|
524 |
+
url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint_v26.fooocus.patch',
|
525 |
+
model_dir=path_inpaint,
|
526 |
+
file_name='inpaint_v26.fooocus.patch'
|
527 |
+
)
|
528 |
+
patch_file = os.path.join(path_inpaint, 'inpaint_v26.fooocus.patch')
|
529 |
+
|
530 |
+
return head_file, patch_file
|
531 |
+
|
532 |
+
|
533 |
+
def downloading_sdxl_lcm_lora():
|
534 |
+
load_file_from_url(
|
535 |
+
url='https://huggingface.co/lllyasviel/misc/resolve/main/sdxl_lcm_lora.safetensors',
|
536 |
+
model_dir=paths_loras[0],
|
537 |
+
file_name=sdxl_lcm_lora
|
538 |
+
)
|
539 |
+
return sdxl_lcm_lora
|
540 |
+
|
541 |
+
|
542 |
+
def downloading_controlnet_canny():
|
543 |
+
load_file_from_url(
|
544 |
+
url='https://huggingface.co/lllyasviel/misc/resolve/main/control-lora-canny-rank128.safetensors',
|
545 |
+
model_dir=path_controlnet,
|
546 |
+
file_name='control-lora-canny-rank128.safetensors'
|
547 |
+
)
|
548 |
+
return os.path.join(path_controlnet, 'control-lora-canny-rank128.safetensors')
|
549 |
+
|
550 |
+
|
551 |
+
def downloading_controlnet_cpds():
|
552 |
+
load_file_from_url(
|
553 |
+
url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_xl_cpds_128.safetensors',
|
554 |
+
model_dir=path_controlnet,
|
555 |
+
file_name='fooocus_xl_cpds_128.safetensors'
|
556 |
+
)
|
557 |
+
return os.path.join(path_controlnet, 'fooocus_xl_cpds_128.safetensors')
|
558 |
+
|
559 |
+
|
560 |
+
def downloading_ip_adapters(v):
|
561 |
+
assert v in ['ip', 'face']
|
562 |
+
|
563 |
+
results = []
|
564 |
+
|
565 |
+
load_file_from_url(
|
566 |
+
url='https://huggingface.co/lllyasviel/misc/resolve/main/clip_vision_vit_h.safetensors',
|
567 |
+
model_dir=path_clip_vision,
|
568 |
+
file_name='clip_vision_vit_h.safetensors'
|
569 |
+
)
|
570 |
+
results += [os.path.join(path_clip_vision, 'clip_vision_vit_h.safetensors')]
|
571 |
+
|
572 |
+
load_file_from_url(
|
573 |
+
url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_ip_negative.safetensors',
|
574 |
+
model_dir=path_controlnet,
|
575 |
+
file_name='fooocus_ip_negative.safetensors'
|
576 |
+
)
|
577 |
+
results += [os.path.join(path_controlnet, 'fooocus_ip_negative.safetensors')]
|
578 |
+
|
579 |
+
if v == 'ip':
|
580 |
+
load_file_from_url(
|
581 |
+
url='https://huggingface.co/lllyasviel/misc/resolve/main/ip-adapter-plus_sdxl_vit-h.bin',
|
582 |
+
model_dir=path_controlnet,
|
583 |
+
file_name='ip-adapter-plus_sdxl_vit-h.bin'
|
584 |
+
)
|
585 |
+
results += [os.path.join(path_controlnet, 'ip-adapter-plus_sdxl_vit-h.bin')]
|
586 |
+
|
587 |
+
if v == 'face':
|
588 |
+
load_file_from_url(
|
589 |
+
url='https://huggingface.co/lllyasviel/misc/resolve/main/ip-adapter-plus-face_sdxl_vit-h.bin',
|
590 |
+
model_dir=path_controlnet,
|
591 |
+
file_name='ip-adapter-plus-face_sdxl_vit-h.bin'
|
592 |
+
)
|
593 |
+
results += [os.path.join(path_controlnet, 'ip-adapter-plus-face_sdxl_vit-h.bin')]
|
594 |
+
|
595 |
+
return results
|
596 |
+
|
597 |
+
|
598 |
+
def downloading_upscale_model():
|
599 |
+
load_file_from_url(
|
600 |
+
url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_upscaler_s409985e5.bin',
|
601 |
+
model_dir=path_upscale_models,
|
602 |
+
file_name='fooocus_upscaler_s409985e5.bin'
|
603 |
+
)
|
604 |
+
return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin')
|
605 |
+
|
606 |
+
|
607 |
+
update_all_model_names()
|
modules/constants.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# as in k-diffusion (sampling.py)
|
2 |
+
MIN_SEED = 0
|
3 |
+
MAX_SEED = 2**63 - 1
|
4 |
+
|
5 |
+
AUTH_FILENAME = 'auth.json'
|
modules/core.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import einops
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import ldm_patched.modules.model_management
|
7 |
+
import ldm_patched.modules.model_detection
|
8 |
+
import ldm_patched.modules.model_patcher
|
9 |
+
import ldm_patched.modules.utils
|
10 |
+
import ldm_patched.modules.controlnet
|
11 |
+
import modules.sample_hijack
|
12 |
+
import ldm_patched.modules.samplers
|
13 |
+
import ldm_patched.modules.latent_formats
|
14 |
+
|
15 |
+
from ldm_patched.modules.sd import load_checkpoint_guess_config
|
16 |
+
from ldm_patched.contrib.external import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDecodeTiled, \
|
17 |
+
ControlNetApplyAdvanced
|
18 |
+
from ldm_patched.contrib.external_freelunch import FreeU_V2
|
19 |
+
from ldm_patched.modules.sample import prepare_mask
|
20 |
+
from modules.lora import match_lora
|
21 |
+
from modules.util import get_file_from_folder_list
|
22 |
+
from ldm_patched.modules.lora import model_lora_keys_unet, model_lora_keys_clip
|
23 |
+
from modules.config import path_embeddings
|
24 |
+
from ldm_patched.contrib.external_model_advanced import ModelSamplingDiscrete
|
25 |
+
|
26 |
+
|
27 |
+
opEmptyLatentImage = EmptyLatentImage()
|
28 |
+
opVAEDecode = VAEDecode()
|
29 |
+
opVAEEncode = VAEEncode()
|
30 |
+
opVAEDecodeTiled = VAEDecodeTiled()
|
31 |
+
opVAEEncodeTiled = VAEEncodeTiled()
|
32 |
+
opControlNetApplyAdvanced = ControlNetApplyAdvanced()
|
33 |
+
opFreeU = FreeU_V2()
|
34 |
+
opModelSamplingDiscrete = ModelSamplingDiscrete()
|
35 |
+
|
36 |
+
|
37 |
+
class StableDiffusionModel:
|
38 |
+
def __init__(self, unet=None, vae=None, clip=None, clip_vision=None, filename=None):
|
39 |
+
self.unet = unet
|
40 |
+
self.vae = vae
|
41 |
+
self.clip = clip
|
42 |
+
self.clip_vision = clip_vision
|
43 |
+
self.filename = filename
|
44 |
+
self.unet_with_lora = unet
|
45 |
+
self.clip_with_lora = clip
|
46 |
+
self.visited_loras = ''
|
47 |
+
|
48 |
+
self.lora_key_map_unet = {}
|
49 |
+
self.lora_key_map_clip = {}
|
50 |
+
|
51 |
+
if self.unet is not None:
|
52 |
+
self.lora_key_map_unet = model_lora_keys_unet(self.unet.model, self.lora_key_map_unet)
|
53 |
+
self.lora_key_map_unet.update({x: x for x in self.unet.model.state_dict().keys()})
|
54 |
+
|
55 |
+
if self.clip is not None:
|
56 |
+
self.lora_key_map_clip = model_lora_keys_clip(self.clip.cond_stage_model, self.lora_key_map_clip)
|
57 |
+
self.lora_key_map_clip.update({x: x for x in self.clip.cond_stage_model.state_dict().keys()})
|
58 |
+
|
59 |
+
@torch.no_grad()
|
60 |
+
@torch.inference_mode()
|
61 |
+
def refresh_loras(self, loras):
|
62 |
+
assert isinstance(loras, list)
|
63 |
+
|
64 |
+
if self.visited_loras == str(loras):
|
65 |
+
return
|
66 |
+
|
67 |
+
self.visited_loras = str(loras)
|
68 |
+
|
69 |
+
if self.unet is None:
|
70 |
+
return
|
71 |
+
|
72 |
+
print(f'Request to load LoRAs {str(loras)} for model [{self.filename}].')
|
73 |
+
|
74 |
+
loras_to_load = []
|
75 |
+
|
76 |
+
for name, weight in loras:
|
77 |
+
if name == 'None':
|
78 |
+
continue
|
79 |
+
|
80 |
+
if os.path.exists(name):
|
81 |
+
lora_filename = name
|
82 |
+
else:
|
83 |
+
lora_filename = get_file_from_folder_list(name, modules.config.paths_loras)
|
84 |
+
|
85 |
+
if not os.path.exists(lora_filename):
|
86 |
+
print(f'Lora file not found: {lora_filename}')
|
87 |
+
continue
|
88 |
+
|
89 |
+
loras_to_load.append((lora_filename, weight))
|
90 |
+
|
91 |
+
self.unet_with_lora = self.unet.clone() if self.unet is not None else None
|
92 |
+
self.clip_with_lora = self.clip.clone() if self.clip is not None else None
|
93 |
+
|
94 |
+
for lora_filename, weight in loras_to_load:
|
95 |
+
lora_unmatch = ldm_patched.modules.utils.load_torch_file(lora_filename, safe_load=False)
|
96 |
+
lora_unet, lora_unmatch = match_lora(lora_unmatch, self.lora_key_map_unet)
|
97 |
+
lora_clip, lora_unmatch = match_lora(lora_unmatch, self.lora_key_map_clip)
|
98 |
+
|
99 |
+
if len(lora_unmatch) > 12:
|
100 |
+
# model mismatch
|
101 |
+
continue
|
102 |
+
|
103 |
+
if len(lora_unmatch) > 0:
|
104 |
+
print(f'Loaded LoRA [{lora_filename}] for model [{self.filename}] '
|
105 |
+
f'with unmatched keys {list(lora_unmatch.keys())}')
|
106 |
+
|
107 |
+
if self.unet_with_lora is not None and len(lora_unet) > 0:
|
108 |
+
loaded_keys = self.unet_with_lora.add_patches(lora_unet, weight)
|
109 |
+
print(f'Loaded LoRA [{lora_filename}] for UNet [{self.filename}] '
|
110 |
+
f'with {len(loaded_keys)} keys at weight {weight}.')
|
111 |
+
for item in lora_unet:
|
112 |
+
if item not in loaded_keys:
|
113 |
+
print("UNet LoRA key skipped: ", item)
|
114 |
+
|
115 |
+
if self.clip_with_lora is not None and len(lora_clip) > 0:
|
116 |
+
loaded_keys = self.clip_with_lora.add_patches(lora_clip, weight)
|
117 |
+
print(f'Loaded LoRA [{lora_filename}] for CLIP [{self.filename}] '
|
118 |
+
f'with {len(loaded_keys)} keys at weight {weight}.')
|
119 |
+
for item in lora_clip:
|
120 |
+
if item not in loaded_keys:
|
121 |
+
print("CLIP LoRA key skipped: ", item)
|
122 |
+
|
123 |
+
|
124 |
+
@torch.no_grad()
|
125 |
+
@torch.inference_mode()
|
126 |
+
def apply_freeu(model, b1, b2, s1, s2):
|
127 |
+
return opFreeU.patch(model=model, b1=b1, b2=b2, s1=s1, s2=s2)[0]
|
128 |
+
|
129 |
+
|
130 |
+
@torch.no_grad()
|
131 |
+
@torch.inference_mode()
|
132 |
+
def load_controlnet(ckpt_filename):
|
133 |
+
return ldm_patched.modules.controlnet.load_controlnet(ckpt_filename)
|
134 |
+
|
135 |
+
|
136 |
+
@torch.no_grad()
|
137 |
+
@torch.inference_mode()
|
138 |
+
def apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent):
|
139 |
+
return opControlNetApplyAdvanced.apply_controlnet(positive=positive, negative=negative, control_net=control_net,
|
140 |
+
image=image, strength=strength, start_percent=start_percent, end_percent=end_percent)
|
141 |
+
|
142 |
+
|
143 |
+
@torch.no_grad()
|
144 |
+
@torch.inference_mode()
|
145 |
+
def load_model(ckpt_filename):
|
146 |
+
unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings)
|
147 |
+
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename)
|
148 |
+
|
149 |
+
|
150 |
+
@torch.no_grad()
|
151 |
+
@torch.inference_mode()
|
152 |
+
def generate_empty_latent(width=1024, height=1024, batch_size=1):
|
153 |
+
return opEmptyLatentImage.generate(width=width, height=height, batch_size=batch_size)[0]
|
154 |
+
|
155 |
+
|
156 |
+
@torch.no_grad()
|
157 |
+
@torch.inference_mode()
|
158 |
+
def decode_vae(vae, latent_image, tiled=False):
|
159 |
+
if tiled:
|
160 |
+
return opVAEDecodeTiled.decode(samples=latent_image, vae=vae, tile_size=512)[0]
|
161 |
+
else:
|
162 |
+
return opVAEDecode.decode(samples=latent_image, vae=vae)[0]
|
163 |
+
|
164 |
+
|
165 |
+
@torch.no_grad()
|
166 |
+
@torch.inference_mode()
|
167 |
+
def encode_vae(vae, pixels, tiled=False):
|
168 |
+
if tiled:
|
169 |
+
return opVAEEncodeTiled.encode(pixels=pixels, vae=vae, tile_size=512)[0]
|
170 |
+
else:
|
171 |
+
return opVAEEncode.encode(pixels=pixels, vae=vae)[0]
|
172 |
+
|
173 |
+
|
174 |
+
@torch.no_grad()
|
175 |
+
@torch.inference_mode()
|
176 |
+
def encode_vae_inpaint(vae, pixels, mask):
|
177 |
+
assert mask.ndim == 3 and pixels.ndim == 4
|
178 |
+
assert mask.shape[-1] == pixels.shape[-2]
|
179 |
+
assert mask.shape[-2] == pixels.shape[-3]
|
180 |
+
|
181 |
+
w = mask.round()[..., None]
|
182 |
+
pixels = pixels * (1 - w) + 0.5 * w
|
183 |
+
|
184 |
+
latent = vae.encode(pixels)
|
185 |
+
B, C, H, W = latent.shape
|
186 |
+
|
187 |
+
latent_mask = mask[:, None, :, :]
|
188 |
+
latent_mask = torch.nn.functional.interpolate(latent_mask, size=(H * 8, W * 8), mode="bilinear").round()
|
189 |
+
latent_mask = torch.nn.functional.max_pool2d(latent_mask, (8, 8)).round().to(latent)
|
190 |
+
|
191 |
+
return latent, latent_mask
|
192 |
+
|
193 |
+
|
194 |
+
class VAEApprox(torch.nn.Module):
|
195 |
+
def __init__(self):
|
196 |
+
super(VAEApprox, self).__init__()
|
197 |
+
self.conv1 = torch.nn.Conv2d(4, 8, (7, 7))
|
198 |
+
self.conv2 = torch.nn.Conv2d(8, 16, (5, 5))
|
199 |
+
self.conv3 = torch.nn.Conv2d(16, 32, (3, 3))
|
200 |
+
self.conv4 = torch.nn.Conv2d(32, 64, (3, 3))
|
201 |
+
self.conv5 = torch.nn.Conv2d(64, 32, (3, 3))
|
202 |
+
self.conv6 = torch.nn.Conv2d(32, 16, (3, 3))
|
203 |
+
self.conv7 = torch.nn.Conv2d(16, 8, (3, 3))
|
204 |
+
self.conv8 = torch.nn.Conv2d(8, 3, (3, 3))
|
205 |
+
self.current_type = None
|
206 |
+
|
207 |
+
def forward(self, x):
|
208 |
+
extra = 11
|
209 |
+
x = torch.nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
|
210 |
+
x = torch.nn.functional.pad(x, (extra, extra, extra, extra))
|
211 |
+
for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8]:
|
212 |
+
x = layer(x)
|
213 |
+
x = torch.nn.functional.leaky_relu(x, 0.1)
|
214 |
+
return x
|
215 |
+
|
216 |
+
|
217 |
+
VAE_approx_models = {}
|
218 |
+
|
219 |
+
|
220 |
+
@torch.no_grad()
|
221 |
+
@torch.inference_mode()
|
222 |
+
def get_previewer(model):
|
223 |
+
global VAE_approx_models
|
224 |
+
|
225 |
+
from modules.config import path_vae_approx
|
226 |
+
is_sdxl = isinstance(model.model.latent_format, ldm_patched.modules.latent_formats.SDXL)
|
227 |
+
vae_approx_filename = os.path.join(path_vae_approx, 'xlvaeapp.pth' if is_sdxl else 'vaeapp_sd15.pth')
|
228 |
+
|
229 |
+
if vae_approx_filename in VAE_approx_models:
|
230 |
+
VAE_approx_model = VAE_approx_models[vae_approx_filename]
|
231 |
+
else:
|
232 |
+
sd = torch.load(vae_approx_filename, map_location='cpu')
|
233 |
+
VAE_approx_model = VAEApprox()
|
234 |
+
VAE_approx_model.load_state_dict(sd)
|
235 |
+
del sd
|
236 |
+
VAE_approx_model.eval()
|
237 |
+
|
238 |
+
if ldm_patched.modules.model_management.should_use_fp16():
|
239 |
+
VAE_approx_model.half()
|
240 |
+
VAE_approx_model.current_type = torch.float16
|
241 |
+
else:
|
242 |
+
VAE_approx_model.float()
|
243 |
+
VAE_approx_model.current_type = torch.float32
|
244 |
+
|
245 |
+
VAE_approx_model.to(ldm_patched.modules.model_management.get_torch_device())
|
246 |
+
VAE_approx_models[vae_approx_filename] = VAE_approx_model
|
247 |
+
|
248 |
+
@torch.no_grad()
|
249 |
+
@torch.inference_mode()
|
250 |
+
def preview_function(x0, step, total_steps):
|
251 |
+
with torch.no_grad():
|
252 |
+
x_sample = x0.to(VAE_approx_model.current_type)
|
253 |
+
x_sample = VAE_approx_model(x_sample) * 127.5 + 127.5
|
254 |
+
x_sample = einops.rearrange(x_sample, 'b c h w -> b h w c')[0]
|
255 |
+
x_sample = x_sample.cpu().numpy().clip(0, 255).astype(np.uint8)
|
256 |
+
return x_sample
|
257 |
+
|
258 |
+
return preview_function
|
259 |
+
|
260 |
+
|
261 |
+
@torch.no_grad()
|
262 |
+
@torch.inference_mode()
|
263 |
+
def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sampler_name='dpmpp_2m_sde_gpu',
|
264 |
+
scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None,
|
265 |
+
force_full_denoise=False, callback_function=None, refiner=None, refiner_switch=-1,
|
266 |
+
previewer_start=None, previewer_end=None, sigmas=None, noise_mean=None, disable_preview=False):
|
267 |
+
|
268 |
+
if sigmas is not None:
|
269 |
+
sigmas = sigmas.clone().to(ldm_patched.modules.model_management.get_torch_device())
|
270 |
+
|
271 |
+
latent_image = latent["samples"]
|
272 |
+
|
273 |
+
if disable_noise:
|
274 |
+
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
275 |
+
else:
|
276 |
+
batch_inds = latent["batch_index"] if "batch_index" in latent else None
|
277 |
+
noise = ldm_patched.modules.sample.prepare_noise(latent_image, seed, batch_inds)
|
278 |
+
|
279 |
+
if isinstance(noise_mean, torch.Tensor):
|
280 |
+
noise = noise + noise_mean - torch.mean(noise, dim=1, keepdim=True)
|
281 |
+
|
282 |
+
noise_mask = None
|
283 |
+
if "noise_mask" in latent:
|
284 |
+
noise_mask = latent["noise_mask"]
|
285 |
+
|
286 |
+
previewer = get_previewer(model)
|
287 |
+
|
288 |
+
if previewer_start is None:
|
289 |
+
previewer_start = 0
|
290 |
+
|
291 |
+
if previewer_end is None:
|
292 |
+
previewer_end = steps
|
293 |
+
|
294 |
+
def callback(step, x0, x, total_steps):
|
295 |
+
ldm_patched.modules.model_management.throw_exception_if_processing_interrupted()
|
296 |
+
y = None
|
297 |
+
if previewer is not None and not disable_preview:
|
298 |
+
y = previewer(x0, previewer_start + step, previewer_end)
|
299 |
+
if callback_function is not None:
|
300 |
+
callback_function(previewer_start + step, x0, x, previewer_end, y)
|
301 |
+
|
302 |
+
disable_pbar = False
|
303 |
+
modules.sample_hijack.current_refiner = refiner
|
304 |
+
modules.sample_hijack.refiner_switch_step = refiner_switch
|
305 |
+
ldm_patched.modules.samplers.sample = modules.sample_hijack.sample_hacked
|
306 |
+
|
307 |
+
try:
|
308 |
+
samples = ldm_patched.modules.sample.sample(model,
|
309 |
+
noise, steps, cfg, sampler_name, scheduler,
|
310 |
+
positive, negative, latent_image,
|
311 |
+
denoise=denoise, disable_noise=disable_noise,
|
312 |
+
start_step=start_step,
|
313 |
+
last_step=last_step,
|
314 |
+
force_full_denoise=force_full_denoise, noise_mask=noise_mask,
|
315 |
+
callback=callback,
|
316 |
+
disable_pbar=disable_pbar, seed=seed, sigmas=sigmas)
|
317 |
+
|
318 |
+
out = latent.copy()
|
319 |
+
out["samples"] = samples
|
320 |
+
finally:
|
321 |
+
modules.sample_hijack.current_refiner = None
|
322 |
+
|
323 |
+
return out
|
324 |
+
|
325 |
+
|
326 |
+
@torch.no_grad()
|
327 |
+
@torch.inference_mode()
|
328 |
+
def pytorch_to_numpy(x):
|
329 |
+
return [np.clip(255. * y.cpu().numpy(), 0, 255).astype(np.uint8) for y in x]
|
330 |
+
|
331 |
+
|
332 |
+
@torch.no_grad()
|
333 |
+
@torch.inference_mode()
|
334 |
+
def numpy_to_pytorch(x):
|
335 |
+
y = x.astype(np.float32) / 255.0
|
336 |
+
y = y[None]
|
337 |
+
y = np.ascontiguousarray(y.copy())
|
338 |
+
y = torch.from_numpy(y).float()
|
339 |
+
return y
|
modules/default_pipeline.py
ADDED
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import modules.core as core
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import modules.patch
|
5 |
+
import modules.config
|
6 |
+
import ldm_patched.modules.model_management
|
7 |
+
import ldm_patched.modules.latent_formats
|
8 |
+
import modules.inpaint_worker
|
9 |
+
import extras.vae_interpose as vae_interpose
|
10 |
+
from extras.expansion import FooocusExpansion
|
11 |
+
|
12 |
+
from ldm_patched.modules.model_base import SDXL, SDXLRefiner
|
13 |
+
from modules.sample_hijack import clip_separate
|
14 |
+
from modules.util import get_file_from_folder_list
|
15 |
+
|
16 |
+
|
17 |
+
model_base = core.StableDiffusionModel()
|
18 |
+
model_refiner = core.StableDiffusionModel()
|
19 |
+
|
20 |
+
final_expansion = None
|
21 |
+
final_unet = None
|
22 |
+
final_clip = None
|
23 |
+
final_vae = None
|
24 |
+
final_refiner_unet = None
|
25 |
+
final_refiner_vae = None
|
26 |
+
|
27 |
+
loaded_ControlNets = {}
|
28 |
+
|
29 |
+
|
30 |
+
@torch.no_grad()
|
31 |
+
@torch.inference_mode()
|
32 |
+
def refresh_controlnets(model_paths):
|
33 |
+
global loaded_ControlNets
|
34 |
+
cache = {}
|
35 |
+
for p in model_paths:
|
36 |
+
if p is not None:
|
37 |
+
if p in loaded_ControlNets:
|
38 |
+
cache[p] = loaded_ControlNets[p]
|
39 |
+
else:
|
40 |
+
cache[p] = core.load_controlnet(p)
|
41 |
+
loaded_ControlNets = cache
|
42 |
+
return
|
43 |
+
|
44 |
+
|
45 |
+
@torch.no_grad()
|
46 |
+
@torch.inference_mode()
|
47 |
+
def assert_model_integrity():
|
48 |
+
error_message = None
|
49 |
+
|
50 |
+
if not isinstance(model_base.unet_with_lora.model, SDXL):
|
51 |
+
error_message = 'You have selected base model other than SDXL. This is not supported yet.'
|
52 |
+
|
53 |
+
if error_message is not None:
|
54 |
+
raise NotImplementedError(error_message)
|
55 |
+
|
56 |
+
return True
|
57 |
+
|
58 |
+
|
59 |
+
@torch.no_grad()
|
60 |
+
@torch.inference_mode()
|
61 |
+
def refresh_base_model(name):
|
62 |
+
global model_base
|
63 |
+
|
64 |
+
filename = get_file_from_folder_list(name, modules.config.paths_checkpoints)
|
65 |
+
|
66 |
+
if model_base.filename == filename:
|
67 |
+
return
|
68 |
+
|
69 |
+
model_base = core.StableDiffusionModel()
|
70 |
+
model_base = core.load_model(filename)
|
71 |
+
print(f'Base model loaded: {model_base.filename}')
|
72 |
+
return
|
73 |
+
|
74 |
+
|
75 |
+
@torch.no_grad()
|
76 |
+
@torch.inference_mode()
|
77 |
+
def refresh_refiner_model(name):
|
78 |
+
global model_refiner
|
79 |
+
|
80 |
+
filename = get_file_from_folder_list(name, modules.config.paths_checkpoints)
|
81 |
+
|
82 |
+
if model_refiner.filename == filename:
|
83 |
+
return
|
84 |
+
|
85 |
+
model_refiner = core.StableDiffusionModel()
|
86 |
+
|
87 |
+
if name == 'None':
|
88 |
+
print(f'Refiner unloaded.')
|
89 |
+
return
|
90 |
+
|
91 |
+
model_refiner = core.load_model(filename)
|
92 |
+
print(f'Refiner model loaded: {model_refiner.filename}')
|
93 |
+
|
94 |
+
if isinstance(model_refiner.unet.model, SDXL):
|
95 |
+
model_refiner.clip = None
|
96 |
+
model_refiner.vae = None
|
97 |
+
elif isinstance(model_refiner.unet.model, SDXLRefiner):
|
98 |
+
model_refiner.clip = None
|
99 |
+
model_refiner.vae = None
|
100 |
+
else:
|
101 |
+
model_refiner.clip = None
|
102 |
+
|
103 |
+
return
|
104 |
+
|
105 |
+
|
106 |
+
@torch.no_grad()
|
107 |
+
@torch.inference_mode()
|
108 |
+
def synthesize_refiner_model():
|
109 |
+
global model_base, model_refiner
|
110 |
+
|
111 |
+
print('Synthetic Refiner Activated')
|
112 |
+
model_refiner = core.StableDiffusionModel(
|
113 |
+
unet=model_base.unet,
|
114 |
+
vae=model_base.vae,
|
115 |
+
clip=model_base.clip,
|
116 |
+
clip_vision=model_base.clip_vision,
|
117 |
+
filename=model_base.filename
|
118 |
+
)
|
119 |
+
model_refiner.vae = None
|
120 |
+
model_refiner.clip = None
|
121 |
+
model_refiner.clip_vision = None
|
122 |
+
|
123 |
+
return
|
124 |
+
|
125 |
+
|
126 |
+
@torch.no_grad()
|
127 |
+
@torch.inference_mode()
|
128 |
+
def refresh_loras(loras, base_model_additional_loras=None):
|
129 |
+
global model_base, model_refiner
|
130 |
+
|
131 |
+
if not isinstance(base_model_additional_loras, list):
|
132 |
+
base_model_additional_loras = []
|
133 |
+
|
134 |
+
model_base.refresh_loras(loras + base_model_additional_loras)
|
135 |
+
model_refiner.refresh_loras(loras)
|
136 |
+
|
137 |
+
return
|
138 |
+
|
139 |
+
|
140 |
+
@torch.no_grad()
|
141 |
+
@torch.inference_mode()
|
142 |
+
def clip_encode_single(clip, text, verbose=False):
|
143 |
+
cached = clip.fcs_cond_cache.get(text, None)
|
144 |
+
if cached is not None:
|
145 |
+
if verbose:
|
146 |
+
print(f'[CLIP Cached] {text}')
|
147 |
+
return cached
|
148 |
+
tokens = clip.tokenize(text)
|
149 |
+
result = clip.encode_from_tokens(tokens, return_pooled=True)
|
150 |
+
clip.fcs_cond_cache[text] = result
|
151 |
+
if verbose:
|
152 |
+
print(f'[CLIP Encoded] {text}')
|
153 |
+
return result
|
154 |
+
|
155 |
+
|
156 |
+
@torch.no_grad()
|
157 |
+
@torch.inference_mode()
|
158 |
+
def clone_cond(conds):
|
159 |
+
results = []
|
160 |
+
|
161 |
+
for c, p in conds:
|
162 |
+
p = p["pooled_output"]
|
163 |
+
|
164 |
+
if isinstance(c, torch.Tensor):
|
165 |
+
c = c.clone()
|
166 |
+
|
167 |
+
if isinstance(p, torch.Tensor):
|
168 |
+
p = p.clone()
|
169 |
+
|
170 |
+
results.append([c, {"pooled_output": p}])
|
171 |
+
|
172 |
+
return results
|
173 |
+
|
174 |
+
|
175 |
+
@torch.no_grad()
|
176 |
+
@torch.inference_mode()
|
177 |
+
def clip_encode(texts, pool_top_k=1):
|
178 |
+
global final_clip
|
179 |
+
|
180 |
+
if final_clip is None:
|
181 |
+
return None
|
182 |
+
if not isinstance(texts, list):
|
183 |
+
return None
|
184 |
+
if len(texts) == 0:
|
185 |
+
return None
|
186 |
+
|
187 |
+
cond_list = []
|
188 |
+
pooled_acc = 0
|
189 |
+
|
190 |
+
for i, text in enumerate(texts):
|
191 |
+
cond, pooled = clip_encode_single(final_clip, text)
|
192 |
+
cond_list.append(cond)
|
193 |
+
if i < pool_top_k:
|
194 |
+
pooled_acc += pooled
|
195 |
+
|
196 |
+
return [[torch.cat(cond_list, dim=1), {"pooled_output": pooled_acc}]]
|
197 |
+
|
198 |
+
|
199 |
+
@torch.no_grad()
|
200 |
+
@torch.inference_mode()
|
201 |
+
def clear_all_caches():
|
202 |
+
final_clip.fcs_cond_cache = {}
|
203 |
+
|
204 |
+
|
205 |
+
@torch.no_grad()
|
206 |
+
@torch.inference_mode()
|
207 |
+
def prepare_text_encoder(async_call=True):
|
208 |
+
if async_call:
|
209 |
+
# TODO: make sure that this is always called in an async way so that users cannot feel it.
|
210 |
+
pass
|
211 |
+
assert_model_integrity()
|
212 |
+
ldm_patched.modules.model_management.load_models_gpu([final_clip.patcher, final_expansion.patcher])
|
213 |
+
return
|
214 |
+
|
215 |
+
|
216 |
+
@torch.no_grad()
|
217 |
+
@torch.inference_mode()
|
218 |
+
def refresh_everything(refiner_model_name, base_model_name, loras,
|
219 |
+
base_model_additional_loras=None, use_synthetic_refiner=False):
|
220 |
+
global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, final_expansion
|
221 |
+
|
222 |
+
final_unet = None
|
223 |
+
final_clip = None
|
224 |
+
final_vae = None
|
225 |
+
final_refiner_unet = None
|
226 |
+
final_refiner_vae = None
|
227 |
+
|
228 |
+
if use_synthetic_refiner and refiner_model_name == 'None':
|
229 |
+
print('Synthetic Refiner Activated')
|
230 |
+
refresh_base_model(base_model_name)
|
231 |
+
synthesize_refiner_model()
|
232 |
+
else:
|
233 |
+
refresh_refiner_model(refiner_model_name)
|
234 |
+
refresh_base_model(base_model_name)
|
235 |
+
|
236 |
+
refresh_loras(loras, base_model_additional_loras=base_model_additional_loras)
|
237 |
+
assert_model_integrity()
|
238 |
+
|
239 |
+
final_unet = model_base.unet_with_lora
|
240 |
+
final_clip = model_base.clip_with_lora
|
241 |
+
final_vae = model_base.vae
|
242 |
+
|
243 |
+
final_refiner_unet = model_refiner.unet_with_lora
|
244 |
+
final_refiner_vae = model_refiner.vae
|
245 |
+
|
246 |
+
if final_expansion is None:
|
247 |
+
final_expansion = FooocusExpansion()
|
248 |
+
|
249 |
+
prepare_text_encoder(async_call=True)
|
250 |
+
clear_all_caches()
|
251 |
+
return
|
252 |
+
|
253 |
+
|
254 |
+
refresh_everything(
|
255 |
+
refiner_model_name=modules.config.default_refiner_model_name,
|
256 |
+
base_model_name=modules.config.default_base_model_name,
|
257 |
+
loras=modules.config.default_loras
|
258 |
+
)
|
259 |
+
|
260 |
+
|
261 |
+
@torch.no_grad()
|
262 |
+
@torch.inference_mode()
|
263 |
+
def vae_parse(latent):
|
264 |
+
if final_refiner_vae is None:
|
265 |
+
return latent
|
266 |
+
|
267 |
+
result = vae_interpose.parse(latent["samples"])
|
268 |
+
return {'samples': result}
|
269 |
+
|
270 |
+
|
271 |
+
@torch.no_grad()
|
272 |
+
@torch.inference_mode()
|
273 |
+
def calculate_sigmas_all(sampler, model, scheduler, steps):
|
274 |
+
from ldm_patched.modules.samplers import calculate_sigmas_scheduler
|
275 |
+
|
276 |
+
discard_penultimate_sigma = False
|
277 |
+
if sampler in ['dpm_2', 'dpm_2_ancestral']:
|
278 |
+
steps += 1
|
279 |
+
discard_penultimate_sigma = True
|
280 |
+
|
281 |
+
sigmas = calculate_sigmas_scheduler(model, scheduler, steps)
|
282 |
+
|
283 |
+
if discard_penultimate_sigma:
|
284 |
+
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
285 |
+
return sigmas
|
286 |
+
|
287 |
+
|
288 |
+
@torch.no_grad()
|
289 |
+
@torch.inference_mode()
|
290 |
+
def calculate_sigmas(sampler, model, scheduler, steps, denoise):
|
291 |
+
if denoise is None or denoise > 0.9999:
|
292 |
+
sigmas = calculate_sigmas_all(sampler, model, scheduler, steps)
|
293 |
+
else:
|
294 |
+
new_steps = int(steps / denoise)
|
295 |
+
sigmas = calculate_sigmas_all(sampler, model, scheduler, new_steps)
|
296 |
+
sigmas = sigmas[-(steps + 1):]
|
297 |
+
return sigmas
|
298 |
+
|
299 |
+
|
300 |
+
@torch.no_grad()
|
301 |
+
@torch.inference_mode()
|
302 |
+
def get_candidate_vae(steps, switch, denoise=1.0, refiner_swap_method='joint'):
|
303 |
+
assert refiner_swap_method in ['joint', 'separate', 'vae']
|
304 |
+
|
305 |
+
if final_refiner_vae is not None and final_refiner_unet is not None:
|
306 |
+
if denoise > 0.9:
|
307 |
+
return final_vae, final_refiner_vae
|
308 |
+
else:
|
309 |
+
if denoise > (float(steps - switch) / float(steps)) ** 0.834: # karras 0.834
|
310 |
+
return final_vae, None
|
311 |
+
else:
|
312 |
+
return final_refiner_vae, None
|
313 |
+
|
314 |
+
return final_vae, final_refiner_vae
|
315 |
+
|
316 |
+
|
317 |
+
@torch.no_grad()
|
318 |
+
@torch.inference_mode()
|
319 |
+
def process_diffusion(positive_cond, negative_cond, steps, switch, width, height, image_seed, callback, sampler_name, scheduler_name, latent=None, denoise=1.0, tiled=False, cfg_scale=7.0, refiner_swap_method='joint', disable_preview=False):
|
320 |
+
target_unet, target_vae, target_refiner_unet, target_refiner_vae, target_clip \
|
321 |
+
= final_unet, final_vae, final_refiner_unet, final_refiner_vae, final_clip
|
322 |
+
|
323 |
+
assert refiner_swap_method in ['joint', 'separate', 'vae']
|
324 |
+
|
325 |
+
if final_refiner_vae is not None and final_refiner_unet is not None:
|
326 |
+
# Refiner Use Different VAE (then it is SD15)
|
327 |
+
if denoise > 0.9:
|
328 |
+
refiner_swap_method = 'vae'
|
329 |
+
else:
|
330 |
+
refiner_swap_method = 'joint'
|
331 |
+
if denoise > (float(steps - switch) / float(steps)) ** 0.834: # karras 0.834
|
332 |
+
target_unet, target_vae, target_refiner_unet, target_refiner_vae \
|
333 |
+
= final_unet, final_vae, None, None
|
334 |
+
print(f'[Sampler] only use Base because of partial denoise.')
|
335 |
+
else:
|
336 |
+
positive_cond = clip_separate(positive_cond, target_model=final_refiner_unet.model, target_clip=final_clip)
|
337 |
+
negative_cond = clip_separate(negative_cond, target_model=final_refiner_unet.model, target_clip=final_clip)
|
338 |
+
target_unet, target_vae, target_refiner_unet, target_refiner_vae \
|
339 |
+
= final_refiner_unet, final_refiner_vae, None, None
|
340 |
+
print(f'[Sampler] only use Refiner because of partial denoise.')
|
341 |
+
|
342 |
+
print(f'[Sampler] refiner_swap_method = {refiner_swap_method}')
|
343 |
+
|
344 |
+
if latent is None:
|
345 |
+
initial_latent = core.generate_empty_latent(width=width, height=height, batch_size=1)
|
346 |
+
else:
|
347 |
+
initial_latent = latent
|
348 |
+
|
349 |
+
minmax_sigmas = calculate_sigmas(sampler=sampler_name, scheduler=scheduler_name, model=final_unet.model, steps=steps, denoise=denoise)
|
350 |
+
sigma_min, sigma_max = minmax_sigmas[minmax_sigmas > 0].min(), minmax_sigmas.max()
|
351 |
+
sigma_min = float(sigma_min.cpu().numpy())
|
352 |
+
sigma_max = float(sigma_max.cpu().numpy())
|
353 |
+
print(f'[Sampler] sigma_min = {sigma_min}, sigma_max = {sigma_max}')
|
354 |
+
|
355 |
+
modules.patch.BrownianTreeNoiseSamplerPatched.global_init(
|
356 |
+
initial_latent['samples'].to(ldm_patched.modules.model_management.get_torch_device()),
|
357 |
+
sigma_min, sigma_max, seed=image_seed, cpu=False)
|
358 |
+
|
359 |
+
decoded_latent = None
|
360 |
+
|
361 |
+
if refiner_swap_method == 'joint':
|
362 |
+
sampled_latent = core.ksampler(
|
363 |
+
model=target_unet,
|
364 |
+
refiner=target_refiner_unet,
|
365 |
+
positive=positive_cond,
|
366 |
+
negative=negative_cond,
|
367 |
+
latent=initial_latent,
|
368 |
+
steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True,
|
369 |
+
seed=image_seed,
|
370 |
+
denoise=denoise,
|
371 |
+
callback_function=callback,
|
372 |
+
cfg=cfg_scale,
|
373 |
+
sampler_name=sampler_name,
|
374 |
+
scheduler=scheduler_name,
|
375 |
+
refiner_switch=switch,
|
376 |
+
previewer_start=0,
|
377 |
+
previewer_end=steps,
|
378 |
+
disable_preview=disable_preview
|
379 |
+
)
|
380 |
+
decoded_latent = core.decode_vae(vae=target_vae, latent_image=sampled_latent, tiled=tiled)
|
381 |
+
|
382 |
+
if refiner_swap_method == 'separate':
|
383 |
+
sampled_latent = core.ksampler(
|
384 |
+
model=target_unet,
|
385 |
+
positive=positive_cond,
|
386 |
+
negative=negative_cond,
|
387 |
+
latent=initial_latent,
|
388 |
+
steps=steps, start_step=0, last_step=switch, disable_noise=False, force_full_denoise=False,
|
389 |
+
seed=image_seed,
|
390 |
+
denoise=denoise,
|
391 |
+
callback_function=callback,
|
392 |
+
cfg=cfg_scale,
|
393 |
+
sampler_name=sampler_name,
|
394 |
+
scheduler=scheduler_name,
|
395 |
+
previewer_start=0,
|
396 |
+
previewer_end=steps,
|
397 |
+
disable_preview=disable_preview
|
398 |
+
)
|
399 |
+
print('Refiner swapped by changing ksampler. Noise preserved.')
|
400 |
+
|
401 |
+
target_model = target_refiner_unet
|
402 |
+
if target_model is None:
|
403 |
+
target_model = target_unet
|
404 |
+
print('Use base model to refine itself - this may because of developer mode.')
|
405 |
+
|
406 |
+
sampled_latent = core.ksampler(
|
407 |
+
model=target_model,
|
408 |
+
positive=clip_separate(positive_cond, target_model=target_model.model, target_clip=target_clip),
|
409 |
+
negative=clip_separate(negative_cond, target_model=target_model.model, target_clip=target_clip),
|
410 |
+
latent=sampled_latent,
|
411 |
+
steps=steps, start_step=switch, last_step=steps, disable_noise=True, force_full_denoise=True,
|
412 |
+
seed=image_seed,
|
413 |
+
denoise=denoise,
|
414 |
+
callback_function=callback,
|
415 |
+
cfg=cfg_scale,
|
416 |
+
sampler_name=sampler_name,
|
417 |
+
scheduler=scheduler_name,
|
418 |
+
previewer_start=switch,
|
419 |
+
previewer_end=steps,
|
420 |
+
disable_preview=disable_preview
|
421 |
+
)
|
422 |
+
|
423 |
+
target_model = target_refiner_vae
|
424 |
+
if target_model is None:
|
425 |
+
target_model = target_vae
|
426 |
+
decoded_latent = core.decode_vae(vae=target_model, latent_image=sampled_latent, tiled=tiled)
|
427 |
+
|
428 |
+
if refiner_swap_method == 'vae':
|
429 |
+
modules.patch.patch_settings[os.getpid()].eps_record = 'vae'
|
430 |
+
|
431 |
+
if modules.inpaint_worker.current_task is not None:
|
432 |
+
modules.inpaint_worker.current_task.unswap()
|
433 |
+
|
434 |
+
sampled_latent = core.ksampler(
|
435 |
+
model=target_unet,
|
436 |
+
positive=positive_cond,
|
437 |
+
negative=negative_cond,
|
438 |
+
latent=initial_latent,
|
439 |
+
steps=steps, start_step=0, last_step=switch, disable_noise=False, force_full_denoise=True,
|
440 |
+
seed=image_seed,
|
441 |
+
denoise=denoise,
|
442 |
+
callback_function=callback,
|
443 |
+
cfg=cfg_scale,
|
444 |
+
sampler_name=sampler_name,
|
445 |
+
scheduler=scheduler_name,
|
446 |
+
previewer_start=0,
|
447 |
+
previewer_end=steps,
|
448 |
+
disable_preview=disable_preview
|
449 |
+
)
|
450 |
+
print('Fooocus VAE-based swap.')
|
451 |
+
|
452 |
+
target_model = target_refiner_unet
|
453 |
+
if target_model is None:
|
454 |
+
target_model = target_unet
|
455 |
+
print('Use base model to refine itself - this may because of developer mode.')
|
456 |
+
|
457 |
+
sampled_latent = vae_parse(sampled_latent)
|
458 |
+
|
459 |
+
k_sigmas = 1.4
|
460 |
+
sigmas = calculate_sigmas(sampler=sampler_name,
|
461 |
+
scheduler=scheduler_name,
|
462 |
+
model=target_model.model,
|
463 |
+
steps=steps,
|
464 |
+
denoise=denoise)[switch:] * k_sigmas
|
465 |
+
len_sigmas = len(sigmas) - 1
|
466 |
+
|
467 |
+
noise_mean = torch.mean(modules.patch.patch_settings[os.getpid()].eps_record, dim=1, keepdim=True)
|
468 |
+
|
469 |
+
if modules.inpaint_worker.current_task is not None:
|
470 |
+
modules.inpaint_worker.current_task.swap()
|
471 |
+
|
472 |
+
sampled_latent = core.ksampler(
|
473 |
+
model=target_model,
|
474 |
+
positive=clip_separate(positive_cond, target_model=target_model.model, target_clip=target_clip),
|
475 |
+
negative=clip_separate(negative_cond, target_model=target_model.model, target_clip=target_clip),
|
476 |
+
latent=sampled_latent,
|
477 |
+
steps=len_sigmas, start_step=0, last_step=len_sigmas, disable_noise=False, force_full_denoise=True,
|
478 |
+
seed=image_seed+1,
|
479 |
+
denoise=denoise,
|
480 |
+
callback_function=callback,
|
481 |
+
cfg=cfg_scale,
|
482 |
+
sampler_name=sampler_name,
|
483 |
+
scheduler=scheduler_name,
|
484 |
+
previewer_start=switch,
|
485 |
+
previewer_end=steps,
|
486 |
+
sigmas=sigmas,
|
487 |
+
noise_mean=noise_mean,
|
488 |
+
disable_preview=disable_preview
|
489 |
+
)
|
490 |
+
|
491 |
+
target_model = target_refiner_vae
|
492 |
+
if target_model is None:
|
493 |
+
target_model = target_vae
|
494 |
+
decoded_latent = core.decode_vae(vae=target_model, latent_image=sampled_latent, tiled=tiled)
|
495 |
+
|
496 |
+
images = core.pytorch_to_numpy(decoded_latent)
|
497 |
+
modules.patch.patch_settings[os.getpid()].eps_record = None
|
498 |
+
return images
|
modules/flags.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import IntEnum, Enum
|
2 |
+
|
3 |
+
disabled = 'Disabled'
|
4 |
+
enabled = 'Enabled'
|
5 |
+
subtle_variation = 'Vary (Subtle)'
|
6 |
+
strong_variation = 'Vary (Strong)'
|
7 |
+
upscale_15 = 'Upscale (1.5x)'
|
8 |
+
upscale_2 = 'Upscale (2x)'
|
9 |
+
upscale_fast = 'Upscale (Fast 2x)'
|
10 |
+
|
11 |
+
uov_list = [
|
12 |
+
disabled, subtle_variation, strong_variation, upscale_15, upscale_2, upscale_fast
|
13 |
+
]
|
14 |
+
|
15 |
+
CIVITAI_NO_KARRAS = ["euler", "euler_ancestral", "heun", "dpm_fast", "dpm_adaptive", "ddim", "uni_pc"]
|
16 |
+
|
17 |
+
# fooocus: a1111 (Civitai)
|
18 |
+
KSAMPLER = {
|
19 |
+
"euler": "Euler",
|
20 |
+
"euler_ancestral": "Euler a",
|
21 |
+
"heun": "Heun",
|
22 |
+
"heunpp2": "",
|
23 |
+
"dpm_2": "DPM2",
|
24 |
+
"dpm_2_ancestral": "DPM2 a",
|
25 |
+
"lms": "LMS",
|
26 |
+
"dpm_fast": "DPM fast",
|
27 |
+
"dpm_adaptive": "DPM adaptive",
|
28 |
+
"dpmpp_2s_ancestral": "DPM++ 2S a",
|
29 |
+
"dpmpp_sde": "DPM++ SDE",
|
30 |
+
"dpmpp_sde_gpu": "DPM++ SDE",
|
31 |
+
"dpmpp_2m": "DPM++ 2M",
|
32 |
+
"dpmpp_2m_sde": "DPM++ 2M SDE",
|
33 |
+
"dpmpp_2m_sde_gpu": "DPM++ 2M SDE",
|
34 |
+
"dpmpp_3m_sde": "",
|
35 |
+
"dpmpp_3m_sde_gpu": "",
|
36 |
+
"ddpm": "",
|
37 |
+
"lcm": "LCM"
|
38 |
+
}
|
39 |
+
|
40 |
+
SAMPLER_EXTRA = {
|
41 |
+
"ddim": "DDIM",
|
42 |
+
"uni_pc": "UniPC",
|
43 |
+
"uni_pc_bh2": ""
|
44 |
+
}
|
45 |
+
|
46 |
+
SAMPLERS = KSAMPLER | SAMPLER_EXTRA
|
47 |
+
|
48 |
+
KSAMPLER_NAMES = list(KSAMPLER.keys())
|
49 |
+
|
50 |
+
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo"]
|
51 |
+
SAMPLER_NAMES = KSAMPLER_NAMES + list(SAMPLER_EXTRA.keys())
|
52 |
+
|
53 |
+
sampler_list = SAMPLER_NAMES
|
54 |
+
scheduler_list = SCHEDULER_NAMES
|
55 |
+
|
56 |
+
refiner_swap_method = 'joint'
|
57 |
+
|
58 |
+
cn_ip = "ImagePrompt"
|
59 |
+
cn_ip_face = "FaceSwap"
|
60 |
+
cn_canny = "PyraCanny"
|
61 |
+
cn_cpds = "CPDS"
|
62 |
+
|
63 |
+
ip_list = [cn_ip, cn_canny, cn_cpds, cn_ip_face]
|
64 |
+
default_ip = cn_ip
|
65 |
+
|
66 |
+
default_parameters = {
|
67 |
+
cn_ip: (0.5, 0.6), cn_ip_face: (0.9, 0.75), cn_canny: (0.5, 1.0), cn_cpds: (0.5, 1.0)
|
68 |
+
} # stop, weight
|
69 |
+
|
70 |
+
output_formats = ['png', 'jpg', 'webp']
|
71 |
+
|
72 |
+
inpaint_engine_versions = ['None', 'v1', 'v2.5', 'v2.6']
|
73 |
+
inpaint_option_default = 'Inpaint or Outpaint (default)'
|
74 |
+
inpaint_option_detail = 'Improve Detail (face, hand, eyes, etc.)'
|
75 |
+
inpaint_option_modify = 'Modify Content (add objects, change background, etc.)'
|
76 |
+
inpaint_options = [inpaint_option_default, inpaint_option_detail, inpaint_option_modify]
|
77 |
+
|
78 |
+
desc_type_photo = 'Photograph'
|
79 |
+
desc_type_anime = 'Art/Anime'
|
80 |
+
|
81 |
+
|
82 |
+
class MetadataScheme(Enum):
|
83 |
+
FOOOCUS = 'fooocus'
|
84 |
+
A1111 = 'a1111'
|
85 |
+
|
86 |
+
|
87 |
+
metadata_scheme = [
|
88 |
+
(f'{MetadataScheme.FOOOCUS.value} (json)', MetadataScheme.FOOOCUS.value),
|
89 |
+
(f'{MetadataScheme.A1111.value} (plain text)', MetadataScheme.A1111.value),
|
90 |
+
]
|
91 |
+
|
92 |
+
lora_count = 5
|
93 |
+
|
94 |
+
controlnet_image_count = 4
|
95 |
+
|
96 |
+
|
97 |
+
class Steps(IntEnum):
|
98 |
+
QUALITY = 60
|
99 |
+
SPEED = 30
|
100 |
+
EXTREME_SPEED = 8
|
101 |
+
|
102 |
+
|
103 |
+
class StepsUOV(IntEnum):
|
104 |
+
QUALITY = 36
|
105 |
+
SPEED = 18
|
106 |
+
EXTREME_SPEED = 8
|
107 |
+
|
108 |
+
|
109 |
+
class Performance(Enum):
|
110 |
+
QUALITY = 'Quality'
|
111 |
+
SPEED = 'Speed'
|
112 |
+
EXTREME_SPEED = 'Extreme Speed'
|
113 |
+
|
114 |
+
@classmethod
|
115 |
+
def list(cls) -> list:
|
116 |
+
return list(map(lambda c: c.value, cls))
|
117 |
+
|
118 |
+
def steps(self) -> int | None:
|
119 |
+
return Steps[self.name].value if Steps[self.name] else None
|
120 |
+
|
121 |
+
def steps_uov(self) -> int | None:
|
122 |
+
return StepsUOV[self.name].value if Steps[self.name] else None
|
123 |
+
|
124 |
+
|
125 |
+
performance_selections = Performance.list()
|
modules/gradio_hijack.py
ADDED
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""gr.Image() component."""
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import warnings
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Any, Literal
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import PIL
|
11 |
+
import PIL.ImageOps
|
12 |
+
import gradio.routes
|
13 |
+
import importlib
|
14 |
+
|
15 |
+
from gradio_client import utils as client_utils
|
16 |
+
from gradio_client.documentation import document, set_documentation_group
|
17 |
+
from gradio_client.serializing import ImgSerializable
|
18 |
+
from PIL import Image as _Image # using _ to minimize namespace pollution
|
19 |
+
|
20 |
+
from gradio import processing_utils, utils
|
21 |
+
from gradio.components.base import IOComponent, _Keywords, Block
|
22 |
+
from gradio.deprecation import warn_style_method_deprecation
|
23 |
+
from gradio.events import (
|
24 |
+
Changeable,
|
25 |
+
Clearable,
|
26 |
+
Editable,
|
27 |
+
EventListenerMethod,
|
28 |
+
Selectable,
|
29 |
+
Streamable,
|
30 |
+
Uploadable,
|
31 |
+
)
|
32 |
+
from gradio.interpretation import TokenInterpretable
|
33 |
+
|
34 |
+
set_documentation_group("component")
|
35 |
+
_Image.init() # fixes https://github.com/gradio-app/gradio/issues/2843
|
36 |
+
|
37 |
+
|
38 |
+
@document()
|
39 |
+
class Image(
|
40 |
+
Editable,
|
41 |
+
Clearable,
|
42 |
+
Changeable,
|
43 |
+
Streamable,
|
44 |
+
Selectable,
|
45 |
+
Uploadable,
|
46 |
+
IOComponent,
|
47 |
+
ImgSerializable,
|
48 |
+
TokenInterpretable,
|
49 |
+
):
|
50 |
+
"""
|
51 |
+
Creates an image component that can be used to upload/draw images (as an input) or display images (as an output).
|
52 |
+
Preprocessing: passes the uploaded image as a {numpy.array}, {PIL.Image} or {str} filepath depending on `type` -- unless `tool` is `sketch` AND source is one of `upload` or `webcam`. In these cases, a {dict} with keys `image` and `mask` is passed, and the format of the corresponding values depends on `type`.
|
53 |
+
Postprocessing: expects a {numpy.array}, {PIL.Image} or {str} or {pathlib.Path} filepath to an image and displays the image.
|
54 |
+
Examples-format: a {str} filepath to a local file that contains the image.
|
55 |
+
Demos: image_mod, image_mod_default_image
|
56 |
+
Guides: image-classification-in-pytorch, image-classification-in-tensorflow, image-classification-with-vision-transformers, building-a-pictionary_app, create-your-own-friends-with-a-gan
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
value: str | _Image.Image | np.ndarray | None = None,
|
62 |
+
*,
|
63 |
+
shape: tuple[int, int] | None = None,
|
64 |
+
height: int | None = None,
|
65 |
+
width: int | None = None,
|
66 |
+
image_mode: Literal[
|
67 |
+
"1", "L", "P", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"
|
68 |
+
] = "RGB",
|
69 |
+
invert_colors: bool = False,
|
70 |
+
source: Literal["upload", "webcam", "canvas"] = "upload",
|
71 |
+
tool: Literal["editor", "select", "sketch", "color-sketch"] | None = None,
|
72 |
+
type: Literal["numpy", "pil", "filepath"] = "numpy",
|
73 |
+
label: str | None = None,
|
74 |
+
every: float | None = None,
|
75 |
+
show_label: bool | None = None,
|
76 |
+
show_download_button: bool = True,
|
77 |
+
container: bool = True,
|
78 |
+
scale: int | None = None,
|
79 |
+
min_width: int = 160,
|
80 |
+
interactive: bool | None = None,
|
81 |
+
visible: bool = True,
|
82 |
+
streaming: bool = False,
|
83 |
+
elem_id: str | None = None,
|
84 |
+
elem_classes: list[str] | str | None = None,
|
85 |
+
mirror_webcam: bool = True,
|
86 |
+
brush_radius: float | None = None,
|
87 |
+
brush_color: str = "#000000",
|
88 |
+
mask_opacity: float = 0.7,
|
89 |
+
show_share_button: bool | None = None,
|
90 |
+
**kwargs,
|
91 |
+
):
|
92 |
+
"""
|
93 |
+
Parameters:
|
94 |
+
value: A PIL Image, numpy array, path or URL for the default value that Image component is going to take. If callable, the function will be called whenever the app loads to set the initial value of the component.
|
95 |
+
shape: (width, height) shape to crop and resize image when passed to function. If None, matches input image size. Pass None for either width or height to only crop and resize the other.
|
96 |
+
height: Height of the displayed image in pixels.
|
97 |
+
width: Width of the displayed image in pixels.
|
98 |
+
image_mode: "RGB" if color, or "L" if black and white. See https://pillow.readthedocs.io/en/stable/handbook/concepts.html for other supported image modes and their meaning.
|
99 |
+
invert_colors: whether to invert the image as a preprocessing step.
|
100 |
+
source: Source of image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "canvas" defaults to a white image that can be edited and drawn upon with tools.
|
101 |
+
tool: Tools used for editing. "editor" allows a full screen editor (and is the default if source is "upload" or "webcam"), "select" provides a cropping and zoom tool, "sketch" allows you to create a binary sketch (and is the default if source="canvas"), and "color-sketch" allows you to created a sketch in different colors. "color-sketch" can be used with source="upload" or "webcam" to allow sketching on an image. "sketch" can also be used with "upload" or "webcam" to create a mask over an image and in that case both the image and mask are passed into the function as a dictionary with keys "image" and "mask" respectively.
|
102 |
+
type: The format the image is converted to before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (height, width, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "filepath" passes a str path to a temporary file containing the image.
|
103 |
+
label: component name in interface.
|
104 |
+
every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute.
|
105 |
+
show_label: if True, will display label.
|
106 |
+
show_download_button: If True, will display button to download image.
|
107 |
+
container: If True, will place the component in a container - providing some extra padding around the border.
|
108 |
+
scale: relative width compared to adjacent Components in a Row. For example, if Component A has scale=2, and Component B has scale=1, A will be twice as wide as B. Should be an integer.
|
109 |
+
min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first.
|
110 |
+
interactive: if True, will allow users to upload and edit an image; if False, can only be used to display images. If not provided, this is inferred based on whether the component is used as an input or output.
|
111 |
+
visible: If False, component will be hidden.
|
112 |
+
streaming: If True when used in a `live` interface, will automatically stream webcam feed. Only valid is source is 'webcam'.
|
113 |
+
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
|
114 |
+
elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.
|
115 |
+
mirror_webcam: If True webcam will be mirrored. Default is True.
|
116 |
+
brush_radius: Size of the brush for Sketch. Default is None which chooses a sensible default
|
117 |
+
brush_color: Color of the brush for Sketch as hex string. Default is "#000000".
|
118 |
+
mask_opacity: Opacity of mask drawn on image, as a value between 0 and 1.
|
119 |
+
show_share_button: If True, will show a share icon in the corner of the component that allows user to share outputs to Hugging Face Spaces Discussions. If False, icon does not appear. If set to None (default behavior), then the icon appears if this Gradio app is launched on Spaces, but not otherwise.
|
120 |
+
"""
|
121 |
+
self.brush_radius = brush_radius
|
122 |
+
self.brush_color = brush_color
|
123 |
+
self.mask_opacity = mask_opacity
|
124 |
+
self.mirror_webcam = mirror_webcam
|
125 |
+
valid_types = ["numpy", "pil", "filepath"]
|
126 |
+
if type not in valid_types:
|
127 |
+
raise ValueError(
|
128 |
+
f"Invalid value for parameter `type`: {type}. Please choose from one of: {valid_types}"
|
129 |
+
)
|
130 |
+
self.type = type
|
131 |
+
self.shape = shape
|
132 |
+
self.height = height
|
133 |
+
self.width = width
|
134 |
+
self.image_mode = image_mode
|
135 |
+
valid_sources = ["upload", "webcam", "canvas"]
|
136 |
+
if source not in valid_sources:
|
137 |
+
raise ValueError(
|
138 |
+
f"Invalid value for parameter `source`: {source}. Please choose from one of: {valid_sources}"
|
139 |
+
)
|
140 |
+
self.source = source
|
141 |
+
if tool is None:
|
142 |
+
self.tool = "sketch" if source == "canvas" else "editor"
|
143 |
+
else:
|
144 |
+
self.tool = tool
|
145 |
+
self.invert_colors = invert_colors
|
146 |
+
self.streaming = streaming
|
147 |
+
self.show_download_button = show_download_button
|
148 |
+
if streaming and source != "webcam":
|
149 |
+
raise ValueError("Image streaming only available if source is 'webcam'.")
|
150 |
+
self.select: EventListenerMethod
|
151 |
+
"""
|
152 |
+
Event listener for when the user clicks on a pixel within the image.
|
153 |
+
Uses event data gradio.SelectData to carry `index` to refer to the [x, y] coordinates of the clicked pixel.
|
154 |
+
See EventData documentation on how to use this event data.
|
155 |
+
"""
|
156 |
+
self.show_share_button = (
|
157 |
+
(utils.get_space() is not None)
|
158 |
+
if show_share_button is None
|
159 |
+
else show_share_button
|
160 |
+
)
|
161 |
+
IOComponent.__init__(
|
162 |
+
self,
|
163 |
+
label=label,
|
164 |
+
every=every,
|
165 |
+
show_label=show_label,
|
166 |
+
container=container,
|
167 |
+
scale=scale,
|
168 |
+
min_width=min_width,
|
169 |
+
interactive=interactive,
|
170 |
+
visible=visible,
|
171 |
+
elem_id=elem_id,
|
172 |
+
elem_classes=elem_classes,
|
173 |
+
value=value,
|
174 |
+
**kwargs,
|
175 |
+
)
|
176 |
+
TokenInterpretable.__init__(self)
|
177 |
+
|
178 |
+
def get_config(self):
|
179 |
+
return {
|
180 |
+
"image_mode": self.image_mode,
|
181 |
+
"shape": self.shape,
|
182 |
+
"height": self.height,
|
183 |
+
"width": self.width,
|
184 |
+
"source": self.source,
|
185 |
+
"tool": self.tool,
|
186 |
+
"value": self.value,
|
187 |
+
"streaming": self.streaming,
|
188 |
+
"mirror_webcam": self.mirror_webcam,
|
189 |
+
"brush_radius": self.brush_radius,
|
190 |
+
"brush_color": self.brush_color,
|
191 |
+
"mask_opacity": self.mask_opacity,
|
192 |
+
"selectable": self.selectable,
|
193 |
+
"show_share_button": self.show_share_button,
|
194 |
+
"show_download_button": self.show_download_button,
|
195 |
+
**IOComponent.get_config(self),
|
196 |
+
}
|
197 |
+
|
198 |
+
@staticmethod
|
199 |
+
def update(
|
200 |
+
value: Any | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE,
|
201 |
+
height: int | None = None,
|
202 |
+
width: int | None = None,
|
203 |
+
label: str | None = None,
|
204 |
+
show_label: bool | None = None,
|
205 |
+
show_download_button: bool | None = None,
|
206 |
+
container: bool | None = None,
|
207 |
+
scale: int | None = None,
|
208 |
+
min_width: int | None = None,
|
209 |
+
interactive: bool | None = None,
|
210 |
+
visible: bool | None = None,
|
211 |
+
brush_radius: float | None = None,
|
212 |
+
brush_color: str | None = None,
|
213 |
+
mask_opacity: float | None = None,
|
214 |
+
show_share_button: bool | None = None,
|
215 |
+
):
|
216 |
+
return {
|
217 |
+
"height": height,
|
218 |
+
"width": width,
|
219 |
+
"label": label,
|
220 |
+
"show_label": show_label,
|
221 |
+
"show_download_button": show_download_button,
|
222 |
+
"container": container,
|
223 |
+
"scale": scale,
|
224 |
+
"min_width": min_width,
|
225 |
+
"interactive": interactive,
|
226 |
+
"visible": visible,
|
227 |
+
"value": value,
|
228 |
+
"brush_radius": brush_radius,
|
229 |
+
"brush_color": brush_color,
|
230 |
+
"mask_opacity": mask_opacity,
|
231 |
+
"show_share_button": show_share_button,
|
232 |
+
"__type__": "update",
|
233 |
+
}
|
234 |
+
|
235 |
+
def _format_image(
|
236 |
+
self, im: _Image.Image | None
|
237 |
+
) -> np.ndarray | _Image.Image | str | None:
|
238 |
+
"""Helper method to format an image based on self.type"""
|
239 |
+
if im is None:
|
240 |
+
return im
|
241 |
+
fmt = im.format
|
242 |
+
if self.type == "pil":
|
243 |
+
return im
|
244 |
+
elif self.type == "numpy":
|
245 |
+
return np.array(im)
|
246 |
+
elif self.type == "filepath":
|
247 |
+
path = self.pil_to_temp_file(
|
248 |
+
im, dir=self.DEFAULT_TEMP_DIR, format=fmt or "png"
|
249 |
+
)
|
250 |
+
self.temp_files.add(path)
|
251 |
+
return path
|
252 |
+
else:
|
253 |
+
raise ValueError(
|
254 |
+
"Unknown type: "
|
255 |
+
+ str(self.type)
|
256 |
+
+ ". Please choose from: 'numpy', 'pil', 'filepath'."
|
257 |
+
)
|
258 |
+
|
259 |
+
def preprocess(
|
260 |
+
self, x: str | dict[str, str]
|
261 |
+
) -> np.ndarray | _Image.Image | str | dict | None:
|
262 |
+
"""
|
263 |
+
Parameters:
|
264 |
+
x: base64 url data, or (if tool == "sketch") a dict of image and mask base64 url data
|
265 |
+
Returns:
|
266 |
+
image in requested format, or (if tool == "sketch") a dict of image and mask in requested format
|
267 |
+
"""
|
268 |
+
if x is None:
|
269 |
+
return x
|
270 |
+
|
271 |
+
mask = None
|
272 |
+
|
273 |
+
if self.tool == "sketch" and self.source in ["upload", "webcam"]:
|
274 |
+
if isinstance(x, dict):
|
275 |
+
x, mask = x["image"], x["mask"]
|
276 |
+
|
277 |
+
assert isinstance(x, str)
|
278 |
+
im = processing_utils.decode_base64_to_image(x)
|
279 |
+
with warnings.catch_warnings():
|
280 |
+
warnings.simplefilter("ignore")
|
281 |
+
im = im.convert(self.image_mode)
|
282 |
+
if self.shape is not None:
|
283 |
+
im = processing_utils.resize_and_crop(im, self.shape)
|
284 |
+
if self.invert_colors:
|
285 |
+
im = PIL.ImageOps.invert(im)
|
286 |
+
if (
|
287 |
+
self.source == "webcam"
|
288 |
+
and self.mirror_webcam is True
|
289 |
+
and self.tool != "color-sketch"
|
290 |
+
):
|
291 |
+
im = PIL.ImageOps.mirror(im)
|
292 |
+
|
293 |
+
if self.tool == "sketch" and self.source in ["upload", "webcam"]:
|
294 |
+
if mask is not None:
|
295 |
+
mask_im = processing_utils.decode_base64_to_image(mask)
|
296 |
+
if mask_im.mode == "RGBA": # whiten any opaque pixels in the mask
|
297 |
+
alpha_data = mask_im.getchannel("A").convert("L")
|
298 |
+
mask_im = _Image.merge("RGB", [alpha_data, alpha_data, alpha_data])
|
299 |
+
return {
|
300 |
+
"image": self._format_image(im),
|
301 |
+
"mask": self._format_image(mask_im),
|
302 |
+
}
|
303 |
+
else:
|
304 |
+
return {
|
305 |
+
"image": self._format_image(im),
|
306 |
+
"mask": None,
|
307 |
+
}
|
308 |
+
|
309 |
+
return self._format_image(im)
|
310 |
+
|
311 |
+
def postprocess(
|
312 |
+
self, y: np.ndarray | _Image.Image | str | Path | None
|
313 |
+
) -> str | None:
|
314 |
+
"""
|
315 |
+
Parameters:
|
316 |
+
y: image as a numpy array, PIL Image, string/Path filepath, or string URL
|
317 |
+
Returns:
|
318 |
+
base64 url data
|
319 |
+
"""
|
320 |
+
if y is None:
|
321 |
+
return None
|
322 |
+
if isinstance(y, np.ndarray):
|
323 |
+
return processing_utils.encode_array_to_base64(y)
|
324 |
+
elif isinstance(y, _Image.Image):
|
325 |
+
return processing_utils.encode_pil_to_base64(y)
|
326 |
+
elif isinstance(y, (str, Path)):
|
327 |
+
return client_utils.encode_url_or_file_to_base64(y)
|
328 |
+
else:
|
329 |
+
raise ValueError("Cannot process this value as an Image")
|
330 |
+
|
331 |
+
def set_interpret_parameters(self, segments: int = 16):
|
332 |
+
"""
|
333 |
+
Calculates interpretation score of image subsections by splitting the image into subsections, then using a "leave one out" method to calculate the score of each subsection by whiting out the subsection and measuring the delta of the output value.
|
334 |
+
Parameters:
|
335 |
+
segments: Number of interpretation segments to split image into.
|
336 |
+
"""
|
337 |
+
self.interpretation_segments = segments
|
338 |
+
return self
|
339 |
+
|
340 |
+
def _segment_by_slic(self, x):
|
341 |
+
"""
|
342 |
+
Helper method that segments an image into superpixels using slic.
|
343 |
+
Parameters:
|
344 |
+
x: base64 representation of an image
|
345 |
+
"""
|
346 |
+
x = processing_utils.decode_base64_to_image(x)
|
347 |
+
if self.shape is not None:
|
348 |
+
x = processing_utils.resize_and_crop(x, self.shape)
|
349 |
+
resized_and_cropped_image = np.array(x)
|
350 |
+
try:
|
351 |
+
from skimage.segmentation import slic
|
352 |
+
except (ImportError, ModuleNotFoundError) as err:
|
353 |
+
raise ValueError(
|
354 |
+
"Error: running this interpretation for images requires scikit-image, please install it first."
|
355 |
+
) from err
|
356 |
+
try:
|
357 |
+
segments_slic = slic(
|
358 |
+
resized_and_cropped_image,
|
359 |
+
self.interpretation_segments,
|
360 |
+
compactness=10,
|
361 |
+
sigma=1,
|
362 |
+
start_label=1,
|
363 |
+
)
|
364 |
+
except TypeError: # For skimage 0.16 and older
|
365 |
+
segments_slic = slic(
|
366 |
+
resized_and_cropped_image,
|
367 |
+
self.interpretation_segments,
|
368 |
+
compactness=10,
|
369 |
+
sigma=1,
|
370 |
+
)
|
371 |
+
return segments_slic, resized_and_cropped_image
|
372 |
+
|
373 |
+
def tokenize(self, x):
|
374 |
+
"""
|
375 |
+
Segments image into tokens, masks, and leave-one-out-tokens
|
376 |
+
Parameters:
|
377 |
+
x: base64 representation of an image
|
378 |
+
Returns:
|
379 |
+
tokens: list of tokens, used by the get_masked_input() method
|
380 |
+
leave_one_out_tokens: list of left-out tokens, used by the get_interpretation_neighbors() method
|
381 |
+
masks: list of masks, used by the get_interpretation_neighbors() method
|
382 |
+
"""
|
383 |
+
segments_slic, resized_and_cropped_image = self._segment_by_slic(x)
|
384 |
+
tokens, masks, leave_one_out_tokens = [], [], []
|
385 |
+
replace_color = np.mean(resized_and_cropped_image, axis=(0, 1))
|
386 |
+
for segment_value in np.unique(segments_slic):
|
387 |
+
mask = segments_slic == segment_value
|
388 |
+
image_screen = np.copy(resized_and_cropped_image)
|
389 |
+
image_screen[segments_slic == segment_value] = replace_color
|
390 |
+
leave_one_out_tokens.append(
|
391 |
+
processing_utils.encode_array_to_base64(image_screen)
|
392 |
+
)
|
393 |
+
token = np.copy(resized_and_cropped_image)
|
394 |
+
token[segments_slic != segment_value] = 0
|
395 |
+
tokens.append(token)
|
396 |
+
masks.append(mask)
|
397 |
+
return tokens, leave_one_out_tokens, masks
|
398 |
+
|
399 |
+
def get_masked_inputs(self, tokens, binary_mask_matrix):
|
400 |
+
masked_inputs = []
|
401 |
+
for binary_mask_vector in binary_mask_matrix:
|
402 |
+
masked_input = np.zeros_like(tokens[0], dtype=int)
|
403 |
+
for token, b in zip(tokens, binary_mask_vector):
|
404 |
+
masked_input = masked_input + token * int(b)
|
405 |
+
masked_inputs.append(processing_utils.encode_array_to_base64(masked_input))
|
406 |
+
return masked_inputs
|
407 |
+
|
408 |
+
def get_interpretation_scores(
|
409 |
+
self, x, neighbors, scores, masks, tokens=None, **kwargs
|
410 |
+
) -> list[list[float]]:
|
411 |
+
"""
|
412 |
+
Returns:
|
413 |
+
A 2D array representing the interpretation score of each pixel of the image.
|
414 |
+
"""
|
415 |
+
x = processing_utils.decode_base64_to_image(x)
|
416 |
+
if self.shape is not None:
|
417 |
+
x = processing_utils.resize_and_crop(x, self.shape)
|
418 |
+
x = np.array(x)
|
419 |
+
output_scores = np.zeros((x.shape[0], x.shape[1]))
|
420 |
+
|
421 |
+
for score, mask in zip(scores, masks):
|
422 |
+
output_scores += score * mask
|
423 |
+
|
424 |
+
max_val, min_val = np.max(output_scores), np.min(output_scores)
|
425 |
+
if max_val > 0:
|
426 |
+
output_scores = (output_scores - min_val) / (max_val - min_val)
|
427 |
+
return output_scores.tolist()
|
428 |
+
|
429 |
+
def style(self, *, height: int | None = None, width: int | None = None, **kwargs):
|
430 |
+
"""
|
431 |
+
This method is deprecated. Please set these arguments in the constructor instead.
|
432 |
+
"""
|
433 |
+
warn_style_method_deprecation()
|
434 |
+
if height is not None:
|
435 |
+
self.height = height
|
436 |
+
if width is not None:
|
437 |
+
self.width = width
|
438 |
+
return self
|
439 |
+
|
440 |
+
def check_streamable(self):
|
441 |
+
if self.source != "webcam":
|
442 |
+
raise ValueError("Image streaming only available if source is 'webcam'.")
|
443 |
+
|
444 |
+
def as_example(self, input_data: str | None) -> str:
|
445 |
+
if input_data is None:
|
446 |
+
return ""
|
447 |
+
elif (
|
448 |
+
self.root_url
|
449 |
+
): # If an externally hosted image, don't convert to absolute path
|
450 |
+
return input_data
|
451 |
+
return str(utils.abspath(input_data))
|
452 |
+
|
453 |
+
|
454 |
+
all_components = []
|
455 |
+
|
456 |
+
if not hasattr(Block, 'original__init__'):
|
457 |
+
Block.original_init = Block.__init__
|
458 |
+
|
459 |
+
|
460 |
+
def blk_ini(self, *args, **kwargs):
|
461 |
+
all_components.append(self)
|
462 |
+
return Block.original_init(self, *args, **kwargs)
|
463 |
+
|
464 |
+
|
465 |
+
Block.__init__ = blk_ini
|
466 |
+
|
467 |
+
|
468 |
+
gradio.routes.asyncio = importlib.reload(gradio.routes.asyncio)
|
469 |
+
|
470 |
+
if not hasattr(gradio.routes.asyncio, 'original_wait_for'):
|
471 |
+
gradio.routes.asyncio.original_wait_for = gradio.routes.asyncio.wait_for
|
472 |
+
|
473 |
+
|
474 |
+
def patched_wait_for(fut, timeout):
|
475 |
+
del timeout
|
476 |
+
return gradio.routes.asyncio.original_wait_for(fut, timeout=65535)
|
477 |
+
|
478 |
+
|
479 |
+
gradio.routes.asyncio.wait_for = patched_wait_for
|
480 |
+
|
modules/html.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
css = '''
|
2 |
+
.loader-container {
|
3 |
+
display: flex; /* Use flex to align items horizontally */
|
4 |
+
align-items: center; /* Center items vertically within the container */
|
5 |
+
white-space: nowrap; /* Prevent line breaks within the container */
|
6 |
+
}
|
7 |
+
|
8 |
+
.loader {
|
9 |
+
border: 8px solid #f3f3f3; /* Light grey */
|
10 |
+
border-top: 8px solid #3498db; /* Blue */
|
11 |
+
border-radius: 50%;
|
12 |
+
width: 30px;
|
13 |
+
height: 30px;
|
14 |
+
animation: spin 2s linear infinite;
|
15 |
+
}
|
16 |
+
|
17 |
+
@keyframes spin {
|
18 |
+
0% { transform: rotate(0deg); }
|
19 |
+
100% { transform: rotate(360deg); }
|
20 |
+
}
|
21 |
+
|
22 |
+
/* Style the progress bar */
|
23 |
+
progress {
|
24 |
+
appearance: none; /* Remove default styling */
|
25 |
+
height: 20px; /* Set the height of the progress bar */
|
26 |
+
border-radius: 5px; /* Round the corners of the progress bar */
|
27 |
+
background-color: #f3f3f3; /* Light grey background */
|
28 |
+
width: 100%;
|
29 |
+
}
|
30 |
+
|
31 |
+
/* Style the progress bar container */
|
32 |
+
.progress-container {
|
33 |
+
margin-left: 20px;
|
34 |
+
margin-right: 20px;
|
35 |
+
flex-grow: 1; /* Allow the progress container to take up remaining space */
|
36 |
+
}
|
37 |
+
|
38 |
+
/* Set the color of the progress bar fill */
|
39 |
+
progress::-webkit-progress-value {
|
40 |
+
background-color: #3498db; /* Blue color for the fill */
|
41 |
+
}
|
42 |
+
|
43 |
+
progress::-moz-progress-bar {
|
44 |
+
background-color: #3498db; /* Blue color for the fill in Firefox */
|
45 |
+
}
|
46 |
+
|
47 |
+
/* Style the text on the progress bar */
|
48 |
+
progress::after {
|
49 |
+
content: attr(value '%'); /* Display the progress value followed by '%' */
|
50 |
+
position: absolute;
|
51 |
+
top: 50%;
|
52 |
+
left: 50%;
|
53 |
+
transform: translate(-50%, -50%);
|
54 |
+
color: white; /* Set text color */
|
55 |
+
font-size: 14px; /* Set font size */
|
56 |
+
}
|
57 |
+
|
58 |
+
/* Style other texts */
|
59 |
+
.loader-container > span {
|
60 |
+
margin-left: 5px; /* Add spacing between the progress bar and the text */
|
61 |
+
}
|
62 |
+
|
63 |
+
.progress-bar > .generating {
|
64 |
+
display: none !important;
|
65 |
+
}
|
66 |
+
|
67 |
+
.progress-bar{
|
68 |
+
height: 30px !important;
|
69 |
+
}
|
70 |
+
|
71 |
+
.type_row{
|
72 |
+
height: 80px !important;
|
73 |
+
}
|
74 |
+
|
75 |
+
.type_row_half{
|
76 |
+
height: 32px !important;
|
77 |
+
}
|
78 |
+
|
79 |
+
.scroll-hide{
|
80 |
+
resize: none !important;
|
81 |
+
}
|
82 |
+
|
83 |
+
.refresh_button{
|
84 |
+
border: none !important;
|
85 |
+
background: none !important;
|
86 |
+
font-size: none !important;
|
87 |
+
box-shadow: none !important;
|
88 |
+
}
|
89 |
+
|
90 |
+
.advanced_check_row{
|
91 |
+
width: 250px !important;
|
92 |
+
}
|
93 |
+
|
94 |
+
.min_check{
|
95 |
+
min-width: min(1px, 100%) !important;
|
96 |
+
}
|
97 |
+
|
98 |
+
.resizable_area {
|
99 |
+
resize: vertical;
|
100 |
+
overflow: auto !important;
|
101 |
+
}
|
102 |
+
|
103 |
+
.aspect_ratios label {
|
104 |
+
width: 140px !important;
|
105 |
+
}
|
106 |
+
|
107 |
+
.aspect_ratios label span {
|
108 |
+
white-space: nowrap !important;
|
109 |
+
}
|
110 |
+
|
111 |
+
.aspect_ratios label input {
|
112 |
+
margin-left: -5px !important;
|
113 |
+
}
|
114 |
+
|
115 |
+
.lora_enable label {
|
116 |
+
height: 100%;
|
117 |
+
}
|
118 |
+
|
119 |
+
.lora_enable label input {
|
120 |
+
margin: auto;
|
121 |
+
}
|
122 |
+
|
123 |
+
.lora_enable label span {
|
124 |
+
display: none;
|
125 |
+
}
|
126 |
+
|
127 |
+
@-moz-document url-prefix() {
|
128 |
+
.lora_weight input[type=number] {
|
129 |
+
width: 80px;
|
130 |
+
}
|
131 |
+
}
|
132 |
+
|
133 |
+
'''
|
134 |
+
progress_html = '''
|
135 |
+
<div class="loader-container">
|
136 |
+
<div class="loader"></div>
|
137 |
+
<div class="progress-container">
|
138 |
+
<progress value="*number*" max="100"></progress>
|
139 |
+
</div>
|
140 |
+
<span>*text*</span>
|
141 |
+
</div>
|
142 |
+
'''
|
143 |
+
|
144 |
+
|
145 |
+
def make_progress_html(number, text):
|
146 |
+
return progress_html.replace('*number*', str(number)).replace('*text*', text)
|
modules/inpaint_worker.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from PIL import Image, ImageFilter
|
5 |
+
from modules.util import resample_image, set_image_shape_ceil, get_image_shape_ceil
|
6 |
+
from modules.upscaler import perform_upscale
|
7 |
+
import cv2
|
8 |
+
|
9 |
+
|
10 |
+
inpaint_head_model = None
|
11 |
+
|
12 |
+
|
13 |
+
class InpaintHead(torch.nn.Module):
|
14 |
+
def __init__(self, *args, **kwargs):
|
15 |
+
super().__init__(*args, **kwargs)
|
16 |
+
self.head = torch.nn.Parameter(torch.empty(size=(320, 5, 3, 3), device='cpu'))
|
17 |
+
|
18 |
+
def __call__(self, x):
|
19 |
+
x = torch.nn.functional.pad(x, (1, 1, 1, 1), "replicate")
|
20 |
+
return torch.nn.functional.conv2d(input=x, weight=self.head)
|
21 |
+
|
22 |
+
|
23 |
+
current_task = None
|
24 |
+
|
25 |
+
|
26 |
+
def box_blur(x, k):
|
27 |
+
x = Image.fromarray(x)
|
28 |
+
x = x.filter(ImageFilter.BoxBlur(k))
|
29 |
+
return np.array(x)
|
30 |
+
|
31 |
+
|
32 |
+
def max_filter_opencv(x, ksize=3):
|
33 |
+
# Use OpenCV maximum filter
|
34 |
+
# Make sure the input type is int16
|
35 |
+
return cv2.dilate(x, np.ones((ksize, ksize), dtype=np.int16))
|
36 |
+
|
37 |
+
|
38 |
+
def morphological_open(x):
|
39 |
+
# Convert array to int16 type via threshold operation
|
40 |
+
x_int16 = np.zeros_like(x, dtype=np.int16)
|
41 |
+
x_int16[x > 127] = 256
|
42 |
+
|
43 |
+
for i in range(32):
|
44 |
+
# Use int16 type to avoid overflow
|
45 |
+
maxed = max_filter_opencv(x_int16, ksize=3) - 8
|
46 |
+
x_int16 = np.maximum(maxed, x_int16)
|
47 |
+
|
48 |
+
# Clip negative values to 0 and convert back to uint8 type
|
49 |
+
x_uint8 = np.clip(x_int16, 0, 255).astype(np.uint8)
|
50 |
+
return x_uint8
|
51 |
+
|
52 |
+
|
53 |
+
def up255(x, t=0):
|
54 |
+
y = np.zeros_like(x).astype(np.uint8)
|
55 |
+
y[x > t] = 255
|
56 |
+
return y
|
57 |
+
|
58 |
+
|
59 |
+
def imsave(x, path):
|
60 |
+
x = Image.fromarray(x)
|
61 |
+
x.save(path)
|
62 |
+
|
63 |
+
|
64 |
+
def regulate_abcd(x, a, b, c, d):
|
65 |
+
H, W = x.shape[:2]
|
66 |
+
if a < 0:
|
67 |
+
a = 0
|
68 |
+
if a > H:
|
69 |
+
a = H
|
70 |
+
if b < 0:
|
71 |
+
b = 0
|
72 |
+
if b > H:
|
73 |
+
b = H
|
74 |
+
if c < 0:
|
75 |
+
c = 0
|
76 |
+
if c > W:
|
77 |
+
c = W
|
78 |
+
if d < 0:
|
79 |
+
d = 0
|
80 |
+
if d > W:
|
81 |
+
d = W
|
82 |
+
return int(a), int(b), int(c), int(d)
|
83 |
+
|
84 |
+
|
85 |
+
def compute_initial_abcd(x):
|
86 |
+
indices = np.where(x)
|
87 |
+
a = np.min(indices[0])
|
88 |
+
b = np.max(indices[0])
|
89 |
+
c = np.min(indices[1])
|
90 |
+
d = np.max(indices[1])
|
91 |
+
abp = (b + a) // 2
|
92 |
+
abm = (b - a) // 2
|
93 |
+
cdp = (d + c) // 2
|
94 |
+
cdm = (d - c) // 2
|
95 |
+
l = int(max(abm, cdm) * 1.15)
|
96 |
+
a = abp - l
|
97 |
+
b = abp + l + 1
|
98 |
+
c = cdp - l
|
99 |
+
d = cdp + l + 1
|
100 |
+
a, b, c, d = regulate_abcd(x, a, b, c, d)
|
101 |
+
return a, b, c, d
|
102 |
+
|
103 |
+
|
104 |
+
def solve_abcd(x, a, b, c, d, k):
|
105 |
+
k = float(k)
|
106 |
+
assert 0.0 <= k <= 1.0
|
107 |
+
|
108 |
+
H, W = x.shape[:2]
|
109 |
+
if k == 1.0:
|
110 |
+
return 0, H, 0, W
|
111 |
+
while True:
|
112 |
+
if b - a >= H * k and d - c >= W * k:
|
113 |
+
break
|
114 |
+
|
115 |
+
add_h = (b - a) < (d - c)
|
116 |
+
add_w = not add_h
|
117 |
+
|
118 |
+
if b - a == H:
|
119 |
+
add_w = True
|
120 |
+
|
121 |
+
if d - c == W:
|
122 |
+
add_h = True
|
123 |
+
|
124 |
+
if add_h:
|
125 |
+
a -= 1
|
126 |
+
b += 1
|
127 |
+
|
128 |
+
if add_w:
|
129 |
+
c -= 1
|
130 |
+
d += 1
|
131 |
+
|
132 |
+
a, b, c, d = regulate_abcd(x, a, b, c, d)
|
133 |
+
return a, b, c, d
|
134 |
+
|
135 |
+
|
136 |
+
def fooocus_fill(image, mask):
|
137 |
+
current_image = image.copy()
|
138 |
+
raw_image = image.copy()
|
139 |
+
area = np.where(mask < 127)
|
140 |
+
store = raw_image[area]
|
141 |
+
|
142 |
+
for k, repeats in [(512, 2), (256, 2), (128, 4), (64, 4), (33, 8), (15, 8), (5, 16), (3, 16)]:
|
143 |
+
for _ in range(repeats):
|
144 |
+
current_image = box_blur(current_image, k)
|
145 |
+
current_image[area] = store
|
146 |
+
|
147 |
+
return current_image
|
148 |
+
|
149 |
+
|
150 |
+
class InpaintWorker:
|
151 |
+
def __init__(self, image, mask, use_fill=True, k=0.618):
|
152 |
+
a, b, c, d = compute_initial_abcd(mask > 0)
|
153 |
+
a, b, c, d = solve_abcd(mask, a, b, c, d, k=k)
|
154 |
+
|
155 |
+
# interested area
|
156 |
+
self.interested_area = (a, b, c, d)
|
157 |
+
self.interested_mask = mask[a:b, c:d]
|
158 |
+
self.interested_image = image[a:b, c:d]
|
159 |
+
|
160 |
+
# super resolution
|
161 |
+
if get_image_shape_ceil(self.interested_image) < 1024:
|
162 |
+
self.interested_image = perform_upscale(self.interested_image)
|
163 |
+
|
164 |
+
# resize to make images ready for diffusion
|
165 |
+
self.interested_image = set_image_shape_ceil(self.interested_image, 1024)
|
166 |
+
self.interested_fill = self.interested_image.copy()
|
167 |
+
H, W, C = self.interested_image.shape
|
168 |
+
|
169 |
+
# process mask
|
170 |
+
self.interested_mask = up255(resample_image(self.interested_mask, W, H), t=127)
|
171 |
+
|
172 |
+
# compute filling
|
173 |
+
if use_fill:
|
174 |
+
self.interested_fill = fooocus_fill(self.interested_image, self.interested_mask)
|
175 |
+
|
176 |
+
# soft pixels
|
177 |
+
self.mask = morphological_open(mask)
|
178 |
+
self.image = image
|
179 |
+
|
180 |
+
# ending
|
181 |
+
self.latent = None
|
182 |
+
self.latent_after_swap = None
|
183 |
+
self.swapped = False
|
184 |
+
self.latent_mask = None
|
185 |
+
self.inpaint_head_feature = None
|
186 |
+
return
|
187 |
+
|
188 |
+
def load_latent(self, latent_fill, latent_mask, latent_swap=None):
|
189 |
+
self.latent = latent_fill
|
190 |
+
self.latent_mask = latent_mask
|
191 |
+
self.latent_after_swap = latent_swap
|
192 |
+
return
|
193 |
+
|
194 |
+
def patch(self, inpaint_head_model_path, inpaint_latent, inpaint_latent_mask, model):
|
195 |
+
global inpaint_head_model
|
196 |
+
|
197 |
+
if inpaint_head_model is None:
|
198 |
+
inpaint_head_model = InpaintHead()
|
199 |
+
sd = torch.load(inpaint_head_model_path, map_location='cpu')
|
200 |
+
inpaint_head_model.load_state_dict(sd)
|
201 |
+
|
202 |
+
feed = torch.cat([
|
203 |
+
inpaint_latent_mask,
|
204 |
+
model.model.process_latent_in(inpaint_latent)
|
205 |
+
], dim=1)
|
206 |
+
|
207 |
+
inpaint_head_model.to(device=feed.device, dtype=feed.dtype)
|
208 |
+
inpaint_head_feature = inpaint_head_model(feed)
|
209 |
+
|
210 |
+
def input_block_patch(h, transformer_options):
|
211 |
+
if transformer_options["block"][1] == 0:
|
212 |
+
h = h + inpaint_head_feature.to(h)
|
213 |
+
return h
|
214 |
+
|
215 |
+
m = model.clone()
|
216 |
+
m.set_model_input_block_patch(input_block_patch)
|
217 |
+
return m
|
218 |
+
|
219 |
+
def swap(self):
|
220 |
+
if self.swapped:
|
221 |
+
return
|
222 |
+
|
223 |
+
if self.latent is None:
|
224 |
+
return
|
225 |
+
|
226 |
+
if self.latent_after_swap is None:
|
227 |
+
return
|
228 |
+
|
229 |
+
self.latent, self.latent_after_swap = self.latent_after_swap, self.latent
|
230 |
+
self.swapped = True
|
231 |
+
return
|
232 |
+
|
233 |
+
def unswap(self):
|
234 |
+
if not self.swapped:
|
235 |
+
return
|
236 |
+
|
237 |
+
if self.latent is None:
|
238 |
+
return
|
239 |
+
|
240 |
+
if self.latent_after_swap is None:
|
241 |
+
return
|
242 |
+
|
243 |
+
self.latent, self.latent_after_swap = self.latent_after_swap, self.latent
|
244 |
+
self.swapped = False
|
245 |
+
return
|
246 |
+
|
247 |
+
def color_correction(self, img):
|
248 |
+
fg = img.astype(np.float32)
|
249 |
+
bg = self.image.copy().astype(np.float32)
|
250 |
+
w = self.mask[:, :, None].astype(np.float32) / 255.0
|
251 |
+
y = fg * w + bg * (1 - w)
|
252 |
+
return y.clip(0, 255).astype(np.uint8)
|
253 |
+
|
254 |
+
def post_process(self, img):
|
255 |
+
a, b, c, d = self.interested_area
|
256 |
+
content = resample_image(img, d - c, b - a)
|
257 |
+
result = self.image.copy()
|
258 |
+
result[a:b, c:d] = content
|
259 |
+
result = self.color_correction(result)
|
260 |
+
return result
|
261 |
+
|
262 |
+
def visualize_mask_processing(self):
|
263 |
+
return [self.interested_fill, self.interested_mask, self.interested_image]
|
264 |
+
|
modules/launch_util.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import importlib
|
3 |
+
import importlib.util
|
4 |
+
import subprocess
|
5 |
+
import sys
|
6 |
+
import re
|
7 |
+
import logging
|
8 |
+
import importlib.metadata
|
9 |
+
import packaging.version
|
10 |
+
from packaging.requirements import Requirement
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
|
16 |
+
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
17 |
+
|
18 |
+
re_requirement = re.compile(r"\s*([-\w]+)\s*(?:==\s*([-+.\w]+))?\s*")
|
19 |
+
|
20 |
+
python = sys.executable
|
21 |
+
default_command_live = (os.environ.get('LAUNCH_LIVE_OUTPUT') == "1")
|
22 |
+
index_url = os.environ.get('INDEX_URL', "")
|
23 |
+
|
24 |
+
modules_path = os.path.dirname(os.path.realpath(__file__))
|
25 |
+
script_path = os.path.dirname(modules_path)
|
26 |
+
|
27 |
+
|
28 |
+
def is_installed(package):
|
29 |
+
try:
|
30 |
+
spec = importlib.util.find_spec(package)
|
31 |
+
except ModuleNotFoundError:
|
32 |
+
return False
|
33 |
+
|
34 |
+
return spec is not None
|
35 |
+
|
36 |
+
|
37 |
+
def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
|
38 |
+
if desc is not None:
|
39 |
+
print(desc)
|
40 |
+
|
41 |
+
run_kwargs = {
|
42 |
+
"args": command,
|
43 |
+
"shell": True,
|
44 |
+
"env": os.environ if custom_env is None else custom_env,
|
45 |
+
"encoding": 'utf8',
|
46 |
+
"errors": 'ignore',
|
47 |
+
}
|
48 |
+
|
49 |
+
if not live:
|
50 |
+
run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
|
51 |
+
|
52 |
+
result = subprocess.run(**run_kwargs)
|
53 |
+
|
54 |
+
if result.returncode != 0:
|
55 |
+
error_bits = [
|
56 |
+
f"{errdesc or 'Error running command'}.",
|
57 |
+
f"Command: {command}",
|
58 |
+
f"Error code: {result.returncode}",
|
59 |
+
]
|
60 |
+
if result.stdout:
|
61 |
+
error_bits.append(f"stdout: {result.stdout}")
|
62 |
+
if result.stderr:
|
63 |
+
error_bits.append(f"stderr: {result.stderr}")
|
64 |
+
raise RuntimeError("\n".join(error_bits))
|
65 |
+
|
66 |
+
return (result.stdout or "")
|
67 |
+
|
68 |
+
|
69 |
+
def run_pip(command, desc=None, live=default_command_live):
|
70 |
+
try:
|
71 |
+
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
72 |
+
return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}",
|
73 |
+
errdesc=f"Couldn't install {desc}", live=live)
|
74 |
+
except Exception as e:
|
75 |
+
print(e)
|
76 |
+
print(f'CMD Failed {desc}: {command}')
|
77 |
+
return None
|
78 |
+
|
79 |
+
|
80 |
+
def requirements_met(requirements_file):
|
81 |
+
with open(requirements_file, "r", encoding="utf8") as file:
|
82 |
+
for line in file:
|
83 |
+
line = line.strip()
|
84 |
+
if line == "" or line.startswith('#'):
|
85 |
+
continue
|
86 |
+
|
87 |
+
requirement = Requirement(line)
|
88 |
+
package = requirement.name
|
89 |
+
|
90 |
+
try:
|
91 |
+
version_installed = importlib.metadata.version(package)
|
92 |
+
installed_version = packaging.version.parse(version_installed)
|
93 |
+
|
94 |
+
# Check if the installed version satisfies the requirement
|
95 |
+
if installed_version not in requirement.specifier:
|
96 |
+
print(f"Version mismatch for {package}: Installed version {version_installed} does not meet requirement {requirement}")
|
97 |
+
return False
|
98 |
+
except Exception as e:
|
99 |
+
print(f"Error checking version for {package}: {e}")
|
100 |
+
return False
|
101 |
+
|
102 |
+
return True
|
103 |
+
|
modules/localization.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
|
5 |
+
current_translation = {}
|
6 |
+
localization_root = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'language')
|
7 |
+
|
8 |
+
|
9 |
+
def localization_js(filename):
|
10 |
+
global current_translation
|
11 |
+
|
12 |
+
if isinstance(filename, str):
|
13 |
+
full_name = os.path.abspath(os.path.join(localization_root, filename + '.json'))
|
14 |
+
if os.path.exists(full_name):
|
15 |
+
try:
|
16 |
+
with open(full_name, encoding='utf-8') as f:
|
17 |
+
current_translation = json.load(f)
|
18 |
+
assert isinstance(current_translation, dict)
|
19 |
+
for k, v in current_translation.items():
|
20 |
+
assert isinstance(k, str)
|
21 |
+
assert isinstance(v, str)
|
22 |
+
except Exception as e:
|
23 |
+
print(str(e))
|
24 |
+
print(f'Failed to load localization file {full_name}')
|
25 |
+
|
26 |
+
# current_translation = {k: 'XXX' for k in current_translation.keys()} # use this to see if all texts are covered
|
27 |
+
|
28 |
+
return f"window.localization = {json.dumps(current_translation)}"
|
29 |
+
|
30 |
+
|
31 |
+
def dump_english_config(components):
|
32 |
+
all_texts = []
|
33 |
+
for c in components:
|
34 |
+
label = getattr(c, 'label', None)
|
35 |
+
value = getattr(c, 'value', None)
|
36 |
+
choices = getattr(c, 'choices', None)
|
37 |
+
info = getattr(c, 'info', None)
|
38 |
+
|
39 |
+
if isinstance(label, str):
|
40 |
+
all_texts.append(label)
|
41 |
+
if isinstance(value, str):
|
42 |
+
all_texts.append(value)
|
43 |
+
if isinstance(info, str):
|
44 |
+
all_texts.append(info)
|
45 |
+
if isinstance(choices, list):
|
46 |
+
for x in choices:
|
47 |
+
if isinstance(x, str):
|
48 |
+
all_texts.append(x)
|
49 |
+
if isinstance(x, tuple):
|
50 |
+
for y in x:
|
51 |
+
if isinstance(y, str):
|
52 |
+
all_texts.append(y)
|
53 |
+
|
54 |
+
config_dict = {k: k for k in all_texts if k != "" and 'progress-container' not in k}
|
55 |
+
full_name = os.path.abspath(os.path.join(localization_root, 'en.json'))
|
56 |
+
|
57 |
+
with open(full_name, "w", encoding="utf-8") as json_file:
|
58 |
+
json.dump(config_dict, json_file, indent=4)
|
59 |
+
|
60 |
+
return
|
modules/lora.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def match_lora(lora, to_load):
|
2 |
+
patch_dict = {}
|
3 |
+
loaded_keys = set()
|
4 |
+
for x in to_load:
|
5 |
+
real_load_key = to_load[x]
|
6 |
+
if real_load_key in lora:
|
7 |
+
patch_dict[real_load_key] = ('fooocus', lora[real_load_key])
|
8 |
+
loaded_keys.add(real_load_key)
|
9 |
+
continue
|
10 |
+
|
11 |
+
alpha_name = "{}.alpha".format(x)
|
12 |
+
alpha = None
|
13 |
+
if alpha_name in lora.keys():
|
14 |
+
alpha = lora[alpha_name].item()
|
15 |
+
loaded_keys.add(alpha_name)
|
16 |
+
|
17 |
+
regular_lora = "{}.lora_up.weight".format(x)
|
18 |
+
diffusers_lora = "{}_lora.up.weight".format(x)
|
19 |
+
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
20 |
+
A_name = None
|
21 |
+
|
22 |
+
if regular_lora in lora.keys():
|
23 |
+
A_name = regular_lora
|
24 |
+
B_name = "{}.lora_down.weight".format(x)
|
25 |
+
mid_name = "{}.lora_mid.weight".format(x)
|
26 |
+
elif diffusers_lora in lora.keys():
|
27 |
+
A_name = diffusers_lora
|
28 |
+
B_name = "{}_lora.down.weight".format(x)
|
29 |
+
mid_name = None
|
30 |
+
elif transformers_lora in lora.keys():
|
31 |
+
A_name = transformers_lora
|
32 |
+
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
33 |
+
mid_name = None
|
34 |
+
|
35 |
+
if A_name is not None:
|
36 |
+
mid = None
|
37 |
+
if mid_name is not None and mid_name in lora.keys():
|
38 |
+
mid = lora[mid_name]
|
39 |
+
loaded_keys.add(mid_name)
|
40 |
+
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
|
41 |
+
loaded_keys.add(A_name)
|
42 |
+
loaded_keys.add(B_name)
|
43 |
+
|
44 |
+
|
45 |
+
######## loha
|
46 |
+
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
47 |
+
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
48 |
+
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
49 |
+
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
50 |
+
hada_t1_name = "{}.hada_t1".format(x)
|
51 |
+
hada_t2_name = "{}.hada_t2".format(x)
|
52 |
+
if hada_w1_a_name in lora.keys():
|
53 |
+
hada_t1 = None
|
54 |
+
hada_t2 = None
|
55 |
+
if hada_t1_name in lora.keys():
|
56 |
+
hada_t1 = lora[hada_t1_name]
|
57 |
+
hada_t2 = lora[hada_t2_name]
|
58 |
+
loaded_keys.add(hada_t1_name)
|
59 |
+
loaded_keys.add(hada_t2_name)
|
60 |
+
|
61 |
+
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
|
62 |
+
loaded_keys.add(hada_w1_a_name)
|
63 |
+
loaded_keys.add(hada_w1_b_name)
|
64 |
+
loaded_keys.add(hada_w2_a_name)
|
65 |
+
loaded_keys.add(hada_w2_b_name)
|
66 |
+
|
67 |
+
|
68 |
+
######## lokr
|
69 |
+
lokr_w1_name = "{}.lokr_w1".format(x)
|
70 |
+
lokr_w2_name = "{}.lokr_w2".format(x)
|
71 |
+
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
72 |
+
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
73 |
+
lokr_t2_name = "{}.lokr_t2".format(x)
|
74 |
+
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
75 |
+
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
76 |
+
|
77 |
+
lokr_w1 = None
|
78 |
+
if lokr_w1_name in lora.keys():
|
79 |
+
lokr_w1 = lora[lokr_w1_name]
|
80 |
+
loaded_keys.add(lokr_w1_name)
|
81 |
+
|
82 |
+
lokr_w2 = None
|
83 |
+
if lokr_w2_name in lora.keys():
|
84 |
+
lokr_w2 = lora[lokr_w2_name]
|
85 |
+
loaded_keys.add(lokr_w2_name)
|
86 |
+
|
87 |
+
lokr_w1_a = None
|
88 |
+
if lokr_w1_a_name in lora.keys():
|
89 |
+
lokr_w1_a = lora[lokr_w1_a_name]
|
90 |
+
loaded_keys.add(lokr_w1_a_name)
|
91 |
+
|
92 |
+
lokr_w1_b = None
|
93 |
+
if lokr_w1_b_name in lora.keys():
|
94 |
+
lokr_w1_b = lora[lokr_w1_b_name]
|
95 |
+
loaded_keys.add(lokr_w1_b_name)
|
96 |
+
|
97 |
+
lokr_w2_a = None
|
98 |
+
if lokr_w2_a_name in lora.keys():
|
99 |
+
lokr_w2_a = lora[lokr_w2_a_name]
|
100 |
+
loaded_keys.add(lokr_w2_a_name)
|
101 |
+
|
102 |
+
lokr_w2_b = None
|
103 |
+
if lokr_w2_b_name in lora.keys():
|
104 |
+
lokr_w2_b = lora[lokr_w2_b_name]
|
105 |
+
loaded_keys.add(lokr_w2_b_name)
|
106 |
+
|
107 |
+
lokr_t2 = None
|
108 |
+
if lokr_t2_name in lora.keys():
|
109 |
+
lokr_t2 = lora[lokr_t2_name]
|
110 |
+
loaded_keys.add(lokr_t2_name)
|
111 |
+
|
112 |
+
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
113 |
+
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))
|
114 |
+
|
115 |
+
#glora
|
116 |
+
a1_name = "{}.a1.weight".format(x)
|
117 |
+
a2_name = "{}.a2.weight".format(x)
|
118 |
+
b1_name = "{}.b1.weight".format(x)
|
119 |
+
b2_name = "{}.b2.weight".format(x)
|
120 |
+
if a1_name in lora:
|
121 |
+
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha))
|
122 |
+
loaded_keys.add(a1_name)
|
123 |
+
loaded_keys.add(a2_name)
|
124 |
+
loaded_keys.add(b1_name)
|
125 |
+
loaded_keys.add(b2_name)
|
126 |
+
|
127 |
+
w_norm_name = "{}.w_norm".format(x)
|
128 |
+
b_norm_name = "{}.b_norm".format(x)
|
129 |
+
w_norm = lora.get(w_norm_name, None)
|
130 |
+
b_norm = lora.get(b_norm_name, None)
|
131 |
+
|
132 |
+
if w_norm is not None:
|
133 |
+
loaded_keys.add(w_norm_name)
|
134 |
+
patch_dict[to_load[x]] = ("diff", (w_norm,))
|
135 |
+
if b_norm is not None:
|
136 |
+
loaded_keys.add(b_norm_name)
|
137 |
+
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
|
138 |
+
|
139 |
+
diff_name = "{}.diff".format(x)
|
140 |
+
diff_weight = lora.get(diff_name, None)
|
141 |
+
if diff_weight is not None:
|
142 |
+
patch_dict[to_load[x]] = ("diff", (diff_weight,))
|
143 |
+
loaded_keys.add(diff_name)
|
144 |
+
|
145 |
+
diff_bias_name = "{}.diff_b".format(x)
|
146 |
+
diff_bias = lora.get(diff_bias_name, None)
|
147 |
+
if diff_bias is not None:
|
148 |
+
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
149 |
+
loaded_keys.add(diff_bias_name)
|
150 |
+
|
151 |
+
remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys}
|
152 |
+
return patch_dict, remaining_dict
|
modules/meta_parser.py
ADDED
@@ -0,0 +1,573 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
from abc import ABC, abstractmethod
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
import fooocus_version
|
11 |
+
import modules.config
|
12 |
+
import modules.sdxl_styles
|
13 |
+
from modules.flags import MetadataScheme, Performance, Steps
|
14 |
+
from modules.flags import SAMPLERS, CIVITAI_NO_KARRAS
|
15 |
+
from modules.util import quote, unquote, extract_styles_from_prompt, is_json, get_file_from_folder_list, calculate_sha256
|
16 |
+
|
17 |
+
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
|
18 |
+
re_param = re.compile(re_param_code)
|
19 |
+
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
20 |
+
|
21 |
+
hash_cache = {}
|
22 |
+
|
23 |
+
|
24 |
+
def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
|
25 |
+
loaded_parameter_dict = raw_metadata
|
26 |
+
if isinstance(raw_metadata, str):
|
27 |
+
loaded_parameter_dict = json.loads(raw_metadata)
|
28 |
+
assert isinstance(loaded_parameter_dict, dict)
|
29 |
+
|
30 |
+
results = [len(loaded_parameter_dict) > 0, 1]
|
31 |
+
|
32 |
+
get_str('prompt', 'Prompt', loaded_parameter_dict, results)
|
33 |
+
get_str('negative_prompt', 'Negative Prompt', loaded_parameter_dict, results)
|
34 |
+
get_list('styles', 'Styles', loaded_parameter_dict, results)
|
35 |
+
get_str('performance', 'Performance', loaded_parameter_dict, results)
|
36 |
+
get_steps('steps', 'Steps', loaded_parameter_dict, results)
|
37 |
+
get_float('overwrite_switch', 'Overwrite Switch', loaded_parameter_dict, results)
|
38 |
+
get_resolution('resolution', 'Resolution', loaded_parameter_dict, results)
|
39 |
+
get_float('guidance_scale', 'Guidance Scale', loaded_parameter_dict, results)
|
40 |
+
get_float('sharpness', 'Sharpness', loaded_parameter_dict, results)
|
41 |
+
get_adm_guidance('adm_guidance', 'ADM Guidance', loaded_parameter_dict, results)
|
42 |
+
get_str('refiner_swap_method', 'Refiner Swap Method', loaded_parameter_dict, results)
|
43 |
+
get_float('adaptive_cfg', 'CFG Mimicking from TSNR', loaded_parameter_dict, results)
|
44 |
+
get_str('base_model', 'Base Model', loaded_parameter_dict, results)
|
45 |
+
get_str('refiner_model', 'Refiner Model', loaded_parameter_dict, results)
|
46 |
+
get_float('refiner_switch', 'Refiner Switch', loaded_parameter_dict, results)
|
47 |
+
get_str('sampler', 'Sampler', loaded_parameter_dict, results)
|
48 |
+
get_str('scheduler', 'Scheduler', loaded_parameter_dict, results)
|
49 |
+
get_seed('seed', 'Seed', loaded_parameter_dict, results)
|
50 |
+
|
51 |
+
if is_generating:
|
52 |
+
results.append(gr.update())
|
53 |
+
else:
|
54 |
+
results.append(gr.update(visible=True))
|
55 |
+
|
56 |
+
results.append(gr.update(visible=False))
|
57 |
+
|
58 |
+
get_freeu('freeu', 'FreeU', loaded_parameter_dict, results)
|
59 |
+
|
60 |
+
for i in range(modules.config.default_max_lora_number):
|
61 |
+
get_lora(f'lora_combined_{i + 1}', f'LoRA {i + 1}', loaded_parameter_dict, results)
|
62 |
+
|
63 |
+
return results
|
64 |
+
|
65 |
+
|
66 |
+
def get_str(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
|
67 |
+
try:
|
68 |
+
h = source_dict.get(key, source_dict.get(fallback, default))
|
69 |
+
assert isinstance(h, str)
|
70 |
+
results.append(h)
|
71 |
+
except:
|
72 |
+
results.append(gr.update())
|
73 |
+
|
74 |
+
|
75 |
+
def get_list(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
|
76 |
+
try:
|
77 |
+
h = source_dict.get(key, source_dict.get(fallback, default))
|
78 |
+
h = eval(h)
|
79 |
+
assert isinstance(h, list)
|
80 |
+
results.append(h)
|
81 |
+
except:
|
82 |
+
results.append(gr.update())
|
83 |
+
|
84 |
+
|
85 |
+
def get_float(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
|
86 |
+
try:
|
87 |
+
h = source_dict.get(key, source_dict.get(fallback, default))
|
88 |
+
assert h is not None
|
89 |
+
h = float(h)
|
90 |
+
results.append(h)
|
91 |
+
except:
|
92 |
+
results.append(gr.update())
|
93 |
+
|
94 |
+
|
95 |
+
def get_steps(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
|
96 |
+
try:
|
97 |
+
h = source_dict.get(key, source_dict.get(fallback, default))
|
98 |
+
assert h is not None
|
99 |
+
h = int(h)
|
100 |
+
# if not in steps or in steps and performance is not the same
|
101 |
+
if h not in iter(Steps) or Steps(h).name.casefold() != source_dict.get('performance', '').replace(' ', '_').casefold():
|
102 |
+
results.append(h)
|
103 |
+
return
|
104 |
+
results.append(-1)
|
105 |
+
except:
|
106 |
+
results.append(-1)
|
107 |
+
|
108 |
+
|
109 |
+
def get_resolution(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
|
110 |
+
try:
|
111 |
+
h = source_dict.get(key, source_dict.get(fallback, default))
|
112 |
+
width, height = eval(h)
|
113 |
+
formatted = modules.config.add_ratio(f'{width}*{height}')
|
114 |
+
if formatted in modules.config.available_aspect_ratios:
|
115 |
+
results.append(formatted)
|
116 |
+
results.append(-1)
|
117 |
+
results.append(-1)
|
118 |
+
else:
|
119 |
+
results.append(gr.update())
|
120 |
+
results.append(int(width))
|
121 |
+
results.append(int(height))
|
122 |
+
except:
|
123 |
+
results.append(gr.update())
|
124 |
+
results.append(gr.update())
|
125 |
+
results.append(gr.update())
|
126 |
+
|
127 |
+
|
128 |
+
def get_seed(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
|
129 |
+
try:
|
130 |
+
h = source_dict.get(key, source_dict.get(fallback, default))
|
131 |
+
assert h is not None
|
132 |
+
h = int(h)
|
133 |
+
results.append(False)
|
134 |
+
results.append(h)
|
135 |
+
except:
|
136 |
+
results.append(gr.update())
|
137 |
+
results.append(gr.update())
|
138 |
+
|
139 |
+
|
140 |
+
def get_adm_guidance(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
|
141 |
+
try:
|
142 |
+
h = source_dict.get(key, source_dict.get(fallback, default))
|
143 |
+
p, n, e = eval(h)
|
144 |
+
results.append(float(p))
|
145 |
+
results.append(float(n))
|
146 |
+
results.append(float(e))
|
147 |
+
except:
|
148 |
+
results.append(gr.update())
|
149 |
+
results.append(gr.update())
|
150 |
+
results.append(gr.update())
|
151 |
+
|
152 |
+
|
153 |
+
def get_freeu(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
|
154 |
+
try:
|
155 |
+
h = source_dict.get(key, source_dict.get(fallback, default))
|
156 |
+
b1, b2, s1, s2 = eval(h)
|
157 |
+
results.append(True)
|
158 |
+
results.append(float(b1))
|
159 |
+
results.append(float(b2))
|
160 |
+
results.append(float(s1))
|
161 |
+
results.append(float(s2))
|
162 |
+
except:
|
163 |
+
results.append(False)
|
164 |
+
results.append(gr.update())
|
165 |
+
results.append(gr.update())
|
166 |
+
results.append(gr.update())
|
167 |
+
results.append(gr.update())
|
168 |
+
|
169 |
+
|
170 |
+
def get_lora(key: str, fallback: str | None, source_dict: dict, results: list):
|
171 |
+
try:
|
172 |
+
n, w = source_dict.get(key, source_dict.get(fallback)).split(' : ')
|
173 |
+
w = float(w)
|
174 |
+
results.append(True)
|
175 |
+
results.append(n)
|
176 |
+
results.append(w)
|
177 |
+
except:
|
178 |
+
results.append(True)
|
179 |
+
results.append('None')
|
180 |
+
results.append(1)
|
181 |
+
|
182 |
+
|
183 |
+
def get_sha256(filepath):
|
184 |
+
global hash_cache
|
185 |
+
if filepath not in hash_cache:
|
186 |
+
hash_cache[filepath] = calculate_sha256(filepath)
|
187 |
+
|
188 |
+
return hash_cache[filepath]
|
189 |
+
|
190 |
+
|
191 |
+
def parse_meta_from_preset(preset_content):
|
192 |
+
assert isinstance(preset_content, dict)
|
193 |
+
preset_prepared = {}
|
194 |
+
items = preset_content
|
195 |
+
|
196 |
+
for settings_key, meta_key in modules.config.possible_preset_keys.items():
|
197 |
+
if settings_key == "default_loras":
|
198 |
+
loras = getattr(modules.config, settings_key)
|
199 |
+
if settings_key in items:
|
200 |
+
loras = items[settings_key]
|
201 |
+
for index, lora in enumerate(loras[:5]):
|
202 |
+
preset_prepared[f'lora_combined_{index + 1}'] = ' : '.join(map(str, lora))
|
203 |
+
elif settings_key == "default_aspect_ratio":
|
204 |
+
if settings_key in items and items[settings_key] is not None:
|
205 |
+
default_aspect_ratio = items[settings_key]
|
206 |
+
width, height = default_aspect_ratio.split('*')
|
207 |
+
else:
|
208 |
+
default_aspect_ratio = getattr(modules.config, settings_key)
|
209 |
+
width, height = default_aspect_ratio.split('×')
|
210 |
+
height = height[:height.index(" ")]
|
211 |
+
preset_prepared[meta_key] = (width, height)
|
212 |
+
else:
|
213 |
+
preset_prepared[meta_key] = items[settings_key] if settings_key in items and items[
|
214 |
+
settings_key] is not None else getattr(modules.config, settings_key)
|
215 |
+
|
216 |
+
if settings_key == "default_styles" or settings_key == "default_aspect_ratio":
|
217 |
+
preset_prepared[meta_key] = str(preset_prepared[meta_key])
|
218 |
+
|
219 |
+
return preset_prepared
|
220 |
+
|
221 |
+
|
222 |
+
class MetadataParser(ABC):
|
223 |
+
def __init__(self):
|
224 |
+
self.raw_prompt: str = ''
|
225 |
+
self.full_prompt: str = ''
|
226 |
+
self.raw_negative_prompt: str = ''
|
227 |
+
self.full_negative_prompt: str = ''
|
228 |
+
self.steps: int = 30
|
229 |
+
self.base_model_name: str = ''
|
230 |
+
self.base_model_hash: str = ''
|
231 |
+
self.refiner_model_name: str = ''
|
232 |
+
self.refiner_model_hash: str = ''
|
233 |
+
self.loras: list = []
|
234 |
+
|
235 |
+
@abstractmethod
|
236 |
+
def get_scheme(self) -> MetadataScheme:
|
237 |
+
raise NotImplementedError
|
238 |
+
|
239 |
+
@abstractmethod
|
240 |
+
def parse_json(self, metadata: dict | str) -> dict:
|
241 |
+
raise NotImplementedError
|
242 |
+
|
243 |
+
@abstractmethod
|
244 |
+
def parse_string(self, metadata: dict) -> str:
|
245 |
+
raise NotImplementedError
|
246 |
+
|
247 |
+
def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_prompt, steps, base_model_name,
|
248 |
+
refiner_model_name, loras):
|
249 |
+
self.raw_prompt = raw_prompt
|
250 |
+
self.full_prompt = full_prompt
|
251 |
+
self.raw_negative_prompt = raw_negative_prompt
|
252 |
+
self.full_negative_prompt = full_negative_prompt
|
253 |
+
self.steps = steps
|
254 |
+
self.base_model_name = Path(base_model_name).stem
|
255 |
+
|
256 |
+
base_model_path = get_file_from_folder_list(base_model_name, modules.config.paths_checkpoints)
|
257 |
+
self.base_model_hash = get_sha256(base_model_path)
|
258 |
+
|
259 |
+
if refiner_model_name not in ['', 'None']:
|
260 |
+
self.refiner_model_name = Path(refiner_model_name).stem
|
261 |
+
refiner_model_path = get_file_from_folder_list(refiner_model_name, modules.config.paths_checkpoints)
|
262 |
+
self.refiner_model_hash = get_sha256(refiner_model_path)
|
263 |
+
|
264 |
+
self.loras = []
|
265 |
+
for (lora_name, lora_weight) in loras:
|
266 |
+
if lora_name != 'None':
|
267 |
+
lora_path = get_file_from_folder_list(lora_name, modules.config.paths_loras)
|
268 |
+
lora_hash = get_sha256(lora_path)
|
269 |
+
self.loras.append((Path(lora_name).stem, lora_weight, lora_hash))
|
270 |
+
|
271 |
+
|
272 |
+
class A1111MetadataParser(MetadataParser):
|
273 |
+
def get_scheme(self) -> MetadataScheme:
|
274 |
+
return MetadataScheme.A1111
|
275 |
+
|
276 |
+
fooocus_to_a1111 = {
|
277 |
+
'raw_prompt': 'Raw prompt',
|
278 |
+
'raw_negative_prompt': 'Raw negative prompt',
|
279 |
+
'negative_prompt': 'Negative prompt',
|
280 |
+
'styles': 'Styles',
|
281 |
+
'performance': 'Performance',
|
282 |
+
'steps': 'Steps',
|
283 |
+
'sampler': 'Sampler',
|
284 |
+
'scheduler': 'Scheduler',
|
285 |
+
'guidance_scale': 'CFG scale',
|
286 |
+
'seed': 'Seed',
|
287 |
+
'resolution': 'Size',
|
288 |
+
'sharpness': 'Sharpness',
|
289 |
+
'adm_guidance': 'ADM Guidance',
|
290 |
+
'refiner_swap_method': 'Refiner Swap Method',
|
291 |
+
'adaptive_cfg': 'Adaptive CFG',
|
292 |
+
'overwrite_switch': 'Overwrite Switch',
|
293 |
+
'freeu': 'FreeU',
|
294 |
+
'base_model': 'Model',
|
295 |
+
'base_model_hash': 'Model hash',
|
296 |
+
'refiner_model': 'Refiner',
|
297 |
+
'refiner_model_hash': 'Refiner hash',
|
298 |
+
'lora_hashes': 'Lora hashes',
|
299 |
+
'lora_weights': 'Lora weights',
|
300 |
+
'created_by': 'User',
|
301 |
+
'version': 'Version'
|
302 |
+
}
|
303 |
+
|
304 |
+
def parse_json(self, metadata: str) -> dict:
|
305 |
+
metadata_prompt = ''
|
306 |
+
metadata_negative_prompt = ''
|
307 |
+
|
308 |
+
done_with_prompt = False
|
309 |
+
|
310 |
+
*lines, lastline = metadata.strip().split("\n")
|
311 |
+
if len(re_param.findall(lastline)) < 3:
|
312 |
+
lines.append(lastline)
|
313 |
+
lastline = ''
|
314 |
+
|
315 |
+
for line in lines:
|
316 |
+
line = line.strip()
|
317 |
+
if line.startswith(f"{self.fooocus_to_a1111['negative_prompt']}:"):
|
318 |
+
done_with_prompt = True
|
319 |
+
line = line[len(f"{self.fooocus_to_a1111['negative_prompt']}:"):].strip()
|
320 |
+
if done_with_prompt:
|
321 |
+
metadata_negative_prompt += ('' if metadata_negative_prompt == '' else "\n") + line
|
322 |
+
else:
|
323 |
+
metadata_prompt += ('' if metadata_prompt == '' else "\n") + line
|
324 |
+
|
325 |
+
found_styles, prompt, negative_prompt = extract_styles_from_prompt(metadata_prompt, metadata_negative_prompt)
|
326 |
+
|
327 |
+
data = {
|
328 |
+
'prompt': prompt,
|
329 |
+
'negative_prompt': negative_prompt
|
330 |
+
}
|
331 |
+
|
332 |
+
for k, v in re_param.findall(lastline):
|
333 |
+
try:
|
334 |
+
if v != '' and v[0] == '"' and v[-1] == '"':
|
335 |
+
v = unquote(v)
|
336 |
+
|
337 |
+
m = re_imagesize.match(v)
|
338 |
+
if m is not None:
|
339 |
+
data['resolution'] = str((m.group(1), m.group(2)))
|
340 |
+
else:
|
341 |
+
data[list(self.fooocus_to_a1111.keys())[list(self.fooocus_to_a1111.values()).index(k)]] = v
|
342 |
+
except Exception:
|
343 |
+
print(f"Error parsing \"{k}: {v}\"")
|
344 |
+
|
345 |
+
# workaround for multiline prompts
|
346 |
+
if 'raw_prompt' in data:
|
347 |
+
data['prompt'] = data['raw_prompt']
|
348 |
+
raw_prompt = data['raw_prompt'].replace("\n", ', ')
|
349 |
+
if metadata_prompt != raw_prompt and modules.sdxl_styles.fooocus_expansion not in found_styles:
|
350 |
+
found_styles.append(modules.sdxl_styles.fooocus_expansion)
|
351 |
+
|
352 |
+
if 'raw_negative_prompt' in data:
|
353 |
+
data['negative_prompt'] = data['raw_negative_prompt']
|
354 |
+
|
355 |
+
data['styles'] = str(found_styles)
|
356 |
+
|
357 |
+
# try to load performance based on steps, fallback for direct A1111 imports
|
358 |
+
if 'steps' in data and 'performance' not in data:
|
359 |
+
try:
|
360 |
+
data['performance'] = Performance[Steps(int(data['steps'])).name].value
|
361 |
+
except ValueError | KeyError:
|
362 |
+
pass
|
363 |
+
|
364 |
+
if 'sampler' in data:
|
365 |
+
data['sampler'] = data['sampler'].replace(' Karras', '')
|
366 |
+
# get key
|
367 |
+
for k, v in SAMPLERS.items():
|
368 |
+
if v == data['sampler']:
|
369 |
+
data['sampler'] = k
|
370 |
+
break
|
371 |
+
|
372 |
+
for key in ['base_model', 'refiner_model']:
|
373 |
+
if key in data:
|
374 |
+
for filename in modules.config.model_filenames:
|
375 |
+
path = Path(filename)
|
376 |
+
if data[key] == path.stem:
|
377 |
+
data[key] = filename
|
378 |
+
break
|
379 |
+
|
380 |
+
if 'lora_hashes' in data:
|
381 |
+
lora_filenames = modules.config.lora_filenames.copy()
|
382 |
+
if modules.config.sdxl_lcm_lora in lora_filenames:
|
383 |
+
lora_filenames.remove(modules.config.sdxl_lcm_lora)
|
384 |
+
for li, lora in enumerate(data['lora_hashes'].split(', ')):
|
385 |
+
lora_name, lora_hash, lora_weight = lora.split(': ')
|
386 |
+
for filename in lora_filenames:
|
387 |
+
path = Path(filename)
|
388 |
+
if lora_name == path.stem:
|
389 |
+
data[f'lora_combined_{li + 1}'] = f'{filename} : {lora_weight}'
|
390 |
+
break
|
391 |
+
|
392 |
+
return data
|
393 |
+
|
394 |
+
def parse_string(self, metadata: dict) -> str:
|
395 |
+
data = {k: v for _, k, v in metadata}
|
396 |
+
|
397 |
+
width, height = eval(data['resolution'])
|
398 |
+
|
399 |
+
sampler = data['sampler']
|
400 |
+
scheduler = data['scheduler']
|
401 |
+
if sampler in SAMPLERS and SAMPLERS[sampler] != '':
|
402 |
+
sampler = SAMPLERS[sampler]
|
403 |
+
if sampler not in CIVITAI_NO_KARRAS and scheduler == 'karras':
|
404 |
+
sampler += f' Karras'
|
405 |
+
|
406 |
+
generation_params = {
|
407 |
+
self.fooocus_to_a1111['steps']: self.steps,
|
408 |
+
self.fooocus_to_a1111['sampler']: sampler,
|
409 |
+
self.fooocus_to_a1111['seed']: data['seed'],
|
410 |
+
self.fooocus_to_a1111['resolution']: f'{width}x{height}',
|
411 |
+
self.fooocus_to_a1111['guidance_scale']: data['guidance_scale'],
|
412 |
+
self.fooocus_to_a1111['sharpness']: data['sharpness'],
|
413 |
+
self.fooocus_to_a1111['adm_guidance']: data['adm_guidance'],
|
414 |
+
self.fooocus_to_a1111['base_model']: Path(data['base_model']).stem,
|
415 |
+
self.fooocus_to_a1111['base_model_hash']: self.base_model_hash,
|
416 |
+
|
417 |
+
self.fooocus_to_a1111['performance']: data['performance'],
|
418 |
+
self.fooocus_to_a1111['scheduler']: scheduler,
|
419 |
+
# workaround for multiline prompts
|
420 |
+
self.fooocus_to_a1111['raw_prompt']: self.raw_prompt,
|
421 |
+
self.fooocus_to_a1111['raw_negative_prompt']: self.raw_negative_prompt,
|
422 |
+
}
|
423 |
+
|
424 |
+
if self.refiner_model_name not in ['', 'None']:
|
425 |
+
generation_params |= {
|
426 |
+
self.fooocus_to_a1111['refiner_model']: self.refiner_model_name,
|
427 |
+
self.fooocus_to_a1111['refiner_model_hash']: self.refiner_model_hash
|
428 |
+
}
|
429 |
+
|
430 |
+
for key in ['adaptive_cfg', 'overwrite_switch', 'refiner_swap_method', 'freeu']:
|
431 |
+
if key in data:
|
432 |
+
generation_params[self.fooocus_to_a1111[key]] = data[key]
|
433 |
+
|
434 |
+
lora_hashes = []
|
435 |
+
for index, (lora_name, lora_weight, lora_hash) in enumerate(self.loras):
|
436 |
+
# workaround for Fooocus not knowing LoRA name in LoRA metadata
|
437 |
+
lora_hashes.append(f'{lora_name}: {lora_hash}: {lora_weight}')
|
438 |
+
lora_hashes_string = ', '.join(lora_hashes)
|
439 |
+
|
440 |
+
generation_params |= {
|
441 |
+
self.fooocus_to_a1111['lora_hashes']: lora_hashes_string,
|
442 |
+
self.fooocus_to_a1111['version']: data['version']
|
443 |
+
}
|
444 |
+
|
445 |
+
if modules.config.metadata_created_by != '':
|
446 |
+
generation_params[self.fooocus_to_a1111['created_by']] = modules.config.metadata_created_by
|
447 |
+
|
448 |
+
generation_params_text = ", ".join(
|
449 |
+
[k if k == v else f'{k}: {quote(v)}' for k, v in generation_params.items() if
|
450 |
+
v is not None])
|
451 |
+
positive_prompt_resolved = ', '.join(self.full_prompt)
|
452 |
+
negative_prompt_resolved = ', '.join(self.full_negative_prompt)
|
453 |
+
negative_prompt_text = f"\nNegative prompt: {negative_prompt_resolved}" if negative_prompt_resolved else ""
|
454 |
+
return f"{positive_prompt_resolved}{negative_prompt_text}\n{generation_params_text}".strip()
|
455 |
+
|
456 |
+
|
457 |
+
class FooocusMetadataParser(MetadataParser):
|
458 |
+
def get_scheme(self) -> MetadataScheme:
|
459 |
+
return MetadataScheme.FOOOCUS
|
460 |
+
|
461 |
+
def parse_json(self, metadata: dict) -> dict:
|
462 |
+
model_filenames = modules.config.model_filenames.copy()
|
463 |
+
lora_filenames = modules.config.lora_filenames.copy()
|
464 |
+
if modules.config.sdxl_lcm_lora in lora_filenames:
|
465 |
+
lora_filenames.remove(modules.config.sdxl_lcm_lora)
|
466 |
+
|
467 |
+
for key, value in metadata.items():
|
468 |
+
if value in ['', 'None']:
|
469 |
+
continue
|
470 |
+
if key in ['base_model', 'refiner_model']:
|
471 |
+
metadata[key] = self.replace_value_with_filename(key, value, model_filenames)
|
472 |
+
elif key.startswith('lora_combined_'):
|
473 |
+
metadata[key] = self.replace_value_with_filename(key, value, lora_filenames)
|
474 |
+
else:
|
475 |
+
continue
|
476 |
+
|
477 |
+
return metadata
|
478 |
+
|
479 |
+
def parse_string(self, metadata: list) -> str:
|
480 |
+
for li, (label, key, value) in enumerate(metadata):
|
481 |
+
# remove model folder paths from metadata
|
482 |
+
if key.startswith('lora_combined_'):
|
483 |
+
name, weight = value.split(' : ')
|
484 |
+
name = Path(name).stem
|
485 |
+
value = f'{name} : {weight}'
|
486 |
+
metadata[li] = (label, key, value)
|
487 |
+
|
488 |
+
res = {k: v for _, k, v in metadata}
|
489 |
+
|
490 |
+
res['full_prompt'] = self.full_prompt
|
491 |
+
res['full_negative_prompt'] = self.full_negative_prompt
|
492 |
+
res['steps'] = self.steps
|
493 |
+
res['base_model'] = self.base_model_name
|
494 |
+
res['base_model_hash'] = self.base_model_hash
|
495 |
+
|
496 |
+
if self.refiner_model_name not in ['', 'None']:
|
497 |
+
res['refiner_model'] = self.refiner_model_name
|
498 |
+
res['refiner_model_hash'] = self.refiner_model_hash
|
499 |
+
|
500 |
+
res['loras'] = self.loras
|
501 |
+
|
502 |
+
if modules.config.metadata_created_by != '':
|
503 |
+
res['created_by'] = modules.config.metadata_created_by
|
504 |
+
|
505 |
+
return json.dumps(dict(sorted(res.items())))
|
506 |
+
|
507 |
+
@staticmethod
|
508 |
+
def replace_value_with_filename(key, value, filenames):
|
509 |
+
for filename in filenames:
|
510 |
+
path = Path(filename)
|
511 |
+
if key.startswith('lora_combined_'):
|
512 |
+
name, weight = value.split(' : ')
|
513 |
+
if name == path.stem:
|
514 |
+
return f'{filename} : {weight}'
|
515 |
+
elif value == path.stem:
|
516 |
+
return filename
|
517 |
+
|
518 |
+
|
519 |
+
def get_metadata_parser(metadata_scheme: MetadataScheme) -> MetadataParser:
|
520 |
+
match metadata_scheme:
|
521 |
+
case MetadataScheme.FOOOCUS:
|
522 |
+
return FooocusMetadataParser()
|
523 |
+
case MetadataScheme.A1111:
|
524 |
+
return A1111MetadataParser()
|
525 |
+
case _:
|
526 |
+
raise NotImplementedError
|
527 |
+
|
528 |
+
|
529 |
+
def read_info_from_image(filepath) -> tuple[str | None, MetadataScheme | None]:
|
530 |
+
with Image.open(filepath) as image:
|
531 |
+
items = (image.info or {}).copy()
|
532 |
+
|
533 |
+
parameters = items.pop('parameters', None)
|
534 |
+
metadata_scheme = items.pop('fooocus_scheme', None)
|
535 |
+
exif = items.pop('exif', None)
|
536 |
+
|
537 |
+
if parameters is not None and is_json(parameters):
|
538 |
+
parameters = json.loads(parameters)
|
539 |
+
elif exif is not None:
|
540 |
+
exif = image.getexif()
|
541 |
+
# 0x9286 = UserComment
|
542 |
+
parameters = exif.get(0x9286, None)
|
543 |
+
# 0x927C = MakerNote
|
544 |
+
metadata_scheme = exif.get(0x927C, None)
|
545 |
+
|
546 |
+
if is_json(parameters):
|
547 |
+
parameters = json.loads(parameters)
|
548 |
+
|
549 |
+
try:
|
550 |
+
metadata_scheme = MetadataScheme(metadata_scheme)
|
551 |
+
except ValueError:
|
552 |
+
metadata_scheme = None
|
553 |
+
|
554 |
+
# broad fallback
|
555 |
+
if isinstance(parameters, dict):
|
556 |
+
metadata_scheme = MetadataScheme.FOOOCUS
|
557 |
+
|
558 |
+
if isinstance(parameters, str):
|
559 |
+
metadata_scheme = MetadataScheme.A1111
|
560 |
+
|
561 |
+
return parameters, metadata_scheme
|
562 |
+
|
563 |
+
|
564 |
+
def get_exif(metadata: str | None, metadata_scheme: str):
|
565 |
+
exif = Image.Exif()
|
566 |
+
# tags see see https://github.com/python-pillow/Pillow/blob/9.2.x/src/PIL/ExifTags.py
|
567 |
+
# 0x9286 = UserComment
|
568 |
+
exif[0x9286] = metadata
|
569 |
+
# 0x0131 = Software
|
570 |
+
exif[0x0131] = 'Fooocus v' + fooocus_version.version
|
571 |
+
# 0x927C = MakerNote
|
572 |
+
exif[0x927C] = metadata_scheme
|
573 |
+
return exif
|
modules/model_loader.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from urllib.parse import urlparse
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
|
6 |
+
def load_file_from_url(
|
7 |
+
url: str,
|
8 |
+
*,
|
9 |
+
model_dir: str,
|
10 |
+
progress: bool = True,
|
11 |
+
file_name: Optional[str] = None,
|
12 |
+
) -> str:
|
13 |
+
"""Download a file from `url` into `model_dir`, using the file present if possible.
|
14 |
+
|
15 |
+
Returns the path to the downloaded file.
|
16 |
+
"""
|
17 |
+
os.makedirs(model_dir, exist_ok=True)
|
18 |
+
if not file_name:
|
19 |
+
parts = urlparse(url)
|
20 |
+
file_name = os.path.basename(parts.path)
|
21 |
+
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
|
22 |
+
if not os.path.exists(cached_file):
|
23 |
+
print(f'Downloading: "{url}" to {cached_file}\n')
|
24 |
+
from torch.hub import download_url_to_file
|
25 |
+
download_url_to_file(url, cached_file, progress=progress)
|
26 |
+
return cached_file
|
modules/ops.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import contextlib
|
3 |
+
|
4 |
+
|
5 |
+
@contextlib.contextmanager
|
6 |
+
def use_patched_ops(operations):
|
7 |
+
op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm']
|
8 |
+
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names}
|
9 |
+
|
10 |
+
try:
|
11 |
+
for op_name in op_names:
|
12 |
+
setattr(torch.nn, op_name, getattr(operations, op_name))
|
13 |
+
|
14 |
+
yield
|
15 |
+
|
16 |
+
finally:
|
17 |
+
for op_name in op_names:
|
18 |
+
setattr(torch.nn, op_name, backups[op_name])
|
19 |
+
return
|
modules/patch.py
ADDED
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import time
|
4 |
+
import math
|
5 |
+
import ldm_patched.modules.model_base
|
6 |
+
import ldm_patched.ldm.modules.diffusionmodules.openaimodel
|
7 |
+
import ldm_patched.modules.model_management
|
8 |
+
import modules.anisotropic as anisotropic
|
9 |
+
import ldm_patched.ldm.modules.attention
|
10 |
+
import ldm_patched.k_diffusion.sampling
|
11 |
+
import ldm_patched.modules.sd1_clip
|
12 |
+
import modules.inpaint_worker as inpaint_worker
|
13 |
+
import ldm_patched.ldm.modules.diffusionmodules.openaimodel
|
14 |
+
import ldm_patched.ldm.modules.diffusionmodules.model
|
15 |
+
import ldm_patched.modules.sd
|
16 |
+
import ldm_patched.controlnet.cldm
|
17 |
+
import ldm_patched.modules.model_patcher
|
18 |
+
import ldm_patched.modules.samplers
|
19 |
+
import ldm_patched.modules.args_parser
|
20 |
+
import warnings
|
21 |
+
import safetensors.torch
|
22 |
+
import modules.constants as constants
|
23 |
+
|
24 |
+
from ldm_patched.modules.samplers import calc_cond_uncond_batch
|
25 |
+
from ldm_patched.k_diffusion.sampling import BatchedBrownianTree
|
26 |
+
from ldm_patched.ldm.modules.diffusionmodules.openaimodel import forward_timestep_embed, apply_control
|
27 |
+
from modules.patch_precision import patch_all_precision
|
28 |
+
from modules.patch_clip import patch_all_clip
|
29 |
+
|
30 |
+
|
31 |
+
class PatchSettings:
|
32 |
+
def __init__(self,
|
33 |
+
sharpness=2.0,
|
34 |
+
adm_scaler_end=0.3,
|
35 |
+
positive_adm_scale=1.5,
|
36 |
+
negative_adm_scale=0.8,
|
37 |
+
controlnet_softness=0.25,
|
38 |
+
adaptive_cfg=7.0):
|
39 |
+
self.sharpness = sharpness
|
40 |
+
self.adm_scaler_end = adm_scaler_end
|
41 |
+
self.positive_adm_scale = positive_adm_scale
|
42 |
+
self.negative_adm_scale = negative_adm_scale
|
43 |
+
self.controlnet_softness = controlnet_softness
|
44 |
+
self.adaptive_cfg = adaptive_cfg
|
45 |
+
self.global_diffusion_progress = 0
|
46 |
+
self.eps_record = None
|
47 |
+
|
48 |
+
|
49 |
+
patch_settings = {}
|
50 |
+
|
51 |
+
|
52 |
+
def calculate_weight_patched(self, patches, weight, key):
|
53 |
+
for p in patches:
|
54 |
+
alpha = p[0]
|
55 |
+
v = p[1]
|
56 |
+
strength_model = p[2]
|
57 |
+
|
58 |
+
if strength_model != 1.0:
|
59 |
+
weight *= strength_model
|
60 |
+
|
61 |
+
if isinstance(v, list):
|
62 |
+
v = (self.calculate_weight(v[1:], v[0].clone(), key),)
|
63 |
+
|
64 |
+
if len(v) == 1:
|
65 |
+
patch_type = "diff"
|
66 |
+
elif len(v) == 2:
|
67 |
+
patch_type = v[0]
|
68 |
+
v = v[1]
|
69 |
+
|
70 |
+
if patch_type == "diff":
|
71 |
+
w1 = v[0]
|
72 |
+
if alpha != 0.0:
|
73 |
+
if w1.shape != weight.shape:
|
74 |
+
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
75 |
+
else:
|
76 |
+
weight += alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype)
|
77 |
+
elif patch_type == "lora":
|
78 |
+
mat1 = ldm_patched.modules.model_management.cast_to_device(v[0], weight.device, torch.float32)
|
79 |
+
mat2 = ldm_patched.modules.model_management.cast_to_device(v[1], weight.device, torch.float32)
|
80 |
+
if v[2] is not None:
|
81 |
+
alpha *= v[2] / mat2.shape[0]
|
82 |
+
if v[3] is not None:
|
83 |
+
mat3 = ldm_patched.modules.model_management.cast_to_device(v[3], weight.device, torch.float32)
|
84 |
+
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
85 |
+
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1),
|
86 |
+
mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
87 |
+
try:
|
88 |
+
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(
|
89 |
+
weight.shape).type(weight.dtype)
|
90 |
+
except Exception as e:
|
91 |
+
print("ERROR", key, e)
|
92 |
+
elif patch_type == "fooocus":
|
93 |
+
w1 = ldm_patched.modules.model_management.cast_to_device(v[0], weight.device, torch.float32)
|
94 |
+
w_min = ldm_patched.modules.model_management.cast_to_device(v[1], weight.device, torch.float32)
|
95 |
+
w_max = ldm_patched.modules.model_management.cast_to_device(v[2], weight.device, torch.float32)
|
96 |
+
w1 = (w1 / 255.0) * (w_max - w_min) + w_min
|
97 |
+
if alpha != 0.0:
|
98 |
+
if w1.shape != weight.shape:
|
99 |
+
print("WARNING SHAPE MISMATCH {} FOOOCUS WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
100 |
+
else:
|
101 |
+
weight += alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype)
|
102 |
+
elif patch_type == "lokr":
|
103 |
+
w1 = v[0]
|
104 |
+
w2 = v[1]
|
105 |
+
w1_a = v[3]
|
106 |
+
w1_b = v[4]
|
107 |
+
w2_a = v[5]
|
108 |
+
w2_b = v[6]
|
109 |
+
t2 = v[7]
|
110 |
+
dim = None
|
111 |
+
|
112 |
+
if w1 is None:
|
113 |
+
dim = w1_b.shape[0]
|
114 |
+
w1 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w1_a, weight.device, torch.float32),
|
115 |
+
ldm_patched.modules.model_management.cast_to_device(w1_b, weight.device, torch.float32))
|
116 |
+
else:
|
117 |
+
w1 = ldm_patched.modules.model_management.cast_to_device(w1, weight.device, torch.float32)
|
118 |
+
|
119 |
+
if w2 is None:
|
120 |
+
dim = w2_b.shape[0]
|
121 |
+
if t2 is None:
|
122 |
+
w2 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w2_a, weight.device, torch.float32),
|
123 |
+
ldm_patched.modules.model_management.cast_to_device(w2_b, weight.device, torch.float32))
|
124 |
+
else:
|
125 |
+
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
126 |
+
ldm_patched.modules.model_management.cast_to_device(t2, weight.device, torch.float32),
|
127 |
+
ldm_patched.modules.model_management.cast_to_device(w2_b, weight.device, torch.float32),
|
128 |
+
ldm_patched.modules.model_management.cast_to_device(w2_a, weight.device, torch.float32))
|
129 |
+
else:
|
130 |
+
w2 = ldm_patched.modules.model_management.cast_to_device(w2, weight.device, torch.float32)
|
131 |
+
|
132 |
+
if len(w2.shape) == 4:
|
133 |
+
w1 = w1.unsqueeze(2).unsqueeze(2)
|
134 |
+
if v[2] is not None and dim is not None:
|
135 |
+
alpha *= v[2] / dim
|
136 |
+
|
137 |
+
try:
|
138 |
+
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
139 |
+
except Exception as e:
|
140 |
+
print("ERROR", key, e)
|
141 |
+
elif patch_type == "loha":
|
142 |
+
w1a = v[0]
|
143 |
+
w1b = v[1]
|
144 |
+
if v[2] is not None:
|
145 |
+
alpha *= v[2] / w1b.shape[0]
|
146 |
+
w2a = v[3]
|
147 |
+
w2b = v[4]
|
148 |
+
if v[5] is not None: # cp decomposition
|
149 |
+
t1 = v[5]
|
150 |
+
t2 = v[6]
|
151 |
+
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
152 |
+
ldm_patched.modules.model_management.cast_to_device(t1, weight.device, torch.float32),
|
153 |
+
ldm_patched.modules.model_management.cast_to_device(w1b, weight.device, torch.float32),
|
154 |
+
ldm_patched.modules.model_management.cast_to_device(w1a, weight.device, torch.float32))
|
155 |
+
|
156 |
+
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
157 |
+
ldm_patched.modules.model_management.cast_to_device(t2, weight.device, torch.float32),
|
158 |
+
ldm_patched.modules.model_management.cast_to_device(w2b, weight.device, torch.float32),
|
159 |
+
ldm_patched.modules.model_management.cast_to_device(w2a, weight.device, torch.float32))
|
160 |
+
else:
|
161 |
+
m1 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w1a, weight.device, torch.float32),
|
162 |
+
ldm_patched.modules.model_management.cast_to_device(w1b, weight.device, torch.float32))
|
163 |
+
m2 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w2a, weight.device, torch.float32),
|
164 |
+
ldm_patched.modules.model_management.cast_to_device(w2b, weight.device, torch.float32))
|
165 |
+
|
166 |
+
try:
|
167 |
+
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
168 |
+
except Exception as e:
|
169 |
+
print("ERROR", key, e)
|
170 |
+
elif patch_type == "glora":
|
171 |
+
if v[4] is not None:
|
172 |
+
alpha *= v[4] / v[0].shape[0]
|
173 |
+
|
174 |
+
a1 = ldm_patched.modules.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
|
175 |
+
a2 = ldm_patched.modules.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
|
176 |
+
b1 = ldm_patched.modules.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
|
177 |
+
b2 = ldm_patched.modules.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
|
178 |
+
|
179 |
+
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
|
180 |
+
else:
|
181 |
+
print("patch type not recognized", patch_type, key)
|
182 |
+
|
183 |
+
return weight
|
184 |
+
|
185 |
+
|
186 |
+
class BrownianTreeNoiseSamplerPatched:
|
187 |
+
transform = None
|
188 |
+
tree = None
|
189 |
+
|
190 |
+
@staticmethod
|
191 |
+
def global_init(x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False):
|
192 |
+
if ldm_patched.modules.model_management.directml_enabled:
|
193 |
+
cpu = True
|
194 |
+
|
195 |
+
t0, t1 = transform(torch.as_tensor(sigma_min)), transform(torch.as_tensor(sigma_max))
|
196 |
+
|
197 |
+
BrownianTreeNoiseSamplerPatched.transform = transform
|
198 |
+
BrownianTreeNoiseSamplerPatched.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
|
199 |
+
|
200 |
+
def __init__(self, *args, **kwargs):
|
201 |
+
pass
|
202 |
+
|
203 |
+
@staticmethod
|
204 |
+
def __call__(sigma, sigma_next):
|
205 |
+
transform = BrownianTreeNoiseSamplerPatched.transform
|
206 |
+
tree = BrownianTreeNoiseSamplerPatched.tree
|
207 |
+
|
208 |
+
t0, t1 = transform(torch.as_tensor(sigma)), transform(torch.as_tensor(sigma_next))
|
209 |
+
return tree(t0, t1) / (t1 - t0).abs().sqrt()
|
210 |
+
|
211 |
+
|
212 |
+
def compute_cfg(uncond, cond, cfg_scale, t):
|
213 |
+
pid = os.getpid()
|
214 |
+
mimic_cfg = float(patch_settings[pid].adaptive_cfg)
|
215 |
+
real_cfg = float(cfg_scale)
|
216 |
+
|
217 |
+
real_eps = uncond + real_cfg * (cond - uncond)
|
218 |
+
|
219 |
+
if cfg_scale > patch_settings[pid].adaptive_cfg:
|
220 |
+
mimicked_eps = uncond + mimic_cfg * (cond - uncond)
|
221 |
+
return real_eps * t + mimicked_eps * (1 - t)
|
222 |
+
else:
|
223 |
+
return real_eps
|
224 |
+
|
225 |
+
|
226 |
+
def patched_sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options=None, seed=None):
|
227 |
+
pid = os.getpid()
|
228 |
+
|
229 |
+
if math.isclose(cond_scale, 1.0) and not model_options.get("disable_cfg1_optimization", False):
|
230 |
+
final_x0 = calc_cond_uncond_batch(model, cond, None, x, timestep, model_options)[0]
|
231 |
+
|
232 |
+
if patch_settings[pid].eps_record is not None:
|
233 |
+
patch_settings[pid].eps_record = ((x - final_x0) / timestep).cpu()
|
234 |
+
|
235 |
+
return final_x0
|
236 |
+
|
237 |
+
positive_x0, negative_x0 = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
|
238 |
+
|
239 |
+
positive_eps = x - positive_x0
|
240 |
+
negative_eps = x - negative_x0
|
241 |
+
|
242 |
+
alpha = 0.001 * patch_settings[pid].sharpness * patch_settings[pid].global_diffusion_progress
|
243 |
+
|
244 |
+
positive_eps_degraded = anisotropic.adaptive_anisotropic_filter(x=positive_eps, g=positive_x0)
|
245 |
+
positive_eps_degraded_weighted = positive_eps_degraded * alpha + positive_eps * (1.0 - alpha)
|
246 |
+
|
247 |
+
final_eps = compute_cfg(uncond=negative_eps, cond=positive_eps_degraded_weighted,
|
248 |
+
cfg_scale=cond_scale, t=patch_settings[pid].global_diffusion_progress)
|
249 |
+
|
250 |
+
if patch_settings[pid].eps_record is not None:
|
251 |
+
patch_settings[pid].eps_record = (final_eps / timestep).cpu()
|
252 |
+
|
253 |
+
return x - final_eps
|
254 |
+
|
255 |
+
|
256 |
+
def round_to_64(x):
|
257 |
+
h = float(x)
|
258 |
+
h = h / 64.0
|
259 |
+
h = round(h)
|
260 |
+
h = int(h)
|
261 |
+
h = h * 64
|
262 |
+
return h
|
263 |
+
|
264 |
+
|
265 |
+
def sdxl_encode_adm_patched(self, **kwargs):
|
266 |
+
clip_pooled = ldm_patched.modules.model_base.sdxl_pooled(kwargs, self.noise_augmentor)
|
267 |
+
width = kwargs.get("width", 1024)
|
268 |
+
height = kwargs.get("height", 1024)
|
269 |
+
target_width = width
|
270 |
+
target_height = height
|
271 |
+
pid = os.getpid()
|
272 |
+
|
273 |
+
if kwargs.get("prompt_type", "") == "negative":
|
274 |
+
width = float(width) * patch_settings[pid].negative_adm_scale
|
275 |
+
height = float(height) * patch_settings[pid].negative_adm_scale
|
276 |
+
elif kwargs.get("prompt_type", "") == "positive":
|
277 |
+
width = float(width) * patch_settings[pid].positive_adm_scale
|
278 |
+
height = float(height) * patch_settings[pid].positive_adm_scale
|
279 |
+
|
280 |
+
def embedder(number_list):
|
281 |
+
h = self.embedder(torch.tensor(number_list, dtype=torch.float32))
|
282 |
+
h = torch.flatten(h).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
|
283 |
+
return h
|
284 |
+
|
285 |
+
width, height = int(width), int(height)
|
286 |
+
target_width, target_height = round_to_64(target_width), round_to_64(target_height)
|
287 |
+
|
288 |
+
adm_emphasized = embedder([height, width, 0, 0, target_height, target_width])
|
289 |
+
adm_consistent = embedder([target_height, target_width, 0, 0, target_height, target_width])
|
290 |
+
|
291 |
+
clip_pooled = clip_pooled.to(adm_emphasized)
|
292 |
+
final_adm = torch.cat((clip_pooled, adm_emphasized, clip_pooled, adm_consistent), dim=1)
|
293 |
+
|
294 |
+
return final_adm
|
295 |
+
|
296 |
+
|
297 |
+
def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
|
298 |
+
if inpaint_worker.current_task is not None:
|
299 |
+
latent_processor = self.inner_model.inner_model.process_latent_in
|
300 |
+
inpaint_latent = latent_processor(inpaint_worker.current_task.latent).to(x)
|
301 |
+
inpaint_mask = inpaint_worker.current_task.latent_mask.to(x)
|
302 |
+
|
303 |
+
if getattr(self, 'energy_generator', None) is None:
|
304 |
+
# avoid bad results by using different seeds.
|
305 |
+
self.energy_generator = torch.Generator(device='cpu').manual_seed((seed + 1) % constants.MAX_SEED)
|
306 |
+
|
307 |
+
energy_sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(x.shape) - 1))
|
308 |
+
current_energy = torch.randn(
|
309 |
+
x.size(), dtype=x.dtype, generator=self.energy_generator, device="cpu").to(x) * energy_sigma
|
310 |
+
x = x * inpaint_mask + (inpaint_latent + current_energy) * (1.0 - inpaint_mask)
|
311 |
+
|
312 |
+
out = self.inner_model(x, sigma,
|
313 |
+
cond=cond,
|
314 |
+
uncond=uncond,
|
315 |
+
cond_scale=cond_scale,
|
316 |
+
model_options=model_options,
|
317 |
+
seed=seed)
|
318 |
+
|
319 |
+
out = out * inpaint_mask + inpaint_latent * (1.0 - inpaint_mask)
|
320 |
+
else:
|
321 |
+
out = self.inner_model(x, sigma,
|
322 |
+
cond=cond,
|
323 |
+
uncond=uncond,
|
324 |
+
cond_scale=cond_scale,
|
325 |
+
model_options=model_options,
|
326 |
+
seed=seed)
|
327 |
+
return out
|
328 |
+
|
329 |
+
|
330 |
+
def timed_adm(y, timesteps):
|
331 |
+
if isinstance(y, torch.Tensor) and int(y.dim()) == 2 and int(y.shape[1]) == 5632:
|
332 |
+
y_mask = (timesteps > 999.0 * (1.0 - float(patch_settings[os.getpid()].adm_scaler_end))).to(y)[..., None]
|
333 |
+
y_with_adm = y[..., :2816].clone()
|
334 |
+
y_without_adm = y[..., 2816:].clone()
|
335 |
+
return y_with_adm * y_mask + y_without_adm * (1.0 - y_mask)
|
336 |
+
return y
|
337 |
+
|
338 |
+
|
339 |
+
def patched_cldm_forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
340 |
+
t_emb = ldm_patched.ldm.modules.diffusionmodules.openaimodel.timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
341 |
+
emb = self.time_embed(t_emb)
|
342 |
+
pid = os.getpid()
|
343 |
+
|
344 |
+
guided_hint = self.input_hint_block(hint, emb, context)
|
345 |
+
|
346 |
+
y = timed_adm(y, timesteps)
|
347 |
+
|
348 |
+
outs = []
|
349 |
+
|
350 |
+
hs = []
|
351 |
+
if self.num_classes is not None:
|
352 |
+
assert y.shape[0] == x.shape[0]
|
353 |
+
emb = emb + self.label_emb(y)
|
354 |
+
|
355 |
+
h = x
|
356 |
+
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
357 |
+
if guided_hint is not None:
|
358 |
+
h = module(h, emb, context)
|
359 |
+
h += guided_hint
|
360 |
+
guided_hint = None
|
361 |
+
else:
|
362 |
+
h = module(h, emb, context)
|
363 |
+
outs.append(zero_conv(h, emb, context))
|
364 |
+
|
365 |
+
h = self.middle_block(h, emb, context)
|
366 |
+
outs.append(self.middle_block_out(h, emb, context))
|
367 |
+
|
368 |
+
if patch_settings[pid].controlnet_softness > 0:
|
369 |
+
for i in range(10):
|
370 |
+
k = 1.0 - float(i) / 9.0
|
371 |
+
outs[i] = outs[i] * (1.0 - patch_settings[pid].controlnet_softness * k)
|
372 |
+
|
373 |
+
return outs
|
374 |
+
|
375 |
+
|
376 |
+
def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
377 |
+
self.current_step = 1.0 - timesteps.to(x) / 999.0
|
378 |
+
patch_settings[os.getpid()].global_diffusion_progress = float(self.current_step.detach().cpu().numpy().tolist()[0])
|
379 |
+
|
380 |
+
y = timed_adm(y, timesteps)
|
381 |
+
|
382 |
+
transformer_options["original_shape"] = list(x.shape)
|
383 |
+
transformer_options["transformer_index"] = 0
|
384 |
+
transformer_patches = transformer_options.get("patches", {})
|
385 |
+
|
386 |
+
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
|
387 |
+
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
|
388 |
+
time_context = kwargs.get("time_context", None)
|
389 |
+
|
390 |
+
assert (y is not None) == (
|
391 |
+
self.num_classes is not None
|
392 |
+
), "must specify y if and only if the model is class-conditional"
|
393 |
+
hs = []
|
394 |
+
t_emb = ldm_patched.ldm.modules.diffusionmodules.openaimodel.timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
395 |
+
emb = self.time_embed(t_emb)
|
396 |
+
|
397 |
+
if self.num_classes is not None:
|
398 |
+
assert y.shape[0] == x.shape[0]
|
399 |
+
emb = emb + self.label_emb(y)
|
400 |
+
|
401 |
+
h = x
|
402 |
+
for id, module in enumerate(self.input_blocks):
|
403 |
+
transformer_options["block"] = ("input", id)
|
404 |
+
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
405 |
+
h = apply_control(h, control, 'input')
|
406 |
+
if "input_block_patch" in transformer_patches:
|
407 |
+
patch = transformer_patches["input_block_patch"]
|
408 |
+
for p in patch:
|
409 |
+
h = p(h, transformer_options)
|
410 |
+
|
411 |
+
hs.append(h)
|
412 |
+
if "input_block_patch_after_skip" in transformer_patches:
|
413 |
+
patch = transformer_patches["input_block_patch_after_skip"]
|
414 |
+
for p in patch:
|
415 |
+
h = p(h, transformer_options)
|
416 |
+
|
417 |
+
transformer_options["block"] = ("middle", 0)
|
418 |
+
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
419 |
+
h = apply_control(h, control, 'middle')
|
420 |
+
|
421 |
+
for id, module in enumerate(self.output_blocks):
|
422 |
+
transformer_options["block"] = ("output", id)
|
423 |
+
hsp = hs.pop()
|
424 |
+
hsp = apply_control(hsp, control, 'output')
|
425 |
+
|
426 |
+
if "output_block_patch" in transformer_patches:
|
427 |
+
patch = transformer_patches["output_block_patch"]
|
428 |
+
for p in patch:
|
429 |
+
h, hsp = p(h, hsp, transformer_options)
|
430 |
+
|
431 |
+
h = torch.cat([h, hsp], dim=1)
|
432 |
+
del hsp
|
433 |
+
if len(hs) > 0:
|
434 |
+
output_shape = hs[-1].shape
|
435 |
+
else:
|
436 |
+
output_shape = None
|
437 |
+
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
438 |
+
h = h.type(x.dtype)
|
439 |
+
if self.predict_codebook_ids:
|
440 |
+
return self.id_predictor(h)
|
441 |
+
else:
|
442 |
+
return self.out(h)
|
443 |
+
|
444 |
+
|
445 |
+
def patched_load_models_gpu(*args, **kwargs):
|
446 |
+
execution_start_time = time.perf_counter()
|
447 |
+
y = ldm_patched.modules.model_management.load_models_gpu_origin(*args, **kwargs)
|
448 |
+
moving_time = time.perf_counter() - execution_start_time
|
449 |
+
if moving_time > 0.1:
|
450 |
+
print(f'[Fooocus Model Management] Moving model(s) has taken {moving_time:.2f} seconds')
|
451 |
+
return y
|
452 |
+
|
453 |
+
|
454 |
+
def build_loaded(module, loader_name):
|
455 |
+
original_loader_name = loader_name + '_origin'
|
456 |
+
|
457 |
+
if not hasattr(module, original_loader_name):
|
458 |
+
setattr(module, original_loader_name, getattr(module, loader_name))
|
459 |
+
|
460 |
+
original_loader = getattr(module, original_loader_name)
|
461 |
+
|
462 |
+
def loader(*args, **kwargs):
|
463 |
+
result = None
|
464 |
+
try:
|
465 |
+
result = original_loader(*args, **kwargs)
|
466 |
+
except Exception as e:
|
467 |
+
result = None
|
468 |
+
exp = str(e) + '\n'
|
469 |
+
for path in list(args) + list(kwargs.values()):
|
470 |
+
if isinstance(path, str):
|
471 |
+
if os.path.exists(path):
|
472 |
+
exp += f'File corrupted: {path} \n'
|
473 |
+
corrupted_backup_file = path + '.corrupted'
|
474 |
+
if os.path.exists(corrupted_backup_file):
|
475 |
+
os.remove(corrupted_backup_file)
|
476 |
+
os.replace(path, corrupted_backup_file)
|
477 |
+
if os.path.exists(path):
|
478 |
+
os.remove(path)
|
479 |
+
exp += f'Fooocus has tried to move the corrupted file to {corrupted_backup_file} \n'
|
480 |
+
exp += f'You may try again now and Fooocus will download models again. \n'
|
481 |
+
raise ValueError(exp)
|
482 |
+
return result
|
483 |
+
|
484 |
+
setattr(module, loader_name, loader)
|
485 |
+
return
|
486 |
+
|
487 |
+
|
488 |
+
def patch_all():
|
489 |
+
if ldm_patched.modules.model_management.directml_enabled:
|
490 |
+
ldm_patched.modules.model_management.lowvram_available = True
|
491 |
+
ldm_patched.modules.model_management.OOM_EXCEPTION = Exception
|
492 |
+
|
493 |
+
patch_all_precision()
|
494 |
+
patch_all_clip()
|
495 |
+
|
496 |
+
if not hasattr(ldm_patched.modules.model_management, 'load_models_gpu_origin'):
|
497 |
+
ldm_patched.modules.model_management.load_models_gpu_origin = ldm_patched.modules.model_management.load_models_gpu
|
498 |
+
|
499 |
+
ldm_patched.modules.model_management.load_models_gpu = patched_load_models_gpu
|
500 |
+
ldm_patched.modules.model_patcher.ModelPatcher.calculate_weight = calculate_weight_patched
|
501 |
+
ldm_patched.controlnet.cldm.ControlNet.forward = patched_cldm_forward
|
502 |
+
ldm_patched.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = patched_unet_forward
|
503 |
+
ldm_patched.modules.model_base.SDXL.encode_adm = sdxl_encode_adm_patched
|
504 |
+
ldm_patched.modules.samplers.KSamplerX0Inpaint.forward = patched_KSamplerX0Inpaint_forward
|
505 |
+
ldm_patched.k_diffusion.sampling.BrownianTreeNoiseSampler = BrownianTreeNoiseSamplerPatched
|
506 |
+
ldm_patched.modules.samplers.sampling_function = patched_sampling_function
|
507 |
+
|
508 |
+
warnings.filterwarnings(action='ignore', module='torchsde')
|
509 |
+
|
510 |
+
build_loaded(safetensors.torch, 'load_file')
|
511 |
+
build_loaded(torch, 'load')
|
512 |
+
|
513 |
+
return
|